diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c2fe9f5..58463426 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,55 @@ # Changelog +## [0.4.0] - 2025-XX-XX + +### Major changes +* Support for complex numbers (c32 and c64) with optimized complex GEMM +* Support for half precision and bfloat16 numbers +* (Experimental) Support for accelerating low precision GEMMs with tensor cores (XMX) +* Compiler directly generates SPIR-V binaries (instead of OpenCL-C); JIT compilation speed-up of about 2x +* Removed dependency on clir and ocloc libraries + +### Tiny Tensor Language +* Clarified execution model +* Added align attribute +* Added unroll attribute +* Added boolean, f16 (half precision), bf16 (bfloat16), c32 (single complex), and c64 (double complex) type +* Added address space to memref type +* Added cooperative matrix type +* Added new instructions: barrier, builtin, constant, cooperative_matrix_load/mul_add/scale/store, parallel, + subgroup_broadcast, work_group +* Extended foreach for iteration spaces with dim > 1 +* For-loops can yield values +* Removed immediate operands in favour of constant instruction +* Instructions annotate the return type instead of operand types +* Added cumulative sum instruction +* Change 1d group id to 3d group id; introduce 2d subgroup id + +### Core API +* Introduced compiler_context class for compilations flags and control of error reporting +* Removed functions to create kernel bundle from source text +* Removed source_context in favour of compiler_context +* Added functions to compile tinytc prog to SPIR-V +* Added API to run function passes +* Added float <-> half and float <-> bfloat16 conversion functions +* Added array_view class in C++-API + +### Builder API +* Changed ownership model for data type, inst, region, and func +* Normalized function names in builder API +* Instruction creation function always include return type + +### Compiler +* Overhauled infrastructure for writing advanced compiler passes +* Introduced data flow analysis to properly insert barriers automatically when for-loops are present +* Added constant folding, constant propagation, and dead code elimination passes +* Added memref alignment analysis +* Added analysis to check applicability of tensor core acceleration +* Added SPIR-V support; Tiny Tensor Language to SPIR-V conversion and SPIR-V binary generation +* Introduced GEMM to cooperative matrix conversion (instead of direct GEMM to OpenCL-C conversion) +* Added fast cooperative matrix mul add implementation for complex numbers +* Added attributes + ## [0.3.1] - 2024-05-22 * Bugfix: Add alias analysis for stack; needed to correctly insert barriers * Bugfix: Disable block writes as alignment analysis is missing diff --git a/CMakeLists.txt b/CMakeLists.txt index ca7f1f1e..df439147 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.23) -project(tiny_tensor_compiler VERSION 0.3.1 LANGUAGES C CXX) +project(tiny-tensor-compiler VERSION 0.3.1 LANGUAGES C CXX) include(CMakeDependentOption) option(BUILD_DOCUMENTATION "Build documentation" OFF) @@ -14,6 +14,7 @@ cmake_dependent_option(BUILD_LEVEL_ZERO cmake_dependent_option(BUILD_OPENCL "Build support for OpenCL run-time; required when SYCL is enabled" ON "NOT BUILD_SYCL" ON) +option(BUILD_DOUBLE_PRECISION_TESTS "Build double precision unit tests" ON) include(CTest) diff --git a/cmake/CPackGeneratedFiles.cmake.in b/cmake/CPackGeneratedFiles.cmake.in new file mode 100644 index 00000000..320e53cc --- /dev/null +++ b/cmake/CPackGeneratedFiles.cmake.in @@ -0,0 +1,11 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +if(CPACK_SOURCE_INSTALLED_DIRECTORIES) + foreach(_src @GENERATED_FILES@) + file(RELATIVE_PATH _src_rel @PROJECT_BINARY_DIR@ ${_src}) + get_filename_component(_src_dir ${_src_rel} DIRECTORY) + message(STATUS "Adding generated file ${_src_rel} to ${_src_dir}") + file(INSTALL ${_src} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/${_src_dir}") + endforeach() +endif() diff --git a/cmake/CPackSetup.cmake b/cmake/CPackSetup.cmake new file mode 100644 index 00000000..b4802865 --- /dev/null +++ b/cmake/CPackSetup.cmake @@ -0,0 +1,43 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +include(GitVersion) + +function(cpack_setup) + git_version() + set(CPACK_SOURCE_GENERATOR "TXZ") + set(CPACK_SOURCE_PACKAGE_FILE_NAME + "${CMAKE_PROJECT_NAME}-${GIT_MAJOR_VERSION}.${GIT_MINOR_VERSION}.${GIT_PATCH_VERSION}" + ) + if(NOT "${GIT_COMMITS_SINCE_RELEASE}" STREQUAL "0") + set(CPACK_SOURCE_PACKAGE_FILE_NAME + "${CPACK_SOURCE_PACKAGE_FILE_NAME}-${GIT_COMMITS_SINCE_RELEASE}-${GIT_COMMIT}" + ) + endif() + set(CPACK_SOURCE_IGNORE_FILES + "/\\\\.git/" + "/\\\\.gitignore$" + "\\\\.swp$" + "/build/" + "/build_debug/" + "/build_iwyu/" + "/coverity_scan/" + "/__pycache__/" + ) + + set(GENERATED_FILES ${GENERATED_FILES}) + + configure_file("${PROJECT_SOURCE_DIR}/cmake/CPackGeneratedFiles.cmake.in" "CPackGeneratedFiles.cmake" @ONLY) + set(CPACK_INSTALL_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/CPackGeneratedFiles.cmake") + include(CPack) + add_custom_target(dist + COMMAND "${CMAKE_COMMAND}" --build "${PROJECT_BINARY_DIR}" --target package_source + DEPENDS tinytc-objects + VERBATIM + USES_TERMINAL + ) + if(BUILD_OPENCL) + add_dependencies(dist tinytc_cl-objects) + endif() +endfunction() + diff --git a/cmake/CommonOptions.cmake b/cmake/CommonOptions.cmake index 1d5ba3a1..2eabf540 100644 --- a/cmake/CommonOptions.cmake +++ b/cmake/CommonOptions.cmake @@ -7,6 +7,7 @@ include(CheckCompilerFlag) include(CheckLinkerFlag) +include(GitVersion) function(add_flag_if_available lang target flag) string(MAKE_C_IDENTIFIER ${flag} flag_c) @@ -54,6 +55,11 @@ function(set_common_options lang target) SOVERSION ${tiny_tensor_compiler_VERSION_MAJOR}) set_property(TARGET ${target} PROPERTY POSITION_INDEPENDENT_CODE ON) target_compile_definitions(${target} PRIVATE -D_FORTIFY_SOURCE=2) + + git_version() + set_target_properties(${target} PROPERTIES + VERSION ${GIT_MAJOR_VERSION}.${GIT_MINOR_VERSION}.${GIT_PATCH_VERSION} + SOVERSION ${GIT_MAJOR_VERSION}) endfunction() function(set_c_common_options target) diff --git a/cmake/FindSPIRVTools.cmake b/cmake/FindSPIRVTools.cmake new file mode 100644 index 00000000..b1f4de66 --- /dev/null +++ b/cmake/FindSPIRVTools.cmake @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +# Try to find SPIR-V Tools + +# The following definitions are added on success +# +# SPIRVTools_FOUND - SPIR-V Tools was found +# SPIRVTools_SPIRV_VAL - spirv-val executable +# SPIRVTools_VERSION - SPIR-V Tools version +# +# The followings hints may be passed in the environment: +# +# RE2C_ROOT +# + +if(SPIRVTools_SPIRV_VAL) + set(SPIRVTools_FOUND TRUE) +else() + find_program(SPIRVTools_SPIRV_VAL NAMES spirv-val + HINTS + ENV SPIRVTools_ROOT + ENV PATH + ) + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(SPIRVTools DEFAULT_MSG SPIRVTools_SPIRV_VAL) + + execute_process(COMMAND ${SPIRVTools_SPIRV_VAL} --version + OUTPUT_VARIABLE SPIRVTools_version_output + ERROR_VARIABLE SPIRVTools_version_error + RESULT_VARIABLE SPIRVTools_version_result + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT ${SPIRVTools_version_result} EQUAL 0) + set(SPIRVTools_message "Command \"{SPIRVTools_SPIRV_VAL} --version\" failed:\n${SPIRVTools_version_output}\n${SPIRVTools_version_error}") + if(SPIRVTools_FIND_REQUIRED) + message(SEND_ERROR ${SPIRVTools_message}) + else() + message(${SPIRVTools_message}) + endif() + else() + string(REGEX REPLACE "SPIRV-Tools ([v0-9\.]+) .*" "\\1" SPIRVTools_VERSION "${SPIRVTools_version_output}") + endif() + + mark_as_advanced(SPIRVTools_SPIRV_VAL) +endif() diff --git a/cmake/GeneratedFiles.cmake b/cmake/GeneratedFiles.cmake new file mode 100644 index 00000000..21244083 --- /dev/null +++ b/cmake/GeneratedFiles.cmake @@ -0,0 +1,14 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +function(get_generated_files VAR target) + get_target_property(_sources ${target} SOURCES) + set(_generated_files "") + foreach(_src ${_sources}) + get_source_file_property(_generated "${_src}" GENERATED) + if(${_generated} GREATER 0) + list(APPEND _generated_files ${_src}) + endif() + endforeach() + set(${VAR} ${_generated_files} PARENT_SCOPE) +endfunction() diff --git a/cmake/GitVersion.cmake b/cmake/GitVersion.cmake index 0ab27061..bc39bf84 100644 --- a/cmake/GitVersion.cmake +++ b/cmake/GitVersion.cmake @@ -13,7 +13,7 @@ function(git_version) if(GIT_FOUND) execute_process( - COMMAND ${GIT_EXECUTABLE} describe --tags --long + COMMAND ${GIT_EXECUTABLE} describe --tags --long --match "v[0-9]*" WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" RESULT_VARIABLE status_code OUTPUT_VARIABLE output diff --git a/cmake/tinytc-config.cmake b/cmake/tinytc-config.cmake index a7a1126b..72ca66f0 100644 --- a/cmake/tinytc-config.cmake +++ b/cmake/tinytc-config.cmake @@ -10,10 +10,6 @@ function(load_dependencies type) if ("${type}" STREQUAL "shared") set(is_shared ON) endif () - if (NOT DEFINED clir_SHARED_LIBS) - set(clir_SHARED_LIBS ${is_shared}) - endif () - find_dependency(clir REQUIRED) endfunction() @SHARED_STATIC_TEMPLATE@ diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt index 388abf9d..0aa8b9bf 100644 --- a/docs/CMakeLists.txt +++ b/docs/CMakeLists.txt @@ -51,6 +51,7 @@ set(DOC_FILES api/ze/capi.rst api/ze/cxxapi.rst api/ze/index.rst + dev/coopmatrix_layout.rst manual/build.rst manual/builder.rst manual/calling_convention.rst diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in index faae40d4..bbedf22e 100644 --- a/docs/Doxyfile.in +++ b/docs/Doxyfile.in @@ -1370,15 +1370,6 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - # If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML # documentation will contain a main index with vertical navigation menus that # are dynamically created via JavaScript. If disabled, the navigation index will @@ -1703,17 +1694,6 @@ HTML_FORMULA_FORMAT = png FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANSPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - # The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands # to create new LaTeX commands to be used in formulas as building blocks. See # the section "Including formulas" for details. @@ -2050,14 +2030,6 @@ LATEX_HIDE_INDICES = NO LATEX_BIB_STYLE = plain -# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated -# page will contain the date and time when the page was generated. Setting this -# to NO can help when comparing the output of multiple runs. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_TIMESTAMP = NO - # The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) # path from which the emoji images will be read. If a relative path is entered, # it will be relative to the LATEX_OUTPUT directory. If left blank the @@ -2429,23 +2401,6 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# When you want a differently looking font in the dot files that doxygen -# generates you can specify the font name using DOT_FONTNAME. You need to make -# sure dot is able to find the font, which can be done by putting it in a -# standard location or by setting the DOTFONTPATH environment variable or by -# setting DOT_FONTPATH to the directory containing the font. -# The default value is: Helvetica. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_FONTNAME = Helvetica - -# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of -# dot graphs. -# Minimum value: 4, maximum value: 24, default value: 10. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_FONTSIZE = 10 - # By default doxygen will tell dot to use the default font as specified with # DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set # the path where dot can find it using this tag. @@ -2692,18 +2647,6 @@ DOT_GRAPH_MAX_NODES = 50 MAX_DOT_GRAPH_DEPTH = 0 -# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent -# background. This is disabled by default, because dot on Windows does not seem -# to support this out of the box. -# -# Warning: Depending on the platform used, enabling this option may lead to -# badly anti-aliased labels on the edges of a graph (i.e. they become hard to -# read). -# The default value is: NO. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_TRANSPARENT = NO - # Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output # files in one run (i.e. multiple -o and -T options on the command line). This # makes dot run faster, but since only newer versions of dot (>1.8.10) support diff --git a/docs/api/api_gen.py b/docs/api/api_gen.py index 01a7057a..939e38c5 100755 --- a/docs/api/api_gen.py +++ b/docs/api/api_gen.py @@ -4,6 +4,7 @@ from argparse import ArgumentParser from yaml import load, dump, Loader, Dumper +import re parser = ArgumentParser() parser.add_argument('input_yaml') @@ -35,6 +36,13 @@ def escape_ref(symbol): symbol = symbol.replace('>', '\\>') return symbol.replace('*', '\\*') +def get_label_and_title(rst_title): + m = re.match('([^<]+) <([^>]+)>', rst_title) + if m: + return m.groups() + return (rst_title, rst_title) + + api = dict() with open(args.input_yaml, 'r') as y: api = load(y, Loader) @@ -42,17 +50,21 @@ def escape_ref(symbol): with open(args.output_rst, 'w') as f: f.write('.. Copyright (C) 2024 Intel Corporation\n') f.write(' SPDX-License-Identifier: BSD-3-Clause\n\n') + for rst_title, categories in api.items(): - f.write('=' * len(rst_title) + '\n') - write_underline(f, rst_title, '=') + title_text, title_label = get_label_and_title(rst_title) + f.write(f'.. _{title_label}:\n\n') + f.write('=' * len(title_text) + '\n') + write_underline(f, title_text, '=') for category_name, category in categories.items(): write_underline(f, category_name, '=') for symbol_type, symbol_list in category.items(): f.write(f'* {title(symbol_type).title()}s\n\n') for symbol in symbol_list: - f.write(f' * :ref:`{escape_ref(strip_symbol_name(symbol))}`\n\n') + f.write(f' * :ref:`{escape_ref(symbol)}`\n\n') for symbol_type, symbol_list in category.items(): write_underline(f, f'{category_name} {title(symbol_type).title()}s', '-') for symbol in symbol_list: + f.write(f'.. _{escape_ref(symbol)}:\n\n') write_underline(f, strip_symbol_name(symbol), '.') f.write(f'.. doxygen{symbol_type}:: {symbol}\n\n') diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 6af8efdd..41cb2e1d 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Builder C-API: + ============= Builder C-API ============= @@ -10,14 +12,30 @@ Common * Enumerations + * :ref:`tinytc_address_space_t` + * :ref:`tinytc_arithmetic_t` * :ref:`tinytc_arithmetic_unary_t` + * :ref:`tinytc_builtin_t` + + * :ref:`tinytc_checked_flag_t` + * :ref:`tinytc_cmp_condition_t` + * :ref:`tinytc_group_arithmetic_t` + + * :ref:`tinytc_group_operation_t` + + * :ref:`tinytc_math_unary_t` + + * :ref:`tinytc_matrix_use_t` + * :ref:`tinytc_scalar_type_t` + * :ref:`tinytc_store_flag_t` + * :ref:`tinytc_transpose_t` * Definitions @@ -26,30 +44,54 @@ Common * Functions + * :ref:`tinytc_address_space_to_string` + * :ref:`tinytc_arithmetic_to_string` * :ref:`tinytc_arithmetic_unary_to_string` + * :ref:`tinytc_builtin_to_string` + + * :ref:`tinytc_checked_flag_to_string` + * :ref:`tinytc_cmp_condition_to_string` + * :ref:`tinytc_group_arithmetic_to_string` + + * :ref:`tinytc_group_operation_to_string` + + * :ref:`tinytc_math_unary_to_string` + + * :ref:`tinytc_matrix_use_to_string` + * :ref:`tinytc_scalar_type_size` * :ref:`tinytc_scalar_type_to_string` + * :ref:`tinytc_store_flag_to_string` + * :ref:`tinytc_transpose_to_string` * Structures - * :ref:`tinytc_position` + * :ref:`tinytc_named_attr` * :ref:`tinytc_location` + * :ref:`tinytc_position` + * Typedefs + * :ref:`tinytc_address_spaces_t` + + * :ref:`tinytc_attr_t` + * :ref:`tinytc_data_type_t` * :ref:`tinytc_func_t` + * :ref:`tinytc_named_attr_t` + * :ref:`tinytc_location_t` * :ref:`tinytc_position_t` @@ -58,10 +100,14 @@ Common * :ref:`tinytc_inst_t` + * :ref:`tinytc_inst_iterator_t` + * :ref:`tinytc_region_t` * :ref:`tinytc_value_t` + * :ref:`const_tinytc_attr_t` + * :ref:`const_tinytc_data_type_t` * :ref:`const_tinytc_func_t` @@ -72,29 +118,97 @@ Common * :ref:`const_tinytc_region_t` + * :ref:`const_tinytc_value_t` + Common Enumerations ------------------- +.. _tinytc_address_space_t: + +tinytc_address_space_t +...................... + +.. doxygenenum:: tinytc_address_space_t + +.. _tinytc_arithmetic_t: + tinytc_arithmetic_t ................... .. doxygenenum:: tinytc_arithmetic_t +.. _tinytc_arithmetic_unary_t: + tinytc_arithmetic_unary_t ......................... .. doxygenenum:: tinytc_arithmetic_unary_t +.. _tinytc_builtin_t: + +tinytc_builtin_t +................ + +.. doxygenenum:: tinytc_builtin_t + +.. _tinytc_checked_flag_t: + +tinytc_checked_flag_t +..................... + +.. doxygenenum:: tinytc_checked_flag_t + +.. _tinytc_cmp_condition_t: + tinytc_cmp_condition_t ...................... .. doxygenenum:: tinytc_cmp_condition_t +.. _tinytc_group_arithmetic_t: + +tinytc_group_arithmetic_t +......................... + +.. doxygenenum:: tinytc_group_arithmetic_t + +.. _tinytc_group_operation_t: + +tinytc_group_operation_t +........................ + +.. doxygenenum:: tinytc_group_operation_t + +.. _tinytc_math_unary_t: + +tinytc_math_unary_t +................... + +.. doxygenenum:: tinytc_math_unary_t + +.. _tinytc_matrix_use_t: + +tinytc_matrix_use_t +................... + +.. doxygenenum:: tinytc_matrix_use_t + +.. _tinytc_scalar_type_t: + tinytc_scalar_type_t .................... .. doxygenenum:: tinytc_scalar_type_t +.. _tinytc_store_flag_t: + +tinytc_store_flag_t +................... + +.. doxygenenum:: tinytc_store_flag_t + +.. _tinytc_transpose_t: + tinytc_transpose_t .................. @@ -103,6 +217,8 @@ tinytc_transpose_t Common Definitions ------------------ +.. _TINYTC_DYNAMIC: + TINYTC_DYNAMIC .............. @@ -111,31 +227,99 @@ TINYTC_DYNAMIC Common Functions ---------------- +.. _tinytc_address_space_to_string: + +tinytc_address_space_to_string +.............................. + +.. doxygenfunction:: tinytc_address_space_to_string + +.. _tinytc_arithmetic_to_string: + tinytc_arithmetic_to_string ........................... .. doxygenfunction:: tinytc_arithmetic_to_string +.. _tinytc_arithmetic_unary_to_string: + tinytc_arithmetic_unary_to_string ................................. .. doxygenfunction:: tinytc_arithmetic_unary_to_string +.. _tinytc_builtin_to_string: + +tinytc_builtin_to_string +........................ + +.. doxygenfunction:: tinytc_builtin_to_string + +.. _tinytc_checked_flag_to_string: + +tinytc_checked_flag_to_string +............................. + +.. doxygenfunction:: tinytc_checked_flag_to_string + +.. _tinytc_cmp_condition_to_string: + tinytc_cmp_condition_to_string .............................. .. doxygenfunction:: tinytc_cmp_condition_to_string +.. _tinytc_group_arithmetic_to_string: + +tinytc_group_arithmetic_to_string +................................. + +.. doxygenfunction:: tinytc_group_arithmetic_to_string + +.. _tinytc_group_operation_to_string: + +tinytc_group_operation_to_string +................................ + +.. doxygenfunction:: tinytc_group_operation_to_string + +.. _tinytc_math_unary_to_string: + +tinytc_math_unary_to_string +........................... + +.. doxygenfunction:: tinytc_math_unary_to_string + +.. _tinytc_matrix_use_to_string: + +tinytc_matrix_use_to_string +........................... + +.. doxygenfunction:: tinytc_matrix_use_to_string + +.. _tinytc_scalar_type_size: + tinytc_scalar_type_size ....................... .. doxygenfunction:: tinytc_scalar_type_size +.. _tinytc_scalar_type_to_string: + tinytc_scalar_type_to_string ............................ .. doxygenfunction:: tinytc_scalar_type_to_string +.. _tinytc_store_flag_to_string: + +tinytc_store_flag_to_string +........................... + +.. doxygenfunction:: tinytc_store_flag_to_string + +.. _tinytc_transpose_to_string: + tinytc_transpose_to_string .......................... @@ -144,176 +328,348 @@ tinytc_transpose_to_string Common Structures ----------------- -tinytc_position -............... +.. _tinytc_named_attr: -.. doxygenstruct:: tinytc_position +tinytc_named_attr +................. + +.. doxygenstruct:: tinytc_named_attr + +.. _tinytc_location: tinytc_location ............... .. doxygenstruct:: tinytc_location +.. _tinytc_position: + +tinytc_position +............... + +.. doxygenstruct:: tinytc_position + Common Typedefs --------------- +.. _tinytc_address_spaces_t: + +tinytc_address_spaces_t +....................... + +.. doxygentypedef:: tinytc_address_spaces_t + +.. _tinytc_attr_t: + +tinytc_attr_t +............. + +.. doxygentypedef:: tinytc_attr_t + +.. _tinytc_data_type_t: + tinytc_data_type_t .................. .. doxygentypedef:: tinytc_data_type_t +.. _tinytc_func_t: + tinytc_func_t ............. .. doxygentypedef:: tinytc_func_t +.. _tinytc_named_attr_t: + +tinytc_named_attr_t +................... + +.. doxygentypedef:: tinytc_named_attr_t + +.. _tinytc_location_t: + tinytc_location_t ................. .. doxygentypedef:: tinytc_location_t +.. _tinytc_position_t: + tinytc_position_t ................. .. doxygentypedef:: tinytc_position_t +.. _tinytc_prog_t: + tinytc_prog_t ............. .. doxygentypedef:: tinytc_prog_t +.. _tinytc_inst_t: + tinytc_inst_t ............. .. doxygentypedef:: tinytc_inst_t +.. _tinytc_inst_iterator_t: + +tinytc_inst_iterator_t +...................... + +.. doxygentypedef:: tinytc_inst_iterator_t + +.. _tinytc_region_t: + tinytc_region_t ............... .. doxygentypedef:: tinytc_region_t +.. _tinytc_value_t: + tinytc_value_t .............. .. doxygentypedef:: tinytc_value_t +.. _const_tinytc_attr_t: + +const_tinytc_attr_t +................... + +.. doxygentypedef:: const_tinytc_attr_t + +.. _const_tinytc_data_type_t: + const_tinytc_data_type_t ........................ .. doxygentypedef:: const_tinytc_data_type_t +.. _const_tinytc_func_t: + const_tinytc_func_t ................... .. doxygentypedef:: const_tinytc_func_t +.. _const_tinytc_inst_t: + const_tinytc_inst_t ................... .. doxygentypedef:: const_tinytc_inst_t +.. _const_tinytc_prog_t: + const_tinytc_prog_t ................... .. doxygentypedef:: const_tinytc_prog_t +.. _const_tinytc_region_t: + const_tinytc_region_t ..................... .. doxygentypedef:: const_tinytc_region_t +.. _const_tinytc_value_t: + +const_tinytc_value_t +.................... + +.. doxygentypedef:: const_tinytc_value_t + +Attribute +========= + +* Functions + + * :ref:`tinytc_array_attr_get` + + * :ref:`tinytc_boolean_attr_get` + + * :ref:`tinytc_dictionary_attr_get` + + * :ref:`tinytc_dictionary_attr_get_with_sorted` + + * :ref:`tinytc_dictionary_attr_sort` + + * :ref:`tinytc_integer_attr_get` + + * :ref:`tinytc_string_attr_get` + +Attribute Functions +------------------- + +.. _tinytc_array_attr_get: + +tinytc_array_attr_get +..................... + +.. doxygenfunction:: tinytc_array_attr_get + +.. _tinytc_boolean_attr_get: + +tinytc_boolean_attr_get +....................... + +.. doxygenfunction:: tinytc_boolean_attr_get + +.. _tinytc_dictionary_attr_get: + +tinytc_dictionary_attr_get +.......................... + +.. doxygenfunction:: tinytc_dictionary_attr_get + +.. _tinytc_dictionary_attr_get_with_sorted: + +tinytc_dictionary_attr_get_with_sorted +...................................... + +.. doxygenfunction:: tinytc_dictionary_attr_get_with_sorted + +.. _tinytc_dictionary_attr_sort: + +tinytc_dictionary_attr_sort +........................... + +.. doxygenfunction:: tinytc_dictionary_attr_sort + +.. _tinytc_integer_attr_get: + +tinytc_integer_attr_get +....................... + +.. doxygenfunction:: tinytc_integer_attr_get + +.. _tinytc_string_attr_get: + +tinytc_string_attr_get +...................... + +.. doxygenfunction:: tinytc_string_attr_get + Data Type ========= * Functions - * :ref:`tinytc_group_type_create` + * :ref:`tinytc_boolean_type_get` - * :ref:`tinytc_memref_type_create` + * :ref:`tinytc_coopmatrix_type_get` - * :ref:`tinytc_scalar_type_create` + * :ref:`tinytc_group_type_get` - * :ref:`tinytc_data_type_release` + * :ref:`tinytc_memref_type_get` - * :ref:`tinytc_data_type_retain` + * :ref:`tinytc_scalar_type_get` + + * :ref:`tinytc_void_type_get` Data Type Functions ------------------- -tinytc_group_type_create -........................ +.. _tinytc_boolean_type_get: -.. doxygenfunction:: tinytc_group_type_create +tinytc_boolean_type_get +....................... -tinytc_memref_type_create -......................... +.. doxygenfunction:: tinytc_boolean_type_get -.. doxygenfunction:: tinytc_memref_type_create +.. _tinytc_coopmatrix_type_get: -tinytc_scalar_type_create -......................... +tinytc_coopmatrix_type_get +.......................... -.. doxygenfunction:: tinytc_scalar_type_create +.. doxygenfunction:: tinytc_coopmatrix_type_get -tinytc_data_type_release -........................ +.. _tinytc_group_type_get: -.. doxygenfunction:: tinytc_data_type_release +tinytc_group_type_get +..................... -tinytc_data_type_retain -....................... +.. doxygenfunction:: tinytc_group_type_get -.. doxygenfunction:: tinytc_data_type_retain +.. _tinytc_memref_type_get: + +tinytc_memref_type_get +...................... + +.. doxygenfunction:: tinytc_memref_type_get + +.. _tinytc_scalar_type_get: + +tinytc_scalar_type_get +...................... + +.. doxygenfunction:: tinytc_scalar_type_get + +.. _tinytc_void_type_get: + +tinytc_void_type_get +.................... + +.. doxygenfunction:: tinytc_void_type_get Function ======== * Functions - * :ref:`tinytc_function_create` - - * :ref:`tinytc_function_prototype_create` + * :ref:`tinytc_func_create` - * :ref:`tinytc_function_set_subgroup_size` + * :ref:`tinytc_func_destroy` - * :ref:`tinytc_function_set_work_group_size` + * :ref:`tinytc_func_get_body` - * :ref:`tinytc_func_release` + * :ref:`tinytc_func_set_attr` - * :ref:`tinytc_func_retain` + * :ref:`tinytc_func_set_parameter_attr` Function Functions ------------------ -tinytc_function_create -...................... +.. _tinytc_func_create: -.. doxygenfunction:: tinytc_function_create +tinytc_func_create +.................. -tinytc_function_prototype_create -................................ +.. doxygenfunction:: tinytc_func_create -.. doxygenfunction:: tinytc_function_prototype_create +.. _tinytc_func_destroy: -tinytc_function_set_subgroup_size -................................. +tinytc_func_destroy +................... -.. doxygenfunction:: tinytc_function_set_subgroup_size +.. doxygenfunction:: tinytc_func_destroy -tinytc_function_set_work_group_size -................................... +.. _tinytc_func_get_body: -.. doxygenfunction:: tinytc_function_set_work_group_size +tinytc_func_get_body +.................... -tinytc_func_release -................... +.. doxygenfunction:: tinytc_func_get_body -.. doxygenfunction:: tinytc_func_release +.. _tinytc_func_set_attr: -tinytc_func_retain -.................. +tinytc_func_set_attr +.................... -.. doxygenfunction:: tinytc_func_retain +.. doxygenfunction:: tinytc_func_set_attr + +.. _tinytc_func_set_parameter_attr: + +tinytc_func_set_parameter_attr +.............................. + +.. doxygenfunction:: tinytc_func_set_parameter_attr Instruction =========== @@ -328,10 +684,44 @@ Instruction * :ref:`tinytc_arith_unary_inst_create` + * :ref:`tinytc_barrier_inst_create` + + * :ref:`tinytc_builtin_inst_create` + * :ref:`tinytc_cast_inst_create` * :ref:`tinytc_cmp_inst_create` + * :ref:`tinytc_constant_inst_create_boolean` + + * :ref:`tinytc_constant_inst_create_complex` + + * :ref:`tinytc_constant_inst_create_float` + + * :ref:`tinytc_constant_inst_create_int` + + * :ref:`tinytc_constant_inst_create_one` + + * :ref:`tinytc_constant_inst_create_zero` + + * :ref:`tinytc_cooperative_matrix_apply_inst_create` + + * :ref:`tinytc_cooperative_matrix_extract_inst_create` + + * :ref:`tinytc_cooperative_matrix_insert_inst_create` + + * :ref:`tinytc_cooperative_matrix_load_inst_create` + + * :ref:`tinytc_cooperative_matrix_mul_add_inst_create` + + * :ref:`tinytc_cooperative_matrix_prefetch_inst_create` + + * :ref:`tinytc_cooperative_matrix_scale_inst_create` + + * :ref:`tinytc_cooperative_matrix_store_inst_create` + + * :ref:`tinytc_cumsum_inst_create` + * :ref:`tinytc_expand_inst_create` * :ref:`tinytc_for_inst_create` @@ -346,18 +736,22 @@ Instruction * :ref:`tinytc_ger_inst_create` - * :ref:`tinytc_group_id_inst_create` - - * :ref:`tinytc_group_size_inst_create` - * :ref:`tinytc_hadamard_inst_create` * :ref:`tinytc_if_inst_create` * :ref:`tinytc_load_inst_create` + * :ref:`tinytc_math_unary_inst_create` + + * :ref:`tinytc_parallel_inst_create` + * :ref:`tinytc_size_inst_create` + * :ref:`tinytc_subgroup_broadcast_inst_create` + + * :ref:`tinytc_subgroup_operation_inst_create` + * :ref:`tinytc_store_inst_create` * :ref:`tinytc_subview_inst_create` @@ -366,161 +760,361 @@ Instruction * :ref:`tinytc_yield_inst_create` - * :ref:`tinytc_inst_get_value` + * :ref:`tinytc_inst_get_parent_region` + + * :ref:`tinytc_inst_get_regions` * :ref:`tinytc_inst_get_values` - * :ref:`tinytc_inst_release` + * :ref:`tinytc_inst_destroy` - * :ref:`tinytc_inst_retain` + * :ref:`tinytc_inst_set_attr` Instruction Functions --------------------- +.. _tinytc_alloca_inst_create: + tinytc_alloca_inst_create ......................... .. doxygenfunction:: tinytc_alloca_inst_create +.. _tinytc_axpby_inst_create: + tinytc_axpby_inst_create ........................ .. doxygenfunction:: tinytc_axpby_inst_create +.. _tinytc_arith_inst_create: + tinytc_arith_inst_create ........................ .. doxygenfunction:: tinytc_arith_inst_create +.. _tinytc_arith_unary_inst_create: + tinytc_arith_unary_inst_create .............................. .. doxygenfunction:: tinytc_arith_unary_inst_create +.. _tinytc_barrier_inst_create: + +tinytc_barrier_inst_create +.......................... + +.. doxygenfunction:: tinytc_barrier_inst_create + +.. _tinytc_builtin_inst_create: + +tinytc_builtin_inst_create +.......................... + +.. doxygenfunction:: tinytc_builtin_inst_create + +.. _tinytc_cast_inst_create: + tinytc_cast_inst_create ....................... .. doxygenfunction:: tinytc_cast_inst_create +.. _tinytc_cmp_inst_create: + tinytc_cmp_inst_create ...................... .. doxygenfunction:: tinytc_cmp_inst_create +.. _tinytc_constant_inst_create_boolean: + +tinytc_constant_inst_create_boolean +................................... + +.. doxygenfunction:: tinytc_constant_inst_create_boolean + +.. _tinytc_constant_inst_create_complex: + +tinytc_constant_inst_create_complex +................................... + +.. doxygenfunction:: tinytc_constant_inst_create_complex + +.. _tinytc_constant_inst_create_float: + +tinytc_constant_inst_create_float +................................. + +.. doxygenfunction:: tinytc_constant_inst_create_float + +.. _tinytc_constant_inst_create_int: + +tinytc_constant_inst_create_int +............................... + +.. doxygenfunction:: tinytc_constant_inst_create_int + +.. _tinytc_constant_inst_create_one: + +tinytc_constant_inst_create_one +............................... + +.. doxygenfunction:: tinytc_constant_inst_create_one + +.. _tinytc_constant_inst_create_zero: + +tinytc_constant_inst_create_zero +................................ + +.. doxygenfunction:: tinytc_constant_inst_create_zero + +.. _tinytc_cooperative_matrix_apply_inst_create: + +tinytc_cooperative_matrix_apply_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_apply_inst_create + +.. _tinytc_cooperative_matrix_extract_inst_create: + +tinytc_cooperative_matrix_extract_inst_create +............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_extract_inst_create + +.. _tinytc_cooperative_matrix_insert_inst_create: + +tinytc_cooperative_matrix_insert_inst_create +............................................ + +.. doxygenfunction:: tinytc_cooperative_matrix_insert_inst_create + +.. _tinytc_cooperative_matrix_load_inst_create: + +tinytc_cooperative_matrix_load_inst_create +.......................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_load_inst_create + +.. _tinytc_cooperative_matrix_mul_add_inst_create: + +tinytc_cooperative_matrix_mul_add_inst_create +............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_mul_add_inst_create + +.. _tinytc_cooperative_matrix_prefetch_inst_create: + +tinytc_cooperative_matrix_prefetch_inst_create +.............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_prefetch_inst_create + +.. _tinytc_cooperative_matrix_scale_inst_create: + +tinytc_cooperative_matrix_scale_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_scale_inst_create + +.. _tinytc_cooperative_matrix_store_inst_create: + +tinytc_cooperative_matrix_store_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_store_inst_create + +.. _tinytc_cumsum_inst_create: + +tinytc_cumsum_inst_create +......................... + +.. doxygenfunction:: tinytc_cumsum_inst_create + +.. _tinytc_expand_inst_create: + tinytc_expand_inst_create ......................... .. doxygenfunction:: tinytc_expand_inst_create +.. _tinytc_for_inst_create: + tinytc_for_inst_create ...................... .. doxygenfunction:: tinytc_for_inst_create +.. _tinytc_foreach_inst_create: + tinytc_foreach_inst_create .......................... .. doxygenfunction:: tinytc_foreach_inst_create +.. _tinytc_fuse_inst_create: + tinytc_fuse_inst_create ....................... .. doxygenfunction:: tinytc_fuse_inst_create +.. _tinytc_gemm_inst_create: + tinytc_gemm_inst_create ....................... .. doxygenfunction:: tinytc_gemm_inst_create +.. _tinytc_gemv_inst_create: + tinytc_gemv_inst_create ....................... .. doxygenfunction:: tinytc_gemv_inst_create +.. _tinytc_ger_inst_create: + tinytc_ger_inst_create ...................... .. doxygenfunction:: tinytc_ger_inst_create -tinytc_group_id_inst_create -........................... - -.. doxygenfunction:: tinytc_group_id_inst_create - -tinytc_group_size_inst_create -............................. - -.. doxygenfunction:: tinytc_group_size_inst_create +.. _tinytc_hadamard_inst_create: tinytc_hadamard_inst_create ........................... .. doxygenfunction:: tinytc_hadamard_inst_create +.. _tinytc_if_inst_create: + tinytc_if_inst_create ..................... .. doxygenfunction:: tinytc_if_inst_create +.. _tinytc_load_inst_create: + tinytc_load_inst_create ....................... .. doxygenfunction:: tinytc_load_inst_create +.. _tinytc_math_unary_inst_create: + +tinytc_math_unary_inst_create +............................. + +.. doxygenfunction:: tinytc_math_unary_inst_create + +.. _tinytc_parallel_inst_create: + +tinytc_parallel_inst_create +........................... + +.. doxygenfunction:: tinytc_parallel_inst_create + +.. _tinytc_size_inst_create: + tinytc_size_inst_create ....................... .. doxygenfunction:: tinytc_size_inst_create +.. _tinytc_subgroup_broadcast_inst_create: + +tinytc_subgroup_broadcast_inst_create +..................................... + +.. doxygenfunction:: tinytc_subgroup_broadcast_inst_create + +.. _tinytc_subgroup_operation_inst_create: + +tinytc_subgroup_operation_inst_create +..................................... + +.. doxygenfunction:: tinytc_subgroup_operation_inst_create + +.. _tinytc_store_inst_create: + tinytc_store_inst_create ........................ .. doxygenfunction:: tinytc_store_inst_create +.. _tinytc_subview_inst_create: + tinytc_subview_inst_create .......................... .. doxygenfunction:: tinytc_subview_inst_create +.. _tinytc_sum_inst_create: + tinytc_sum_inst_create ...................... .. doxygenfunction:: tinytc_sum_inst_create +.. _tinytc_yield_inst_create: + tinytc_yield_inst_create ........................ .. doxygenfunction:: tinytc_yield_inst_create -tinytc_inst_get_value -..................... +.. _tinytc_inst_get_parent_region: + +tinytc_inst_get_parent_region +............................. -.. doxygenfunction:: tinytc_inst_get_value +.. doxygenfunction:: tinytc_inst_get_parent_region + +.. _tinytc_inst_get_regions: + +tinytc_inst_get_regions +....................... + +.. doxygenfunction:: tinytc_inst_get_regions + +.. _tinytc_inst_get_values: tinytc_inst_get_values ...................... .. doxygenfunction:: tinytc_inst_get_values -tinytc_inst_release +.. _tinytc_inst_destroy: + +tinytc_inst_destroy ................... -.. doxygenfunction:: tinytc_inst_release +.. doxygenfunction:: tinytc_inst_destroy -tinytc_inst_retain -.................. +.. _tinytc_inst_set_attr: + +tinytc_inst_set_attr +.................... -.. doxygenfunction:: tinytc_inst_retain +.. doxygenfunction:: tinytc_inst_set_attr Program ======= * Functions - * :ref:`tinytc_program_create` + * :ref:`tinytc_prog_create` + + * :ref:`tinytc_prog_add_function` * :ref:`tinytc_prog_dump` + * :ref:`tinytc_prog_get_compiler_context` + * :ref:`tinytc_prog_print_to_file` * :ref:`tinytc_prog_print_to_string` @@ -532,31 +1126,57 @@ Program Program Functions ----------------- -tinytc_program_create -..................... +.. _tinytc_prog_create: -.. doxygenfunction:: tinytc_program_create +tinytc_prog_create +.................. + +.. doxygenfunction:: tinytc_prog_create + +.. _tinytc_prog_add_function: + +tinytc_prog_add_function +........................ + +.. doxygenfunction:: tinytc_prog_add_function + +.. _tinytc_prog_dump: tinytc_prog_dump ................ .. doxygenfunction:: tinytc_prog_dump +.. _tinytc_prog_get_compiler_context: + +tinytc_prog_get_compiler_context +................................ + +.. doxygenfunction:: tinytc_prog_get_compiler_context + +.. _tinytc_prog_print_to_file: + tinytc_prog_print_to_file ......................... .. doxygenfunction:: tinytc_prog_print_to_file +.. _tinytc_prog_print_to_string: + tinytc_prog_print_to_string ........................... .. doxygenfunction:: tinytc_prog_print_to_string +.. _tinytc_prog_release: + tinytc_prog_release ................... .. doxygenfunction:: tinytc_prog_release +.. _tinytc_prog_retain: + tinytc_prog_retain .................. @@ -567,84 +1187,122 @@ Region * Functions - * :ref:`tinytc_region_create` + * :ref:`tinytc_region_append` + + * :ref:`tinytc_region_begin` + + * :ref:`tinytc_region_end` - * :ref:`tinytc_region_release` + * :ref:`tinytc_region_erase` - * :ref:`tinytc_region_retain` + * :ref:`tinytc_region_insert` + + * :ref:`tinytc_next_inst` + + * :ref:`tinytc_prev_inst` + + * :ref:`tinytc_region_get_parameters` Region Functions ---------------- -tinytc_region_create +.. _tinytc_region_append: + +tinytc_region_append .................... -.. doxygenfunction:: tinytc_region_create +.. doxygenfunction:: tinytc_region_append -tinytc_region_release -..................... +.. _tinytc_region_begin: + +tinytc_region_begin +................... -.. doxygenfunction:: tinytc_region_release +.. doxygenfunction:: tinytc_region_begin -tinytc_region_retain +.. _tinytc_region_end: + +tinytc_region_end +................. + +.. doxygenfunction:: tinytc_region_end + +.. _tinytc_region_erase: + +tinytc_region_erase +................... + +.. doxygenfunction:: tinytc_region_erase + +.. _tinytc_region_insert: + +tinytc_region_insert .................... -.. doxygenfunction:: tinytc_region_retain +.. doxygenfunction:: tinytc_region_insert -Value -===== +.. _tinytc_next_inst: -* Functions +tinytc_next_inst +................ - * :ref:`tinytc_float_imm_create` +.. doxygenfunction:: tinytc_next_inst - * :ref:`tinytc_int_imm_create` +.. _tinytc_prev_inst: - * :ref:`tinytc_value_create` +tinytc_prev_inst +................ - * :ref:`tinytc_value_get_name` +.. doxygenfunction:: tinytc_prev_inst - * :ref:`tinytc_value_set_name` +.. _tinytc_region_get_parameters: - * :ref:`tinytc_value_release` +tinytc_region_get_parameters +............................ - * :ref:`tinytc_value_retain` +.. doxygenfunction:: tinytc_region_get_parameters -Value Functions ---------------- +Value +===== -tinytc_float_imm_create -....................... +* Functions -.. doxygenfunction:: tinytc_float_imm_create + * :ref:`tinytc_value_get_name` -tinytc_int_imm_create -..................... + * :ref:`tinytc_value_get_type` + + * :ref:`tinytc_value_set_name` -.. doxygenfunction:: tinytc_int_imm_create + * :ref:`tinytc_value_set_name_n` -tinytc_value_create -................... +Value Functions +--------------- -.. doxygenfunction:: tinytc_value_create +.. _tinytc_value_get_name: tinytc_value_get_name ..................... .. doxygenfunction:: tinytc_value_get_name +.. _tinytc_value_get_type: + +tinytc_value_get_type +..................... + +.. doxygenfunction:: tinytc_value_get_type + +.. _tinytc_value_set_name: + tinytc_value_set_name ..................... .. doxygenfunction:: tinytc_value_set_name -tinytc_value_release -.................... - -.. doxygenfunction:: tinytc_value_release +.. _tinytc_value_set_name_n: -tinytc_value_retain -................... +tinytc_value_set_name_n +....................... -.. doxygenfunction:: tinytc_value_retain +.. doxygenfunction:: tinytc_value_set_name_n diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 4412c745..6e60ca90 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -3,60 +3,111 @@ Builder C-API: Common: enum: + - tinytc_address_space_t - tinytc_arithmetic_t - tinytc_arithmetic_unary_t + - tinytc_builtin_t + - tinytc_checked_flag_t - tinytc_cmp_condition_t + - tinytc_group_arithmetic_t + - tinytc_group_operation_t + - tinytc_math_unary_t + - tinytc_matrix_use_t + - tinytc_reduce_mode_t - tinytc_scalar_type_t + - tinytc_store_flag_t - tinytc_transpose_t define: - TINYTC_DYNAMIC function: + - tinytc_address_space_to_string - tinytc_arithmetic_to_string - tinytc_arithmetic_unary_to_string + - tinytc_builtin_to_string + - tinytc_checked_flag_to_string - tinytc_cmp_condition_to_string + - tinytc_group_arithmetic_to_string + - tinytc_group_operation_to_string + - tinytc_math_unary_to_string + - tinytc_matrix_use_to_string + - tinytc_reduce_mode_to_string - tinytc_scalar_type_size - tinytc_scalar_type_to_string + - tinytc_store_flag_to_string - tinytc_transpose_to_string struct: - - tinytc_position + - tinytc_named_attr - tinytc_location + - tinytc_position typedef: + - tinytc_address_spaces_t + - tinytc_attr_t - tinytc_data_type_t - tinytc_func_t + - tinytc_named_attr_t - tinytc_location_t - tinytc_position_t - tinytc_prog_t - tinytc_inst_t + - tinytc_inst_iterator_t - tinytc_region_t - tinytc_value_t + - const_tinytc_attr_t - const_tinytc_data_type_t - const_tinytc_func_t - const_tinytc_inst_t - const_tinytc_prog_t - const_tinytc_region_t + - const_tinytc_value_t + Attribute: + function: + - tinytc_array_attr_get + - tinytc_boolean_attr_get + - tinytc_dictionary_attr_get + - tinytc_dictionary_attr_get_with_sorted + - tinytc_dictionary_attr_sort + - tinytc_integer_attr_get + - tinytc_string_attr_get Data Type: function: - - tinytc_group_type_create - - tinytc_memref_type_create - - tinytc_scalar_type_create - - tinytc_data_type_release - - tinytc_data_type_retain + - tinytc_boolean_type_get + - tinytc_coopmatrix_type_get + - tinytc_group_type_get + - tinytc_memref_type_get + - tinytc_scalar_type_get + - tinytc_void_type_get Function: function: - - tinytc_function_create - - tinytc_function_prototype_create - - tinytc_function_set_subgroup_size - - tinytc_function_set_work_group_size - - tinytc_func_release - - tinytc_func_retain + - tinytc_func_create + - tinytc_func_destroy + - tinytc_func_get_body + - tinytc_func_set_attr + - tinytc_func_set_parameter_attr Instruction: function: - tinytc_alloca_inst_create - tinytc_axpby_inst_create - tinytc_arith_inst_create - tinytc_arith_unary_inst_create + - tinytc_barrier_inst_create + - tinytc_builtin_inst_create - tinytc_cast_inst_create - tinytc_cmp_inst_create + - tinytc_constant_inst_create_boolean + - tinytc_constant_inst_create_complex + - tinytc_constant_inst_create_float + - tinytc_constant_inst_create_int + - tinytc_constant_inst_create_one + - tinytc_constant_inst_create_zero + - tinytc_cooperative_matrix_apply_inst_create + - tinytc_cooperative_matrix_extract_inst_create + - tinytc_cooperative_matrix_insert_inst_create + - tinytc_cooperative_matrix_load_inst_create + - tinytc_cooperative_matrix_mul_add_inst_create + - tinytc_cooperative_matrix_prefetch_inst_create + - tinytc_cooperative_matrix_scale_inst_create + - tinytc_cooperative_matrix_store_inst_create + - tinytc_cumsum_inst_create - tinytc_expand_inst_create - tinytc_for_inst_create - tinytc_foreach_inst_create @@ -64,39 +115,46 @@ Builder C-API: - tinytc_gemm_inst_create - tinytc_gemv_inst_create - tinytc_ger_inst_create - - tinytc_group_id_inst_create - - tinytc_group_size_inst_create - tinytc_hadamard_inst_create - tinytc_if_inst_create - tinytc_load_inst_create + - tinytc_math_unary_inst_create + - tinytc_parallel_inst_create - tinytc_size_inst_create + - tinytc_subgroup_broadcast_inst_create + - tinytc_subgroup_operation_inst_create - tinytc_store_inst_create - tinytc_subview_inst_create - tinytc_sum_inst_create - tinytc_yield_inst_create - - tinytc_inst_get_value + - tinytc_inst_get_parent_region + - tinytc_inst_get_regions - tinytc_inst_get_values - - tinytc_inst_release - - tinytc_inst_retain + - tinytc_inst_destroy + - tinytc_inst_set_attr Program: function: - - tinytc_program_create + - tinytc_prog_create + - tinytc_prog_add_function - tinytc_prog_dump + - tinytc_prog_get_compiler_context - tinytc_prog_print_to_file - tinytc_prog_print_to_string - tinytc_prog_release - tinytc_prog_retain Region: function: - - tinytc_region_create - - tinytc_region_release - - tinytc_region_retain + - tinytc_region_append + - tinytc_region_begin + - tinytc_region_end + - tinytc_region_erase + - tinytc_region_insert + - tinytc_next_inst + - tinytc_prev_inst + - tinytc_region_get_parameters Value: function: - - tinytc_float_imm_create - - tinytc_int_imm_create - - tinytc_value_create - tinytc_value_get_name + - tinytc_value_get_type - tinytc_value_set_name - - tinytc_value_release - - tinytc_value_retain + - tinytc_value_set_name_n diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 57668d4a..75c70972 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Builder C++-API: + =============== Builder C++-API =============== @@ -10,69 +12,158 @@ Common * Enumerations - * :ref:`arithmetic` + * :ref:`tinytc::address_space` + + * :ref:`tinytc::arithmetic` + + * :ref:`tinytc::arithmetic_unary` + + * :ref:`tinytc::builtin` + + * :ref:`tinytc::cmp_condition` + + * :ref:`tinytc::group_arithmetic` - * :ref:`arithmetic_unary` + * :ref:`tinytc::group_operation` - * :ref:`cmp_condition` + * :ref:`tinytc::math_unary` - * :ref:`scalar_type` + * :ref:`tinytc::matrix_use` - * :ref:`transpose` + * :ref:`tinytc::scalar_type` + + * :ref:`tinytc::store_flag` + + * :ref:`tinytc::transpose` * Functions - * :ref:`is_dynamic_value` + * :ref:`tinytc::is_dynamic_value` + + * :ref:`tinytc::to_string(address_space)` + + * :ref:`tinytc::to_string(arithmetic)` + + * :ref:`tinytc::to_string(arithmetic_unary)` + + * :ref:`tinytc::to_string(builtin)` - * :ref:`to_string(arithmetic)` + * :ref:`tinytc::to_string(checked_flag)` - * :ref:`to_string(arithmetic_unary)` + * :ref:`tinytc::to_string(cmp_condition)` - * :ref:`to_string(cmp_condition)` + * :ref:`tinytc::to_string(group_arithmetic)` - * :ref:`to_string(scalar_type)` + * :ref:`tinytc::to_string(group_operation)` - * :ref:`to_string(transpose)` + * :ref:`tinytc::to_string(math_unary)` - * :ref:`size` + * :ref:`tinytc::to_string(matrix_use)` + + * :ref:`tinytc::to_string(scalar_type)` + + * :ref:`tinytc::to_string(store_flag)` + + * :ref:`tinytc::to_string(transpose)` + + * :ref:`tinytc::size` * Classes - * :ref:`builder_error` + * :ref:`tinytc::builder_error` * Typedefs - * :ref:`position` + * :ref:`tinytc::location` - * :ref:`location` + * :ref:`tinytc::position` * Variables - * :ref:`dynamic` + * :ref:`tinytc::dynamic` Common Enumerations ------------------- +.. _tinytc::address_space: + +address_space +............. + +.. doxygenenum:: tinytc::address_space + +.. _tinytc::arithmetic: + arithmetic .......... .. doxygenenum:: tinytc::arithmetic +.. _tinytc::arithmetic_unary: + arithmetic_unary ................ .. doxygenenum:: tinytc::arithmetic_unary +.. _tinytc::builtin: + +builtin +....... + +.. doxygenenum:: tinytc::builtin + +.. _tinytc::cmp_condition: + cmp_condition ............. .. doxygenenum:: tinytc::cmp_condition +.. _tinytc::group_arithmetic: + +group_arithmetic +................ + +.. doxygenenum:: tinytc::group_arithmetic + +.. _tinytc::group_operation: + +group_operation +............... + +.. doxygenenum:: tinytc::group_operation + +.. _tinytc::math_unary: + +math_unary +.......... + +.. doxygenenum:: tinytc::math_unary + +.. _tinytc::matrix_use: + +matrix_use +.......... + +.. doxygenenum:: tinytc::matrix_use + +.. _tinytc::scalar_type: + scalar_type ........... .. doxygenenum:: tinytc::scalar_type +.. _tinytc::store_flag: + +store_flag +.......... + +.. doxygenenum:: tinytc::store_flag + +.. _tinytc::transpose: + transpose ......... @@ -81,36 +172,106 @@ transpose Common Functions ---------------- +.. _tinytc::is_dynamic_value: + is_dynamic_value ................ .. doxygenfunction:: tinytc::is_dynamic_value +.. _tinytc::to_string(address_space): + +to_string(address_space) +........................ + +.. doxygenfunction:: tinytc::to_string(address_space) + +.. _tinytc::to_string(arithmetic): + to_string(arithmetic) ..................... .. doxygenfunction:: tinytc::to_string(arithmetic) +.. _tinytc::to_string(arithmetic_unary): + to_string(arithmetic_unary) ........................... .. doxygenfunction:: tinytc::to_string(arithmetic_unary) +.. _tinytc::to_string(builtin): + +to_string(builtin) +.................. + +.. doxygenfunction:: tinytc::to_string(builtin) + +.. _tinytc::to_string(checked_flag): + +to_string(checked_flag) +....................... + +.. doxygenfunction:: tinytc::to_string(checked_flag) + +.. _tinytc::to_string(cmp_condition): + to_string(cmp_condition) ........................ .. doxygenfunction:: tinytc::to_string(cmp_condition) +.. _tinytc::to_string(group_arithmetic): + +to_string(group_arithmetic) +........................... + +.. doxygenfunction:: tinytc::to_string(group_arithmetic) + +.. _tinytc::to_string(group_operation): + +to_string(group_operation) +.......................... + +.. doxygenfunction:: tinytc::to_string(group_operation) + +.. _tinytc::to_string(math_unary): + +to_string(math_unary) +..................... + +.. doxygenfunction:: tinytc::to_string(math_unary) + +.. _tinytc::to_string(matrix_use): + +to_string(matrix_use) +..................... + +.. doxygenfunction:: tinytc::to_string(matrix_use) + +.. _tinytc::to_string(scalar_type): + to_string(scalar_type) ...................... .. doxygenfunction:: tinytc::to_string(scalar_type) +.. _tinytc::to_string(store_flag): + +to_string(store_flag) +..................... + +.. doxygenfunction:: tinytc::to_string(store_flag) + +.. _tinytc::to_string(transpose): + to_string(transpose) .................... .. doxygenfunction:: tinytc::to_string(transpose) +.. _tinytc::size: + size .... @@ -119,6 +280,8 @@ size Common Classes -------------- +.. _tinytc::builder_error: + builder_error ............. @@ -127,84 +290,223 @@ builder_error Common Typedefs --------------- -position -........ - -.. doxygentypedef:: tinytc::position +.. _tinytc::location: location ........ .. doxygentypedef:: tinytc::location +.. _tinytc::position: + +position +........ + +.. doxygentypedef:: tinytc::position + Common Variables ---------------- +.. _tinytc::dynamic: + dynamic ....... .. doxygenvariable:: tinytc::dynamic +Attribute +========= + +* Functions + + * :ref:`get_array_attr` + + * :ref:`get_boolean_attr` + + * :ref:`get_dictionary_attr` + + * :ref:`get_dictionary_attr_with_sorted` + + * :ref:`get_integer_attr` + + * :ref:`get_string_attr` + + * :ref:`sort_items` + +* Typedefs + + * :ref:`tinytc::attr` + + * :ref:`tinytc::named_attr` + +Attribute Functions +------------------- + +.. _get_array_attr: + +get_array_attr +.............. + +.. doxygenfunction:: get_array_attr + +.. _get_boolean_attr: + +get_boolean_attr +................ + +.. doxygenfunction:: get_boolean_attr + +.. _get_dictionary_attr: + +get_dictionary_attr +................... + +.. doxygenfunction:: get_dictionary_attr + +.. _get_dictionary_attr_with_sorted: + +get_dictionary_attr_with_sorted +............................... + +.. doxygenfunction:: get_dictionary_attr_with_sorted + +.. _get_integer_attr: + +get_integer_attr +................ + +.. doxygenfunction:: get_integer_attr + +.. _get_string_attr: + +get_string_attr +............... + +.. doxygenfunction:: get_string_attr + +.. _sort_items: + +sort_items +.......... + +.. doxygenfunction:: sort_items + +Attribute Typedefs +------------------ + +.. _tinytc::attr: + +attr +.... + +.. doxygentypedef:: tinytc::attr + +.. _tinytc::named_attr: + +named_attr +.......... + +.. doxygentypedef:: tinytc::named_attr + Data Type ========= * Functions - * :ref:`make_memref` + * :ref:`tinytc::get_boolean` - * :ref:`make_group` + * :ref:`tinytc::get_coopmatrix` - * :ref:`make_scalar` + * :ref:`tinytc::get_group` -* Classes + * :ref:`tinytc::get_memref` + + * :ref:`tinytc::get_scalar` - * :ref:`data_type` + * :ref:`tinytc::get_void` * Structures - * :ref:`to_scalar_type` + * :ref:`tinytc::to_scalar_type` + +* Typedefs + + * :ref:`tinytc::data_type` * Variables - * :ref:`to_scalar_type_v` + * :ref:`tinytc::to_scalar_type_v` Data Type Functions ------------------- -make_memref +.. _tinytc::get_boolean: + +get_boolean ........... -.. doxygenfunction:: tinytc::make_memref +.. doxygenfunction:: tinytc::get_boolean + +.. _tinytc::get_coopmatrix: + +get_coopmatrix +.............. + +.. doxygenfunction:: tinytc::get_coopmatrix + +.. _tinytc::get_group: + +get_group +......... + +.. doxygenfunction:: tinytc::get_group + +.. _tinytc::get_memref: -make_group +get_memref .......... -.. doxygenfunction:: tinytc::make_group +.. doxygenfunction:: tinytc::get_memref -make_scalar -........... +.. _tinytc::get_scalar: -.. doxygenfunction:: tinytc::make_scalar +get_scalar +.......... -Data Type Classes ------------------ +.. doxygenfunction:: tinytc::get_scalar -data_type -......... +.. _tinytc::get_void: + +get_void +........ -.. doxygenclass:: tinytc::data_type +.. doxygenfunction:: tinytc::get_void Data Type Structures -------------------- +.. _tinytc::to_scalar_type: + to_scalar_type .............. .. doxygenstruct:: tinytc::to_scalar_type +Data Type Typedefs +------------------ + +.. _tinytc::data_type: + +data_type +......... + +.. doxygentypedef:: tinytc::data_type + Data Type Variables ------------------- +.. _tinytc::to_scalar_type_v: + to_scalar_type_v ................ @@ -215,224 +517,426 @@ Function * Functions - * :ref:`make_function` - - * :ref:`make_function_prototype` - - * :ref:`set_work_group_size` - - * :ref:`set_subgroup_size` + * :ref:`tinytc::make_func` * Classes - * :ref:`func` - - * :ref:`function_builder` + * :ref:`tinytc::func` Function Functions ------------------ -make_function -............. - -.. doxygenfunction:: tinytc::make_function - -make_function_prototype -....................... - -.. doxygenfunction:: tinytc::make_function_prototype - -set_work_group_size -................... - -.. doxygenfunction:: tinytc::set_work_group_size +.. _tinytc::make_func: -set_subgroup_size -................. +make_func +......... -.. doxygenfunction:: tinytc::set_subgroup_size +.. doxygenfunction:: tinytc::make_func Function Classes ---------------- +.. _tinytc::func: + func .... .. doxygenclass:: tinytc::func -function_builder -................ - -.. doxygenclass:: tinytc::function_builder - Instruction =========== * Functions - * :ref:`make_alloca` + * :ref:`tinytc::make_alloca` + + * :ref:`tinytc::make_axpby` + + * :ref:`tinytc::make_arith(arithmetic,value,value,data_type,location const&)` + + * :ref:`tinytc::make_arith(arithmetic_unary,value,data_type,location const&)` + + * :ref:`tinytc::make_barrier` + + * :ref:`tinytc::make_builtin` + + * :ref:`tinytc::make_cast` + + * :ref:`tinytc::make_cmp` + + * :ref:`tinytc::make_constant(bool,data_type,location const&)` + + * :ref:`tinytc::make_constant(std::complex\,data_type,location const&)` + + * :ref:`tinytc::make_constant(double,data_type,location const&)` + + * :ref:`tinytc::make_constant(std::int32_t,data_type,location const&)` + + * :ref:`tinytc::make_constant(std::int64_t,data_type,location const&)` + + * :ref:`tinytc::make_constant_one` + + * :ref:`tinytc::make_constant_zero` - * :ref:`make_axpby` + * :ref:`tinytc::make_cooperative_matrix_apply` - * :ref:`make_arith(arithmetic,value const&,value const&,location const&)` + * :ref:`tinytc::make_cooperative_matrix_extract` - * :ref:`make_arith(arithmetic_unary,value const&,location const&)` + * :ref:`tinytc::make_cooperative_matrix_insert` - * :ref:`make_cast` + * :ref:`tinytc::make_cooperative_matrix_load` - * :ref:`make_cmp` + * :ref:`tinytc::make_cooperative_matrix_mul_add` - * :ref:`make_expand` + * :ref:`tinytc::make_cooperative_matrix_prefetch` - * :ref:`make_for` + * :ref:`tinytc::make_cooperative_matrix_scale` - * :ref:`make_foreach` + * :ref:`tinytc::make_cooperative_matrix_store` - * :ref:`make_fuse` + * :ref:`tinytc::make_cumsum` - * :ref:`make_gemm` + * :ref:`tinytc::make_expand` - * :ref:`make_gemv` + * :ref:`tinytc::make_for` - * :ref:`make_ger` + * :ref:`tinytc::make_foreach` - * :ref:`make_group_id` + * :ref:`tinytc::make_fuse` - * :ref:`make_group_size` + * :ref:`tinytc::make_gemm` - * :ref:`make_hadamard` + * :ref:`tinytc::make_gemv` - * :ref:`make_if` + * :ref:`tinytc::make_ger` - * :ref:`make_load` + * :ref:`tinytc::make_hadamard` - * :ref:`make_size` + * :ref:`tinytc::make_if` - * :ref:`make_store` + * :ref:`tinytc::make_load` - * :ref:`make_subview` + * :ref:`tinytc::make_math(math_unary,value,data_type,location const&)` - * :ref:`make_sum` + * :ref:`tinytc::make_parallel` - * :ref:`make_yield` + * :ref:`tinytc::make_size` + + * :ref:`tinytc::make_store` + + * :ref:`tinytc::make_subgroup_broadcast` + + * :ref:`tinytc::make_subgroup_operation` + + * :ref:`tinytc::make_subview` + + * :ref:`tinytc::make_sum` + + * :ref:`tinytc::make_yield` * Classes - * :ref:`inst` + * :ref:`tinytc::inst` Instruction Functions --------------------- +.. _tinytc::make_alloca: + make_alloca ........... .. doxygenfunction:: tinytc::make_alloca +.. _tinytc::make_axpby: + make_axpby .......... .. doxygenfunction:: tinytc::make_axpby -make_arith(arithmetic,value const&,value const&,location const&) -................................................................ +.. _tinytc::make_arith(arithmetic,value,value,data_type,location const&): + +make_arith(arithmetic,value,value,data_type,location const&) +............................................................ + +.. doxygenfunction:: tinytc::make_arith(arithmetic,value,value,data_type,location const&) + +.. _tinytc::make_arith(arithmetic_unary,value,data_type,location const&): + +make_arith(arithmetic_unary,value,data_type,location const&) +............................................................ -.. doxygenfunction:: tinytc::make_arith(arithmetic,value const&,value const&,location const&) +.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value,data_type,location const&) -make_arith(arithmetic_unary,value const&,location const&) -......................................................... +.. _tinytc::make_barrier: -.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value const&,location const&) +make_barrier +............ + +.. doxygenfunction:: tinytc::make_barrier + +.. _tinytc::make_builtin: + +make_builtin +............ + +.. doxygenfunction:: tinytc::make_builtin + +.. _tinytc::make_cast: make_cast ......... .. doxygenfunction:: tinytc::make_cast +.. _tinytc::make_cmp: + make_cmp ........ .. doxygenfunction:: tinytc::make_cmp +.. _tinytc::make_constant(bool,data_type,location const&): + +make_constant(bool,data_type,location const&) +............................................. + +.. doxygenfunction:: tinytc::make_constant(bool,data_type,location const&) + +.. _tinytc::make_constant(std::complex\,data_type,location const&): + +make_constant(std::complex,data_type,location const&) +............................................................. + +.. doxygenfunction:: tinytc::make_constant(std::complex,data_type,location const&) + +.. _tinytc::make_constant(double,data_type,location const&): + +make_constant(double,data_type,location const&) +............................................... + +.. doxygenfunction:: tinytc::make_constant(double,data_type,location const&) + +.. _tinytc::make_constant(std::int32_t,data_type,location const&): + +make_constant(std::int32_t,data_type,location const&) +..................................................... + +.. doxygenfunction:: tinytc::make_constant(std::int32_t,data_type,location const&) + +.. _tinytc::make_constant(std::int64_t,data_type,location const&): + +make_constant(std::int64_t,data_type,location const&) +..................................................... + +.. doxygenfunction:: tinytc::make_constant(std::int64_t,data_type,location const&) + +.. _tinytc::make_constant_one: + +make_constant_one +................. + +.. doxygenfunction:: tinytc::make_constant_one + +.. _tinytc::make_constant_zero: + +make_constant_zero +.................. + +.. doxygenfunction:: tinytc::make_constant_zero + +.. _tinytc::make_cooperative_matrix_apply: + +make_cooperative_matrix_apply +............................. + +.. doxygenfunction:: tinytc::make_cooperative_matrix_apply + +.. _tinytc::make_cooperative_matrix_extract: + +make_cooperative_matrix_extract +............................... + +.. doxygenfunction:: tinytc::make_cooperative_matrix_extract + +.. _tinytc::make_cooperative_matrix_insert: + +make_cooperative_matrix_insert +.............................. + +.. doxygenfunction:: tinytc::make_cooperative_matrix_insert + +.. _tinytc::make_cooperative_matrix_load: + +make_cooperative_matrix_load +............................ + +.. doxygenfunction:: tinytc::make_cooperative_matrix_load + +.. _tinytc::make_cooperative_matrix_mul_add: + +make_cooperative_matrix_mul_add +............................... + +.. doxygenfunction:: tinytc::make_cooperative_matrix_mul_add + +.. _tinytc::make_cooperative_matrix_prefetch: + +make_cooperative_matrix_prefetch +................................ + +.. doxygenfunction:: tinytc::make_cooperative_matrix_prefetch + +.. _tinytc::make_cooperative_matrix_scale: + +make_cooperative_matrix_scale +............................. + +.. doxygenfunction:: tinytc::make_cooperative_matrix_scale + +.. _tinytc::make_cooperative_matrix_store: + +make_cooperative_matrix_store +............................. + +.. doxygenfunction:: tinytc::make_cooperative_matrix_store + +.. _tinytc::make_cumsum: + +make_cumsum +........... + +.. doxygenfunction:: tinytc::make_cumsum + +.. _tinytc::make_expand: + make_expand ........... .. doxygenfunction:: tinytc::make_expand +.. _tinytc::make_for: + make_for ........ .. doxygenfunction:: tinytc::make_for +.. _tinytc::make_foreach: + make_foreach ............ .. doxygenfunction:: tinytc::make_foreach +.. _tinytc::make_fuse: + make_fuse ......... .. doxygenfunction:: tinytc::make_fuse +.. _tinytc::make_gemm: + make_gemm ......... .. doxygenfunction:: tinytc::make_gemm +.. _tinytc::make_gemv: + make_gemv ......... .. doxygenfunction:: tinytc::make_gemv +.. _tinytc::make_ger: + make_ger ........ .. doxygenfunction:: tinytc::make_ger -make_group_id -............. - -.. doxygenfunction:: tinytc::make_group_id - -make_group_size -............... - -.. doxygenfunction:: tinytc::make_group_size +.. _tinytc::make_hadamard: make_hadamard ............. .. doxygenfunction:: tinytc::make_hadamard +.. _tinytc::make_if: + make_if ....... .. doxygenfunction:: tinytc::make_if +.. _tinytc::make_load: + make_load ......... .. doxygenfunction:: tinytc::make_load +.. _tinytc::make_math(math_unary,value,data_type,location const&): + +make_math(math_unary,value,data_type,location const&) +..................................................... + +.. doxygenfunction:: tinytc::make_math(math_unary,value,data_type,location const&) + +.. _tinytc::make_parallel: + +make_parallel +............. + +.. doxygenfunction:: tinytc::make_parallel + +.. _tinytc::make_size: + make_size ......... .. doxygenfunction:: tinytc::make_size +.. _tinytc::make_store: + make_store .......... .. doxygenfunction:: tinytc::make_store +.. _tinytc::make_subgroup_broadcast: + +make_subgroup_broadcast +....................... + +.. doxygenfunction:: tinytc::make_subgroup_broadcast + +.. _tinytc::make_subgroup_operation: + +make_subgroup_operation +....................... + +.. doxygenfunction:: tinytc::make_subgroup_operation + +.. _tinytc::make_subview: + make_subview ............ .. doxygenfunction:: tinytc::make_subview +.. _tinytc::make_sum: + make_sum ........ .. doxygenfunction:: tinytc::make_sum +.. _tinytc::make_yield: + make_yield .......... @@ -441,6 +945,8 @@ make_yield Instruction Classes ------------------- +.. _tinytc::inst: + inst .... @@ -451,64 +957,76 @@ Program * Functions - * :ref:`make_program` + * :ref:`tinytc::make_prog` * Classes - * :ref:`prog` - - * :ref:`program_builder` + * :ref:`tinytc::prog` Program Functions ----------------- -make_program -............ +.. _tinytc::make_prog: + +make_prog +......... -.. doxygenfunction:: tinytc::make_program +.. doxygenfunction:: tinytc::make_prog Program Classes --------------- +.. _tinytc::prog: + prog .... .. doxygenclass:: tinytc::prog -program_builder -............... - -.. doxygenclass:: tinytc::program_builder - Region ====== * Functions - * :ref:`make_region` + * :ref:`tinytc::next` + + * :ref:`tinytc::prev` * Classes - * :ref:`region` + * :ref:`tinytc::region` - * :ref:`region_builder` + * :ref:`tinytc::region_builder` Region Functions ---------------- -make_region -........... +.. _tinytc::next: -.. doxygenfunction:: tinytc::make_region +next +.... + +.. doxygenfunction:: tinytc::next + +.. _tinytc::prev: + +prev +.... + +.. doxygenfunction:: tinytc::prev Region Classes -------------- +.. _tinytc::region: + region ...... .. doxygenclass:: tinytc::region +.. _tinytc::region_builder: + region_builder .............. @@ -517,95 +1035,15 @@ region_builder Value ===== -* Functions - - * :ref:`make_dynamic(location const&)` - - * :ref:`make_imm(float,location const&)` - - * :ref:`make_imm(double,scalar_type,location const&)` - - * :ref:`make_imm(std::int8_t,location const&)` - - * :ref:`make_imm(std::int16_t,location const&)` - - * :ref:`make_imm(std::int32_t,location const&)` - - * :ref:`make_imm(std::int64_t,scalar_type,location const&)` - - * :ref:`make_index(std::int32_t,location const&)` - - * :ref:`make_index(std::int64_t,location const&)` - - * :ref:`make_value(data_type const&,location const&)` - - * :ref:`make_value(scalar_type,location const&)` - * Classes - * :ref:`value` - -Value Functions ---------------- - -make_dynamic(location const&) -............................. - -.. doxygenfunction:: tinytc::make_dynamic(location const&) - -make_imm(float,location const&) -............................... - -.. doxygenfunction:: tinytc::make_imm(float,location const&) - -make_imm(double,scalar_type,location const&) -............................................ - -.. doxygenfunction:: tinytc::make_imm(double,scalar_type,location const&) - -make_imm(std::int8_t,location const&) -..................................... - -.. doxygenfunction:: tinytc::make_imm(std::int8_t,location const&) - -make_imm(std::int16_t,location const&) -...................................... - -.. doxygenfunction:: tinytc::make_imm(std::int16_t,location const&) - -make_imm(std::int32_t,location const&) -...................................... - -.. doxygenfunction:: tinytc::make_imm(std::int32_t,location const&) - -make_imm(std::int64_t,scalar_type,location const&) -.................................................. - -.. doxygenfunction:: tinytc::make_imm(std::int64_t,scalar_type,location const&) - -make_index(std::int32_t,location const&) -........................................ - -.. doxygenfunction:: tinytc::make_index(std::int32_t,location const&) - -make_index(std::int64_t,location const&) -........................................ - -.. doxygenfunction:: tinytc::make_index(std::int64_t,location const&) - -make_value(data_type const&,location const&) -............................................ - -.. doxygenfunction:: tinytc::make_value(data_type const&,location const&) - -make_value(scalar_type,location const&) -....................................... - -.. doxygenfunction:: tinytc::make_value(scalar_type,location const&) + * :ref:`tinytc::value` Value Classes ------------- +.. _tinytc::value: + value ..... diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index b2358d2a..372dc214 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -3,54 +3,100 @@ Builder C++-API: Common: enum: + - tinytc::address_space - tinytc::arithmetic - tinytc::arithmetic_unary + - tinytc::builtin - tinytc::cmp_condition + - tinytc::group_arithmetic + - tinytc::group_operation + - tinytc::math_unary + - tinytc::matrix_use + - tinytc::reduce_mode - tinytc::scalar_type + - tinytc::store_flag - tinytc::transpose function: - tinytc::is_dynamic_value + - tinytc::to_string(address_space) - tinytc::to_string(arithmetic) - tinytc::to_string(arithmetic_unary) + - tinytc::to_string(builtin) + - tinytc::to_string(checked_flag) - tinytc::to_string(cmp_condition) + - tinytc::to_string(group_arithmetic) + - tinytc::to_string(group_operation) + - tinytc::to_string(math_unary) + - tinytc::to_string(matrix_use) + - tinytc::to_string(reduce_mode) - tinytc::to_string(scalar_type) + - tinytc::to_string(store_flag) - tinytc::to_string(transpose) - tinytc::size class: - tinytc::builder_error typedef: - - tinytc::position - tinytc::location + - tinytc::position variable: - tinytc::dynamic + Attribute: + function: + - get_array_attr + - get_boolean_attr + - get_dictionary_attr + - get_dictionary_attr_with_sorted + - get_integer_attr + - get_string_attr + - sort_items + typedef: + - tinytc::attr + - tinytc::named_attr Data Type: function: - - tinytc::make_memref - - tinytc::make_group - - tinytc::make_scalar - class: - - tinytc::data_type + - tinytc::get_boolean + - tinytc::get_coopmatrix + - tinytc::get_group + - tinytc::get_memref + - tinytc::get_scalar + - tinytc::get_void struct: - tinytc::to_scalar_type + typedef: + - tinytc::data_type variable: - tinytc::to_scalar_type_v Function: function: - - tinytc::make_function - - tinytc::make_function_prototype - - tinytc::set_work_group_size - - tinytc::set_subgroup_size + - tinytc::make_func class: - tinytc::func - - tinytc::function_builder Instruction: function: - tinytc::make_alloca - tinytc::make_axpby - - tinytc::make_arith(arithmetic,value const&,value const&,location const&) - - tinytc::make_arith(arithmetic_unary,value const&,location const&) + - tinytc::make_arith(arithmetic,value,value,data_type,location const&) + - tinytc::make_arith(arithmetic_unary,value,data_type,location const&) + - tinytc::make_barrier + - tinytc::make_builtin - tinytc::make_cast - tinytc::make_cmp + - tinytc::make_constant(bool,data_type,location const&) + - tinytc::make_constant(std::complex,data_type,location const&) + - tinytc::make_constant(double,data_type,location const&) + - tinytc::make_constant(std::int32_t,data_type,location const&) + - tinytc::make_constant(std::int64_t,data_type,location const&) + - tinytc::make_constant_one + - tinytc::make_constant_zero + - tinytc::make_cooperative_matrix_apply + - tinytc::make_cooperative_matrix_extract + - tinytc::make_cooperative_matrix_insert + - tinytc::make_cooperative_matrix_load + - tinytc::make_cooperative_matrix_mul_add + - tinytc::make_cooperative_matrix_prefetch + - tinytc::make_cooperative_matrix_scale + - tinytc::make_cooperative_matrix_store + - tinytc::make_cumsum - tinytc::make_expand - tinytc::make_for - tinytc::make_foreach @@ -58,13 +104,15 @@ Builder C++-API: - tinytc::make_gemm - tinytc::make_gemv - tinytc::make_ger - - tinytc::make_group_id - - tinytc::make_group_size - tinytc::make_hadamard - tinytc::make_if - tinytc::make_load + - tinytc::make_math(math_unary,value,data_type,location const&) + - tinytc::make_parallel - tinytc::make_size - tinytc::make_store + - tinytc::make_subgroup_broadcast + - tinytc::make_subgroup_operation - tinytc::make_subview - tinytc::make_sum - tinytc::make_yield @@ -72,28 +120,16 @@ Builder C++-API: - tinytc::inst Program: function: - - tinytc::make_program + - tinytc::make_prog class: - tinytc::prog - - tinytc::program_builder Region: function: - - tinytc::make_region + - tinytc::next + - tinytc::prev class: - tinytc::region - tinytc::region_builder Value: - function: - - tinytc::make_dynamic(location const&) - - tinytc::make_imm(float,location const&) - - tinytc::make_imm(double,scalar_type,location const&) - - tinytc::make_imm(std::int8_t,location const&) - - tinytc::make_imm(std::int16_t,location const&) - - tinytc::make_imm(std::int32_t,location const&) - - tinytc::make_imm(std::int64_t,scalar_type,location const&) - - tinytc::make_index(std::int32_t,location const&) - - tinytc::make_index(std::int64_t,location const&) - - tinytc::make_value(data_type const&,location const&) - - tinytc::make_value(scalar_type,location const&) class: - tinytc::value diff --git a/docs/api/cl/capi.rst b/docs/api/cl/capi.rst index f4d2d8a2..a2e70e6b 100644 --- a/docs/api/cl/capi.rst +++ b/docs/api/cl/capi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _OpenCL C-API: + ===== C-API ===== @@ -15,6 +17,8 @@ Common Common Functions ---------------- +.. _tinytc_cl_convert_status: + tinytc_cl_convert_status ........................ @@ -32,11 +36,15 @@ Device Info Device Info Functions --------------------- +.. _tinytc_cl_core_info_create: + tinytc_cl_core_info_create .......................... .. doxygenfunction:: tinytc_cl_core_info_create +.. _tinytc_cl_get_support_level: + tinytc_cl_get_support_level ........................... @@ -51,8 +59,6 @@ Kernel * :ref:`tinytc_cl_get_group_size` - * :ref:`tinytc_cl_kernel_bundle_create_with_source` - * :ref:`tinytc_cl_kernel_bundle_create_with_program` * :ref:`tinytc_cl_kernel_bundle_create_with_binary` @@ -60,26 +66,29 @@ Kernel Kernel Functions ---------------- +.. _tinytc_cl_get_global_size: + tinytc_cl_get_global_size ......................... .. doxygenfunction:: tinytc_cl_get_global_size +.. _tinytc_cl_get_group_size: + tinytc_cl_get_group_size ........................ .. doxygenfunction:: tinytc_cl_get_group_size -tinytc_cl_kernel_bundle_create_with_source -.......................................... - -.. doxygenfunction:: tinytc_cl_kernel_bundle_create_with_source +.. _tinytc_cl_kernel_bundle_create_with_program: tinytc_cl_kernel_bundle_create_with_program ........................................... .. doxygenfunction:: tinytc_cl_kernel_bundle_create_with_program +.. _tinytc_cl_kernel_bundle_create_with_binary: + tinytc_cl_kernel_bundle_create_with_binary .......................................... @@ -97,11 +106,15 @@ Recipe Recipe Functions ---------------- +.. _tinytc_cl_recipe_handler_create: + tinytc_cl_recipe_handler_create ............................... .. doxygenfunction:: tinytc_cl_recipe_handler_create +.. _tinytc_cl_recipe_handler_submit: + tinytc_cl_recipe_handler_submit ............................... diff --git a/docs/api/cl/capi.yaml b/docs/api/cl/capi.yaml index 67ed7e61..89e60ac7 100644 --- a/docs/api/cl/capi.yaml +++ b/docs/api/cl/capi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C-API: +C-API : Common: function: - tinytc_cl_convert_status @@ -12,7 +12,6 @@ C-API: function: - tinytc_cl_get_global_size - tinytc_cl_get_group_size - - tinytc_cl_kernel_bundle_create_with_source - tinytc_cl_kernel_bundle_create_with_program - tinytc_cl_kernel_bundle_create_with_binary Recipe: diff --git a/docs/api/cl/cxxapi.rst b/docs/api/cl/cxxapi.rst index 2245e883..b6350da2 100644 --- a/docs/api/cl/cxxapi.rst +++ b/docs/api/cl/cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _OpenCL C++-API: + ======= C++-API ======= @@ -10,11 +12,13 @@ Common * Functions - * :ref:`CL_CHECK_STATUS` + * :ref:`tinytc::CL_CHECK_STATUS` Common Functions ---------------- +.. _tinytc::CL_CHECK_STATUS: + CL_CHECK_STATUS ............... @@ -25,18 +29,22 @@ Device Info * Functions - * :ref:`get_support_level(cl_device_id)` + * :ref:`tinytc::get_support_level(cl_device_id)` - * :ref:`make_core_info(cl_device_id)` + * :ref:`tinytc::make_core_info(cl_device_id)` Device Info Functions --------------------- +.. _tinytc::get_support_level(cl_device_id): + get_support_level(cl_device_id) ............................... .. doxygenfunction:: tinytc::get_support_level(cl_device_id) +.. _tinytc::make_core_info(cl_device_id): + make_core_info(cl_device_id) ............................ @@ -47,77 +55,84 @@ Kernel * Functions - * :ref:`get_global_size(std::int64_t,std::array\ const &)` - - * :ref:`get_group_size(cl_kernel)` + * :ref:`tinytc::get_global_size(std::array\ const &,std::array\ const &)` - * :ref:`make_kernel(cl_program,char const\\*)` + * :ref:`tinytc::get_group_size(cl_kernel)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context)` + * :ref:`tinytc::make_kernel(cl_program,char const\*)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,source const&,source_context)` + * :ref:`tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t)` Kernel Functions ---------------- -get_global_size(std::int64_t,std::array const &) -................................................................. +.. _tinytc::get_global_size(std::array\ const &,std::array\ const &): + +get_global_size(std::array const &,std::array const &) +........................................................................................ -.. doxygenfunction:: tinytc::get_global_size(std::int64_t,std::array const &) +.. doxygenfunction:: tinytc::get_global_size(std::array const &,std::array const &) + +.. _tinytc::get_group_size(cl_kernel): get_group_size(cl_kernel) ......................... .. doxygenfunction:: tinytc::get_group_size(cl_kernel) +.. _tinytc::make_kernel(cl_program,char const\*): + make_kernel(cl_program,char const\*) .................................... .. doxygenfunction:: tinytc::make_kernel(cl_program,char const*) -make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) -........................................................................ +.. _tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&): -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) +make_kernel_bundle(cl_context,cl_device_id,binary const&) +......................................................... -make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) -........................................................................................... +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&) -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) +.. _tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t): -make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) -........................................................................ +make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) +............................................................................ -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) Recipe ====== * Functions - * :ref:`make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context)` + * :ref:`tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&)` * Classes - * :ref:`opencl_recipe_handler` + * :ref:`tinytc::opencl_recipe_handler` * Structures - * :ref:`auto_mem_type\` + * :ref:`tinytc::auto_mem_type\< cl_mem \>` Recipe Functions ---------------- -make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) -......................................................................... +.. _tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&): + +make_recipe_handler(cl_context,cl_device_id,recipe const&) +.......................................................... -.. doxygenfunction:: tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&) Recipe Classes -------------- +.. _tinytc::opencl_recipe_handler: + opencl_recipe_handler ..................... @@ -126,6 +141,8 @@ opencl_recipe_handler Recipe Structures ----------------- +.. _tinytc::auto_mem_type\< cl_mem \>: + auto_mem_type ..................... diff --git a/docs/api/cl/cxxapi.yaml b/docs/api/cl/cxxapi.yaml index a61f974d..f41717cd 100644 --- a/docs/api/cl/cxxapi.yaml +++ b/docs/api/cl/cxxapi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C++-API: +C++-API : Common: function: - tinytc::CL_CHECK_STATUS @@ -10,15 +10,14 @@ C++-API: - tinytc::make_core_info(cl_device_id) Kernel: function: - - tinytc::get_global_size(std::int64_t,std::array const &) + - tinytc::get_global_size(std::array const &,std::array const &) - tinytc::get_group_size(cl_kernel) - tinytc::make_kernel(cl_program,char const*) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t) Recipe: function: - - tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) + - tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&) class: - tinytc::opencl_recipe_handler struct: diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 81c1901d..69a66afd 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Core C-API: + ========== Core C-API ========== @@ -46,9 +48,9 @@ Common * :ref:`tinytc_recipe_handler_t` - * :ref:`tinytc_source_t` + * :ref:`tinytc_spv_mod_t` - * :ref:`tinytc_source_context_t` + * :ref:`tinytc_compiler_context_t` * :ref:`const_tinytc_binary_t` @@ -58,18 +60,24 @@ Common * :ref:`const_tinytc_recipe_handler_t` - * :ref:`const_tinytc_source_t` + * :ref:`const_tinytc_spv_mod_t` + + * :ref:`const_tinytc_compiler_context_t` - * :ref:`const_tinytc_source_context_t` + * :ref:`tinytc_error_reporter_t` Common Enumerations ------------------- +.. _tinytc_status_t: + tinytc_status_t ............... .. doxygenenum:: tinytc_status_t +.. _tinytc_support_level_t: + tinytc_support_level_t ...................... @@ -78,31 +86,43 @@ tinytc_support_level_t Common Definitions ------------------ +.. _TINYTC_VERSION_MAJOR: + TINYTC_VERSION_MAJOR .................... .. doxygendefine:: TINYTC_VERSION_MAJOR +.. _TINYTC_VERSION_MINOR: + TINYTC_VERSION_MINOR .................... .. doxygendefine:: TINYTC_VERSION_MINOR +.. _TINYTC_VERSION_PATCH: + TINYTC_VERSION_PATCH .................... .. doxygendefine:: TINYTC_VERSION_PATCH +.. _TINYTC_VERSION_HASH: + TINYTC_VERSION_HASH ................... .. doxygendefine:: TINYTC_VERSION_HASH +.. _TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE: + TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE .............................................. .. doxygendefine:: TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE +.. _TINYTC_VERSION_DESCRIPTION: + TINYTC_VERSION_DESCRIPTION .......................... @@ -111,11 +131,15 @@ TINYTC_VERSION_DESCRIPTION Common Functions ---------------- +.. _tinytc_error_string: + tinytc_error_string ................... .. doxygenfunction:: tinytc_error_string +.. _tinytc_string_destroy: + tinytc_string_destroy ..................... @@ -124,70 +148,103 @@ tinytc_string_destroy Common Typedefs --------------- +.. _tinytc_binary_t: + tinytc_binary_t ............... .. doxygentypedef:: tinytc_binary_t +.. _tinytc_bool_t: + tinytc_bool_t ............. .. doxygentypedef:: tinytc_bool_t +.. _tinytc_core_info_t: + tinytc_core_info_t .................. .. doxygentypedef:: tinytc_core_info_t +.. _tinytc_recipe_t: + tinytc_recipe_t ............... .. doxygentypedef:: tinytc_recipe_t +.. _tinytc_recipe_handler_t: + tinytc_recipe_handler_t ....................... .. doxygentypedef:: tinytc_recipe_handler_t -tinytc_source_t -............... +.. _tinytc_spv_mod_t: -.. doxygentypedef:: tinytc_source_t +tinytc_spv_mod_t +................ -tinytc_source_context_t -....................... +.. doxygentypedef:: tinytc_spv_mod_t + +.. _tinytc_compiler_context_t: + +tinytc_compiler_context_t +......................... -.. doxygentypedef:: tinytc_source_context_t +.. doxygentypedef:: tinytc_compiler_context_t + +.. _const_tinytc_binary_t: const_tinytc_binary_t ..................... .. doxygentypedef:: const_tinytc_binary_t +.. _const_tinytc_core_info_t: + const_tinytc_core_info_t ........................ .. doxygentypedef:: const_tinytc_core_info_t +.. _const_tinytc_recipe_t: + const_tinytc_recipe_t ..................... .. doxygentypedef:: const_tinytc_recipe_t +.. _const_tinytc_recipe_handler_t: + const_tinytc_recipe_handler_t ............................. .. doxygentypedef:: const_tinytc_recipe_handler_t -const_tinytc_source_t -..................... +.. _const_tinytc_spv_mod_t: + +const_tinytc_spv_mod_t +...................... -.. doxygentypedef:: const_tinytc_source_t +.. doxygentypedef:: const_tinytc_spv_mod_t -const_tinytc_source_context_t -............................. +.. _const_tinytc_compiler_context_t: -.. doxygentypedef:: const_tinytc_source_context_t +const_tinytc_compiler_context_t +............................... + +.. doxygentypedef:: const_tinytc_compiler_context_t + +.. _tinytc_error_reporter_t: + +tinytc_error_reporter_t +....................... + +.. doxygentypedef:: tinytc_error_reporter_t Binary ====== @@ -196,6 +253,8 @@ Binary * :ref:`tinytc_binary_create` + * :ref:`tinytc_binary_get_compiler_context` + * :ref:`tinytc_binary_get_core_features` * :ref:`tinytc_binary_get_raw` @@ -207,26 +266,43 @@ Binary Binary Functions ---------------- +.. _tinytc_binary_create: + tinytc_binary_create .................... .. doxygenfunction:: tinytc_binary_create +.. _tinytc_binary_get_compiler_context: + +tinytc_binary_get_compiler_context +.................................. + +.. doxygenfunction:: tinytc_binary_get_compiler_context + +.. _tinytc_binary_get_core_features: + tinytc_binary_get_core_features ............................... .. doxygenfunction:: tinytc_binary_get_core_features +.. _tinytc_binary_get_raw: + tinytc_binary_get_raw ..................... .. doxygenfunction:: tinytc_binary_get_raw +.. _tinytc_binary_release: + tinytc_binary_release ..................... .. doxygenfunction:: tinytc_binary_release +.. _tinytc_binary_retain: + tinytc_binary_retain .................... @@ -239,25 +315,154 @@ Compiler * :ref:`tinytc_bundle_format_t` + * :ref:`tinytc_optflag_t` + * Functions - * :ref:`tinytc_prog_compile_to_opencl` + * :ref:`tinytc_run_function_pass` + + * :ref:`tinytc_list_function_passes` + + * :ref:`tinytc_prog_compile_to_spirv` + + * :ref:`tinytc_prog_compile_to_spirv_and_assemble` + + * :ref:`tinytc_spirv_assemble` Compiler Enumerations --------------------- +.. _tinytc_bundle_format_t: + tinytc_bundle_format_t ...................... .. doxygenenum:: tinytc_bundle_format_t +.. _tinytc_optflag_t: + +tinytc_optflag_t +................ + +.. doxygenenum:: tinytc_optflag_t + Compiler Functions ------------------ -tinytc_prog_compile_to_opencl -............................. +.. _tinytc_run_function_pass: + +tinytc_run_function_pass +........................ + +.. doxygenfunction:: tinytc_run_function_pass + +.. _tinytc_list_function_passes: + +tinytc_list_function_passes +........................... + +.. doxygenfunction:: tinytc_list_function_passes + +.. _tinytc_prog_compile_to_spirv: + +tinytc_prog_compile_to_spirv +............................ -.. doxygenfunction:: tinytc_prog_compile_to_opencl +.. doxygenfunction:: tinytc_prog_compile_to_spirv + +.. _tinytc_prog_compile_to_spirv_and_assemble: + +tinytc_prog_compile_to_spirv_and_assemble +......................................... + +.. doxygenfunction:: tinytc_prog_compile_to_spirv_and_assemble + +.. _tinytc_spirv_assemble: + +tinytc_spirv_assemble +..................... + +.. doxygenfunction:: tinytc_spirv_assemble + +Compiler Context +================ + +* Functions + + * :ref:`tinytc_compiler_context_create` + + * :ref:`tinytc_compiler_context_add_source` + + * :ref:`tinytc_compiler_context_set_error_reporter` + + * :ref:`tinytc_compiler_context_set_optimization_flag` + + * :ref:`tinytc_compiler_context_set_optimization_level` + + * :ref:`tinytc_compiler_context_report_error` + + * :ref:`tinytc_compiler_context_release` + + * :ref:`tinytc_compiler_context_retain` + +Compiler Context Functions +-------------------------- + +.. _tinytc_compiler_context_create: + +tinytc_compiler_context_create +.............................. + +.. doxygenfunction:: tinytc_compiler_context_create + +.. _tinytc_compiler_context_add_source: + +tinytc_compiler_context_add_source +.................................. + +.. doxygenfunction:: tinytc_compiler_context_add_source + +.. _tinytc_compiler_context_set_error_reporter: + +tinytc_compiler_context_set_error_reporter +.......................................... + +.. doxygenfunction:: tinytc_compiler_context_set_error_reporter + +.. _tinytc_compiler_context_set_optimization_flag: + +tinytc_compiler_context_set_optimization_flag +............................................. + +.. doxygenfunction:: tinytc_compiler_context_set_optimization_flag + +.. _tinytc_compiler_context_set_optimization_level: + +tinytc_compiler_context_set_optimization_level +.............................................. + +.. doxygenfunction:: tinytc_compiler_context_set_optimization_level + +.. _tinytc_compiler_context_report_error: + +tinytc_compiler_context_report_error +.................................... + +.. doxygenfunction:: tinytc_compiler_context_report_error + +.. _tinytc_compiler_context_release: + +tinytc_compiler_context_release +............................... + +.. doxygenfunction:: tinytc_compiler_context_release + +.. _tinytc_compiler_context_retain: + +tinytc_compiler_context_retain +.............................. + +.. doxygenfunction:: tinytc_compiler_context_retain Device Info =========== @@ -268,26 +473,40 @@ Device Info * :ref:`tinytc_intel_gpu_architecture_t` + * :ref:`tinytc_spirv_feature_t` + * Functions + * :ref:`tinytc_core_info_generic_create` + * :ref:`tinytc_core_info_get_core_features` + * :ref:`tinytc_core_info_get_default_alignment` + * :ref:`tinytc_core_info_get_register_space` * :ref:`tinytc_core_info_get_subgroup_sizes` - * :ref:`tinytc_core_info_set_core_features` - - * :ref:`tinytc_core_info_generic_create` + * :ref:`tinytc_core_info_have_spirv_feature` * :ref:`tinytc_core_info_intel_create` * :ref:`tinytc_core_info_intel_create_from_arch` + * :ref:`tinytc_core_info_intel_create_from_name` + * :ref:`tinytc_core_info_release` * :ref:`tinytc_core_info_retain` + * :ref:`tinytc_core_info_set_core_features` + + * :ref:`tinytc_core_info_set_default_alignment` + + * :ref:`tinytc_core_info_set_spirv_feature` + + * :ref:`tinytc_spirv_feature_to_string` + * Typedefs * :ref:`tinytc_core_feature_flags_t` @@ -295,72 +514,189 @@ Device Info Device Info Enumerations ------------------------ +.. _tinytc_core_feature_flag_t: + tinytc_core_feature_flag_t .......................... .. doxygenenum:: tinytc_core_feature_flag_t +.. _tinytc_intel_gpu_architecture_t: + tinytc_intel_gpu_architecture_t ............................... .. doxygenenum:: tinytc_intel_gpu_architecture_t +.. _tinytc_spirv_feature_t: + +tinytc_spirv_feature_t +...................... + +.. doxygenenum:: tinytc_spirv_feature_t + Device Info Functions --------------------- +.. _tinytc_core_info_generic_create: + +tinytc_core_info_generic_create +............................... + +.. doxygenfunction:: tinytc_core_info_generic_create + +.. _tinytc_core_info_get_core_features: + tinytc_core_info_get_core_features .................................. .. doxygenfunction:: tinytc_core_info_get_core_features +.. _tinytc_core_info_get_default_alignment: + +tinytc_core_info_get_default_alignment +...................................... + +.. doxygenfunction:: tinytc_core_info_get_default_alignment + +.. _tinytc_core_info_get_register_space: + tinytc_core_info_get_register_space ................................... .. doxygenfunction:: tinytc_core_info_get_register_space +.. _tinytc_core_info_get_subgroup_sizes: + tinytc_core_info_get_subgroup_sizes ................................... .. doxygenfunction:: tinytc_core_info_get_subgroup_sizes -tinytc_core_info_set_core_features -.................................. +.. _tinytc_core_info_have_spirv_feature: -.. doxygenfunction:: tinytc_core_info_set_core_features +tinytc_core_info_have_spirv_feature +................................... -tinytc_core_info_generic_create -............................... +.. doxygenfunction:: tinytc_core_info_have_spirv_feature -.. doxygenfunction:: tinytc_core_info_generic_create +.. _tinytc_core_info_intel_create: tinytc_core_info_intel_create ............................. .. doxygenfunction:: tinytc_core_info_intel_create +.. _tinytc_core_info_intel_create_from_arch: + tinytc_core_info_intel_create_from_arch ....................................... .. doxygenfunction:: tinytc_core_info_intel_create_from_arch +.. _tinytc_core_info_intel_create_from_name: + +tinytc_core_info_intel_create_from_name +....................................... + +.. doxygenfunction:: tinytc_core_info_intel_create_from_name + +.. _tinytc_core_info_release: + tinytc_core_info_release ........................ .. doxygenfunction:: tinytc_core_info_release +.. _tinytc_core_info_retain: + tinytc_core_info_retain ....................... .. doxygenfunction:: tinytc_core_info_retain +.. _tinytc_core_info_set_core_features: + +tinytc_core_info_set_core_features +.................................. + +.. doxygenfunction:: tinytc_core_info_set_core_features + +.. _tinytc_core_info_set_default_alignment: + +tinytc_core_info_set_default_alignment +...................................... + +.. doxygenfunction:: tinytc_core_info_set_default_alignment + +.. _tinytc_core_info_set_spirv_feature: + +tinytc_core_info_set_spirv_feature +.................................. + +.. doxygenfunction:: tinytc_core_info_set_spirv_feature + +.. _tinytc_spirv_feature_to_string: + +tinytc_spirv_feature_to_string +.............................. + +.. doxygenfunction:: tinytc_spirv_feature_to_string + Device Info Typedefs -------------------- +.. _tinytc_core_feature_flags_t: + tinytc_core_feature_flags_t ........................... .. doxygentypedef:: tinytc_core_feature_flags_t +FP math +======= + +* Functions + + * :ref:`tinytc_f32_to_bf16_as_ui16` + + * :ref:`tinytc_f32_to_f16_as_ui16` + + * :ref:`tinytc_f16_as_ui16_to_f32` + + * :ref:`tinytc_bf16_as_ui16_to_f32` + +FP math Functions +----------------- + +.. _tinytc_f32_to_bf16_as_ui16: + +tinytc_f32_to_bf16_as_ui16 +.......................... + +.. doxygenfunction:: tinytc_f32_to_bf16_as_ui16 + +.. _tinytc_f32_to_f16_as_ui16: + +tinytc_f32_to_f16_as_ui16 +......................... + +.. doxygenfunction:: tinytc_f32_to_f16_as_ui16 + +.. _tinytc_f16_as_ui16_to_f32: + +tinytc_f16_as_ui16_to_f32 +......................... + +.. doxygenfunction:: tinytc_f16_as_ui16_to_f32 + +.. _tinytc_bf16_as_ui16_to_f32: + +tinytc_bf16_as_ui16_to_f32 +.......................... + +.. doxygenfunction:: tinytc_bf16_as_ui16_to_f32 + Parser ====== @@ -375,16 +711,22 @@ Parser Parser Functions ---------------- +.. _tinytc_parse_file: + tinytc_parse_file ................. .. doxygenfunction:: tinytc_parse_file +.. _tinytc_parse_stdin: + tinytc_parse_stdin .................. .. doxygenfunction:: tinytc_parse_stdin +.. _tinytc_parse_string: + tinytc_parse_string ................... @@ -399,9 +741,9 @@ Recipe * Functions - * :ref:`tinytc_recipe_get_prog` + * :ref:`tinytc_recipe_get_binary` - * :ref:`tinytc_recipe_get_source` + * :ref:`tinytc_recipe_get_prog` * :ref:`tinytc_recipe_handler_get_recipe` @@ -428,6 +770,8 @@ Recipe Recipe Enumerations ------------------- +.. _tinytc_mem_type_t: + tinytc_mem_type_t ................. @@ -436,168 +780,147 @@ tinytc_mem_type_t Recipe Functions ---------------- +.. _tinytc_recipe_get_binary: + +tinytc_recipe_get_binary +........................ + +.. doxygenfunction:: tinytc_recipe_get_binary + +.. _tinytc_recipe_get_prog: + tinytc_recipe_get_prog ...................... .. doxygenfunction:: tinytc_recipe_get_prog -tinytc_recipe_get_source -........................ - -.. doxygenfunction:: tinytc_recipe_get_source +.. _tinytc_recipe_handler_get_recipe: tinytc_recipe_handler_get_recipe ................................ .. doxygenfunction:: tinytc_recipe_handler_get_recipe +.. _tinytc_recipe_small_gemm_batched_create: + tinytc_recipe_small_gemm_batched_create ....................................... .. doxygenfunction:: tinytc_recipe_small_gemm_batched_create +.. _tinytc_recipe_small_gemm_batched_set_args: + tinytc_recipe_small_gemm_batched_set_args ......................................... .. doxygenfunction:: tinytc_recipe_small_gemm_batched_set_args +.. _tinytc_recipe_tall_and_skinny_create: + tinytc_recipe_tall_and_skinny_create .................................... .. doxygenfunction:: tinytc_recipe_tall_and_skinny_create +.. _tinytc_recipe_tall_and_skinny_create_specialized: + tinytc_recipe_tall_and_skinny_create_specialized ................................................ .. doxygenfunction:: tinytc_recipe_tall_and_skinny_create_specialized +.. _tinytc_recipe_tall_and_skinny_set_args: + tinytc_recipe_tall_and_skinny_set_args ...................................... .. doxygenfunction:: tinytc_recipe_tall_and_skinny_set_args +.. _tinytc_recipe_tall_and_skinny_suggest_block_size: + tinytc_recipe_tall_and_skinny_suggest_block_size ................................................ .. doxygenfunction:: tinytc_recipe_tall_and_skinny_suggest_block_size +.. _tinytc_recipe_release: + tinytc_recipe_release ..................... .. doxygenfunction:: tinytc_recipe_release +.. _tinytc_recipe_retain: + tinytc_recipe_retain .................... .. doxygenfunction:: tinytc_recipe_retain +.. _tinytc_recipe_handler_release: + tinytc_recipe_handler_release ............................. .. doxygenfunction:: tinytc_recipe_handler_release +.. _tinytc_recipe_handler_retain: + tinytc_recipe_handler_retain ............................ .. doxygenfunction:: tinytc_recipe_handler_retain -Source -====== +SPIR-V module +============= * Functions - * :ref:`tinytc_source_get_code` - - * :ref:`tinytc_source_get_core_features` - - * :ref:`tinytc_source_get_location` - - * :ref:`tinytc_source_get_extensions` - - * :ref:`tinytc_source_release` + * :ref:`tinytc_spv_mod_dump` - * :ref:`tinytc_source_retain` + * :ref:`tinytc_spv_mod_print_to_file` -Source Functions ----------------- - -tinytc_source_get_code -...................... - -.. doxygenfunction:: tinytc_source_get_code - -tinytc_source_get_core_features -............................... - -.. doxygenfunction:: tinytc_source_get_core_features - -tinytc_source_get_location -.......................... - -.. doxygenfunction:: tinytc_source_get_location - -tinytc_source_get_extensions -............................ - -.. doxygenfunction:: tinytc_source_get_extensions - -tinytc_source_release -..................... - -.. doxygenfunction:: tinytc_source_release + * :ref:`tinytc_spv_mod_print_to_string` -tinytc_source_retain -.................... - -.. doxygenfunction:: tinytc_source_retain - -Source Context -============== - -* Functions - - * :ref:`tinytc_source_context_create` + * :ref:`tinytc_spv_mod_release` - * :ref:`tinytc_source_context_add_source` + * :ref:`tinytc_spv_mod_retain` - * :ref:`tinytc_source_context_get_error_log` +SPIR-V module Functions +----------------------- - * :ref:`tinytc_source_context_report_error` +.. _tinytc_spv_mod_dump: - * :ref:`tinytc_source_context_release` +tinytc_spv_mod_dump +................... - * :ref:`tinytc_source_context_retain` +.. doxygenfunction:: tinytc_spv_mod_dump -Source Context Functions ------------------------- +.. _tinytc_spv_mod_print_to_file: -tinytc_source_context_create +tinytc_spv_mod_print_to_file ............................ -.. doxygenfunction:: tinytc_source_context_create +.. doxygenfunction:: tinytc_spv_mod_print_to_file -tinytc_source_context_add_source -................................ - -.. doxygenfunction:: tinytc_source_context_add_source +.. _tinytc_spv_mod_print_to_string: -tinytc_source_context_get_error_log -................................... +tinytc_spv_mod_print_to_string +.............................. -.. doxygenfunction:: tinytc_source_context_get_error_log +.. doxygenfunction:: tinytc_spv_mod_print_to_string -tinytc_source_context_report_error -.................................. +.. _tinytc_spv_mod_release: -.. doxygenfunction:: tinytc_source_context_report_error +tinytc_spv_mod_release +...................... -tinytc_source_context_release -............................. +.. doxygenfunction:: tinytc_spv_mod_release -.. doxygenfunction:: tinytc_source_context_release +.. _tinytc_spv_mod_retain: -tinytc_source_context_retain -............................ +tinytc_spv_mod_retain +..................... -.. doxygenfunction:: tinytc_source_context_retain +.. doxygenfunction:: tinytc_spv_mod_retain diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 0a2e97c6..bb426775 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -21,17 +21,19 @@ Core C-API: - tinytc_core_info_t - tinytc_recipe_t - tinytc_recipe_handler_t - - tinytc_source_t - - tinytc_source_context_t + - tinytc_spv_mod_t + - tinytc_compiler_context_t - const_tinytc_binary_t - const_tinytc_core_info_t - const_tinytc_recipe_t - const_tinytc_recipe_handler_t - - const_tinytc_source_t - - const_tinytc_source_context_t + - const_tinytc_spv_mod_t + - const_tinytc_compiler_context_t + - tinytc_error_reporter_t Binary: function: - tinytc_binary_create + - tinytc_binary_get_compiler_context - tinytc_binary_get_core_features - tinytc_binary_get_raw - tinytc_binary_release @@ -39,24 +41,52 @@ Core C-API: Compiler: enum: - tinytc_bundle_format_t + - tinytc_optflag_t function: - - tinytc_prog_compile_to_opencl + - tinytc_run_function_pass + - tinytc_list_function_passes + - tinytc_prog_compile_to_spirv + - tinytc_prog_compile_to_spirv_and_assemble + - tinytc_spirv_assemble + Compiler Context: + function: + - tinytc_compiler_context_create + - tinytc_compiler_context_add_source + - tinytc_compiler_context_set_error_reporter + - tinytc_compiler_context_set_optimization_flag + - tinytc_compiler_context_set_optimization_level + - tinytc_compiler_context_report_error + - tinytc_compiler_context_release + - tinytc_compiler_context_retain Device Info: enum: - tinytc_core_feature_flag_t - tinytc_intel_gpu_architecture_t + - tinytc_spirv_feature_t function: + - tinytc_core_info_generic_create - tinytc_core_info_get_core_features + - tinytc_core_info_get_default_alignment - tinytc_core_info_get_register_space - tinytc_core_info_get_subgroup_sizes - - tinytc_core_info_set_core_features - - tinytc_core_info_generic_create + - tinytc_core_info_have_spirv_feature - tinytc_core_info_intel_create - tinytc_core_info_intel_create_from_arch + - tinytc_core_info_intel_create_from_name - tinytc_core_info_release - tinytc_core_info_retain + - tinytc_core_info_set_core_features + - tinytc_core_info_set_default_alignment + - tinytc_core_info_set_spirv_feature + - tinytc_spirv_feature_to_string typedef: - tinytc_core_feature_flags_t + FP math: + function: + - tinytc_f32_to_bf16_as_ui16 + - tinytc_f32_to_f16_as_ui16 + - tinytc_f16_as_ui16_to_f32 + - tinytc_bf16_as_ui16_to_f32 Parser: function: - tinytc_parse_file @@ -66,8 +96,8 @@ Core C-API: enum: - tinytc_mem_type_t function: + - tinytc_recipe_get_binary - tinytc_recipe_get_prog - - tinytc_recipe_get_source - tinytc_recipe_handler_get_recipe - tinytc_recipe_small_gemm_batched_create - tinytc_recipe_small_gemm_batched_set_args @@ -79,19 +109,10 @@ Core C-API: - tinytc_recipe_retain - tinytc_recipe_handler_release - tinytc_recipe_handler_retain - Source: - function: - - tinytc_source_get_code - - tinytc_source_get_core_features - - tinytc_source_get_location - - tinytc_source_get_extensions - - tinytc_source_release - - tinytc_source_retain - Source Context: + SPIR-V module: function: - - tinytc_source_context_create - - tinytc_source_context_add_source - - tinytc_source_context_get_error_log - - tinytc_source_context_report_error - - tinytc_source_context_release - - tinytc_source_context_retain + - tinytc_spv_mod_dump + - tinytc_spv_mod_print_to_file + - tinytc_spv_mod_print_to_string + - tinytc_spv_mod_release + - tinytc_spv_mod_retain diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index ba68c04e..fb3db23f 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Core C++-API: + ============ Core C++-API ============ @@ -10,32 +12,48 @@ Common * Enumerations - * :ref:`status` + * :ref:`tinytc::status` - * :ref:`support_level` + * :ref:`tinytc::support_level` * Functions - * :ref:`error_string` + * :ref:`tinytc::error_string` - * :ref:`CHECK_STATUS` + * :ref:`tinytc::CHECK_STATUS` - * :ref:`CHECK_STATUS_LOC` + * :ref:`tinytc::CHECK_STATUS_LOC` * Classes - * :ref:`shared_handle` + * :ref:`tinytc::array_view_base` + + * :ref:`tinytc::array_view` + + * :ref:`tinytc::mutable_array_view` + + * :ref:`tinytc::handle` + + * :ref:`tinytc::shared_handle` + + * :ref:`tinytc::unique_handle` - * :ref:`unique_handle` +* Typedefs + + * :ref:`tinytc::error_reporter_t` Common Enumerations ------------------- +.. _tinytc::status: + status ...... .. doxygenenum:: tinytc::status +.. _tinytc::support_level: + support_level ............. @@ -44,16 +62,22 @@ support_level Common Functions ---------------- +.. _tinytc::error_string: + error_string ............ .. doxygenfunction:: tinytc::error_string +.. _tinytc::CHECK_STATUS: + CHECK_STATUS ............ .. doxygenfunction:: tinytc::CHECK_STATUS +.. _tinytc::CHECK_STATUS_LOC: + CHECK_STATUS_LOC ................ @@ -62,34 +86,78 @@ CHECK_STATUS_LOC Common Classes -------------- +.. _tinytc::array_view_base: + +array_view_base +............... + +.. doxygenclass:: tinytc::array_view_base + +.. _tinytc::array_view: + +array_view +.......... + +.. doxygenclass:: tinytc::array_view + +.. _tinytc::mutable_array_view: + +mutable_array_view +.................. + +.. doxygenclass:: tinytc::mutable_array_view + +.. _tinytc::handle: + +handle +...... + +.. doxygenclass:: tinytc::handle + +.. _tinytc::shared_handle: + shared_handle ............. .. doxygenclass:: tinytc::shared_handle +.. _tinytc::unique_handle: + unique_handle ............. .. doxygenclass:: tinytc::unique_handle +Common Typedefs +--------------- + +.. _tinytc::error_reporter_t: + +error_reporter_t +................ + +.. doxygentypedef:: tinytc::error_reporter_t + Binary ====== * Enumerations - * :ref:`bundle_format` + * :ref:`tinytc::bundle_format` * Functions - * :ref:`make_binary` + * :ref:`tinytc::make_binary` * Classes - * :ref:`binary` + * :ref:`tinytc::binary` Binary Enumerations ------------------- +.. _tinytc::bundle_format: + bundle_format ............. @@ -98,6 +166,8 @@ bundle_format Binary Functions ---------------- +.. _tinytc::make_binary: + make_binary ........... @@ -106,6 +176,8 @@ make_binary Binary Classes -------------- +.. _tinytc::binary: + binary ...... @@ -116,100 +188,318 @@ Compiler * Functions - * :ref:`compile_to_opencl` + * :ref:`tinytc::run_function_pass` + + * :ref:`tinytc::list_function_passes` + + * :ref:`tinytc::compile_to_spirv` + + * :ref:`tinytc::compile_to_spirv_and_assemble` + + * :ref:`tinytc::spirv_assemble` Compiler Functions ------------------ -compile_to_opencl +.. _tinytc::run_function_pass: + +run_function_pass ................. -.. doxygenfunction:: tinytc::compile_to_opencl +.. doxygenfunction:: tinytc::run_function_pass + +.. _tinytc::list_function_passes: + +list_function_passes +.................... + +.. doxygenfunction:: tinytc::list_function_passes + +.. _tinytc::compile_to_spirv: + +compile_to_spirv +................ + +.. doxygenfunction:: tinytc::compile_to_spirv + +.. _tinytc::compile_to_spirv_and_assemble: + +compile_to_spirv_and_assemble +............................. + +.. doxygenfunction:: tinytc::compile_to_spirv_and_assemble + +.. _tinytc::spirv_assemble: + +spirv_assemble +.............. + +.. doxygenfunction:: tinytc::spirv_assemble + +Compiler Context +================ + +* Functions + + * :ref:`tinytc::make_compiler_context` + +* Classes + + * :ref:`tinytc::compiler_context` + +Compiler Context Functions +-------------------------- + +.. _tinytc::make_compiler_context: + +make_compiler_context +..................... + +.. doxygenfunction:: tinytc::make_compiler_context + +Compiler Context Classes +------------------------ + +.. _tinytc::compiler_context: + +compiler_context +................ + +.. doxygenclass:: tinytc::compiler_context Device Info =========== * Enumerations - * :ref:`core_feature_flag` + * :ref:`tinytc::core_feature_flag` + + * :ref:`tinytc::intel_gpu_architecture` - * :ref:`intel_gpu_architecture` + * :ref:`tinytc::spirv_feature` * Functions - * :ref:`make_core_info_generic` + * :ref:`tinytc::make_core_info_generic` + + * :ref:`tinytc::make_core_info_intel` - * :ref:`make_core_info_intel` + * :ref:`tinytc::make_core_info_intel_from_arch` - * :ref:`make_core_info_intel_from_arch` + * :ref:`tinytc::make_core_info_intel_from_name` + + * :ref:`tinytc::to_string(spirv_feature)` * Classes - * :ref:`core_info` + * :ref:`tinytc::core_info` Device Info Enumerations ------------------------ +.. _tinytc::core_feature_flag: + core_feature_flag ................. .. doxygenenum:: tinytc::core_feature_flag +.. _tinytc::intel_gpu_architecture: + intel_gpu_architecture ...................... .. doxygenenum:: tinytc::intel_gpu_architecture +.. _tinytc::spirv_feature: + +spirv_feature +............. + +.. doxygenenum:: tinytc::spirv_feature + Device Info Functions --------------------- +.. _tinytc::make_core_info_generic: + make_core_info_generic ...................... .. doxygenfunction:: tinytc::make_core_info_generic +.. _tinytc::make_core_info_intel: + make_core_info_intel .................... .. doxygenfunction:: tinytc::make_core_info_intel +.. _tinytc::make_core_info_intel_from_arch: + make_core_info_intel_from_arch .............................. .. doxygenfunction:: tinytc::make_core_info_intel_from_arch +.. _tinytc::make_core_info_intel_from_name: + +make_core_info_intel_from_name +.............................. + +.. doxygenfunction:: tinytc::make_core_info_intel_from_name + +.. _tinytc::to_string(spirv_feature): + +to_string(spirv_feature) +........................ + +.. doxygenfunction:: tinytc::to_string(spirv_feature) + Device Info Classes ------------------- +.. _tinytc::core_info: + core_info ......... .. doxygenclass:: tinytc::core_info +FP math +======= + +* Functions + + * :ref:`tinytc::ieee754_extend` + + * :ref:`tinytc::ieee754_truncate` + +* Classes + + * :ref:`tinytc::lp_float` + +* Structures + + * :ref:`tinytc::ieee754_format` + +* Typedefs + + * :ref:`tinytc::bf16_format` + + * :ref:`tinytc::bfloat16` + + * :ref:`tinytc::f16_format` + + * :ref:`tinytc::f32_format` + + * :ref:`tinytc::half` + +FP math Functions +----------------- + +.. _tinytc::ieee754_extend: + +ieee754_extend +.............. + +.. doxygenfunction:: tinytc::ieee754_extend + +.. _tinytc::ieee754_truncate: + +ieee754_truncate +................ + +.. doxygenfunction:: tinytc::ieee754_truncate + +FP math Classes +--------------- + +.. _tinytc::lp_float: + +lp_float +........ + +.. doxygenclass:: tinytc::lp_float + +FP math Structures +------------------ + +.. _tinytc::ieee754_format: + +ieee754_format +.............. + +.. doxygenstruct:: tinytc::ieee754_format + +FP math Typedefs +---------------- + +.. _tinytc::bf16_format: + +bf16_format +........... + +.. doxygentypedef:: tinytc::bf16_format + +.. _tinytc::bfloat16: + +bfloat16 +........ + +.. doxygentypedef:: tinytc::bfloat16 + +.. _tinytc::f16_format: + +f16_format +.......... + +.. doxygentypedef:: tinytc::f16_format + +.. _tinytc::f32_format: + +f32_format +.......... + +.. doxygentypedef:: tinytc::f32_format + +.. _tinytc::half: + +half +.... + +.. doxygentypedef:: tinytc::half + Parser ====== * Functions - * :ref:`parse_file` + * :ref:`tinytc::parse_file` - * :ref:`parse_stdin` + * :ref:`tinytc::parse_stdin` - * :ref:`parse_string` + * :ref:`tinytc::parse_string` Parser Functions ---------------- +.. _tinytc::parse_file: + parse_file .......... .. doxygenfunction:: tinytc::parse_file +.. _tinytc::parse_stdin: + parse_stdin ........... .. doxygenfunction:: tinytc::parse_stdin +.. _tinytc::parse_string: + parse_string ............ @@ -220,43 +510,47 @@ Recipe * Enumerations - * :ref:`mem_type` + * :ref:`tinytc::mem_type` * Functions - * :ref:`make_small_gemm_batched` + * :ref:`tinytc::make_small_gemm_batched` - * :ref:`make_tall_and_skinny` + * :ref:`tinytc::make_tall_and_skinny` - * :ref:`make_tall_and_skinny_specialized` + * :ref:`tinytc::make_tall_and_skinny_specialized` * Classes - * :ref:`recipe` + * :ref:`tinytc::recipe` - * :ref:`recipe_handler` + * :ref:`tinytc::recipe_handler` - * :ref:`small_gemm_batched` + * :ref:`tinytc::small_gemm_batched` - * :ref:`tall_and_skinny` + * :ref:`tinytc::tall_and_skinny` * Structures - * :ref:`auto_mem_type` + * :ref:`tinytc::auto_mem_type` - * :ref:`auto_mem_type\\>\>` + * :ref:`tinytc::auto_mem_type\< T, std::enable_if_t\< is_usm_pointer_type\< T \> \> \>` - * :ref:`mem` + * :ref:`tinytc::mem` * Variables - * :ref:`auto_mem_type_v` + * :ref:`tinytc::auto_mem_type_v` + + * :ref:`tinytc::is_supported_scalar_type` - * :ref:`usm_pointer_type` + * :ref:`tinytc::is_usm_pointer_type` Recipe Enumerations ------------------- +.. _tinytc::mem_type: + mem_type ........ @@ -265,16 +559,22 @@ mem_type Recipe Functions ---------------- +.. _tinytc::make_small_gemm_batched: + make_small_gemm_batched ....................... .. doxygenfunction:: tinytc::make_small_gemm_batched +.. _tinytc::make_tall_and_skinny: + make_tall_and_skinny .................... .. doxygenfunction:: tinytc::make_tall_and_skinny +.. _tinytc::make_tall_and_skinny_specialized: + make_tall_and_skinny_specialized ................................ @@ -283,21 +583,29 @@ make_tall_and_skinny_specialized Recipe Classes -------------- +.. _tinytc::recipe: + recipe ...... .. doxygenclass:: tinytc::recipe +.. _tinytc::recipe_handler: + recipe_handler .............. .. doxygenclass:: tinytc::recipe_handler +.. _tinytc::small_gemm_batched: + small_gemm_batched .................. .. doxygenclass:: tinytc::small_gemm_batched +.. _tinytc::tall_and_skinny: + tall_and_skinny ............... @@ -306,15 +614,21 @@ tall_and_skinny Recipe Structures ----------------- +.. _tinytc::auto_mem_type: + auto_mem_type ............. .. doxygenstruct:: tinytc::auto_mem_type -auto_mem_type>> -....................................................... +.. _tinytc::auto_mem_type\< T, std::enable_if_t\< is_usm_pointer_type\< T \> \> \>: -.. doxygenstruct:: tinytc::auto_mem_type< T, std::enable_if_t< usm_pointer_type< T > > > +auto_mem_type>> +.......................................................... + +.. doxygenstruct:: tinytc::auto_mem_type< T, std::enable_if_t< is_usm_pointer_type< T > > > + +.. _tinytc::mem: mem ... @@ -324,55 +638,41 @@ mem Recipe Variables ---------------- +.. _tinytc::auto_mem_type_v: + auto_mem_type_v ............... .. doxygenvariable:: tinytc::auto_mem_type_v -usm_pointer_type -................ - -.. doxygenvariable:: tinytc::usm_pointer_type - -Source -====== +.. _tinytc::is_supported_scalar_type: -* Classes +is_supported_scalar_type +........................ - * :ref:`source` +.. doxygenvariable:: tinytc::is_supported_scalar_type -Source Classes --------------- +.. _tinytc::is_usm_pointer_type: -source -...... - -.. doxygenclass:: tinytc::source - -Source Context -============== +is_usm_pointer_type +................... -* Functions +.. doxygenvariable:: tinytc::is_usm_pointer_type - * :ref:`make_source_context` +SPIR-V module +============= * Classes - * :ref:`source_context` + * :ref:`tinytc::spv_mod` -Source Context Functions ------------------------- - -make_source_context -................... - -.. doxygenfunction:: tinytc::make_source_context +SPIR-V module Classes +--------------------- -Source Context Classes ----------------------- +.. _tinytc::spv_mod: -source_context -.............. +spv_mod +....... -.. doxygenclass:: tinytc::source_context +.. doxygenclass:: tinytc::spv_mod diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 990f9cf6..72772a44 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -10,8 +10,14 @@ Core C++-API: - tinytc::CHECK_STATUS - tinytc::CHECK_STATUS_LOC class: + - tinytc::array_view_base + - tinytc::array_view + - tinytc::mutable_array_view + - tinytc::handle - tinytc::shared_handle - tinytc::unique_handle + typedef: + - tinytc::error_reporter_t Binary: enum: - tinytc::bundle_format @@ -21,17 +27,43 @@ Core C++-API: - tinytc::binary Compiler: function: - - tinytc::compile_to_opencl + - tinytc::run_function_pass + - tinytc::list_function_passes + - tinytc::compile_to_spirv + - tinytc::compile_to_spirv_and_assemble + - tinytc::spirv_assemble + Compiler Context: + function: + - tinytc::make_compiler_context + class: + - tinytc::compiler_context Device Info: enum: - tinytc::core_feature_flag - tinytc::intel_gpu_architecture + - tinytc::spirv_feature function: - tinytc::make_core_info_generic - tinytc::make_core_info_intel - tinytc::make_core_info_intel_from_arch + - tinytc::make_core_info_intel_from_name + - tinytc::to_string(spirv_feature) class: - tinytc::core_info + FP math: + function: + - tinytc::ieee754_extend + - tinytc::ieee754_truncate + class: + - tinytc::lp_float + struct: + - tinytc::ieee754_format + typedef: + - tinytc::bf16_format + - tinytc::bfloat16 + - tinytc::f16_format + - tinytc::f32_format + - tinytc::half Parser: function: - tinytc::parse_file @@ -51,16 +83,12 @@ Core C++-API: - tinytc::tall_and_skinny struct: - tinytc::auto_mem_type - - tinytc::auto_mem_type< T, std::enable_if_t< usm_pointer_type< T > > > + - tinytc::auto_mem_type< T, std::enable_if_t< is_usm_pointer_type< T > > > - tinytc::mem variable: - tinytc::auto_mem_type_v - - tinytc::usm_pointer_type - Source: - class: - - tinytc::source - Source Context: - function: - - tinytc::make_source_context + - tinytc::is_supported_scalar_type + - tinytc::is_usm_pointer_type + SPIR-V module: class: - - tinytc::source_context + - tinytc::spv_mod diff --git a/docs/api/sycl/cxxapi.rst b/docs/api/sycl/cxxapi.rst index 82417e6c..90e0f894 100644 --- a/docs/api/sycl/cxxapi.rst +++ b/docs/api/sycl/cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _SYCL C++-API: + ======= C++-API ======= @@ -10,18 +12,22 @@ Device Info * Functions - * :ref:`get_support_level(sycl::device const&)` + * :ref:`tinytc::get_support_level(sycl::device const&)` - * :ref:`make_core_info(sycl::device const&)` + * :ref:`tinytc::make_core_info(sycl::device const&)` Device Info Functions --------------------- +.. _tinytc::get_support_level(sycl::device const&): + get_support_level(sycl::device const&) ...................................... .. doxygenfunction:: tinytc::get_support_level(sycl::device const&) +.. _tinytc::make_core_info(sycl::device const&): + make_core_info(sycl::device const&) ................................... @@ -32,87 +38,98 @@ Kernel * Functions - * :ref:`get_execution_range` - - * :ref:`get_global_size(std::int64_t,sycl::range\<3u\> const &)` + * :ref:`tinytc::get_execution_range` - * :ref:`get_group_size(sycl::kernel const &)` + * :ref:`tinytc::get_global_size(sycl::range\<3u\> const &,sycl::range\<3u\> const &)` - * :ref:`make_kernel(sycl::kernel_bundle\ const &,char const \\*)` + * :ref:`tinytc::get_group_size(sycl::kernel const &)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context)` + * :ref:`tinytc::make_kernel(sycl::kernel_bundle\ const &,char const \*)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context)` + * :ref:`tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t)` Kernel Functions ---------------- +.. _tinytc::get_execution_range: + get_execution_range ................... .. doxygenfunction:: tinytc::get_execution_range -get_global_size(std::int64_t,sycl::range<3u> const &) -..................................................... +.. _tinytc::get_global_size(sycl::range\<3u\> const &,sycl::range\<3u\> const &): + +get_global_size(sycl::range<3u> const &,sycl::range<3u> const &) +................................................................ -.. doxygenfunction:: tinytc::get_global_size(std::int64_t,sycl::range<3u> const &) +.. doxygenfunction:: tinytc::get_global_size(sycl::range<3u> const &,sycl::range<3u> const &) + +.. _tinytc::get_group_size(sycl::kernel const &): get_group_size(sycl::kernel const &) .................................... .. doxygenfunction:: tinytc::get_group_size(sycl::kernel const &) +.. _tinytc::make_kernel(sycl::kernel_bundle\ const &,char const \*): + make_kernel(sycl::kernel_bundle const &,char const \*) ...................................................................................... .. doxygenfunction:: tinytc::make_kernel(sycl::kernel_bundle const &,char const *) -make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) -............................................................................................ +.. _tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &): -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) +make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) +............................................................................. -make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) -.............................................................................................................. +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) +.. _tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t): -make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) -............................................................................................ +make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) +............................................................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) Recipe ====== * Functions - * :ref:`make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context)` + * :ref:`tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &)` - * :ref:`make_recipe_handler(sycl::queue const&,recipe const&,source_context)` + * :ref:`tinytc::make_recipe_handler(sycl::queue const&,recipe const&)` * Classes - * :ref:`sycl_recipe_handler` + * :ref:`tinytc::sycl_recipe_handler` Recipe Functions ---------------- -make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) -............................................................................................. +.. _tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &): + +make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) +.............................................................................. -.. doxygenfunction:: tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) -make_recipe_handler(sycl::queue const&,recipe const&,source_context) -.................................................................... +.. _tinytc::make_recipe_handler(sycl::queue const&,recipe const&): -.. doxygenfunction:: tinytc::make_recipe_handler(sycl::queue const&,recipe const&,source_context) +make_recipe_handler(sycl::queue const&,recipe const&) +..................................................... + +.. doxygenfunction:: tinytc::make_recipe_handler(sycl::queue const&,recipe const&) Recipe Classes -------------- +.. _tinytc::sycl_recipe_handler: + sycl_recipe_handler ................... diff --git a/docs/api/sycl/cxxapi.yaml b/docs/api/sycl/cxxapi.yaml index 2d3416f6..d1a0cd66 100644 --- a/docs/api/sycl/cxxapi.yaml +++ b/docs/api/sycl/cxxapi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C++-API: +C++-API : Device Info: function: - tinytc::get_support_level(sycl::device const&) @@ -8,15 +8,14 @@ C++-API: Kernel: function: - tinytc::get_execution_range - - tinytc::get_global_size(std::int64_t,sycl::range<3u> const &) + - tinytc::get_global_size(sycl::range<3u> const &,sycl::range<3u> const &) - tinytc::get_group_size(sycl::kernel const &) - tinytc::make_kernel(sycl::kernel_bundle const &,char const *) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t) Recipe: function: - - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) - - tinytc::make_recipe_handler(sycl::queue const&,recipe const&,source_context) + - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &) + - tinytc::make_recipe_handler(sycl::queue const&,recipe const&) class: - tinytc::sycl_recipe_handler diff --git a/docs/api/ze/capi.rst b/docs/api/ze/capi.rst index d0c0e6ca..d3922d44 100644 --- a/docs/api/ze/capi.rst +++ b/docs/api/ze/capi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Level Zero C-API: + ===== C-API ===== @@ -15,6 +17,8 @@ Common Common Functions ---------------- +.. _tinytc_ze_convert_status: + tinytc_ze_convert_status ........................ @@ -32,11 +36,15 @@ Device Info Device Info Functions --------------------- +.. _tinytc_ze_core_info_create: + tinytc_ze_core_info_create .......................... .. doxygenfunction:: tinytc_ze_core_info_create +.. _tinytc_ze_get_support_level: + tinytc_ze_get_support_level ........................... @@ -47,8 +55,6 @@ Kernel * Functions - * :ref:`tinytc_ze_get_group_count` - * :ref:`tinytc_ze_get_group_size` * :ref:`tinytc_ze_kernel_create` @@ -57,48 +63,37 @@ Kernel * :ref:`tinytc_ze_kernel_bundle_create_with_program` - * :ref:`tinytc_ze_kernel_bundle_create_with_source` - - * :ref:`tinytc_ze_source_compile_to_binary` - Kernel Functions ---------------- -tinytc_ze_get_group_count -......................... - -.. doxygenfunction:: tinytc_ze_get_group_count +.. _tinytc_ze_get_group_size: tinytc_ze_get_group_size ........................ .. doxygenfunction:: tinytc_ze_get_group_size +.. _tinytc_ze_kernel_create: + tinytc_ze_kernel_create ....................... .. doxygenfunction:: tinytc_ze_kernel_create +.. _tinytc_ze_kernel_bundle_create_with_binary: + tinytc_ze_kernel_bundle_create_with_binary .......................................... .. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_binary +.. _tinytc_ze_kernel_bundle_create_with_program: + tinytc_ze_kernel_bundle_create_with_program ........................................... .. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_program -tinytc_ze_kernel_bundle_create_with_source -.......................................... - -.. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_source - -tinytc_ze_source_compile_to_binary -.................................. - -.. doxygenfunction:: tinytc_ze_source_compile_to_binary - Recipe ====== @@ -111,11 +106,15 @@ Recipe Recipe Functions ---------------- +.. _tinytc_ze_recipe_handler_create: + tinytc_ze_recipe_handler_create ............................... .. doxygenfunction:: tinytc_ze_recipe_handler_create +.. _tinytc_ze_recipe_handler_submit: + tinytc_ze_recipe_handler_submit ............................... diff --git a/docs/api/ze/capi.yaml b/docs/api/ze/capi.yaml index acfde096..ce7fe382 100644 --- a/docs/api/ze/capi.yaml +++ b/docs/api/ze/capi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C-API: +C-API : Common: function: - tinytc_ze_convert_status @@ -10,13 +10,10 @@ C-API: - tinytc_ze_get_support_level Kernel: function: - - tinytc_ze_get_group_count - tinytc_ze_get_group_size - tinytc_ze_kernel_create - tinytc_ze_kernel_bundle_create_with_binary - tinytc_ze_kernel_bundle_create_with_program - - tinytc_ze_kernel_bundle_create_with_source - - tinytc_ze_source_compile_to_binary Recipe: function: - tinytc_ze_recipe_handler_create diff --git a/docs/api/ze/cxxapi.rst b/docs/api/ze/cxxapi.rst index d11b47b5..fc18b4d9 100644 --- a/docs/api/ze/cxxapi.rst +++ b/docs/api/ze/cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Level Zero C++-API: + ======= C++-API ======= @@ -10,11 +12,13 @@ Common * Functions - * :ref:`ZE_CHECK_STATUS` + * :ref:`tinytc::ZE_CHECK_STATUS` Common Functions ---------------- +.. _tinytc::ZE_CHECK_STATUS: + ZE_CHECK_STATUS ............... @@ -25,18 +29,22 @@ Device Info * Functions - * :ref:`get_support_level(ze_device_handle_t)` + * :ref:`tinytc::get_support_level(ze_device_handle_t)` - * :ref:`make_core_info(ze_device_handle_t)` + * :ref:`tinytc::make_core_info(ze_device_handle_t)` Device Info Functions --------------------- +.. _tinytc::get_support_level(ze_device_handle_t): + get_support_level(ze_device_handle_t) ..................................... .. doxygenfunction:: tinytc::get_support_level(ze_device_handle_t) +.. _tinytc::make_core_info(ze_device_handle_t): + make_core_info(ze_device_handle_t) .................................. @@ -47,80 +55,71 @@ Kernel * Functions - * :ref:`compile_to_binary` - - * :ref:`get_group_count` - - * :ref:`get_group_size(ze_kernel_handle_t)` - - * :ref:`make_kernel(ze_module_handle_t,char const \\*)` + * :ref:`tinytc::get_group_size(ze_kernel_handle_t)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context)` + * :ref:`tinytc::make_kernel(ze_module_handle_t,char const \*)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context)` + * :ref:`tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t)` Kernel Functions ---------------- -compile_to_binary -................. - -.. doxygenfunction:: tinytc::compile_to_binary - -get_group_count -............... - -.. doxygenfunction:: tinytc::get_group_count +.. _tinytc::get_group_size(ze_kernel_handle_t): get_group_size(ze_kernel_handle_t) .................................. .. doxygenfunction:: tinytc::get_group_size(ze_kernel_handle_t) +.. _tinytc::make_kernel(ze_module_handle_t,char const \*): + make_kernel(ze_module_handle_t,char const \*) ............................................. .. doxygenfunction:: tinytc::make_kernel(ze_module_handle_t,char const *) -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) -....................................................................................... +.. _tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&): -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) +........................................................................ -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) -.......................................................................................................... +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) +.. _tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t): -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) -....................................................................................... +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) +........................................................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) Recipe ====== * Functions - * :ref:`make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context)` + * :ref:`tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&)` * Classes - * :ref:`level_zero_recipe_handler` + * :ref:`tinytc::level_zero_recipe_handler` Recipe Functions ---------------- -make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) -........................................................................................ +.. _tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&): + +make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) +......................................................................... -.. doxygenfunction:: tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) Recipe Classes -------------- +.. _tinytc::level_zero_recipe_handler: + level_zero_recipe_handler ......................... diff --git a/docs/api/ze/cxxapi.yaml b/docs/api/ze/cxxapi.yaml index 4308b3ed..a4bd97b9 100644 --- a/docs/api/ze/cxxapi.yaml +++ b/docs/api/ze/cxxapi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C++-API: +C++-API : Common: function: - tinytc::ZE_CHECK_STATUS @@ -10,15 +10,12 @@ C++-API: - tinytc::make_core_info(ze_device_handle_t) Kernel: function: - - tinytc::compile_to_binary - - tinytc::get_group_count - tinytc::get_group_size(ze_kernel_handle_t) - tinytc::make_kernel(ze_module_handle_t,char const *) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t) Recipe: function: - - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) + - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&) class: - tinytc::level_zero_recipe_handler diff --git a/docs/conf.py b/docs/conf.py index 3d0d7c8e..8f6376bc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,7 +5,7 @@ copyright = '2024, Intel Corporation' author = 'Intel' -extensions = ['breathe', 'sphinx.ext.autosectionlabel', 'sphinx.ext.mathjax', 'sphinx_tabs.tabs'] +extensions = ['breathe', 'sphinx.ext.mathjax', 'sphinx_tabs.tabs'] templates_path = ['_templates'] exclude_patterns = [] diff --git a/docs/dev/coopmatrix_layout.rst b/docs/dev/coopmatrix_layout.rst new file mode 100644 index 00000000..3b4f23e1 --- /dev/null +++ b/docs/dev/coopmatrix_layout.rst @@ -0,0 +1,219 @@ +.. Copyright (C) 2025 Intel Corporation + SPDX-License-Identifier: BSD-3-Clause + +.. _coopmatrix layout: + +================= +Coopmatrix layout +================= + +A cooperative matrix is distributed to work-items such that each work-item holds an equal +share of the matrix, potentially with zero-padding if the matrix size is not divisible by +the subgroup size. + +General layout +============== + +Let a coopmatrix :math:`A` of size :math:`M\times N` be given and let the subgroup size +be given by :math:`S`. +We require :math:`M,S` to be powers of two and to be greater or equal than 1. +Internally, we represent the :math:`A` matrix as the :math:`I \times K_1\times J \times K_2` tensor :math:`A^*`. +The mapping of :math:`A` to :math:`A^*` is + +.. math:: + + A^*_{i,k_1,j,k_2} = \left\{\begin{array}{rcl} + A_{i+k_1I+k_2IK_1,j} & \text{ if } & i+k_1I+k_2IK_1 < M \wedge j < N, \\ + 0 & \text{ else.} + \end{array}\right. + +The shape of :math:`A^*` is given by :math:`(I,K_1,J,K_2)`, where + +.. math:: + + \begin{aligned} + I &:= \min(M, S),\\ + J &:= \min(\{n\in\mathbb N : n \geq N \wedge (In) \bmod S = 0\})\\ + K &:= K_1K_2 = M/I.\\ + \end{aligned} + +As both :math:`S` and :math:`I` are powers of two, an explicit formula for :math:`J` is given by +:math:`J = (\lceil IN/S\rceil S) / I`. + +Work-item mapping +----------------- + +We linearize the index of the :math:`A^*` tensor canonically: + +.. math:: + + L(i,j,k) := i + k_1I + j IK_1 + k_2 IK_1J + +Every work-item stores a vector :math:`v` with :math:`V:=IKJ/S` components. +We define the per work-item vector as + +.. math:: + + W^p := (v \in [V] : A^*[L^{-1}(p+vS)]), + +where :math:`p=0,\dots,S-1`. + +An example is helpful at this point. +Say we have a :math:`4\times 15` matrix and subgroup size :math:`S=16`, then the following table +shows how the work-item id maps to the 2d matrix index (the per work-item vectors are given by the +columns): + +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +p 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +.x 0,0 1,0 2,0 3,0 0,1 1,1 2,1 3,1 0,2 1,2 2,2 3,2 0,3 1,3 2,3 3,3 +.y 0,4 1,4 2,4 3,4 0,5 1,5 2,5 3,5 0,6 1,6 2,6 3,6 0,7 1,7 2,7 3,7 +.z 0,8 1,8 2,8 3,8 0,9 1,9 2,9 3,9 0,10 1,10 2,10 3,10 0,11 1,11 2,11 3,11 +.w 0,12 1,12 2,12 3,12 0,13 1,13 2,13 3,13 0,14 1,14 2,14 3,14 -, - -, - -, - -, - +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + +For a :math:`1\times 17` coopmatrix we have + +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +p 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +.x 0,0 0,1 0,2 0,3 0,4 0,5 0,6 0,7 0,8 0,9 0,10 0,11 0,12 0,13 0,14 0,15 +.y 0,16 -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + +Mapping properties +------------------ + +The inverse of :math:`L` is + +.. math:: + + \begin{aligned} + i &= L \bmod I, \\ + k_1 &= \lfloor L / I \rfloor \bmod K_1, \\ + j &= \lfloor L / (IK_1)\rfloor \bmod J, \\ + k_2 &= \lfloor L / (IK_1J)\rfloor. \\ + \end{aligned} + +Let :math:`v=w_1 + uK_1 + w_2K_1(V/K)`, with :math:`u=0,\dots,V/K-1`, +:math:`w_1=0,\dots,K_1-1`, and :math:`w=0,\dots,K_2-1`. +(Note that :math:`V/K=IJ/S`.) + +We first assume that :math:`I=S`. Then + +.. math:: + + \begin{aligned} + i &= (p + (w_1 + uK_1 + w_2K_1J)S) \bmod S = p, \\ + k_1 &= \lfloor (p + (w_1 + uK_1 + wK_1J)S) / S \rfloor \bmod K_1 = w_1, \\ + j &= \lfloor (p + (w_1 + uK_1 + wK_1J)S) / (SK_1) \rfloor \bmod J = u, \\ + k_2 &= \lfloor (p + (w_1 + uK_1 + wK_1J)S) / (SK_1J)\rfloor = w_2, \\ + \end{aligned} + +Now we assume :math:`I + vector load_coopmatrix(RealT* B, int pos0, int pos1, int shape0, int shape1, + int stride0, int stride1) { + constexpr int S = get_sub_group_size(); + constexpr int I = min(M,S); + constexpr int J = ceil(I*N/S)*S/I; + constexpr int K = M/I; + static_assert(K%K1 == 0); + constexpr int K2 = K/K1; + constexpr bool needs_mask = J*S/I > N; + + if (Transpose) { + std::swap(shape0, shape1); + std::swap(stride0, stride1); + } + + constexpr int V = I*K*J/S; + array R; + int p = get_sub_group_local_id(); + int i0 = p % I; + int j0 = p / I; + for (int w1 = 0; w1 < K1; ++w1) { + for (int w2 = 0; w2 < K2; ++w2) { + int k1 = w1, k2 = w2; + int row = pos0 + i0 + (k1 + k2*K1)*I; + bool row_ok = !RowsChecked || (row >= 0 && row < shape0); + if (row_ok) { + for (int u = 0; u < V/K; ++u) { + int j = j0 + u*(S/I); + int col = pos1 + j; + bool col_ok = !ColsChecked || (col >= 0 && col < shape1); + bool mask_ok = !needs_mask || j < N; + R[w1 + u*K1 + w2*K1*(V/K)] = mask_ok && col_ok ? A[row * stride0 + col * stride1] : 0; + } + } else { + for (int u = 0; u < V/K; ++u) { + R[w1 + u*K1 + w2*K1*(V/K)] = 0; + } + } + } + } + return R; + } + +Matrix Accumulator +================== + +For cooperative matrices with use *matrix_acc* we always have :math:`K_1=1`. + +Matrix A +======== + +For cooperative matrices with use *matrix_a* we always have :math:`K_1=1`. +Moreover, low precision matrices are VNNI transformed if :math:`N` is a multiple of :math:`\omega`, +where :math:`\omega=\max(1, \max(1,4/\text{size}(\text{ty}))` is the number of operands per channel. +We store :math:`A^*` tensors internally +as the :math:`\omega\times I \times K_1\times \lceil J/\omega\rceil \times K_2` tensor :math:`A^{**}`, +The mapping from :math:`A^*` to :math:`A^{**}` is given by + +.. math:: + + A^{**}_{c,i,k_1,j,k_2} := A^{*}_{i,k_1,c+j\omega,k_2} + +and the inverse mapping is given by + +.. math:: + + A^{*}_{i,k_1,j,k_2} = A^{**}_{j\bmod\omega,i,k_1,\lfloor j/\omega\rfloor,k_2}. + +Moreover, all channels of an entry are packed. E.g. for half precision floats we have two channels +and we pack :math:`A^{**}_{0,i,k_1,j,k_2}` in the lower 16 bits of an i32 and :math:`A^{**}_{1,i,k_1,j,k_2}` +in the higher 16 bits of the i32. +We store an i32 per work-item, so from the SIMT point of view one work-item owns all channels of an entry. + +Matrix B +======== + +Let :math:`\omega_b = \max(1,2/\text{size}(\text{ty}))`. +We choose :math:`K_1 = \omega_b \text{ if } M/S > 1 \text{ else } 1`. diff --git a/docs/index.rst b/docs/index.rst index 6a1226af..b37c8c47 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,12 @@ Table of contents api/sycl/index api/ze/index +.. toctree:: + :maxdepth: 2 + :caption: Developer guide + + dev/coopmatrix_layout + Index ----- diff --git a/docs/manual/build.rst b/docs/manual/build.rst index a1fd4f0b..8001aef4 100644 --- a/docs/manual/build.rst +++ b/docs/manual/build.rst @@ -8,20 +8,19 @@ Building and linking Dependencies ============ -- `CMake `_ >= 3.23 -- `Intel oneAPI Base Toolkit `_ -- OpenCL -- Level Zero -- `ocloc `_ (OpenCL offline compiler from the Intel Compute Runtime) -- `Double-Batched FFT Library `_ >= 0.5.1 -- `re2c `_ >= 3.0 -- `bison `_ >= 3.8.2 +- Mandatory + - `CMake `_ >= 3.23 + - `re2c `_ >= 3.0 + - `bison `_ >= 3.8.2 +- Optional + - `OpenCL ICD loader `_ [OpenCL, SYCL] + - `Level Zero Loader `_ [Level Zero, SYCL] + - `Intel oneAPI Base Toolkit `_ [SYCL] Build from source using oneAPI ============================== -Install CMake, the oneAPI Base Toolkit, the Intel compute runtime, re2c, and bison using your system -package manager. +Install the dependencies via your favourite package manager. Initialize the oneAPI environment. @@ -29,51 +28,38 @@ Initialize the oneAPI environment. . /opt/intel/oneapi/setvars.sh -Clone the Double-Batched FFT library to your filesystem. -.. code:: console - - git clone https://github.com/intel/double-batched-fft-library.git - -We only need libclir.so that we build and install using the following steps. - -.. code:: console - - cd double-batched-fft-library/clir/ - cmake -Bbuild -GNinja -DCMAKE_CXX_COMPILER=icpx -DCMAKE_INSTALL_PREFIX=$(pwd)/../../install/ \ - -DCMAKE_CXX_FLAGS="-ffp-model=precise" -DBUILD_SHARED_LIBS=YES - cmake --build build - cmake --install build - cd ../.. - -Then, build and install Tiny Tensor Compiler with the following steps +Build and install Tiny Tensor Compiler with the following steps .. code:: console git clone https://github.com/intel/tiny-tensor-compiler.git tinytc cd tinytc cmake -Bbuild -GNinja -DCMAKE_CXX_COMPILER=icpx -DCMAKE_INSTALL_PREFIX=$(pwd)/../install/ \ - -DCMAKE_PREFIX_PATH=$(pwd)/../install -DBUILD_SHARED_LIBS=YES + -DBUILD_SHARED_LIBS=YES cmake --build build cmake --install build cd .. -If you need a static library, set `-DBUILD_SHARED_LIBS=NO` when compiling libclir and the Tiny Tensor Compiler. +If you need a static library, set `-DBUILD_SHARED_LIBS=NO` when compiling the Tiny Tensor Compiler. +Adjust `CMAKE_INSTALL_PREFIX` to control the installation directory. +(Can be left empty for a system install; needs sudo.) Build options ============= The following CMake option options are supported. -====================== ============ -Option Description -====================== ============ -BUILD_DOCUMENTATION Build the documentation -BUILD_TESTING Build unit tests -BUILD_LEVEL_ZERO Build libtinytc_ze for Level Zero support (enforced if BUILD_SYCL=ON) -BUILD_OPENCL Build libtinytc_cl for OpenCL support (enforced if BUILD_SYCL=ON) -BUILD_SYCL Build libtinytc_sycl for SYCL support -====================== ============ +============================ ========================================================================= +Option Description +============================ ========================================================================= +BUILD_DOCUMENTATION Build the documentation +BUILD_DOUBLE_PRECISION_TESTS Build unit tests for double precision (e.g. for iGPUs without DP support) +BUILD_TESTING Build unit tests +BUILD_LEVEL_ZERO Build libtinytc_ze for Level Zero support (enforced if BUILD_SYCL=ON) +BUILD_OPENCL Build libtinytc_cl for OpenCL support (enforced if BUILD_SYCL=ON) +BUILD_SYCL Build libtinytc_sycl for SYCL support +============================ ========================================================================= Linking in a CMake project ========================== @@ -91,7 +77,7 @@ in your CMakeLists.txt to find the Tiny Tensor Compiler. For non-standard installation directories add -DCMAKE_PREFIX_PATH=/path/to/installation when invoking cmake. -Runtime support is split in the three library libtinytc_ze, libtinytc_cl, and libtinytc_sycl. +Runtime support is split in the three libraries libtinytc_ze, libtinytc_cl, and libtinytc_sycl. The BUILD_(LEVEL_ZERO, OPENCL, SYCL) options control which libraries are built, respectively. For example, when using OpenCL only, you can set BUILD_SYCL=OFF such that you do not need a C++ compiler with SYCL support. diff --git a/docs/manual/builder.rst b/docs/manual/builder.rst index f05a67f4..03f59d9b 100644 --- a/docs/manual/builder.rst +++ b/docs/manual/builder.rst @@ -21,7 +21,9 @@ Consider the following simple copy kernel .. code-block:: func @copy(%A: memref<${type}x${M}x${N}>, %B: memref<${type}x${M}x${M}>) { - axpby.n 1.0, %A, 0.0, %B : ${type}, memref<${type}x${M}x${N}>, ${type}, memref<${type}x${M}x${N}> + %c0 = constant 0.0 : ${type} + %c1 = constant 1.0 : ${type} + axpby.n %c1, %A, %c0, %B } In the following example we build the above code programmatically and replace the place-holders (${.}) @@ -33,63 +35,91 @@ by actual values: .. code:: C - tinytc_scalar_type_t type = ...; + tinytc_scalar_type_t sty = ...; int64_t M = ...; int64_t N = ...; - tinytc_data_type_t dt; + char const *copy_fun_name = "copy"; + uint32_t num_results; + uint32_t num_params; + tinytc_compiler_context_t ctx; + tinytc_prog_t program; + tinytc_data_type_t void_ty, element_ty, ty; + tinytc_func_t copy_fun; + tinytc_region_t copy_body; + tinytc_inst_t tmp; + tinytc_value_t params[2]; + tinytc_value_t alpha, beta; + + tinytc_compiler_context_create(&ctx); + + // Create program + tinytc_prog_create(&program, ctx, NULL); + + // Get types + tinytc_scalar_type_get(&element_ty, ctx, sty); int64_t shape[2] = {M, N}; - tinytc_memref_type_create(&dt, type, 2, shape, 0, NULL, NULL); - - tinytc_value_t A, B, alpha, beta; - tinytc_value_create(&A, dt, NULL); - tinytc_value_create(&B, dt, NULL); - tinytc_float_imm_create(&alpha, 1.0, type, NULL); - tinytc_float_imm_create(&beta, 0.0, type, NULL); - tinytc_data_type_release(dt); - - tinytc_inst_t copy_inst; - tinytc_axpby_inst_create(©_inst, tinytc_transpose_N, 0, alpha, A, beta, B, NULL); - tinytc_value_release(alpha); - tinytc_value_release(beta); - - tinytc_func_t copy_proto; - tinytc_value_t args[2] = {A, B}; - tinytc_function_prototype_create(©_proto, "copy", 2, args, NULL); - tinytc_value_release(A); - tinytc_value_release(B); + tinytc_memref_type_get(&ty, element_ty, 2, shape, 0, NULL, tinytc_address_space_global, NULL); - tinytc_region_t copy_body; - tinytc_region_create(©_body, 1, ©_inst, NULL); - tinytc_inst_release(copy_inst); + // Get void type + tinytc_void_type_get(&void_ty, ctx); - tinytc_func_t copy_fun; - tinytc_function_create(©_fun, copy_proto, copy_body, NULL); - tinytc_func_release(copy_proto); - tinytc_region_release(copy_body); + // Create function + tinytc_data_type_t param_types[2] = {ty, ty}; + tinytc_func_create(©_fun, sizeof(copy_fun_name) - 1, copy_fun_name, 2, param_types, void_ty, + NULL); + tinytc_prog_add_function(program, copy_fun); - tinytc_prog_t program; - tinytc_program_create(&program, 1, ©_fun, NULL); - tinytc_func_release(copy_fun); + // Get body + tinytc_func_get_body(copy_fun, ©_body); + num_params = 2; + tinytc_region_get_parameters(copy_body, &num_params, params); + + // Create instructions + tinytc_constant_inst_create_one(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &alpha); + tinytc_region_append(copy_body, tmp); + + tinytc_constant_inst_create_zero(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &beta); + tinytc_region_append(copy_body, tmp); + + tinytc_axpby_inst_create(&tmp, tinytc_transpose_N, 0, alpha, params[0], beta, params[1], NULL); + tinytc_region_append(copy_body, tmp); + + // Dump program + tinytc_prog_dump(program); + + // Clean-up + tinytc_prog_release(program); + tinytc_compiler_context_release(ctx); .. tab:: C++ .. code:: C++ - scalar_type type = ...; + scalar_type sty = ...; int64_t M = ...; int64_t N = ...; - auto pb = program_builder{}; - pb.create("copy", [&](function_builder &fb) { - auto dt = make_memref(type, {M, N}); - auto A = fb.argument(dt); - auto B = fb.argument(dt); - fb.body([&](region_builder &bb) { - auto alpha = make_imm(1.0, type); - auto beta = make_imm(0.0, type); - bb.add(make_axpby(transpose::N, false, alpha, A, beta, B)); - }); - }); - auto program = pb.get_product(); + auto ctx = make_compiler_context(); + auto element_ty = get_scalar(ctx, sty); + auto ty = get_memref(element_ty, {M, N}); + + auto f = make_func("copy", {ty, ty}, get_void(ctx)); + + auto body = f.get_body(); + std::array params; + body.get_parameters(params); + + auto bb = region_builder{body}; + auto alpha = bb.add(make_constant_one(element_ty)); + auto beta = bb.add(make_constant_zero(element_ty)); + bb.add(make_axpby(transpose::N, false, alpha, params[0], beta, params[1])); + + auto p = make_prog(ctx); + p.add_function(std::move(f)); + p.dump(); diff --git a/docs/manual/calling_convention.rst b/docs/manual/calling_convention.rst index f023e5fb..106130cc 100644 --- a/docs/manual/calling_convention.rst +++ b/docs/manual/calling_convention.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _calling convention: + ================== Calling convention ================== @@ -70,6 +72,17 @@ leads to Note that `memref_example3` and `memref_example4` have the same signature, because `memref` has the canonical stride `strided<1,5,?>`. +.. _memref alignment requirements: + +**Memory alignment:** The base pointer must be sufficiently aligned. +The required alignment depends on the core info and may be queried with +:ref:`tinytc_core_info_get_default_alignment`. +Using :ref:`tinytc_core_info_set_default_alignment` the alignment requirements may be overriden. +The alignment requirement may also be overriden per memref using the +:ref:`"alignment" attribute `. +When the core info object is created using :ref:`tinytc_cl_core_info_create` the default alignment +is queried from `CL_DEVICE_MEM_BASE_ADDR_ALIGN `_. + Group types =========== @@ -77,23 +90,28 @@ A group argument might require multiple arguments in the OpenCL-C code. The rule is that the first argument in the OpenCL kernel is a global pointer to a global pointer to the underlying scalar type of the memref. Then a global pointer argument follows for every '?' in the memref's shape or stride, ordered from left-to-right. -If an dynamic offset is given, the offset is the last argument. +Afterwards, the dynamic group size and the dynamic offset follow if necessary. .. code:: - func @group_example1(%a: group) {} - func @group_example2(%a: group>) {} - func @group_example3(%a: group, offset: ?>) {} + func @group_example1(%a: groupx42) {} + func @group_example2(%a: groupx?>) {} + func @group_example3(%a: groupx?, offset: ?>) {} + func @group_example4(%a: groupx42, offset: ?>) {} leads to .. code:: c kernel void group_example1(global short*global* a) {} - kernel void group_example2(global int*global* a, global long* a_shape1, global long* a_stride2) {} - kernel void group_example3(global float*global* a, global long* a_shape0, long a_offset) {} + kernel void group_example2(global int*global* a, global long* a_shape1, global long* a_stride2, long a_size) {} + kernel void group_example3(global float*global* a, global long* a_shape0, long a_size, long a_offset) {} + kernel void group_example4(global float*global* a, long a_offset) {} -Note that `a_shape_0`, `a_shape1`, and `a_stride2` must contain at least as many values as the group size. -That is, if a is accessed with `load %a[%id] : group>`, then +Note that `a_shape_0`, `a_shape1`, and `a_stride2` must contain at least as many values as the group size (`a_size`). +That is, if a is accessed with `load %a[%id] : memref`, then `*(a_shape0 + id)`, `*(a_shape1 + id)`, and `*(a_stride2 + id)` must not lead to out-of-bounds memory access. + +**Memory alignment:** The memrefs the group points to are subject to the same alignment requirements as a +:ref:`regular memref argument (see above) `. diff --git a/docs/manual/core.rst b/docs/manual/core.rst index 716f752a..727c65ba 100644 --- a/docs/manual/core.rst +++ b/docs/manual/core.rst @@ -5,8 +5,14 @@ Core programming guide ====================== -Memory -====== +Memory management +================= + +Objects are either shared, or unique, or managed. +We detail the memory policies in the following. + +Shared objects +-------------- Objects are created, retained, and released. At creation, objects are constructed and the reference count is set to 1. @@ -30,15 +36,61 @@ does not need the object anymore, or it increased the reference count by one. .. tab:: C++ - C handles are wrapped in the :ref:`shared_handle` class. + C handles are wrapped in the :ref:`tinytc::shared_handle` class. The wrapper implements copy constructor, move constructor, copy operator, and move operator, correctly handling the reference count. Objects of *type* are created using the make\_\ *type* functions. - **Important:** The default constructor of a :ref:`shared_handle` or any of its derivatives - always gives an invalid object, wrapping a nullptr. + **Important:** The default constructor of a :ref:`tinytc::shared_handle` or any of its + derivatives always gives an invalid object, wrapping a nullptr. Always use the make\_\ *type* function unless you know what you are doing. +Unique objects +-------------- + +Unique objects always have a single owner. In the C-API, when the object is passed to a function, the ownership +may be passed to another object when the function's documentation says so. +For example, when adding an instruction to a region, the region takes ownership of the instruction and the user must not +destroy the instruction as that would lead to a double free. +In C++, the copy constructor is deleted and unique objects must be moved when a function transfers ownership. + +.. tabs:: + + .. tab:: C + + An object is created and deleted with + + * tinytc\_\ *type*\ _create + + * tinytc\_\ *type*\ _destroy + + .. tab:: C++ + + C handles are wrapped in the :ref:`tinytc::unique_handle` class. + The wrapper deletes the copy constructor and copy operator, and implements the move constructor and move operator. + Objects of *type* are created using the make\_\ *type* functions. + +Managed ojects +-------------- + +Some objects are never created or deleted but looked up in a parent object. +In that case the user never needs to destroy the object but only the parent object. +Care must be taken that the parent object is not deleted while the managed object is still in use. + +.. tabs:: + + .. tab:: C + + The object is obtained with + + * tinytc\_\ *type*\ _get + + .. tab:: C++ + + The object is obtained with + + * tinytc::get\_\ *type* + Error ===== @@ -53,7 +105,7 @@ The C-API returns error codes, the C++-API throws exceptions. .. tab:: C++ - Functions throw the :ref:`status` enum. + Functions throw the :ref:`tinytc::status` enum. The following minimum error handling code is recommended: .. code:: C++ @@ -66,25 +118,33 @@ The C-API returns error codes, the C++-API throws exceptions. std::cerr << e.what() << std::endl; } - **Hint:** The IR builder API throws the :ref:`builder_error` (deriving from std::exception) - instead of the status enum for better source code location tracking. + **Hint:** The IR builder API throws the :ref:`tinytc::builder_error` + (deriving from std::exception) instead of the status enum for better + source code location tracking. Parser ====== Programs written in the :ref:`tensor language ` are parsed from a file, stdin, or a string. -A :ref:`tinytc_source_context_t` (:ref:`source_context`) object can be attached any of the parse functions. -The source context stores the file name and source text and enhances error messages with source code context. -For example, if a parse error occurs, -then the error log of the source context contains the following error message: +The :ref:`tinytc_compiler_context_t` (:ref:`tinytc::compiler_context`) object controls optimization level, optimization flags, +and error logging. (The default compiler context does not print or log errors.) +When an error reporter is installed via :ref:`tinytc_compiler_context_set_error_reporter`, +then errors are printed along with source code locations and source context. +For example: .. code-block:: + test/lit/opt/check-ir/type_mismatch0.ir:6.8-23: Type of operand must match return type + func @kernel(%K0: memref) { - %0 = load %K0[] : memref - ~~~~~~~~~~~~~~~~~~~ - test/codegen/type_mismatch0.ir:6.13-31: Type of SSA value does not match operand type + + %0 = load %K0[] : f64 + ~~~~~~~~~~~~~~~~ + test/lit/opt/check-ir/type_mismatch0.ir:5.14-16: value defined here + + func @kernel(%K0: memref) { + ~~~ .. tabs:: @@ -96,21 +156,20 @@ then the error log of the source context contains the following error message: .. code:: C tinytc_status_t status; - tinytc_source_context_t source_ctx = NULL; + tinytc_compiler_context_t ctx = NULL; tinytc_prog_t program = NULL; - status = tinytc_source_context_create(&source_ctx); + status = tinytc_compiler_context_create(&ctx); // ... check status ... - status = tinytc_parse_file(&program, "test/codegen/type_mismatch0.ir"), source_ctx) + status = tinytc_compiler_context_set_error_reporter(ctx, error_callback, NULL); + // ... check status ... + status = tinytc_parse_file(&program, "test/lit/opt/check-ir/type_mismatch0.ir", ctx) if (status != tinytc_status_success) { printf("Error: %d\n", status); - char const* error_log; - status = tinytc_source_context_get_error_log(source_ctx, &error_log); - // ... check status ... - printf("Error log:\n%s\n", error_log); } // ... + err: tinytc_prog_release(program); - tinytc_source_context_release(source_ctx); + tinytc_compiler_context_release(ctx); .. tab:: C++ @@ -119,29 +178,31 @@ then the error log of the source context contains the following error message: .. code:: C++ try { - auto source_ctx = tinytc::make_source_context(); - auto program = tinytc::parse_file("test/codegen/type_mismatch0.ir", source_ctx); + auto ctx = tinytc::make_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); + auto program = tinytc::parse_file("test/lit/opt/check-ir/type_mismatch0.ir", ctx); } catch (tinytc::status const& st) { std::cerr << "Error: " << tinytc::error_string(st) << std::endl; - std::cerr << "Error log: " << source_ctx.get_error_log() << std::endl; + } catch (std::exception const &e) { + std::cerr << e.what() << std::endl; } Compiler ======== -Program objects (:ref:`tinytc_prog_t`, :ref:`prog`) are online-compiled -using the :ref:`tinytc_prog_compile_to_opencl` (:ref:`compile_to_opencl`) function. +Program objects (:ref:`tinytc_prog_t`, :ref:`tinytc::prog`) are online-compiled +using the :ref:`tinytc_prog_compile_to_spirv_and_assemble` (:ref:`tinytc::compile_to_spirv_and_assemble`) function. The program object is hereby modified as compiler passes are necessary. -A source object is returned that contains OpenCL-C source text. +A binary object is returned that contains the SPIR-V binary. Some compiler passes specialize the code based on properties of the GPU device. -Therefore, a :ref:`tinytc_core_info_t` (:ref:`core_info`) object is required. +Therefore, a :ref:`tinytc_core_info_t` (:ref:`tinytc::core_info`) object is required. It is recommend to query the core info from the runtime using any of the tinytc\_\ *runtime*\ _core_info_create functions (make_core_info in C++), but one may also look up the core info from a table, as done in the example code below. -A source context can be added to capture potential errors in the optimizer. - .. tabs:: .. tab:: C @@ -152,12 +213,12 @@ A source context can be added to capture potential errors in the optimizer. tinytc_status_t status; tinytc_core_info_t info = NULL; - tinytc_source_t source = NULL; + tinytc_binary_t bin = NULL; status = tinytc_core_info_intel_create_from_arch(&info, tinytc_intel_gpu_architecture_pvc); // ... check status ... - status = tinytc_prog_compile_to_opencl(&source, program, info, source_ctx); + status = tinytc_prog_compile_to_spirv_and_assemble(&bin, program, info); // ... - tinytc_source_release(source); + tinytc_binary_release(source); tinytc_core_info_release(info); .. tab:: C++ @@ -168,18 +229,27 @@ A source context can be added to capture potential errors in the optimizer. try { auto info = tinytc::make_core_info_intel_from_arch(tinytc::intel_gpu_architecture::pvc); - auto source = tinytc::compile_to_opencl(program, info, source_ctx); + auto source = tinytc::compile_to_spirv_and_assemble(program, info); } catch (tinytc::status const& st) { ... } .. note:: - Code generation targets OpenCL-C. Currently, the library requires the - `cl_intel_required_subgroup_size `_ extension, - the `cl_intel_subgroups `_ extension, - the `cl_intel_subgroups_long `_ extension, - and the `cl_intel_subgroups_short `_ extension. + Code generation targets SPIR-V. + As a minimum, the Addresses, SubgroupDispatch, and Int64 capability must be supported by the runtime. + + + Further capabilites are required for specific functionality: + + * Int(8|16) for i8, i16 ints + * Float(16|64) for f16, f64 floats + * Int64Atomics for atomics on i64 + * Groups for work group operations (e.g. broadcast) + * AtomicFloat(16|32|64)AddExt for atomics on f16, f32, f64 (SPV_EXT_shader_atomic_float[16]_add extensions) + * BFloat16ConversionINTEL for bf16 support (SPV_INTEL_bfloat16_conversion extension) + * SubgroupBufferBlockIOINTEL for efficient block loads and stores (SPV_INTEL_subgroups extension) + Device info =========== @@ -229,11 +299,11 @@ run-time. Runtime ======= -The JIT compiler compiles tensor programs into OpenCL-C code. +The JIT compiler compiles tensor programs into SPIR-V binaries. The libray provides functions to create the runtime's kernel bundle object -(cl_program, sycl::kernel_bundle, ze_module_handle_t) from a source object. +(cl_program, sycl::kernel_bundle, ze_module_handle_t) from a binary object. The runtime's kernel objects are obtained using the native API or the Tiny Tensor Compiler API (if applicable). -Setting the kernel arguments should following the :ref:`calling convention`. +Setting the kernel arguments should follow the :ref:`calling convention `. The Tiny Tensor Compiler should be used to translate the 2D work-group size of the tensor language to a 3D work-group size, and to translate the group size to the global size that is passed to the runtime. @@ -245,29 +315,29 @@ Example for "func @foo(%a: i32, ...) { ... }" (without error handling code): .. code:: C - ze_module_handle_t module = NULL; + ze_module_handle_t bundle = NULL; ze_kernel_handle_t kernel = NULL; int a = 42; - tinytc_ze_kernel_bundle_create_with_source(&module, context, device, source, source_ctx); - tinytc_ze_kernel_create(&kernel, module, "foo"); // Sets the work-group size + tinytc_ze_kernel_bundle_create_with_binary(&bundle, context, device, bin); + tinytc_ze_kernel_create(&kernel, bundle, "foo"); // Sets the work-group size zeKernelSetArgumentValue(kernel, 0, sizeof(a), &a); // ... ze_group_count_t group_count = tinytc_ze_get_group_count(howmany); zeCommandListAppendLaunchKernel(command_list, kernel, &group_count, NULL, 0, NULL); // ... zeKernelDestroy(kernel); - zeModuleDestroy(module); + zeModuleDestroy(bundle); .. tab:: OpenCL (C) .. code:: C - cl_program program = NULL; + cl_program bundle = NULL; cl_kernel kernel; cl_int err; int a = 42; - tinytc_cl_kernel_bundle_create_with_source(&program, context, device, binary, source_ctx); - kernel = clCreateKernel(program, "foo", &err); + tinytc_cl_kernel_bundle_create_with_binary(&bundle, context, device, bin); + kernel = clCreateKernel(bundle, "foo", &err); clSetKernelArg(kernel, 0, sizeof(a), &a); // ... size_t ls[3], gs[3]; @@ -276,13 +346,13 @@ Example for "func @foo(%a: i32, ...) { ... }" (without error handling code): clEnqueueNDRangeKernel(command_list, kernel, 3u, NULL, gs, ls, 0, NULL, NULL); // ... clReleaseKernel(kernel); - clReleaseProgram(program); + clReleaseProgram(bundle); .. tab:: SYCL (C++) .. code:: C++ - auto bundle = tinytc::make_kernel_bundle(context, device, source, source_ctx); + auto bundle = tinytc::make_kernel_bundle(context, device, bin); auto kernel = tinytc::make_kernel(bundle, "foo"); auto exe_range = tinytc::get_execution_range(kernel, howmany); queue.submit([&](sycl::handler &h) { @@ -314,8 +384,8 @@ The general usage of a recipe is as following: tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; - tinytc_recipe__create(&recipe, info, , source_ctx); - tinytc_ze_recipe_handler_create(&handler, context, device, recipe, source_ctx); + tinytc_recipe__create(&recipe, info, , ctx); + tinytc_ze_recipe_handler_create(&handler, context, device, recipe, ctx); tinytc_recipe__set_args(handler, ); tinytc_ze_recipe_handler_submit(handler, command_list, NULL, 0, NULL); // ... @@ -328,8 +398,8 @@ The general usage of a recipe is as following: tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; - tinytc_recipe__create(&recipe, info, , source_ctx); - tinytc_cl_recipe_handler_create(&handler, context, device, recipe, source_ctx); + tinytc_recipe__create(&recipe, info, , ctx); + tinytc_cl_recipe_handler_create(&handler, context, device, recipe, ctx); tinytc_recipe__set_args(handler, ); tinytc_cl_recipe_handler_submit(handler, queue, 0, NULL, NULL); // ... @@ -341,7 +411,7 @@ The general usage of a recipe is as following: .. code:: C++ auto handler = tinytc::make_recipe_handler(queue, - tinytc::make_(info, , source_ctx), source_ctx); + tinytc::make_(info, , ctx), ctx); ::set_args(handler, ); handler.submit(queue); @@ -362,7 +432,7 @@ For example: tinytc_recipe__set_args(..., A, tinytc_mem_type_usm_pointer, ...); In C++, one only needs to pass the memory object. -The memory object is implicitly converted to the :ref:`mem` type that +The memory object is implicitly converted to the :ref:`tinytc::mem` type that automatically determines whether a pointer or a cl_mem object is given. A pointer maps to tinytc_mem_type_usm_pointer and a cl_mem object maps to tinytc_mem_type_buffer. diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 3f7bcdeb..9e4aa1c9 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -15,23 +15,26 @@ Execution model The unit of execution described by a function written in the tensor language is called a **kernel**. Kernels are launched in batches, where each instance of the kernel is called a work-group. -The kernel has access to its group id that is used to select the work done in the work group. -Each work group consists of a fixed number of work-items that execute concurrently. -The language distinguishes between two kinds of instructions: *replicated* and *collective* instructions. -It is distinguished between *mixed* and *spmd* regions. -Mixed regions may contain replicated and collective instructions whereas spmd regions -may only contain replicated instructions. - -A collective instruction distributes the work among the work-items. -The instruction is responsible to distribute the work in a sensible manner. - -A replicated instruction replicates the work across all work-items. -In a mixed region, the replicated instructions always operate on the same data. -In spmd regions, the replicated instructions can operate on multiple data, -but in these regions collective instructions are prohibited. - -Mixed regions can be nested whereas spmd regions must not be nested. -A mixed region may be nested in a spmd region. +The kernel has access to a three dimensional group id that is used to select the work done in the work group. +Each work group consists of a fixed number of subgroups that execute concurrently. +Subgroups can be further divided into work-items, where the number of work-items per subgroup +is given by the subgroup size. + +The language distinguishes between *collective*, *SPMD*, and *mixed* instructions. +A collective instruction distributes the work among the work-items in an implementation-defined manner. +Local variables passed to or returned from a collective instruction are always uniform, meaning +that each work-item holds the same value. +An SPMD instruction follows the OpenCL execution model, where local variables may have a different value +for each work-item. +Mixed instructions accept both varying and uniform local variables. + +In an SPMD region, we call an argument *dynamically uniform* if all work-items in a subgroup have +the same value. + +Regions come in two different kinds: collective and SPMD. +A collective instructions must only appear in a collective region, and an SPMD instruction +must only appear in a SPMD region. Mixed instructions might appear in both kinds of regions. +SPMD regions may be nested in collective regions but collective regions must not be nested in SPMD regions. Core rules ========== @@ -54,7 +57,9 @@ are prefixed with ``@``. .. code:: abnf - identifier = 1*DIGIT / (ALPHA *(ALPHA / DIGIT / "_")) + identifier = unnamed-identifier / named-identifier + unnamed-identifier = 1*DIGIT + named-identifier = ALPHA *(ALPHA / DIGIT / "_") local-identifier = "%" identifier global-identifier = "@" identifier @@ -63,15 +68,18 @@ Constants .. code:: abnf + constant = boolean-constant / integer-constant / floating-constant / complex-constant + boolean-constant = "true" / "false" + integer-constant = [sign] 1*DIGIT sign = "-" / "+" - integer-constant = "true" / "false" / [sign] 1*DIGIT - floating-constant = [sign] *DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] + floating-constant = [sign] (*DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] / "inf" / "nan") mantissa-dec = *DIGIT "." 1*DIGIT / 1*DIGIT "." mantissa-hex = *HEXDIG "." 1*HEXDIG / 1*HEXDIG "." exponent = [sign] 1*DIGIT floating-constant-dec = [sign] (mantissa-dec ["e" exponent] / 1*DIGIT "e" exponent) floating-constant-hex = [sign] "0x" (mantissa-hex ["p" exponent] / 1*HEXDIG "p" exponent) floating-constant = floating-constant-dec / floating-constant-hex + complex-constant = "[" floating-constant "," floating-constant "]" Integer constants must lie in the range :math:`-2^{63}+1,\dots,2^{63}-1`. @@ -80,6 +88,32 @@ The hexadecimal floating point syntax is supported, too. `strtod `_ can be used for parsing floating point numbers. +Attributes +========== + +.. code:: abnf + + attribute = array-attribute / + boolean-attribute / + dictionary-attribute / + integer-attribute / + string-attribute + array-attribute = "[" [attribute *(", " attribute)] "]" + boolean-attribute = boolean-constant + dictionary-attribute = "{" [named-attribute *("," named-attribute)] "}" + named-attribute = attribute-name "=" attribute + attribute-name = "alignment" / + "shape_gcd" / + "stride_gcd" / + "subgroup_size" / + "unroll" / + "work_group_size" / + string-attribute + integer-attribute = integer-constant + string-attribute = %x22 *(%x20-21 / %x23-7E) %x22 + +Attributes add information about an operation, for example to assert properties or to direct the compiler. + .. _tensor language functions: Functions @@ -87,16 +121,30 @@ Functions .. code:: abnf - function-definition = "func" global-identifier "(" [argument-list] ")" *attribute region + function-definition = "func" global-identifier "(" [argument-list] ")" + ["attributes" dictionary-attribute] region argument-list = argument *("," argument) - argument = local-identifier ":" type - attribute = work-group-size-attribute / subgroup-size-attribute - work-group-size-attribute = "work_group_size" "(" 1*DIGIT "," 1*DIGIT ")" - subgroup-size-attribute = "subgroup_size" "(" 1*DIGIT ")" + argument = local-identifier ":" type [dictionary-attribute] Defines a function that is callable from the host. -Attributes are optional and autoatically determined if omitted. +Attributes +---------- + +Subgroup size and work-group size are determined automatically by the compiler, but can be overriden +using the function's attribute dictionary: + +.. list-table:: + + * - Name + - Type + - Description + * - subgroup_size + - integer-attribute + - Subgroup size; valid values depend on the target device (typically 16 or 32) + * - work_group_size + - array-attribute with 2 integer-attribute entries + - Two dimensional work-group size in number of work-items The work-group size attribute defines the size of the local work group. Due to the focus on matrix operations, the work-group size is always two-dimensional, @@ -108,11 +156,35 @@ the subgroup sizes supported by the device. The product of the work-group size modes must be smaller or equal than the maximum work-group size of device. -The work-group is divided into full subgroups, therefore the work-group size -is always a multiple of the subgroup size. -The subgroup size attribute enforces a particular subgroup device supported by +The subgroup size attribute enforces a particular subgroup size that must be supported by the device. +Parameter attributes +-------------------- + +Parameters with memref or group type accept the following named attributes: + +.. list-table:: + + * - Name + - Type + - Description + * - alignment + - integer-attribute + - Minimum pointer alignment + * - shape_gcd + - array-attribute of integer-attribute + - Greatest common divisors of shape + * - stride_gcd + - array-attribute of integer-attribute + - Greatest common divisors of stride + +Cf. the documentation of the :ref:`memref type ` and the :ref:`group type `. + +Restrictions +------------ + +* Arguments must not have coopmatrix type. Regions ======= @@ -131,31 +203,88 @@ Types .. code:: abnf - type = void-type / scalar-type / memref-type / group-type + type = void-type / boolean-type / scalar-type / memref-type / group-type void-type = "void" +Boolean type +------------ + +.. code:: abnf + + boolean-type = "bool" + +Boolean type that only has two states (true or false). + Scalar types ------------ .. code:: abnf - scalar-type = integer-type / floating-type - integer-type = "i" ("1" / "8" / "16" / "32" / "64") / "index" - floating-type = "f" ("32" / "64") + scalar-type = integer-type / floating-type / complex-type + integer-type = "i8" / "i16" / "i32" / "i64" / "index" + floating-type = "bf16" / "f16" / "f32" / "f64" + complex-type = "c32" / "c64" -Scalar types are either signless integer ("i") or floating point ("f"). +Scalar types are either signless integer ("i"), floating point ("f"), +or complex floating point ("c"). The number behind the scalar type prefix denotes the number of bits, e.g. "f64" are double precision floating point numbers. +The "bf16" type encodes bfloat16 floating point numbers. The "index" type is an integer type whose width is platform-specific. +Type sizes in bytes are given by + +=========================== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +:math:`\alpha` i8 i16 i32 i64 bf16 f16 f32 f64 c32 c64 +=========================== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +:math:`\text{size}(\alpha)` 1 2 4 8 2 2 4 8 8 16 +=========================== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + + +Mixed precision operands might be allowed in instructions if the operands' types are *promotable*. +The scalar type :math:`\alpha` may be promoted to the scalar type :math:`\beta` if all values an operand +of type :math:`\alpha` may take can be exactly represented in type :math:`\beta`. +Formally, :math:`\alpha` is promotable to :math:`\beta` if :math:`\alpha \preceq \beta`, +where the partial order :math:`\preceq` is defined by the following relation matrix: + +=============== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +:math:`\preceq` i8 i16 i32 i64 bf16 f16 f32 f64 c32 c64 +=============== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +i8 1 1 1 1 1 1 1 1 1 1 +i16 1 1 1 1 1 1 1 +i32 1 1 1 1 1 +i64 1 +bf16 1 1 1 1 1 +f16 1 1 1 1 1 +f32 1 1 1 1 +f64 1 1 +c32 1 1 +c64 1 +=============== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + +Moreover, for scalar types :math:`\alpha,\beta` we define + +.. math:: + + \text{promote}(\alpha, \beta) = \left\{\begin{array}{rcl} + \beta & \text{ if } & \alpha \preceq \beta, \\ + \alpha & \text{ if } & \beta \preceq \alpha, \\ + \text{fail} & \text{ else.} + \end{array}\right. + +Here, "fail" means that the promotion is not allowed and the compiler should throw an error. + + + Memref type ----------- .. code:: abnf - memref-type = "memref<" scalar-type tensor-shape ["," memory-layout] ">" + memref-type = "memref<" scalar-type tensor-shape ["," memory-layout] ["," address-space] ">" constant-or-dynamic = integer-constant / "?" tensor-shape = *("x" constant-or-dynamic) + address-space = "global" / "local" A memref is a reference to a region of memory. In analogy to the C/C++-language, the memref can be thought of as a pointer, @@ -182,6 +311,31 @@ E.g. the memory layout of ``memref`` is ``strided<1,5,30>``. We note that ``memref`` and ``memref>`` are the same type. +Memrefs have an optional address space attribute. +The global address space referse to memory objects allocated from the global memory pool +that is shared by all work groups. +The local memory space is shared by all work-items of the work-group but inaccessible to another work-group. +The default address space is "global", memrefs with "local" address space are returned by +the alloca instruction. + +Definitions +........... + +Let V be a value of memref type. +The :math:`\text{order}(V)` operation returns the memref's order. +The :math:`\text{shape}(V)` returns the tensor shape as tuple. +:math:`\text{rows}(V)` and :math:`\text{columns}(V)` return the size of the first +and second mode, respectively. +The :math:`\text{element_type}(V)` operation gives the underlying scalar type. + +For example, let B be a value of memref type, then + +* :math:`\text{order}(B) = 3` +* :math:`\text{shape}(B) = (8,16,4)` +* :math:`\text{rows}(B) = 8` +* :math:`\text{columns}(B) = 16` +* :math:`\text{element_type}(B) = \text{f32}` + Memory layout ............. @@ -190,6 +344,8 @@ Memory layout memory-layout = strided-layout +.. _strided layout: + Strided layout ~~~~~~~~~~~~~~ @@ -219,825 +375,1794 @@ The default packed dense layout is given by Stride modes might be dynamic as well, indicated by a question mark. +.. _memref attributes: + +Alignment attribute +................... + +The *alignment=X* attribute gives the alignment X of the memref's base pointer in bytes. +That is, for the pointer P pointing to the first element of the memref we must have :math:`P = 0 \pmod{X}`. + +**Restriction:** The alignment must be a multiple of the size of the memref's element type. + + +Greatest common divisor (GCD) attributes +........................................ + +The *shape_gcd=[d_1,...,d_k]* attribute asserts that :math:`s_i = 0 \pmod{d_i}, i=1,\dots,k`, where k is +smaller or equal than the order of the tensor n and :math:`s_i` is the i-th entry of the shape vector. +The divisors are understood to be the greatest common divisors for the set of shapes that the kernel is used for. +For example, if we know that :math:`s_1` is always a multiple of 4 then we can set *shape_gcd=[4]*. + +The *stride_gcd=[D_1,...,D_m]* attribute asserts that :math:`S_i = 0 \pmod{D_i}, i=1,\dots,m`, where m is +smaller or equal than the order of the tensor n and :math:`S_i` is the i-th entry of the stride vector. +The divisors are understood to be the greatest common divisors for the set of strides that the kernel is used for. +For example, if we know that :math:`S_2` is always a multiple of 4 then we can set *stride_gcd=[1,4]*. + Group type ---------- .. code:: abnf - group-type = "group<" memref-type ["," "offset" ":" constant-or-dynamic] ">" + group-type = "group<" memref-type "x" constant-or-dynamic ["," "offset" ":" constant-or-dynamic] ">" The group type collects unstructured pointers to memref's with potentially different dynamic mode sizes. The C-analogy of a group is a pointer-to-a-pointer. -For example, the C-analogue of a ``group>`` is a ``float**``. +For example, the C-analogue of a ``groupx?>`` is a ``float**``. + +The group shape is always one-dimensional and may be queried using the +:ref:`size instruction `. The optional offset parameter is used to offset each pointer by the given number of elements. Given the C-analogue ``float** group``, loading element ``i`` with offset ``off`` gives the pointer ``float* tmp = group[i] + off``. The default offset is 0. -Dynamic values ('?') may appear in the memref-type and in the offset. +Dynamic values ('?') may appear in the memref-type, in the group shape, and in the offset. These values are stored in the dope vector; the calling convention for groups is implementation-defined. -Instructions -============ +.. _group attributes: -.. code:: abnf +Attributes +.......... - value-instruction = local-identifier "=" (alloca-instruction - / arith-binary-instruction - / arith-unary-instruction - / cast-instruction - / comparison-instruction - / expand-instruction - / fuse-instruction - / group-id-instruction - / group-size-instruction - / load-instruction - / size-instruction - / subview-instruction) - multi-value-instruction = [local-identifier-list "="] if-instruction - local-identifier-list = local-identifier *("," local-identifier) - instruction = value-instruction - / multi-value-instruction - / axpby-instruction - / barrier-instruction - / for-instruction - / foreach-instruction - / lifetime-stop-instruction - / gemm-instruction - / gemv-instruction - / ger-instruction - / hadamard-product-instruction - / store-instruction - / sum-instruction - / yield-instruction +Attributes applied on a group type are passed through to the memrefs. +That is, when a memref is loaded from the group then the :ref:`memref attributes ` +are equal to the attributes of the group. -Alloca ------- +Cooperative matrix type +----------------------- .. code:: abnf - alloca-instruction = "alloca" "->" memref-type - -Overview -........ - -*Collective instruction.* -The alloca instruction allocates temporary memory that is freed automatically at the end of the block that contains the alloca. - -Returns -....... + coopmatrix-type = "coopmatrix<" scalar-type 2*2("x" integer-constant) "," matrix-use ">" + matrix-use = "matrix_a" / "matrix_b" / "matrix_acc" -A memref of the memref-type. +The coopmatrix represents a matrix distributed across a subgroup, where each work-item in a subgroup +stores a part of the matrix. +The scalar-type specifies the matrix element type, the first integer-constant the number of rows, +and the second integer-constant the number of columns. +The matrix-use may affect the distribution of the matrix in the subgroup, and the name refers to the +position of the matrix in a matrix multiplication. -Restrictions -............ +Not all matrix shapes need to be supported in the implementation. +The supported matrix shapes may depend on data type, matrix use, and target hardware. -The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +An argument to any instruction that has coopmatrix type **must** be dynamically uniform. -Arithmetic (binary) -------------------- +Definitions +........... -.. code:: abnf +Let V be a value of coopmatrix type. +The :math:`\text{rows}(V)` and :math:`\text{columns}(V)` functions return the size of the first +and second mode, respectively, and :math:`\text{shape}(V)` returns rows and cols as tuple. +The :math:`\text{component_type}(V)` operation gives the underlying scalar type +and :math:`\text{use}(V)` returns the use. - identifier-or-constant = local-identifier / integer-constant / floating-constant - arith-binary-type = ".add" / - ".sub" / - ".mul" / - ".div" / - ".rem" / - ".shl" / - ".shr" / - ".and" / - ".or" / - ".xor" - arith-binary-instruction = "arith" arith-binary-type - identifier-or-constant "," identifier-or-constant ":" scalar-type +For example, let B be a value of coopmatrix type, then -Overview -........ +* :math:`\text{shape}(B) = (8,16)` +* :math:`\text{rows}(B) = 8` +* :math:`\text{columns}(B) = 16` +* :math:`\text{component_type}(B) = \text{f32}` +* :math:`\text{use}(B) = \text{matrix_acc}` -*Replicated instruction.* -Binary arithmetic operation on scalars. -Both operands, as well as the returned type, have the same scalar type. - -==== ============ ============================================================================== -Op Allowed type Description -==== ============ ============================================================================== -.add scalar-type Sum of operands -.sub scalar-type Difference of operands -.mul scalar-type Product of operands -.div scalar-type Quotient of operands -.rem scalar-type Remainder from the division of operands -.shl integer-type Left shift first operand by number of bits given by second operand -.shr integer-type Arithmetic right shift first operand by number of bits given by second operand -.and integer-type Bitwise and -.or integer-type Bitwise or -.xor integer-type Bitwise xor -==== ============ ============================================================================== +Instructions +============ -Arithmetic (unary) ------------------- +Instructions may return zero, one, or multiple values, and follow the following format: .. code:: abnf - arith-unary-type = ".neg" / ".not" - arith-unary-instruction = "arith" arith-unary-type identifier-or-constant ":" scalar-type + value-instruction-assignment = local-identifier "=" value-instruction + multi-value-instruction-assignment = [local-identifier-list "="] multi-value-instruction + local-identifier-list = local-identifier *("," local-identifier) + instruction = value-instruction-assignment + / multi-value-instruction-assignment -Overview -........ +That is, on the left-hand side we have list of values that are produced by the instruction followed by an equals sign, +or an empty string, if the instruction does not produce values. +On the right-hand side, after the equals sign or empty string, the name of the instruction is written, e.g. "ger", optionally followed by instruction modifiers, e.g. "ger.atomic". +Then, a list of operands follows that is usually comma-seperated but might also be printed in a custom format +(e.g. for "load", "store", "subview", etc.). +If the instruction produces values, then the types of the returned values must be annotated after a colon. -*Replicated instruction.* -Unary arithmetic operation on scalars. -The returned value has the same type as the operand. -==== ============ ============================================================================== -Op Allowed type Description -==== ============ ============================================================================== -.neg scalar-type Negation -.not integer-type Bitwise not -==== ============ ============================================================================== -Cast ----- +Collective instructions +----------------------- + +Alloca +...... .. code:: abnf - cast-instruction = "cast" identifier-or-constant ":" scalar-type "->" scalar-type + value-instruction = "alloca" [dictionary-attribute] ":" memref-type Overview -........ +~~~~~~~~ -*Replicated instruction.* -Cast scalar values. +The alloca instruction allocates temporary memory that is freed automatically at the end of the block that contains the alloca. -Comparison ----------- +Attributes +~~~~~~~~~~ -.. code:: abnf +Alloca accepts the following named attributes: - comparison-instruction = "cmp" (".eq" / ".ne" / ".gt" / ".ge" / ".lt" / ".le") - identifier-or-constant "," identifier-or-constant ":" scalar-type +.. list-table:: -Overview -........ + * - Name + - Type + - Description + * - alignment + - integer-attribute + - Base pointer alignment; must not be larger than the :ref:`default alignment `. -*Replicated instruction.* -Scalar comparison. -Both operands must have the same scalar type and the returned value is boolean. - -==== ===================== -Cond Description -==== ===================== -.eq Equal -.ne Not equal -.gt Greater than -.ge Greater than or equal -.lt Less than -.le Less than or equal -==== ===================== +Restrictions +~~~~~~~~~~~~ -Expand ------- +* The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +* The address space must be "local". + +Axpby +..... .. code:: abnf - expand-instruction = "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type - expand-shape = constant-or-dynamic-or-identifier 1*("x" constant-or-dynamic-or-identifier) - constant-or-dynamic-or-identifier = integer-constant / "?" / local-identifier + transpose = ".t" / ".n" + instruction =/ "axpby" transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier Overview -........ - -*Replicated instruction.* -The expand instruction returns a view on a tensor with a mode viewed as higher-order mode. - -Arguments -......... - -The first argument must point to a value of memref type. -The integer constant in square brackets gives the mode that shall be expanded. -The expand shape gives the new shape of the mode. -Values in the expand shape must have index type. +~~~~~~~~ -The output type is a memref type according to the following rules: - -#. **Shape:** The mode size is replaced with the expand shape. If one entry in expand shape is dynamic, - then either its size is inferred automatically if the mode size is known, or it determined automatically - at run-time if the mode size is dynamic. +Axpby implements - .. code:: +.. math:: - expand %0[1 -> 2x8] : memref ; -> memref - expand %0[1 -> 2x?] : memref ; -> memref - expand %0[1 -> ?x8] : memref ; -> memref - expand %0[1 -> 2x?] : memref ; -> memref - expand %0[1 -> ?x8] : memref ; -> memref + B := \alpha \text{op}(A) + \beta B -#. **Identifiers:** Local identifiers in the expand shape are dynamic in the resulting memref type. +for vectors and matrices, where :math:`\text{op}(X)` is defined as - .. code:: +.. math:: - expand %0[1 -> %1 x ?] : memref ; -> memref - expand %0[1 -> %1 x ?] : memref ; -> memref - expand %0[1 -> %1 x %2] : memref ; -> memref - expand %0[1 -> 4 x %1] : memref ; -> memref + \text{op}(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{transpose} = \text{".t"} \wedge \text{order}(X) = 2,\\ + X & \text{ else. } + \end{array} + \right. -#. **Stride:** A new stride entry is entered that follows the canonical stride computation. +If the atomic flag is set, B is updated atomically. - .. code:: +Operands +~~~~~~~~ - expand %0[0->4x8] : memref> ; -> memref> - expand %0[0->4x?] : memref> ; -> memref> - expand %0[0->?x4] : memref> ; -> memref> - expand %0[0->4x?] : memref> ; -> memref> +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 scalar-type :math:`\beta` +4 memref-type B +======= =========== ============== Restrictions -............ - -At most one mode in expand-shape must be dynamic. +~~~~~~~~~~~~ -The product of the expand shape must be the same as the mode size. -If one entry in the expand shape is dynamic then the other must evenly divide the mode size. +* :math:`\text{shape}(B) = \text{shape}(\text{op}(A))` +* :math:`\text{order}(B) = 0 \lor \text{order}(B) = 1 \lor \text{order}(B) = 2` +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A) \preceq \text{element_type}(B)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -Fuse ----- +Cumulative sum +.............. .. code:: abnf - fuse-instruction = "fuse" local-identifier "[" integer-constant "," integer-constant "]" ":" memref-type + instruction =/ "cumsum" [".atomic"] local-identifier "," local-identifier "," integer-constant "," + local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* -The fuse instruction returns a view on a tensor with two or more adjacent modes viewed as a single mode. +Computes the n-mode cumulative sum -Arguments -......... +.. math:: -The first argument must point to a value of memref type. -The fused modes are specified as the interval [from, to], where from is given -by the first integer and to is given by the second integer. -Counting starts from 0 so we have + B := \alpha A \times_{n} L_{s_n} + \beta B, + +where :math:`L_{s_n}` is the lower triangular matrix of ones of size :math:`s_n\times s_n` and +:math:`s_n` is the n-th entry of the shape vector of A. +In index notation, we have equivalently .. math:: - - 0 \leq from < to < order(memref) -The local identifier must have the memref type specified last. -The output type is a memref type according to the following rules: + B_{i_1\dots i_{n-1}ji_{n+1}\dots i_M} + := \alpha \sum_{i_n=1}^{j}A_{i_1\dots i_{n-1}i_ni_{n+1}\dots i_M} + + \beta B_{i_1\dots i_{n-1}ji_{n+1}\dots i_M}, -#. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. +If the atomic flag is set, B is updated atomically. - .. code:: - fuse %0[1,3] : memref ; -> memref - fuse %0[1,3] : memref> ; -> memref> +Operands +~~~~~~~~ -#. **Stride:** Strides remain unchanged. +======= ================ ================== +Op.-No. Type Description +======= ================ ================== +1 scalar-type :math:`\alpha` +2 memref-type A +3 integer-constant n (summation mode) +4 scalar-type :math:`\beta` +5 memref-type B +======= ================ ================== - .. code:: +Restrictions +~~~~~~~~~~~~ - fuse %0[1,2] : memref> ; -> memref> - fuse %0[0,1] : memref> ; -> memref> +* :math:`\text{order}(A) \geq 1` +* :math:`\text{shape}(A) = \text{shape}(B)` +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A) \preceq \text{element_type}(B)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -Restrictions -............ +Foreach +....... -Let i be the first mode and j the last mode. -The stride vector S and the shape vector s must satisify the following compatibility condition: +.. code:: abnf -:math:`\forall k \in [i,j): S_{k}s_{k} = S_{k+1}` + instruction =/ "foreach" "(" local-identifier-list ")" [":" integer-type] "=" + "(" local-identifier-list ")" "," "(" local-identifier-list ")" region -If S(i:j) and s(i:j) are known at compile time, the fuse instruction is illegal if the compatibility -condition is not satisfied. -If a single entry in S(i:j) or s(i:j) is dynamic, then fusing modes that violate the compatbility condition -is undefined beheaviour. +Overview +~~~~~~~~ -.. code:: +A foreach loop that executes the loop's range without any sequence guarantee. +The region of a foreach is a *spmd region*. + +The three local identifier lists define the loop range and the local identifiers that +make the trip count available within the loop body. +All three lists must have the same length and have the following format: - fuse %0[0,1] : memref> ; Illegal, modes cannot be fused - fuse %0[0,1] : memref> ; Undefined behaviour if dynamic stride != 8 +.. math:: + (\text{var}_1, \dots, \text{var}_N) = (\text{from}_1, \dots, \text{from}_N), + (\text{to}_1, \dots, \text{to}_N), -Group id --------- +where :math:`N` is the common length of each of the three lists. +The loop range is defined as the cartesian product of the half-open intervals +:math:`[\text{from}_i; \text{to}_i)` such that the trip count take the values -.. code:: abnf +.. math:: - group-id-instruction = "group_id" + (\text{var}_1, \dots, \text{var}_N) \in [\text{from}_1; \text{to}_1) \times \dots \times + [\text{from}_N; \text{to}_N) -Overview -........ +The integer type of the loop variable and the loop bounds can be optionally set after the colon. +The default integer type is ``index``. -*Replicated instruction.* -Returns the group id, an integer of type "index" inbetween 0 and the group size - 1. +The mapping of trip count to work-item is implementation-defined. -Group size ----------- +GEMM +.... .. code:: abnf - group-size-instruction = "group_size" + instruction =/ "gemm" transpose transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* -Returns the group size, an integer of type "index". +GEMM implements the well-known GEMM BLAS-3 operation. +.. math:: -Load ----- + C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C -.. code:: abnf +The functions :math:`\text{op}_1` and :math:`\text{op}_2` are defined as - load-instruction = "load" local-identifier "[" [index-list] "]" ":" memref-or-group-type - index-list = identifier-or-int-constant *("," identifier-or-int-constant) - identifier-or-int-constant = integer-constant / local-identifier - memref-or-group-type = memref-type / group-type +.. math:: -Overview -........ + \text{op}_i(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{transpose}_i = \text{".t"},\\ + X & \text{ if } & \text{transpose}_i = \text{".n"}. + \end{array} + \right. -Load the element given by the index list from a memref or group. -The number of indices must match the order of the memref -and a single index must be given for a group. +where transpose\ :sub:`1` and transpose\ :sub:`2` refer to the first and second transpose modifier, respectively. -Arguments -......... +If the atomic flag is set, C is updated atomically. -The first operand must have memref or group type. -The indices must be of ``index`` type. +Operands +~~~~~~~~ -Returns -....... +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 memref-type B +4 scalar-type :math:`\beta` +5 memref-type C +======= =========== ============== -A value of the memref's element type or the group's memref type. -Examples: +Restrictions +~~~~~~~~~~~~ -#. ``load %0[] : memref`` returns a ``f32`` value. -#. ``load %0[5, %1] : memref`` returns a ``f32`` value. -#. ``load %0[%1] : group>`` returns a ``memref`` value. -#. ``load %0[%1] : group, offset: ?>`` returns a ``memref`` value. +* :math:`\text{order}(A) = \text{order}(B) = \text{order}(C) = 2` +* :math:`\text{colums}(\text{op}_1(A)) = \text{rows}(\text{op}_2(B))` +* :math:`\text{rows}(C) = \text{rows}(\text{op}_1(A))` +* :math:`\text{columns}(C) = \text{columns}(\text{op}_2(B))` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(B)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -Size ----- +GEMV +.... .. code:: abnf - size-instruction = "size" local-identifier "[" integer-constant "]" ":" memref-type + instruction =/ "gemv" transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* -The size instruction returns the i-th entry of the tensor's shape, where "i" is given by the integer -constant in square brackets. +GEMV implements the well-known GEMM BLAS-2 operation. -Arguments -......... +.. math:: -The first argument must point to a value of memref type. -The integer constant i gives the mode for which the size shall be returned. -It is required that + c := \alpha \text{op}_1(A) b + \beta c -.. math:: - - 0 \leq i < order(memref) +where :math:`\text{op}_1` is defined as in GEMM. -The local identifier must have the memref type specified last. -The instruction returns an integer of index type. +If the atomic flag is set, c is updated atomically. -Subview -------- +Operands +~~~~~~~~ -.. code:: abnf +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type c +======= =========== ============== - subview-instruction = "subview" local-identifier "[" [index-or-slice-list] "]" ":" memref-type - index-or-slice-list = index-or-slice *("," index-or-slice) - index-or-slice = identifier-or-int-constant [":" (identifier-or-int-constant / "?")] / ":" +Restrictions +~~~~~~~~~~~~ -Overview -........ +* :math:`\text{order}(A) = 2` +* :math:`\text{order}(b) = \text{order}(c) = 1` +* :math:`\text{colums}(\text{op}_1(A)) = \text{rows}(b)` +* :math:`\text{rows}(c) = \text{rows}(\text{op}_1(A))` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(b)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -*Replicated instruction.* -The subview instruction returns a view on a tensor. +GER +... -Arguments -......... +.. code:: abnf -The first argument must point to a value of memref type. -The number of indices in square brackets must match the order of the memref. -The indices are either given as single index or as a slice, where -slices are given in offset plus size notation ("%offset : %size"). -E.g. the slice "%0 : %1" extracts a block of %1 elements beginning from %0, which is equivalent -to the index interval [%0, %0 + %1). + instruction =/ "ger" [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier -.. admonition:: Note +Overview +~~~~~~~~ - A slice is often defined as "%0 : %1" being the index interval [%0, %1). - However, then the compiler needs to figure out whether %1 - %0 is constant or not in order - to determine whether the mode size is known at compile-time or not. - Therefore, we prefer the offset plus size notation. +Computes the general rank-1 update: -A dynamic size ("?") means that the size is the mode size inferred from the memref type -minus the offset. -A plain colon is syntactic sugar for "0:?". +.. math:: -There is no run-time check whether the indices are within bounds. -Offset and size must be of index type. -Offset must be non-negative and size must be positive. + C := \alpha a b^T + \beta C -The local identifier must have the memref type specified last. -The output type is a memref type according to the following rules: +If the atomic flag is set, C is updated atomically. -#. **Invariant-stride:** The stride is not changed. +Operands +~~~~~~~~ - .. code:: +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type a +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type C +======= =========== ============== - subview %0[4:8,8:4] : memref ; Returns memref> +Restrictions +~~~~~~~~~~~~ +* :math:`\text{order}(a) = \text{order}(b) = 1` +* :math:`\text{order}(C) = 2` +* :math:`\text{rows}(C) = \text{rows}(a)` +* :math:`\text{columns}(C) = \text{rows}(b)` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(b)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -#. **Rank-reduction:** A mode accessed by a single constant or value is removed from the output tensor. - .. code:: +Hadamard product +................ - subview %0[2:4, %1] : memref ; Returns memref> - subview %0[2:4, %1:1] : memref ; Returns memref> +.. code:: abnf -#. **Output-mode size:** The size of the output mode is determined by the size field of a slice - and may be dynamic. + instruction =/ "hadamard_product" [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier - .. code:: +Overview +~~~~~~~~ - subview %0[%1:4] : memref ; Returns memref - subview %0[%2:%2] : memref ; Returns memref - subview %0[2:4, %2:%2, 6:7] : memref ; Returns memref - subview %0[2:4, %2:%2, 6:7] : memref> ; Returns memref +Computes the Hadamard product of two vectors or two matrices. +That is, in index notation we have -#. **Dynamic size:** +.. math:: - .. code:: + c_{i} := \alpha a_{i} b_{i} + \beta c_{i} - subview %0[:] : memref ; Returns memref - subview %0[:] : memref ; Returns memref - subview %0[5:?] : memref ; Returns memref - subview %0[%2:?] : memref ; Returns memref +for vectors and -If --- +.. math:: -.. code:: abnf + C_{ij} := \alpha A_{ij} B_{ij} + \beta C_{ij} - if-instruction = "if" identifier-or-int-constant ["->" "(" scalar-type-list ")"] - region ["else" region] - type-list = scalar-type *("," scalar-type) +for matrices. If the atomic flag is set, c/C is updated atomically. -Overview -........ +Operands +~~~~~~~~ -An if statement. -Both regions are *mixed regions*. +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type a/A +3 memref-type b/B +4 scalar-type :math:`\beta` +5 memref-type c/C +======= =========== ============== -The condition must be of bool type. +Restrictions +~~~~~~~~~~~~ -Arguments -......... +* :math:`\text{order}(a) = \text{order}(b) = \text{order}(c) = o` with :math:`o\in\{1,2\}` +* :math:`\text{shape}(a) = \text{shape}(b) = \text{shape}(c)` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(b)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -The if instruction may return multiple values, where the number of values and the value types -are given by the scalar-type-list. -If values are returned, the last instruction in both the "then"-region and the "else"-region must -be a yield instruction (the "else"-region cannot be omitted). +Parallel +........ -Example: +.. code:: abnf - .. code:: + instruction =/ "parallel" region - %1 = cmp.lt %0, 16 : i32 - %x = if %1 -> (i32) { - yield %0 : i32 - } else { - yield 16 : i32 - } +Overview +~~~~~~~~ -Axpby ------ +Opens an *spmd region*. + +Sum +... .. code:: abnf - transpose = ".t" / ".n" - const-or-val = floating-constant / local-identifier - axpby-instruction = "axpby" transpose [".atomic"] - const-or-val "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + instruction =/ "sum" transpose [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Collective instruction.* -Axpby implements +Computes the matrix-vector product or the dot product of A with a vector of ones. +That is, if the result is a vector we have .. math:: - B := \alpha \text{op}(A) + \beta B + b := \alpha \text{op}(A) \vec{1} + \beta b, -for vectors and matrices. -If the atomic flag is set, B is updated atomically. +where :math:`\text{op}(A)` is defined as in the axpby instruction, +and if the result is a scalar we have -Arguments -......... +.. math:: -The first argument gives :math:`\alpha`, and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. + b := \alpha \left + \beta b -The transpose modifier defines :math:`\text{op}` as following: +If the atomic flag is set, b is updated atomically. -.. math:: - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i= t \wedge \text{order}(X) = 2,\\ - X & \text{ else. } - \end{array} - \right. +Operands +~~~~~~~~ -(Note that ".t" has no effect on vectors.) +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 scalar-type :math:`\beta` +4 memref-type b +======= =========== ============== -The shape of :math:`\text{op}(A)` and B must be identical and the order of A and B needs to be 1 (vector) -or 2 (matrix). +Restrictions +~~~~~~~~~~~~ +* :math:`\text{order}(b) = 1 \lor \text{order}(b) = 0` +* :math:`\text{order}(A) = \text{order}(b)+1` +* :math:`\text{rows}(b) = \text{rows}(\text{op}(A)) \text{ if } \text{order}(b) = 1` +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A) \preceq \text{element_type}(B)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -For ---- + +Mixed instructions +------------------ + +Arithmetic (binary) +................... .. code:: abnf - for-instruction = "for" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - ["," identifier-or-int-constant] [":" integer-type] region + arith-binary-type = "arith.add" / + "arith.sub" / + "arith.mul" / + "arith.div" / + "arith.rem" / + "arith.min" / + "arith.max" / + "arith.shl" / + "arith.shr" / + "arith.and" / + "arith.or" / + "arith.xor" + value-instruction =/ arith-binary-type local-identifier "," local-identifier + ":" (boolean-type / scalar-type / coopmatrix-type) Overview -........ +~~~~~~~~ + +Binary arithmetic operation on scalars and cooperative matrices. +Both operands, as well as the returned type, have the same scalar or component type. +Arithmetic on cooperative matrices is done component-wise. + +The following table shows the operations' description and the types that are allowed for the operation. +The backslash "\\" is used to exclude types from the list of allowed types. + +=== ============================= ====================================================== +Op Allowed type Description +=== ============================= ====================================================== +add scalar-type Sum of operands +sub scalar-type Difference of operands +mul scalar-type Product of operands +div scalar-type Quotient of operands +rem scalar-type \\ complex-type Remainder from the division of operands +shl integer-type Left shift first operand by second operand +shr integer-type Arithmetic right shift first operand by second operand +and boolean-type / integer-type Bitwise and +or boolean-type / integer-type Bitwise or +xor boolean-type / integer-type Bitwise xor +min scalar-type \\ complex-type Minimum of operands +max scalar-type \\ complex-type Maximum of operands +=== ============================= ====================================================== -A for loop. -Instructions in the for loop execute sequentially and its region is a *mixed region*. +Arithmetic (unary) +.................. -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -A step size can be given with the third integer constant. -The step size defaults to 1 if omitted. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. +.. code:: abnf -Foreach -------- + arith-unary-type = "arith.abs" / + "arith.neg" / + "arith.not" / + "arith.conj" / + "arith.im" / + "arith.re" + value-instruction =/ arith-unary-type local-identifier + ":" (scalar-type / coopmatrix-type) + +Overview +~~~~~~~~ + +Unary arithmetic operation on scalars and cooperative matrices. +For integer and floating point input, the operand must have the same type as the returned value. +For complex input, the returned value has the component floating point type +for ".abs", ".im", and ".re", and the returned value has the same type as the operand +for ".neg" and ".conj". +Arithmetic on cooperative matrices is done component-wise. + +The following table shows the operations' description and the types that are allowed for the operation. + +==== ============================= ============================= +Op Allowed type Description +==== ============================= ============================= +abs scalar-type Compute absolute value +neg scalar-type Negation +not boolean-type / integer-type Bitwise not +conj complex-type Complex conjugate +im complex-type Extract imaginary part +re complex-type Extract real part +==== ============================= ============================= + +Barrier +....... .. code:: abnf - foreach-instruction = "foreach" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - [":" integer-type] region + instruction =/ "barrier" [".global"] [".local"] Overview -........ +~~~~~~~~ -A foreach loop that executes the loop's range [from; to) without any sequence guarantee. -The region of a foreach is a *spmd region*. +**Note:** Barriers are inserted automatically in collective regions, but not in SPMD regions. +Manual barrier insertion should only be only necessesary in SPMD regions. -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -The integer type of the loop variable is given after the colon. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. -GEMM ----- +Control barrier. +The barrier must be encountered by all work-items. +A work-item in a work-group is not allowed to continue until all work-items in the work-group +have reached the barrier. + +Aditional memory fences are controlled by the following attributes: + +========= ====================================================================================== +Attribute Description +========= ====================================================================================== +.global Ensure that global memory accesses become visible to the work-group. +.local Ensure that local memory accesses become visible to the work-group. +========= ====================================================================================== + +Builtin (mixed) +............... .. code:: abnf - gemm-instruction = "gemm" transpose transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + mixed-builtin-type = "builtin.group_id.x" / + "builtin.group_id.y" / + "builtin.group_id.z" / + "builtin.num_groups.x" / + "builtin.num_groups.y" / + "builtin.num_groups.z" / + "builtin.num_subgroups.x" / + "builtin.num_subgroups.y" / + "builtin.subgroup_size" + value-instruction =/ mixed-builtin-type ":" integer-type Overview -........ +~~~~~~~~ -*Collective instruction.* -GEMM implements the well-known GEMM BLAS-3 operation. +Returns a builtin value. + +The group id is three dimensional; the mode is selected with the .x, .y, and .z suffix. +Each mode starts with zero and is limited by the corresponding num_groups mode. That is, .. math:: - C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C + \forall d \in \{x,y,z\} : 0 \leq \text{group_id}_d < \text{num_groups}_d -If the atomic flag is set, C is updated atomically. +The number of subgroups is two dimensional and is related to the work-group size as following: -Arguments -......... +.. math:: -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, B, and C, respectively. + \begin{aligned} + \text{num_subgroups}_x &= \frac{\text{work_group_size[0]}}{\text{subgroup_size}} \\ + \text{num_subgroups}_y &= \text{work_group_size[1]} + \end{aligned} -The first transpose modifier defines :math:`\text{op}_1` and the second transpose modifier -defines :math:`\text{op}_2` as following: +The following table shows the builtins' description and the types that are returned. -.. math:: +=================== ===== ====================== ==================================================== +Builtin Type OpenCL analogue Description +=================== ===== ====================== ==================================================== +group_id.(x/y/z) index get_group_id Returns the x, y, or z mode of the group id +num_groups.(x/y/z) index get_num_groups Returns number of groups in the x, y, or z mode +num_subgroups.(x/y) i32 N/A Returns the number of subgroups in the x or y mode +subgroup_size i32 get_max_sub_group_size Returns the subgroup size +=================== ===== ====================== ==================================================== - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i = t,\\ - X & \text{ if } & \text{modifier}_i = n. - \end{array} - \right. +Cast +.... +.. code:: abnf -If :math:`\text{op}_1(A)` has the shape MxK and -:math:`\text{op}_2(B)` has the shape KxN then C must have the shape MxN. + value-instruction =/ "cast" local-identifier ":" scalar-type + value-instruction =/ "cast" local-identifier ":" coopmatrix-type -GEMV ----- +Overview +~~~~~~~~ + +Cast scalar values or cooperative matrices to type indicated after the colon. + +The source type must be a coopmatrix type if the destination type is a coopmatrix type, +and the shapes must match. +The coopmatrix use must either match, or +the use of the source type must be matrix_acc and the use of the destination type +must be matrix_a or matrix_b. + +Casts from complex types to non-complex types are forbidden. +The following table summarizes the casts and the mapping to SPIR-V +(the casts are done component-wise for coopmatrix types): + +============= ============= ================================================== +Operand type Result type SPIR-V Op +============= ============= ================================================== +integer-type integer-type OpSConvert +floating-type floating-type OpFConvert +complex-type complex-type OpFConvert (on vector2) +integer-type floating-type OpConvertSToF +floating-type integer-type OpConvertFToS +floating-type complex-type OpFConvert on real part, imaginary part is zero +integer-type complex-type OpConvertSToF on real part, imaginary part is zero +complex-type integer-type Forbidden +complex-type floating-type Forbidden +============= ============= ================================================== + +Comparison +.......... .. code:: abnf - gemv-instruction = "gemm" transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + cmp-type = "cmp.eq" / + "cmp.ne" / + "cmp.gt" / + "cmp.ge" / + "cmp.lt" / + "cmp.le" + value-instruction =/ cmp-type local-identifier "," local-identifier ":" "bool" Overview +~~~~~~~~ + +Scalar comparison. +Both operands must have the same scalar type and the returned value has boolean type. + +The following table shows the comparisons' description and the types that are allowed for the comparison. +The backslash "\\" is used to exclude types from the list of allowed types. + +==== =========================== ===================== +Cond Allowed type Description +==== =========================== ===================== +eq scalar-type Equal +ne scalar-type Not equal +gt scalar-type \\ complex-type Greater than +ge scalar-type \\ complex-type Greater than or equal +lt scalar-type \\ complex-type Less than +le scalar-type \\ complex-type Less than or equal +==== =========================== ===================== + +Constant ........ -*Collective instruction.* -GEMV implements the well-known GEMM BLAS-2 operation. +.. code:: abnf -.. math:: + value-instruction =/ "constant" constant ":" (boolean-type / scalar-type / coopmatrix-type) - c := \alpha \text{op}_1(A) b + \beta C +Overview +~~~~~~~~ -If the atomic flag is set, c is updated atomically. +Sets the result value to a constant value. +The type of the constant must match the scalar or component type +(e.g. an integer type requires an integer-constant and a floating type requires a floating-constant). + +When the result is a cooperative matrix, all entries are set to the same constant value. + +Expand +...... + +.. code:: abnf + + value-instruction =/ "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type + expand-shape = integer-constant-or-identifier 1*("x" integer-constant-or-identifier) + integer-constant-or-identifier = integer-constant / local-identifier + +Overview +~~~~~~~~ + +The expand instruction returns a view on a tensor with a mode viewed as higher-order mode. + +Operands +~~~~~~~~ + +The first argument must point to a value of memref type. +The first integer constant before "->" gives the mode that shall be expanded. +The expand shape coming after "->" gives the new shape of the mode. +Dynamic values in the expand shape must have `index` type. + +Restrictions +~~~~~~~~~~~~ + +The memref type of the result must conform with the following rules: + +#. Element type and address space must match the operand's memref type. +#. **Shape:** The mode size is replaced with the expand shape. + The product of the expand shape must equal the size of the expanded mode. + + .. code:: + + expand %0[1 -> 2x8] : memref ; %0: memref + expand %0[1 -> 2x2x2x2] : memref ; %0: memref + +#. **Identifiers:** Local identifiers in the expand shape are dynamic in the resulting memref type. + The product of the dynamic expand shape must equal the size of the expanded mode. + + .. code:: + + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> 2 x %1] : memref ; %0: memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> %1 x %2 x 2] : memref ; %0: memref + expand %0[1 -> %2 x 2 x %1] : memref ; %0: memref + expand %0[1 -> %1 x %2] : memref ; %0: memref + expand %0[1 -> %1 x %2] : memref ; %0: memref + + *Note:* In the third example above, %1 must be equal to 8. + The output mode corresponding to %1 is still dynamic. + +#. **Stride:** A new stride entry is entered that follows the canonical stride computation. + It is also permissible to put '?' for a stride instead of the constant value. + + .. code:: + + expand %0[0->4 x 8] : memref> ; %0: memref> + expand %0[0->4 x 8] : memref> ; %0: memref> + expand %0[0->%1 x 4] : memref> ; %0: memref> + expand %0[0->4 x %1] : memref> ; %0: memref> + expand %0[0->4 x %1] : memref> ; %0: memref> + +Further restrictions: + +* The product of the expand shape must be the same as the mode size. +* If the product of the expand shape is only known at runtime, then it is undefined behaviour + if the dynamic product does not match the mode size. + +For +... + +.. code:: abnf + + multi-value-instruction = "for" local-identifier [":" integer-type] "=" + local-identifier "," local-identifier ["," local-identifier] + ["init" "(" init-value-list ")" "->" "(" return-type-list ")" ] region + [dictionary-attribute] + init-value-list = init-value *("," init-value) + init-value = local-identifier "=" local-identifier + return-type-list = return-type *("," return-type) + return-type = boolean-type / scalar-type / coopmatrix-type + + +Overview +~~~~~~~~ + +A for loop. +Instructions in the for loop execute sequentially and its region is a *mixed region*. Arguments -......... +~~~~~~~~~ + +The trip count is stored in the first local identifier and is accessible within the loop body. +The loop's range [from; to) is given by the first and the second local identifier after the equals sign, +and a step size may be given with the third local identifier after the equals sign. +The step size defaults to 1 if omitted. +The integer type of the loop variable and the loop bounds is given after the colon and +the default integer type is ``index``. + +Values that are given in the init-value-list may be carried from one iteration to the next. +The local identifier gives the name of the loop-carried value as it is accessible in the loop body. +The local identifier given on the right-hand side of the init-value expression determines +the initial value of the loop-carried value, and its type must coincide with the scalar-type-list. +When loop-carried values are present, the loop's last instruction must be a yield instruction that +updates the loop-carried values for the next iteration. +The number and types of the yielded values must correspond the scalar-type-list. -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, b, and c, respectively. +Returns +~~~~~~~ -The transpose modifier for A as in GEMM. +The final value of the loop-carried values are returned by the for instruction. -:math:`\text{op}_1(A)` has the shape MxK and :math:`B` has the shape K then c must have the shape M. -GER ---- +Example: + +.. code:: + + %from = constant 2 : i32 + %to = constant 6 : i32 + %f0 = constant 0 : i64 + %f1 = constant 1 : i64 + %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { + %fn = arith.add %fn_2, %fn_1 : i64 + yield (%fn_1, %fn) + } + ; %fn_1 contains the fourth Fibonacci number and %fn the fifth Fibonacci number + +Attributes +~~~~~~~~~~ + +The following named attributes may be passed in the attribute dictionary: + +.. list-table:: + + * - Name + - Type + - Description + * - unroll + - boolean-attribute or integer-attribute + - true: request to unroll loop, false: request to not unroll loop, integer: partial unroll count + +Fuse +.... .. code:: abnf - ger-instruction = "ger" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + value-instruction =/ "fuse" local-identifier "[" integer-constant "," integer-constant "]" + ":" memref-type Overview -........ +~~~~~~~~ -Computes the general rank-1 update: +The fuse instruction returns a view on a tensor with two or more adjacent modes viewed as a single mode. -.. math:: +Fused modes are specified as the interval [from, to], where counting starts from 0. +From and to must refer to existing modes, that is, we require :math:`0 \leq \text{from} < \text{to} < \text{order}(\text{tensor})`. +Moreover, the stride vector S and the shape vector s must satisify the following compatibility condition: - C := \alpha a b^T + \beta C +:math:`\forall k \in [\text{from},\text{to}): S_{k}s_{k} = S_{k+1}` -If the atomic flag is set, C is updated atomically. +If S(i:j) and s(i:j) are known at compile time, the fuse instruction is illegal if the compatibility +condition is not satisfied. +If a single entry in S(i:j) or s(i:j) is dynamic, then fusing modes that violate the compatbility condition +is undefined beheaviour, e.g. -Arguments -......... +.. code:: -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and C, respectively. + ; Illegal, modes cannot be fused + fuse %0[0,1] : memref ; %0: memref> + ; Undefined behaviour if dynamic stride != 8 + fuse %0[0,1] : memref> ; %0: memref> -a and b must be vectors. If the size of a is M and the size of b is N the shape of C must be :math:`M\times N`. +Operands +~~~~~~~~ +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 memref-type tensor +2 integer-constant from +3 integer-constant to +======= ================ =========== -Hadamard product ----------------- +Restrictions +~~~~~~~~~~~~ + +The memref type of the result must conform with the following rules: + +#. Element type and address space must match the operand's memref type. +#. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. + + .. code:: + + fuse %0[1,3] : memref ; %0: memref + fuse %0[1,3] : memref> ; %0: memref> + +#. **Stride:** Strides remain unchanged or are replaced by '?'. + + .. code:: + + fuse %0[1,2] : memref> ; %0: memref> + fuse %0[1,2] : memref> ; %0: memref> + fuse %0[0,1] : memref> ; %0: memref> + +If +.. .. code:: abnf - hadamard-product-instruction = "hadamard_product" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + multi-value-instruction =/ "if" local-identifier ["->" "(" return-type-list ")"] + region ["else" region] Overview -........ +~~~~~~~~ -*Collective instruction.* -Computes the Hadamard product of two tensors. -That is, in index notation we have +An if statement. +Both regions are *mixed regions*. -.. math:: +The condition (first operand) must have boolean type. - c_{i} := \alpha a_{i} b_{i} + \beta c_{i} +Returns +~~~~~~~ -If the atomic flag is set, c is updated atomically. +The if instruction may return multiple values, where the number of values and the value types +are given by the return-type-list. +If values are returned, the last instruction in both the "then"-region and the "else"-region must +be a yield instruction (the "else"-region cannot be omitted). -Arguments -......... +Example: + + .. code:: + + %1 = cmp.lt %0, 16 : i32 + %x = if %1 -> (i32) { + yield (%0) + } else { + %c16 = constant 16 : i32 + yield (%c16) + } + + +Load +.... + +.. code:: abnf + + value-instruction =/ "load" local-identifier "[" [local-identifier-list] "]" + ":" scalar-or-memref-type + scalar-or-memref-type = scalar-type / memref-type + +Overview +~~~~~~~~ + +Load the element given by the index list from a memref or group. +The number of indices must match the order of the memref +and a single index must be given for a group. + +Operands +~~~~~~~~~ + +======= ======================== =========== +Op.-No. Type Description +======= ======================== =========== +1 memref-type / group-type tensor +2... index index list +======= ======================== =========== + +Returns +~~~~~~~ + +A value of the memref's element type or the group's memref type. +Examples: + +#. ``load %0[] : f32 ; %0: memref`` +#. ``load %0[5, %1] : f32 ; %0: memref`` +#. ``load %0[%1] : memref ; %0: groupx?>`` +#. ``load %0[%1] : memref ; %0: groupx?, offset: ?>`` + +Math (unary) +............ + +.. code:: abnf + + math-unary-type = "math.exp" / + "math.exp2" / + "math.native_exp" / + "math.native_exp2" + value-instruction =/ math-unary-type local-identifier ":" scalar-type + +Overview +~~~~~~~~ + +Unary math operation on scalars. +The operand must have the same type as the returned value. + +The following table shows the operations' description and the types that are allowed for the operation. + +=========== ============================= ===================================================================== +Op Allowed type Description +=========== ============================= ===================================================================== +cos floating-type Compute cosine function +sin floating-type Compute sine function +exp floating-type / complex-type Compute base-e exponential function +exp2 floating-type / complex-type Compute base-2 exponential function +native_cos floating-type Compute cosine function with implementation-defined error +native_sin floating-type Compute sine function with implementation-defined error +native_exp floating-type / complex-type Compute base-e exponential function with implementation-defined error +native_exp2 floating-type / complex-type Compute base-2 exponential function with implementation-defined error +=========== ============================= ===================================================================== + +.. _size instruction: + +Size +.... + +.. code:: abnf + + value-instruction =/ "size" local-identifier "[" integer-constant "]" ":" "index" + +Overview +~~~~~~~~ + +The size instruction returns the i-th entry of the tensor's shape, where "i" is given by the integer +constant in square brackets. +"i" must be in bounds, i.e. :math:`0 \leq i < \text{order}(tensor)`. + +For group types, the group size is returned and "i" must be 0. + +Operands +~~~~~~~~~ + +======= ======================== =========== +Op.-No. Type Description +======= ======================== =========== +1 memref-type / group-type tensor +2 integer-constant mode index +======= ======================== =========== + +Subview +....... + +.. code:: abnf + + value-instruction =/ "subview" local-identifier "[" [index-or-slice-list] "]" + ":" memref-type + index-or-slice-list = index-or-slice *("," index-or-slice) + index-or-slice = integer-constant-or-identifier [":" integer-constant-or-identifier] + +Overview +~~~~~~~~ + +The subview instruction returns a view on a tensor. + +The first argument must point to a value of memref type. +The number of indices in square brackets must match the order of the memref type. +The indices are either given as single index or as a slice, where +slices are given in offset plus size notation ("%offset : %size"). +E.g. the slice "%0 : %1" extracts a block of %1 elements beginning from %0, which is equivalent +to the index interval [%0, %0 + %1). + +.. admonition:: Note + + A slice is often defined as "%0 : %1" being the index interval [%0, %1). + However, then the compiler needs to figure out whether %1 - %0 is constant or not in order + to determine whether the mode size is known at compile-time or not. + Therefore, we prefer the offset plus size notation. -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and c, respectively. +Zero sizes are used to encode that a rank-reduction is required, that is, +the rank of size 0 is removed from the output memref type. +A single index is syntactic sugar for offset plus size 0, e.g. %0 is syntactic sugar for %0:0. +(Note that a zero-size rank, e.g. in memref, is non-sense, because any multi-index passed +to the memref would be out-of-bounds. However, a one-sized rank, e.g. memref, might be desirable.) +A dynamic size of zero is undefined behaviour. -a, b, and c must be vectors and have equal shape. +There is no run-time check whether the indices are within bounds. +Offset and size must be of index type. +Offset must be non-negative and size must be positive. +Restrictions +~~~~~~~~~~~~ + +The memref type of the result must conform with the following rules: + +#. Element type and address space must match the operand's memref type. +#. **Invariant-stride:** The stride is not changed or replaced with '?'. + + .. code:: + + subview %0[4:8,8:4] : memref> ; %0: memref + subview %0[4:8,8:4] : memref> ; %0: memref + + +#. **Rank-reduction:** A mode accessed by offset only or a mode with size statically known to be 0 is removed from the output tensor. + + .. code:: + + subview %0[2:4, %1] : memref ; %0: memref + subview %0[2:4, %1:0] : memref ; %0: memref + subview %0[2:4, %1:1] : memref> ; %0: memref + +#. **Output-mode size:** The size of the output mode is determined by the size field of a slice + and may be dynamic. + + .. code:: + + subview %0[%1:4] : memref ; %0: memref + subview %0[%2:%2] : memref ; %0: memref + subview %0[2:4, %2:%2, 6:7] : memref ; %0: memref + subview %0[2:4, %2:%2, 6:7] : memref ; %0: memref> Store ------ +..... .. code:: abnf - store-instruction = "store" local-identifier "," local-identifier "[" [index-list] "]" ":" memref-type + instruction =/ "store" [store-flag] local-identifier "," + local-identifier "[" [local-identifier-list] "]" + store-flag = ".atomic" / ".atomic_add" / ".atomic_max" / ".atomic_min" Overview -........ +~~~~~~~~ -*Replicated instruction.* -Store a scalar value in a memref at the position given by the index list. +Store a scalar value (first operand) in a memref (second operand) at the position given by the index list. The number of indices must match the order of the memref. +The store is atomic when the atomic flag is set with relaxed memory ordering. +When the atomic_add/max/min flag is set, the following steps are done atomically: +The value at the memory location is fetched, the scalar value is added to the fetched value, +and the resulting value is stored at the memory location. + +When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used +for the the real and imaginary separately. + *Note:* Store should only be used in SPMD regions as otherwise the same memory location is written from all work-items. -Arguments -......... +Operands +~~~~~~~~ -The first operand must have the same scalar type as the memref type. -The indices must be of ``index`` type. +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 scalar-type value +2 memref-type tensor +3... index index list +======= ================ =========== -Sum ---- +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{type}(value) = \text{element_type}(tensor)` + +Yield +..... .. code:: abnf - sum-instruction = "sum" transpose [".atomic"] - "," const-or-val "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + instruction =/ "yield" "(" [local-identifier-list] ")" Overview -........ +~~~~~~~~ -*Collective instruction.* -Computes the matrix-vector product or the dot product of A with a vector of ones. -That is, for matrices we have +Yield returns values from an if or for instruction. + +Operands +~~~~~~~~ + +======= ============================================ =========== +Op.-No. Type Description +======= ============================================ =========== +1... boolean-type / scalar-type / coopmatrix-type value +======= ============================================ =========== + +Additional instructions +....................... + +.. code:: abnf + + instruction =/ "lifetime_stop" local-identifier + +SPMD instructions +----------------- + +Builtin (SPMD) +.............. + +.. code:: abnf + + spmd-builtin-type = "builtin.subgroup_id.x" / + "builtin.subgroup_id.y" / + "builtin.subgroup_linear_id" / + "builtin.subgroup_local_id" + value-instruction =/ spmd-builtin-type ":" integer-type + +Overview +~~~~~~~~ + +Returns a builtin value. + +The subgroup id is two dimensional; the mode is selected with the .x and .y suffix. +Each mode starts with zero and is limited by the corresponding num_subgroups mode. That is, .. math:: - B := \alpha \text{op}(A) \vec{1} + \beta B + \forall d \in \{x,y\} : 0 \leq \text{subgroup_id}_d < \text{num_subgroups}_d -and for vectors we have +The subgroup linear id combines the x and y modes of the subgroup id as following: .. math:: - b := \alpha \left + \beta b + \text{subgroup_linear_id} = \text{subgroup_id}_x + \text{subgroup_id}_y\cdot \text{num_subgroups}_x -If the atomic flag is set, B is updated atomically. +The subgroup local id is the invocation id within the subgroup and ranges from 0 to subgroup_size-1. + +The following table shows the builtins' description and the types that are returned. + +=================== ===== ====================== ==================================================== +Builtin Type OpenCL analogue Description +=================== ===== ====================== ==================================================== +subgroup_id.(x/y) i32 N/A Returns the x or y mode of the subgroup id +subgroup_linear_id i32 get_sub_group_id Returns linear subgroup id +subgroup_local_id i32 get_sub_group_local_id Returns the local invocation id in the subgroup +=================== ===== ====================== ==================================================== + +Cooperative matrix apply +........................ + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_apply" + "(" local-identifier "," local-identifier "," local-identifier ")" + "=" local-identifier + "->" coopmatrix-type region + +Overview +~~~~~~~~ +Apply an action on every component of a coopmatrix and update the component with the result of the action. +The action is described in the *parallel region* of the instruction. Arguments -......... +~~~~~~~~~ -The first argument gives :math:`\alpha` and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. -If A is a matrix then B must be a vector. -The first mode size of :math:`\text{op}(A)` must match the size of B. -If A is a vector, then B must be a scalar memref. +The first three local identifier introduce SSA values for the row index, column index, and component value. +The row and columns values have i32 type and the component value has the same component type as the resulting +coopmatrix type. +The fourth identifer, after "in", gives the input coopmatrix, and its type must match the result type. -The transpose op is defined as in the axpby instruction. +The region must yield exactly one value whose scalar type is identical to the component type of the coopmatrix. -Yield ------ +Example: + +.. code:: + + %0 = ... ; contains a coopmatrix of type coopmatrix + %1 = cooperative_matrix_apply (%i,%j,%v)=%0 -> coopmatrix { + %mask = cmp.le %i, %j : bool + %exp_v_masked = if %mask -> (f32) { + %exp_v = math.native_exp %v : f32 + yield (%exp_v) + } else { + %zero = constant 0.0 : f32 + yield (%zero) + } + yield (%exp_v_masked) + } + ; The entries of %1 are given by %1[i,j] = exp(%0[i,j]) if i <= j else 0 + +Cooperative matrix extract +.......................... .. code:: abnf - yield-instruction = "yield" [local-identifier-list] ":" [scalar-type-list] - identifier-or-constant-list = identifier-or-constant *("," identifier-or-constant) + value-instruction =/ "cooperative_matrix_extract" + local-identifier "[" integer-constant "]" ":" scalar-type Overview -........ +~~~~~~~~ -Yield returns values from an if or for instruction. +Return an element of the coopmatrix's work-item vector. +The index is supplied in square brackets and must be greater or equal than zero +and smaller than the length of the work-item vector, cf. :ref:`coopmatrix layout`. -Arguments -......... +The scalar type of the returned value must match the component type of the coopmatrix. -The length of the local identifier list must equal the length of the scalar type list. +Operands +~~~~~~~~~ +======= ================ =========================== +Op.-No. Type Description +======= ================ =========================== +1 coopmatrix-type Cooperative matrix +2 integer-constant Index into work-item vector +======= ================ =========================== -Additional instructions ------------------------ +Cooperative matrix insert +......................... .. code:: abnf - barrier-instruction = "barrier" - lifetime-stop-instruction = "lifetime_stop" local-identifier + value-instruction =/ "cooperative_matrix_insert" local-identifier "," + local-identifier "[" integer-constant "]" ":" coopmatrix-type + +Overview +~~~~~~~~ + +Return a copy the coopmatrix, while modifying one entry of the coopmatrix. +The index is supplied in square brackets and must be greater or equal than zero +and smaller than the length of the work-item vector, cf. :ref:`coopmatrix layout`. + +The coopmatrix type of the returned value must match the coopmatrix type of the incoming matrix. +The scalar type of the inserted scalar must match the component type of the coopmatrix. + +Operands +~~~~~~~~~ + +======= ================ =========================== +Op.-No. Type Description +======= ================ =========================== +1 scalar-type Inserted scalar +2 coopmatrix-type Cooperative matrix +3 integer-constant Index into work-item vector +======= ================ =========================== + +Cooperative matrix load +....................... + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_load" transpose [checked-flag] + local-identifier "[" local-identifier "," local-identifier "]" + ":" coopmatrix-type + checked-flag = ".rows_checked" / ".cols_checked" / ".both_checked" + +Overview +~~~~~~~~ + +Load a cooperative matrix from a 2d-memref at the position given by the indices in square brackets. +The position gives the starting row and column index, that is, +when a coopmatrix of size :math:`X\times Y` is loaded from memref :math:`M` at +position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are given by + +.. math:: + + \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + i) S_1 + (y + j) S_2], + +where :math:`S_1` and :math:`S_2` are the entries of the memref's stride array. +When the transpose modifier ".t" is given, we have + +.. math:: + + \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + j) S_1 + (y + i) S_2] + +When the checked flag is set, the following out-of-bound checks are added +(with memref shape :math:`s_1\times s_2`): + +=============== =================================================================== +Flag Description +=============== =================================================================== +.n.rows_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq x+i < s_1 \text{ else } 0` +.t.rows_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq y+i < s_2 \text{ else } 0` +.n.cols_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq y+j < s_2 \text{ else } 0` +.t.cols_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq x+j < s_1 \text{ else } 0` +.n.both_checked .n.rows_checked.n and .n.cols_checked +.t.both_checked .t.rows_checked.t and .t.cols_checked +=============== =================================================================== + +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 memref-type M +2 index x +3 index y +======= =============== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{order}(M) = 2` +* :math:`\text{component_type}(A) = \text{element_type}(M)` +* All arguments **must** be dynamically uniform. + +Cooperative matrix mul add +.......................... + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_mul_add" local-identifier "," + local-identifier "," local-identifier ":" coopmatrix-type + +Overview +~~~~~~~~ + +Matrix mul add returns the value of + +.. math:: + + D := AB + C, + +where A, B, and C are matrices given by the three operands. + +The number of rows of matrix A,C, and D must be a multiple of the subgroup size. + +Operands +~~~~~~~~ + +======= =============== ========== =========== +Op.-No. Type Use Description +======= =============== ========== =========== +1 coopmatrix-type matrix_a A +2 coopmatrix-type matrix_b B +3 coopmatrix-type matrix_acc C +======= =============== ========== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\forall X\in\{A,C,D\}: \text{rows}(X) \bmod \text{subgroup_size} = 0` +* :math:`\text{columns}(A) = \text{rows}(B)` +* :math:`\text{rows}(C) = \text{rows}(A) \land \text{columns}(C) = \text{columns}(B)` +* :math:`\text{shape}(D) = \text{shape}(C)` +* :math:`\text{use}(D) = \text{matrix_acc}` +* :math:`\text{promote}(\text{component_type}(A), \text{component_type}(B)) \preceq \text{component_type}(C)` +* Cast of :math:`\text{component_type}(C)` to :math:`\text{component_type}(D)` must be allowed + +Cooperative matrix prefetch +........................... + +.. code:: abnf + + instruction =/ "cooperative_matrix_prefetch" integer-constant "," + local-identifier "[" local-identifier "," local-identifier "]" "," + integer-constant "," integer-constant + +Overview +~~~~~~~~ + +Cooperatively prefetch memory into device cache. +The cache level is given by the first non-negative integer constant, where "0" is the cache closest the core +and core distance increases with increasing cache level. +The prefetch instruction is ignored if the cache level does not exist in the target device. +The position in square brackets gives the starting row and column index. +The last two positive integer constants give the size of the memory region to fetch (in rows by columns). +The following memory locations are prefetched: + +.. math:: + + \{\forall i \in [0,X), j \in [0,Y): M[(x + i) S_1 + (y + j) S_2]\} + +Prefetch is an optimization hint and may be disregarded by the compiler. + +Operands +~~~~~~~~ + +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 integer-constant Cache-level +2 memref-type M +3 index x +4 index y +5 integer-constant X +6 integer-constant Y +======= ================ =========== + +Restrictions +~~~~~~~~~~~~ + +* All arguments **must** be dynamically uniform. + +Cooperative matrix reduce +......................... + +.. code:: abnf + + coopmatrix-reduce-type = "cooperative_matrix_reduce.add.row" / + "cooperative_matrix_reduce.add.column" / + "cooperative_matrix_reduce.max.row" / + "cooperative_matrix_reduce.max.column" / + "cooperative_matrix_reduce.min.row" / + "cooperative_matrix_reduce.min.column" + value-instruction =/ coopmatrix-reduce-type local-identifier ":" coopmatrix-type + +Overview +~~~~~~~~ + +Computes the sum, maximum, or minimum over either the rows or columns of a coopmatrix. + +The component type and use of the the returned value's coopmatrix type +must match the component type and use of the incoming matrix. + +For a row reduction the resulting shape must be :math:`M\times 1` and for a column reduction +the resulting shape must be :math:`1\times N`, where the shape of the incoming matrix is :math:`M\times N`. + +Operands +~~~~~~~~~ + +======= ================ =========================== +Op.-No. Type Description +======= ================ =========================== +1 coopmatrix-type Incoming cooperative matrix +======= ================ =========================== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{rows}(A) \bmod \text{subgroup_size} = 0` + +Cooperative matrix scale +........................ + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_scale" local-identifier "," local-identifier + ":" coopmatrix-type + +Overview +~~~~~~~~ + +Scale a coopmatrix by a scalar. +The scalar type of the scalar and the component type of the coopmatrix must match, +and the returned must have the same coopmatrix type as the matrix operand. + +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 scalar-type scalar +2 coopmatrix-type matrix +======= =============== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{type}(scalar) = \text{component_type}(matrix)` +* :math:`\text{type}(result) = \text{type}(matrix)` + +Cooperative matrix store +........................ + +.. code:: abnf + + instruction =/ "cooperative_matrix_store" [checked-flag] [store-flag] local-identifier "," + local-identifier "[" local-identifier "," local-identifier "]" + +Overview +~~~~~~~~ + +Store a cooperative matrix value in a 2d-memref at the position given by the indices in square brackets. +The position gives the starting row and column index, that is, +when a coopmatrix of size :math:`X\times Y` is written to memref :math:`M` at +position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are written to + +.. math:: + + \forall i \in [0,X), j \in [0,Y): M[(x + i) S_1 + (y + j) S_2] := A_{ij}, + +where :math:`S_1` and :math:`S_2` are the entries of the memref's stride array. +When the checked flag is set, the following out-of-bound checks are added +(with memref shape :math:`s_1\times s_2`): + +============= ======================================================================================================= +Flag Description +============= ======================================================================================================= +.rows_checked Only execute store if :math:`0 \leq x+i < s_1` +.cols_checked Only execute store if :math:`0 \leq y+j < s_2` +.both_checked .rows_checked + .cols_checked +============= ======================================================================================================= + +The store is atomic when the atomic flag is set with relaxed memory ordering. +When the atomic_add flag is set, the coopmatrix is added to the memref atomically. + +When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used +for the the real and imaginary separately. + +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 coopmatrix-type A +2 memref-type M +3 index x +4 index y +======= =============== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{component_type}(A) = \text{element_type}(B)` +* All arguments **must** be dynamically uniform. + +Subgroup broadcast +.................. + +.. code:: abnf + + value-instruction =/ "subgroup_broadcast" local-identifier "," local-identifier ":" scalar-type + +Overview +~~~~~~~~ + +Broadcast a scalar to all work-items in the subgroup. +The scalar type of the first operand and the type of the result must match. +The second identifier must have i32 type. + +Operands +~~~~~~~~ + +======= =============== ================================================================================================== +Op.-No. Type Description +======= =============== ================================================================================================== +1 scalar-type Value that is to be distributed to all work-items of the sub-group +2 i32 Subgroup local index that identifies the work-item whose value is returned to all other work-items +======= =============== ================================================================================================== + +Restrictions +~~~~~~~~~~~~ + +* The second operand **must** be dynamically uniform. + +Subgroup operation +.................. + +.. code:: abnf + + subgroup-operation-type = "subgroup_operation.add.exclusive_scan" / + "subgroup_operation.add.inclusive_scan" / + "subgroup_operation.add.reduce" / + "subgroup_operation.max.exclusive_scan" / + "subgroup_operation.max.inclusive_scan" / + "subgroup_operation.max.reduce" / + "subgroup_operation.min.exclusive_scan" / + "subgroup_operation.min.inclusive_scan" / + "subgroup_operation.min.reduce" + value-instruction =/ subgroup-operation-type local-identifier ":" scalar-type + +Overview +~~~~~~~~ + +Let :math:`[x_0,x_1,\dots,x_{n-1}]` be the input vector contributed by a subgroup of size *n*. +(The work-item with subgroup local id *i* contributes :math:`x_i`.) +Let :math:`\diamond` be the binary operator and *I* the identity. +We define the output vector of size *n* for the group operations in the following table: + +============== ============================================================================================= +Operation type Result +============== ============================================================================================= +exclusive_scan :math:`[I, x_0, (x_0 \diamond x_1), \dots, x_0 \diamond x_1 \diamond \dots \diamond x_{n-2}]` +inclusive_scan :math:`[x_0, (x_0 \diamond x_1), \dots, x_0 \diamond x_1 \diamond \dots \diamond x_{n-1}]` +reduce :math:`[s,s,\dots,s] \text{ with } s := x_0 \diamond \dots \diamond x_{n-1}` +============== ============================================================================================= + +Add +~~~ + +Computes the subgroup operation with :math:`\diamond:=+` and :math:`I:=0`. + +Max +~~~ + +Computes the subgroup operation with :math:`\diamond:=\max` and identity as given in the following table: + +============= ============================================== +Identity Value +============= ============================================== +integer-type Smallest integer representable by integer type +floating-type :math:`-\infty` +complex type Forbidden +============= ============================================== + +Min +~~~ + +Computes the subgroup operation with :math:`\diamond:=\min` and identity as given in the following table: + +============= ============================================= +Identity Value +============= ============================================= +integer-type Largest integer representable by integer type +floating-type :math:`+\infty` +complex type Forbidden +============= ============================================= + Sample code =========== @@ -1057,16 +2182,16 @@ where B and C are constant matrices and A and D are matrix batches. .. code:: func @fused_kernel(%alpha: f32, - %A: group>, + %A: groupx?>, %B: memref, %C: memref, %D: memref) { - %0 = group_id - %1 = load %A[%0] : group> ; Returns memref - %2 = subview %D[:,:,%0] : memref ; Returns memref - %tmp0 = alloca -> memref - gemm.n.t 1.0, %1, %B, 0.0, %tmp0 - : f32, memref, memref, f32, memref - gemm.n.n %alpha, %tmp0, %C, 1.0, %2 - : f32, memref, memref, f32, memref + %0 = group_id : index + %1 = load %A[%0] : memref + %2 = subview %D[:,:,%0] : memref + %tmp0 = alloca : memref + %zero = constant 0.0 : f32 + %one = constant 1.0 : f32 + gemm.n.t %one, %1, %B, %zero, %tmp0 + gemm.n.n %alpha, %tmp0, %C, %one, %2 } diff --git a/docs/manual/tutorial_matrix_chain.rst b/docs/manual/tutorial_matrix_chain.rst index e9c552d6..768f8837 100644 --- a/docs/manual/tutorial_matrix_chain.rst +++ b/docs/manual/tutorial_matrix_chain.rst @@ -27,50 +27,57 @@ In the :ref:`tensor language ` we can implement the kernel as f func @fused_kernel(%K: memref, %P: memref, - %A: group>, + %A: groupx?>, %Q: memref) { - %gid = group_id ; Get our index e - - %p = subview %P[:,:,%gid] : memref ; %p has type memref - %a = load %A[%gid] : group> ; %a has type memref - %q = subview %Q[:,:,%gid] : memref ; %q has type memref - - %tmp = alloca -> memref ; Reserve temporary memory - - gemm.n.n 1.0, %K, %p, 0.0, %tmp ; Compute tmp <- K P(:,:,e) - : f32, memref, memref, f32, memref - gemm.n.n 1.0, %tmp, %a, 1.0, %q ; Update Q(:,:,e) <- Q(:,:,e) + tmp A_e - : f32, memref, memref, f32, memref + %gid = builtin.group_id : index ; Get our index e + + %p = subview %P[0:56,0:9,%gid] : memref ; Get view on submatrix + %a = load %A[%gid] : memref ; Load matrix from group + %q = subview %Q[0:56,0:9,%gid] : memref ; Get view on submatrix + + %tmp = alloca : memref ; Reserve temporary memory + ; in the Shared Local Memory + %c0 = constant 0.0 : f32 + %c1 = constant 1.0 : f32 + gemm.n.n %c1, %K, %p, %c0, %tmp ; Compute tmp <- K P(:,:,e) + gemm.n.n %c1, %tmp, %a, %c1, %q ; Update Q(:,:,e) <- Q(:,:,e) + tmp A_e } -Compilation with the Tiny Tensor Compiler generates the following OpenCL-C code - -.. code-block:: c - - kernel - __attribute__((reqd_work_group_size(64,1,1))) - __attribute__((intel_reqd_sub_group_size(32))) - fused_kernel(global float *K, global float *P, uint P_shape2, global float *global *A, - global float *Q, uint Q_shape2) { - local uchar stack[2016] __attribute__((aligned(64))); - uint gid = get_global_id(2); - global float *p = P + 0ll * 1 + 0ll * 56 + gid * 504; - global float *a = *(A + gid); - global float *q = Q + 0ll * 1 + 0ll * 56 + gid * 504; - local float *tmp = (local float *)(stack + 0); - gemm_f32f32f32f32f32_An_Bn_M56_N9_K56_Astride1_56_Bstride1_56_Cstride1_56_alpha3ff0000000000000_beta0( - 56, 9, 56, 0x1p+0f, K, 1, 56, p, 1, 56, 0x0p+0f, tmp, 1, 56); - barrier(CLK_LOCAL_MEM_FENCE); - gemm_f32f32f32f32f32_An_Bn_M56_N9_K9_Astride1_56_Bstride1_9_Cstride1_56_alpha3ff0000000000000_beta3ff0000000000000( - 56, 9, 9, 0x1p+0f, tmp, 1, 56, a, 1, 9, 0x1p+0f, q, 1, 56); +Using the *tinytc-opt* tool we can run compiler passes on the code to get insight on what is happening under the hood. +For example, running the insert-lifetime-stop, insert-barrier, and work-group-size pass, + +.. code-block:: bash + + tinytc-opt -pinsert-lifetime-stop -pinsert-barrier -pwork-group-size test.ir + +we get + +.. code-block:: + :emphasize-lines: 4, 13, 15 + + func @fused_kernel(%K: memref, + %P: memref, + %A: groupx?>, + %Q: memref) attributes{subgroup_size=32, work_group_size=[64,1]} { + %gid = builtin.group_id : index + %p = subview %P[0:56,0:9,%gid] : memref + %a = load %A[%gid] : memref + %q = subview %Q[0:56,0:9,%gid] : memref + %tmp = alloca : memref + %c0 = constant 0x0p+0 : f32 + %c1 = constant 0x1p+0 : f32 + gemm.n.n %c1, %K, %p, %c0, %tmp + barrier.local + gemm.n.n %c1, %tmp, %a, %c1, %q + lifetime_stop %tmp } -where the definition of the generated GEMM functions have been omitted for brevity. We observe that -* a GEMM is processed in parallel by a work-group with 64 threads, -* temporary memory is mapped to shared local memory (local uchar stack), -* load and subview calls translate to simple pointer manipulation, +* the kernel is executed concurrently by 64 work-items, +* temporary memory is only needed until after the lifetime_stop instruction after the GEMM + (if multiple alloca's are present that do not overlap, that is, lifetime_stop for alloca #1 appears before alloca #2, + then Shared Local Memory is reused, reducing the total amount needed), * and that a barrier has been introduced between the GEMM calls to avoid data races. When using SYCL, we can run the kernel using the following pseudo-code: @@ -81,16 +88,15 @@ When using SYCL, we can run the kernel using the following pseudo-code: #include #include - auto source_ctx = tinytc::make_source_context(); + #include + + auto ctx = tinytc::make_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); try { // Parse tensor program - auto prog = tinytc::parse_file("fused_kernel.ir", source_ctx); - - // JIT compile program - auto q = sycl::queue{}; - auto info = tinytc::make_core_info(q.get_device()); - auto bin = tinytc::compile_to_binary(std::move(prog), info, tinytc::bundle_format::native, - source_ctx); + auto prog = tinytc::parse_file("fused_kernel.ir", ctx); // Initialize tensors float* K = ...; @@ -98,20 +104,21 @@ When using SYCL, we can run the kernel using the following pseudo-code: float** A = ...; float* Q = ...; - auto bundle = tinytc::make_kernel_bundle(q.get_context(), q.get_device(), bin); + // JIT compile program + auto q = sycl::queue{}; + auto bundle = tinytc::make_kernel_bundle(q.get_context(), q.get_device(), prog); + auto kernel = tinytc::make_kernel(bundle, "fused_kernel"); auto exe_range = tinytc::get_execution_range(kernel, howmany); for (int timestep = 0; timestep < num_timesteps; ++timestep) { q.submit([&](sycl::handler &h) { - h.set_args(K, P, howmany, A, Q, howmany); - h.parallel_for(exec_range, kernel); + h.set_args(K, P, howmany, A, howmany, Q, howmany); + h.parallel_for(exe_range, kernel); }).wait(); } } catch (tinytc::status const& st) { std::cerr << "Error (" << static_cast(st) << "): " << tinytc::error_string(st) << std::endl; - std::cerr << "Error log:" << std::endl - << source_ctx.get_error_log() << std::endl; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; } diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index d3f351a4..81cc0224 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) find_package(SYCL REQUIRED) -add_executable(tinytc-bench main.cpp args.cpp) +add_executable(tinytc-bench main.cpp) add_sycl_to_target(TARGET tinytc-bench SOURCES main.cpp) -target_link_libraries(tinytc-bench PRIVATE tinytc tinytc_sycl) +target_link_libraries(tinytc-bench PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(tinytc-bench) diff --git a/examples/benchmark/args.cpp b/examples/benchmark/args.cpp deleted file mode 100644 index 61030e6b..00000000 --- a/examples/benchmark/args.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" - -#include -#include -#include -#include -#include - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.internal_repetitions = 1; - a.transA = tinytc::transpose::N; - a.transB = tinytc::transpose::N; - a.beta = 0.0; - auto num = std::vector(3); - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error("==> Error: unrecognized argument " + - std::string(argv[i])); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "--trans-a") == 0) { - a.transA = tinytc::transpose::T; - } else if (std::strcmp(argv[i], "--trans-b") == 0) { - a.transB = tinytc::transpose::T; - } else if (std::strcmp(argv[i], "-v") == 0 || std::strcmp(argv[i], "--verify") == 0) { - a.verify = true; - } else if (std::strcmp(argv[i], "-a") == 0 || std::strcmp(argv[i], "--atomic") == 0) { - a.atomic = true; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-i") == 0 || - std::strcmp(argv[i], "--internal-reps") == 0) { - a.internal_repetitions = atoi(argv[++i]); - } else if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { - ++i; - a.beta = atof(argv[i]); - } else if (std::strcmp(argv[i], "-p") == 0 || - std::strcmp(argv[i], "--precision") == 0) { - ++i; - if (argv[i][0] == 'd') { - a.double_precision = true; - } else if (argv[i][0] == 's') { - a.double_precision = false; - } else { - fail(); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - num.clear(); - char const *delim = "x"; - auto arg = std::string(argv[i]); - char *token = std::strtok(argv[i], delim); - while (token) { - num.emplace_back(atoi(token)); - token = std::strtok(nullptr, delim); - } - if (num.size() != 3) { - throw std::runtime_error("==> Could not parse test case: " + arg); - } - a.tc.push_back({num[0], num[1], num[2]}); - } - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tinytcbench test-case1 test-case2 ..." << std::endl - << R"HELP( -positional arguments: - test-caseN MxNxK triplet (e.g. 64x64x64) - -optional arguments: - -h, --help Show help and quit - -i, --internal-reps Number of GEMM repetitions inside kernel (default: 1) - -p, --precision Precision (single = s, double = d) - --trans-a Transpose A matrix - --trans-b Transpose B matrix - -v, --verify Verify optimized implementation - -a, --atomic Update C atomically -)HELP"; -} diff --git a/examples/benchmark/args.hpp b/examples/benchmark/args.hpp deleted file mode 100644 index 23f69c65..00000000 --- a/examples/benchmark/args.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20230417_HPP -#define ARGS_20230417_HPP - -#include "tinytc/types.hpp" - -#include -#include -#include - -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - -struct args { - std::vector tc; - int internal_repetitions; - bool double_precision; - bool help; - tinytc::transpose transA; - tinytc::transpose transB; - double beta; - bool verify; - bool atomic; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20230417_HPP diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index f62101a7..ecedb24c 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -1,15 +1,19 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "../gemm_common.hpp" +#include #include #include #include #include +#include #include +#include #include +#include #include #include #include @@ -21,6 +25,19 @@ using namespace sycl; using namespace tinytc; +struct args { + std::int32_t alignment = 0; + bool atomic = false; + bool dump = false; + std::int32_t internal_repetitions = 1; + bool trans_a = false; + bool trans_b = false; + scalar_type ty = scalar_type::f32; + bool update = false; + bool verify = false; + std::vector tc; +}; + template double bench(F f, int nrepeat = 10) { f(); double min_exec_time_ns = std::numeric_limits::max(); @@ -38,10 +55,13 @@ template double bench(F f, int nrepeat = 10) { auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose tB, bool atomic, std::int64_t M, std::int64_t N, std::int64_t K, std::array A_stride, - std::array B_stride, double beta, - std::array C_stride, - std::int32_t repetitions, queue q) -> source { - auto ctx = make_source_context(); + std::array B_stride, bool update, + std::array C_stride, std::int32_t alignment, + std::int32_t repetitions, bool dump, queue q) -> binary { + auto ctx = make_compiler_context(); + ctx.set_error_reporter( + [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, + nullptr); char const *file_name = std::source_location::current().file_name(); auto const source_id = ctx.add_source(file_name, ""); @@ -54,88 +74,112 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t ++l.end.column; return l; }; + auto const make_memref = [](data_type element_ty, transpose t, int64_t A, std::int64_t B, + std::array const &stride, location const &loc) { + auto s = std::array{A, B}; + if (t == transpose::T) { + std::swap(s[0], s[1]); + } + return get_memref(element_ty, s, stride, address_space::global, loc); + }; - auto kernel = [&](function_builder &fb) { - auto A = fb.argument( - make_group(make_memref( - ty, {M, K}, std::vector(A_stride.begin(), A_stride.end()), my_loc())), - "A", my_loc()); - auto B = fb.argument( - make_group(make_memref( - ty, {K, N}, std::vector(B_stride.begin(), B_stride.end()), my_loc())), - "B", my_loc()); - auto C = fb.argument( - make_group(make_memref( - ty, {M, N}, std::vector(C_stride.begin(), C_stride.end()), my_loc())), - "C", my_loc()); - fb.body( - [&](region_builder &bb) { - auto gid = bb.add(make_group_id(my_loc())); - auto a = bb.add(make_load(A, {gid}, my_loc())); - auto b = bb.add(make_load(B, {gid}, my_loc())); - auto c = bb.add(make_load(C, {gid}, my_loc())); - bb.for_loop( - scalar_type::index, make_index(0, my_loc()), make_index(repetitions, my_loc()), - [&](region_builder &bb) { - bb.add(make_gemm(tA, tB, atomic, make_imm(1.0, ty, my_loc()), a, b, - make_imm(beta, ty, my_loc()), c, my_loc())); - }, - "r", my_loc()); + auto kernel = [&](compiler_context const &ctx) { + auto index_ty = get_scalar(ctx, scalar_type::index); + auto element_ty = get_scalar(ctx, ty); + auto A_ty = make_memref(element_ty, tA, M, K, A_stride, my_loc()); + auto B_ty = make_memref(element_ty, tB, K, N, B_stride, my_loc()); + auto C_ty = make_memref(element_ty, transpose::N, M, N, C_stride, my_loc()); + auto f = + make_func("gemm", + {get_group(A_ty, dynamic, 0, my_loc()), get_group(B_ty, dynamic, 0, my_loc()), + get_group(C_ty, dynamic, 0, my_loc())}, + get_void(ctx), my_loc()); + if (alignment > 0) { + auto align_attr = get_dictionary_attr_with_sorted( + ctx, named_attr{get_string_attr(ctx, "align"), get_integer_attr(ctx, alignment)}); + f.set_parameter_attr(0, align_attr); + f.set_parameter_attr(1, align_attr); + f.set_parameter_attr(2, align_attr); + } + auto fn_body = f.get_body(); + auto params = std::array{}; + fn_body.get_parameters(params); + + auto bb = region_builder{fn_body}; + auto gid = bb.add(make_builtin(builtin::group_id_x, index_ty, my_loc())); + auto from = bb.add(make_constant_zero(index_ty, my_loc())); + auto to = bb.add(make_constant(repetitions, index_ty, my_loc())); + auto calpha = bb.add(make_constant_one(element_ty, my_loc())); + auto cbeta = bb.add(update ? make_constant_one(element_ty, my_loc()) + : make_constant_zero(element_ty, my_loc())); + auto a = bb.add(make_load(params[0], {gid}, A_ty, my_loc())); + auto b = bb.add(make_load(params[1], {gid}, B_ty, my_loc())); + auto c = bb.add(make_load(params[2], {gid}, C_ty, my_loc())); + bb.for_loop( + index_ty, from, to, + [&](region_builder &bb, value const &) { + bb.add(make_gemm(tA, tB, atomic, calpha, a, b, cbeta, c, my_loc())); }, - my_loc()); + nullptr, my_loc()); + + return f; }; try { - auto pb = program_builder{}; - pb.create("gemm", kernel, my_loc()); + auto p = make_prog(ctx, my_loc()); + p.add_function(kernel(ctx)); + if (dump) { + p.dump(); + } auto info = make_core_info(q.get_device()); info.set_core_features(tinytc_core_feature_flag_large_register_file); - return compile_to_opencl(pb.get_product(my_loc()), info, ctx); + return compile_to_spirv_and_assemble(std::move(p), info); } catch (builder_error const &e) { ctx.report_error(e.loc(), e.what()); - std::cerr << "Error (" << static_cast(e.code()) << "): " << std::endl - << ctx.get_error_log() << std::endl; + std::cerr << "Error (" << static_cast(e.code()) << "): " << error_string(e.code()) + << std::endl; } catch (status const &st) { - std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl - << "Error log:" << std::endl - << ctx.get_error_log() << std::endl; + std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; } - return source{nullptr}; + return binary{nullptr}; } template void test(queue q, args &a) { - auto const fill = [](T *ptr, std::size_t n) { - for (std::size_t i = 0; i < n; ++i) { - ptr[i] = i % 101; - } - }; auto total_reals = 1024 * 1024 * 1024 / sizeof(T); T *A_host = new T[total_reals]; T *B_host = new T[total_reals]; T *C_host = new T[total_reals]; - T *C_ref_host = new T[total_reals]; - T *C_ref = malloc_device(total_reals, q); - T *A = malloc_device(total_reals, q); - T *B = malloc_device(total_reals, q); - T *C = malloc_device(total_reals, q); - fill(A_host, total_reals); - fill(B_host, total_reals); - q.copy(A_host, A, total_reals).wait(); - q.copy(B_host, B, total_reals).wait(); + auto const alloc_device = [&a, &q](std::size_t num_bytes) { + if (a.alignment == 0) { + return malloc_device(num_bytes, q); + } else { + return aligned_alloc_device(a.alignment, num_bytes, q); + } + }; + T *A = alloc_device(total_reals); + T *B = alloc_device(total_reals); + T *C = alloc_device(total_reals); - auto const check = [&](std::int64_t M, std::int64_t N, std::size_t howmany) { - q.copy(C_ref, C_ref_host, total_reals).wait(); + auto const check = [&](std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t howmany) { q.copy(C, C_host, total_reals).wait(); std::size_t num_err = 0; - for (std::size_t i = 0; i < M * N * howmany; ++i) { - auto err = std::abs(C_host[i] - C_ref_host[i]); - if (err > 10.0 * std::numeric_limits::epsilon()) { - if (num_err < 10) { - std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] - << std::endl; + const auto error_bound = examples::test_gemm_error_bound(K); + for (std::int64_t b = 0; b < howmany; ++b) { + auto C_host_b = C_host + b * M * N; + for (std::int64_t j = 0; j < N; ++j) { + for (std::int64_t i = 0; i < M; ++i) { + const auto relerr = examples::test_gemm_rel_error(C_host_b, i, j, M); + if (relerr > error_bound) { + if (num_err < 10) { + std::cout + << "C_{" << i << "," << j << "," << b << "}=" << C_host_b[i + j * M] + << ", relative_error=" << relerr << ", error_bound=" << error_bound + << std::endl; + } + ++num_err; + } } - ++num_err; } } if (num_err > 10) { @@ -143,7 +187,6 @@ template void test(queue q, args &a) { } }; - auto const &type = typeid(T); for (auto &c : a.tc) { auto na = c.m * c.k; auto nb = c.k * c.n; @@ -151,29 +194,13 @@ template void test(queue q, args &a) { auto max_reals = std::max(std::max(na, nb), nc); auto howmany = total_reals / max_reals; - if (a.verify && a.internal_repetitions == 1) { - q.submit([&](auto &h) { - bool transa = a.transA == transpose::T; - bool transb = a.transB == transpose::T; - h.parallel_for(range{howmany, 32}, [=](id<2> it) { - auto batch = it[0]; - auto m = it[1]; - auto a = A + batch * na; - auto b = B + batch * nb; - auto c_ref = C_ref + batch * nc; - for (std::int64_t mb = m; mb < c.m; mb += 32) { - for (std::int64_t n = 0; n < c.n; ++n) { - auto c_acc = 0.0f; - for (std::int64_t k = 0; k < c.k; ++k) { - c_acc += a[transa ? k + mb * c.k : mb + k * c.m] * - b[transb ? n + k * c.n : k + n * c.k]; - } - c_ref[mb + n * c.m] = c_acc; - } - } - }); - }); + for (std::size_t i = 0; i < howmany; ++i) { + examples::test_gemm_matrix(A_host + i * na, c.m, c.k, a.trans_a); + examples::test_gemm_matrix(B_host + i * nb, c.k, c.n, a.trans_b); } + q.copy(A_host, A, total_reals).wait(); + q.copy(B_host, B, total_reals).wait(); + q.memset(C, 0, total_reals * sizeof(T)).wait(); T const **AA = malloc_shared(howmany, q); T const **BB = malloc_shared(howmany, q); @@ -187,36 +214,46 @@ template void test(queue q, args &a) { double min_exec_time_ns = 0.0; try { auto src = gemm_kernel_with_inner_repetition( - to_scalar_type_v, a.transA, a.transB, a.atomic, c.m, c.n, c.k, - {1, a.transA == transpose::T ? c.k : c.m}, - {1, a.transB == transpose::T ? c.n : c.k}, a.beta, {1, c.m}, a.internal_repetitions, - q); + a.ty, a.trans_a ? transpose::T : transpose::N, + a.trans_b ? transpose::T : transpose::N, a.atomic, c.m, c.n, c.k, + {1, a.trans_a ? c.k : c.m}, {1, a.trans_b ? c.n : c.k}, a.update, {1, c.m}, + a.alignment, a.internal_repetitions, a.dump, q); if (src) { auto bundle = make_kernel_bundle(q.get_context(), q.get_device(), src); auto kernel = make_kernel(bundle, "gemm"); - auto exe_range = get_execution_range(kernel, howmany); + auto exe_range = get_execution_range(kernel, sycl::range<3u>{1u, 1u, howmany}); q.submit([&](handler &h) { - h.set_args(AA, BB, CC); + h.set_args(AA, howmany, BB, howmany, CC, howmany); h.parallel_for(exe_range, kernel); }).wait(); if (a.internal_repetitions == 1 && a.verify) { - check(c.m, c.n, howmany); + check(c.m, c.n, c.k, howmany); } min_exec_time_ns = bench([&]() { q.submit([&](handler &h) { - h.set_args(AA, BB, CC); + h.set_args(AA, howmany, BB, howmany, CC, howmany); h.parallel_for(exe_range, kernel); }).wait(); }); - auto gflops = - a.internal_repetitions * 2 * c.m * c.n * c.k * howmany / min_exec_time_ns; + const auto ops_per_mnk = [&] { + switch (a.ty) { + case scalar_type::c32: + case scalar_type::c64: + return 8; + default: + return 2; + } + }(); + + auto gflops = a.internal_repetitions * ops_per_mnk * c.m * c.n * c.k * howmany / + min_exec_time_ns; auto roofline_gflops = std::min(512 * 32 * 1.6e9, a.internal_repetitions * 2 * c.m * c.n * c.k / (sizeof(T) * (na + nb + nc) / 1.1e12)) / 1e9; - std::cout << type.name() << "," << c.m << "," << c.n << "," << c.k << "," << howmany - << "," << min_exec_time_ns / 1e9 << "," << gflops << "," + std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," + << howmany << "," << min_exec_time_ns / 1e9 << "," << gflops << "," << roofline_gflops << "," << std::round(gflops / roofline_gflops * 100) << "%," << a.internal_repetitions << std::endl; } @@ -234,24 +271,45 @@ template void test(queue q, args &a) { free(A, q); free(B, q); free(C, q); - free(C_ref, q); delete[] A_host; delete[] B_host; delete[] C_host; - delete[] C_ref_host; }; int main(int argc, char **argv) { auto a = args{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); - } catch (std::runtime_error const &e) { + parser.set_short_opt('a', &a.atomic, "Update C atomically"); + parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); + parser.set_short_opt('f', &a.ty, "Data type (f32, f64, c32, c64)") + .converter(examples::convert_data_type); + parser + .set_short_opt('i', &a.internal_repetitions, + "Number of GEMM repetitions inside kernel (default: 1)") + .validator([](std::int32_t rep) { return 0 <= rep; }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('u', &a.update, + "Add A*B to C (beta=1) instead of overwriting C (beta=0)"); + parser.set_short_opt('v', &a.verify, "Verify optimized implementation"); + parser.set_long_opt("help", &help, "Show help"); + parser.set_long_opt("alignment", &a.alignment, "Memory alignment"); + parser.set_long_opt("transpose-a", &a.trans_a, "Transpose A matrix"); + parser.set_long_opt("transpose-b", &a.trans_b, "Transpose B matrix"); + parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 64x64x64)") + .converter(examples::convert_test_case) + .validator(examples::validate_test_case); + + parser.parse(argc, argv); + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } - if (a.help || a.tc.empty()) { - arg_parser::show_help(std::cout); - return 0; + if (help || a.tc.empty()) { + parser.print_help(std::cout, "tinytc-bench", ""); + return !help ? -1 : 0; } auto q = queue{}; @@ -260,10 +318,27 @@ int main(int argc, char **argv) { "repetitions" << std::endl; try { - if (a.double_precision) { - test(std::move(q), a); - } else { + switch (a.ty) { + case scalar_type::bf16: + test(std::move(q), a); + break; + case scalar_type::f16: + test(std::move(q), a); + break; + case scalar_type::f32: test(std::move(q), a); + break; + case scalar_type::f64: + test(std::move(q), a); + break; + case scalar_type::c32: + test>(std::move(q), a); + break; + case scalar_type::c64: + test>(std::move(q), a); + break; + default: + return -1; } } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/examples/builder/main.c b/examples/builder/main.c index 49ad78d4..f0923d14 100644 --- a/examples/builder/main.c +++ b/examples/builder/main.c @@ -6,48 +6,66 @@ #include int main(void) { - tinytc_scalar_type_t type = tinytc_scalar_type_f32; + tinytc_scalar_type_t sty = tinytc_scalar_type_f32; int64_t M = 64; int64_t N = 32; - tinytc_data_type_t dt; + char const *copy_fun_name = "copy"; + uint32_t num_results; + uint32_t num_params; + tinytc_compiler_context_t ctx; + tinytc_prog_t program; + tinytc_data_type_t void_ty, element_ty, ty; + tinytc_func_t copy_fun; + tinytc_region_t copy_body; + tinytc_inst_t tmp; + tinytc_value_t params[2]; + tinytc_value_t alpha, beta; + + tinytc_compiler_context_create(&ctx); + + // Create program + tinytc_prog_create(&program, ctx, NULL); + + // Get types + tinytc_scalar_type_get(&element_ty, ctx, sty); int64_t shape[2] = {M, N}; - tinytc_memref_type_create(&dt, type, 2, shape, 0, NULL, NULL); - - tinytc_value_t A, B, alpha, beta; - tinytc_value_create(&A, dt, NULL); - tinytc_value_create(&B, dt, NULL); - tinytc_float_imm_create(&alpha, 1.0, type, NULL); - tinytc_float_imm_create(&beta, 0.0, type, NULL); - tinytc_data_type_release(dt); - - tinytc_inst_t copy_inst; - tinytc_axpby_inst_create(©_inst, tinytc_transpose_N, 0, alpha, A, beta, B, NULL); - tinytc_value_release(alpha); - tinytc_value_release(beta); - - tinytc_func_t copy_proto; - tinytc_value_t args[2] = {A, B}; - tinytc_function_prototype_create(©_proto, "copy", 2, args, NULL); - tinytc_value_release(A); - tinytc_value_release(B); + tinytc_memref_type_get(&ty, element_ty, 2, shape, 0, NULL, tinytc_address_space_global, NULL); - tinytc_region_t copy_body; - tinytc_region_create(©_body, 1, ©_inst, NULL); - tinytc_inst_release(copy_inst); + // Get void type + tinytc_void_type_get(&void_ty, ctx); - tinytc_func_t copy_fun; - tinytc_function_create(©_fun, copy_proto, copy_body, NULL); - tinytc_func_release(copy_proto); - tinytc_region_release(copy_body); + // Create function + tinytc_data_type_t param_types[2] = {ty, ty}; + tinytc_func_create(©_fun, sizeof(copy_fun_name) - 1, copy_fun_name, 2, param_types, void_ty, + NULL); + tinytc_prog_add_function(program, copy_fun); - tinytc_prog_t program; - tinytc_program_create(&program, 1, ©_fun, NULL); - tinytc_func_release(copy_fun); + // Get body + tinytc_func_get_body(copy_fun, ©_body); + num_params = 2; + tinytc_region_get_parameters(copy_body, &num_params, params); + + // Create instructions + tinytc_constant_inst_create_one(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &alpha); + tinytc_region_append(copy_body, tmp); + + tinytc_constant_inst_create_zero(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &beta); + tinytc_region_append(copy_body, tmp); + + tinytc_axpby_inst_create(&tmp, tinytc_transpose_N, 0, alpha, params[0], beta, params[1], NULL); + tinytc_region_append(copy_body, tmp); + // Dump program tinytc_prog_dump(program); + // Clean-up tinytc_prog_release(program); + tinytc_compiler_context_release(ctx); return 0; } diff --git a/examples/builder/main.cpp b/examples/builder/main.cpp index f506aeeb..5eb01e48 100644 --- a/examples/builder/main.cpp +++ b/examples/builder/main.cpp @@ -4,31 +4,38 @@ #include "tinytc/tinytc.hpp" #include "tinytc/types.hpp" +#include #include #include +#include using namespace tinytc; int main() { - scalar_type type = scalar_type::f32; + scalar_type sty = scalar_type::f32; int64_t M = 64; int64_t N = 32; try { - auto pb = program_builder{}; - pb.create("copy", [&](function_builder &fb) { - auto dt = make_memref(type, {M, N}); - auto A = fb.argument(dt); - auto B = fb.argument(dt); - fb.body([&](region_builder &bb) { - auto alpha = make_imm(1.0, type); - auto beta = make_imm(0.0, type); - bb.add(make_axpby(transpose::N, false, alpha, A, beta, B)); - }); - }); - auto program = pb.get_product(); - - program.dump(); + auto ctx = make_compiler_context(); + auto element_ty = get_scalar(ctx, sty); + auto ty = get_memref(element_ty, {M, N}); + + auto f = make_func("copy", {ty, ty}, get_void(ctx)); + + auto body = f.get_body(); + std::array params; + body.get_parameters(params); + + auto bb = region_builder{body}; + auto alpha = bb.add(make_constant_one(element_ty)); + auto beta = bb.add(make_constant_zero(element_ty)); + bb.add(make_axpby(transpose::N, false, alpha, params[0], beta, params[1])); + + auto p = make_prog(ctx); + p.add_function(std::move(f)); + + p.dump(); } catch (builder_error const &e) { std::cerr << "Error " << static_cast(e.code()) << std::endl; } catch (status const &st) { diff --git a/examples/gemm_common.hpp b/examples/gemm_common.hpp new file mode 100644 index 00000000..95e698e1 --- /dev/null +++ b/examples/gemm_common.hpp @@ -0,0 +1,139 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GEMM_COMMON_20241014_HPP +#define GEMM_COMMON_20241014_HPP + +#include "argparser.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::examples { + +struct test_case { + std::int64_t m; + std::int64_t n; + std::int64_t k; +}; + +inline auto convert_data_type(char const *str, scalar_type &val) -> cmd::parser_status { + if (std::strcmp(str, "bf16") == 0) { + val = scalar_type::bf16; + } else if (std::strcmp(str, "f16") == 0) { + val = scalar_type::f16; + } else if (std::strcmp(str, "f32") == 0) { + val = scalar_type::f32; + } else if (std::strcmp(str, "f64") == 0) { + val = scalar_type::f64; + } else if (std::strcmp(str, "c32") == 0) { + val = scalar_type::c32; + } else if (std::strcmp(str, "c64") == 0) { + val = scalar_type::c64; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; +}; +inline auto convert_test_case(char const *str, test_case &tc) -> cmd::parser_status { + auto const parse = [](std::int64_t *v, char const *str, char **end, char sep) { + *v = strtol(str, end, 10); + if (*v == 0 || **end != sep) { + throw cmd::parser_status::invalid_argument; + } + if (errno == ERANGE) { + throw cmd::parser_status::argument_out_of_range; + } + }; + char *end = nullptr; + try { + parse(&tc.m, str, &end, 'x'); + parse(&tc.n, end + 1, &end, 'x'); + parse(&tc.k, end + 1, &end, 0); + } catch (cmd::parser_status st) { + return st; + } + return cmd::parser_status::success; +} +inline auto validate_test_case(test_case const &tc) -> bool { + return tc.m > 0 && tc.n > 0 && tc.k > 0; +} + +template auto fabs(T x) { + if constexpr (std::is_same_v) { + return sycl::fabs(x); + } else { + return std::abs(x); + } +} + +template auto compute_error(T x, T x_ref) { + auto err = examples::fabs(x - x_ref); + const auto scale = examples::fabs(x_ref); + return scale > std::numeric_limits::epsilon() ? err / scale : err; +} + +// Increment values in bf16 epsilons +constexpr double test_gemm_smallest_eps = 0.0078125; + +template +void test_gemm_matrix(T *data, std::size_t M, std::size_t N, bool transposed = false) { + for (std::size_t j = 0; j < N; ++j) { + for (std::size_t i = 0; i < M; ++i) { + const auto idx = transposed ? j + i * N : i + j * M; + if constexpr (Use == matrix_use::a) { + data[idx] = (1.0 + i * test_gemm_smallest_eps) * (j + 1) / N; + } else if constexpr (Use == matrix_use::b) { + data[idx] = 1.0 / ((i + 1) * (1 + j * test_gemm_smallest_eps)); + } else { + data[idx] = 0; + } + } + } +} + +template struct is_complex : public std::false_type {}; +template struct is_complex> : public std::true_type {}; +template inline constexpr bool is_complex_v = is_complex::value; + +template struct is_lp_float : public std::false_type {}; +template +struct is_lp_float> : public std::true_type {}; +template inline constexpr bool is_lp_float_v = is_lp_float::value; + +template +auto test_gemm_rel_error(T *data, std::size_t i, std::size_t j, std::size_t M) -> double { + const double ref = (1 + i * test_gemm_smallest_eps) / (1 + j * test_gemm_smallest_eps); + if constexpr (is_complex_v) { + return std::abs(static_cast>(data[i + j * M]) - + std::complex{ref}) / + ref; + } else { + return std::abs(static_cast(data[i + j * M]) - ref) / ref; + } +} + +template auto test_gemm_error_bound(std::size_t K) { + const auto gamma = [](std::size_t K, double u) { return K * u / (1.0 - K * u); }; + if constexpr (is_lp_float_v) { + const double u = std::pow(2.0, -static_cast(T::lp_format::mantissa_bits)); + const double u_f32 = std::numeric_limits::epsilon(); + // Accumulation is done in single precision + return 2.0 * u + u * u + gamma(K, u_f32) * (1 + u) * (1 + u); + } else if constexpr (is_complex_v) { + return gamma(K, std::numeric_limits::epsilon()); + } else { + return gamma(K, std::numeric_limits::epsilon()); + } +} + +} // namespace tinytc::examples + +#endif // GEMM_COMMON_20241014_HPP diff --git a/examples/jit/main.cpp b/examples/jit/main.cpp index 14aee5a3..887c2535 100644 --- a/examples/jit/main.cpp +++ b/examples/jit/main.cpp @@ -15,18 +15,15 @@ int main(int argc, char **argv) { return -1; } - auto ctx = source_context{}; try { - ctx = make_source_context(); auto info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - auto prog = parse_file(argv[1], ctx); + auto prog = parse_file(argv[1]); if (!prog) { return -1; } - compile_to_opencl(std::move(prog), info, ctx); + compile_to_spirv_and_assemble(std::move(prog), info); } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; - std::cerr << "Error log: " << std::endl << ctx.get_error_log() << std::endl; return 1; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/examples/matrix_chain/CMakeLists.txt b/examples/matrix_chain/CMakeLists.txt index 7fa1358c..9f5e9eb0 100644 --- a/examples/matrix_chain/CMakeLists.txt +++ b/examples/matrix_chain/CMakeLists.txt @@ -14,5 +14,5 @@ set(SOURCES add_executable(matrix_chain ${SOURCES}) add_sycl_to_target(TARGET matrix_chain SOURCES ${SOURCES}) -target_link_libraries(matrix_chain PRIVATE tinytc tinytc_sycl) +target_link_libraries(matrix_chain PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(matrix_chain) diff --git a/examples/matrix_chain/main.cpp b/examples/matrix_chain/main.cpp index e8062846..1355b4e1 100644 --- a/examples/matrix_chain/main.cpp +++ b/examples/matrix_chain/main.cpp @@ -4,6 +4,7 @@ #include "test.hpp" #include "test_multi.hpp" +#include #include #include @@ -14,8 +15,58 @@ #include using namespace sycl; +using namespace tinytc; int main(int argc, char **argv) { + bool dump = false; + std::int64_t N = 5, P = 9, howmany; + std::size_t alignment = 0; + char precision = 's'; + test_case tc = test_case::volume; + bool help = false; + + auto parser = cmd::arg_parser{}; + try { + parser.set_short_opt('a', &alignment, "Alignment (in number of bytes)"); + parser.set_short_opt('d', &dump, "Dump IR to stdout"); + parser.set_short_opt('f', &precision, "Data type (s or d)").validator([](char f) { + return f == 's' || f == 'd'; + }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('N', &N, "Polynomial degree").validator([](std::int64_t p) { + return p > 0; + }); + parser.set_short_opt('P', &P, "Number of quantities").validator([](std::int64_t n) { + return n > 0; + }); + parser.set_long_opt("help", &help, "Show help"); + parser.add_positional_arg("test_case", &tc, "Test case (volume or ader)", true) + .converter([](char const *str, test_case &val) -> cmd::parser_status { + if (strcmp(str, "volume") == 0) { + val = test_case::volume; + } else if (strcmp(str, "ader") == 0) { + val = test_case::ader; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; + }); + parser.add_positional_arg("howmany", &howmany, "Batch size", true) + .validator([](std::int64_t h) { return h > 0; }); + + parser.parse(argc, argv); + } catch (std::exception const &e) { + if (!help) { + std::cerr << e.what() << std::endl; + } + parser.print_help(std::cout, "matrix_chain", ""); + return help ? 0 : -1; + } + if (help) { + parser.print_help(std::cout, "matrix_chain", ""); + return 0; + } + auto devices = platform{}.get_devices(); auto sub_devices = std::vector{}; for (auto &device : devices) { @@ -33,41 +84,9 @@ int main(int argc, char **argv) { q.emplace_back(queue(device)); } - std::int64_t N = 5, P = 9, howmany; - std::size_t alignment = 0; - char precision = 's'; - test_case tc = test_case::volume; - - if (argc < 5) { - std::cerr << "Usage: matrix_chain

[alignment] [s/d]" - << std::endl; - return -1; - } - if (strcmp(argv[1], "volume") == 0) { - tc = test_case::volume; - } else if (strcmp(argv[1], "ader") == 0) { - tc = test_case::ader; - } else { - std::cerr << "Unknown test case " << argv[1] << ". Available are: ader, volume." - << std::endl; - return -1; - } - N = static_cast(std::atol(argv[2])); - P = static_cast(std::atol(argv[3])); - howmany = static_cast(std::atol(argv[4])); - if (argc >= 6) { - alignment = static_cast(std::atol(argv[5])); - } - if (argc >= 7) { - precision = argv[6][0]; - if (precision != 's' && precision != 'd') { - std::cerr << "Precision must be single (s) or double (d)" << std::endl; - return -1; - } - } auto run_test_multi = [&](auto precision) { using T = decltype(precision); - auto t = test_multi(N, P, howmany, alignment, tc, q); + auto t = test_multi(N, P, howmany, alignment, tc, q, dump); if (!t.check()) { std::cerr << "Result mismatch between reference and optimized!" << std::endl; // return; diff --git a/examples/matrix_chain/matrix_batch.hpp b/examples/matrix_chain/matrix_batch.hpp index f882b5b7..721f54e3 100644 --- a/examples/matrix_chain/matrix_batch.hpp +++ b/examples/matrix_chain/matrix_batch.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -16,30 +17,35 @@ template class matrix_batch { public: matrix_batch(std::int64_t nrows, std::int64_t ncols, std::int64_t ld, std::int64_t howmany, sycl::queue q) - : nrows_(nrows), ncols_(ncols), ld_(ld), howmany_(howmany), - data_(ld_ * ncols_ * howmany, std::move(q)) {} + : shape_{nrows, ncols}, ld_{ld}, howmany_{howmany}, + data_(stride() * howmany_, std::move(q)) {} inline T *get() { return data_.get(); } inline T const *get() const { return data_.get(); } - inline std::int64_t nrows() const { return nrows_; } - inline std::int64_t ncols() const { return ncols_; } - inline std::int64_t ld() const { return ld_; } + inline auto shape() const -> tinytc::array_view { return shape_; } + inline std::int64_t nrows() const { return shape_[0]; } + inline std::int64_t ncols() const { return shape_[1]; } inline std::int64_t howmany() const { return howmany_; } - inline std::int64_t stride() const { return ld_ * ncols_; } + inline std::int64_t ld() const { return ld_; } + inline std::int64_t stride() const { return ld_ * ncols(); } inline std::size_t size() const { return data_.size(); } inline void fill(T const &v) { data_.fill(v); } inline void random() { data_.random(); } - inline tinytc::data_type type(bool include_batch_dim = true) { - constexpr auto real_t = tinytc::to_scalar_type_v; - if (include_batch_dim && howmany() > 1) { - return tinytc::make_memref(real_t, {nrows(), ncols(), tinytc::dynamic}, - {1, ld(), stride()}); + inline auto type(tinytc::data_type element_ty) -> tinytc::data_type { + if (howmany_ == 1) { + return tinytc::get_memref(element_ty, {nrows(), ncols()}, {1, ld()}); } - return tinytc::make_memref(real_t, {nrows(), ncols()}, {1, ld()}); + return tinytc::get_memref(element_ty, {nrows(), ncols(), tinytc::dynamic}, + {1, ld(), stride()}); + } + inline auto local_type(tinytc::data_type element_ty) -> tinytc::data_type { + return tinytc::get_memref(element_ty, {nrows(), ncols()}, {1, ld()}, + tinytc::address_space::local); } private: - std::int64_t nrows_, ncols_, ld_, howmany_; + std::array shape_; + std::int64_t ld_, howmany_; device_array data_; }; diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 78192ff9..a80acdfd 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -5,8 +5,8 @@ #include #include -#include #include +#include #include using namespace sycl; @@ -14,14 +14,15 @@ using namespace tinytc; template test_ader::test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - queue q) + queue q, bool dump) : N_(N), P_(P), howmany_(howmany), alignment_(alignment), q_(std::move(q)), dev_info_(make_core_info(q_.get_device())), I_ref_(Bd(), P_, Bd_aligned(), howmany_, q_), I_opt_(Bd(), P_, Bd_aligned(), howmany_, q_), tmp_(Bd(), P_, Bd_aligned(N_ - 1), howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), K_(dim, matrix_batch(Bd(), Bd(), Bd_aligned(N_ - 1), 1, q_)), dQ_(make_dQ()), - opt_bundle_(make_optimized_kernel()), opt_kernel_(make_kernel(opt_bundle_, "ader_kernel")) { + opt_bundle_(make_optimized_kernel(dump)), + opt_kernel_(make_kernel(opt_bundle_, "ader_kernel")) { I_ref_.random(); I_opt_.random(); for (auto &a : A_) { @@ -58,63 +59,108 @@ template std::vector> test_ader::make_dQ() { } template -auto test_ader::make_optimized_kernel() -> sycl::kernel_bundle { +auto test_ader::make_optimized_kernel(bool dump) + -> sycl::kernel_bundle { constexpr auto real_t = to_scalar_type_v; - auto opt_kernel = [&](function_builder &fb) { - T dt = 1.01; - T num = T(1.0); - int denom = 1; - std::array A; - std::array K; + auto opt_kernel = [&](compiler_context const &ctx) { + auto element_ty = get_scalar(ctx, real_t); + std::array param_types; + param_types[0] = element_ty; for (std::size_t i = 0; i < dim; ++i) { - A[i] = fb.argument(A_[i].type(), "A"); + param_types[1 + i] = A_[i].type(element_ty); } for (std::size_t i = 0; i < dim; ++i) { - K[i] = fb.argument(K_[i].type(), "K"); + param_types[1 + dim + i] = K_[i].type(element_ty); + } + param_types[1 + 2 * dim + 0] = dQ_[0].type(element_ty); + param_types[1 + 2 * dim + 1] = I_opt_.type(element_ty); + + auto f = make_func("ader_kernel", param_types, get_void(ctx)); + auto fn_body = f.get_body(); + + std::array params; + fn_body.get_parameters(params); + + auto dt = params[0]; + dt.set_name("dt"); + auto A = [¶ms](std::size_t i) -> value & { return params[1 + i]; }; + auto K = [¶ms](std::size_t i) -> value & { return params[1 + dim + i]; }; + auto Q = params[1 + 2 * dim + 0]; + auto I = params[1 + 2 * dim + 1]; + for (std::size_t i = 0; i < dim; ++i) { + A(i).set_name((std::ostringstream{} << 'A' << i).str()); + K(i).set_name((std::ostringstream{} << 'K' << i).str()); + } + Q.set_name("Q"); + I.set_name("I"); + + auto bb = region_builder{fn_body}; + auto const c0 = bb.add(make_constant_zero(element_ty)); + auto const c1 = bb.add(make_constant_one(element_ty)); + auto const gid = + bb.add(make_builtin(builtin::group_id_x, get_scalar(ctx, scalar_type::index))); + auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes3 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols(), 0}; + }; + auto const static_sizes2 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols()}; + }; + auto const offsets3 = array_view(gid); + auto dqt = get_memref(element_ty, static_sizes2(dQ_[0]), {1, dynamic}); + auto dq = + bb.add(make_subview(Q, static_offsets3, static_sizes3(dQ_[0]), offsets3, {}, dqt)); + for (std::size_t d = 0; d < dim; ++d) { + auto At = get_memref(element_ty, static_sizes2(A_[d])); + A(d) = + bb.add(make_subview(A(d), static_offsets3, static_sizes3(A_[d]), offsets3, {}, At)); } - auto Q = fb.argument(dQ_[0].type(), "dQ"); - auto I = fb.argument(I_opt_.type(), "I"); - fb.body([&](region_builder &bb) { - auto const gid = bb.add(make_group_id()); - auto const offsets3 = std::vector{make_index(0), make_index(0), gid}; - auto const size3 = std::vector{make_dynamic(), make_dynamic(), value{}}; - auto dq = bb.add(make_subview(Q, offsets3, size3)); + auto it = get_memref(element_ty, static_sizes2(I_opt_), {1, dynamic}); + auto i = bb.add(make_subview(I, static_offsets3, static_sizes3(I_opt_), offsets3, {}, it)); + bb.add(make_axpby(transpose::N, false, c1, dq, c1, i)); + + int denom = 1; + auto cnum = c1; + auto const static_offsets2 = std::array{0, 0}; + for (std::int64_t n = 1; n <= N_; ++n) { + cnum = bb.add(make_arith(arithmetic::mul, cnum, dt, dt.get_type())); + denom *= n + 1; + auto cdenom = bb.add(make_constant(static_cast(denom), element_ty)); + auto cfactor = bb.add(make_arith(arithmetic::div, cnum, cdenom, cnum.get_type())); + auto bn = Bd_aligned(N_ - n); + auto dq_next = bb.add(make_alloca(dQ_[n].local_type(element_ty))); + auto dq_nextvt = get_memref(element_ty, {bn, P_}, {1, dynamic}, address_space::local); + auto dq_nextv = + bb.add(make_subview(dq_next, static_offsets2, {bn, P_}, {}, {}, dq_nextvt)); + auto tmp = bb.add( + make_alloca(get_memref(element_ty, {bn, P_}, {1, dynamic}, address_space::local))); for (std::size_t d = 0; d < dim; ++d) { - A[d] = bb.add(make_subview(A[d], offsets3, size3)); + auto Kvt = get_memref(element_ty, {bn, Bd(N_ - n + 1)}, {1, dynamic}); + auto Kv = + bb.add(make_subview(K(d), static_offsets2, {bn, Bd(N_ - n + 1)}, {}, {}, Kvt)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, Kv, dq, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, tmp, A(d), d > 0 ? c1 : c0, + dq_nextv)); } - auto i = bb.add(make_subview(I, offsets3, size3)); - bb.add(make_axpby(transpose::N, false, make_imm(num / denom), dq, make_imm(T(1.0)), i)); - - auto const offsets2 = std::vector{make_index(0), make_index(0)}; - for (std::int64_t n = 1; n <= N_; ++n) { - num *= dt; - denom *= n + 1; - auto bn = Bd_aligned(N_ - n); - auto dq_next = bb.add(make_alloca(dQ_[n].type(false))); - auto dq_nextv = - bb.add(make_subview(dq_next, offsets2, {make_index(bn), make_index(P_)})); - auto tmp = bb.add(make_alloca(make_memref(real_t, {bn, P_}, {1, bn}))); - for (std::size_t d = 0; d < dim; ++d) { - auto Kv = bb.add( - make_subview(K[d], offsets2, {make_index(bn), make_index(Bd(N_ - n + 1))})); - bb.add(make_gemm(transpose::N, transpose::N, false, make_imm(T(1.0)), Kv, dq, - make_imm(T(0.0)), tmp)); - bb.add(make_gemm(transpose::N, transpose::N, false, make_imm(T(1.0)), tmp, A[d], - make_imm(T(d > 0 ? 1.0 : 0.0)), dq_nextv)); - } - auto iv = - bb.add(make_subview(i, offsets2, {make_index(Bd(N_ - n)), make_index(P_)})); - bb.add(make_axpby(transpose::N, false, make_imm(num / denom), dq_next, - make_imm(T(1.0)), iv)); - dq = dq_next; - } - }); - }; - auto pb = program_builder{}; - pb.create("ader_kernel", opt_kernel); + auto ivt = get_memref(element_ty, {Bd(N_ - n), P_}, {1, dynamic}); + auto iv = bb.add(make_subview(i, static_offsets2, {Bd(N_ - n), P_}, {}, {}, ivt)); + bb.add(make_axpby(transpose::N, false, cfactor, dq_next, c1, iv)); + dq = dq_next; + } + return f; + }; + auto ctx = make_compiler_context(); + ctx.set_error_reporter( + [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, + nullptr); + auto p = make_prog(ctx); + p.add_function(opt_kernel(ctx)); + if (dump) { + p.dump(); + } return make_kernel_bundle(q_.get_context(), q_.get_device(), - compile_to_opencl(pb.get_product(), dev_info_)); + compile_to_spirv_and_assemble(p, dev_info_)); } template @@ -157,10 +203,13 @@ template std::vector test_ader::reference() { } template std::vector test_ader::optimized() { - auto exe_range = get_execution_range(opt_kernel_, howmany_); + T dt = 1.01; + auto exe_range = get_execution_range( + opt_kernel_, sycl::range<3u>{1u, 1u, static_cast(howmany_)}); return {q_.submit([&](handler &h) { - h.set_args(A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, K_[0].get(), - K_[1].get(), K_[2].get(), dQ_[0].get(), howmany_, I_opt_.get(), howmany_); + h.set_args(dt, A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, + K_[0].get(), K_[1].get(), K_[2].get(), dQ_[0].get(), howmany_, I_opt_.get(), + howmany_); h.parallel_for(exe_range, opt_kernel_); })}; } diff --git a/examples/matrix_chain/test_ader.hpp b/examples/matrix_chain/test_ader.hpp index 0195c975..84100f89 100644 --- a/examples/matrix_chain/test_ader.hpp +++ b/examples/matrix_chain/test_ader.hpp @@ -20,7 +20,7 @@ template class test_ader : public test { public: test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - sycl::queue q); + sycl::queue q, bool dump = false); ~test_ader() = default; test_ader(test_ader const &other) = delete; test_ader(test_ader &&other) = default; @@ -41,7 +41,7 @@ template class test_ader : public test { inline std::int64_t Bd_aligned() { return aligned(Bd(N_), alignment_); } inline std::int64_t Bd_aligned(std::int64_t N) { return aligned(Bd(N), alignment_); } std::vector> make_dQ(); - auto make_optimized_kernel() -> sycl::kernel_bundle; + auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; sycl::event taylor_sum(matrix_batch &I, matrix_batch &dQ, T factor, std::vector const &dep_events = {}); diff --git a/examples/matrix_chain/test_multi.cpp b/examples/matrix_chain/test_multi.cpp index c0d30308..34fbc50f 100644 --- a/examples/matrix_chain/test_multi.cpp +++ b/examples/matrix_chain/test_multi.cpp @@ -25,14 +25,17 @@ template double bench(F f, int nrepeat = 10) { template test_multi::test_multi(std::int64_t N, std::int64_t P, std::int64_t howmany, - std::size_t alignment, test_case tc, std::vector const &q) { + std::size_t alignment, test_case tc, std::vector const &q, + bool dump) { for (auto &qu : q) { switch (tc) { case test_case::ader: - instances_.emplace_back(std::make_unique>(N, P, howmany, alignment, qu)); + instances_.emplace_back( + std::make_unique>(N, P, howmany, alignment, qu, dump)); break; case test_case::volume: - instances_.emplace_back(std::make_unique>(N, P, howmany, alignment, qu)); + instances_.emplace_back( + std::make_unique>(N, P, howmany, alignment, qu, dump)); break; default: break; diff --git a/examples/matrix_chain/test_multi.hpp b/examples/matrix_chain/test_multi.hpp index 1bd4beac..b7b3027f 100644 --- a/examples/matrix_chain/test_multi.hpp +++ b/examples/matrix_chain/test_multi.hpp @@ -17,7 +17,7 @@ enum class test_case { volume, ader }; template class test_multi { public: test_multi(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - test_case tc, std::vector<::sycl::queue> const &q); + test_case tc, std::vector<::sycl::queue> const &q, bool dump = false); void reference(); void optimized(); diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index 479e5e03..8cce789b 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -6,8 +6,8 @@ #include #include -#include #include +#include #include using namespace sycl; @@ -15,14 +15,15 @@ using namespace tinytc; template test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany, - std::size_t alignment, queue q) + std::size_t alignment, queue q, bool dump) : B3_(num_basis(N, dim)), B2_(num_basis(N - 1, dim)), P_(P), howmany_(howmany), B3_aligned_(aligned(B3_, alignment)), B2_aligned_(aligned(B2_, alignment)), q_(std::move(q)), dev_info_(make_core_info(q_.get_device())), Q_ref_(B3_, P_, B3_aligned_, howmany_, q_), Q_opt_(B3_, P_, B3_aligned_, howmany_, q_), I_(B3_, P_, B3_aligned_, howmany_, q_), tmp_(B3_, P_, B2_aligned_, howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), - K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), opt_bundle_(make_optimized_kernel()), + K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), ctx_(make_compiler_context()), + opt_bundle_(make_optimized_kernel(dump)), opt_kernel_(make_kernel(opt_bundle_, "volume_kernel")) { Q_ref_.random(); Q_opt_.random(); @@ -38,92 +39,110 @@ test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany g_.emplace_back(make_recipe_handler( q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, B2_aligned_, P_, P_, B3_aligned_, B3_aligned_ * P_, P_, P_ * P_, - B2_aligned_, B2_aligned_ * P_))); + B2_aligned_, B2_aligned_ * P_, ctx_))); g_.emplace_back(make_recipe_handler( q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, B3_aligned_, P_, B2_, B3_aligned_, 0, B2_aligned_, - B2_aligned_ * P_, B3_aligned_, B3_aligned_ * P_))); + B2_aligned_ * P_, B3_aligned_, B3_aligned_ * P_, ctx_))); +} + +template auto test_volume::make_compiler_context() -> compiler_context { + auto ctx = ::tinytc::make_compiler_context(); + ctx.set_error_reporter( + [](char const *what, const tinytc_location_t *, void *) { std::cerr << what << std::endl; }, + nullptr); + return ctx; } template -auto test_volume::make_optimized_kernel() +auto test_volume::make_optimized_kernel(bool dump) -> sycl::kernel_bundle { constexpr auto real_t = to_scalar_type_v; - /** - * With B3_ = 56, B3_aligned_ = 64, B2_ = 35, B2_aligned_ = 48, P_ = 9 - * - * func chain(A0: batch, distance<81>>, - * A1: batch, distance<81>>, - * A2: batch, distance<81>>, - * K0: memref, - * K1: memref, - * K2: memref, - * Q: batch, distance<576>>, - * I: batch, distance<576>>) { - * a0 = get_work_item A0 - * a1 = get_work_item A1 - * a2 = get_work_item A2 - * q = get_work_item Q - * i = get_work_item i - * tmp = alloca matrix; - * K0v = submatrix K0[0:64,0:35] - * K1v = submatrix K1[0:64,0:35] - * K2v = submatrix K2[0:64,0:35] - * qv = submatrix Q[0:64,0:9] - * iv = submatrix I[0:48,0:9] - * tmpv = submatrix tmp[0:48,0:9] - * matmul(iv, a0, tmpv, 1.0, 0.0); - * matmul(K0v, tmp, qv, 1.0, 0.0); - * matmul(iv, a1, tmpv, 1.0, 0.0); - * matmul(K1v, tmp, qv, 1.0, 0.0); - * matmul(iv, a2, tmpv, 1.0, 0.0); - * matmul(K2v, tmp, qv, 1.0, 0.0); - * } - */ // Optimized kernel - auto opt_kernel = [&](function_builder &fb) { - auto A0 = fb.argument(A_[0].type(), "A0"); - auto A1 = fb.argument(A_[1].type(), "A1"); - auto A2 = fb.argument(A_[2].type(), "A2"); - auto K0 = fb.argument(K_[0].type(), "K0"); - auto K1 = fb.argument(K_[1].type(), "K1"); - auto K2 = fb.argument(K_[2].type(), "K2"); - auto Q = fb.argument(Q_opt_.type(), "Q"); - auto I = fb.argument(I_.type(), "I"); - fb.body([&](region_builder &bb) { - auto gid = bb.add(make_group_id()); - auto const offsets2 = std::vector{make_index(0), make_index(0)}; - auto const offsets3 = std::vector{make_index(0), make_index(0), gid}; - auto const size3 = std::vector{make_dynamic(), make_dynamic(), value{}}; - auto const sizeK2 = std::vector{make_index(B3_aligned_), make_index(B2_)}; - auto tmp = bb.add(make_alloca(make_memref(real_t, {B2_, P_}, {1, B2_aligned_}))); - auto a0 = bb.add(make_subview(A0, offsets3, size3)); - auto a1 = bb.add(make_subview(A1, offsets3, size3)); - auto a2 = bb.add(make_subview(A2, offsets3, size3)); - auto K0v = bb.add(make_subview(K0, offsets2, sizeK2)); - auto K1v = bb.add(make_subview(K1, offsets2, sizeK2)); - auto K2v = bb.add(make_subview(K2, offsets2, sizeK2)); - auto qv = bb.add( - make_subview(Q, offsets3, {make_index(B3_aligned_), make_index(P_), value{}})); - auto iv = bb.add( - make_subview(I, offsets3, {make_index(B2_aligned_), make_index(P_), value{}})); - auto tmpv = - bb.add(make_subview(tmp, offsets2, {make_index(B2_aligned_), make_index(P_)})); - auto const s0 = make_imm(T(0.0)); - auto const s1 = make_imm(T(1.0)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a0, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K0v, tmp, s1, qv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a1, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K1v, tmp, s1, qv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a2, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K2v, tmp, s1, qv)); - }); - }; - auto pb = program_builder{}; - pb.create("volume_kernel", opt_kernel); + auto opt_kernel = [&](compiler_context const &ctx) { + auto element_ty = get_scalar(ctx, real_t); + std::array param_types; + for (std::size_t i = 0; i < dim; ++i) { + param_types[i] = A_[i].type(element_ty); + } + for (std::size_t i = 0; i < dim; ++i) { + param_types[dim + i] = K_[i].type(element_ty); + } + param_types[2 * dim + 0] = Q_opt_.type(element_ty); + param_types[2 * dim + 1] = I_.type(element_ty); + + auto f = make_func("volume_kernel", param_types, get_void(ctx)); + auto fn_body = f.get_body(); + + std::array params; + fn_body.get_parameters(params); + + auto A = [¶ms](std::size_t i) -> value & { return params[i]; }; + auto K = [¶ms](std::size_t i) -> value & { return params[dim + i]; }; + auto Q = params[2 * dim + 0]; + auto I = params[2 * dim + 1]; + for (std::size_t i = 0; i < dim; ++i) { + A(i).set_name((std::ostringstream{} << 'A' << i).str()); + K(i).set_name((std::ostringstream{} << 'K' << i).str()); + } + Q.set_name("Q"); + I.set_name("I"); + + auto bb = region_builder{fn_body}; + auto gid = bb.add(make_builtin(builtin::group_id_x, get_scalar(ctx, scalar_type::index))); + auto const static_offsets2 = std::array{0, 0}; + auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes2 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols()}; + }; + auto const static_sizes3 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols(), 0}; + }; + auto const offsets3 = array_view(gid); + auto const sizeK2 = std::array{B3_aligned_, B2_}; + auto tmp = bb.add( + make_alloca(get_memref(element_ty, {B2_aligned_, P_}, {}, address_space::local))); + + auto a0t = get_memref(element_ty, static_sizes2(A_[0])); + auto a1t = get_memref(element_ty, static_sizes2(A_[1])); + auto a2t = get_memref(element_ty, static_sizes2(A_[2])); + auto k0t = get_memref(element_ty, sizeK2); + auto k1t = get_memref(element_ty, sizeK2); + auto k2t = get_memref(element_ty, sizeK2); + auto qvt = get_memref(element_ty, {B3_aligned_, P_}); + auto ivt = get_memref(element_ty, {B2_aligned_, P_}, {1, dynamic}); + auto tmpvt = get_memref(element_ty, {B2_, P_}, {}, address_space::local); + auto a0 = + bb.add(make_subview(A(0), static_offsets3, static_sizes3(A_[0]), offsets3, {}, a0t)); + auto a1 = + bb.add(make_subview(A(1), static_offsets3, static_sizes3(A_[1]), offsets3, {}, a1t)); + auto a2 = + bb.add(make_subview(A(2), static_offsets3, static_sizes3(A_[2]), offsets3, {}, a2t)); + auto k0 = bb.add(make_subview(K(0), static_offsets2, sizeK2, {}, {}, k0t)); + auto k1 = bb.add(make_subview(K(1), static_offsets2, sizeK2, {}, {}, k1t)); + auto k2 = bb.add(make_subview(K(2), static_offsets2, sizeK2, {}, {}, k2t)); + auto qv = bb.add(make_subview(Q, static_offsets3, {B3_aligned_, P_, 0}, offsets3, {}, qvt)); + auto iv = bb.add(make_subview(I, static_offsets3, {B2_aligned_, P_, 0}, offsets3, {}, ivt)); + auto tmpv = bb.add(make_subview(tmp, static_offsets2, {B2_, P_}, {}, {}, tmpvt)); + auto const c0 = bb.add(make_constant_zero(element_ty)); + auto const c1 = bb.add(make_constant_one(element_ty)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a0, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, k0, tmpv, c1, qv)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a1, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, k1, tmpv, c1, qv)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, iv, a2, c0, tmp)); + bb.add(make_gemm(transpose::N, transpose::N, false, c1, k2, tmpv, c1, qv)); + + return f; + }; + auto p = make_prog(ctx_); + p.add_function(opt_kernel(ctx_)); + if (dump) { + p.dump(); + } return make_kernel_bundle(q_.get_context(), q_.get_device(), - compile_to_opencl(pb.get_product(), dev_info_)); + compile_to_spirv_and_assemble(p, dev_info_)); } template std::vector test_volume::reference() { @@ -142,7 +161,8 @@ template std::vector test_volume::reference() { } template std::vector test_volume::optimized() { - auto exe_range = get_execution_range(opt_kernel_, howmany_); + auto exe_range = get_execution_range( + opt_kernel_, sycl::range<3u>{1u, 1u, static_cast(howmany_)}); return {q_.submit([&](handler &h) { h.set_args(A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, K_[0].get(), K_[1].get(), K_[2].get(), Q_opt_.get(), howmany_, I_.get(), howmany_); diff --git a/examples/matrix_chain/test_volume.hpp b/examples/matrix_chain/test_volume.hpp index ebd64f55..d0d07d40 100644 --- a/examples/matrix_chain/test_volume.hpp +++ b/examples/matrix_chain/test_volume.hpp @@ -19,7 +19,7 @@ template class test_volume : public test { public: test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - sycl::queue q); + sycl::queue q, bool dump = false); ~test_volume() = default; test_volume(test_volume const &other) = delete; test_volume(test_volume &&other) = default; @@ -35,13 +35,15 @@ template class test_volume : public test { private: constexpr static std::size_t dim = 3; - auto make_optimized_kernel() -> sycl::kernel_bundle; + auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; + auto make_compiler_context() -> tinytc::compiler_context; std::int64_t B3_, B2_, P_, howmany_, B3_aligned_, B2_aligned_; sycl::queue q_; tinytc::core_info dev_info_; matrix_batch Q_ref_, Q_opt_, I_, tmp_; std::vector> A_, K_; + tinytc::compiler_context ctx_; sycl::kernel_bundle opt_bundle_; sycl::kernel opt_kernel_; std::vector g_; diff --git a/examples/simple_cl/main.c b/examples/simple_cl/main.c index 0ae5c520..c12a1378 100644 --- a/examples/simple_cl/main.c +++ b/examples/simple_cl/main.c @@ -39,7 +39,6 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue queue) { tinytc_status_t status = tinytc_status_success; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; cl_mem A = NULL, B = NULL, C = NULL; @@ -49,11 +48,10 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue q CHECK(tinytc_cl_core_info_create(&info, device)); const uint32_t M = 64, N = 64, K = 64, howmany = 1000; - CHECK(tinytc_source_context_create(&source_ctx)); CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, tinytc_scalar_type_f32, tinytc_transpose_N, tinytc_transpose_N, M, N, K, - M, M * K, K, K * N, M, M * N, source_ctx)); - CHECK(tinytc_cl_recipe_handler_create(&handler, context, device, recipe, source_ctx)); + M, M * K, K, K * N, M, M * N, NULL)); + CHECK(tinytc_cl_recipe_handler_create(&handler, context, device, recipe)); const size_t Abytes = M * K * howmany * sizeof(float); const size_t Bbytes = K * N * howmany * sizeof(float); @@ -114,14 +112,6 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue q } tinytc_recipe_handler_release(handler); tinytc_recipe_release(recipe); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); return status; @@ -132,7 +122,6 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman int32_t *host = NULL; cl_mem A = NULL, B = NULL; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_prog_t program = NULL; cl_program module = NULL; cl_kernel kernel = NULL; @@ -158,17 +147,16 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman static const char source_text[] = "func @copy(%A: memref, %B: memref) {\n" - " %gid = group_id\n" - " %a = subview %A[:,%gid] : memref\n" - " %b = subview %B[:,%gid] : memref\n" - " axpby.n 1, %a, 0, %b\n" - " : i32, memref, i32, memref\n" + " %gid = builtin.group_id.x : index\n" + " %a = subview %A[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %b = subview %B[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %c0 = constant 0 : i32\n" + " %c1 = constant 1 : i32\n" + " axpby.n %c1, %a, %c0, %b\n" "}\n"; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, source_ctx)); - CHECK(tinytc_cl_kernel_bundle_create_with_program(&module, context, device, program, 0u, - source_ctx)); + CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, NULL)); + CHECK(tinytc_cl_kernel_bundle_create_with_program(&module, context, device, program, 0u)); kernel = clCreateKernel(module, "copy", &err); CL_CHECK(err); @@ -176,9 +164,10 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman CL_CHECK(clSetKernelArg(kernel, 1, sizeof(howmany), &howmany)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(B), &B)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(howmany), &howmany)); + size_t ng[3] = {howmany, 1, 1}; size_t ls[3], gs[3]; CHECK(tinytc_cl_get_group_size(kernel, ls)); - tinytc_cl_get_global_size(howmany, ls, gs); + tinytc_cl_get_global_size(ng, ls, gs); struct timespec start_time, end_time; clock_gettime(CLOCK_MONOTONIC, &start_time); @@ -211,14 +200,6 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman clReleaseProgram(module); } tinytc_prog_release(program); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); if (B) { clReleaseMemObject(B); diff --git a/examples/simple_ze/main.c b/examples/simple_ze/main.c index d4a6a713..0911b2bb 100644 --- a/examples/simple_ze/main.c +++ b/examples/simple_ze/main.c @@ -41,7 +41,6 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, ze_command_list_handle_t list) { tinytc_status_t status = tinytc_status_success; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; void *A = NULL, *B = NULL, *C = NULL; @@ -50,11 +49,10 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, CHECK(tinytc_ze_core_info_create(&info, device)); const uint32_t M = 64, N = 64, K = 64, howmany = 1000; - CHECK(tinytc_source_context_create(&source_ctx)); CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, tinytc_scalar_type_f32, tinytc_transpose_N, tinytc_transpose_N, M, N, K, - M, M * K, K, K * N, M, M * N, source_ctx)); - CHECK(tinytc_ze_recipe_handler_create(&handler, context, device, recipe, source_ctx)); + M, M * K, K, K * N, M, M * N, NULL)); + CHECK(tinytc_ze_recipe_handler_create(&handler, context, device, recipe)); const size_t Abytes = M * K * howmany * sizeof(float); const size_t Bbytes = K * N * howmany * sizeof(float); @@ -110,14 +108,6 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, } tinytc_recipe_handler_release(handler); tinytc_recipe_release(recipe); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); return status; @@ -126,18 +116,17 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t device, ze_command_list_handle_t list) { tinytc_status_t status = tinytc_status_success; - int32_t *host = NULL; + int16_t *host = NULL; void *A = NULL, *B = NULL; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_prog_t program = NULL; ze_module_handle_t module = NULL; ze_kernel_handle_t kernel = NULL; const uint32_t howmany = 1000; const int32_t elements = CHUNK_SIZE * howmany; - const size_t bytes = elements * sizeof(float); - host = (int32_t *)malloc(bytes); + const size_t bytes = elements * sizeof(int16_t); + host = (int16_t *)malloc(bytes); if (!host) { goto err; } @@ -156,24 +145,21 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de static const char source_text[] = "func @copy(%A: memref, %B: memref) {\n" - " %gid = group_id\n" - " %a = subview %A[:,%gid] : memref\n" - " %b = subview %B[:,%gid] : memref\n" - " axpby.n 1, %a, 0, %b\n" - " : i32, memref, i32, memref\n" + " %gid = builtin.group_id.x : index\n" + " %a = subview %A[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %b = subview %B[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %c0 = constant 0 : i32\n" + " %c1 = constant 1 : i32\n" + " axpby.n %c1, %a, %c0, %b\n" "}\n"; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, source_ctx)); - CHECK(tinytc_ze_kernel_bundle_create_with_program(&module, context, device, program, 0u, - source_ctx)); + CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, NULL)); + CHECK(tinytc_ze_kernel_bundle_create_with_program(&module, context, device, program, 0u)); CHECK(tinytc_ze_kernel_create(&kernel, module, "copy")); ZE_CHECK(zeKernelSetArgumentValue(kernel, 0, sizeof(A), &A)); - ZE_CHECK(zeKernelSetArgumentValue(kernel, 1, sizeof(howmany), &howmany)); - ZE_CHECK(zeKernelSetArgumentValue(kernel, 2, sizeof(B), &B)); - ZE_CHECK(zeKernelSetArgumentValue(kernel, 3, sizeof(howmany), &howmany)); - ze_group_count_t group_count = tinytc_ze_get_group_count(howmany); + ZE_CHECK(zeKernelSetArgumentValue(kernel, 1, sizeof(B), &B)); + ze_group_count_t group_count = {howmany, 1u, 1u}; ZE_CHECK(zeCommandListAppendLaunchKernel(list, kernel, &group_count, NULL, 0, NULL)); ZE_CHECK(zeCommandListHostSynchronize(list, TIMEOUT)); @@ -182,10 +168,19 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de uint32_t ok = 0; for (int32_t i = 0; i < elements; ++i) { + int32_t c = i / 64; + int32_t r = i % 64; + if (r < 16 && c < 8) { + printf("%d ", host[i]); + } + if (r == 63 && c < 8) { + printf("\n"); + } if (host[i] == i) { ++ok; } } + printf("\n"); if (ok == (uint32_t)elements) { printf("Custom kernel was successful\n"); } else { @@ -200,14 +195,6 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de zeModuleDestroy(module); } tinytc_prog_release(program); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); if (B) { zeMemFree(context, B); diff --git a/examples/tall_and_skinny/CMakeLists.txt b/examples/tall_and_skinny/CMakeLists.txt index f65f8efb..9defb5c6 100644 --- a/examples/tall_and_skinny/CMakeLists.txt +++ b/examples/tall_and_skinny/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) find_package(SYCL REQUIRED) -add_executable(tall_and_skinny main.cpp args.cpp) +add_executable(tall_and_skinny main.cpp) add_sycl_to_target(TARGET tall_and_skinny SOURCES main.cpp) -target_link_libraries(tall_and_skinny PRIVATE tinytc tinytc_sycl) +target_link_libraries(tall_and_skinny PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(tall_and_skinny) diff --git a/examples/tall_and_skinny/args.cpp b/examples/tall_and_skinny/args.cpp deleted file mode 100644 index 6aa1268a..00000000 --- a/examples/tall_and_skinny/args.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" - -#include -#include -#include -#include -#include - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.beta = 0.0; - a.specialize_M = false; - a.specialize_ld = false; - auto num = std::vector(3); - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error("==> Error: unrecognized argument " + - std::string(argv[i])); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "-v") == 0 || std::strcmp(argv[i], "--verify") == 0) { - a.verify = true; - } else if (std::strcmp(argv[i], "--specialize-M") == 0) { - a.specialize_M = true; - } else if (std::strcmp(argv[i], "--specialize-ld") == 0) { - a.specialize_ld = true; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { - ++i; - a.beta = atof(argv[i]); - } else if (std::strcmp(argv[i], "-p") == 0 || - std::strcmp(argv[i], "--precision") == 0) { - ++i; - if (argv[i][0] == 'd') { - a.double_precision = true; - } else if (argv[i][0] == 's') { - a.double_precision = false; - } else { - fail(); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - num.clear(); - char const *delim = "x"; - auto arg = std::string(argv[i]); - char *token = std::strtok(argv[i], delim); - while (token) { - num.emplace_back(atoi(token)); - token = std::strtok(nullptr, delim); - } - if (num.size() != 3) { - throw std::runtime_error("==> Could not parse test case: " + arg); - } - a.tc.push_back({num[0], num[1], num[2]}); - } - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tall_and_skinny test-case1 test-case2 ..." << std::endl - << R"HELP( -positional arguments: - test-caseN MxNxK triplet (e.g. 300000x64x64) - -optional arguments: - -h, --help Show help and quit - -p, --precision Precision (single = s, double = d) - -v, --verify Verify optimized implementation - --specialize-M Specialize M instead of using dynamic value - --specialize-ld Specialize ldA, ldB, ldC instead of using dynamic value -)HELP"; -} diff --git a/examples/tall_and_skinny/args.hpp b/examples/tall_and_skinny/args.hpp deleted file mode 100644 index 701710fe..00000000 --- a/examples/tall_and_skinny/args.hpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20240215_HPP -#define ARGS_20240215_HPP - -#include -#include -#include - -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - -struct args { - std::vector tc; - bool double_precision; - bool help; - bool verify; - double beta; - bool specialize_M; - bool specialize_ld; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20240215_HPP diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 8538cf0e..daf4abe4 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -1,14 +1,16 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "../gemm_common.hpp" +#include #include #include #include #include #include +#include #include #include #include @@ -20,6 +22,18 @@ using namespace sycl; using namespace tinytc; +struct args { + bool dump = false; + bool specialize_M = false; + bool specialize_ld = false; + scalar_type ty = scalar_type::f32; + bool update = false; + bool verify = false; + std::int32_t alignment = 0; + std::int32_t M_block_size = 0; + std::vector tc; +}; + template double bench(F f, int nrepeat = 10) { f(); double min_exec_time_ns = std::numeric_limits::max(); @@ -35,11 +49,6 @@ template double bench(F f, int nrepeat = 10) { } template void test(queue q, args &a) { - auto const fill = [](std::vector &x) { - for (std::size_t i = 0; i < x.size(); ++i) { - x[i] = i % 101; - } - }; std::int64_t na_max = 0; std::int64_t nb_max = 0; std::int64_t nc_max = 0; @@ -51,28 +60,32 @@ template void test(queue q, args &a) { auto A_host = std::vector(na_max); auto B_host = std::vector(nb_max); auto C_host = std::vector(nc_max); - auto C_ref_host = std::vector(nc_max); - T *C_ref = malloc_device(nc_max, q); - T *A = malloc_device(na_max, q); - T *B = malloc_device(nb_max, q); - T *C = malloc_device(nc_max, q); - fill(A_host); - fill(B_host); - q.copy(A_host.data(), A, na_max).wait(); - q.copy(B_host.data(), B, nb_max).wait(); - - auto const check = [&](std::int64_t M, std::int64_t N) { - q.copy(C_ref, C_ref_host.data(), M * N).wait(); + auto const alloc_device = [&a, &q](std::size_t num_bytes) { + if (a.alignment == 0) { + return malloc_device(num_bytes, q); + } else { + return aligned_alloc_device(a.alignment, num_bytes, q); + } + }; + T *A = alloc_device(na_max); + T *B = alloc_device(nb_max); + T *C = alloc_device(nc_max); + + auto const check = [&](std::int64_t M, std::int64_t N, std::int64_t K) { q.copy(C, C_host.data(), M * N).wait(); std::size_t num_err = 0; - for (std::int64_t i = 0; i < M * N; ++i) { - auto err = std::abs(C_host[i] - C_ref_host[i]); - if (err > 10.0 * std::numeric_limits::epsilon()) { - if (num_err < 10) { - std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] - << std::endl; + const auto error_bound = examples::test_gemm_error_bound(K); + for (std::int64_t j = 0; j < N; ++j) { + for (std::int64_t i = 0; i < M; ++i) { + const auto relerr = examples::test_gemm_rel_error(C_host.data(), i, j, M); + if (relerr > error_bound) { + if (num_err < 10) { + std::cout << "C_{" << i << "," << j << "}=" << C_host[i + j * M] + << ", relative_error=" << relerr + << ", error_bound=" << error_bound << std::endl; + } + ++num_err; } - ++num_err; } } if (num_err > 10) { @@ -80,29 +93,15 @@ template void test(queue q, args &a) { } }; - auto const &type = typeid(T); for (auto &c : a.tc) { - if (a.verify) { - q.memset(C, 0, c.m * c.n * sizeof(T)).wait(); - q.memset(C_ref, 0, c.m * c.n * sizeof(T)).wait(); - q.submit([&](auto &h) { - auto beta = a.beta; - h.parallel_for(range{static_cast(c.n), static_cast(c.m)}, - [=](id<2> it) { - auto m = it[1]; - auto n = it[0]; - auto c_acc = T(0.0); - for (std::int64_t k = 0; k < c.k; ++k) { - c_acc += A[m + k * c.m] * B[k + n * c.k]; - } - C_ref[m + n * c.m] = c_acc + T(beta) * C_ref[m + n * c.m]; - }); - }).wait(); - } + examples::test_gemm_matrix(A_host.data(), c.m, c.k); + examples::test_gemm_matrix(B_host.data(), c.k, c.n); + q.copy(A_host.data(), A, c.m * c.k).wait(); + q.copy(B_host.data(), B, c.k * c.n).wait(); + q.memset(C, 0, c.m * c.n * sizeof(T)).wait(); - auto source_ctx = source_context{}; + auto beta = a.update ? T{1} : T{0}; try { - source_ctx = make_source_context(); auto info = make_core_info(q.get_device()); info.set_core_features(tinytc_core_feature_flag_large_register_file); @@ -113,31 +112,47 @@ template void test(queue q, args &a) { ldB = c.k; ldC = c.m; } - auto tas = make_recipe_handler( - q, - make_tall_and_skinny_specialized(info, to_scalar_type_v, M, c.n, c.k, ldA, ldB, - ldC, 0, source_ctx), - source_ctx); + auto ctx = make_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); + auto r = make_tall_and_skinny_specialized(info, a.ty, M, c.n, c.k, ldA, ldB, ldC, + a.alignment, a.alignment, a.alignment, + a.M_block_size, ctx); + if (a.dump) { + r.get_prog().dump(); + } + auto tas = make_recipe_handler(q, r); - tall_and_skinny::set_args(tas, c.m, T(1.0), A, c.m, B, c.k, T(a.beta), C, c.m); + tall_and_skinny::set_args(tas, c.m, T{1}, mem(A, mem_type::usm_pointer), c.m, + mem(B, mem_type::usm_pointer), c.k, beta, + mem(C, mem_type::usm_pointer), c.m); tas.submit(q).wait(); if (a.verify) { - check(c.m, c.n); + check(c.m, c.n, c.k); } double min_exec_time_ns = bench([&]() { tas.submit(q).wait(); }); - auto bw_C_factor = a.beta != 0.0 ? 2 : 1; + const auto ops_per_mnk = [&] { + switch (a.ty) { + case scalar_type::c32: + case scalar_type::c64: + return 8; + default: + return 2; + } + }(); + + auto bw_C_factor = a.update ? 2 : 1; auto bw = sizeof(T) * (c.m * c.n * bw_C_factor + c.m * c.k + c.k * c.n) / min_exec_time_ns; - auto gflops = 2 * c.m * c.n * c.k / min_exec_time_ns; - std::cout << type.name() << "," << c.m << "," << c.n << "," << c.k << "," << a.beta - << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops << std::endl; + auto gflops = ops_per_mnk * c.m * c.n * c.k / min_exec_time_ns; + std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," + << a.update << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops + << std::endl; } catch (status const &st) { std::cerr << "Error (" << static_cast(st) << "): " << tinytc::error_string(st) << std::endl; - if (source_ctx.get_error_log()[0] != '\0') { - std::cerr << "Error log: " << std::endl << source_ctx.get_error_log() << std::endl; - } } catch (std::exception const &e) { std::cerr << "Error: " << e.what() << std::endl; } @@ -146,30 +161,68 @@ template void test(queue q, args &a) { free(A, q); free(B, q); free(C, q); - free(C_ref, q); }; int main(int argc, char **argv) { auto a = args{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); - } catch (std::runtime_error const &e) { + parser.set_short_opt('a', &a.alignment, "Override memory alignment"); + parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); + parser.set_short_opt('f', &a.ty, "Data type (f32, f64, c32, c64)") + .converter(examples::convert_data_type); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('u', &a.update, + "Add A*B to C (beta=1) instead of overwriting C (beta=0)"); + parser.set_short_opt('v', &a.verify, "Verify optimized implementation"); + parser.set_long_opt("help", &help, "Show help"); + parser.set_long_opt("m-block-size", &a.M_block_size, + "Set block size for M mode (one work-group per block)"); + parser.set_long_opt("specialize-m", &a.specialize_M, + "Specialize M instead of using dynamic value"); + parser.set_long_opt("specialize-ld", &a.specialize_ld, + "Specialize ldA, ldB, ldC instead of using dynamic value"); + parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 300000x64x64)") + .converter(examples::convert_test_case) + .validator(examples::validate_test_case); + + parser.parse(argc, argv); + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } - if (a.help || a.tc.empty()) { - arg_parser::show_help(std::cout); - return 0; + if (help || a.tc.empty()) { + parser.print_help(std::cout, "tall_and_skinny", ""); + return !help ? -1 : 0; } auto q = queue{}; - std::cout << "precision,m,n,k,beta,time,bandwidth,gflops" << std::endl; + std::cout << "precision,m,n,k,update,time,bandwidth,gflops" << std::endl; try { - if (a.double_precision) { - test(std::move(q), a); - } else { + switch (a.ty) { + case scalar_type::bf16: + test(std::move(q), a); + break; + case scalar_type::f16: + test(std::move(q), a); + break; + case scalar_type::f32: test(std::move(q), a); + break; + case scalar_type::f64: + test(std::move(q), a); + break; + case scalar_type::c32: + test>(std::move(q), a); + break; + case scalar_type::c64: + test>(std::move(q), a); + break; + default: + return -1; } } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/external/doctest/doctest.h b/external/doctest/doctest.h index 8f442842..5c754cde 100644 --- a/external/doctest/doctest.h +++ b/external/doctest/doctest.h @@ -1,7019 +1,7106 @@ -// ====================================================================== lgtm [cpp/missing-header-guard] -// == DO NOT MODIFY THIS FILE BY HAND - IT IS AUTO GENERATED BY CMAKE! == -// ====================================================================== -// -// doctest.h - the lightest feature-rich C++ single-header testing framework for unit tests and TDD -// -// Copyright (c) 2016-2021 Viktor Kirilov -// -// Distributed under the MIT Software License -// See accompanying file LICENSE.txt or copy at -// https://opensource.org/licenses/MIT -// -// The documentation can be found at the library's page: -// https://github.com/doctest/doctest/blob/master/doc/markdown/readme.md -// -// ================================================================================================= -// ================================================================================================= -// ================================================================================================= -// -// The library is heavily influenced by Catch - https://github.com/catchorg/Catch2 -// which uses the Boost Software License - Version 1.0 -// see here - https://github.com/catchorg/Catch2/blob/master/LICENSE.txt -// -// The concept of subcases (sections in Catch) and expression decomposition are from there. -// Some parts of the code are taken directly: -// - stringification - the detection of "ostream& operator<<(ostream&, const T&)" and StringMaker<> -// - the Approx() helper class for floating point comparison -// - colors in the console -// - breaking into a debugger -// - signal / SEH handling -// - timer -// - XmlWriter class - thanks to Phil Nash for allowing the direct reuse (AKA copy/paste) -// -// The expression decomposing templates are taken from lest - https://github.com/martinmoene/lest -// which uses the Boost Software License - Version 1.0 -// see here - https://github.com/martinmoene/lest/blob/master/LICENSE.txt -// -// ================================================================================================= -// ================================================================================================= -// ================================================================================================= - -#ifndef DOCTEST_LIBRARY_INCLUDED -#define DOCTEST_LIBRARY_INCLUDED - -// ================================================================================================= -// == VERSION ====================================================================================== -// ================================================================================================= - -#define DOCTEST_VERSION_MAJOR 2 -#define DOCTEST_VERSION_MINOR 4 -#define DOCTEST_VERSION_PATCH 9 - -// util we need here -#define DOCTEST_TOSTR_IMPL(x) #x -#define DOCTEST_TOSTR(x) DOCTEST_TOSTR_IMPL(x) - -#define DOCTEST_VERSION_STR \ - DOCTEST_TOSTR(DOCTEST_VERSION_MAJOR) "." \ - DOCTEST_TOSTR(DOCTEST_VERSION_MINOR) "." \ - DOCTEST_TOSTR(DOCTEST_VERSION_PATCH) - -#define DOCTEST_VERSION \ - (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) - -// ================================================================================================= -// == COMPILER VERSION ============================================================================= -// ================================================================================================= - -// ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect - -#ifdef _MSC_VER -#define DOCTEST_CPLUSPLUS _MSVC_LANG -#else -#define DOCTEST_CPLUSPLUS __cplusplus -#endif - -#define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) - -// GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... -#if defined(_MSC_VER) && defined(_MSC_FULL_VER) -#if _MSC_VER == _MSC_FULL_VER / 10000 -#define DOCTEST_MSVC DOCTEST_COMPILER(_MSC_VER / 100, _MSC_VER % 100, _MSC_FULL_VER % 10000) -#else // MSVC -#define DOCTEST_MSVC \ - DOCTEST_COMPILER(_MSC_VER / 100, (_MSC_FULL_VER / 100000) % 100, _MSC_FULL_VER % 100000) -#endif // MSVC -#endif // MSVC -#if defined(__clang__) && defined(__clang_minor__) -#define DOCTEST_CLANG DOCTEST_COMPILER(__clang_major__, __clang_minor__, __clang_patchlevel__) -#elif defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) && \ - !defined(__INTEL_COMPILER) -#define DOCTEST_GCC DOCTEST_COMPILER(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#endif // GCC - -#ifndef DOCTEST_MSVC -#define DOCTEST_MSVC 0 -#endif // DOCTEST_MSVC -#ifndef DOCTEST_CLANG -#define DOCTEST_CLANG 0 -#endif // DOCTEST_CLANG -#ifndef DOCTEST_GCC -#define DOCTEST_GCC 0 -#endif // DOCTEST_GCC - -// ================================================================================================= -// == COMPILER WARNINGS HELPERS ==================================================================== -// ================================================================================================= - -#if DOCTEST_CLANG -#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) -#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") -#define DOCTEST_CLANG_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(clang diagnostic ignored w) -#define DOCTEST_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") -#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ - DOCTEST_CLANG_SUPPRESS_WARNING_PUSH DOCTEST_CLANG_SUPPRESS_WARNING(w) -#else // DOCTEST_CLANG -#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -#define DOCTEST_CLANG_SUPPRESS_WARNING(w) -#define DOCTEST_CLANG_SUPPRESS_WARNING_POP -#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) -#endif // DOCTEST_CLANG - -#if DOCTEST_GCC -#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) -#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH _Pragma("GCC diagnostic push") -#define DOCTEST_GCC_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(GCC diagnostic ignored w) -#define DOCTEST_GCC_SUPPRESS_WARNING_POP _Pragma("GCC diagnostic pop") -#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) \ - DOCTEST_GCC_SUPPRESS_WARNING_PUSH DOCTEST_GCC_SUPPRESS_WARNING(w) -#else // DOCTEST_GCC -#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH -#define DOCTEST_GCC_SUPPRESS_WARNING(w) -#define DOCTEST_GCC_SUPPRESS_WARNING_POP -#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) -#endif // DOCTEST_GCC - -#if DOCTEST_MSVC -#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH __pragma(warning(push)) -#define DOCTEST_MSVC_SUPPRESS_WARNING(w) __pragma(warning(disable : w)) -#define DOCTEST_MSVC_SUPPRESS_WARNING_POP __pragma(warning(pop)) -#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) \ - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH DOCTEST_MSVC_SUPPRESS_WARNING(w) -#else // DOCTEST_MSVC -#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -#define DOCTEST_MSVC_SUPPRESS_WARNING(w) -#define DOCTEST_MSVC_SUPPRESS_WARNING_POP -#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) -#endif // DOCTEST_MSVC - -// ================================================================================================= -// == COMPILER WARNINGS ============================================================================ -// ================================================================================================= - -// both the header and the implementation suppress all of these, -// so it only makes sense to aggregrate them like so -#define DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH \ - DOCTEST_CLANG_SUPPRESS_WARNING_PUSH \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") \ - \ - DOCTEST_GCC_SUPPRESS_WARNING_PUSH \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") \ - \ - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ - /* these 4 also disabled globally via cmake: */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4514) /* unreferenced inline function has been removed */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4571) /* SEH related */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4710) /* function not inlined */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4711) /* function selected for inline expansion*/ \ - /* */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4616) /* invalid compiler warning */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4619) /* invalid compiler warning */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4996) /* The compiler encountered a deprecated declaration */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4706) /* assignment within conditional expression */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4512) /* 'class' : assignment operator could not be generated */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4127) /* conditional expression is constant */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4640) /* construction of local static object not thread-safe */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ - /* static analysis */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26439) /* Function may not throw. Declare it 'noexcept' */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26495) /* Always initialize a member variable */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26451) /* Arithmetic overflow ... */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26444) /* Avoid unnamed objects with custom ctor and dtor... */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26812) /* Prefer 'enum class' over 'enum' */ - -#define DOCTEST_SUPPRESS_COMMON_WARNINGS_POP \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP \ - DOCTEST_GCC_SUPPRESS_WARNING_POP \ - DOCTEST_MSVC_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH - -DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") - -DOCTEST_GCC_SUPPRESS_WARNING_PUSH -DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") -DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") -DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") - -DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted - -#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ - DOCTEST_MSVC_SUPPRESS_WARNING(4548) /* before comma no effect; expected side - effect */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4265) /* virtual functions, but destructor is not virtual */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4986) /* exception specification does not match previous */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4350) /* 'member1' called instead of 'member2' */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4668) /* not defined as a preprocessor macro */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4365) /* signed/unsigned mismatch */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4774) /* format string not a string literal */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4623) /* default constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5039) /* pointer to pot. throwing function passed to extern C */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5105) /* macro producing 'defined' has undefined behavior */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4738) /* storing float result in memory, loss of performance */ - -#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP - -// ================================================================================================= -// == FEATURE DETECTION ============================================================================ -// ================================================================================================= - -// general compiler feature support table: https://en.cppreference.com/w/cpp/compiler_support -// MSVC C++11 feature support table: https://msdn.microsoft.com/en-us/library/hh567368.aspx -// GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html -// MSVC version table: -// https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering -// MSVC++ 14.3 (17) _MSC_VER == 1930 (Visual Studio 2022) -// MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) -// MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) -// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) -// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) -// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) -// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) -// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) -// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) - -// Universal Windows Platform support -#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) -#define DOCTEST_CONFIG_NO_WINDOWS_SEH -#endif // WINAPI_FAMILY -#if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) -#define DOCTEST_CONFIG_WINDOWS_SEH -#endif // MSVC -#if defined(DOCTEST_CONFIG_NO_WINDOWS_SEH) && defined(DOCTEST_CONFIG_WINDOWS_SEH) -#undef DOCTEST_CONFIG_WINDOWS_SEH -#endif // DOCTEST_CONFIG_NO_WINDOWS_SEH - -#if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ - !defined(__EMSCRIPTEN__) && !defined(__wasi__) -#define DOCTEST_CONFIG_POSIX_SIGNALS -#endif // _WIN32 -#if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) -#undef DOCTEST_CONFIG_POSIX_SIGNALS -#endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS -#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) \ - || defined(__wasi__) -#define DOCTEST_CONFIG_NO_EXCEPTIONS -#endif // no exceptions -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS -#define DOCTEST_CONFIG_NO_EXCEPTIONS -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS - -#if defined(DOCTEST_CONFIG_NO_EXCEPTIONS) && !defined(DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS) -#define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS - -#ifdef __wasi__ -#define DOCTEST_CONFIG_NO_MULTITHREADING -#endif - -#if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) -#define DOCTEST_CONFIG_IMPLEMENT -#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN - -#if defined(_WIN32) || defined(__CYGWIN__) -#if DOCTEST_MSVC -#define DOCTEST_SYMBOL_EXPORT __declspec(dllexport) -#define DOCTEST_SYMBOL_IMPORT __declspec(dllimport) -#else // MSVC -#define DOCTEST_SYMBOL_EXPORT __attribute__((dllexport)) -#define DOCTEST_SYMBOL_IMPORT __attribute__((dllimport)) -#endif // MSVC -#else // _WIN32 -#define DOCTEST_SYMBOL_EXPORT __attribute__((visibility("default"))) -#define DOCTEST_SYMBOL_IMPORT -#endif // _WIN32 - -#ifdef DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL -#ifdef DOCTEST_CONFIG_IMPLEMENT -#define DOCTEST_INTERFACE DOCTEST_SYMBOL_EXPORT -#else // DOCTEST_CONFIG_IMPLEMENT -#define DOCTEST_INTERFACE DOCTEST_SYMBOL_IMPORT -#endif // DOCTEST_CONFIG_IMPLEMENT -#else // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL -#define DOCTEST_INTERFACE -#endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL - -// needed for extern template instantiations -// see https://github.com/fmtlib/fmt/issues/2228 -#if DOCTEST_MSVC -#define DOCTEST_INTERFACE_DECL -#define DOCTEST_INTERFACE_DEF DOCTEST_INTERFACE -#else // DOCTEST_MSVC -#define DOCTEST_INTERFACE_DECL DOCTEST_INTERFACE -#define DOCTEST_INTERFACE_DEF -#endif // DOCTEST_MSVC - -#define DOCTEST_EMPTY - -#if DOCTEST_MSVC -#define DOCTEST_NOINLINE __declspec(noinline) -#define DOCTEST_UNUSED -#define DOCTEST_ALIGNMENT(x) -#elif DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 5, 0) -#define DOCTEST_NOINLINE -#define DOCTEST_UNUSED -#define DOCTEST_ALIGNMENT(x) -#else -#define DOCTEST_NOINLINE __attribute__((noinline)) -#define DOCTEST_UNUSED __attribute__((unused)) -#define DOCTEST_ALIGNMENT(x) __attribute__((aligned(x))) -#endif - -#ifndef DOCTEST_NORETURN -#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_NORETURN -#else // DOCTEST_MSVC -#define DOCTEST_NORETURN [[noreturn]] -#endif // DOCTEST_MSVC -#endif // DOCTEST_NORETURN - -#ifndef DOCTEST_NOEXCEPT -#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_NOEXCEPT -#else // DOCTEST_MSVC -#define DOCTEST_NOEXCEPT noexcept -#endif // DOCTEST_MSVC -#endif // DOCTEST_NOEXCEPT - -#ifndef DOCTEST_CONSTEXPR -#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_CONSTEXPR const -#define DOCTEST_CONSTEXPR_FUNC inline -#else // DOCTEST_MSVC -#define DOCTEST_CONSTEXPR constexpr -#define DOCTEST_CONSTEXPR_FUNC constexpr -#endif // DOCTEST_MSVC -#endif // DOCTEST_CONSTEXPR - -// ================================================================================================= -// == FEATURE DETECTION END ======================================================================== -// ================================================================================================= - -#define DOCTEST_DECLARE_INTERFACE(name) \ - virtual ~name(); \ - name() = default; \ - name(const name&) = delete; \ - name(name&&) = delete; \ - name& operator=(const name&) = delete; \ - name& operator=(name&&) = delete; - -#define DOCTEST_DEFINE_INTERFACE(name) \ - name::~name() = default; - -// internal macros for string concatenation and anonymous variable name generation -#define DOCTEST_CAT_IMPL(s1, s2) s1##s2 -#define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) -#ifdef __COUNTER__ // not standard and may be missing for some compilers -#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __COUNTER__) -#else // __COUNTER__ -#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) -#endif // __COUNTER__ - -#ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE -#define DOCTEST_REF_WRAP(x) x& -#else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE -#define DOCTEST_REF_WRAP(x) x -#endif // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE - -// not using __APPLE__ because... this is how Catch does it -#ifdef __MAC_OS_X_VERSION_MIN_REQUIRED -#define DOCTEST_PLATFORM_MAC -#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) -#define DOCTEST_PLATFORM_IPHONE -#elif defined(_WIN32) -#define DOCTEST_PLATFORM_WINDOWS -#elif defined(__wasi__) -#define DOCTEST_PLATFORM_WASI -#else // DOCTEST_PLATFORM -#define DOCTEST_PLATFORM_LINUX -#endif // DOCTEST_PLATFORM - -namespace doctest { namespace detail { - static DOCTEST_CONSTEXPR int consume(const int*, int) noexcept { return 0; } -}} - -#define DOCTEST_GLOBAL_NO_WARNINGS(var, ...) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ - static const int var = doctest::detail::consume(&var, __VA_ARGS__); \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP - -#ifndef DOCTEST_BREAK_INTO_DEBUGGER -// should probably take a look at https://github.com/scottt/debugbreak -#ifdef DOCTEST_PLATFORM_LINUX -#if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) -// Break at the location of the failing check if possible -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) -#else -#include -#define DOCTEST_BREAK_INTO_DEBUGGER() raise(SIGTRAP) -#endif -#elif defined(DOCTEST_PLATFORM_MAC) -#if defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || defined(__i386) -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) -#elif defined(__ppc__) || defined(__ppc64__) -// https://www.cocoawithlove.com/2008/03/break-into-debugger.html -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n": : : "memory","r0","r3","r4") // NOLINT(hicpp-no-assembler) -#else -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT(hicpp-no-assembler) -#endif -#elif DOCTEST_MSVC -#define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() -#elif defined(__MINGW32__) -DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wredundant-decls") -extern "C" __declspec(dllimport) void __stdcall DebugBreak(); -DOCTEST_GCC_SUPPRESS_WARNING_POP -#define DOCTEST_BREAK_INTO_DEBUGGER() ::DebugBreak() -#else // linux -#define DOCTEST_BREAK_INTO_DEBUGGER() (static_cast(0)) -#endif // linux -#endif // DOCTEST_BREAK_INTO_DEBUGGER - -// this is kept here for backwards compatibility since the config option was changed -#ifdef DOCTEST_CONFIG_USE_IOSFWD -#ifndef DOCTEST_CONFIG_USE_STD_HEADERS -#define DOCTEST_CONFIG_USE_STD_HEADERS -#endif -#endif // DOCTEST_CONFIG_USE_IOSFWD - -// for clang - always include ciso646 (which drags some std stuff) because -// we want to check if we are using libc++ with the _LIBCPP_VERSION macro in -// which case we don't want to forward declare stuff from std - for reference: -// https://github.com/doctest/doctest/issues/126 -// https://github.com/doctest/doctest/issues/356 -#if DOCTEST_CLANG -#include -#ifdef _LIBCPP_VERSION -#ifndef DOCTEST_CONFIG_USE_STD_HEADERS -#define DOCTEST_CONFIG_USE_STD_HEADERS -#endif -#endif // _LIBCPP_VERSION -#endif // clang - -#ifdef DOCTEST_CONFIG_USE_STD_HEADERS -#ifndef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#define DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN -#include -#include -#include -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END -#else // DOCTEST_CONFIG_USE_STD_HEADERS - -// Forward declaring 'X' in namespace std is not permitted by the C++ Standard. -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) - -namespace std { // NOLINT(cert-dcl58-cpp) -typedef decltype(nullptr) nullptr_t; // NOLINT(modernize-use-using) -typedef decltype(sizeof(void*)) size_t; // NOLINT(modernize-use-using) -template -struct char_traits; -template <> -struct char_traits; -template -class basic_ostream; // NOLINT(fuchsia-virtual-inheritance) -typedef basic_ostream> ostream; // NOLINT(modernize-use-using) -template -// NOLINTNEXTLINE -basic_ostream& operator<<(basic_ostream&, const char*); -template -class basic_istream; -typedef basic_istream> istream; // NOLINT(modernize-use-using) -template -class tuple; -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 -template -class allocator; -template -class basic_string; -using string = basic_string, allocator>; -#endif // VS 2019 -} // namespace std - -DOCTEST_MSVC_SUPPRESS_WARNING_POP - -#endif // DOCTEST_CONFIG_USE_STD_HEADERS - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#include -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - -namespace doctest { - -using std::size_t; - -DOCTEST_INTERFACE extern bool is_running_in_test; - -#ifndef DOCTEST_CONFIG_STRING_SIZE_TYPE -#define DOCTEST_CONFIG_STRING_SIZE_TYPE unsigned -#endif - -// A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length -// of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: -// - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) -// - if small - capacity left before going on the heap - using the lowest 5 bits -// - if small - 2 bits are left unused - the second and third highest ones -// - if small - acts as a null terminator if strlen() is 23 (24 including the null terminator) -// and the "is small" bit remains "0" ("as well as the capacity left") so its OK -// Idea taken from this lecture about the string implementation of facebook/folly - fbstring -// https://www.youtube.com/watch?v=kPR8h4-qZdk -// TODO: -// - optimizations - like not deleting memory unnecessarily in operator= and etc. -// - resize/reserve/clear -// - replace -// - back/front -// - iterator stuff -// - find & friends -// - push_back/pop_back -// - assign/insert/erase -// - relational operators as free functions - taking const char* as one of the params -class DOCTEST_INTERFACE String -{ -public: - using size_type = DOCTEST_CONFIG_STRING_SIZE_TYPE; - -private: - static DOCTEST_CONSTEXPR size_type len = 24; //!OCLINT avoid private static members - static DOCTEST_CONSTEXPR size_type last = len - 1; //!OCLINT avoid private static members - - struct view // len should be more than sizeof(view) - because of the final byte for flags - { - char* ptr; - size_type size; - size_type capacity; - }; - - union - { - char buf[len]; // NOLINT(*-avoid-c-arrays) - view data; - }; - - char* allocate(size_type sz); - - bool isOnStack() const noexcept { return (buf[last] & 128) == 0; } - void setOnHeap() noexcept; - void setLast(size_type in = last) noexcept; - void setSize(size_type sz) noexcept; - - void copy(const String& other); - -public: - static DOCTEST_CONSTEXPR size_type npos = static_cast(-1); - - String() noexcept; - ~String(); - - // cppcheck-suppress noExplicitConstructor - String(const char* in); - String(const char* in, size_type in_size); - - String(std::istream& in, size_type in_size); - - String(const String& other); - String& operator=(const String& other); - - String& operator+=(const String& other); - - String(String&& other) noexcept; - String& operator=(String&& other) noexcept; - - char operator[](size_type i) const; - char& operator[](size_type i); - - // the only functions I'm willing to leave in the interface - available for inlining - const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT - char* c_str() { - if (isOnStack()) { - return reinterpret_cast(buf); - } - return data.ptr; - } - - size_type size() const; - size_type capacity() const; - - String substr(size_type pos, size_type cnt = npos) &&; - String substr(size_type pos, size_type cnt = npos) const &; - - size_type find(char ch, size_type pos = 0) const; - size_type rfind(char ch, size_type pos = npos) const; - - int compare(const char* other, bool no_case = false) const; - int compare(const String& other, bool no_case = false) const; - -friend DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); -}; - -DOCTEST_INTERFACE String operator+(const String& lhs, const String& rhs); - -DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); - -class DOCTEST_INTERFACE Contains { -public: - explicit Contains(const String& string); - - bool checkWith(const String& other) const; - - String string; -}; - -DOCTEST_INTERFACE String toString(const Contains& in); - -DOCTEST_INTERFACE bool operator==(const String& lhs, const Contains& rhs); -DOCTEST_INTERFACE bool operator==(const Contains& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator!=(const String& lhs, const Contains& rhs); -DOCTEST_INTERFACE bool operator!=(const Contains& lhs, const String& rhs); - -namespace Color { - enum Enum - { - None = 0, - White, - Red, - Green, - Blue, - Cyan, - Yellow, - Grey, - - Bright = 0x10, - - BrightRed = Bright | Red, - BrightGreen = Bright | Green, - LightGrey = Bright | Grey, - BrightWhite = Bright | White - }; - - DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, Color::Enum code); -} // namespace Color - -namespace assertType { - enum Enum - { - // macro traits - - is_warn = 1, - is_check = 2 * is_warn, - is_require = 2 * is_check, - - is_normal = 2 * is_require, - is_throws = 2 * is_normal, - is_throws_as = 2 * is_throws, - is_throws_with = 2 * is_throws_as, - is_nothrow = 2 * is_throws_with, - - is_false = 2 * is_nothrow, - is_unary = 2 * is_false, // not checked anywhere - used just to distinguish the types - - is_eq = 2 * is_unary, - is_ne = 2 * is_eq, - - is_lt = 2 * is_ne, - is_gt = 2 * is_lt, - - is_ge = 2 * is_gt, - is_le = 2 * is_ge, - - // macro types - - DT_WARN = is_normal | is_warn, - DT_CHECK = is_normal | is_check, - DT_REQUIRE = is_normal | is_require, - - DT_WARN_FALSE = is_normal | is_false | is_warn, - DT_CHECK_FALSE = is_normal | is_false | is_check, - DT_REQUIRE_FALSE = is_normal | is_false | is_require, - - DT_WARN_THROWS = is_throws | is_warn, - DT_CHECK_THROWS = is_throws | is_check, - DT_REQUIRE_THROWS = is_throws | is_require, - - DT_WARN_THROWS_AS = is_throws_as | is_warn, - DT_CHECK_THROWS_AS = is_throws_as | is_check, - DT_REQUIRE_THROWS_AS = is_throws_as | is_require, - - DT_WARN_THROWS_WITH = is_throws_with | is_warn, - DT_CHECK_THROWS_WITH = is_throws_with | is_check, - DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, - - DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, - DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, - DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, - - DT_WARN_NOTHROW = is_nothrow | is_warn, - DT_CHECK_NOTHROW = is_nothrow | is_check, - DT_REQUIRE_NOTHROW = is_nothrow | is_require, - - DT_WARN_EQ = is_normal | is_eq | is_warn, - DT_CHECK_EQ = is_normal | is_eq | is_check, - DT_REQUIRE_EQ = is_normal | is_eq | is_require, - - DT_WARN_NE = is_normal | is_ne | is_warn, - DT_CHECK_NE = is_normal | is_ne | is_check, - DT_REQUIRE_NE = is_normal | is_ne | is_require, - - DT_WARN_GT = is_normal | is_gt | is_warn, - DT_CHECK_GT = is_normal | is_gt | is_check, - DT_REQUIRE_GT = is_normal | is_gt | is_require, - - DT_WARN_LT = is_normal | is_lt | is_warn, - DT_CHECK_LT = is_normal | is_lt | is_check, - DT_REQUIRE_LT = is_normal | is_lt | is_require, - - DT_WARN_GE = is_normal | is_ge | is_warn, - DT_CHECK_GE = is_normal | is_ge | is_check, - DT_REQUIRE_GE = is_normal | is_ge | is_require, - - DT_WARN_LE = is_normal | is_le | is_warn, - DT_CHECK_LE = is_normal | is_le | is_check, - DT_REQUIRE_LE = is_normal | is_le | is_require, - - DT_WARN_UNARY = is_normal | is_unary | is_warn, - DT_CHECK_UNARY = is_normal | is_unary | is_check, - DT_REQUIRE_UNARY = is_normal | is_unary | is_require, - - DT_WARN_UNARY_FALSE = is_normal | is_false | is_unary | is_warn, - DT_CHECK_UNARY_FALSE = is_normal | is_false | is_unary | is_check, - DT_REQUIRE_UNARY_FALSE = is_normal | is_false | is_unary | is_require, - }; -} // namespace assertType - -DOCTEST_INTERFACE const char* assertString(assertType::Enum at); -DOCTEST_INTERFACE const char* failureString(assertType::Enum at); -DOCTEST_INTERFACE const char* skipPathFromFilename(const char* file); - -struct DOCTEST_INTERFACE TestCaseData -{ - String m_file; // the file in which the test was registered (using String - see #350) - unsigned m_line; // the line where the test was registered - const char* m_name; // name of the test case - const char* m_test_suite; // the test suite in which the test was added - const char* m_description; - bool m_skip; - bool m_no_breaks; - bool m_no_output; - bool m_may_fail; - bool m_should_fail; - int m_expected_failures; - double m_timeout; -}; - -struct DOCTEST_INTERFACE AssertData -{ - // common - for all asserts - const TestCaseData* m_test_case; - assertType::Enum m_at; - const char* m_file; - int m_line; - const char* m_expr; - bool m_failed; - - // exception-related - for all asserts - bool m_threw; - String m_exception; - - // for normal asserts - String m_decomp; - - // for specific exception-related asserts - bool m_threw_as; - const char* m_exception_type; - - class DOCTEST_INTERFACE StringContains { - private: - Contains content; - bool isContains; - - public: - StringContains(const String& str) : content(str), isContains(false) { } - StringContains(Contains cntn) : content(static_cast(cntn)), isContains(true) { } - - bool check(const String& str) { return isContains ? (content == str) : (content.string == str); } - - operator const String&() const { return content.string; } - - const char* c_str() const { return content.string.c_str(); } - } m_exception_string; - - AssertData(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const StringContains& exception_string); -}; - -struct DOCTEST_INTERFACE MessageData -{ - String m_string; - const char* m_file; - int m_line; - assertType::Enum m_severity; -}; - -struct DOCTEST_INTERFACE SubcaseSignature -{ - String m_name; - const char* m_file; - int m_line; - - bool operator==(const SubcaseSignature& other) const; - bool operator<(const SubcaseSignature& other) const; -}; - -struct DOCTEST_INTERFACE IContextScope -{ - DOCTEST_DECLARE_INTERFACE(IContextScope) - virtual void stringify(std::ostream*) const = 0; -}; - -namespace detail { - struct DOCTEST_INTERFACE TestCase; -} // namespace detail - -struct ContextOptions //!OCLINT too many fields -{ - std::ostream* cout = nullptr; // stdout stream - String binary_name; // the test binary name - - const detail::TestCase* currentTest = nullptr; - - // == parameters from the command line - String out; // output filename - String order_by; // how tests should be ordered - unsigned rand_seed; // the seed for rand ordering - - unsigned first; // the first (matching) test to be executed - unsigned last; // the last (matching) test to be executed - - int abort_after; // stop tests after this many failed assertions - int subcase_filter_levels; // apply the subcase filters for the first N levels - - bool success; // include successful assertions in output - bool case_sensitive; // if filtering should be case sensitive - bool exit; // if the program should be exited after the tests are ran/whatever - bool duration; // print the time duration of each test case - bool minimal; // minimal console output (only test failures) - bool quiet; // no console output - bool no_throw; // to skip exceptions-related assertion macros - bool no_exitcode; // if the framework should return 0 as the exitcode - bool no_run; // to not run the tests at all (can be done with an "*" exclude) - bool no_intro; // to not print the intro of the framework - bool no_version; // to not print the version of the framework - bool no_colors; // if output to the console should be colorized - bool force_colors; // forces the use of colors even when a tty cannot be detected - bool no_breaks; // to not break into the debugger - bool no_skip; // don't skip test cases which are marked to be skipped - bool gnu_file_line; // if line numbers should be surrounded with :x: and not (x): - bool no_path_in_filenames; // if the path to files should be removed from the output - bool no_line_numbers; // if source code line numbers should be omitted from the output - bool no_debug_output; // no output in the debug console when a debugger is attached - bool no_skipped_summary; // don't print "skipped" in the summary !!! UNDOCUMENTED !!! - bool no_time_in_output; // omit any time/timestamps from output !!! UNDOCUMENTED !!! - - bool help; // to print the help - bool version; // to print the version - bool count; // if only the count of matching tests is to be retrieved - bool list_test_cases; // to list all tests matching the filters - bool list_test_suites; // to list all suites matching the filters - bool list_reporters; // lists all registered reporters -}; - -namespace detail { - namespace types { -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - using namespace std; -#else - template - struct enable_if { }; - - template - struct enable_if { using type = T; }; - - struct true_type { static DOCTEST_CONSTEXPR bool value = true; }; - struct false_type { static DOCTEST_CONSTEXPR bool value = false; }; - - template struct remove_reference { using type = T; }; - template struct remove_reference { using type = T; }; - template struct remove_reference { using type = T; }; - - template struct is_rvalue_reference : false_type { }; - template struct is_rvalue_reference : true_type { }; - - template struct remove_const { using type = T; }; - template struct remove_const { using type = T; }; - - // Compiler intrinsics - template struct is_enum { static DOCTEST_CONSTEXPR bool value = __is_enum(T); }; - template struct underlying_type { using type = __underlying_type(T); }; - - template struct is_pointer : false_type { }; - template struct is_pointer : true_type { }; - - template struct is_array : false_type { }; - // NOLINTNEXTLINE(*-avoid-c-arrays) - template struct is_array : true_type { }; -#endif - } - - // - template - T&& declval(); - - template - DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type& t) DOCTEST_NOEXCEPT { - return static_cast(t); - } - - template - DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type&& t) DOCTEST_NOEXCEPT { - return static_cast(t); - } - - template - struct deferred_false : types::false_type { }; - -// MSVS 2015 :( -#if defined(_MSC_VER) && _MSC_VER <= 1900 - template - struct has_global_insertion_operator : types::false_type { }; - - template - struct has_global_insertion_operator(), declval()), void())> : types::true_type { }; - - template - struct has_insertion_operator { static DOCTEST_CONSTEXPR bool value = has_global_insertion_operator::value; }; - - template - struct insert_hack; - - template - struct insert_hack { - static void insert(std::ostream& os, const T& t) { ::operator<<(os, t); } - }; - - template - struct insert_hack { - static void insert(std::ostream& os, const T& t) { operator<<(os, t); } - }; - - template - using insert_hack_t = insert_hack::value>; -#else - template - struct has_insertion_operator : types::false_type { }; -#endif - -template -struct has_insertion_operator(), declval()), void())> : types::true_type { }; - - DOCTEST_INTERFACE std::ostream* tlssPush(); - DOCTEST_INTERFACE String tlssPop(); - - template - struct StringMakerBase { - template - static String convert(const DOCTEST_REF_WRAP(T)) { -#ifdef DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES - static_assert(deferred_false::value, "No stringification detected for type T. See string conversion manual"); -#endif - return "{?}"; - } - }; - - template - struct filldata; - - template - void filloss(std::ostream* stream, const T& in) { - filldata::fill(stream, in); - } - - template - void filloss(std::ostream* stream, const T (&in)[N]) { // NOLINT(*-avoid-c-arrays) - // T[N], T(&)[N], T(&&)[N] have same behaviour. - // Hence remove reference. - filloss::type>(stream, in); - } - - template - String toStream(const T& in) { - std::ostream* stream = tlssPush(); - filloss(stream, in); - return tlssPop(); - } - - template <> - struct StringMakerBase { - template - static String convert(const DOCTEST_REF_WRAP(T) in) { - return toStream(in); - } - }; -} // namespace detail - -template -struct StringMaker : public detail::StringMakerBase< - detail::has_insertion_operator::value || detail::types::is_pointer::value || detail::types::is_array::value> -{}; - -#ifndef DOCTEST_STRINGIFY -#ifdef DOCTEST_CONFIG_DOUBLE_STRINGIFY -#define DOCTEST_STRINGIFY(...) toString(toString(__VA_ARGS__)) -#else -#define DOCTEST_STRINGIFY(...) toString(__VA_ARGS__) -#endif -#endif - -template -String toString() { -#if DOCTEST_MSVC >= 0 && DOCTEST_CLANG == 0 && DOCTEST_GCC == 0 - String ret = __FUNCSIG__; // class doctest::String __cdecl doctest::toString(void) - String::size_type beginPos = ret.find('<'); - return ret.substr(beginPos + 1, ret.size() - beginPos - static_cast(sizeof(">(void)"))); -#else - String ret = __PRETTY_FUNCTION__; // doctest::String toString() [with T = TYPE] - String::size_type begin = ret.find('=') + 2; - return ret.substr(begin, ret.size() - begin - 1); -#endif -} - -template ::value, bool>::type = true> -String toString(const DOCTEST_REF_WRAP(T) value) { - return StringMaker::convert(value); -} - -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -DOCTEST_INTERFACE String toString(const char* in); -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 -DOCTEST_INTERFACE String toString(const std::string& in); -#endif // VS 2019 - -DOCTEST_INTERFACE String toString(String in); - -DOCTEST_INTERFACE String toString(std::nullptr_t); - -DOCTEST_INTERFACE String toString(bool in); - -DOCTEST_INTERFACE String toString(float in); -DOCTEST_INTERFACE String toString(double in); -DOCTEST_INTERFACE String toString(double long in); - -DOCTEST_INTERFACE String toString(char in); -DOCTEST_INTERFACE String toString(char signed in); -DOCTEST_INTERFACE String toString(char unsigned in); -DOCTEST_INTERFACE String toString(short in); -DOCTEST_INTERFACE String toString(short unsigned in); -DOCTEST_INTERFACE String toString(signed in); -DOCTEST_INTERFACE String toString(unsigned in); -DOCTEST_INTERFACE String toString(long in); -DOCTEST_INTERFACE String toString(long unsigned in); -DOCTEST_INTERFACE String toString(long long in); -DOCTEST_INTERFACE String toString(long long unsigned in); - -template ::value, bool>::type = true> -String toString(const DOCTEST_REF_WRAP(T) value) { - using UT = typename detail::types::underlying_type::type; - return (DOCTEST_STRINGIFY(static_cast(value))); -} - -namespace detail { - template - struct filldata - { - static void fill(std::ostream* stream, const T& in) { -#if defined(_MSC_VER) && _MSC_VER <= 1900 - insert_hack_t::insert(*stream, in); -#else - operator<<(*stream, in); -#endif - } - }; - -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) -// NOLINTBEGIN(*-avoid-c-arrays) - template - struct filldata { - static void fill(std::ostream* stream, const T(&in)[N]) { - *stream << "["; - for (size_t i = 0; i < N; i++) { - if (i != 0) { *stream << ", "; } - *stream << (DOCTEST_STRINGIFY(in[i])); - } - *stream << "]"; - } - }; -// NOLINTEND(*-avoid-c-arrays) -DOCTEST_MSVC_SUPPRESS_WARNING_POP - - // Specialized since we don't want the terminating null byte! -// NOLINTBEGIN(*-avoid-c-arrays) - template - struct filldata { - static void fill(std::ostream* stream, const char (&in)[N]) { - *stream << String(in, in[N - 1] ? N : N - 1); - } // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - }; -// NOLINTEND(*-avoid-c-arrays) - - template <> - struct filldata { - static void fill(std::ostream* stream, const void* in); - }; - - template - struct filldata { - static void fill(std::ostream* stream, const T* in) { - filldata::fill(stream, in); - } - }; -} - -struct DOCTEST_INTERFACE Approx -{ - Approx(double value); - - Approx operator()(double value) const; - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template - explicit Approx(const T& value, - typename detail::types::enable_if::value>::type* = - static_cast(nullptr)) { - *this = static_cast(value); - } -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - Approx& epsilon(double newEpsilon); - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template - typename std::enable_if::value, Approx&>::type epsilon( - const T& newEpsilon) { - m_epsilon = static_cast(newEpsilon); - return *this; - } -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - Approx& scale(double newScale); - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template - typename std::enable_if::value, Approx&>::type scale( - const T& newScale) { - m_scale = static_cast(newScale); - return *this; - } -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - // clang-format off - DOCTEST_INTERFACE friend bool operator==(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator==(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator!=(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator!=(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator<=(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator<=(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator>=(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator>=(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator< (double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator< (const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#define DOCTEST_APPROX_PREFIX \ - template friend typename std::enable_if::value, bool>::type - - DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(static_cast(lhs), rhs); } - DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } - DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } - DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } - DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) && lhs != rhs; } -#undef DOCTEST_APPROX_PREFIX -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - // clang-format on - - double m_epsilon; - double m_scale; - double m_value; -}; - -DOCTEST_INTERFACE String toString(const Approx& in); - -DOCTEST_INTERFACE const ContextOptions* getContextOptions(); - -template -struct DOCTEST_INTERFACE_DECL IsNaN -{ - F value; bool flipped; - IsNaN(F f, bool flip = false) : value(f), flipped(flip) { } - IsNaN operator!() const { return { value, !flipped }; } - operator bool() const; -}; -#ifndef __MINGW32__ -extern template struct DOCTEST_INTERFACE_DECL IsNaN; -extern template struct DOCTEST_INTERFACE_DECL IsNaN; -extern template struct DOCTEST_INTERFACE_DECL IsNaN; -#endif -DOCTEST_INTERFACE String toString(IsNaN in); -DOCTEST_INTERFACE String toString(IsNaN in); -DOCTEST_INTERFACE String toString(IsNaN in); - -#ifndef DOCTEST_CONFIG_DISABLE - -namespace detail { - // clang-format off -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - template struct decay_array { using type = T; }; - template struct decay_array { using type = T*; }; - template struct decay_array { using type = T*; }; - - template struct not_char_pointer { static DOCTEST_CONSTEXPR value = 1; }; - template<> struct not_char_pointer { static DOCTEST_CONSTEXPR value = 0; }; - template<> struct not_char_pointer { static DOCTEST_CONSTEXPR value = 0; }; - - template struct can_use_op : public not_char_pointer::type> {}; -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - // clang-format on - - struct DOCTEST_INTERFACE TestFailureException - { - }; - - DOCTEST_INTERFACE bool checkIfShouldThrow(assertType::Enum at); - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - DOCTEST_NORETURN -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - DOCTEST_INTERFACE void throwException(); - - struct DOCTEST_INTERFACE Subcase - { - SubcaseSignature m_signature; - bool m_entered = false; - - Subcase(const String& name, const char* file, int line); - Subcase(const Subcase&) = delete; - Subcase(Subcase&&) = delete; - Subcase& operator=(const Subcase&) = delete; - Subcase& operator=(Subcase&&) = delete; - ~Subcase(); - - operator bool() const; - - private: - bool checkFilters(); - }; - - template - String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, - const DOCTEST_REF_WRAP(R) rhs) { - return (DOCTEST_STRINGIFY(lhs)) + op + (DOCTEST_STRINGIFY(rhs)); - } - -#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) -DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") -#endif - -// This will check if there is any way it could find a operator like member or friend and uses it. -// If not it doesn't find the operator or if the operator at global scope is defined after -// this template, the template won't be instantiated due to SFINAE. Once the template is not -// instantiated it can look for global operator using normal conversions. -#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),ret{}) - -#define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ - template \ - DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ - bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ - if(m_at & assertType::is_false) \ - res = !res; \ - if(!res || doctest::getContextOptions()->success) \ - return Result(res, stringifyBinaryExpr(lhs, op_str, rhs)); \ - return Result(res); \ - } - - // more checks could be added - like in Catch: - // https://github.com/catchorg/Catch2/pull/1480/files - // https://github.com/catchorg/Catch2/pull/1481/files -#define DOCTEST_FORBIT_EXPRESSION(rt, op) \ - template \ - rt& operator op(const R&) { \ - static_assert(deferred_false::value, \ - "Expression Too Complex Please Rewrite As Binary Comparison!"); \ - return *this; \ - } - - struct DOCTEST_INTERFACE Result // NOLINT(*-member-init) - { - bool m_passed; - String m_decomp; - - Result() = default; // TODO: Why do we need this? (To remove NOLINT) - Result(bool passed, const String& decomposition = String()); - - // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence - DOCTEST_FORBIT_EXPRESSION(Result, &) - DOCTEST_FORBIT_EXPRESSION(Result, ^) - DOCTEST_FORBIT_EXPRESSION(Result, |) - DOCTEST_FORBIT_EXPRESSION(Result, &&) - DOCTEST_FORBIT_EXPRESSION(Result, ||) - DOCTEST_FORBIT_EXPRESSION(Result, ==) - DOCTEST_FORBIT_EXPRESSION(Result, !=) - DOCTEST_FORBIT_EXPRESSION(Result, <) - DOCTEST_FORBIT_EXPRESSION(Result, >) - DOCTEST_FORBIT_EXPRESSION(Result, <=) - DOCTEST_FORBIT_EXPRESSION(Result, >=) - DOCTEST_FORBIT_EXPRESSION(Result, =) - DOCTEST_FORBIT_EXPRESSION(Result, +=) - DOCTEST_FORBIT_EXPRESSION(Result, -=) - DOCTEST_FORBIT_EXPRESSION(Result, *=) - DOCTEST_FORBIT_EXPRESSION(Result, /=) - DOCTEST_FORBIT_EXPRESSION(Result, %=) - DOCTEST_FORBIT_EXPRESSION(Result, <<=) - DOCTEST_FORBIT_EXPRESSION(Result, >>=) - DOCTEST_FORBIT_EXPRESSION(Result, &=) - DOCTEST_FORBIT_EXPRESSION(Result, ^=) - DOCTEST_FORBIT_EXPRESSION(Result, |=) - }; - -#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - - DOCTEST_CLANG_SUPPRESS_WARNING_PUSH - DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") - DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-compare") - //DOCTEST_CLANG_SUPPRESS_WARNING("-Wdouble-promotion") - //DOCTEST_CLANG_SUPPRESS_WARNING("-Wconversion") - //DOCTEST_CLANG_SUPPRESS_WARNING("-Wfloat-equal") - - DOCTEST_GCC_SUPPRESS_WARNING_PUSH - DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") - DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-compare") - //DOCTEST_GCC_SUPPRESS_WARNING("-Wdouble-promotion") - //DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") - //DOCTEST_GCC_SUPPRESS_WARNING("-Wfloat-equal") - - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH - // https://stackoverflow.com/questions/39479163 what's the difference between 4018 and 4389 - DOCTEST_MSVC_SUPPRESS_WARNING(4388) // signed/unsigned mismatch - DOCTEST_MSVC_SUPPRESS_WARNING(4389) // 'operator' : signed/unsigned mismatch - DOCTEST_MSVC_SUPPRESS_WARNING(4018) // 'expression' : signed/unsigned mismatch - //DOCTEST_MSVC_SUPPRESS_WARNING(4805) // 'operation' : unsafe mix of type 'type' and type 'type' in operation - -#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - - // clang-format off -#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_COMPARISON_RETURN_TYPE bool -#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_COMPARISON_RETURN_TYPE typename types::enable_if::value || can_use_op::value, bool>::type - inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } - inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } - inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } - inline bool gt(const char* lhs, const char* rhs) { return String(lhs) > String(rhs); } - inline bool le(const char* lhs, const char* rhs) { return String(lhs) <= String(rhs); } - inline bool ge(const char* lhs, const char* rhs) { return String(lhs) >= String(rhs); } -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - // clang-format on - -#define DOCTEST_RELATIONAL_OP(name, op) \ - template \ - DOCTEST_COMPARISON_RETURN_TYPE name(const DOCTEST_REF_WRAP(L) lhs, \ - const DOCTEST_REF_WRAP(R) rhs) { \ - return lhs op rhs; \ - } - - DOCTEST_RELATIONAL_OP(eq, ==) - DOCTEST_RELATIONAL_OP(ne, !=) - DOCTEST_RELATIONAL_OP(lt, <) - DOCTEST_RELATIONAL_OP(gt, >) - DOCTEST_RELATIONAL_OP(le, <=) - DOCTEST_RELATIONAL_OP(ge, >=) - -#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_CMP_EQ(l, r) l == r -#define DOCTEST_CMP_NE(l, r) l != r -#define DOCTEST_CMP_GT(l, r) l > r -#define DOCTEST_CMP_LT(l, r) l < r -#define DOCTEST_CMP_GE(l, r) l >= r -#define DOCTEST_CMP_LE(l, r) l <= r -#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_CMP_EQ(l, r) eq(l, r) -#define DOCTEST_CMP_NE(l, r) ne(l, r) -#define DOCTEST_CMP_GT(l, r) gt(l, r) -#define DOCTEST_CMP_LT(l, r) lt(l, r) -#define DOCTEST_CMP_GE(l, r) ge(l, r) -#define DOCTEST_CMP_LE(l, r) le(l, r) -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - - template - // cppcheck-suppress copyCtorAndEqOperator - struct Expression_lhs - { - L lhs; - assertType::Enum m_at; - - explicit Expression_lhs(L&& in, assertType::Enum at) - : lhs(static_cast(in)) - , m_at(at) {} - - DOCTEST_NOINLINE operator Result() { -// this is needed only for MSVC 2015 -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4800) // 'int': forcing value to bool - bool res = static_cast(lhs); -DOCTEST_MSVC_SUPPRESS_WARNING_POP - if(m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional - res = !res; - } - - if(!res || getContextOptions()->success) { - return { res, (DOCTEST_STRINGIFY(lhs)) }; - } - return { res }; - } - - /* This is required for user-defined conversions from Expression_lhs to L */ - operator L() const { return lhs; } - - // clang-format off - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(!=, " != ", DOCTEST_CMP_NE) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>, " > ", DOCTEST_CMP_GT) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<, " < ", DOCTEST_CMP_LT) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>=, " >= ", DOCTEST_CMP_GE) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<=, " <= ", DOCTEST_CMP_LE) //!OCLINT bitwise operator in conditional - // clang-format on - - // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &&) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ||) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, =) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, +=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, -=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, *=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, /=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, %=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |=) - // these 2 are unfortunate because they should be allowed - they have higher precedence over the comparisons, but the - // ExpressionDecomposer class uses the left shift operator to capture the left operand of the binary expression... - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>) - }; - -#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_MSVC_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP - -#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - -#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) -DOCTEST_CLANG_SUPPRESS_WARNING_POP -#endif - - struct DOCTEST_INTERFACE ExpressionDecomposer - { - assertType::Enum m_at; - - ExpressionDecomposer(assertType::Enum at); - - // The right operator for capturing expressions is "<=" instead of "<<" (based on the operator precedence table) - // but then there will be warnings from GCC about "-Wparentheses" and since "_Pragma()" is problematic this will stay for now... - // https://github.com/catchorg/Catch2/issues/870 - // https://github.com/catchorg/Catch2/issues/565 - template - Expression_lhs operator<<(L&& operand) { - return Expression_lhs(static_cast(operand), m_at); - } - - template ::value,void >::type* = nullptr> - Expression_lhs operator<<(const L &operand) { - return Expression_lhs(operand, m_at); - } - }; - - struct DOCTEST_INTERFACE TestSuite - { - const char* m_test_suite = nullptr; - const char* m_description = nullptr; - bool m_skip = false; - bool m_no_breaks = false; - bool m_no_output = false; - bool m_may_fail = false; - bool m_should_fail = false; - int m_expected_failures = 0; - double m_timeout = 0; - - TestSuite& operator*(const char* in); - - template - TestSuite& operator*(const T& in) { - in.fill(*this); - return *this; - } - }; - - using funcType = void (*)(); - - struct DOCTEST_INTERFACE TestCase : public TestCaseData - { - funcType m_test; // a function pointer to the test case - - String m_type; // for templated test cases - gets appended to the real name - int m_template_id; // an ID used to distinguish between the different versions of a templated test case - String m_full_name; // contains the name (only for templated test cases!) + the template type - - TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, - const String& type = String(), int template_id = -1); - - TestCase(const TestCase& other); - TestCase(TestCase&&) = delete; - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function - TestCase& operator=(const TestCase& other); - DOCTEST_MSVC_SUPPRESS_WARNING_POP - - TestCase& operator=(TestCase&&) = delete; - - TestCase& operator*(const char* in); - - template - TestCase& operator*(const T& in) { - in.fill(*this); - return *this; - } - - bool operator<(const TestCase& other) const; - - ~TestCase() = default; - }; - - // forward declarations of functions used by the macros - DOCTEST_INTERFACE int regTest(const TestCase& tc); - DOCTEST_INTERFACE int setTestSuite(const TestSuite& ts); - DOCTEST_INTERFACE bool isDebuggerActive(); - - template - int instantiationHelper(const T&) { return 0; } - - namespace binaryAssertComparison { - enum Enum - { - eq = 0, - ne, - gt, - lt, - ge, - le - }; - } // namespace binaryAssertComparison - - // clang-format off - template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L), const DOCTEST_REF_WRAP(R) ) const { return false; } }; - -#define DOCTEST_BINARY_RELATIONAL_OP(n, op) \ - template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) const { return op(lhs, rhs); } }; - // clang-format on - - DOCTEST_BINARY_RELATIONAL_OP(0, doctest::detail::eq) - DOCTEST_BINARY_RELATIONAL_OP(1, doctest::detail::ne) - DOCTEST_BINARY_RELATIONAL_OP(2, doctest::detail::gt) - DOCTEST_BINARY_RELATIONAL_OP(3, doctest::detail::lt) - DOCTEST_BINARY_RELATIONAL_OP(4, doctest::detail::ge) - DOCTEST_BINARY_RELATIONAL_OP(5, doctest::detail::le) - - struct DOCTEST_INTERFACE ResultBuilder : public AssertData - { - ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type = "", const String& exception_string = ""); - - ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const Contains& exception_string); - - void setResult(const Result& res); - - template - DOCTEST_NOINLINE bool binary_assert(const DOCTEST_REF_WRAP(L) lhs, - const DOCTEST_REF_WRAP(R) rhs) { - m_failed = !RelationalComparator()(lhs, rhs); - if (m_failed || getContextOptions()->success) { - m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); - } - return !m_failed; - } - - template - DOCTEST_NOINLINE bool unary_assert(const DOCTEST_REF_WRAP(L) val) { - m_failed = !val; - - if (m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional - m_failed = !m_failed; - } - - if (m_failed || getContextOptions()->success) { - m_decomp = (DOCTEST_STRINGIFY(val)); - } - - return !m_failed; - } - - void translateException(); - - bool log(); - void react() const; - }; - - namespace assertAction { - enum Enum - { - nothing = 0, - dbgbreak = 1, - shouldthrow = 2 - }; - } // namespace assertAction - - DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); - - DOCTEST_INTERFACE bool decomp_assert(assertType::Enum at, const char* file, int line, - const char* expr, const Result& result); - -#define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ - do { \ - if(!is_running_in_test) { \ - if(failed) { \ - ResultBuilder rb(at, file, line, expr); \ - rb.m_failed = failed; \ - rb.m_decomp = decomp; \ - failed_out_of_a_testing_context(rb); \ - if(isDebuggerActive() && !getContextOptions()->no_breaks) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - if(checkIfShouldThrow(at)) \ - throwException(); \ - } \ - return !failed; \ - } \ - } while(false) - -#define DOCTEST_ASSERT_IN_TESTS(decomp) \ - ResultBuilder rb(at, file, line, expr); \ - rb.m_failed = failed; \ - if(rb.m_failed || getContextOptions()->success) \ - rb.m_decomp = decomp; \ - if(rb.log()) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - if(rb.m_failed && checkIfShouldThrow(at)) \ - throwException() - - template - DOCTEST_NOINLINE bool binary_assert(assertType::Enum at, const char* file, int line, - const char* expr, const DOCTEST_REF_WRAP(L) lhs, - const DOCTEST_REF_WRAP(R) rhs) { - bool failed = !RelationalComparator()(lhs, rhs); - - // ################################################################################### - // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT - // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED - // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); - DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); - return !failed; - } - - template - DOCTEST_NOINLINE bool unary_assert(assertType::Enum at, const char* file, int line, - const char* expr, const DOCTEST_REF_WRAP(L) val) { - bool failed = !val; - - if(at & assertType::is_false) //!OCLINT bitwise operator in conditional - failed = !failed; - - // ################################################################################### - // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT - // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED - // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS((DOCTEST_STRINGIFY(val))); - DOCTEST_ASSERT_IN_TESTS((DOCTEST_STRINGIFY(val))); - return !failed; - } - - struct DOCTEST_INTERFACE IExceptionTranslator - { - DOCTEST_DECLARE_INTERFACE(IExceptionTranslator) - virtual bool translate(String&) const = 0; - }; - - template - class ExceptionTranslator : public IExceptionTranslator //!OCLINT destructor of virtual class - { - public: - explicit ExceptionTranslator(String (*translateFunction)(T)) - : m_translateFunction(translateFunction) {} - - bool translate(String& res) const override { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - try { - throw; // lgtm [cpp/rethrow-no-exception] - // cppcheck-suppress catchExceptionByValue - } catch(const T& ex) { - res = m_translateFunction(ex); //!OCLINT parameter reassignment - return true; - } catch(...) {} //!OCLINT - empty catch statement -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - static_cast(res); // to silence -Wunused-parameter - return false; - } - - private: - String (*m_translateFunction)(T); - }; - - DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); - - // ContextScope base class used to allow implementing methods of ContextScope - // that don't depend on the template parameter in doctest.cpp. - struct DOCTEST_INTERFACE ContextScopeBase : public IContextScope { - ContextScopeBase(const ContextScopeBase&) = delete; - - ContextScopeBase& operator=(const ContextScopeBase&) = delete; - ContextScopeBase& operator=(ContextScopeBase&&) = delete; - - ~ContextScopeBase() override = default; - - protected: - ContextScopeBase(); - ContextScopeBase(ContextScopeBase&& other) noexcept; - - void destroy(); - bool need_to_destroy{true}; - }; - - template class ContextScope : public ContextScopeBase - { - L lambda_; - - public: - explicit ContextScope(const L &lambda) : lambda_(lambda) {} - explicit ContextScope(L&& lambda) : lambda_(static_cast(lambda)) { } - - ContextScope(const ContextScope&) = delete; - ContextScope(ContextScope&&) noexcept = default; - - ContextScope& operator=(const ContextScope&) = delete; - ContextScope& operator=(ContextScope&&) = delete; - - void stringify(std::ostream* s) const override { lambda_(s); } - - ~ContextScope() override { - if (need_to_destroy) { - destroy(); - } - } - }; - - struct DOCTEST_INTERFACE MessageBuilder : public MessageData - { - std::ostream* m_stream; - bool logged = false; - - MessageBuilder(const char* file, int line, assertType::Enum severity); - - MessageBuilder(const MessageBuilder&) = delete; - MessageBuilder(MessageBuilder&&) = delete; - - MessageBuilder& operator=(const MessageBuilder&) = delete; - MessageBuilder& operator=(MessageBuilder&&) = delete; - - ~MessageBuilder(); - - // the preferred way of chaining parameters for stringification -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) - template - MessageBuilder& operator,(const T& in) { - *m_stream << (DOCTEST_STRINGIFY(in)); - return *this; - } -DOCTEST_MSVC_SUPPRESS_WARNING_POP - - // kept here just for backwards-compatibility - the comma operator should be preferred now - template - MessageBuilder& operator<<(const T& in) { return this->operator,(in); } - - // the `,` operator has the lowest operator precedence - if `<<` is used by the user then - // the `,` operator will be called last which is not what we want and thus the `*` operator - // is used first (has higher operator precedence compared to `<<`) so that we guarantee that - // an operator of the MessageBuilder class is called first before the rest of the parameters - template - MessageBuilder& operator*(const T& in) { return this->operator,(in); } - - bool log(); - void react(); - }; - - template - ContextScope MakeContextScope(const L &lambda) { - return ContextScope(lambda); - } -} // namespace detail - -#define DOCTEST_DEFINE_DECORATOR(name, type, def) \ - struct name \ - { \ - type data; \ - name(type in = def) \ - : data(in) {} \ - void fill(detail::TestCase& state) const { state.DOCTEST_CAT(m_, name) = data; } \ - void fill(detail::TestSuite& state) const { state.DOCTEST_CAT(m_, name) = data; } \ - } - -DOCTEST_DEFINE_DECORATOR(test_suite, const char*, ""); -DOCTEST_DEFINE_DECORATOR(description, const char*, ""); -DOCTEST_DEFINE_DECORATOR(skip, bool, true); -DOCTEST_DEFINE_DECORATOR(no_breaks, bool, true); -DOCTEST_DEFINE_DECORATOR(no_output, bool, true); -DOCTEST_DEFINE_DECORATOR(timeout, double, 0); -DOCTEST_DEFINE_DECORATOR(may_fail, bool, true); -DOCTEST_DEFINE_DECORATOR(should_fail, bool, true); -DOCTEST_DEFINE_DECORATOR(expected_failures, int, 0); - -template -int registerExceptionTranslator(String (*translateFunction)(T)) { - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") - static detail::ExceptionTranslator exceptionTranslator(translateFunction); - DOCTEST_CLANG_SUPPRESS_WARNING_POP - detail::registerExceptionTranslatorImpl(&exceptionTranslator); - return 0; -} - -} // namespace doctest - -// in a separate namespace outside of doctest because the DOCTEST_TEST_SUITE macro -// introduces an anonymous namespace in which getCurrentTestSuite gets overridden -namespace doctest_detail_test_suite_ns { -DOCTEST_INTERFACE doctest::detail::TestSuite& getCurrentTestSuite(); -} // namespace doctest_detail_test_suite_ns - -namespace doctest { -#else // DOCTEST_CONFIG_DISABLE -template -int registerExceptionTranslator(String (*)(T)) { - return 0; -} -#endif // DOCTEST_CONFIG_DISABLE - -namespace detail { - using assert_handler = void (*)(const AssertData&); - struct ContextState; -} // namespace detail - -class DOCTEST_INTERFACE Context -{ - detail::ContextState* p; - - void parseArgs(int argc, const char* const* argv, bool withDefaults = false); - -public: - explicit Context(int argc = 0, const char* const* argv = nullptr); - - Context(const Context&) = delete; - Context(Context&&) = delete; - - Context& operator=(const Context&) = delete; - Context& operator=(Context&&) = delete; - - ~Context(); // NOLINT(performance-trivially-destructible) - - void applyCommandLine(int argc, const char* const* argv); - - void addFilter(const char* filter, const char* value); - void clearFilters(); - void setOption(const char* option, bool value); - void setOption(const char* option, int value); - void setOption(const char* option, const char* value); - - bool shouldExit(); - - void setAsDefaultForAssertsOutOfTestCases(); - - void setAssertHandler(detail::assert_handler ah); - - void setCout(std::ostream* out); - - int run(); -}; - -namespace TestCaseFailureReason { - enum Enum - { - None = 0, - AssertFailure = 1, // an assertion has failed in the test case - Exception = 2, // test case threw an exception - Crash = 4, // a crash... - TooManyFailedAsserts = 8, // the abort-after option - Timeout = 16, // see the timeout decorator - ShouldHaveFailedButDidnt = 32, // see the should_fail decorator - ShouldHaveFailedAndDid = 64, // see the should_fail decorator - DidntFailExactlyNumTimes = 128, // see the expected_failures decorator - FailedExactlyNumTimes = 256, // see the expected_failures decorator - CouldHaveFailedAndDid = 512 // see the may_fail decorator - }; -} // namespace TestCaseFailureReason - -struct DOCTEST_INTERFACE CurrentTestCaseStats -{ - int numAssertsCurrentTest; - int numAssertsFailedCurrentTest; - double seconds; - int failure_flags; // use TestCaseFailureReason::Enum - bool testCaseSuccess; -}; - -struct DOCTEST_INTERFACE TestCaseException -{ - String error_string; - bool is_crash; -}; - -struct DOCTEST_INTERFACE TestRunStats -{ - unsigned numTestCases; - unsigned numTestCasesPassingFilters; - unsigned numTestSuitesPassingFilters; - unsigned numTestCasesFailed; - int numAsserts; - int numAssertsFailed; -}; - -struct QueryData -{ - const TestRunStats* run_stats = nullptr; - const TestCaseData** data = nullptr; - unsigned num_data = 0; -}; - -struct DOCTEST_INTERFACE IReporter -{ - // The constructor has to accept "const ContextOptions&" as a single argument - // which has most of the options for the run + a pointer to the stdout stream - // Reporter(const ContextOptions& in) - - // called when a query should be reported (listing test cases, printing the version, etc.) - virtual void report_query(const QueryData&) = 0; - - // called when the whole test run starts - virtual void test_run_start() = 0; - // called when the whole test run ends (caching a pointer to the input doesn't make sense here) - virtual void test_run_end(const TestRunStats&) = 0; - - // called when a test case is started (safe to cache a pointer to the input) - virtual void test_case_start(const TestCaseData&) = 0; - // called when a test case is reentered because of unfinished subcases (safe to cache a pointer to the input) - virtual void test_case_reenter(const TestCaseData&) = 0; - // called when a test case has ended - virtual void test_case_end(const CurrentTestCaseStats&) = 0; - - // called when an exception is thrown from the test case (or it crashes) - virtual void test_case_exception(const TestCaseException&) = 0; - - // called whenever a subcase is entered (don't cache pointers to the input) - virtual void subcase_start(const SubcaseSignature&) = 0; - // called whenever a subcase is exited (don't cache pointers to the input) - virtual void subcase_end() = 0; - - // called for each assert (don't cache pointers to the input) - virtual void log_assert(const AssertData&) = 0; - // called for each message (don't cache pointers to the input) - virtual void log_message(const MessageData&) = 0; - - // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator - // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) - virtual void test_case_skipped(const TestCaseData&) = 0; - - DOCTEST_DECLARE_INTERFACE(IReporter) - - // can obtain all currently active contexts and stringify them if one wishes to do so - static int get_num_active_contexts(); - static const IContextScope* const* get_active_contexts(); - - // can iterate through contexts which have been stringified automatically in their destructors when an exception has been thrown - static int get_num_stringified_contexts(); - static const String* get_stringified_contexts(); -}; - -namespace detail { - using reporterCreatorFunc = IReporter* (*)(const ContextOptions&); - - DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); - - template - IReporter* reporterCreator(const ContextOptions& o) { - return new Reporter(o); - } -} // namespace detail - -template -int registerReporter(const char* name, int priority, bool isReporter) { - detail::registerReporterImpl(name, priority, detail::reporterCreator, isReporter); - return 0; -} -} // namespace doctest - -#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES -#define DOCTEST_FUNC_EMPTY [] { return false; }() -#else -#define DOCTEST_FUNC_EMPTY (void)0 -#endif - -// if registering is not disabled -#ifndef DOCTEST_CONFIG_DISABLE - -#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES -#define DOCTEST_FUNC_SCOPE_BEGIN [&] -#define DOCTEST_FUNC_SCOPE_END () -#define DOCTEST_FUNC_SCOPE_RET(v) return v -#else -#define DOCTEST_FUNC_SCOPE_BEGIN do -#define DOCTEST_FUNC_SCOPE_END while(false) -#define DOCTEST_FUNC_SCOPE_RET(v) (void)0 -#endif - -// common code in asserts - for convenience -#define DOCTEST_ASSERT_LOG_REACT_RETURN(b) \ - if(b.log()) DOCTEST_BREAK_INTO_DEBUGGER(); \ - b.react(); \ - DOCTEST_FUNC_SCOPE_RET(!b.m_failed) - -#ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS -#define DOCTEST_WRAP_IN_TRY(x) x; -#else // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS -#define DOCTEST_WRAP_IN_TRY(x) \ - try { \ - x; \ - } catch(...) { DOCTEST_RB.translateException(); } -#endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS - -#ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS -#define DOCTEST_CAST_TO_VOID(...) \ - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wuseless-cast") \ - static_cast(__VA_ARGS__); \ - DOCTEST_GCC_SUPPRESS_WARNING_POP -#else // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS -#define DOCTEST_CAST_TO_VOID(...) __VA_ARGS__; -#endif // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS - -// registers the test by initializing a dummy var with a function -#define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ - global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT */ \ - doctest::detail::regTest( \ - doctest::detail::TestCase( \ - f, __FILE__, __LINE__, \ - doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ - decorators)) - -#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ - namespace { /* NOLINT */ \ - struct der : public base \ - { \ - void f(); \ - }; \ - static inline DOCTEST_NOINLINE void func() { \ - der v; \ - v.f(); \ - } \ - DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ - } \ - inline DOCTEST_NOINLINE void der::f() // NOLINT(misc-definitions-in-headers) - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ - static void f(); \ - DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, f, decorators) \ - static void f() - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ - static doctest::detail::funcType proxy() { return f; } \ - DOCTEST_REGISTER_FUNCTION(inline, proxy(), decorators) \ - static void f() - -// for registering tests -#define DOCTEST_TEST_CASE(decorators) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) - -// for registering tests in classes - requires C++17 for inline variables! -#if DOCTEST_CPLUSPLUS >= 201703L -#define DOCTEST_TEST_CASE_CLASS(decorators) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_PROXY_), \ - decorators) -#else // DOCTEST_TEST_CASE_CLASS -#define DOCTEST_TEST_CASE_CLASS(...) \ - TEST_CASES_CAN_BE_REGISTERED_IN_CLASSES_ONLY_IN_CPP17_MODE_OR_WITH_VS_2017_OR_NEWER -#endif // DOCTEST_TEST_CASE_CLASS - -// for registering tests with a fixture -#define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ - DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), c, \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) - -// for converting types to strings without the header and demangling -#define DOCTEST_TYPE_TO_STRING_AS(str, ...) \ - namespace doctest { \ - template <> \ - inline String toString<__VA_ARGS__>() { \ - return str; \ - } \ - } \ - static_assert(true, "") - -#define DOCTEST_TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING_AS(#__VA_ARGS__, __VA_ARGS__) - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ - template \ - static void func(); \ - namespace { /* NOLINT */ \ - template \ - struct iter; \ - template \ - struct iter> \ - { \ - iter(const char* file, unsigned line, int index) { \ - doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ - doctest_detail_test_suite_ns::getCurrentTestSuite(), \ - doctest::toString(), \ - int(line) * 1000 + index) \ - * dec); \ - iter>(file, line, index + 1); \ - } \ - }; \ - template <> \ - struct iter> \ - { \ - iter(const char*, unsigned, int) {} \ - }; \ - } \ - template \ - static void func() - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ - DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)) - -#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY), /* NOLINT(cert-err58-cpp, fuchsia-statically-constructed-objects) */ \ - doctest::detail::instantiationHelper( \ - DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0))) - -#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ - static_assert(true, "") - -#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) \ - static_assert(true, "") - -#define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(anon, anon, std::tuple<__VA_ARGS__>) \ - template \ - static void anon() - -#define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) - -// for subcases -#define DOCTEST_SUBCASE(name) \ - if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ - doctest::detail::Subcase(name, __FILE__, __LINE__)) - -// for grouping tests in test suites by using code blocks -#define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ - namespace ns_name { namespace doctest_detail_test_suite_ns { \ - static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() noexcept { \ - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmissing-field-initializers") \ - static doctest::detail::TestSuite data{}; \ - static bool inited = false; \ - DOCTEST_MSVC_SUPPRESS_WARNING_POP \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP \ - DOCTEST_GCC_SUPPRESS_WARNING_POP \ - if(!inited) { \ - data* decorators; \ - inited = true; \ - } \ - return data; \ - } \ - } \ - } \ - namespace ns_name - -#define DOCTEST_TEST_SUITE(decorators) \ - DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(DOCTEST_ANON_SUITE_)) - -// for starting a testsuite block -#define DOCTEST_TEST_SUITE_BEGIN(decorators) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ - doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators)) \ - static_assert(true, "") - -// for ending a testsuite block -#define DOCTEST_TEST_SUITE_END \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ - doctest::detail::setTestSuite(doctest::detail::TestSuite() * "")) \ - using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int - -// for registering exception translators -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ - inline doctest::String translatorName(signature); \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), /* NOLINT(cert-err58-cpp) */ \ - doctest::registerExceptionTranslator(translatorName)) \ - doctest::String translatorName(signature) - -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ - DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), \ - signature) - -// for registering reporters -#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ - doctest::registerReporter(name, priority, true)) \ - static_assert(true, "") - -// for registering listeners -#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ - doctest::registerReporter(name, priority, false)) \ - static_assert(true, "") - -// clang-format off -// for logging - disabling formatting because it's important to have these on 2 separate lines - see PR #557 -#define DOCTEST_INFO(...) \ - DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_), \ - DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_OTHER_), \ - __VA_ARGS__) -// clang-format on - -#define DOCTEST_INFO_IMPL(mb_name, s_name, ...) \ - auto DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ - [&](std::ostream* s_name) { \ - doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ - mb_name.m_stream = s_name; \ - mb_name * __VA_ARGS__; \ - }) - -#define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := ", x) - -#define DOCTEST_ADD_AT_IMPL(type, file, line, mb, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ - mb * __VA_ARGS__; \ - if(mb.log()) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - mb.react(); \ - } DOCTEST_FUNC_SCOPE_END - -// clang-format off -#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) -#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) -#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) -// clang-format on - -#define DOCTEST_MESSAGE(...) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, __VA_ARGS__) -#define DOCTEST_FAIL_CHECK(...) DOCTEST_ADD_FAIL_CHECK_AT(__FILE__, __LINE__, __VA_ARGS__) -#define DOCTEST_FAIL(...) DOCTEST_ADD_FAIL_AT(__FILE__, __LINE__, __VA_ARGS__) - -#define DOCTEST_TO_LVALUE(...) __VA_ARGS__ // Not removed to keep backwards compatibility. - -#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -#define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ - /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY(DOCTEST_RB.setResult( \ - doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ - << __VA_ARGS__)) /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB) \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP - -#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ - } DOCTEST_FUNC_SCOPE_END // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - -#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY( \ - DOCTEST_RB.binary_assert( \ - __VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } DOCTEST_FUNC_SCOPE_END - -#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY(DOCTEST_RB.unary_assert(__VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } DOCTEST_FUNC_SCOPE_END - -#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -// necessary for _MESSAGE -#define DOCTEST_ASSERT_IMPLEMENT_2 DOCTEST_ASSERT_IMPLEMENT_1 - -#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ - doctest::detail::decomp_assert( \ - doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, \ - doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ - << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP - -#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ - doctest::detail::binary_assert( \ - doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) - -#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ - doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ - #__VA_ARGS__, __VA_ARGS__) - -#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -#define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) -#define DOCTEST_CHECK(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK, __VA_ARGS__) -#define DOCTEST_REQUIRE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE, __VA_ARGS__) -#define DOCTEST_WARN_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN_FALSE, __VA_ARGS__) -#define DOCTEST_CHECK_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK_FALSE, __VA_ARGS__) -#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) - -// clang-format off -#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } DOCTEST_FUNC_SCOPE_END -// clang-format on - -#define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) -#define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) -#define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) -#define DOCTEST_WARN_NE(...) DOCTEST_BINARY_ASSERT(DT_WARN_NE, ne, __VA_ARGS__) -#define DOCTEST_CHECK_NE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_NE, ne, __VA_ARGS__) -#define DOCTEST_REQUIRE_NE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_NE, ne, __VA_ARGS__) -#define DOCTEST_WARN_GT(...) DOCTEST_BINARY_ASSERT(DT_WARN_GT, gt, __VA_ARGS__) -#define DOCTEST_CHECK_GT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GT, gt, __VA_ARGS__) -#define DOCTEST_REQUIRE_GT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GT, gt, __VA_ARGS__) -#define DOCTEST_WARN_LT(...) DOCTEST_BINARY_ASSERT(DT_WARN_LT, lt, __VA_ARGS__) -#define DOCTEST_CHECK_LT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LT, lt, __VA_ARGS__) -#define DOCTEST_REQUIRE_LT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LT, lt, __VA_ARGS__) -#define DOCTEST_WARN_GE(...) DOCTEST_BINARY_ASSERT(DT_WARN_GE, ge, __VA_ARGS__) -#define DOCTEST_CHECK_GE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GE, ge, __VA_ARGS__) -#define DOCTEST_REQUIRE_GE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GE, ge, __VA_ARGS__) -#define DOCTEST_WARN_LE(...) DOCTEST_BINARY_ASSERT(DT_WARN_LE, le, __VA_ARGS__) -#define DOCTEST_CHECK_LE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LE, le, __VA_ARGS__) -#define DOCTEST_REQUIRE_LE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LE, le, __VA_ARGS__) - -#define DOCTEST_WARN_UNARY(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY, __VA_ARGS__) -#define DOCTEST_CHECK_UNARY(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY, __VA_ARGS__) -#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY, __VA_ARGS__) -#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY_FALSE, __VA_ARGS__) -#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) -#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - -#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - if(!doctest::getContextOptions()->no_throw) { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #expr, #__VA_ARGS__, message); \ - try { \ - DOCTEST_CAST_TO_VOID(expr) \ - } catch(const typename doctest::detail::types::remove_const< \ - typename doctest::detail::types::remove_reference<__VA_ARGS__>::type>::type&) {\ - DOCTEST_RB.translateException(); \ - DOCTEST_RB.m_threw_as = true; \ - } catch(...) { DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } else { /* NOLINT(*-else-after-return) */ \ - DOCTEST_FUNC_SCOPE_RET(false); \ - } \ - } DOCTEST_FUNC_SCOPE_END - -#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - if(!doctest::getContextOptions()->no_throw) { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, expr_str, "", __VA_ARGS__); \ - try { \ - DOCTEST_CAST_TO_VOID(expr) \ - } catch(...) { DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } else { /* NOLINT(*-else-after-return) */ \ - DOCTEST_FUNC_SCOPE_RET(false); \ - } \ - } DOCTEST_FUNC_SCOPE_END - -#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - try { \ - DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ - } catch(...) { DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } DOCTEST_FUNC_SCOPE_END - -// clang-format off -#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") -#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") - -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) - -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END -// clang-format on - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -// ================================================================================================= -// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == -// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == -// ================================================================================================= -#else // DOCTEST_CONFIG_DISABLE - -#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ - namespace /* NOLINT */ { \ - template \ - struct der : public base \ - { void f(); }; \ - } \ - template \ - inline void der::f() - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ - template \ - static inline void f() - -// for registering tests -#define DOCTEST_TEST_CASE(name) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) - -// for registering tests in classes -#define DOCTEST_TEST_CASE_CLASS(name) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) - -// for registering tests with a fixture -#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ - DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), x, \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) - -// for converting types to strings without the header and demangling -#define DOCTEST_TYPE_TO_STRING_AS(str, ...) static_assert(true, "") -#define DOCTEST_TYPE_TO_STRING(...) static_assert(true, "") - -// for typed tests -#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ - template \ - inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ - template \ - inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() - -#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) static_assert(true, "") -#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) static_assert(true, "") - -// for subcases -#define DOCTEST_SUBCASE(name) - -// for a testsuite block -#define DOCTEST_TEST_SUITE(name) namespace // NOLINT - -// for starting a testsuite block -#define DOCTEST_TEST_SUITE_BEGIN(name) static_assert(true, "") - -// for ending a testsuite block -#define DOCTEST_TEST_SUITE_END using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int - -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ - template \ - static inline doctest::String DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_)(signature) - -#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) -#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) - -#define DOCTEST_INFO(...) (static_cast(0)) -#define DOCTEST_CAPTURE(x) (static_cast(0)) -#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_MESSAGE(...) (static_cast(0)) -#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) -#define DOCTEST_FAIL(...) (static_cast(0)) - -#if defined(DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED) \ - && defined(DOCTEST_CONFIG_ASSERTS_RETURN_VALUES) - -#define DOCTEST_WARN(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_CHECK(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_REQUIRE(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_WARN_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_CHECK_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_FALSE(...) [&] { return !(__VA_ARGS__); }() - -#define DOCTEST_WARN_MESSAGE(cond, ...) [&] { return cond; }() -#define DOCTEST_CHECK_MESSAGE(cond, ...) [&] { return cond; }() -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) [&] { return cond; }() -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() - -namespace doctest { -namespace detail { -#define DOCTEST_RELATIONAL_OP(name, op) \ - template \ - bool name(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { return lhs op rhs; } - - DOCTEST_RELATIONAL_OP(eq, ==) - DOCTEST_RELATIONAL_OP(ne, !=) - DOCTEST_RELATIONAL_OP(lt, <) - DOCTEST_RELATIONAL_OP(gt, >) - DOCTEST_RELATIONAL_OP(le, <=) - DOCTEST_RELATIONAL_OP(ge, >=) -} // namespace detail -} // namespace doctest - -#define DOCTEST_WARN_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() -#define DOCTEST_CHECK_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() -#define DOCTEST_WARN_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() -#define DOCTEST_CHECK_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() -#define DOCTEST_WARN_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() -#define DOCTEST_CHECK_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() -#define DOCTEST_WARN_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() -#define DOCTEST_CHECK_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() -#define DOCTEST_WARN_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() -#define DOCTEST_CHECK_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() -#define DOCTEST_WARN_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() -#define DOCTEST_CHECK_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() -#define DOCTEST_WARN_UNARY(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_CHECK_UNARY(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_REQUIRE_UNARY(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_WARN_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_CHECK_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - -#define DOCTEST_WARN_THROWS_WITH(expr, with, ...) [] { static_assert(false, "Exception translation is not available when doctest is disabled."); return false; }() -#define DOCTEST_CHECK_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) - -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) - -#define DOCTEST_WARN_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_CHECK_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_REQUIRE_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_WARN_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_CHECK_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_WARN_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_CHECK_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_REQUIRE_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -#else // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED - -#define DOCTEST_WARN(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_EQ(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_EQ(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_EQ(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_NE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_NE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_NE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_GT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_GT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_GT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_LT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_LT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_LT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_GE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_GE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_GE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_LE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_LE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_LE(...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_UNARY(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_UNARY(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - -#define DOCTEST_WARN_THROWS(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -#endif // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED - -#endif // DOCTEST_CONFIG_DISABLE - -#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS - -#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS -#define DOCTEST_EXCEPTION_EMPTY_FUNC DOCTEST_FUNC_EMPTY -#else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS -#define DOCTEST_EXCEPTION_EMPTY_FUNC [] { static_assert(false, "Exceptions are disabled! " \ - "Use DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS if you want to compile with exceptions disabled."); return false; }() - -#undef DOCTEST_REQUIRE -#undef DOCTEST_REQUIRE_FALSE -#undef DOCTEST_REQUIRE_MESSAGE -#undef DOCTEST_REQUIRE_FALSE_MESSAGE -#undef DOCTEST_REQUIRE_EQ -#undef DOCTEST_REQUIRE_NE -#undef DOCTEST_REQUIRE_GT -#undef DOCTEST_REQUIRE_LT -#undef DOCTEST_REQUIRE_GE -#undef DOCTEST_REQUIRE_LE -#undef DOCTEST_REQUIRE_UNARY -#undef DOCTEST_REQUIRE_UNARY_FALSE - -#define DOCTEST_REQUIRE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_FALSE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_EQ DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_NE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_GT DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_LT DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_GE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_LE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_UNARY DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_UNARY_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS - -#define DOCTEST_WARN_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -// clang-format off -// KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS -#define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ -#define DOCTEST_FAST_CHECK_EQ DOCTEST_CHECK_EQ -#define DOCTEST_FAST_REQUIRE_EQ DOCTEST_REQUIRE_EQ -#define DOCTEST_FAST_WARN_NE DOCTEST_WARN_NE -#define DOCTEST_FAST_CHECK_NE DOCTEST_CHECK_NE -#define DOCTEST_FAST_REQUIRE_NE DOCTEST_REQUIRE_NE -#define DOCTEST_FAST_WARN_GT DOCTEST_WARN_GT -#define DOCTEST_FAST_CHECK_GT DOCTEST_CHECK_GT -#define DOCTEST_FAST_REQUIRE_GT DOCTEST_REQUIRE_GT -#define DOCTEST_FAST_WARN_LT DOCTEST_WARN_LT -#define DOCTEST_FAST_CHECK_LT DOCTEST_CHECK_LT -#define DOCTEST_FAST_REQUIRE_LT DOCTEST_REQUIRE_LT -#define DOCTEST_FAST_WARN_GE DOCTEST_WARN_GE -#define DOCTEST_FAST_CHECK_GE DOCTEST_CHECK_GE -#define DOCTEST_FAST_REQUIRE_GE DOCTEST_REQUIRE_GE -#define DOCTEST_FAST_WARN_LE DOCTEST_WARN_LE -#define DOCTEST_FAST_CHECK_LE DOCTEST_CHECK_LE -#define DOCTEST_FAST_REQUIRE_LE DOCTEST_REQUIRE_LE - -#define DOCTEST_FAST_WARN_UNARY DOCTEST_WARN_UNARY -#define DOCTEST_FAST_CHECK_UNARY DOCTEST_CHECK_UNARY -#define DOCTEST_FAST_REQUIRE_UNARY DOCTEST_REQUIRE_UNARY -#define DOCTEST_FAST_WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE -#define DOCTEST_FAST_CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE -#define DOCTEST_FAST_REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE - -#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id,__VA_ARGS__) -// clang-format on - -// BDD style macros -// clang-format off -#define DOCTEST_SCENARIO(name) DOCTEST_TEST_CASE(" Scenario: " name) -#define DOCTEST_SCENARIO_CLASS(name) DOCTEST_TEST_CASE_CLASS(" Scenario: " name) -#define DOCTEST_SCENARIO_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(" Scenario: " name, T, __VA_ARGS__) -#define DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(" Scenario: " name, T, id) - -#define DOCTEST_GIVEN(name) DOCTEST_SUBCASE(" Given: " name) -#define DOCTEST_WHEN(name) DOCTEST_SUBCASE(" When: " name) -#define DOCTEST_AND_WHEN(name) DOCTEST_SUBCASE("And when: " name) -#define DOCTEST_THEN(name) DOCTEST_SUBCASE(" Then: " name) -#define DOCTEST_AND_THEN(name) DOCTEST_SUBCASE(" And: " name) -// clang-format on - -// == SHORT VERSIONS OF THE MACROS -#ifndef DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES - -#define TEST_CASE(name) DOCTEST_TEST_CASE(name) -#define TEST_CASE_CLASS(name) DOCTEST_TEST_CASE_CLASS(name) -#define TEST_CASE_FIXTURE(x, name) DOCTEST_TEST_CASE_FIXTURE(x, name) -#define TYPE_TO_STRING_AS(str, ...) DOCTEST_TYPE_TO_STRING_AS(str, __VA_ARGS__) -#define TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING(__VA_ARGS__) -#define TEST_CASE_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(name, T, __VA_ARGS__) -#define TEST_CASE_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, T, id) -#define TEST_CASE_TEMPLATE_INVOKE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, __VA_ARGS__) -#define TEST_CASE_TEMPLATE_APPLY(id, ...) DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, __VA_ARGS__) -#define SUBCASE(name) DOCTEST_SUBCASE(name) -#define TEST_SUITE(decorators) DOCTEST_TEST_SUITE(decorators) -#define TEST_SUITE_BEGIN(name) DOCTEST_TEST_SUITE_BEGIN(name) -#define TEST_SUITE_END DOCTEST_TEST_SUITE_END -#define REGISTER_EXCEPTION_TRANSLATOR(signature) DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) -#define REGISTER_REPORTER(name, priority, reporter) DOCTEST_REGISTER_REPORTER(name, priority, reporter) -#define REGISTER_LISTENER(name, priority, reporter) DOCTEST_REGISTER_LISTENER(name, priority, reporter) -#define INFO(...) DOCTEST_INFO(__VA_ARGS__) -#define CAPTURE(x) DOCTEST_CAPTURE(x) -#define ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_MESSAGE_AT(file, line, __VA_ARGS__) -#define ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_FAIL_CHECK_AT(file, line, __VA_ARGS__) -#define ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_FAIL_AT(file, line, __VA_ARGS__) -#define MESSAGE(...) DOCTEST_MESSAGE(__VA_ARGS__) -#define FAIL_CHECK(...) DOCTEST_FAIL_CHECK(__VA_ARGS__) -#define FAIL(...) DOCTEST_FAIL(__VA_ARGS__) -#define TO_LVALUE(...) DOCTEST_TO_LVALUE(__VA_ARGS__) - -#define WARN(...) DOCTEST_WARN(__VA_ARGS__) -#define WARN_FALSE(...) DOCTEST_WARN_FALSE(__VA_ARGS__) -#define WARN_THROWS(...) DOCTEST_WARN_THROWS(__VA_ARGS__) -#define WARN_THROWS_AS(expr, ...) DOCTEST_WARN_THROWS_AS(expr, __VA_ARGS__) -#define WARN_THROWS_WITH(expr, ...) DOCTEST_WARN_THROWS_WITH(expr, __VA_ARGS__) -#define WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_WARN_THROWS_WITH_AS(expr, with, __VA_ARGS__) -#define WARN_NOTHROW(...) DOCTEST_WARN_NOTHROW(__VA_ARGS__) -#define CHECK(...) DOCTEST_CHECK(__VA_ARGS__) -#define CHECK_FALSE(...) DOCTEST_CHECK_FALSE(__VA_ARGS__) -#define CHECK_THROWS(...) DOCTEST_CHECK_THROWS(__VA_ARGS__) -#define CHECK_THROWS_AS(expr, ...) DOCTEST_CHECK_THROWS_AS(expr, __VA_ARGS__) -#define CHECK_THROWS_WITH(expr, ...) DOCTEST_CHECK_THROWS_WITH(expr, __VA_ARGS__) -#define CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_AS(expr, with, __VA_ARGS__) -#define CHECK_NOTHROW(...) DOCTEST_CHECK_NOTHROW(__VA_ARGS__) -#define REQUIRE(...) DOCTEST_REQUIRE(__VA_ARGS__) -#define REQUIRE_FALSE(...) DOCTEST_REQUIRE_FALSE(__VA_ARGS__) -#define REQUIRE_THROWS(...) DOCTEST_REQUIRE_THROWS(__VA_ARGS__) -#define REQUIRE_THROWS_AS(expr, ...) DOCTEST_REQUIRE_THROWS_AS(expr, __VA_ARGS__) -#define REQUIRE_THROWS_WITH(expr, ...) DOCTEST_REQUIRE_THROWS_WITH(expr, __VA_ARGS__) -#define REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, __VA_ARGS__) -#define REQUIRE_NOTHROW(...) DOCTEST_REQUIRE_NOTHROW(__VA_ARGS__) - -#define WARN_MESSAGE(cond, ...) DOCTEST_WARN_MESSAGE(cond, __VA_ARGS__) -#define WARN_FALSE_MESSAGE(cond, ...) DOCTEST_WARN_FALSE_MESSAGE(cond, __VA_ARGS__) -#define WARN_THROWS_MESSAGE(expr, ...) DOCTEST_WARN_THROWS_MESSAGE(expr, __VA_ARGS__) -#define WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) -#define WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) -#define WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) -#define WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_WARN_NOTHROW_MESSAGE(expr, __VA_ARGS__) -#define CHECK_MESSAGE(cond, ...) DOCTEST_CHECK_MESSAGE(cond, __VA_ARGS__) -#define CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_CHECK_FALSE_MESSAGE(cond, __VA_ARGS__) -#define CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_CHECK_THROWS_MESSAGE(expr, __VA_ARGS__) -#define CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) -#define CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) -#define CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) -#define CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_CHECK_NOTHROW_MESSAGE(expr, __VA_ARGS__) -#define REQUIRE_MESSAGE(cond, ...) DOCTEST_REQUIRE_MESSAGE(cond, __VA_ARGS__) -#define REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_REQUIRE_FALSE_MESSAGE(cond, __VA_ARGS__) -#define REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_REQUIRE_THROWS_MESSAGE(expr, __VA_ARGS__) -#define REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) -#define REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) -#define REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) -#define REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, __VA_ARGS__) - -#define SCENARIO(name) DOCTEST_SCENARIO(name) -#define SCENARIO_CLASS(name) DOCTEST_SCENARIO_CLASS(name) -#define SCENARIO_TEMPLATE(name, T, ...) DOCTEST_SCENARIO_TEMPLATE(name, T, __VA_ARGS__) -#define SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) -#define GIVEN(name) DOCTEST_GIVEN(name) -#define WHEN(name) DOCTEST_WHEN(name) -#define AND_WHEN(name) DOCTEST_AND_WHEN(name) -#define THEN(name) DOCTEST_THEN(name) -#define AND_THEN(name) DOCTEST_AND_THEN(name) - -#define WARN_EQ(...) DOCTEST_WARN_EQ(__VA_ARGS__) -#define CHECK_EQ(...) DOCTEST_CHECK_EQ(__VA_ARGS__) -#define REQUIRE_EQ(...) DOCTEST_REQUIRE_EQ(__VA_ARGS__) -#define WARN_NE(...) DOCTEST_WARN_NE(__VA_ARGS__) -#define CHECK_NE(...) DOCTEST_CHECK_NE(__VA_ARGS__) -#define REQUIRE_NE(...) DOCTEST_REQUIRE_NE(__VA_ARGS__) -#define WARN_GT(...) DOCTEST_WARN_GT(__VA_ARGS__) -#define CHECK_GT(...) DOCTEST_CHECK_GT(__VA_ARGS__) -#define REQUIRE_GT(...) DOCTEST_REQUIRE_GT(__VA_ARGS__) -#define WARN_LT(...) DOCTEST_WARN_LT(__VA_ARGS__) -#define CHECK_LT(...) DOCTEST_CHECK_LT(__VA_ARGS__) -#define REQUIRE_LT(...) DOCTEST_REQUIRE_LT(__VA_ARGS__) -#define WARN_GE(...) DOCTEST_WARN_GE(__VA_ARGS__) -#define CHECK_GE(...) DOCTEST_CHECK_GE(__VA_ARGS__) -#define REQUIRE_GE(...) DOCTEST_REQUIRE_GE(__VA_ARGS__) -#define WARN_LE(...) DOCTEST_WARN_LE(__VA_ARGS__) -#define CHECK_LE(...) DOCTEST_CHECK_LE(__VA_ARGS__) -#define REQUIRE_LE(...) DOCTEST_REQUIRE_LE(__VA_ARGS__) -#define WARN_UNARY(...) DOCTEST_WARN_UNARY(__VA_ARGS__) -#define CHECK_UNARY(...) DOCTEST_CHECK_UNARY(__VA_ARGS__) -#define REQUIRE_UNARY(...) DOCTEST_REQUIRE_UNARY(__VA_ARGS__) -#define WARN_UNARY_FALSE(...) DOCTEST_WARN_UNARY_FALSE(__VA_ARGS__) -#define CHECK_UNARY_FALSE(...) DOCTEST_CHECK_UNARY_FALSE(__VA_ARGS__) -#define REQUIRE_UNARY_FALSE(...) DOCTEST_REQUIRE_UNARY_FALSE(__VA_ARGS__) - -// KEPT FOR BACKWARDS COMPATIBILITY -#define FAST_WARN_EQ(...) DOCTEST_FAST_WARN_EQ(__VA_ARGS__) -#define FAST_CHECK_EQ(...) DOCTEST_FAST_CHECK_EQ(__VA_ARGS__) -#define FAST_REQUIRE_EQ(...) DOCTEST_FAST_REQUIRE_EQ(__VA_ARGS__) -#define FAST_WARN_NE(...) DOCTEST_FAST_WARN_NE(__VA_ARGS__) -#define FAST_CHECK_NE(...) DOCTEST_FAST_CHECK_NE(__VA_ARGS__) -#define FAST_REQUIRE_NE(...) DOCTEST_FAST_REQUIRE_NE(__VA_ARGS__) -#define FAST_WARN_GT(...) DOCTEST_FAST_WARN_GT(__VA_ARGS__) -#define FAST_CHECK_GT(...) DOCTEST_FAST_CHECK_GT(__VA_ARGS__) -#define FAST_REQUIRE_GT(...) DOCTEST_FAST_REQUIRE_GT(__VA_ARGS__) -#define FAST_WARN_LT(...) DOCTEST_FAST_WARN_LT(__VA_ARGS__) -#define FAST_CHECK_LT(...) DOCTEST_FAST_CHECK_LT(__VA_ARGS__) -#define FAST_REQUIRE_LT(...) DOCTEST_FAST_REQUIRE_LT(__VA_ARGS__) -#define FAST_WARN_GE(...) DOCTEST_FAST_WARN_GE(__VA_ARGS__) -#define FAST_CHECK_GE(...) DOCTEST_FAST_CHECK_GE(__VA_ARGS__) -#define FAST_REQUIRE_GE(...) DOCTEST_FAST_REQUIRE_GE(__VA_ARGS__) -#define FAST_WARN_LE(...) DOCTEST_FAST_WARN_LE(__VA_ARGS__) -#define FAST_CHECK_LE(...) DOCTEST_FAST_CHECK_LE(__VA_ARGS__) -#define FAST_REQUIRE_LE(...) DOCTEST_FAST_REQUIRE_LE(__VA_ARGS__) - -#define FAST_WARN_UNARY(...) DOCTEST_FAST_WARN_UNARY(__VA_ARGS__) -#define FAST_CHECK_UNARY(...) DOCTEST_FAST_CHECK_UNARY(__VA_ARGS__) -#define FAST_REQUIRE_UNARY(...) DOCTEST_FAST_REQUIRE_UNARY(__VA_ARGS__) -#define FAST_WARN_UNARY_FALSE(...) DOCTEST_FAST_WARN_UNARY_FALSE(__VA_ARGS__) -#define FAST_CHECK_UNARY_FALSE(...) DOCTEST_FAST_CHECK_UNARY_FALSE(__VA_ARGS__) -#define FAST_REQUIRE_UNARY_FALSE(...) DOCTEST_FAST_REQUIRE_UNARY_FALSE(__VA_ARGS__) - -#define TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, __VA_ARGS__) - -#endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES - -#ifndef DOCTEST_CONFIG_DISABLE - -// this is here to clear the 'current test suite' for the current translation unit - at the top -DOCTEST_TEST_SUITE_END(); - -#endif // DOCTEST_CONFIG_DISABLE - -DOCTEST_CLANG_SUPPRESS_WARNING_POP -DOCTEST_MSVC_SUPPRESS_WARNING_POP -DOCTEST_GCC_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_POP - -#endif // DOCTEST_LIBRARY_INCLUDED - -#ifndef DOCTEST_SINGLE_HEADER -#define DOCTEST_SINGLE_HEADER -#endif // DOCTEST_SINGLE_HEADER - -#if defined(DOCTEST_CONFIG_IMPLEMENT) || !defined(DOCTEST_SINGLE_HEADER) - -#ifndef DOCTEST_SINGLE_HEADER -#include "doctest_fwd.h" -#endif // DOCTEST_SINGLE_HEADER - -DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") - -#ifndef DOCTEST_LIBRARY_IMPLEMENTATION -#define DOCTEST_LIBRARY_IMPLEMENTATION - -DOCTEST_CLANG_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH - -DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wnonportable-system-include-path") - -DOCTEST_GCC_SUPPRESS_WARNING_PUSH -DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") -DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") -DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") -DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") -DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") -DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") -DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") -DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") -DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") - -DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data -DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled -DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified -DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal -DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch -DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C -DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) -DOCTEST_MSVC_SUPPRESS_WARNING(5245) // unreferenced function with internal linkage has been removed - -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN - -// required includes - will go only in one translation unit! -#include -#include -#include -// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/doctest/doctest/pull/37 -#ifdef __BORLANDC__ -#include -#endif // __BORLANDC__ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef DOCTEST_CONFIG_NO_MULTITHREADING -#include -#include -#define DOCTEST_DECLARE_MUTEX(name) std::mutex name; -#define DOCTEST_DECLARE_STATIC_MUTEX(name) static DOCTEST_DECLARE_MUTEX(name) -#define DOCTEST_LOCK_MUTEX(name) std::lock_guard DOCTEST_ANONYMOUS(DOCTEST_ANON_LOCK_)(name); -#else // DOCTEST_CONFIG_NO_MULTITHREADING -#define DOCTEST_DECLARE_MUTEX(name) -#define DOCTEST_DECLARE_STATIC_MUTEX(name) -#define DOCTEST_LOCK_MUTEX(name) -#endif // DOCTEST_CONFIG_NO_MULTITHREADING -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef DOCTEST_PLATFORM_MAC -#include -#include -#include -#endif // DOCTEST_PLATFORM_MAC - -#ifdef DOCTEST_PLATFORM_WINDOWS - -// defines for a leaner windows.h -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif // WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -#define NOMINMAX -#endif // NOMINMAX - -// not sure what AfxWin.h is for - here I do what Catch does -#ifdef __AFXDLL -#include -#else -#include -#endif -#include - -#else // DOCTEST_PLATFORM_WINDOWS - -#include -#include - -#endif // DOCTEST_PLATFORM_WINDOWS - -// this is a fix for https://github.com/doctest/doctest/issues/348 -// https://mail.gnome.org/archives/xml/2012-January/msg00000.html -#if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) -#define STDOUT_FILENO fileno(stdout) -#endif // HAVE_UNISTD_H - -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END - -// counts the number of elements in a C array -#define DOCTEST_COUNTOF(x) (sizeof(x) / sizeof(x[0])) - -#ifdef DOCTEST_CONFIG_DISABLE -#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_disabled -#else // DOCTEST_CONFIG_DISABLE -#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_not_disabled -#endif // DOCTEST_CONFIG_DISABLE - -#ifndef DOCTEST_CONFIG_OPTIONS_PREFIX -#define DOCTEST_CONFIG_OPTIONS_PREFIX "dt-" -#endif - -#ifndef DOCTEST_THREAD_LOCAL -#if defined(DOCTEST_CONFIG_NO_MULTITHREADING) || DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_THREAD_LOCAL -#else // DOCTEST_MSVC -#define DOCTEST_THREAD_LOCAL thread_local -#endif // DOCTEST_MSVC -#endif // DOCTEST_THREAD_LOCAL - -#ifndef DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES -#define DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES 32 -#endif - -#ifndef DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE -#define DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE 64 -#endif - -#ifdef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS -#define DOCTEST_OPTIONS_PREFIX_DISPLAY DOCTEST_CONFIG_OPTIONS_PREFIX -#else -#define DOCTEST_OPTIONS_PREFIX_DISPLAY "" -#endif - -#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) -#define DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS -#endif - -#ifndef DOCTEST_CDECL -#define DOCTEST_CDECL __cdecl -#endif - -namespace doctest { - -bool is_running_in_test = false; - -namespace { - using namespace detail; - - template - DOCTEST_NORETURN void throw_exception(Ex const& e) { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - throw e; -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - std::cerr << "doctest will terminate because it needed to throw an exception.\n" - << "The message was: " << e.what() << '\n'; - std::terminate(); -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - } - -#ifndef DOCTEST_INTERNAL_ERROR -#define DOCTEST_INTERNAL_ERROR(msg) \ - throw_exception(std::logic_error( \ - __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) -#endif // DOCTEST_INTERNAL_ERROR - - // case insensitive strcmp - int stricmp(const char* a, const char* b) { - for(;; a++, b++) { - const int d = tolower(*a) - tolower(*b); - if(d != 0 || !*a) - return d; - } - } - - struct Endianness - { - enum Arch - { - Big, - Little - }; - - static Arch which() { - int x = 1; - // casting any data pointer to char* is allowed - auto ptr = reinterpret_cast(&x); - if(*ptr) - return Little; - return Big; - } - }; -} // namespace - -namespace detail { - DOCTEST_THREAD_LOCAL class - { - std::vector stack; - std::stringstream ss; - - public: - std::ostream* push() { - stack.push_back(ss.tellp()); - return &ss; - } - - String pop() { - if (stack.empty()) - DOCTEST_INTERNAL_ERROR("TLSS was empty when trying to pop!"); - - std::streampos pos = stack.back(); - stack.pop_back(); - unsigned sz = static_cast(ss.tellp() - pos); - ss.rdbuf()->pubseekpos(pos, std::ios::in | std::ios::out); - return String(ss, sz); - } - } g_oss; - - std::ostream* tlssPush() { - return g_oss.push(); - } - - String tlssPop() { - return g_oss.pop(); - } - -#ifndef DOCTEST_CONFIG_DISABLE - -namespace timer_large_integer -{ - -#if defined(DOCTEST_PLATFORM_WINDOWS) - using type = ULONGLONG; -#else // DOCTEST_PLATFORM_WINDOWS - using type = std::uint64_t; -#endif // DOCTEST_PLATFORM_WINDOWS -} - -using ticks_t = timer_large_integer::type; - -#ifdef DOCTEST_CONFIG_GETCURRENTTICKS - ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } -#elif defined(DOCTEST_PLATFORM_WINDOWS) - ticks_t getCurrentTicks() { - static LARGE_INTEGER hz = { {0} }, hzo = { {0} }; - if(!hz.QuadPart) { - QueryPerformanceFrequency(&hz); - QueryPerformanceCounter(&hzo); - } - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return ((t.QuadPart - hzo.QuadPart) * LONGLONG(1000000)) / hz.QuadPart; - } -#else // DOCTEST_PLATFORM_WINDOWS - ticks_t getCurrentTicks() { - timeval t; - gettimeofday(&t, nullptr); - return static_cast(t.tv_sec) * 1000000 + static_cast(t.tv_usec); - } -#endif // DOCTEST_PLATFORM_WINDOWS - - struct Timer - { - void start() { m_ticks = getCurrentTicks(); } - unsigned int getElapsedMicroseconds() const { - return static_cast(getCurrentTicks() - m_ticks); - } - //unsigned int getElapsedMilliseconds() const { - // return static_cast(getElapsedMicroseconds() / 1000); - //} - double getElapsedSeconds() const { return static_cast(getCurrentTicks() - m_ticks) / 1000000.0; } - - private: - ticks_t m_ticks = 0; - }; - -#ifdef DOCTEST_CONFIG_NO_MULTITHREADING - template - using Atomic = T; -#else // DOCTEST_CONFIG_NO_MULTITHREADING - template - using Atomic = std::atomic; -#endif // DOCTEST_CONFIG_NO_MULTITHREADING - -#if defined(DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS) || defined(DOCTEST_CONFIG_NO_MULTITHREADING) - template - using MultiLaneAtomic = Atomic; -#else // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS - // Provides a multilane implementation of an atomic variable that supports add, sub, load, - // store. Instead of using a single atomic variable, this splits up into multiple ones, - // each sitting on a separate cache line. The goal is to provide a speedup when most - // operations are modifying. It achieves this with two properties: - // - // * Multiple atomics are used, so chance of congestion from the same atomic is reduced. - // * Each atomic sits on a separate cache line, so false sharing is reduced. - // - // The disadvantage is that there is a small overhead due to the use of TLS, and load/store - // is slower because all atomics have to be accessed. - template - class MultiLaneAtomic - { - struct CacheLineAlignedAtomic - { - Atomic atomic{}; - char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(Atomic)]; - }; - CacheLineAlignedAtomic m_atomics[DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES]; - - static_assert(sizeof(CacheLineAlignedAtomic) == DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE, - "guarantee one atomic takes exactly one cache line"); - - public: - T operator++() DOCTEST_NOEXCEPT { return fetch_add(1) + 1; } - - T operator++(int) DOCTEST_NOEXCEPT { return fetch_add(1); } - - T fetch_add(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { - return myAtomic().fetch_add(arg, order); - } - - T fetch_sub(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { - return myAtomic().fetch_sub(arg, order); - } - - operator T() const DOCTEST_NOEXCEPT { return load(); } - - T load(std::memory_order order = std::memory_order_seq_cst) const DOCTEST_NOEXCEPT { - auto result = T(); - for(auto const& c : m_atomics) { - result += c.atomic.load(order); - } - return result; - } - - T operator=(T desired) DOCTEST_NOEXCEPT { // lgtm [cpp/assignment-does-not-return-this] - store(desired); - return desired; - } - - void store(T desired, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { - // first value becomes desired", all others become 0. - for(auto& c : m_atomics) { - c.atomic.store(desired, order); - desired = {}; - } - } - - private: - // Each thread has a different atomic that it operates on. If more than NumLanes threads - // use this, some will use the same atomic. So performance will degrade a bit, but still - // everything will work. - // - // The logic here is a bit tricky. The call should be as fast as possible, so that there - // is minimal to no overhead in determining the correct atomic for the current thread. - // - // 1. A global static counter laneCounter counts continuously up. - // 2. Each successive thread will use modulo operation of that counter so it gets an atomic - // assigned in a round-robin fashion. - // 3. This tlsLaneIdx is stored in the thread local data, so it is directly available with - // little overhead. - Atomic& myAtomic() DOCTEST_NOEXCEPT { - static Atomic laneCounter; - DOCTEST_THREAD_LOCAL size_t tlsLaneIdx = - laneCounter++ % DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES; - - return m_atomics[tlsLaneIdx].atomic; - } - }; -#endif // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS - - // this holds both parameters from the command line and runtime data for tests - struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats - { - MultiLaneAtomic numAssertsCurrentTest_atomic; - MultiLaneAtomic numAssertsFailedCurrentTest_atomic; - - std::vector> filters = decltype(filters)(9); // 9 different filters - - std::vector reporters_currently_used; - - assert_handler ah = nullptr; - - Timer timer; - - std::vector stringifiedContexts; // logging from INFO() due to an exception - - // stuff for subcases - bool reachedLeaf; - std::vector subcaseStack; - std::vector nextSubcaseStack; - std::unordered_set fullyTraversedSubcases; - size_t currentSubcaseDepth; - Atomic shouldLogCurrentException; - - void resetRunData() { - numTestCases = 0; - numTestCasesPassingFilters = 0; - numTestSuitesPassingFilters = 0; - numTestCasesFailed = 0; - numAsserts = 0; - numAssertsFailed = 0; - numAssertsCurrentTest = 0; - numAssertsFailedCurrentTest = 0; - } - - void finalizeTestCaseData() { - seconds = timer.getElapsedSeconds(); - - // update the non-atomic counters - numAsserts += numAssertsCurrentTest_atomic; - numAssertsFailed += numAssertsFailedCurrentTest_atomic; - numAssertsCurrentTest = numAssertsCurrentTest_atomic; - numAssertsFailedCurrentTest = numAssertsFailedCurrentTest_atomic; - - if(numAssertsFailedCurrentTest) - failure_flags |= TestCaseFailureReason::AssertFailure; - - if(Approx(currentTest->m_timeout).epsilon(DBL_EPSILON) != 0 && - Approx(seconds).epsilon(DBL_EPSILON) > currentTest->m_timeout) - failure_flags |= TestCaseFailureReason::Timeout; - - if(currentTest->m_should_fail) { - if(failure_flags) { - failure_flags |= TestCaseFailureReason::ShouldHaveFailedAndDid; - } else { - failure_flags |= TestCaseFailureReason::ShouldHaveFailedButDidnt; - } - } else if(failure_flags && currentTest->m_may_fail) { - failure_flags |= TestCaseFailureReason::CouldHaveFailedAndDid; - } else if(currentTest->m_expected_failures > 0) { - if(numAssertsFailedCurrentTest == currentTest->m_expected_failures) { - failure_flags |= TestCaseFailureReason::FailedExactlyNumTimes; - } else { - failure_flags |= TestCaseFailureReason::DidntFailExactlyNumTimes; - } - } - - bool ok_to_fail = (TestCaseFailureReason::ShouldHaveFailedAndDid & failure_flags) || - (TestCaseFailureReason::CouldHaveFailedAndDid & failure_flags) || - (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); - - // if any subcase has failed - the whole test case has failed - testCaseSuccess = !(failure_flags && !ok_to_fail); - if(!testCaseSuccess) - numTestCasesFailed++; - } - }; - - ContextState* g_cs = nullptr; - - // used to avoid locks for the debug output - // TODO: figure out if this is indeed necessary/correct - seems like either there still - // could be a race or that there wouldn't be a race even if using the context directly - DOCTEST_THREAD_LOCAL bool g_no_colors; - -#endif // DOCTEST_CONFIG_DISABLE -} // namespace detail - -char* String::allocate(size_type sz) { - if (sz <= last) { - buf[sz] = '\0'; - setLast(last - sz); - return buf; - } else { - setOnHeap(); - data.size = sz; - data.capacity = data.size + 1; - data.ptr = new char[data.capacity]; - data.ptr[sz] = '\0'; - return data.ptr; - } -} - -void String::setOnHeap() noexcept { *reinterpret_cast(&buf[last]) = 128; } -void String::setLast(size_type in) noexcept { buf[last] = char(in); } -void String::setSize(size_type sz) noexcept { - if (isOnStack()) { buf[sz] = '\0'; setLast(last - sz); } - else { data.ptr[sz] = '\0'; data.size = sz; } -} - -void String::copy(const String& other) { - if(other.isOnStack()) { - memcpy(buf, other.buf, len); - } else { - memcpy(allocate(other.data.size), other.data.ptr, other.data.size); - } -} - -String::String() noexcept { - buf[0] = '\0'; - setLast(); -} - -String::~String() { - if(!isOnStack()) - delete[] data.ptr; -} // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - -String::String(const char* in) - : String(in, strlen(in)) {} - -String::String(const char* in, size_type in_size) { - memcpy(allocate(in_size), in, in_size); -} - -String::String(std::istream& in, size_type in_size) { - in.read(allocate(in_size), in_size); -} - -String::String(const String& other) { copy(other); } - -String& String::operator=(const String& other) { - if(this != &other) { - if(!isOnStack()) - delete[] data.ptr; - - copy(other); - } - - return *this; -} - -String& String::operator+=(const String& other) { - const size_type my_old_size = size(); - const size_type other_size = other.size(); - const size_type total_size = my_old_size + other_size; - if(isOnStack()) { - if(total_size < len) { - // append to the current stack space - memcpy(buf + my_old_size, other.c_str(), other_size + 1); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - setLast(last - total_size); - } else { - // alloc new chunk - char* temp = new char[total_size + 1]; - // copy current data to new location before writing in the union - memcpy(temp, buf, my_old_size); // skip the +1 ('\0') for speed - // update data in union - setOnHeap(); - data.size = total_size; - data.capacity = data.size + 1; - data.ptr = temp; - // transfer the rest of the data - memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); - } - } else { - if(data.capacity > total_size) { - // append to the current heap block - data.size = total_size; - memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); - } else { - // resize - data.capacity *= 2; - if(data.capacity <= total_size) - data.capacity = total_size + 1; - // alloc new chunk - char* temp = new char[data.capacity]; - // copy current data to new location before releasing it - memcpy(temp, data.ptr, my_old_size); // skip the +1 ('\0') for speed - // release old chunk - delete[] data.ptr; - // update the rest of the union members - data.size = total_size; - data.ptr = temp; - // transfer the rest of the data - memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); - } - } - - return *this; -} - -String::String(String&& other) noexcept { - memcpy(buf, other.buf, len); - other.buf[0] = '\0'; - other.setLast(); -} - -String& String::operator=(String&& other) noexcept { - if(this != &other) { - if(!isOnStack()) - delete[] data.ptr; - memcpy(buf, other.buf, len); - other.buf[0] = '\0'; - other.setLast(); - } - return *this; -} - -char String::operator[](size_type i) const { - return const_cast(this)->operator[](i); -} - -char& String::operator[](size_type i) { - if(isOnStack()) - return reinterpret_cast(buf)[i]; - return data.ptr[i]; -} - -DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") -String::size_type String::size() const { - if(isOnStack()) - return last - (size_type(buf[last]) & 31); // using "last" would work only if "len" is 32 - return data.size; -} -DOCTEST_GCC_SUPPRESS_WARNING_POP - -String::size_type String::capacity() const { - if(isOnStack()) - return len; - return data.capacity; -} - -String String::substr(size_type pos, size_type cnt) && { - cnt = std::min(cnt, size() - 1 - pos); - char* cptr = c_str(); - memmove(cptr, cptr + pos, cnt); - setSize(cnt); - return std::move(*this); -} - -String String::substr(size_type pos, size_type cnt) const & { - cnt = std::min(cnt, size() - 1 - pos); - return String{ c_str() + pos, cnt }; -} - -String::size_type String::find(char ch, size_type pos) const { - const char* begin = c_str(); - const char* end = begin + size(); - const char* it = begin + pos; - for (; it < end && *it != ch; it++); - if (it < end) { return static_cast(it - begin); } - else { return npos; } -} - -String::size_type String::rfind(char ch, size_type pos) const { - const char* begin = c_str(); - const char* it = begin + std::min(pos, size() - 1); - for (; it >= begin && *it != ch; it--); - if (it >= begin) { return static_cast(it - begin); } - else { return npos; } -} - -int String::compare(const char* other, bool no_case) const { - if(no_case) - return doctest::stricmp(c_str(), other); - return std::strcmp(c_str(), other); -} - -int String::compare(const String& other, bool no_case) const { - return compare(other.c_str(), no_case); -} - -String operator+(const String& lhs, const String& rhs) { return String(lhs) += rhs; } - -bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } -bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } -bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } -bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } -bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } -bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } - -std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } - -Contains::Contains(const String& str) : string(str) { } - -bool Contains::checkWith(const String& other) const { - return strstr(other.c_str(), string.c_str()) != nullptr; -} - -String toString(const Contains& in) { - return "Contains( " + in.string + " )"; -} - -bool operator==(const String& lhs, const Contains& rhs) { return rhs.checkWith(lhs); } -bool operator==(const Contains& lhs, const String& rhs) { return lhs.checkWith(rhs); } -bool operator!=(const String& lhs, const Contains& rhs) { return !rhs.checkWith(lhs); } -bool operator!=(const Contains& lhs, const String& rhs) { return !lhs.checkWith(rhs); } - -namespace { - void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) -} // namespace - -namespace Color { - std::ostream& operator<<(std::ostream& s, Color::Enum code) { - color_to_stream(s, code); - return s; - } -} // namespace Color - -// clang-format off -const char* assertString(assertType::Enum at) { - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4061) // enum 'x' in switch of enum 'y' is not explicitely handled - #define DOCTEST_GENERATE_ASSERT_TYPE_CASE(assert_type) case assertType::DT_ ## assert_type: return #assert_type - #define DOCTEST_GENERATE_ASSERT_TYPE_CASES(assert_type) \ - DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN_ ## assert_type); \ - DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK_ ## assert_type); \ - DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE_ ## assert_type) - switch(at) { - DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN); - DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK); - DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(FALSE); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_AS); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH_AS); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(NOTHROW); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(EQ); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(NE); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(GT); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(LT); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(GE); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(LE); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY_FALSE); - - default: DOCTEST_INTERNAL_ERROR("Tried stringifying invalid assert type!"); - } - DOCTEST_MSVC_SUPPRESS_WARNING_POP -} -// clang-format on - -const char* failureString(assertType::Enum at) { - if(at & assertType::is_warn) //!OCLINT bitwise operator in conditional - return "WARNING"; - if(at & assertType::is_check) //!OCLINT bitwise operator in conditional - return "ERROR"; - if(at & assertType::is_require) //!OCLINT bitwise operator in conditional - return "FATAL ERROR"; - return ""; -} - -DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") -DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") -// depending on the current options this will remove the path of filenames -const char* skipPathFromFilename(const char* file) { -#ifndef DOCTEST_CONFIG_DISABLE - if(getContextOptions()->no_path_in_filenames) { - auto back = std::strrchr(file, '\\'); - auto forward = std::strrchr(file, '/'); - if(back || forward) { - if(back > forward) - forward = back; - return forward + 1; - } - } -#endif // DOCTEST_CONFIG_DISABLE - return file; -} -DOCTEST_CLANG_SUPPRESS_WARNING_POP -DOCTEST_GCC_SUPPRESS_WARNING_POP - -bool SubcaseSignature::operator==(const SubcaseSignature& other) const { - return m_line == other.m_line - && std::strcmp(m_file, other.m_file) == 0 - && m_name == other.m_name; -} - -bool SubcaseSignature::operator<(const SubcaseSignature& other) const { - if(m_line != other.m_line) - return m_line < other.m_line; - if(std::strcmp(m_file, other.m_file) != 0) - return std::strcmp(m_file, other.m_file) < 0; - return m_name.compare(other.m_name) < 0; -} - -DOCTEST_DEFINE_INTERFACE(IContextScope) - -namespace detail { - void filldata::fill(std::ostream* stream, const void* in) { - if (in) { *stream << in; } - else { *stream << "nullptr"; } - } - - template - String toStreamLit(T t) { - std::ostream* os = tlssPush(); - os->operator<<(t); - return tlssPop(); - } -} - -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 -String toString(const std::string& in) { return in.c_str(); } -#endif // VS 2019 - -String toString(String in) { return in; } - -String toString(std::nullptr_t) { return "nullptr"; } - -String toString(bool in) { return in ? "true" : "false"; } - -String toString(float in) { return toStreamLit(in); } -String toString(double in) { return toStreamLit(in); } -String toString(double long in) { return toStreamLit(in); } - -String toString(char in) { return toStreamLit(static_cast(in)); } -String toString(char signed in) { return toStreamLit(static_cast(in)); } -String toString(char unsigned in) { return toStreamLit(static_cast(in)); } -String toString(short in) { return toStreamLit(in); } -String toString(short unsigned in) { return toStreamLit(in); } -String toString(signed in) { return toStreamLit(in); } -String toString(unsigned in) { return toStreamLit(in); } -String toString(long in) { return toStreamLit(in); } -String toString(long unsigned in) { return toStreamLit(in); } -String toString(long long in) { return toStreamLit(in); } -String toString(long long unsigned in) { return toStreamLit(in); } - -Approx::Approx(double value) - : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) - , m_scale(1.0) - , m_value(value) {} - -Approx Approx::operator()(double value) const { - Approx approx(value); - approx.epsilon(m_epsilon); - approx.scale(m_scale); - return approx; -} - -Approx& Approx::epsilon(double newEpsilon) { - m_epsilon = newEpsilon; - return *this; -} -Approx& Approx::scale(double newScale) { - m_scale = newScale; - return *this; -} - -bool operator==(double lhs, const Approx& rhs) { - // Thanks to Richard Harris for his help refining this formula - return std::fabs(lhs - rhs.m_value) < - rhs.m_epsilon * (rhs.m_scale + std::max(std::fabs(lhs), std::fabs(rhs.m_value))); -} -bool operator==(const Approx& lhs, double rhs) { return operator==(rhs, lhs); } -bool operator!=(double lhs, const Approx& rhs) { return !operator==(lhs, rhs); } -bool operator!=(const Approx& lhs, double rhs) { return !operator==(rhs, lhs); } -bool operator<=(double lhs, const Approx& rhs) { return lhs < rhs.m_value || lhs == rhs; } -bool operator<=(const Approx& lhs, double rhs) { return lhs.m_value < rhs || lhs == rhs; } -bool operator>=(double lhs, const Approx& rhs) { return lhs > rhs.m_value || lhs == rhs; } -bool operator>=(const Approx& lhs, double rhs) { return lhs.m_value > rhs || lhs == rhs; } -bool operator<(double lhs, const Approx& rhs) { return lhs < rhs.m_value && lhs != rhs; } -bool operator<(const Approx& lhs, double rhs) { return lhs.m_value < rhs && lhs != rhs; } -bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs != rhs; } -bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } - -String toString(const Approx& in) { - return "Approx( " + doctest::toString(in.m_value) + " )"; -} -const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } - -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4738) -template -IsNaN::operator bool() const { - return std::isnan(value) ^ flipped; -} -DOCTEST_MSVC_SUPPRESS_WARNING_POP -template struct DOCTEST_INTERFACE_DEF IsNaN; -template struct DOCTEST_INTERFACE_DEF IsNaN; -template struct DOCTEST_INTERFACE_DEF IsNaN; -template -String toString(IsNaN in) { return String(in.flipped ? "! " : "") + "IsNaN( " + doctest::toString(in.value) + " )"; } -String toString(IsNaN in) { return toString(in); } -String toString(IsNaN in) { return toString(in); } -String toString(IsNaN in) { return toString(in); } - -} // namespace doctest - -#ifdef DOCTEST_CONFIG_DISABLE -namespace doctest { -Context::Context(int, const char* const*) {} -Context::~Context() = default; -void Context::applyCommandLine(int, const char* const*) {} -void Context::addFilter(const char*, const char*) {} -void Context::clearFilters() {} -void Context::setOption(const char*, bool) {} -void Context::setOption(const char*, int) {} -void Context::setOption(const char*, const char*) {} -bool Context::shouldExit() { return false; } -void Context::setAsDefaultForAssertsOutOfTestCases() {} -void Context::setAssertHandler(detail::assert_handler) {} -void Context::setCout(std::ostream*) {} -int Context::run() { return 0; } - -int IReporter::get_num_active_contexts() { return 0; } -const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } -int IReporter::get_num_stringified_contexts() { return 0; } -const String* IReporter::get_stringified_contexts() { return nullptr; } - -int registerReporter(const char*, int, IReporter*) { return 0; } - -} // namespace doctest -#else // DOCTEST_CONFIG_DISABLE - -#if !defined(DOCTEST_CONFIG_COLORS_NONE) -#if !defined(DOCTEST_CONFIG_COLORS_WINDOWS) && !defined(DOCTEST_CONFIG_COLORS_ANSI) -#ifdef DOCTEST_PLATFORM_WINDOWS -#define DOCTEST_CONFIG_COLORS_WINDOWS -#else // linux -#define DOCTEST_CONFIG_COLORS_ANSI -#endif // platform -#endif // DOCTEST_CONFIG_COLORS_WINDOWS && DOCTEST_CONFIG_COLORS_ANSI -#endif // DOCTEST_CONFIG_COLORS_NONE - -namespace doctest_detail_test_suite_ns { -// holds the current test suite -doctest::detail::TestSuite& getCurrentTestSuite() { - static doctest::detail::TestSuite data{}; - return data; -} -} // namespace doctest_detail_test_suite_ns - -namespace doctest { -namespace { - // the int (priority) is part of the key for automatic sorting - sadly one can register a - // reporter with a duplicate name and a different priority but hopefully that won't happen often :| - using reporterMap = std::map, reporterCreatorFunc>; - - reporterMap& getReporters() { - static reporterMap data; - return data; - } - reporterMap& getListeners() { - static reporterMap data; - return data; - } -} // namespace -namespace detail { -#define DOCTEST_ITERATE_THROUGH_REPORTERS(function, ...) \ - for(auto& curr_rep : g_cs->reporters_currently_used) \ - curr_rep->function(__VA_ARGS__) - - bool checkIfShouldThrow(assertType::Enum at) { - if(at & assertType::is_require) //!OCLINT bitwise operator in conditional - return true; - - if((at & assertType::is_check) //!OCLINT bitwise operator in conditional - && getContextOptions()->abort_after > 0 && - (g_cs->numAssertsFailed + g_cs->numAssertsFailedCurrentTest_atomic) >= - getContextOptions()->abort_after) - return true; - - return false; - } - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - DOCTEST_NORETURN void throwException() { - g_cs->shouldLogCurrentException = false; - throw TestFailureException(); // NOLINT(hicpp-exception-baseclass) - } -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - void throwException() {} -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS -} // namespace detail - -namespace { - using namespace detail; - // matching of a string against a wildcard mask (case sensitivity configurable) taken from - // https://www.codeproject.com/Articles/1088/Wildcard-string-compare-globbing - int wildcmp(const char* str, const char* wild, bool caseSensitive) { - const char* cp = str; - const char* mp = wild; - - while((*str) && (*wild != '*')) { - if((caseSensitive ? (*wild != *str) : (tolower(*wild) != tolower(*str))) && - (*wild != '?')) { - return 0; - } - wild++; - str++; - } - - while(*str) { - if(*wild == '*') { - if(!*++wild) { - return 1; - } - mp = wild; - cp = str + 1; - } else if((caseSensitive ? (*wild == *str) : (tolower(*wild) == tolower(*str))) || - (*wild == '?')) { - wild++; - str++; - } else { - wild = mp; //!OCLINT parameter reassignment - str = cp++; //!OCLINT parameter reassignment - } - } - - while(*wild == '*') { - wild++; - } - return !*wild; - } - - // checks if the name matches any of the filters (and can be configured what to do when empty) - bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, - bool caseSensitive) { - if (filters.empty() && matchEmpty) - return true; - for (auto& curr : filters) - if (wildcmp(name, curr.c_str(), caseSensitive)) - return true; - return false; - } - - unsigned long long hash(unsigned long long a, unsigned long long b) { - return (a << 5) + b; - } - - // C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html - unsigned long long hash(const char* str) { - unsigned long long hash = 5381; - char c; - while ((c = *str++)) - hash = ((hash << 5) + hash) + c; // hash * 33 + c - return hash; - } - - unsigned long long hash(const SubcaseSignature& sig) { - return hash(hash(hash(sig.m_file), hash(sig.m_name.c_str())), sig.m_line); - } - - unsigned long long hash(const std::vector& sigs, size_t count) { - unsigned long long running = 0; - auto end = sigs.begin() + count; - for (auto it = sigs.begin(); it != end; it++) { - running = hash(running, hash(*it)); - } - return running; - } - - unsigned long long hash(const std::vector& sigs) { - unsigned long long running = 0; - for (const SubcaseSignature& sig : sigs) { - running = hash(running, hash(sig)); - } - return running; - } -} // namespace -namespace detail { - bool Subcase::checkFilters() { - if (g_cs->subcaseStack.size() < size_t(g_cs->subcase_filter_levels)) { - if (!matchesAny(m_signature.m_name.c_str(), g_cs->filters[6], true, g_cs->case_sensitive)) - return true; - if (matchesAny(m_signature.m_name.c_str(), g_cs->filters[7], false, g_cs->case_sensitive)) - return true; - } - return false; - } - - Subcase::Subcase(const String& name, const char* file, int line) - : m_signature({name, file, line}) { - if (!g_cs->reachedLeaf) { - if (g_cs->nextSubcaseStack.size() <= g_cs->subcaseStack.size() - || g_cs->nextSubcaseStack[g_cs->subcaseStack.size()] == m_signature) { - // Going down. - if (checkFilters()) { return; } - - g_cs->subcaseStack.push_back(m_signature); - g_cs->currentSubcaseDepth++; - m_entered = true; - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); - } - } else { - if (g_cs->subcaseStack[g_cs->currentSubcaseDepth] == m_signature) { - // This subcase is reentered via control flow. - g_cs->currentSubcaseDepth++; - m_entered = true; - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); - } else if (g_cs->nextSubcaseStack.size() <= g_cs->currentSubcaseDepth - && g_cs->fullyTraversedSubcases.find(hash(hash(g_cs->subcaseStack, g_cs->currentSubcaseDepth), hash(m_signature))) - == g_cs->fullyTraversedSubcases.end()) { - if (checkFilters()) { return; } - // This subcase is part of the one to be executed next. - g_cs->nextSubcaseStack.clear(); - g_cs->nextSubcaseStack.insert(g_cs->nextSubcaseStack.end(), - g_cs->subcaseStack.begin(), g_cs->subcaseStack.begin() + g_cs->currentSubcaseDepth); - g_cs->nextSubcaseStack.push_back(m_signature); - } - } - } - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - - Subcase::~Subcase() { - if (m_entered) { - g_cs->currentSubcaseDepth--; - - if (!g_cs->reachedLeaf) { - // Leaf. - g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); - g_cs->nextSubcaseStack.clear(); - g_cs->reachedLeaf = true; - } else if (g_cs->nextSubcaseStack.empty()) { - // All children are finished. - g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); - } - -#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) - if(std::uncaught_exceptions() > 0 -#else - if(std::uncaught_exception() -#endif - && g_cs->shouldLogCurrentException) { - DOCTEST_ITERATE_THROUGH_REPORTERS( - test_case_exception, {"exception thrown in subcase - will translate later " - "when the whole test case has been exited (cannot " - "translate while there is an active exception)", - false}); - g_cs->shouldLogCurrentException = false; - } - - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); - } - } - - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP - DOCTEST_MSVC_SUPPRESS_WARNING_POP - - Subcase::operator bool() const { return m_entered; } - - Result::Result(bool passed, const String& decomposition) - : m_passed(passed) - , m_decomp(decomposition) {} - - ExpressionDecomposer::ExpressionDecomposer(assertType::Enum at) - : m_at(at) {} - - TestSuite& TestSuite::operator*(const char* in) { - m_test_suite = in; - return *this; - } - - TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, - const String& type, int template_id) { - m_file = file; - m_line = line; - m_name = nullptr; // will be later overridden in operator* - m_test_suite = test_suite.m_test_suite; - m_description = test_suite.m_description; - m_skip = test_suite.m_skip; - m_no_breaks = test_suite.m_no_breaks; - m_no_output = test_suite.m_no_output; - m_may_fail = test_suite.m_may_fail; - m_should_fail = test_suite.m_should_fail; - m_expected_failures = test_suite.m_expected_failures; - m_timeout = test_suite.m_timeout; - - m_test = test; - m_type = type; - m_template_id = template_id; - } - - TestCase::TestCase(const TestCase& other) - : TestCaseData() { - *this = other; - } - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function - TestCase& TestCase::operator=(const TestCase& other) { - TestCaseData::operator=(other); - m_test = other.m_test; - m_type = other.m_type; - m_template_id = other.m_template_id; - m_full_name = other.m_full_name; - - if(m_template_id != -1) - m_name = m_full_name.c_str(); - return *this; - } - DOCTEST_MSVC_SUPPRESS_WARNING_POP - - TestCase& TestCase::operator*(const char* in) { - m_name = in; - // make a new name with an appended type for templated test case - if(m_template_id != -1) { - m_full_name = String(m_name) + "<" + m_type + ">"; - // redirect the name to point to the newly constructed full name - m_name = m_full_name.c_str(); - } - return *this; - } - - bool TestCase::operator<(const TestCase& other) const { - // this will be used only to differentiate between test cases - not relevant for sorting - if(m_line != other.m_line) - return m_line < other.m_line; - const int name_cmp = strcmp(m_name, other.m_name); - if(name_cmp != 0) - return name_cmp < 0; - const int file_cmp = m_file.compare(other.m_file); - if(file_cmp != 0) - return file_cmp < 0; - return m_template_id < other.m_template_id; - } - - // all the registered tests - std::set& getRegisteredTests() { - static std::set data; - return data; - } -} // namespace detail -namespace { - using namespace detail; - // for sorting tests by file/line - bool fileOrderComparator(const TestCase* lhs, const TestCase* rhs) { - // this is needed because MSVC gives different case for drive letters - // for __FILE__ when evaluated in a header and a source file - const int res = lhs->m_file.compare(rhs->m_file, bool(DOCTEST_MSVC)); - if(res != 0) - return res < 0; - if(lhs->m_line != rhs->m_line) - return lhs->m_line < rhs->m_line; - return lhs->m_template_id < rhs->m_template_id; - } - - // for sorting tests by suite/file/line - bool suiteOrderComparator(const TestCase* lhs, const TestCase* rhs) { - const int res = std::strcmp(lhs->m_test_suite, rhs->m_test_suite); - if(res != 0) - return res < 0; - return fileOrderComparator(lhs, rhs); - } - - // for sorting tests by name/suite/file/line - bool nameOrderComparator(const TestCase* lhs, const TestCase* rhs) { - const int res = std::strcmp(lhs->m_name, rhs->m_name); - if(res != 0) - return res < 0; - return suiteOrderComparator(lhs, rhs); - } - - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - void color_to_stream(std::ostream& s, Color::Enum code) { - static_cast(s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS - static_cast(code); // for DOCTEST_CONFIG_COLORS_NONE -#ifdef DOCTEST_CONFIG_COLORS_ANSI - if(g_no_colors || - (isatty(STDOUT_FILENO) == false && getContextOptions()->force_colors == false)) - return; - - auto col = ""; - // clang-format off - switch(code) { //!OCLINT missing break in switch statement / unnecessary default statement in covered switch statement - case Color::Red: col = "[0;31m"; break; - case Color::Green: col = "[0;32m"; break; - case Color::Blue: col = "[0;34m"; break; - case Color::Cyan: col = "[0;36m"; break; - case Color::Yellow: col = "[0;33m"; break; - case Color::Grey: col = "[1;30m"; break; - case Color::LightGrey: col = "[0;37m"; break; - case Color::BrightRed: col = "[1;31m"; break; - case Color::BrightGreen: col = "[1;32m"; break; - case Color::BrightWhite: col = "[1;37m"; break; - case Color::Bright: // invalid - case Color::None: - case Color::White: - default: col = "[0m"; - } - // clang-format on - s << "\033" << col; -#endif // DOCTEST_CONFIG_COLORS_ANSI - -#ifdef DOCTEST_CONFIG_COLORS_WINDOWS - if(g_no_colors || - (_isatty(_fileno(stdout)) == false && getContextOptions()->force_colors == false)) - return; - - static struct ConsoleHelper { - HANDLE stdoutHandle; - WORD origFgAttrs; - WORD origBgAttrs; - - ConsoleHelper() { - stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); - CONSOLE_SCREEN_BUFFER_INFO csbiInfo; - GetConsoleScreenBufferInfo(stdoutHandle, &csbiInfo); - origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | - BACKGROUND_BLUE | BACKGROUND_INTENSITY); - origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | - FOREGROUND_BLUE | FOREGROUND_INTENSITY); - } - } ch; - -#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(ch.stdoutHandle, x | ch.origBgAttrs) - - // clang-format off - switch (code) { - case Color::White: DOCTEST_SET_ATTR(FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; - case Color::Red: DOCTEST_SET_ATTR(FOREGROUND_RED); break; - case Color::Green: DOCTEST_SET_ATTR(FOREGROUND_GREEN); break; - case Color::Blue: DOCTEST_SET_ATTR(FOREGROUND_BLUE); break; - case Color::Cyan: DOCTEST_SET_ATTR(FOREGROUND_BLUE | FOREGROUND_GREEN); break; - case Color::Yellow: DOCTEST_SET_ATTR(FOREGROUND_RED | FOREGROUND_GREEN); break; - case Color::Grey: DOCTEST_SET_ATTR(0); break; - case Color::LightGrey: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY); break; - case Color::BrightRed: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_RED); break; - case Color::BrightGreen: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN); break; - case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; - case Color::None: - case Color::Bright: // invalid - default: DOCTEST_SET_ATTR(ch.origFgAttrs); - } - // clang-format on -#endif // DOCTEST_CONFIG_COLORS_WINDOWS - } - DOCTEST_CLANG_SUPPRESS_WARNING_POP - - std::vector& getExceptionTranslators() { - static std::vector data; - return data; - } - - String translateActiveException() { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - String res; - auto& translators = getExceptionTranslators(); - for(auto& curr : translators) - if(curr->translate(res)) - return res; - // clang-format off - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wcatch-value") - try { - throw; - } catch(std::exception& ex) { - return ex.what(); - } catch(std::string& msg) { - return msg.c_str(); - } catch(const char* msg) { - return msg; - } catch(...) { - return "unknown exception"; - } - DOCTEST_GCC_SUPPRESS_WARNING_POP -// clang-format on -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - return ""; -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - } -} // namespace - -namespace detail { - // used by the macros for registering tests - int regTest(const TestCase& tc) { - getRegisteredTests().insert(tc); - return 0; - } - - // sets the current test suite - int setTestSuite(const TestSuite& ts) { - doctest_detail_test_suite_ns::getCurrentTestSuite() = ts; - return 0; - } - -#ifdef DOCTEST_IS_DEBUGGER_ACTIVE - bool isDebuggerActive() { return DOCTEST_IS_DEBUGGER_ACTIVE(); } -#else // DOCTEST_IS_DEBUGGER_ACTIVE -#ifdef DOCTEST_PLATFORM_LINUX - class ErrnoGuard { - public: - ErrnoGuard() : m_oldErrno(errno) {} - ~ErrnoGuard() { errno = m_oldErrno; } - private: - int m_oldErrno; - }; - // See the comments in Catch2 for the reasoning behind this implementation: - // https://github.com/catchorg/Catch2/blob/v2.13.1/include/internal/catch_debugger.cpp#L79-L102 - bool isDebuggerActive() { - ErrnoGuard guard; - std::ifstream in("/proc/self/status"); - for(std::string line; std::getline(in, line);) { - static const int PREFIX_LEN = 11; - if(line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0) { - return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; - } - } - return false; - } -#elif defined(DOCTEST_PLATFORM_MAC) - // The following function is taken directly from the following technical note: - // https://developer.apple.com/library/archive/qa/qa1361/_index.html - // Returns true if the current process is being debugged (either - // running under the debugger or has a debugger attached post facto). - bool isDebuggerActive() { - int mib[4]; - kinfo_proc info; - size_t size; - // Initialize the flags so that, if sysctl fails for some bizarre - // reason, we get a predictable result. - info.kp_proc.p_flag = 0; - // Initialize mib, which tells sysctl the info we want, in this case - // we're looking for information about a specific process ID. - mib[0] = CTL_KERN; - mib[1] = KERN_PROC; - mib[2] = KERN_PROC_PID; - mib[3] = getpid(); - // Call sysctl. - size = sizeof(info); - if(sysctl(mib, DOCTEST_COUNTOF(mib), &info, &size, 0, 0) != 0) { - std::cerr << "\nCall to sysctl failed - unable to determine if debugger is active **\n"; - return false; - } - // We're being debugged if the P_TRACED flag is set. - return ((info.kp_proc.p_flag & P_TRACED) != 0); - } -#elif DOCTEST_MSVC || defined(__MINGW32__) || defined(__MINGW64__) - bool isDebuggerActive() { return ::IsDebuggerPresent() != 0; } -#else - bool isDebuggerActive() { return false; } -#endif // Platform -#endif // DOCTEST_IS_DEBUGGER_ACTIVE - - void registerExceptionTranslatorImpl(const IExceptionTranslator* et) { - if(std::find(getExceptionTranslators().begin(), getExceptionTranslators().end(), et) == - getExceptionTranslators().end()) - getExceptionTranslators().push_back(et); - } - - DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() - - ContextScopeBase::ContextScopeBase() { - g_infoContexts.push_back(this); - } - - ContextScopeBase::ContextScopeBase(ContextScopeBase&& other) noexcept { - if (other.need_to_destroy) { - other.destroy(); - } - other.need_to_destroy = false; - g_infoContexts.push_back(this); - } - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - - // destroy cannot be inlined into the destructor because that would mean calling stringify after - // ContextScope has been destroyed (base class destructors run after derived class destructors). - // Instead, ContextScope calls this method directly from its destructor. - void ContextScopeBase::destroy() { -#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) - if(std::uncaught_exceptions() > 0) { -#else - if(std::uncaught_exception()) { -#endif - std::ostringstream s; - this->stringify(&s); - g_cs->stringifiedContexts.push_back(s.str().c_str()); - } - g_infoContexts.pop_back(); - } - - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP - DOCTEST_MSVC_SUPPRESS_WARNING_POP -} // namespace detail -namespace { - using namespace detail; - -#if !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && !defined(DOCTEST_CONFIG_WINDOWS_SEH) - struct FatalConditionHandler - { - static void reset() {} - static void allocateAltStackMem() {} - static void freeAltStackMem() {} - }; -#else // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH - - void reportFatal(const std::string&); - -#ifdef DOCTEST_PLATFORM_WINDOWS - - struct SignalDefs - { - DWORD id; - const char* name; - }; - // There is no 1-1 mapping between signals and windows exceptions. - // Windows can easily distinguish between SO and SigSegV, - // but SigInt, SigTerm, etc are handled differently. - SignalDefs signalDefs[] = { - {static_cast(EXCEPTION_ILLEGAL_INSTRUCTION), - "SIGILL - Illegal instruction signal"}, - {static_cast(EXCEPTION_STACK_OVERFLOW), "SIGSEGV - Stack overflow"}, - {static_cast(EXCEPTION_ACCESS_VIOLATION), - "SIGSEGV - Segmentation violation signal"}, - {static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error"}, - }; - - struct FatalConditionHandler - { - static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { - // Multiple threads may enter this filter/handler at once. We want the error message to be printed on the - // console just once no matter how many threads have crashed. - DOCTEST_DECLARE_STATIC_MUTEX(mutex) - static bool execute = true; - { - DOCTEST_LOCK_MUTEX(mutex) - if(execute) { - bool reported = false; - for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - if(ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { - reportFatal(signalDefs[i].name); - reported = true; - break; - } - } - if(reported == false) - reportFatal("Unhandled SEH exception caught"); - if(isDebuggerActive() && !g_cs->no_breaks) - DOCTEST_BREAK_INTO_DEBUGGER(); - } - execute = false; - } - std::exit(EXIT_FAILURE); - } - - static void allocateAltStackMem() {} - static void freeAltStackMem() {} - - FatalConditionHandler() { - isSet = true; - // 32k seems enough for doctest to handle stack overflow, - // but the value was found experimentally, so there is no strong guarantee - guaranteeSize = 32 * 1024; - // Register an unhandled exception filter - previousTop = SetUnhandledExceptionFilter(handleException); - // Pass in guarantee size to be filled - SetThreadStackGuarantee(&guaranteeSize); - - // On Windows uncaught exceptions from another thread, exceptions from - // destructors, or calls to std::terminate are not a SEH exception - - // The terminal handler gets called when: - // - std::terminate is called FROM THE TEST RUNNER THREAD - // - an exception is thrown from a destructor FROM THE TEST RUNNER THREAD - original_terminate_handler = std::get_terminate(); - std::set_terminate([]() DOCTEST_NOEXCEPT { - reportFatal("Terminate handler called"); - if(isDebuggerActive() && !g_cs->no_breaks) - DOCTEST_BREAK_INTO_DEBUGGER(); - std::exit(EXIT_FAILURE); // explicitly exit - otherwise the SIGABRT handler may be called as well - }); - - // SIGABRT is raised when: - // - std::terminate is called FROM A DIFFERENT THREAD - // - an exception is thrown from a destructor FROM A DIFFERENT THREAD - // - an uncaught exception is thrown FROM A DIFFERENT THREAD - prev_sigabrt_handler = std::signal(SIGABRT, [](int signal) DOCTEST_NOEXCEPT { - if(signal == SIGABRT) { - reportFatal("SIGABRT - Abort (abnormal termination) signal"); - if(isDebuggerActive() && !g_cs->no_breaks) - DOCTEST_BREAK_INTO_DEBUGGER(); - std::exit(EXIT_FAILURE); - } - }); - - // The following settings are taken from google test, and more - // specifically from UnitTest::Run() inside of gtest.cc - - // the user does not want to see pop-up dialogs about crashes - prev_error_mode_1 = SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | - SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); - // This forces the abort message to go to stderr in all circumstances. - prev_error_mode_2 = _set_error_mode(_OUT_TO_STDERR); - // In the debug version, Visual Studio pops up a separate dialog - // offering a choice to debug the aborted program - we want to disable that. - prev_abort_behavior = _set_abort_behavior(0x0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); - // In debug mode, the Windows CRT can crash with an assertion over invalid - // input (e.g. passing an invalid file descriptor). The default handling - // for these assertions is to pop up a dialog and wait for user input. - // Instead ask the CRT to dump such assertions to stderr non-interactively. - prev_report_mode = _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); - prev_report_file = _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); - } - - static void reset() { - if(isSet) { - // Unregister handler and restore the old guarantee - SetUnhandledExceptionFilter(previousTop); - SetThreadStackGuarantee(&guaranteeSize); - std::set_terminate(original_terminate_handler); - std::signal(SIGABRT, prev_sigabrt_handler); - SetErrorMode(prev_error_mode_1); - _set_error_mode(prev_error_mode_2); - _set_abort_behavior(prev_abort_behavior, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); - static_cast(_CrtSetReportMode(_CRT_ASSERT, prev_report_mode)); - static_cast(_CrtSetReportFile(_CRT_ASSERT, prev_report_file)); - isSet = false; - } - } - - ~FatalConditionHandler() { reset(); } - - private: - static UINT prev_error_mode_1; - static int prev_error_mode_2; - static unsigned int prev_abort_behavior; - static int prev_report_mode; - static _HFILE prev_report_file; - static void (DOCTEST_CDECL *prev_sigabrt_handler)(int); - static std::terminate_handler original_terminate_handler; - static bool isSet; - static ULONG guaranteeSize; - static LPTOP_LEVEL_EXCEPTION_FILTER previousTop; - }; - - UINT FatalConditionHandler::prev_error_mode_1; - int FatalConditionHandler::prev_error_mode_2; - unsigned int FatalConditionHandler::prev_abort_behavior; - int FatalConditionHandler::prev_report_mode; - _HFILE FatalConditionHandler::prev_report_file; - void (DOCTEST_CDECL *FatalConditionHandler::prev_sigabrt_handler)(int); - std::terminate_handler FatalConditionHandler::original_terminate_handler; - bool FatalConditionHandler::isSet = false; - ULONG FatalConditionHandler::guaranteeSize = 0; - LPTOP_LEVEL_EXCEPTION_FILTER FatalConditionHandler::previousTop = nullptr; - -#else // DOCTEST_PLATFORM_WINDOWS - - struct SignalDefs - { - int id; - const char* name; - }; - SignalDefs signalDefs[] = {{SIGINT, "SIGINT - Terminal interrupt signal"}, - {SIGILL, "SIGILL - Illegal instruction signal"}, - {SIGFPE, "SIGFPE - Floating point error signal"}, - {SIGSEGV, "SIGSEGV - Segmentation violation signal"}, - {SIGTERM, "SIGTERM - Termination request signal"}, - {SIGABRT, "SIGABRT - Abort (abnormal termination) signal"}}; - - struct FatalConditionHandler - { - static bool isSet; - static struct sigaction oldSigActions[DOCTEST_COUNTOF(signalDefs)]; - static stack_t oldSigStack; - static size_t altStackSize; - static char* altStackMem; - - static void handleSignal(int sig) { - const char* name = ""; - for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - SignalDefs& def = signalDefs[i]; - if(sig == def.id) { - name = def.name; - break; - } - } - reset(); - reportFatal(name); - raise(sig); - } - - static void allocateAltStackMem() { - altStackMem = new char[altStackSize]; - } - - static void freeAltStackMem() { - delete[] altStackMem; - } - - FatalConditionHandler() { - isSet = true; - stack_t sigStack; - sigStack.ss_sp = altStackMem; - sigStack.ss_size = altStackSize; - sigStack.ss_flags = 0; - sigaltstack(&sigStack, &oldSigStack); - struct sigaction sa = {}; - sa.sa_handler = handleSignal; - sa.sa_flags = SA_ONSTACK; - for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); - } - } - - ~FatalConditionHandler() { reset(); } - static void reset() { - if(isSet) { - // Set signals back to previous values -- hopefully nobody overwrote them in the meantime - for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); - } - // Return the old stack - sigaltstack(&oldSigStack, nullptr); - isSet = false; - } - } - }; - - bool FatalConditionHandler::isSet = false; - struct sigaction FatalConditionHandler::oldSigActions[DOCTEST_COUNTOF(signalDefs)] = {}; - stack_t FatalConditionHandler::oldSigStack = {}; - size_t FatalConditionHandler::altStackSize = 4 * SIGSTKSZ; - char* FatalConditionHandler::altStackMem = nullptr; - -#endif // DOCTEST_PLATFORM_WINDOWS -#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH - -} // namespace - -namespace { - using namespace detail; - -#ifdef DOCTEST_PLATFORM_WINDOWS -#define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) -#else - // TODO: integration with XCode and other IDEs -#define DOCTEST_OUTPUT_DEBUG_STRING(text) -#endif // Platform - - void addAssert(assertType::Enum at) { - if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional - g_cs->numAssertsCurrentTest_atomic++; - } - - void addFailedAssert(assertType::Enum at) { - if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional - g_cs->numAssertsFailedCurrentTest_atomic++; - } - -#if defined(DOCTEST_CONFIG_POSIX_SIGNALS) || defined(DOCTEST_CONFIG_WINDOWS_SEH) - void reportFatal(const std::string& message) { - g_cs->failure_flags |= TestCaseFailureReason::Crash; - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); - - while (g_cs->subcaseStack.size()) { - g_cs->subcaseStack.pop_back(); - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); - } - - g_cs->finalizeTestCaseData(); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); - } -#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH -} // namespace - -AssertData::AssertData(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const StringContains& exception_string) - : m_test_case(g_cs->currentTest), m_at(at), m_file(file), m_line(line), m_expr(expr), - m_failed(true), m_threw(false), m_threw_as(false), m_exception_type(exception_type), - m_exception_string(exception_string) { -#if DOCTEST_MSVC - if (m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC - ++m_expr; -#endif // MSVC -} - -namespace detail { - ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const String& exception_string) - : AssertData(at, file, line, expr, exception_type, exception_string) { } - - ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const Contains& exception_string) - : AssertData(at, file, line, expr, exception_type, exception_string) { } - - void ResultBuilder::setResult(const Result& res) { - m_decomp = res.m_decomp; - m_failed = !res.m_passed; - } - - void ResultBuilder::translateException() { - m_threw = true; - m_exception = translateActiveException(); - } - - bool ResultBuilder::log() { - if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional - m_failed = !m_threw; - } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT - m_failed = !m_threw_as || !m_exception_string.check(m_exception); - } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional - m_failed = !m_threw_as; - } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional - m_failed = !m_exception_string.check(m_exception); - } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional - m_failed = m_threw; - } - - if(m_exception.size()) - m_exception = "\"" + m_exception + "\""; - - if(is_running_in_test) { - addAssert(m_at); - DOCTEST_ITERATE_THROUGH_REPORTERS(log_assert, *this); - - if(m_failed) - addFailedAssert(m_at); - } else if(m_failed) { - failed_out_of_a_testing_context(*this); - } - - return m_failed && isDebuggerActive() && !getContextOptions()->no_breaks && - (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger - } - - void ResultBuilder::react() const { - if(m_failed && checkIfShouldThrow(m_at)) - throwException(); - } - - void failed_out_of_a_testing_context(const AssertData& ad) { - if(g_cs->ah) - g_cs->ah(ad); - else - std::abort(); - } - - bool decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, - const Result& result) { - bool failed = !result.m_passed; - - // ################################################################################### - // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT - // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED - // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); - DOCTEST_ASSERT_IN_TESTS(result.m_decomp); - return !failed; - } - - MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { - m_stream = tlssPush(); - m_file = file; - m_line = line; - m_severity = severity; - } - - MessageBuilder::~MessageBuilder() { - if (!logged) - tlssPop(); - } - - DOCTEST_DEFINE_INTERFACE(IExceptionTranslator) - - bool MessageBuilder::log() { - if (!logged) { - m_string = tlssPop(); - logged = true; - } - - DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); - - const bool isWarn = m_severity & assertType::is_warn; - - // warn is just a message in this context so we don't treat it as an assert - if(!isWarn) { - addAssert(m_severity); - addFailedAssert(m_severity); - } - - return isDebuggerActive() && !getContextOptions()->no_breaks && !isWarn && - (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger - } - - void MessageBuilder::react() { - if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional - throwException(); - } -} // namespace detail -namespace { - using namespace detail; - - // clang-format off - -// ================================================================================================= -// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp -// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. -// ================================================================================================= - - class XmlEncode { - public: - enum ForWhat { ForTextNodes, ForAttributes }; - - XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); - - void encodeTo( std::ostream& os ) const; - - friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); - - private: - std::string m_str; - ForWhat m_forWhat; - }; - - class XmlWriter { - public: - - class ScopedElement { - public: - ScopedElement( XmlWriter* writer ); - - ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT; - ScopedElement& operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT; - - ~ScopedElement(); - - ScopedElement& writeText( std::string const& text, bool indent = true ); - - template - ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { - m_writer->writeAttribute( name, attribute ); - return *this; - } - - private: - mutable XmlWriter* m_writer = nullptr; - }; - - XmlWriter( std::ostream& os = std::cout ); - ~XmlWriter(); - - XmlWriter( XmlWriter const& ) = delete; - XmlWriter& operator=( XmlWriter const& ) = delete; - - XmlWriter& startElement( std::string const& name ); - - ScopedElement scopedElement( std::string const& name ); - - XmlWriter& endElement(); - - XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); - - XmlWriter& writeAttribute( std::string const& name, const char* attribute ); - - XmlWriter& writeAttribute( std::string const& name, bool attribute ); - - template - XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { - std::stringstream rss; - rss << attribute; - return writeAttribute( name, rss.str() ); - } - - XmlWriter& writeText( std::string const& text, bool indent = true ); - - //XmlWriter& writeComment( std::string const& text ); - - //void writeStylesheetRef( std::string const& url ); - - //XmlWriter& writeBlankLine(); - - void ensureTagClosed(); - - void writeDeclaration(); - - private: - - void newlineIfNecessary(); - - bool m_tagIsOpen = false; - bool m_needsNewline = false; - std::vector m_tags; - std::string m_indent; - std::ostream& m_os; - }; - -// ================================================================================================= -// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp -// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. -// ================================================================================================= - -using uchar = unsigned char; - -namespace { - - size_t trailingBytes(unsigned char c) { - if ((c & 0xE0) == 0xC0) { - return 2; - } - if ((c & 0xF0) == 0xE0) { - return 3; - } - if ((c & 0xF8) == 0xF0) { - return 4; - } - DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); - } - - uint32_t headerValue(unsigned char c) { - if ((c & 0xE0) == 0xC0) { - return c & 0x1F; - } - if ((c & 0xF0) == 0xE0) { - return c & 0x0F; - } - if ((c & 0xF8) == 0xF0) { - return c & 0x07; - } - DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); - } - - void hexEscapeChar(std::ostream& os, unsigned char c) { - std::ios_base::fmtflags f(os.flags()); - os << "\\x" - << std::uppercase << std::hex << std::setfill('0') << std::setw(2) - << static_cast(c); - os.flags(f); - } - -} // anonymous namespace - - XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) - : m_str( str ), - m_forWhat( forWhat ) - {} - - void XmlEncode::encodeTo( std::ostream& os ) const { - // Apostrophe escaping not necessary if we always use " to write attributes - // (see: https://www.w3.org/TR/xml/#syntax) - - for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { - uchar c = m_str[idx]; - switch (c) { - case '<': os << "<"; break; - case '&': os << "&"; break; - - case '>': - // See: https://www.w3.org/TR/xml/#syntax - if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') - os << ">"; - else - os << c; - break; - - case '\"': - if (m_forWhat == ForAttributes) - os << """; - else - os << c; - break; - - default: - // Check for control characters and invalid utf-8 - - // Escape control characters in standard ascii - // see https://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 - if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { - hexEscapeChar(os, c); - break; - } - - // Plain ASCII: Write it to stream - if (c < 0x7F) { - os << c; - break; - } - - // UTF-8 territory - // Check if the encoding is valid and if it is not, hex escape bytes. - // Important: We do not check the exact decoded values for validity, only the encoding format - // First check that this bytes is a valid lead byte: - // This means that it is not encoded as 1111 1XXX - // Or as 10XX XXXX - if (c < 0xC0 || - c >= 0xF8) { - hexEscapeChar(os, c); - break; - } - - auto encBytes = trailingBytes(c); - // Are there enough bytes left to avoid accessing out-of-bounds memory? - if (idx + encBytes - 1 >= m_str.size()) { - hexEscapeChar(os, c); - break; - } - // The header is valid, check data - // The next encBytes bytes must together be a valid utf-8 - // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) - bool valid = true; - uint32_t value = headerValue(c); - for (std::size_t n = 1; n < encBytes; ++n) { - uchar nc = m_str[idx + n]; - valid &= ((nc & 0xC0) == 0x80); - value = (value << 6) | (nc & 0x3F); - } - - if ( - // Wrong bit pattern of following bytes - (!valid) || - // Overlong encodings - (value < 0x80) || - ( value < 0x800 && encBytes > 2) || // removed "0x80 <= value &&" because redundant - (0x800 < value && value < 0x10000 && encBytes > 3) || - // Encoded value out of range - (value >= 0x110000) - ) { - hexEscapeChar(os, c); - break; - } - - // If we got here, this is in fact a valid(ish) utf-8 sequence - for (std::size_t n = 0; n < encBytes; ++n) { - os << m_str[idx + n]; - } - idx += encBytes - 1; - break; - } - } - } - - std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { - xmlEncode.encodeTo( os ); - return os; - } - - XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) - : m_writer( writer ) - {} - - XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT - : m_writer( other.m_writer ){ - other.m_writer = nullptr; - } - XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT { - if ( m_writer ) { - m_writer->endElement(); - } - m_writer = other.m_writer; - other.m_writer = nullptr; - return *this; - } - - - XmlWriter::ScopedElement::~ScopedElement() { - if( m_writer ) - m_writer->endElement(); - } - - XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { - m_writer->writeText( text, indent ); - return *this; - } - - XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) - { - // writeDeclaration(); // called explicitly by the reporters that use the writer class - see issue #627 - } - - XmlWriter::~XmlWriter() { - while( !m_tags.empty() ) - endElement(); - } - - XmlWriter& XmlWriter::startElement( std::string const& name ) { - ensureTagClosed(); - newlineIfNecessary(); - m_os << m_indent << '<' << name; - m_tags.push_back( name ); - m_indent += " "; - m_tagIsOpen = true; - return *this; - } - - XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { - ScopedElement scoped( this ); - startElement( name ); - return scoped; - } - - XmlWriter& XmlWriter::endElement() { - newlineIfNecessary(); - m_indent = m_indent.substr( 0, m_indent.size()-2 ); - if( m_tagIsOpen ) { - m_os << "/>"; - m_tagIsOpen = false; - } - else { - m_os << m_indent << ""; - } - m_os << std::endl; - m_tags.pop_back(); - return *this; - } - - XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { - if( !name.empty() && !attribute.empty() ) - m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; - return *this; - } - - XmlWriter& XmlWriter::writeAttribute( std::string const& name, const char* attribute ) { - if( !name.empty() && attribute && attribute[0] != '\0' ) - m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; - return *this; - } - - XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { - m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; - return *this; - } - - XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { - if( !text.empty() ){ - bool tagWasOpen = m_tagIsOpen; - ensureTagClosed(); - if( tagWasOpen && indent ) - m_os << m_indent; - m_os << XmlEncode( text ); - m_needsNewline = true; - } - return *this; - } - - //XmlWriter& XmlWriter::writeComment( std::string const& text ) { - // ensureTagClosed(); - // m_os << m_indent << ""; - // m_needsNewline = true; - // return *this; - //} - - //void XmlWriter::writeStylesheetRef( std::string const& url ) { - // m_os << "\n"; - //} - - //XmlWriter& XmlWriter::writeBlankLine() { - // ensureTagClosed(); - // m_os << '\n'; - // return *this; - //} - - void XmlWriter::ensureTagClosed() { - if( m_tagIsOpen ) { - m_os << ">" << std::endl; - m_tagIsOpen = false; - } - } - - void XmlWriter::writeDeclaration() { - m_os << "\n"; - } - - void XmlWriter::newlineIfNecessary() { - if( m_needsNewline ) { - m_os << std::endl; - m_needsNewline = false; - } - } - -// ================================================================================================= -// End of copy-pasted code from Catch -// ================================================================================================= - - // clang-format on - - struct XmlReporter : public IReporter - { - XmlWriter xml; - DOCTEST_DECLARE_MUTEX(mutex) - - // caching pointers/references to objects of these types - safe to do - const ContextOptions& opt; - const TestCaseData* tc = nullptr; - - XmlReporter(const ContextOptions& co) - : xml(*co.cout) - , opt(co) {} - - void log_contexts() { - int num_contexts = get_num_active_contexts(); - if(num_contexts) { - auto contexts = get_active_contexts(); - std::stringstream ss; - for(int i = 0; i < num_contexts; ++i) { - contexts[i]->stringify(&ss); - xml.scopedElement("Info").writeText(ss.str()); - ss.str(""); - } - } - } - - unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } - - void test_case_start_impl(const TestCaseData& in) { - bool open_ts_tag = false; - if(tc != nullptr) { // we have already opened a test suite - if(std::strcmp(tc->m_test_suite, in.m_test_suite) != 0) { - xml.endElement(); - open_ts_tag = true; - } - } - else { - open_ts_tag = true; // first test case ==> first test suite - } - - if(open_ts_tag) { - xml.startElement("TestSuite"); - xml.writeAttribute("name", in.m_test_suite); - } - - tc = ∈ - xml.startElement("TestCase") - .writeAttribute("name", in.m_name) - .writeAttribute("filename", skipPathFromFilename(in.m_file.c_str())) - .writeAttribute("line", line(in.m_line)) - .writeAttribute("description", in.m_description); - - if(Approx(in.m_timeout) != 0) - xml.writeAttribute("timeout", in.m_timeout); - if(in.m_may_fail) - xml.writeAttribute("may_fail", true); - if(in.m_should_fail) - xml.writeAttribute("should_fail", true); - } - - // ========================================================================================= - // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE - // ========================================================================================= - - void report_query(const QueryData& in) override { - test_run_start(); - if(opt.list_reporters) { - for(auto& curr : getListeners()) - xml.scopedElement("Listener") - .writeAttribute("priority", curr.first.first) - .writeAttribute("name", curr.first.second); - for(auto& curr : getReporters()) - xml.scopedElement("Reporter") - .writeAttribute("priority", curr.first.first) - .writeAttribute("name", curr.first.second); - } else if(opt.count || opt.list_test_cases) { - for(unsigned i = 0; i < in.num_data; ++i) { - xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) - .writeAttribute("testsuite", in.data[i]->m_test_suite) - .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) - .writeAttribute("line", line(in.data[i]->m_line)) - .writeAttribute("skipped", in.data[i]->m_skip); - } - xml.scopedElement("OverallResultsTestCases") - .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); - } else if(opt.list_test_suites) { - for(unsigned i = 0; i < in.num_data; ++i) - xml.scopedElement("TestSuite").writeAttribute("name", in.data[i]->m_test_suite); - xml.scopedElement("OverallResultsTestCases") - .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); - xml.scopedElement("OverallResultsTestSuites") - .writeAttribute("unskipped", in.run_stats->numTestSuitesPassingFilters); - } - xml.endElement(); - } - - void test_run_start() override { - xml.writeDeclaration(); - - // remove .exe extension - mainly to have the same output on UNIX and Windows - std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); -#ifdef DOCTEST_PLATFORM_WINDOWS - if(binary_name.rfind(".exe") != std::string::npos) - binary_name = binary_name.substr(0, binary_name.length() - 4); -#endif // DOCTEST_PLATFORM_WINDOWS - - xml.startElement("doctest").writeAttribute("binary", binary_name); - if(opt.no_version == false) - xml.writeAttribute("version", DOCTEST_VERSION_STR); - - // only the consequential ones (TODO: filters) - xml.scopedElement("Options") - .writeAttribute("order_by", opt.order_by.c_str()) - .writeAttribute("rand_seed", opt.rand_seed) - .writeAttribute("first", opt.first) - .writeAttribute("last", opt.last) - .writeAttribute("abort_after", opt.abort_after) - .writeAttribute("subcase_filter_levels", opt.subcase_filter_levels) - .writeAttribute("case_sensitive", opt.case_sensitive) - .writeAttribute("no_throw", opt.no_throw) - .writeAttribute("no_skip", opt.no_skip); - } - - void test_run_end(const TestRunStats& p) override { - if(tc) // the TestSuite tag - only if there has been at least 1 test case - xml.endElement(); - - xml.scopedElement("OverallResultsAsserts") - .writeAttribute("successes", p.numAsserts - p.numAssertsFailed) - .writeAttribute("failures", p.numAssertsFailed); - - xml.startElement("OverallResultsTestCases") - .writeAttribute("successes", - p.numTestCasesPassingFilters - p.numTestCasesFailed) - .writeAttribute("failures", p.numTestCasesFailed); - if(opt.no_skipped_summary == false) - xml.writeAttribute("skipped", p.numTestCases - p.numTestCasesPassingFilters); - xml.endElement(); - - xml.endElement(); - } - - void test_case_start(const TestCaseData& in) override { - test_case_start_impl(in); - xml.ensureTagClosed(); - } - - void test_case_reenter(const TestCaseData&) override {} - - void test_case_end(const CurrentTestCaseStats& st) override { - xml.startElement("OverallResultsAsserts") - .writeAttribute("successes", - st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) - .writeAttribute("failures", st.numAssertsFailedCurrentTest) - .writeAttribute("test_case_success", st.testCaseSuccess); - if(opt.duration) - xml.writeAttribute("duration", st.seconds); - if(tc->m_expected_failures) - xml.writeAttribute("expected_failures", tc->m_expected_failures); - xml.endElement(); - - xml.endElement(); - } - - void test_case_exception(const TestCaseException& e) override { - DOCTEST_LOCK_MUTEX(mutex) - - xml.scopedElement("Exception") - .writeAttribute("crash", e.is_crash) - .writeText(e.error_string.c_str()); - } - - void subcase_start(const SubcaseSignature& in) override { - xml.startElement("SubCase") - .writeAttribute("name", in.m_name) - .writeAttribute("filename", skipPathFromFilename(in.m_file)) - .writeAttribute("line", line(in.m_line)); - xml.ensureTagClosed(); - } - - void subcase_end() override { xml.endElement(); } - - void log_assert(const AssertData& rb) override { - if(!rb.m_failed && !opt.success) - return; - - DOCTEST_LOCK_MUTEX(mutex) - - xml.startElement("Expression") - .writeAttribute("success", !rb.m_failed) - .writeAttribute("type", assertString(rb.m_at)) - .writeAttribute("filename", skipPathFromFilename(rb.m_file)) - .writeAttribute("line", line(rb.m_line)); - - xml.scopedElement("Original").writeText(rb.m_expr); - - if(rb.m_threw) - xml.scopedElement("Exception").writeText(rb.m_exception.c_str()); - - if(rb.m_at & assertType::is_throws_as) - xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); - if(rb.m_at & assertType::is_throws_with) - xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string.c_str()); - if((rb.m_at & assertType::is_normal) && !rb.m_threw) - xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); - - log_contexts(); - - xml.endElement(); - } - - void log_message(const MessageData& mb) override { - DOCTEST_LOCK_MUTEX(mutex) - - xml.startElement("Message") - .writeAttribute("type", failureString(mb.m_severity)) - .writeAttribute("filename", skipPathFromFilename(mb.m_file)) - .writeAttribute("line", line(mb.m_line)); - - xml.scopedElement("Text").writeText(mb.m_string.c_str()); - - log_contexts(); - - xml.endElement(); - } - - void test_case_skipped(const TestCaseData& in) override { - if(opt.no_skipped_summary == false) { - test_case_start_impl(in); - xml.writeAttribute("skipped", "true"); - xml.endElement(); - } - } - }; - - DOCTEST_REGISTER_REPORTER("xml", 0, XmlReporter); - - void fulltext_log_assert_to_stream(std::ostream& s, const AssertData& rb) { - if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) == - 0) //!OCLINT bitwise operator in conditional - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << " ) " - << Color::None; - - if(rb.m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional - s << (rb.m_threw ? "threw as expected!" : "did NOT throw at all!") << "\n"; - } else if((rb.m_at & assertType::is_throws_as) && - (rb.m_at & assertType::is_throws_with)) { //!OCLINT - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" - << rb.m_exception_string.c_str() - << "\", " << rb.m_exception_type << " ) " << Color::None; - if(rb.m_threw) { - if(!rb.m_failed) { - s << "threw as expected!\n"; - } else { - s << "threw a DIFFERENT exception! (contents: " << rb.m_exception << ")\n"; - } - } else { - s << "did NOT throw at all!\n"; - } - } else if(rb.m_at & - assertType::is_throws_as) { //!OCLINT bitwise operator in conditional - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", " - << rb.m_exception_type << " ) " << Color::None - << (rb.m_threw ? (rb.m_threw_as ? "threw as expected!" : - "threw a DIFFERENT exception: ") : - "did NOT throw at all!") - << Color::Cyan << rb.m_exception << "\n"; - } else if(rb.m_at & - assertType::is_throws_with) { //!OCLINT bitwise operator in conditional - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" - << rb.m_exception_string.c_str() - << "\" ) " << Color::None - << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : - "threw a DIFFERENT exception: ") : - "did NOT throw at all!") - << Color::Cyan << rb.m_exception << "\n"; - } else if(rb.m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional - s << (rb.m_threw ? "THREW exception: " : "didn't throw!") << Color::Cyan - << rb.m_exception << "\n"; - } else { - s << (rb.m_threw ? "THREW exception: " : - (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")); - if(rb.m_threw) - s << rb.m_exception << "\n"; - else - s << " values: " << assertString(rb.m_at) << "( " << rb.m_decomp << " )\n"; - } - } - - // TODO: - // - log_message() - // - respond to queries - // - honor remaining options - // - more attributes in tags - struct JUnitReporter : public IReporter - { - XmlWriter xml; - DOCTEST_DECLARE_MUTEX(mutex) - Timer timer; - std::vector deepestSubcaseStackNames; - - struct JUnitTestCaseData - { - static std::string getCurrentTimestamp() { - // Beware, this is not reentrant because of backward compatibility issues - // Also, UTC only, again because of backward compatibility (%z is C++11) - time_t rawtime; - std::time(&rawtime); - auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); - - std::tm timeInfo; -#ifdef DOCTEST_PLATFORM_WINDOWS - gmtime_s(&timeInfo, &rawtime); -#else // DOCTEST_PLATFORM_WINDOWS - gmtime_r(&rawtime, &timeInfo); -#endif // DOCTEST_PLATFORM_WINDOWS - - char timeStamp[timeStampSize]; - const char* const fmt = "%Y-%m-%dT%H:%M:%SZ"; - - std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); - return std::string(timeStamp); - } - - struct JUnitTestMessage - { - JUnitTestMessage(const std::string& _message, const std::string& _type, const std::string& _details) - : message(_message), type(_type), details(_details) {} - - JUnitTestMessage(const std::string& _message, const std::string& _details) - : message(_message), type(), details(_details) {} - - std::string message, type, details; - }; - - struct JUnitTestCase - { - JUnitTestCase(const std::string& _classname, const std::string& _name) - : classname(_classname), name(_name), time(0), failures() {} - - std::string classname, name; - double time; - std::vector failures, errors; - }; - - void add(const std::string& classname, const std::string& name) { - testcases.emplace_back(classname, name); - } - - void appendSubcaseNamesToLastTestcase(std::vector nameStack) { - for(auto& curr: nameStack) - if(curr.size()) - testcases.back().name += std::string("/") + curr.c_str(); - } - - void addTime(double time) { - if(time < 1e-4) - time = 0; - testcases.back().time = time; - totalSeconds += time; - } - - void addFailure(const std::string& message, const std::string& type, const std::string& details) { - testcases.back().failures.emplace_back(message, type, details); - ++totalFailures; - } - - void addError(const std::string& message, const std::string& details) { - testcases.back().errors.emplace_back(message, details); - ++totalErrors; - } - - std::vector testcases; - double totalSeconds = 0; - int totalErrors = 0, totalFailures = 0; - }; - - JUnitTestCaseData testCaseData; - - // caching pointers/references to objects of these types - safe to do - const ContextOptions& opt; - const TestCaseData* tc = nullptr; - - JUnitReporter(const ContextOptions& co) - : xml(*co.cout) - , opt(co) {} - - unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } - - // ========================================================================================= - // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE - // ========================================================================================= - - void report_query(const QueryData&) override { - xml.writeDeclaration(); - } - - void test_run_start() override { - xml.writeDeclaration(); - } - - void test_run_end(const TestRunStats& p) override { - // remove .exe extension - mainly to have the same output on UNIX and Windows - std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); -#ifdef DOCTEST_PLATFORM_WINDOWS - if(binary_name.rfind(".exe") != std::string::npos) - binary_name = binary_name.substr(0, binary_name.length() - 4); -#endif // DOCTEST_PLATFORM_WINDOWS - xml.startElement("testsuites"); - xml.startElement("testsuite").writeAttribute("name", binary_name) - .writeAttribute("errors", testCaseData.totalErrors) - .writeAttribute("failures", testCaseData.totalFailures) - .writeAttribute("tests", p.numAsserts); - if(opt.no_time_in_output == false) { - xml.writeAttribute("time", testCaseData.totalSeconds); - xml.writeAttribute("timestamp", JUnitTestCaseData::getCurrentTimestamp()); - } - if(opt.no_version == false) - xml.writeAttribute("doctest_version", DOCTEST_VERSION_STR); - - for(const auto& testCase : testCaseData.testcases) { - xml.startElement("testcase") - .writeAttribute("classname", testCase.classname) - .writeAttribute("name", testCase.name); - if(opt.no_time_in_output == false) - xml.writeAttribute("time", testCase.time); - // This is not ideal, but it should be enough to mimic gtest's junit output. - xml.writeAttribute("status", "run"); - - for(const auto& failure : testCase.failures) { - xml.scopedElement("failure") - .writeAttribute("message", failure.message) - .writeAttribute("type", failure.type) - .writeText(failure.details, false); - } - - for(const auto& error : testCase.errors) { - xml.scopedElement("error") - .writeAttribute("message", error.message) - .writeText(error.details); - } - - xml.endElement(); - } - xml.endElement(); - xml.endElement(); - } - - void test_case_start(const TestCaseData& in) override { - testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); - timer.start(); - } - - void test_case_reenter(const TestCaseData& in) override { - testCaseData.addTime(timer.getElapsedSeconds()); - testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); - deepestSubcaseStackNames.clear(); - - timer.start(); - testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); - } - - void test_case_end(const CurrentTestCaseStats&) override { - testCaseData.addTime(timer.getElapsedSeconds()); - testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); - deepestSubcaseStackNames.clear(); - } - - void test_case_exception(const TestCaseException& e) override { - DOCTEST_LOCK_MUTEX(mutex) - testCaseData.addError("exception", e.error_string.c_str()); - } - - void subcase_start(const SubcaseSignature& in) override { - deepestSubcaseStackNames.push_back(in.m_name); - } - - void subcase_end() override {} - - void log_assert(const AssertData& rb) override { - if(!rb.m_failed) // report only failures & ignore the `success` option - return; - - DOCTEST_LOCK_MUTEX(mutex) - - std::ostringstream os; - os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") - << line(rb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; - - fulltext_log_assert_to_stream(os, rb); - log_contexts(os); - testCaseData.addFailure(rb.m_decomp.c_str(), assertString(rb.m_at), os.str()); - } - - void log_message(const MessageData&) override {} - - void test_case_skipped(const TestCaseData&) override {} - - void log_contexts(std::ostringstream& s) { - int num_contexts = get_num_active_contexts(); - if(num_contexts) { - auto contexts = get_active_contexts(); - - s << " logged: "; - for(int i = 0; i < num_contexts; ++i) { - s << (i == 0 ? "" : " "); - contexts[i]->stringify(&s); - s << std::endl; - } - } - } - }; - - DOCTEST_REGISTER_REPORTER("junit", 0, JUnitReporter); - - struct Whitespace - { - int nrSpaces; - explicit Whitespace(int nr) - : nrSpaces(nr) {} - }; - - std::ostream& operator<<(std::ostream& out, const Whitespace& ws) { - if(ws.nrSpaces != 0) - out << std::setw(ws.nrSpaces) << ' '; - return out; - } - - struct ConsoleReporter : public IReporter - { - std::ostream& s; - bool hasLoggedCurrentTestStart; - std::vector subcasesStack; - size_t currentSubcaseLevel; - DOCTEST_DECLARE_MUTEX(mutex) - - // caching pointers/references to objects of these types - safe to do - const ContextOptions& opt; - const TestCaseData* tc; - - ConsoleReporter(const ContextOptions& co) - : s(*co.cout) - , opt(co) {} - - ConsoleReporter(const ContextOptions& co, std::ostream& ostr) - : s(ostr) - , opt(co) {} - - // ========================================================================================= - // WHAT FOLLOWS ARE HELPERS USED BY THE OVERRIDES OF THE VIRTUAL METHODS OF THE INTERFACE - // ========================================================================================= - - void separator_to_stream() { - s << Color::Yellow - << "===============================================================================" - "\n"; - } - - const char* getSuccessOrFailString(bool success, assertType::Enum at, - const char* success_str) { - if(success) - return success_str; - return failureString(at); - } - - Color::Enum getSuccessOrFailColor(bool success, assertType::Enum at) { - return success ? Color::BrightGreen : - (at & assertType::is_warn) ? Color::Yellow : Color::Red; - } - - void successOrFailColoredStringToStream(bool success, assertType::Enum at, - const char* success_str = "SUCCESS") { - s << getSuccessOrFailColor(success, at) - << getSuccessOrFailString(success, at, success_str) << ": "; - } - - void log_contexts() { - int num_contexts = get_num_active_contexts(); - if(num_contexts) { - auto contexts = get_active_contexts(); - - s << Color::None << " logged: "; - for(int i = 0; i < num_contexts; ++i) { - s << (i == 0 ? "" : " "); - contexts[i]->stringify(&s); - s << "\n"; - } - } - - s << "\n"; - } - - // this was requested to be made virtual so users could override it - virtual void file_line_to_stream(const char* file, int line, - const char* tail = "") { - s << Color::LightGrey << skipPathFromFilename(file) << (opt.gnu_file_line ? ":" : "(") - << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option - << (opt.gnu_file_line ? ":" : "):") << tail; - } - - void logTestStart() { - if(hasLoggedCurrentTestStart) - return; - - separator_to_stream(); - file_line_to_stream(tc->m_file.c_str(), tc->m_line, "\n"); - if(tc->m_description) - s << Color::Yellow << "DESCRIPTION: " << Color::None << tc->m_description << "\n"; - if(tc->m_test_suite && tc->m_test_suite[0] != '\0') - s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n"; - if(strncmp(tc->m_name, " Scenario:", 11) != 0) - s << Color::Yellow << "TEST CASE: "; - s << Color::None << tc->m_name << "\n"; - - for(size_t i = 0; i < currentSubcaseLevel; ++i) { - if(subcasesStack[i].m_name[0] != '\0') - s << " " << subcasesStack[i].m_name << "\n"; - } - - if(currentSubcaseLevel != subcasesStack.size()) { - s << Color::Yellow << "\nDEEPEST SUBCASE STACK REACHED (DIFFERENT FROM THE CURRENT ONE):\n" << Color::None; - for(size_t i = 0; i < subcasesStack.size(); ++i) { - if(subcasesStack[i].m_name[0] != '\0') - s << " " << subcasesStack[i].m_name << "\n"; - } - } - - s << "\n"; - - hasLoggedCurrentTestStart = true; - } - - void printVersion() { - if(opt.no_version == false) - s << Color::Cyan << "[doctest] " << Color::None << "doctest version is \"" - << DOCTEST_VERSION_STR << "\"\n"; - } - - void printIntro() { - if(opt.no_intro == false) { - printVersion(); - s << Color::Cyan << "[doctest] " << Color::None - << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; - } - } - - void printHelp() { - int sizePrefixDisplay = static_cast(strlen(DOCTEST_OPTIONS_PREFIX_DISPLAY)); - printVersion(); - // clang-format off - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "boolean values: \"1/on/yes/true\" or \"0/off/no/false\"\n"; - s << Color::Cyan << "[doctest] " << Color::None; - s << "filter values: \"str1,str2,str3\" (comma separated strings)\n"; - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "filters use wildcards for matching strings\n"; - s << Color::Cyan << "[doctest] " << Color::None; - s << "something passes a filter if any of the strings in a filter matches\n"; -#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "ALL FLAGS, OPTIONS AND FILTERS ALSO AVAILABLE WITH A \"" DOCTEST_CONFIG_OPTIONS_PREFIX "\" PREFIX!!!\n"; -#endif - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "Query flags - the program quits after them. Available:\n\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "?, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "help, -" DOCTEST_OPTIONS_PREFIX_DISPLAY "h " - << Whitespace(sizePrefixDisplay*0) << "prints this message\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "v, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "version " - << Whitespace(sizePrefixDisplay*1) << "prints the version\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "c, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "count " - << Whitespace(sizePrefixDisplay*1) << "prints the number of matching tests\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ltc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-cases " - << Whitespace(sizePrefixDisplay*1) << "lists all matching tests by name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-suites " - << Whitespace(sizePrefixDisplay*1) << "lists all matching test suites\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-reporters " - << Whitespace(sizePrefixDisplay*1) << "lists all registered reporters\n\n"; - // ================================================================================== << 79 - s << Color::Cyan << "[doctest] " << Color::None; - s << "The available / options/filters are:\n\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case= " - << Whitespace(sizePrefixDisplay*1) << "filters tests by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file= " - << Whitespace(sizePrefixDisplay*1) << "filters tests by their file\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sfe, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their file\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite= " - << Whitespace(sizePrefixDisplay*1) << "filters tests by their test suite\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tse, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their test suite\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase= " - << Whitespace(sizePrefixDisplay*1) << "filters subcases by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT subcases by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "r, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "reporters= " - << Whitespace(sizePrefixDisplay*1) << "reporters to use (console is default)\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "o, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "out= " - << Whitespace(sizePrefixDisplay*1) << "output filename\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ob, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "order-by= " - << Whitespace(sizePrefixDisplay*1) << "how the tests should be ordered\n"; - s << Whitespace(sizePrefixDisplay*3) << " - [file/suite/name/rand/none]\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "rs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "rand-seed= " - << Whitespace(sizePrefixDisplay*1) << "seed for random ordering\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "f, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "first= " - << Whitespace(sizePrefixDisplay*1) << "the first test passing the filters to\n"; - s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "l, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "last= " - << Whitespace(sizePrefixDisplay*1) << "the last test passing the filters to\n"; - s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "aa, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "abort-after= " - << Whitespace(sizePrefixDisplay*1) << "stop after failed assertions\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "scfl,--" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-filter-levels= " - << Whitespace(sizePrefixDisplay*1) << "apply filters for the first levels\n"; - s << Color::Cyan << "\n[doctest] " << Color::None; - s << "Bool options - can be used like flags and true is assumed. Available:\n\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "s, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "success= " - << Whitespace(sizePrefixDisplay*1) << "include successful assertions in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "cs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "case-sensitive= " - << Whitespace(sizePrefixDisplay*1) << "filters being treated as case sensitive\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "e, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "exit= " - << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " - << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "m, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "minimal= " - << Whitespace(sizePrefixDisplay*1) << "minimal console output (only failures)\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "q, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "quiet= " - << Whitespace(sizePrefixDisplay*1) << "no console output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " - << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " - << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " - << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ni, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-intro= " - << Whitespace(sizePrefixDisplay*1) << "omit the framework intro in the output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " - << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " - << Whitespace(sizePrefixDisplay*1) << "disables colors in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "fc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "force-colors= " - << Whitespace(sizePrefixDisplay*1) << "use colors even when not in a tty\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nb, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-breaks= " - << Whitespace(sizePrefixDisplay*1) << "disables breakpoints in debuggers\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ns, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-skip= " - << Whitespace(sizePrefixDisplay*1) << "don't skip test cases marked as skip\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "gfl, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "gnu-file-line= " - << Whitespace(sizePrefixDisplay*1) << ":n: vs (n): for line numbers in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "npf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-path-filenames= " - << Whitespace(sizePrefixDisplay*1) << "only filenames and no paths in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nln, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-line-numbers= " - << Whitespace(sizePrefixDisplay*1) << "0 instead of real line numbers in output\n"; - // ================================================================================== << 79 - // clang-format on - - s << Color::Cyan << "\n[doctest] " << Color::None; - s << "for more information visit the project documentation\n\n"; - } - - void printRegisteredReporters() { - printVersion(); - auto printReporters = [this] (const reporterMap& reporters, const char* type) { - if(reporters.size()) { - s << Color::Cyan << "[doctest] " << Color::None << "listing all registered " << type << "\n"; - for(auto& curr : reporters) - s << "priority: " << std::setw(5) << curr.first.first - << " name: " << curr.first.second << "\n"; - } - }; - printReporters(getListeners(), "listeners"); - printReporters(getReporters(), "reporters"); - } - - // ========================================================================================= - // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE - // ========================================================================================= - - void report_query(const QueryData& in) override { - if(opt.version) { - printVersion(); - } else if(opt.help) { - printHelp(); - } else if(opt.list_reporters) { - printRegisteredReporters(); - } else if(opt.count || opt.list_test_cases) { - if(opt.list_test_cases) { - s << Color::Cyan << "[doctest] " << Color::None - << "listing all test case names\n"; - separator_to_stream(); - } - - for(unsigned i = 0; i < in.num_data; ++i) - s << Color::None << in.data[i]->m_name << "\n"; - - separator_to_stream(); - - s << Color::Cyan << "[doctest] " << Color::None - << "unskipped test cases passing the current filters: " - << g_cs->numTestCasesPassingFilters << "\n"; - - } else if(opt.list_test_suites) { - s << Color::Cyan << "[doctest] " << Color::None << "listing all test suites\n"; - separator_to_stream(); - - for(unsigned i = 0; i < in.num_data; ++i) - s << Color::None << in.data[i]->m_test_suite << "\n"; - - separator_to_stream(); - - s << Color::Cyan << "[doctest] " << Color::None - << "unskipped test cases passing the current filters: " - << g_cs->numTestCasesPassingFilters << "\n"; - s << Color::Cyan << "[doctest] " << Color::None - << "test suites with unskipped test cases passing the current filters: " - << g_cs->numTestSuitesPassingFilters << "\n"; - } - } - - void test_run_start() override { - if(!opt.minimal) - printIntro(); - } - - void test_run_end(const TestRunStats& p) override { - if(opt.minimal && p.numTestCasesFailed == 0) - return; - - separator_to_stream(); - s << std::dec; - - auto totwidth = int(std::ceil(log10((std::max(p.numTestCasesPassingFilters, static_cast(p.numAsserts))) + 1))); - auto passwidth = int(std::ceil(log10((std::max(p.numTestCasesPassingFilters - p.numTestCasesFailed, static_cast(p.numAsserts - p.numAssertsFailed))) + 1))); - auto failwidth = int(std::ceil(log10((std::max(p.numTestCasesFailed, static_cast(p.numAssertsFailed))) + 1))); - const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0; - s << Color::Cyan << "[doctest] " << Color::None << "test cases: " << std::setw(totwidth) - << p.numTestCasesPassingFilters << " | " - << ((p.numTestCasesPassingFilters == 0 || anythingFailed) ? Color::None : - Color::Green) - << std::setw(passwidth) << p.numTestCasesPassingFilters - p.numTestCasesFailed << " passed" - << Color::None << " | " << (p.numTestCasesFailed > 0 ? Color::Red : Color::None) - << std::setw(failwidth) << p.numTestCasesFailed << " failed" << Color::None << " |"; - if(opt.no_skipped_summary == false) { - const int numSkipped = p.numTestCases - p.numTestCasesPassingFilters; - s << " " << (numSkipped == 0 ? Color::None : Color::Yellow) << numSkipped - << " skipped" << Color::None; - } - s << "\n"; - s << Color::Cyan << "[doctest] " << Color::None << "assertions: " << std::setw(totwidth) - << p.numAsserts << " | " - << ((p.numAsserts == 0 || anythingFailed) ? Color::None : Color::Green) - << std::setw(passwidth) << (p.numAsserts - p.numAssertsFailed) << " passed" << Color::None - << " | " << (p.numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(failwidth) - << p.numAssertsFailed << " failed" << Color::None << " |\n"; - s << Color::Cyan << "[doctest] " << Color::None - << "Status: " << (p.numTestCasesFailed > 0 ? Color::Red : Color::Green) - << ((p.numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl; - } - - void test_case_start(const TestCaseData& in) override { - hasLoggedCurrentTestStart = false; - tc = ∈ - subcasesStack.clear(); - currentSubcaseLevel = 0; - } - - void test_case_reenter(const TestCaseData&) override { - subcasesStack.clear(); - } - - void test_case_end(const CurrentTestCaseStats& st) override { - if(tc->m_no_output) - return; - - // log the preamble of the test case only if there is something - // else to print - something other than that an assert has failed - if(opt.duration || - (st.failure_flags && st.failure_flags != static_cast(TestCaseFailureReason::AssertFailure))) - logTestStart(); - - if(opt.duration) - s << Color::None << std::setprecision(6) << std::fixed << st.seconds - << " s: " << tc->m_name << "\n"; - - if(st.failure_flags & TestCaseFailureReason::Timeout) - s << Color::Red << "Test case exceeded time limit of " << std::setprecision(6) - << std::fixed << tc->m_timeout << "!\n"; - - if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedButDidnt) { - s << Color::Red << "Should have failed but didn't! Marking it as failed!\n"; - } else if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedAndDid) { - s << Color::Yellow << "Failed as expected so marking it as not failed\n"; - } else if(st.failure_flags & TestCaseFailureReason::CouldHaveFailedAndDid) { - s << Color::Yellow << "Allowed to fail so marking it as not failed\n"; - } else if(st.failure_flags & TestCaseFailureReason::DidntFailExactlyNumTimes) { - s << Color::Red << "Didn't fail exactly " << tc->m_expected_failures - << " times so marking it as failed!\n"; - } else if(st.failure_flags & TestCaseFailureReason::FailedExactlyNumTimes) { - s << Color::Yellow << "Failed exactly " << tc->m_expected_failures - << " times as expected so marking it as not failed!\n"; - } - if(st.failure_flags & TestCaseFailureReason::TooManyFailedAsserts) { - s << Color::Red << "Aborting - too many failed asserts!\n"; - } - s << Color::None; // lgtm [cpp/useless-expression] - } - - void test_case_exception(const TestCaseException& e) override { - DOCTEST_LOCK_MUTEX(mutex) - if(tc->m_no_output) - return; - - logTestStart(); - - file_line_to_stream(tc->m_file.c_str(), tc->m_line, " "); - successOrFailColoredStringToStream(false, e.is_crash ? assertType::is_require : - assertType::is_check); - s << Color::Red << (e.is_crash ? "test case CRASHED: " : "test case THREW exception: ") - << Color::Cyan << e.error_string << "\n"; - - int num_stringified_contexts = get_num_stringified_contexts(); - if(num_stringified_contexts) { - auto stringified_contexts = get_stringified_contexts(); - s << Color::None << " logged: "; - for(int i = num_stringified_contexts; i > 0; --i) { - s << (i == num_stringified_contexts ? "" : " ") - << stringified_contexts[i - 1] << "\n"; - } - } - s << "\n" << Color::None; - } - - void subcase_start(const SubcaseSignature& subc) override { - subcasesStack.push_back(subc); - ++currentSubcaseLevel; - hasLoggedCurrentTestStart = false; - } - - void subcase_end() override { - --currentSubcaseLevel; - hasLoggedCurrentTestStart = false; - } - - void log_assert(const AssertData& rb) override { - if((!rb.m_failed && !opt.success) || tc->m_no_output) - return; - - DOCTEST_LOCK_MUTEX(mutex) - - logTestStart(); - - file_line_to_stream(rb.m_file, rb.m_line, " "); - successOrFailColoredStringToStream(!rb.m_failed, rb.m_at); - - fulltext_log_assert_to_stream(s, rb); - - log_contexts(); - } - - void log_message(const MessageData& mb) override { - if(tc->m_no_output) - return; - - DOCTEST_LOCK_MUTEX(mutex) - - logTestStart(); - - file_line_to_stream(mb.m_file, mb.m_line, " "); - s << getSuccessOrFailColor(false, mb.m_severity) - << getSuccessOrFailString(mb.m_severity & assertType::is_warn, mb.m_severity, - "MESSAGE") << ": "; - s << Color::None << mb.m_string << "\n"; - log_contexts(); - } - - void test_case_skipped(const TestCaseData&) override {} - }; - - DOCTEST_REGISTER_REPORTER("console", 0, ConsoleReporter); - -#ifdef DOCTEST_PLATFORM_WINDOWS - struct DebugOutputWindowReporter : public ConsoleReporter - { - DOCTEST_THREAD_LOCAL static std::ostringstream oss; - - DebugOutputWindowReporter(const ContextOptions& co) - : ConsoleReporter(co, oss) {} - -#define DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(func, type, arg) \ - void func(type arg) override { \ - bool with_col = g_no_colors; \ - g_no_colors = false; \ - ConsoleReporter::func(arg); \ - if(oss.tellp() != std::streampos{}) { \ - DOCTEST_OUTPUT_DEBUG_STRING(oss.str().c_str()); \ - oss.str(""); \ - } \ - g_no_colors = with_col; \ - } - - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_start, DOCTEST_EMPTY, DOCTEST_EMPTY) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_end, const TestRunStats&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_start, const TestCaseData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_reenter, const TestCaseData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_end, const CurrentTestCaseStats&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_exception, const TestCaseException&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_start, const SubcaseSignature&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_end, DOCTEST_EMPTY, DOCTEST_EMPTY) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_assert, const AssertData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_message, const MessageData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_skipped, const TestCaseData&, in) - }; - - DOCTEST_THREAD_LOCAL std::ostringstream DebugOutputWindowReporter::oss; -#endif // DOCTEST_PLATFORM_WINDOWS - - // the implementation of parseOption() - bool parseOptionImpl(int argc, const char* const* argv, const char* pattern, String* value) { - // going from the end to the beginning and stopping on the first occurrence from the end - for(int i = argc; i > 0; --i) { - auto index = i - 1; - auto temp = std::strstr(argv[index], pattern); - if(temp && (value || strlen(temp) == strlen(pattern))) { //!OCLINT prefer early exits and continue - // eliminate matches in which the chars before the option are not '-' - bool noBadCharsFound = true; - auto curr = argv[index]; - while(curr != temp) { - if(*curr++ != '-') { - noBadCharsFound = false; - break; - } - } - if(noBadCharsFound && argv[index][0] == '-') { - if(value) { - // parsing the value of an option - temp += strlen(pattern); - const unsigned len = strlen(temp); - if(len) { - *value = temp; - return true; - } - } else { - // just a flag - no value - return true; - } - } - } - } - return false; - } - - // parses an option and returns the string after the '=' character - bool parseOption(int argc, const char* const* argv, const char* pattern, String* value = nullptr, - const String& defaultVal = String()) { - if(value) - *value = defaultVal; -#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS - // offset (normally 3 for "dt-") to skip prefix - if(parseOptionImpl(argc, argv, pattern + strlen(DOCTEST_CONFIG_OPTIONS_PREFIX), value)) - return true; -#endif // DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS - return parseOptionImpl(argc, argv, pattern, value); - } - - // locates a flag on the command line - bool parseFlag(int argc, const char* const* argv, const char* pattern) { - return parseOption(argc, argv, pattern); - } - - // parses a comma separated list of words after a pattern in one of the arguments in argv - bool parseCommaSepArgs(int argc, const char* const* argv, const char* pattern, - std::vector& res) { - String filtersString; - if(parseOption(argc, argv, pattern, &filtersString)) { - // tokenize with "," as a separator, unless escaped with backslash - std::ostringstream s; - auto flush = [&s, &res]() { - auto string = s.str(); - if(string.size() > 0) { - res.push_back(string.c_str()); - } - s.str(""); - }; - - bool seenBackslash = false; - const char* current = filtersString.c_str(); - const char* end = current + strlen(current); - while(current != end) { - char character = *current++; - if(seenBackslash) { - seenBackslash = false; - if(character == ',' || character == '\\') { - s.put(character); - continue; - } - s.put('\\'); - } - if(character == '\\') { - seenBackslash = true; - } else if(character == ',') { - flush(); - } else { - s.put(character); - } - } - - if(seenBackslash) { - s.put('\\'); - } - flush(); - return true; - } - return false; - } - - enum optionType - { - option_bool, - option_int - }; - - // parses an int/bool option from the command line - bool parseIntOption(int argc, const char* const* argv, const char* pattern, optionType type, - int& res) { - String parsedValue; - if(!parseOption(argc, argv, pattern, &parsedValue)) - return false; - - if(type) { - // integer - // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... - int theInt = std::atoi(parsedValue.c_str()); - if (theInt != 0) { - res = theInt; //!OCLINT parameter reassignment - return true; - } - } else { - // boolean - const char positive[][5] = { "1", "true", "on", "yes" }; // 5 - strlen("true") + 1 - const char negative[][6] = { "0", "false", "off", "no" }; // 6 - strlen("false") + 1 - - // if the value matches any of the positive/negative possibilities - for (unsigned i = 0; i < 4; i++) { - if (parsedValue.compare(positive[i], true) == 0) { - res = 1; //!OCLINT parameter reassignment - return true; - } - if (parsedValue.compare(negative[i], true) == 0) { - res = 0; //!OCLINT parameter reassignment - return true; - } - } - } - return false; - } -} // namespace - -Context::Context(int argc, const char* const* argv) - : p(new detail::ContextState) { - parseArgs(argc, argv, true); - if(argc) - p->binary_name = argv[0]; -} - -Context::~Context() { - if(g_cs == p) - g_cs = nullptr; - delete p; -} - -void Context::applyCommandLine(int argc, const char* const* argv) { - parseArgs(argc, argv); - if(argc) - p->binary_name = argv[0]; -} - -// parses args -void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { - using namespace detail; - - // clang-format off - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file=", p->filters[0]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sf=", p->filters[0]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file-exclude=",p->filters[1]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sfe=", p->filters[1]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite=", p->filters[2]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ts=", p->filters[2]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite-exclude=", p->filters[3]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tse=", p->filters[3]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case=", p->filters[4]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tc=", p->filters[4]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case-exclude=", p->filters[5]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tce=", p->filters[5]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase=", p->filters[6]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sc=", p->filters[6]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase-exclude=", p->filters[7]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sce=", p->filters[7]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "reporters=", p->filters[8]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "r=", p->filters[8]); - // clang-format on - - int intRes = 0; - String strRes; - -#define DOCTEST_PARSE_AS_BOOL_OR_FLAG(name, sname, var, default) \ - if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_bool, intRes) || \ - parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_bool, intRes)) \ - p->var = static_cast(intRes); \ - else if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name) || \ - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname)) \ - p->var = true; \ - else if(withDefaults) \ - p->var = default - -#define DOCTEST_PARSE_INT_OPTION(name, sname, var, default) \ - if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_int, intRes) || \ - parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_int, intRes)) \ - p->var = intRes; \ - else if(withDefaults) \ - p->var = default - -#define DOCTEST_PARSE_STR_OPTION(name, sname, var, default) \ - if(parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", &strRes, default) || \ - parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", &strRes, default) || \ - withDefaults) \ - p->var = strRes - - // clang-format off - DOCTEST_PARSE_STR_OPTION("out", "o", out, ""); - DOCTEST_PARSE_STR_OPTION("order-by", "ob", order_by, "file"); - DOCTEST_PARSE_INT_OPTION("rand-seed", "rs", rand_seed, 0); - - DOCTEST_PARSE_INT_OPTION("first", "f", first, 0); - DOCTEST_PARSE_INT_OPTION("last", "l", last, UINT_MAX); - - DOCTEST_PARSE_INT_OPTION("abort-after", "aa", abort_after, 0); - DOCTEST_PARSE_INT_OPTION("subcase-filter-levels", "scfl", subcase_filter_levels, INT_MAX); - - DOCTEST_PARSE_AS_BOOL_OR_FLAG("success", "s", success, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("minimal", "m", minimal, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("quiet", "q", quiet, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-intro", "ni", no_intro, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-breaks", "nb", no_breaks, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skip", "ns", no_skip, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("gnu-file-line", "gfl", gnu_file_line, !bool(DOCTEST_MSVC)); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-path-filenames", "npf", no_path_in_filenames, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-line-numbers", "nln", no_line_numbers, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-debug-output", "ndo", no_debug_output, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skipped-summary", "nss", no_skipped_summary, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-time-in-output", "ntio", no_time_in_output, false); - // clang-format on - - if(withDefaults) { - p->help = false; - p->version = false; - p->count = false; - p->list_test_cases = false; - p->list_test_suites = false; - p->list_reporters = false; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "help") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "h") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "?")) { - p->help = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "version") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "v")) { - p->version = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "count") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "c")) { - p->count = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-cases") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ltc")) { - p->list_test_cases = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-suites") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lts")) { - p->list_test_suites = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-reporters") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lr")) { - p->list_reporters = true; - p->exit = true; - } -} - -// allows the user to add procedurally to the filters from the command line -void Context::addFilter(const char* filter, const char* value) { setOption(filter, value); } - -// allows the user to clear all filters from the command line -void Context::clearFilters() { - for(auto& curr : p->filters) - curr.clear(); -} - -// allows the user to override procedurally the bool options from the command line -void Context::setOption(const char* option, bool value) { - setOption(option, value ? "true" : "false"); -} - -// allows the user to override procedurally the int options from the command line -void Context::setOption(const char* option, int value) { - setOption(option, toString(value).c_str()); -} - -// allows the user to override procedurally the string options from the command line -void Context::setOption(const char* option, const char* value) { - auto argv = String("-") + option + "=" + value; - auto lvalue = argv.c_str(); - parseArgs(1, &lvalue); -} - -// users should query this in their main() and exit the program if true -bool Context::shouldExit() { return p->exit; } - -void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } - -void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } - -void Context::setCout(std::ostream* out) { p->cout = out; } - -static class DiscardOStream : public std::ostream -{ -private: - class : public std::streambuf - { - private: - // allowing some buffering decreases the amount of calls to overflow - char buf[1024]; - - protected: - std::streamsize xsputn(const char_type*, std::streamsize count) override { return count; } - - int_type overflow(int_type ch) override { - setp(std::begin(buf), std::end(buf)); - return traits_type::not_eof(ch); - } - } discardBuf; - -public: - DiscardOStream() - : std::ostream(&discardBuf) {} -} discardOut; - -// the main function that does all the filtering and test running -int Context::run() { - using namespace detail; - - // save the old context state in case such was setup - for using asserts out of a testing context - auto old_cs = g_cs; - // this is the current contest - g_cs = p; - is_running_in_test = true; - - g_no_colors = p->no_colors; - p->resetRunData(); - - std::fstream fstr; - if(p->cout == nullptr) { - if(p->quiet) { - p->cout = &discardOut; - } else if(p->out.size()) { - // to a file if specified - fstr.open(p->out.c_str(), std::fstream::out); - p->cout = &fstr; - } else { - // stdout by default - p->cout = &std::cout; - } - } - - FatalConditionHandler::allocateAltStackMem(); - - auto cleanup_and_return = [&]() { - FatalConditionHandler::freeAltStackMem(); - - if(fstr.is_open()) - fstr.close(); - - // restore context - g_cs = old_cs; - is_running_in_test = false; - - // we have to free the reporters which were allocated when the run started - for(auto& curr : p->reporters_currently_used) - delete curr; - p->reporters_currently_used.clear(); - - if(p->numTestCasesFailed && !p->no_exitcode) - return EXIT_FAILURE; - return EXIT_SUCCESS; - }; - - // setup default reporter if none is given through the command line - if(p->filters[8].empty()) - p->filters[8].push_back("console"); - - // check to see if any of the registered reporters has been selected - for(auto& curr : getReporters()) { - if(matchesAny(curr.first.second.c_str(), p->filters[8], false, p->case_sensitive)) - p->reporters_currently_used.push_back(curr.second(*g_cs)); - } - - // TODO: check if there is nothing in reporters_currently_used - - // prepend all listeners - for(auto& curr : getListeners()) - p->reporters_currently_used.insert(p->reporters_currently_used.begin(), curr.second(*g_cs)); - -#ifdef DOCTEST_PLATFORM_WINDOWS - if(isDebuggerActive() && p->no_debug_output == false) - p->reporters_currently_used.push_back(new DebugOutputWindowReporter(*g_cs)); -#endif // DOCTEST_PLATFORM_WINDOWS - - // handle version, help and no_run - if(p->no_run || p->version || p->help || p->list_reporters) { - DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, QueryData()); - - return cleanup_and_return(); - } - - std::vector testArray; - for(auto& curr : getRegisteredTests()) - testArray.push_back(&curr); - p->numTestCases = testArray.size(); - - // sort the collected records - if(!testArray.empty()) { - if(p->order_by.compare("file", true) == 0) { - std::sort(testArray.begin(), testArray.end(), fileOrderComparator); - } else if(p->order_by.compare("suite", true) == 0) { - std::sort(testArray.begin(), testArray.end(), suiteOrderComparator); - } else if(p->order_by.compare("name", true) == 0) { - std::sort(testArray.begin(), testArray.end(), nameOrderComparator); - } else if(p->order_by.compare("rand", true) == 0) { - std::srand(p->rand_seed); - - // random_shuffle implementation - const auto first = &testArray[0]; - for(size_t i = testArray.size() - 1; i > 0; --i) { - int idxToSwap = std::rand() % (i + 1); - - const auto temp = first[i]; - - first[i] = first[idxToSwap]; - first[idxToSwap] = temp; - } - } else if(p->order_by.compare("none", true) == 0) { - // means no sorting - beneficial for death tests which call into the executable - // with a specific test case in mind - we don't want to slow down the startup times - } - } - - std::set testSuitesPassingFilt; - - bool query_mode = p->count || p->list_test_cases || p->list_test_suites; - std::vector queryResults; - - if(!query_mode) - DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_start, DOCTEST_EMPTY); - - // invoke the registered functions if they match the filter criteria (or just count them) - for(auto& curr : testArray) { - const auto& tc = *curr; - - bool skip_me = false; - if(tc.m_skip && !p->no_skip) - skip_me = true; - - if(!matchesAny(tc.m_file.c_str(), p->filters[0], true, p->case_sensitive)) - skip_me = true; - if(matchesAny(tc.m_file.c_str(), p->filters[1], false, p->case_sensitive)) - skip_me = true; - if(!matchesAny(tc.m_test_suite, p->filters[2], true, p->case_sensitive)) - skip_me = true; - if(matchesAny(tc.m_test_suite, p->filters[3], false, p->case_sensitive)) - skip_me = true; - if(!matchesAny(tc.m_name, p->filters[4], true, p->case_sensitive)) - skip_me = true; - if(matchesAny(tc.m_name, p->filters[5], false, p->case_sensitive)) - skip_me = true; - - if(!skip_me) - p->numTestCasesPassingFilters++; - - // skip the test if it is not in the execution range - if((p->last < p->numTestCasesPassingFilters && p->first <= p->last) || - (p->first > p->numTestCasesPassingFilters)) - skip_me = true; - - if(skip_me) { - if(!query_mode) - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_skipped, tc); - continue; - } - - // do not execute the test if we are to only count the number of filter passing tests - if(p->count) - continue; - - // print the name of the test and don't execute it - if(p->list_test_cases) { - queryResults.push_back(&tc); - continue; - } - - // print the name of the test suite if not done already and don't execute it - if(p->list_test_suites) { - if((testSuitesPassingFilt.count(tc.m_test_suite) == 0) && tc.m_test_suite[0] != '\0') { - queryResults.push_back(&tc); - testSuitesPassingFilt.insert(tc.m_test_suite); - p->numTestSuitesPassingFilters++; - } - continue; - } - - // execute the test if it passes all the filtering - { - p->currentTest = &tc; - - p->failure_flags = TestCaseFailureReason::None; - p->seconds = 0; - - // reset atomic counters - p->numAssertsFailedCurrentTest_atomic = 0; - p->numAssertsCurrentTest_atomic = 0; - - p->fullyTraversedSubcases.clear(); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); - - p->timer.start(); - - bool run_test = true; - - do { - // reset some of the fields for subcases (except for the set of fully passed ones) - p->reachedLeaf = false; - // May not be empty if previous subcase exited via exception. - p->subcaseStack.clear(); - p->currentSubcaseDepth = 0; - - p->shouldLogCurrentException = true; - - // reset stuff for logging with INFO() - p->stringifiedContexts.clear(); - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - try { -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS -// MSVC 2015 diagnoses fatalConditionHandler as unused (because reset() is a static method) -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4101) // unreferenced local variable - FatalConditionHandler fatalConditionHandler; // Handle signals - // execute the test - tc.m_test(); - fatalConditionHandler.reset(); -DOCTEST_MSVC_SUPPRESS_WARNING_POP -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - } catch(const TestFailureException&) { - p->failure_flags |= TestCaseFailureReason::AssertFailure; - } catch(...) { - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, - {translateActiveException(), false}); - p->failure_flags |= TestCaseFailureReason::Exception; - } -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - - // exit this loop if enough assertions have failed - even if there are more subcases - if(p->abort_after > 0 && - p->numAssertsFailed + p->numAssertsFailedCurrentTest_atomic >= p->abort_after) { - run_test = false; - p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; - } - - if(!p->nextSubcaseStack.empty() && run_test) - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); - if(p->nextSubcaseStack.empty()) - run_test = false; - } while(run_test); - - p->finalizeTestCaseData(); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); - - p->currentTest = nullptr; - - // stop executing tests if enough assertions have failed - if(p->abort_after > 0 && p->numAssertsFailed >= p->abort_after) - break; - } - } - - if(!query_mode) { - DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); - } else { - QueryData qdata; - qdata.run_stats = g_cs; - qdata.data = queryResults.data(); - qdata.num_data = unsigned(queryResults.size()); - DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); - } - - return cleanup_and_return(); -} - -DOCTEST_DEFINE_INTERFACE(IReporter) - -int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } -const IContextScope* const* IReporter::get_active_contexts() { - return get_num_active_contexts() ? &detail::g_infoContexts[0] : nullptr; -} - -int IReporter::get_num_stringified_contexts() { return detail::g_cs->stringifiedContexts.size(); } -const String* IReporter::get_stringified_contexts() { - return get_num_stringified_contexts() ? &detail::g_cs->stringifiedContexts[0] : nullptr; -} - -namespace detail { - void registerReporterImpl(const char* name, int priority, reporterCreatorFunc c, bool isReporter) { - if(isReporter) - getReporters().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); - else - getListeners().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); - } -} // namespace detail - -} // namespace doctest - -#endif // DOCTEST_CONFIG_DISABLE - -#ifdef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) // 'function' : must be 'attribute' - see issue #182 -int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } -DOCTEST_MSVC_SUPPRESS_WARNING_POP -#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN - -DOCTEST_CLANG_SUPPRESS_WARNING_POP -DOCTEST_MSVC_SUPPRESS_WARNING_POP -DOCTEST_GCC_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_POP - -#endif // DOCTEST_LIBRARY_IMPLEMENTATION -#endif // DOCTEST_CONFIG_IMPLEMENT +// ====================================================================== lgtm [cpp/missing-header-guard] +// == DO NOT MODIFY THIS FILE BY HAND - IT IS AUTO GENERATED BY CMAKE! == +// ====================================================================== +// +// doctest.h - the lightest feature-rich C++ single-header testing framework for unit tests and TDD +// +// Copyright (c) 2016-2023 Viktor Kirilov +// +// Distributed under the MIT Software License +// See accompanying file LICENSE.txt or copy at +// https://opensource.org/licenses/MIT +// +// The documentation can be found at the library's page: +// https://github.com/doctest/doctest/blob/master/doc/markdown/readme.md +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= +// +// The library is heavily influenced by Catch - https://github.com/catchorg/Catch2 +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/catchorg/Catch2/blob/master/LICENSE.txt +// +// The concept of subcases (sections in Catch) and expression decomposition are from there. +// Some parts of the code are taken directly: +// - stringification - the detection of "ostream& operator<<(ostream&, const T&)" and StringMaker<> +// - the Approx() helper class for floating point comparison +// - colors in the console +// - breaking into a debugger +// - signal / SEH handling +// - timer +// - XmlWriter class - thanks to Phil Nash for allowing the direct reuse (AKA copy/paste) +// +// The expression decomposing templates are taken from lest - https://github.com/martinmoene/lest +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/martinmoene/lest/blob/master/LICENSE.txt +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= + +#ifndef DOCTEST_LIBRARY_INCLUDED +#define DOCTEST_LIBRARY_INCLUDED + +// ================================================================================================= +// == VERSION ====================================================================================== +// ================================================================================================= + +#define DOCTEST_VERSION_MAJOR 2 +#define DOCTEST_VERSION_MINOR 4 +#define DOCTEST_VERSION_PATCH 11 + +// util we need here +#define DOCTEST_TOSTR_IMPL(x) #x +#define DOCTEST_TOSTR(x) DOCTEST_TOSTR_IMPL(x) + +#define DOCTEST_VERSION_STR \ + DOCTEST_TOSTR(DOCTEST_VERSION_MAJOR) "." \ + DOCTEST_TOSTR(DOCTEST_VERSION_MINOR) "." \ + DOCTEST_TOSTR(DOCTEST_VERSION_PATCH) + +#define DOCTEST_VERSION \ + (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) + +// ================================================================================================= +// == COMPILER VERSION ============================================================================= +// ================================================================================================= + +// ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect + +#ifdef _MSC_VER +#define DOCTEST_CPLUSPLUS _MSVC_LANG +#else +#define DOCTEST_CPLUSPLUS __cplusplus +#endif + +#define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) + +// GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... +#if defined(_MSC_VER) && defined(_MSC_FULL_VER) +#if _MSC_VER == _MSC_FULL_VER / 10000 +#define DOCTEST_MSVC DOCTEST_COMPILER(_MSC_VER / 100, _MSC_VER % 100, _MSC_FULL_VER % 10000) +#else // MSVC +#define DOCTEST_MSVC \ + DOCTEST_COMPILER(_MSC_VER / 100, (_MSC_FULL_VER / 100000) % 100, _MSC_FULL_VER % 100000) +#endif // MSVC +#endif // MSVC +#if defined(__clang__) && defined(__clang_minor__) && defined(__clang_patchlevel__) +#define DOCTEST_CLANG DOCTEST_COMPILER(__clang_major__, __clang_minor__, __clang_patchlevel__) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) && \ + !defined(__INTEL_COMPILER) +#define DOCTEST_GCC DOCTEST_COMPILER(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#endif // GCC +#if defined(__INTEL_COMPILER) +#define DOCTEST_ICC DOCTEST_COMPILER(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) +#endif // ICC + +#ifndef DOCTEST_MSVC +#define DOCTEST_MSVC 0 +#endif // DOCTEST_MSVC +#ifndef DOCTEST_CLANG +#define DOCTEST_CLANG 0 +#endif // DOCTEST_CLANG +#ifndef DOCTEST_GCC +#define DOCTEST_GCC 0 +#endif // DOCTEST_GCC +#ifndef DOCTEST_ICC +#define DOCTEST_ICC 0 +#endif // DOCTEST_ICC + +// ================================================================================================= +// == COMPILER WARNINGS HELPERS ==================================================================== +// ================================================================================================= + +#if DOCTEST_CLANG && !DOCTEST_ICC +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(clang diagnostic ignored w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH DOCTEST_CLANG_SUPPRESS_WARNING(w) +#else // DOCTEST_CLANG +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_CLANG + +#if DOCTEST_GCC +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH _Pragma("GCC diagnostic push") +#define DOCTEST_GCC_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(GCC diagnostic ignored w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP _Pragma("GCC diagnostic pop") +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH DOCTEST_GCC_SUPPRESS_WARNING(w) +#else // DOCTEST_GCC +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH +#define DOCTEST_GCC_SUPPRESS_WARNING(w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_GCC + +#if DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH __pragma(warning(push)) +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) __pragma(warning(disable : w)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP __pragma(warning(pop)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH DOCTEST_MSVC_SUPPRESS_WARNING(w) +#else // DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_MSVC + +// ================================================================================================= +// == COMPILER WARNINGS ============================================================================ +// ================================================================================================= + +// both the header and the implementation suppress all of these, +// so it only makes sense to aggregate them like so +#define DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") \ + \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") \ + \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + /* these 4 also disabled globally via cmake: */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4514) /* unreferenced inline function has been removed */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4571) /* SEH related */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4710) /* function not inlined */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4711) /* function selected for inline expansion*/ \ + /* common ones */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4616) /* invalid compiler warning */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4619) /* invalid compiler warning */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4996) /* The compiler encountered a deprecated declaration */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4706) /* assignment within conditional expression */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4512) /* 'class' : assignment operator could not be generated */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4127) /* conditional expression is constant */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4640) /* construction of local static object not thread-safe */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5264) /* 'variable-name': 'const' variable is not used */ \ + /* static analysis */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26439) /* Function may not throw. Declare it 'noexcept' */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26495) /* Always initialize a member variable */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26451) /* Arithmetic overflow ... */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26444) /* Avoid unnamed objects with custom ctor and dtor... */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26812) /* Prefer 'enum class' over 'enum' */ + +#define DOCTEST_SUPPRESS_COMMON_WARNINGS_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + DOCTEST_GCC_SUPPRESS_WARNING_POP \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + DOCTEST_MSVC_SUPPRESS_WARNING(4548) /* before comma no effect; expected side - effect */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4265) /* virtual functions, but destructor is not virtual */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4986) /* exception specification does not match previous */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4350) /* 'member1' called instead of 'member2' */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4668) /* not defined as a preprocessor macro */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4365) /* signed/unsigned mismatch */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4774) /* format string not a string literal */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4623) /* default constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5039) /* pointer to pot. throwing function passed to extern C */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5105) /* macro producing 'defined' has undefined behavior */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4738) /* storing float result in memory, loss of performance */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5262) /* implicit fall-through */ + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP + +// ================================================================================================= +// == FEATURE DETECTION ============================================================================ +// ================================================================================================= + +// general compiler feature support table: https://en.cppreference.com/w/cpp/compiler_support +// MSVC C++11 feature support table: https://msdn.microsoft.com/en-us/library/hh567368.aspx +// GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html +// MSVC version table: +// https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering +// MSVC++ 14.3 (17) _MSC_VER == 1930 (Visual Studio 2022) +// MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) +// MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) +// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) +// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) +// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) +// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) +// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) +// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) + +// Universal Windows Platform support +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define DOCTEST_CONFIG_NO_WINDOWS_SEH +#endif // WINAPI_FAMILY +#if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) +#define DOCTEST_CONFIG_WINDOWS_SEH +#endif // MSVC +#if defined(DOCTEST_CONFIG_NO_WINDOWS_SEH) && defined(DOCTEST_CONFIG_WINDOWS_SEH) +#undef DOCTEST_CONFIG_WINDOWS_SEH +#endif // DOCTEST_CONFIG_NO_WINDOWS_SEH + +#if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ + !defined(__EMSCRIPTEN__) && !defined(__wasi__) +#define DOCTEST_CONFIG_POSIX_SIGNALS +#endif // _WIN32 +#if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) +#undef DOCTEST_CONFIG_POSIX_SIGNALS +#endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) \ + || defined(__wasi__) +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // no exceptions +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#if defined(DOCTEST_CONFIG_NO_EXCEPTIONS) && !defined(DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS) +#define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#ifdef __wasi__ +#define DOCTEST_CONFIG_NO_MULTITHREADING +#endif + +#if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) +#define DOCTEST_CONFIG_IMPLEMENT +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#if defined(_WIN32) || defined(__CYGWIN__) +#if DOCTEST_MSVC +#define DOCTEST_SYMBOL_EXPORT __declspec(dllexport) +#define DOCTEST_SYMBOL_IMPORT __declspec(dllimport) +#else // MSVC +#define DOCTEST_SYMBOL_EXPORT __attribute__((dllexport)) +#define DOCTEST_SYMBOL_IMPORT __attribute__((dllimport)) +#endif // MSVC +#else // _WIN32 +#define DOCTEST_SYMBOL_EXPORT __attribute__((visibility("default"))) +#define DOCTEST_SYMBOL_IMPORT +#endif // _WIN32 + +#ifdef DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#ifdef DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_EXPORT +#else // DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_IMPORT +#endif // DOCTEST_CONFIG_IMPLEMENT +#else // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#define DOCTEST_INTERFACE +#endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL + +// needed for extern template instantiations +// see https://github.com/fmtlib/fmt/issues/2228 +#if DOCTEST_MSVC +#define DOCTEST_INTERFACE_DECL +#define DOCTEST_INTERFACE_DEF DOCTEST_INTERFACE +#else // DOCTEST_MSVC +#define DOCTEST_INTERFACE_DECL DOCTEST_INTERFACE +#define DOCTEST_INTERFACE_DEF +#endif // DOCTEST_MSVC + +#define DOCTEST_EMPTY + +#if DOCTEST_MSVC +#define DOCTEST_NOINLINE __declspec(noinline) +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#elif DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 5, 0) +#define DOCTEST_NOINLINE +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#else +#define DOCTEST_NOINLINE __attribute__((noinline)) +#define DOCTEST_UNUSED __attribute__((unused)) +#define DOCTEST_ALIGNMENT(x) __attribute__((aligned(x))) +#endif + +#ifdef DOCTEST_CONFIG_NO_CONTRADICTING_INLINE +#define DOCTEST_INLINE_NOINLINE inline +#else +#define DOCTEST_INLINE_NOINLINE inline DOCTEST_NOINLINE +#endif + +#ifndef DOCTEST_NORETURN +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_NORETURN +#else // DOCTEST_MSVC +#define DOCTEST_NORETURN [[noreturn]] +#endif // DOCTEST_MSVC +#endif // DOCTEST_NORETURN + +#ifndef DOCTEST_NOEXCEPT +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_NOEXCEPT +#else // DOCTEST_MSVC +#define DOCTEST_NOEXCEPT noexcept +#endif // DOCTEST_MSVC +#endif // DOCTEST_NOEXCEPT + +#ifndef DOCTEST_CONSTEXPR +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_CONSTEXPR const +#define DOCTEST_CONSTEXPR_FUNC inline +#else // DOCTEST_MSVC +#define DOCTEST_CONSTEXPR constexpr +#define DOCTEST_CONSTEXPR_FUNC constexpr +#endif // DOCTEST_MSVC +#endif // DOCTEST_CONSTEXPR + +#ifndef DOCTEST_NO_SANITIZE_INTEGER +#if DOCTEST_CLANG >= DOCTEST_COMPILER(3, 7, 0) +#define DOCTEST_NO_SANITIZE_INTEGER __attribute__((no_sanitize("integer"))) +#else +#define DOCTEST_NO_SANITIZE_INTEGER +#endif +#endif // DOCTEST_NO_SANITIZE_INTEGER + +// ================================================================================================= +// == FEATURE DETECTION END ======================================================================== +// ================================================================================================= + +#define DOCTEST_DECLARE_INTERFACE(name) \ + virtual ~name(); \ + name() = default; \ + name(const name&) = delete; \ + name(name&&) = delete; \ + name& operator=(const name&) = delete; \ + name& operator=(name&&) = delete; + +#define DOCTEST_DEFINE_INTERFACE(name) \ + name::~name() = default; + +// internal macros for string concatenation and anonymous variable name generation +#define DOCTEST_CAT_IMPL(s1, s2) s1##s2 +#define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) +#ifdef __COUNTER__ // not standard and may be missing for some compilers +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __COUNTER__) +#else // __COUNTER__ +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) +#endif // __COUNTER__ + +#ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x& +#else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x +#endif // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE + +// not using __APPLE__ because... this is how Catch does it +#ifdef __MAC_OS_X_VERSION_MIN_REQUIRED +#define DOCTEST_PLATFORM_MAC +#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) +#define DOCTEST_PLATFORM_IPHONE +#elif defined(_WIN32) +#define DOCTEST_PLATFORM_WINDOWS +#elif defined(__wasi__) +#define DOCTEST_PLATFORM_WASI +#else // DOCTEST_PLATFORM +#define DOCTEST_PLATFORM_LINUX +#endif // DOCTEST_PLATFORM + +namespace doctest { namespace detail { + static DOCTEST_CONSTEXPR int consume(const int*, int) noexcept { return 0; } +}} + +#define DOCTEST_GLOBAL_NO_WARNINGS(var, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ + static const int var = doctest::detail::consume(&var, __VA_ARGS__); \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#ifndef DOCTEST_BREAK_INTO_DEBUGGER +// should probably take a look at https://github.com/scottt/debugbreak +#ifdef DOCTEST_PLATFORM_LINUX +#if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) +// Break at the location of the failing check if possible +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) +#else +#include +#define DOCTEST_BREAK_INTO_DEBUGGER() raise(SIGTRAP) +#endif +#elif defined(DOCTEST_PLATFORM_MAC) +#if defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || defined(__i386) +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) +#elif defined(__ppc__) || defined(__ppc64__) +// https://www.cocoawithlove.com/2008/03/break-into-debugger.html +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n": : : "memory","r0","r3","r4") // NOLINT(hicpp-no-assembler) +#else +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT(hicpp-no-assembler) +#endif +#elif DOCTEST_MSVC +#define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() +#elif defined(__MINGW32__) +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wredundant-decls") +extern "C" __declspec(dllimport) void __stdcall DebugBreak(); +DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_BREAK_INTO_DEBUGGER() ::DebugBreak() +#else // linux +#define DOCTEST_BREAK_INTO_DEBUGGER() (static_cast(0)) +#endif // linux +#endif // DOCTEST_BREAK_INTO_DEBUGGER + +// this is kept here for backwards compatibility since the config option was changed +#ifdef DOCTEST_CONFIG_USE_IOSFWD +#ifndef DOCTEST_CONFIG_USE_STD_HEADERS +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif +#endif // DOCTEST_CONFIG_USE_IOSFWD + +// for clang - always include ciso646 (which drags some std stuff) because +// we want to check if we are using libc++ with the _LIBCPP_VERSION macro in +// which case we don't want to forward declare stuff from std - for reference: +// https://github.com/doctest/doctest/issues/126 +// https://github.com/doctest/doctest/issues/356 +#if DOCTEST_CLANG +#include +#endif // clang + +#ifdef _LIBCPP_VERSION +#ifndef DOCTEST_CONFIG_USE_STD_HEADERS +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif +#endif // _LIBCPP_VERSION + +#ifdef DOCTEST_CONFIG_USE_STD_HEADERS +#ifndef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN +#include +#include +#include +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END +#else // DOCTEST_CONFIG_USE_STD_HEADERS + +// Forward declaring 'X' in namespace std is not permitted by the C++ Standard. +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) + +namespace std { // NOLINT(cert-dcl58-cpp) +typedef decltype(nullptr) nullptr_t; // NOLINT(modernize-use-using) +typedef decltype(sizeof(void*)) size_t; // NOLINT(modernize-use-using) +template +struct char_traits; +template <> +struct char_traits; +template +class basic_ostream; // NOLINT(fuchsia-virtual-inheritance) +typedef basic_ostream> ostream; // NOLINT(modernize-use-using) +template +// NOLINTNEXTLINE +basic_ostream& operator<<(basic_ostream&, const char*); +template +class basic_istream; +typedef basic_istream> istream; // NOLINT(modernize-use-using) +template +class tuple; +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +template +class allocator; +template +class basic_string; +using string = basic_string, allocator>; +#endif // VS 2019 +} // namespace std + +DOCTEST_MSVC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_USE_STD_HEADERS + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#include +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + +namespace doctest { + +using std::size_t; + +DOCTEST_INTERFACE extern bool is_running_in_test; + +#ifndef DOCTEST_CONFIG_STRING_SIZE_TYPE +#define DOCTEST_CONFIG_STRING_SIZE_TYPE unsigned +#endif + +// A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length +// of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: +// - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) +// - if small - capacity left before going on the heap - using the lowest 5 bits +// - if small - 2 bits are left unused - the second and third highest ones +// - if small - acts as a null terminator if strlen() is 23 (24 including the null terminator) +// and the "is small" bit remains "0" ("as well as the capacity left") so its OK +// Idea taken from this lecture about the string implementation of facebook/folly - fbstring +// https://www.youtube.com/watch?v=kPR8h4-qZdk +// TODO: +// - optimizations - like not deleting memory unnecessarily in operator= and etc. +// - resize/reserve/clear +// - replace +// - back/front +// - iterator stuff +// - find & friends +// - push_back/pop_back +// - assign/insert/erase +// - relational operators as free functions - taking const char* as one of the params +class DOCTEST_INTERFACE String +{ +public: + using size_type = DOCTEST_CONFIG_STRING_SIZE_TYPE; + +private: + static DOCTEST_CONSTEXPR size_type len = 24; //!OCLINT avoid private static members + static DOCTEST_CONSTEXPR size_type last = len - 1; //!OCLINT avoid private static members + + struct view // len should be more than sizeof(view) - because of the final byte for flags + { + char* ptr; + size_type size; + size_type capacity; + }; + + union + { + char buf[len]; // NOLINT(*-avoid-c-arrays) + view data; + }; + + char* allocate(size_type sz); + + bool isOnStack() const noexcept { return (buf[last] & 128) == 0; } + void setOnHeap() noexcept; + void setLast(size_type in = last) noexcept; + void setSize(size_type sz) noexcept; + + void copy(const String& other); + +public: + static DOCTEST_CONSTEXPR size_type npos = static_cast(-1); + + String() noexcept; + ~String(); + + // cppcheck-suppress noExplicitConstructor + String(const char* in); + String(const char* in, size_type in_size); + + String(std::istream& in, size_type in_size); + + String(const String& other); + String& operator=(const String& other); + + String& operator+=(const String& other); + + String(String&& other) noexcept; + String& operator=(String&& other) noexcept; + + char operator[](size_type i) const; + char& operator[](size_type i); + + // the only functions I'm willing to leave in the interface - available for inlining + const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT + char* c_str() { + if (isOnStack()) { + return reinterpret_cast(buf); + } + return data.ptr; + } + + size_type size() const; + size_type capacity() const; + + String substr(size_type pos, size_type cnt = npos) &&; + String substr(size_type pos, size_type cnt = npos) const &; + + size_type find(char ch, size_type pos = 0) const; + size_type rfind(char ch, size_type pos = npos) const; + + int compare(const char* other, bool no_case = false) const; + int compare(const String& other, bool no_case = false) const; + +friend DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); +}; + +DOCTEST_INTERFACE String operator+(const String& lhs, const String& rhs); + +DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); + +class DOCTEST_INTERFACE Contains { +public: + explicit Contains(const String& string); + + bool checkWith(const String& other) const; + + String string; +}; + +DOCTEST_INTERFACE String toString(const Contains& in); + +DOCTEST_INTERFACE bool operator==(const String& lhs, const Contains& rhs); +DOCTEST_INTERFACE bool operator==(const Contains& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const Contains& rhs); +DOCTEST_INTERFACE bool operator!=(const Contains& lhs, const String& rhs); + +namespace Color { + enum Enum + { + None = 0, + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White + }; + + DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, Color::Enum code); +} // namespace Color + +namespace assertType { + enum Enum + { + // macro traits + + is_warn = 1, + is_check = 2 * is_warn, + is_require = 2 * is_check, + + is_normal = 2 * is_require, + is_throws = 2 * is_normal, + is_throws_as = 2 * is_throws, + is_throws_with = 2 * is_throws_as, + is_nothrow = 2 * is_throws_with, + + is_false = 2 * is_nothrow, + is_unary = 2 * is_false, // not checked anywhere - used just to distinguish the types + + is_eq = 2 * is_unary, + is_ne = 2 * is_eq, + + is_lt = 2 * is_ne, + is_gt = 2 * is_lt, + + is_ge = 2 * is_gt, + is_le = 2 * is_ge, + + // macro types + + DT_WARN = is_normal | is_warn, + DT_CHECK = is_normal | is_check, + DT_REQUIRE = is_normal | is_require, + + DT_WARN_FALSE = is_normal | is_false | is_warn, + DT_CHECK_FALSE = is_normal | is_false | is_check, + DT_REQUIRE_FALSE = is_normal | is_false | is_require, + + DT_WARN_THROWS = is_throws | is_warn, + DT_CHECK_THROWS = is_throws | is_check, + DT_REQUIRE_THROWS = is_throws | is_require, + + DT_WARN_THROWS_AS = is_throws_as | is_warn, + DT_CHECK_THROWS_AS = is_throws_as | is_check, + DT_REQUIRE_THROWS_AS = is_throws_as | is_require, + + DT_WARN_THROWS_WITH = is_throws_with | is_warn, + DT_CHECK_THROWS_WITH = is_throws_with | is_check, + DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, + + DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, + DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, + DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, + + DT_WARN_NOTHROW = is_nothrow | is_warn, + DT_CHECK_NOTHROW = is_nothrow | is_check, + DT_REQUIRE_NOTHROW = is_nothrow | is_require, + + DT_WARN_EQ = is_normal | is_eq | is_warn, + DT_CHECK_EQ = is_normal | is_eq | is_check, + DT_REQUIRE_EQ = is_normal | is_eq | is_require, + + DT_WARN_NE = is_normal | is_ne | is_warn, + DT_CHECK_NE = is_normal | is_ne | is_check, + DT_REQUIRE_NE = is_normal | is_ne | is_require, + + DT_WARN_GT = is_normal | is_gt | is_warn, + DT_CHECK_GT = is_normal | is_gt | is_check, + DT_REQUIRE_GT = is_normal | is_gt | is_require, + + DT_WARN_LT = is_normal | is_lt | is_warn, + DT_CHECK_LT = is_normal | is_lt | is_check, + DT_REQUIRE_LT = is_normal | is_lt | is_require, + + DT_WARN_GE = is_normal | is_ge | is_warn, + DT_CHECK_GE = is_normal | is_ge | is_check, + DT_REQUIRE_GE = is_normal | is_ge | is_require, + + DT_WARN_LE = is_normal | is_le | is_warn, + DT_CHECK_LE = is_normal | is_le | is_check, + DT_REQUIRE_LE = is_normal | is_le | is_require, + + DT_WARN_UNARY = is_normal | is_unary | is_warn, + DT_CHECK_UNARY = is_normal | is_unary | is_check, + DT_REQUIRE_UNARY = is_normal | is_unary | is_require, + + DT_WARN_UNARY_FALSE = is_normal | is_false | is_unary | is_warn, + DT_CHECK_UNARY_FALSE = is_normal | is_false | is_unary | is_check, + DT_REQUIRE_UNARY_FALSE = is_normal | is_false | is_unary | is_require, + }; +} // namespace assertType + +DOCTEST_INTERFACE const char* assertString(assertType::Enum at); +DOCTEST_INTERFACE const char* failureString(assertType::Enum at); +DOCTEST_INTERFACE const char* skipPathFromFilename(const char* file); + +struct DOCTEST_INTERFACE TestCaseData +{ + String m_file; // the file in which the test was registered (using String - see #350) + unsigned m_line; // the line where the test was registered + const char* m_name; // name of the test case + const char* m_test_suite; // the test suite in which the test was added + const char* m_description; + bool m_skip; + bool m_no_breaks; + bool m_no_output; + bool m_may_fail; + bool m_should_fail; + int m_expected_failures; + double m_timeout; +}; + +struct DOCTEST_INTERFACE AssertData +{ + // common - for all asserts + const TestCaseData* m_test_case; + assertType::Enum m_at; + const char* m_file; + int m_line; + const char* m_expr; + bool m_failed; + + // exception-related - for all asserts + bool m_threw; + String m_exception; + + // for normal asserts + String m_decomp; + + // for specific exception-related asserts + bool m_threw_as; + const char* m_exception_type; + + class DOCTEST_INTERFACE StringContains { + private: + Contains content; + bool isContains; + + public: + StringContains(const String& str) : content(str), isContains(false) { } + StringContains(Contains cntn) : content(static_cast(cntn)), isContains(true) { } + + bool check(const String& str) { return isContains ? (content == str) : (content.string == str); } + + operator const String&() const { return content.string; } + + const char* c_str() const { return content.string.c_str(); } + } m_exception_string; + + AssertData(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const StringContains& exception_string); +}; + +struct DOCTEST_INTERFACE MessageData +{ + String m_string; + const char* m_file; + int m_line; + assertType::Enum m_severity; +}; + +struct DOCTEST_INTERFACE SubcaseSignature +{ + String m_name; + const char* m_file; + int m_line; + + bool operator==(const SubcaseSignature& other) const; + bool operator<(const SubcaseSignature& other) const; +}; + +struct DOCTEST_INTERFACE IContextScope +{ + DOCTEST_DECLARE_INTERFACE(IContextScope) + virtual void stringify(std::ostream*) const = 0; +}; + +namespace detail { + struct DOCTEST_INTERFACE TestCase; +} // namespace detail + +struct ContextOptions //!OCLINT too many fields +{ + std::ostream* cout = nullptr; // stdout stream + String binary_name; // the test binary name + + const detail::TestCase* currentTest = nullptr; + + // == parameters from the command line + String out; // output filename + String order_by; // how tests should be ordered + unsigned rand_seed; // the seed for rand ordering + + unsigned first; // the first (matching) test to be executed + unsigned last; // the last (matching) test to be executed + + int abort_after; // stop tests after this many failed assertions + int subcase_filter_levels; // apply the subcase filters for the first N levels + + bool success; // include successful assertions in output + bool case_sensitive; // if filtering should be case sensitive + bool exit; // if the program should be exited after the tests are ran/whatever + bool duration; // print the time duration of each test case + bool minimal; // minimal console output (only test failures) + bool quiet; // no console output + bool no_throw; // to skip exceptions-related assertion macros + bool no_exitcode; // if the framework should return 0 as the exitcode + bool no_run; // to not run the tests at all (can be done with an "*" exclude) + bool no_intro; // to not print the intro of the framework + bool no_version; // to not print the version of the framework + bool no_colors; // if output to the console should be colorized + bool force_colors; // forces the use of colors even when a tty cannot be detected + bool no_breaks; // to not break into the debugger + bool no_skip; // don't skip test cases which are marked to be skipped + bool gnu_file_line; // if line numbers should be surrounded with :x: and not (x): + bool no_path_in_filenames; // if the path to files should be removed from the output + bool no_line_numbers; // if source code line numbers should be omitted from the output + bool no_debug_output; // no output in the debug console when a debugger is attached + bool no_skipped_summary; // don't print "skipped" in the summary !!! UNDOCUMENTED !!! + bool no_time_in_output; // omit any time/timestamps from output !!! UNDOCUMENTED !!! + + bool help; // to print the help + bool version; // to print the version + bool count; // if only the count of matching tests is to be retrieved + bool list_test_cases; // to list all tests matching the filters + bool list_test_suites; // to list all suites matching the filters + bool list_reporters; // lists all registered reporters +}; + +namespace detail { + namespace types { +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + using namespace std; +#else + template + struct enable_if { }; + + template + struct enable_if { using type = T; }; + + struct true_type { static DOCTEST_CONSTEXPR bool value = true; }; + struct false_type { static DOCTEST_CONSTEXPR bool value = false; }; + + template struct remove_reference { using type = T; }; + template struct remove_reference { using type = T; }; + template struct remove_reference { using type = T; }; + + template struct is_rvalue_reference : false_type { }; + template struct is_rvalue_reference : true_type { }; + + template struct remove_const { using type = T; }; + template struct remove_const { using type = T; }; + + // Compiler intrinsics + template struct is_enum { static DOCTEST_CONSTEXPR bool value = __is_enum(T); }; + template struct underlying_type { using type = __underlying_type(T); }; + + template struct is_pointer : false_type { }; + template struct is_pointer : true_type { }; + + template struct is_array : false_type { }; + // NOLINTNEXTLINE(*-avoid-c-arrays) + template struct is_array : true_type { }; +#endif + } + + // + template + T&& declval(); + + template + DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type& t) DOCTEST_NOEXCEPT { + return static_cast(t); + } + + template + DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type&& t) DOCTEST_NOEXCEPT { + return static_cast(t); + } + + template + struct deferred_false : types::false_type { }; + +// MSVS 2015 :( +#if !DOCTEST_CLANG && defined(_MSC_VER) && _MSC_VER <= 1900 + template + struct has_global_insertion_operator : types::false_type { }; + + template + struct has_global_insertion_operator(), declval()), void())> : types::true_type { }; + + template + struct has_insertion_operator { static DOCTEST_CONSTEXPR bool value = has_global_insertion_operator::value; }; + + template + struct insert_hack; + + template + struct insert_hack { + static void insert(std::ostream& os, const T& t) { ::operator<<(os, t); } + }; + + template + struct insert_hack { + static void insert(std::ostream& os, const T& t) { operator<<(os, t); } + }; + + template + using insert_hack_t = insert_hack::value>; +#else + template + struct has_insertion_operator : types::false_type { }; +#endif + + template + struct has_insertion_operator(), declval()), void())> : types::true_type { }; + + template + struct should_stringify_as_underlying_type { + static DOCTEST_CONSTEXPR bool value = detail::types::is_enum::value && !doctest::detail::has_insertion_operator::value; + }; + + DOCTEST_INTERFACE std::ostream* tlssPush(); + DOCTEST_INTERFACE String tlssPop(); + + template + struct StringMakerBase { + template + static String convert(const DOCTEST_REF_WRAP(T)) { +#ifdef DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES + static_assert(deferred_false::value, "No stringification detected for type T. See string conversion manual"); +#endif + return "{?}"; + } + }; + + template + struct filldata; + + template + void filloss(std::ostream* stream, const T& in) { + filldata::fill(stream, in); + } + + template + void filloss(std::ostream* stream, const T (&in)[N]) { // NOLINT(*-avoid-c-arrays) + // T[N], T(&)[N], T(&&)[N] have same behaviour. + // Hence remove reference. + filloss::type>(stream, in); + } + + template + String toStream(const T& in) { + std::ostream* stream = tlssPush(); + filloss(stream, in); + return tlssPop(); + } + + template <> + struct StringMakerBase { + template + static String convert(const DOCTEST_REF_WRAP(T) in) { + return toStream(in); + } + }; +} // namespace detail + +template +struct StringMaker : public detail::StringMakerBase< + detail::has_insertion_operator::value || detail::types::is_pointer::value || detail::types::is_array::value> +{}; + +#ifndef DOCTEST_STRINGIFY +#ifdef DOCTEST_CONFIG_DOUBLE_STRINGIFY +#define DOCTEST_STRINGIFY(...) toString(toString(__VA_ARGS__)) +#else +#define DOCTEST_STRINGIFY(...) toString(__VA_ARGS__) +#endif +#endif + +template +String toString() { +#if DOCTEST_CLANG == 0 && DOCTEST_GCC == 0 && DOCTEST_ICC == 0 + String ret = __FUNCSIG__; // class doctest::String __cdecl doctest::toString(void) + String::size_type beginPos = ret.find('<'); + return ret.substr(beginPos + 1, ret.size() - beginPos - static_cast(sizeof(">(void)"))); +#else + String ret = __PRETTY_FUNCTION__; // doctest::String toString() [with T = TYPE] + String::size_type begin = ret.find('=') + 2; + return ret.substr(begin, ret.size() - begin - 1); +#endif +} + +template ::value, bool>::type = true> +String toString(const DOCTEST_REF_WRAP(T) value) { + return StringMaker::convert(value); +} + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +DOCTEST_INTERFACE String toString(const char* in); +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +DOCTEST_INTERFACE String toString(const std::string& in); +#endif // VS 2019 + +DOCTEST_INTERFACE String toString(String in); + +DOCTEST_INTERFACE String toString(std::nullptr_t); + +DOCTEST_INTERFACE String toString(bool in); + +DOCTEST_INTERFACE String toString(float in); +DOCTEST_INTERFACE String toString(double in); +DOCTEST_INTERFACE String toString(double long in); + +DOCTEST_INTERFACE String toString(char in); +DOCTEST_INTERFACE String toString(char signed in); +DOCTEST_INTERFACE String toString(char unsigned in); +DOCTEST_INTERFACE String toString(short in); +DOCTEST_INTERFACE String toString(short unsigned in); +DOCTEST_INTERFACE String toString(signed in); +DOCTEST_INTERFACE String toString(unsigned in); +DOCTEST_INTERFACE String toString(long in); +DOCTEST_INTERFACE String toString(long unsigned in); +DOCTEST_INTERFACE String toString(long long in); +DOCTEST_INTERFACE String toString(long long unsigned in); + +template ::value, bool>::type = true> +String toString(const DOCTEST_REF_WRAP(T) value) { + using UT = typename detail::types::underlying_type::type; + return (DOCTEST_STRINGIFY(static_cast(value))); +} + +namespace detail { + template + struct filldata + { + static void fill(std::ostream* stream, const T& in) { +#if defined(_MSC_VER) && _MSC_VER <= 1900 + insert_hack_t::insert(*stream, in); +#else + operator<<(*stream, in); +#endif + } + }; + +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) +// NOLINTBEGIN(*-avoid-c-arrays) + template + struct filldata { + static void fill(std::ostream* stream, const T(&in)[N]) { + *stream << "["; + for (size_t i = 0; i < N; i++) { + if (i != 0) { *stream << ", "; } + *stream << (DOCTEST_STRINGIFY(in[i])); + } + *stream << "]"; + } + }; +// NOLINTEND(*-avoid-c-arrays) +DOCTEST_MSVC_SUPPRESS_WARNING_POP + + // Specialized since we don't want the terminating null byte! +// NOLINTBEGIN(*-avoid-c-arrays) + template + struct filldata { + static void fill(std::ostream* stream, const char (&in)[N]) { + *stream << String(in, in[N - 1] ? N : N - 1); + } // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + }; +// NOLINTEND(*-avoid-c-arrays) + + template <> + struct filldata { + static void fill(std::ostream* stream, const void* in); + }; + + template + struct filldata { +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4180) + static void fill(std::ostream* stream, const T* in) { +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wmicrosoft-cast") + filldata::fill(stream, +#if DOCTEST_GCC == 0 || DOCTEST_GCC >= DOCTEST_COMPILER(4, 9, 0) + reinterpret_cast(in) +#else + *reinterpret_cast(&in) +#endif + ); +DOCTEST_CLANG_SUPPRESS_WARNING_POP + } + }; +} + +struct DOCTEST_INTERFACE Approx +{ + Approx(double value); + + Approx operator()(double value) const; + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + explicit Approx(const T& value, + typename detail::types::enable_if::value>::type* = + static_cast(nullptr)) { + *this = static_cast(value); + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& epsilon(double newEpsilon); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename std::enable_if::value, Approx&>::type epsilon( + const T& newEpsilon) { + m_epsilon = static_cast(newEpsilon); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& scale(double newScale); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename std::enable_if::value, Approx&>::type scale( + const T& newScale) { + m_scale = static_cast(newScale); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format off + DOCTEST_INTERFACE friend bool operator==(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator==(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator!=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator!=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator<=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator<=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator>=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator>=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator< (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator< (const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_APPROX_PREFIX \ + template friend typename std::enable_if::value, bool>::type + + DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(static_cast(lhs), rhs); } + DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } + DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) && lhs != rhs; } +#undef DOCTEST_APPROX_PREFIX +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format on + + double m_epsilon; + double m_scale; + double m_value; +}; + +DOCTEST_INTERFACE String toString(const Approx& in); + +DOCTEST_INTERFACE const ContextOptions* getContextOptions(); + +template +struct DOCTEST_INTERFACE_DECL IsNaN +{ + F value; bool flipped; + IsNaN(F f, bool flip = false) : value(f), flipped(flip) { } + IsNaN operator!() const { return { value, !flipped }; } + operator bool() const; +}; +#ifndef __MINGW32__ +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +#endif +DOCTEST_INTERFACE String toString(IsNaN in); +DOCTEST_INTERFACE String toString(IsNaN in); +DOCTEST_INTERFACE String toString(IsNaN in); + +#ifndef DOCTEST_CONFIG_DISABLE + +namespace detail { + // clang-format off +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + template struct decay_array { using type = T; }; + template struct decay_array { using type = T*; }; + template struct decay_array { using type = T*; }; + + template struct not_char_pointer { static DOCTEST_CONSTEXPR int value = 1; }; + template<> struct not_char_pointer { static DOCTEST_CONSTEXPR int value = 0; }; + template<> struct not_char_pointer { static DOCTEST_CONSTEXPR int value = 0; }; + + template struct can_use_op : public not_char_pointer::type> {}; +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + + struct DOCTEST_INTERFACE TestFailureException + { + }; + + DOCTEST_INTERFACE bool checkIfShouldThrow(assertType::Enum at); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_INTERFACE void throwException(); + + struct DOCTEST_INTERFACE Subcase + { + SubcaseSignature m_signature; + bool m_entered = false; + + Subcase(const String& name, const char* file, int line); + Subcase(const Subcase&) = delete; + Subcase(Subcase&&) = delete; + Subcase& operator=(const Subcase&) = delete; + Subcase& operator=(Subcase&&) = delete; + ~Subcase(); + + operator bool() const; + + private: + bool checkFilters(); + }; + + template + String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, + const DOCTEST_REF_WRAP(R) rhs) { + return (DOCTEST_STRINGIFY(lhs)) + op + (DOCTEST_STRINGIFY(rhs)); + } + +#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") +#endif + +// This will check if there is any way it could find a operator like member or friend and uses it. +// If not it doesn't find the operator or if the operator at global scope is defined after +// this template, the template won't be instantiated due to SFINAE. Once the template is not +// instantiated it can look for global operator using normal conversions. +#ifdef __NVCC__ +#define SFINAE_OP(ret,op) ret +#else +#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),ret{}) +#endif + +#define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ + template \ + DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ + bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ + if(m_at & assertType::is_false) \ + res = !res; \ + if(!res || doctest::getContextOptions()->success) \ + return Result(res, stringifyBinaryExpr(lhs, op_str, rhs)); \ + return Result(res); \ + } + + // more checks could be added - like in Catch: + // https://github.com/catchorg/Catch2/pull/1480/files + // https://github.com/catchorg/Catch2/pull/1481/files +#define DOCTEST_FORBIT_EXPRESSION(rt, op) \ + template \ + rt& operator op(const R&) { \ + static_assert(deferred_false::value, \ + "Expression Too Complex Please Rewrite As Binary Comparison!"); \ + return *this; \ + } + + struct DOCTEST_INTERFACE Result // NOLINT(*-member-init) + { + bool m_passed; + String m_decomp; + + Result() = default; // TODO: Why do we need this? (To remove NOLINT) + Result(bool passed, const String& decomposition = String()); + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Result, &) + DOCTEST_FORBIT_EXPRESSION(Result, ^) + DOCTEST_FORBIT_EXPRESSION(Result, |) + DOCTEST_FORBIT_EXPRESSION(Result, &&) + DOCTEST_FORBIT_EXPRESSION(Result, ||) + DOCTEST_FORBIT_EXPRESSION(Result, ==) + DOCTEST_FORBIT_EXPRESSION(Result, !=) + DOCTEST_FORBIT_EXPRESSION(Result, <) + DOCTEST_FORBIT_EXPRESSION(Result, >) + DOCTEST_FORBIT_EXPRESSION(Result, <=) + DOCTEST_FORBIT_EXPRESSION(Result, >=) + DOCTEST_FORBIT_EXPRESSION(Result, =) + DOCTEST_FORBIT_EXPRESSION(Result, +=) + DOCTEST_FORBIT_EXPRESSION(Result, -=) + DOCTEST_FORBIT_EXPRESSION(Result, *=) + DOCTEST_FORBIT_EXPRESSION(Result, /=) + DOCTEST_FORBIT_EXPRESSION(Result, %=) + DOCTEST_FORBIT_EXPRESSION(Result, <<=) + DOCTEST_FORBIT_EXPRESSION(Result, >>=) + DOCTEST_FORBIT_EXPRESSION(Result, &=) + DOCTEST_FORBIT_EXPRESSION(Result, ^=) + DOCTEST_FORBIT_EXPRESSION(Result, |=) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_GCC_SUPPRESS_WARNING_PUSH + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH + // https://stackoverflow.com/questions/39479163 what's the difference between 4018 and 4389 + DOCTEST_MSVC_SUPPRESS_WARNING(4388) // signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4389) // 'operator' : signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4018) // 'expression' : signed/unsigned mismatch + //DOCTEST_MSVC_SUPPRESS_WARNING(4805) // 'operation' : unsafe mix of type 'type' and type 'type' in operation + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + // clang-format off +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE bool +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE typename types::enable_if::value || can_use_op::value, bool>::type + inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } + inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } + inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } + inline bool gt(const char* lhs, const char* rhs) { return String(lhs) > String(rhs); } + inline bool le(const char* lhs, const char* rhs) { return String(lhs) <= String(rhs); } + inline bool ge(const char* lhs, const char* rhs) { return String(lhs) >= String(rhs); } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + DOCTEST_COMPARISON_RETURN_TYPE name(const DOCTEST_REF_WRAP(L) lhs, \ + const DOCTEST_REF_WRAP(R) rhs) { \ + return lhs op rhs; \ + } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) + +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) l == r +#define DOCTEST_CMP_NE(l, r) l != r +#define DOCTEST_CMP_GT(l, r) l > r +#define DOCTEST_CMP_LT(l, r) l < r +#define DOCTEST_CMP_GE(l, r) l >= r +#define DOCTEST_CMP_LE(l, r) l <= r +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) eq(l, r) +#define DOCTEST_CMP_NE(l, r) ne(l, r) +#define DOCTEST_CMP_GT(l, r) gt(l, r) +#define DOCTEST_CMP_LT(l, r) lt(l, r) +#define DOCTEST_CMP_GE(l, r) ge(l, r) +#define DOCTEST_CMP_LE(l, r) le(l, r) +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + + template + // cppcheck-suppress copyCtorAndEqOperator + struct Expression_lhs + { + L lhs; + assertType::Enum m_at; + + explicit Expression_lhs(L&& in, assertType::Enum at) + : lhs(static_cast(in)) + , m_at(at) {} + + DOCTEST_NOINLINE operator Result() { +// this is needed only for MSVC 2015 +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4800) // 'int': forcing value to bool + bool res = static_cast(lhs); +DOCTEST_MSVC_SUPPRESS_WARNING_POP + if(m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional + res = !res; + } + + if(!res || getContextOptions()->success) { + return { res, (DOCTEST_STRINGIFY(lhs)) }; + } + return { res }; + } + + /* This is required for user-defined conversions from Expression_lhs to L */ + operator L() const { return lhs; } + + // clang-format off + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(!=, " != ", DOCTEST_CMP_NE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>, " > ", DOCTEST_CMP_GT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<, " < ", DOCTEST_CMP_LT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>=, " >= ", DOCTEST_CMP_GE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<=, " <= ", DOCTEST_CMP_LE) //!OCLINT bitwise operator in conditional + // clang-format on + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &&) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ||) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, =) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, +=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, -=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, *=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, /=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, %=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |=) + // these 2 are unfortunate because they should be allowed - they have higher precedence over the comparisons, but the + // ExpressionDecomposer class uses the left shift operator to capture the left operand of the binary expression... + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + +#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) +DOCTEST_CLANG_SUPPRESS_WARNING_POP +#endif + + struct DOCTEST_INTERFACE ExpressionDecomposer + { + assertType::Enum m_at; + + ExpressionDecomposer(assertType::Enum at); + + // The right operator for capturing expressions is "<=" instead of "<<" (based on the operator precedence table) + // but then there will be warnings from GCC about "-Wparentheses" and since "_Pragma()" is problematic this will stay for now... + // https://github.com/catchorg/Catch2/issues/870 + // https://github.com/catchorg/Catch2/issues/565 + template + Expression_lhs operator<<(L&& operand) { + return Expression_lhs(static_cast(operand), m_at); + } + + template ::value,void >::type* = nullptr> + Expression_lhs operator<<(const L &operand) { + return Expression_lhs(operand, m_at); + } + }; + + struct DOCTEST_INTERFACE TestSuite + { + const char* m_test_suite = nullptr; + const char* m_description = nullptr; + bool m_skip = false; + bool m_no_breaks = false; + bool m_no_output = false; + bool m_may_fail = false; + bool m_should_fail = false; + int m_expected_failures = 0; + double m_timeout = 0; + + TestSuite& operator*(const char* in); + + template + TestSuite& operator*(const T& in) { + in.fill(*this); + return *this; + } + }; + + using funcType = void (*)(); + + struct DOCTEST_INTERFACE TestCase : public TestCaseData + { + funcType m_test; // a function pointer to the test case + + String m_type; // for templated test cases - gets appended to the real name + int m_template_id; // an ID used to distinguish between the different versions of a templated test case + String m_full_name; // contains the name (only for templated test cases!) + the template type + + TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const String& type = String(), int template_id = -1); + + TestCase(const TestCase& other); + TestCase(TestCase&&) = delete; + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + TestCase& operator=(const TestCase& other); + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& operator=(TestCase&&) = delete; + + TestCase& operator*(const char* in); + + template + TestCase& operator*(const T& in) { + in.fill(*this); + return *this; + } + + bool operator<(const TestCase& other) const; + + ~TestCase() = default; + }; + + // forward declarations of functions used by the macros + DOCTEST_INTERFACE int regTest(const TestCase& tc); + DOCTEST_INTERFACE int setTestSuite(const TestSuite& ts); + DOCTEST_INTERFACE bool isDebuggerActive(); + + template + int instantiationHelper(const T&) { return 0; } + + namespace binaryAssertComparison { + enum Enum + { + eq = 0, + ne, + gt, + lt, + ge, + le + }; + } // namespace binaryAssertComparison + + // clang-format off + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L), const DOCTEST_REF_WRAP(R) ) const { return false; } }; + +#define DOCTEST_BINARY_RELATIONAL_OP(n, op) \ + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) const { return op(lhs, rhs); } }; + // clang-format on + + DOCTEST_BINARY_RELATIONAL_OP(0, doctest::detail::eq) + DOCTEST_BINARY_RELATIONAL_OP(1, doctest::detail::ne) + DOCTEST_BINARY_RELATIONAL_OP(2, doctest::detail::gt) + DOCTEST_BINARY_RELATIONAL_OP(3, doctest::detail::lt) + DOCTEST_BINARY_RELATIONAL_OP(4, doctest::detail::ge) + DOCTEST_BINARY_RELATIONAL_OP(5, doctest::detail::le) + + struct DOCTEST_INTERFACE ResultBuilder : public AssertData + { + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type = "", const String& exception_string = ""); + + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const Contains& exception_string); + + void setResult(const Result& res); + + template + DOCTEST_NOINLINE bool binary_assert(const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + m_failed = !RelationalComparator()(lhs, rhs); + if (m_failed || getContextOptions()->success) { + m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); + } + return !m_failed; + } + + template + DOCTEST_NOINLINE bool unary_assert(const DOCTEST_REF_WRAP(L) val) { + m_failed = !val; + + if (m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional + m_failed = !m_failed; + } + + if (m_failed || getContextOptions()->success) { + m_decomp = (DOCTEST_STRINGIFY(val)); + } + + return !m_failed; + } + + void translateException(); + + bool log(); + void react() const; + }; + + namespace assertAction { + enum Enum + { + nothing = 0, + dbgbreak = 1, + shouldthrow = 2 + }; + } // namespace assertAction + + DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); + + DOCTEST_INTERFACE bool decomp_assert(assertType::Enum at, const char* file, int line, + const char* expr, const Result& result); + +#define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ + do { \ + if(!is_running_in_test) { \ + if(failed) { \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + rb.m_decomp = decomp; \ + failed_out_of_a_testing_context(rb); \ + if(isDebuggerActive() && !getContextOptions()->no_breaks) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(checkIfShouldThrow(at)) \ + throwException(); \ + } \ + return !failed; \ + } \ + } while(false) + +#define DOCTEST_ASSERT_IN_TESTS(decomp) \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + if(rb.m_failed || getContextOptions()->success) \ + rb.m_decomp = decomp; \ + if(rb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(rb.m_failed && checkIfShouldThrow(at)) \ + throwException() + + template + DOCTEST_NOINLINE bool binary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + bool failed = !RelationalComparator()(lhs, rhs); + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + return !failed; + } + + template + DOCTEST_NOINLINE bool unary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) val) { + bool failed = !val; + + if(at & assertType::is_false) //!OCLINT bitwise operator in conditional + failed = !failed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS((DOCTEST_STRINGIFY(val))); + DOCTEST_ASSERT_IN_TESTS((DOCTEST_STRINGIFY(val))); + return !failed; + } + + struct DOCTEST_INTERFACE IExceptionTranslator + { + DOCTEST_DECLARE_INTERFACE(IExceptionTranslator) + virtual bool translate(String&) const = 0; + }; + + template + class ExceptionTranslator : public IExceptionTranslator //!OCLINT destructor of virtual class + { + public: + explicit ExceptionTranslator(String (*translateFunction)(T)) + : m_translateFunction(translateFunction) {} + + bool translate(String& res) const override { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { + throw; // lgtm [cpp/rethrow-no-exception] + // cppcheck-suppress catchExceptionByValue + } catch(const T& ex) { + res = m_translateFunction(ex); //!OCLINT parameter reassignment + return true; + } catch(...) {} //!OCLINT - empty catch statement +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + static_cast(res); // to silence -Wunused-parameter + return false; + } + + private: + String (*m_translateFunction)(T); + }; + + DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); + + // ContextScope base class used to allow implementing methods of ContextScope + // that don't depend on the template parameter in doctest.cpp. + struct DOCTEST_INTERFACE ContextScopeBase : public IContextScope { + ContextScopeBase(const ContextScopeBase&) = delete; + + ContextScopeBase& operator=(const ContextScopeBase&) = delete; + ContextScopeBase& operator=(ContextScopeBase&&) = delete; + + ~ContextScopeBase() override = default; + + protected: + ContextScopeBase(); + ContextScopeBase(ContextScopeBase&& other) noexcept; + + void destroy(); + bool need_to_destroy{true}; + }; + + template class ContextScope : public ContextScopeBase + { + L lambda_; + + public: + explicit ContextScope(const L &lambda) : lambda_(lambda) {} + explicit ContextScope(L&& lambda) : lambda_(static_cast(lambda)) { } + + ContextScope(const ContextScope&) = delete; + ContextScope(ContextScope&&) noexcept = default; + + ContextScope& operator=(const ContextScope&) = delete; + ContextScope& operator=(ContextScope&&) = delete; + + void stringify(std::ostream* s) const override { lambda_(s); } + + ~ContextScope() override { + if (need_to_destroy) { + destroy(); + } + } + }; + + struct DOCTEST_INTERFACE MessageBuilder : public MessageData + { + std::ostream* m_stream; + bool logged = false; + + MessageBuilder(const char* file, int line, assertType::Enum severity); + + MessageBuilder(const MessageBuilder&) = delete; + MessageBuilder(MessageBuilder&&) = delete; + + MessageBuilder& operator=(const MessageBuilder&) = delete; + MessageBuilder& operator=(MessageBuilder&&) = delete; + + ~MessageBuilder(); + + // the preferred way of chaining parameters for stringification +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) + template + MessageBuilder& operator,(const T& in) { + *m_stream << (DOCTEST_STRINGIFY(in)); + return *this; + } +DOCTEST_MSVC_SUPPRESS_WARNING_POP + + // kept here just for backwards-compatibility - the comma operator should be preferred now + template + MessageBuilder& operator<<(const T& in) { return this->operator,(in); } + + // the `,` operator has the lowest operator precedence - if `<<` is used by the user then + // the `,` operator will be called last which is not what we want and thus the `*` operator + // is used first (has higher operator precedence compared to `<<`) so that we guarantee that + // an operator of the MessageBuilder class is called first before the rest of the parameters + template + MessageBuilder& operator*(const T& in) { return this->operator,(in); } + + bool log(); + void react(); + }; + + template + ContextScope MakeContextScope(const L &lambda) { + return ContextScope(lambda); + } +} // namespace detail + +#define DOCTEST_DEFINE_DECORATOR(name, type, def) \ + struct name \ + { \ + type data; \ + name(type in = def) \ + : data(in) {} \ + void fill(detail::TestCase& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + void fill(detail::TestSuite& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + } + +DOCTEST_DEFINE_DECORATOR(test_suite, const char*, ""); +DOCTEST_DEFINE_DECORATOR(description, const char*, ""); +DOCTEST_DEFINE_DECORATOR(skip, bool, true); +DOCTEST_DEFINE_DECORATOR(no_breaks, bool, true); +DOCTEST_DEFINE_DECORATOR(no_output, bool, true); +DOCTEST_DEFINE_DECORATOR(timeout, double, 0); +DOCTEST_DEFINE_DECORATOR(may_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(should_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(expected_failures, int, 0); + +template +int registerExceptionTranslator(String (*translateFunction)(T)) { + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") + static detail::ExceptionTranslator exceptionTranslator(translateFunction); + DOCTEST_CLANG_SUPPRESS_WARNING_POP + detail::registerExceptionTranslatorImpl(&exceptionTranslator); + return 0; +} + +} // namespace doctest + +// in a separate namespace outside of doctest because the DOCTEST_TEST_SUITE macro +// introduces an anonymous namespace in which getCurrentTestSuite gets overridden +namespace doctest_detail_test_suite_ns { +DOCTEST_INTERFACE doctest::detail::TestSuite& getCurrentTestSuite(); +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +#else // DOCTEST_CONFIG_DISABLE +template +int registerExceptionTranslator(String (*)(T)) { + return 0; +} +#endif // DOCTEST_CONFIG_DISABLE + +namespace detail { + using assert_handler = void (*)(const AssertData&); + struct ContextState; +} // namespace detail + +class DOCTEST_INTERFACE Context +{ + detail::ContextState* p; + + void parseArgs(int argc, const char* const* argv, bool withDefaults = false); + +public: + explicit Context(int argc = 0, const char* const* argv = nullptr); + + Context(const Context&) = delete; + Context(Context&&) = delete; + + Context& operator=(const Context&) = delete; + Context& operator=(Context&&) = delete; + + ~Context(); // NOLINT(performance-trivially-destructible) + + void applyCommandLine(int argc, const char* const* argv); + + void addFilter(const char* filter, const char* value); + void clearFilters(); + void setOption(const char* option, bool value); + void setOption(const char* option, int value); + void setOption(const char* option, const char* value); + + bool shouldExit(); + + void setAsDefaultForAssertsOutOfTestCases(); + + void setAssertHandler(detail::assert_handler ah); + + void setCout(std::ostream* out); + + int run(); +}; + +namespace TestCaseFailureReason { + enum Enum + { + None = 0, + AssertFailure = 1, // an assertion has failed in the test case + Exception = 2, // test case threw an exception + Crash = 4, // a crash... + TooManyFailedAsserts = 8, // the abort-after option + Timeout = 16, // see the timeout decorator + ShouldHaveFailedButDidnt = 32, // see the should_fail decorator + ShouldHaveFailedAndDid = 64, // see the should_fail decorator + DidntFailExactlyNumTimes = 128, // see the expected_failures decorator + FailedExactlyNumTimes = 256, // see the expected_failures decorator + CouldHaveFailedAndDid = 512 // see the may_fail decorator + }; +} // namespace TestCaseFailureReason + +struct DOCTEST_INTERFACE CurrentTestCaseStats +{ + int numAssertsCurrentTest; + int numAssertsFailedCurrentTest; + double seconds; + int failure_flags; // use TestCaseFailureReason::Enum + bool testCaseSuccess; +}; + +struct DOCTEST_INTERFACE TestCaseException +{ + String error_string; + bool is_crash; +}; + +struct DOCTEST_INTERFACE TestRunStats +{ + unsigned numTestCases; + unsigned numTestCasesPassingFilters; + unsigned numTestSuitesPassingFilters; + unsigned numTestCasesFailed; + int numAsserts; + int numAssertsFailed; +}; + +struct QueryData +{ + const TestRunStats* run_stats = nullptr; + const TestCaseData** data = nullptr; + unsigned num_data = 0; +}; + +struct DOCTEST_INTERFACE IReporter +{ + // The constructor has to accept "const ContextOptions&" as a single argument + // which has most of the options for the run + a pointer to the stdout stream + // Reporter(const ContextOptions& in) + + // called when a query should be reported (listing test cases, printing the version, etc.) + virtual void report_query(const QueryData&) = 0; + + // called when the whole test run starts + virtual void test_run_start() = 0; + // called when the whole test run ends (caching a pointer to the input doesn't make sense here) + virtual void test_run_end(const TestRunStats&) = 0; + + // called when a test case is started (safe to cache a pointer to the input) + virtual void test_case_start(const TestCaseData&) = 0; + // called when a test case is reentered because of unfinished subcases (safe to cache a pointer to the input) + virtual void test_case_reenter(const TestCaseData&) = 0; + // called when a test case has ended + virtual void test_case_end(const CurrentTestCaseStats&) = 0; + + // called when an exception is thrown from the test case (or it crashes) + virtual void test_case_exception(const TestCaseException&) = 0; + + // called whenever a subcase is entered (don't cache pointers to the input) + virtual void subcase_start(const SubcaseSignature&) = 0; + // called whenever a subcase is exited (don't cache pointers to the input) + virtual void subcase_end() = 0; + + // called for each assert (don't cache pointers to the input) + virtual void log_assert(const AssertData&) = 0; + // called for each message (don't cache pointers to the input) + virtual void log_message(const MessageData&) = 0; + + // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator + // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) + virtual void test_case_skipped(const TestCaseData&) = 0; + + DOCTEST_DECLARE_INTERFACE(IReporter) + + // can obtain all currently active contexts and stringify them if one wishes to do so + static int get_num_active_contexts(); + static const IContextScope* const* get_active_contexts(); + + // can iterate through contexts which have been stringified automatically in their destructors when an exception has been thrown + static int get_num_stringified_contexts(); + static const String* get_stringified_contexts(); +}; + +namespace detail { + using reporterCreatorFunc = IReporter* (*)(const ContextOptions&); + + DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); + + template + IReporter* reporterCreator(const ContextOptions& o) { + return new Reporter(o); + } +} // namespace detail + +template +int registerReporter(const char* name, int priority, bool isReporter) { + detail::registerReporterImpl(name, priority, detail::reporterCreator, isReporter); + return 0; +} +} // namespace doctest + +#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES +#define DOCTEST_FUNC_EMPTY [] { return false; }() +#else +#define DOCTEST_FUNC_EMPTY (void)0 +#endif + +// if registering is not disabled +#ifndef DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES +#define DOCTEST_FUNC_SCOPE_BEGIN [&] +#define DOCTEST_FUNC_SCOPE_END () +#define DOCTEST_FUNC_SCOPE_RET(v) return v +#else +#define DOCTEST_FUNC_SCOPE_BEGIN do +#define DOCTEST_FUNC_SCOPE_END while(false) +#define DOCTEST_FUNC_SCOPE_RET(v) (void)0 +#endif + +// common code in asserts - for convenience +#define DOCTEST_ASSERT_LOG_REACT_RETURN(b) \ + if(b.log()) DOCTEST_BREAK_INTO_DEBUGGER(); \ + b.react(); \ + DOCTEST_FUNC_SCOPE_RET(!b.m_failed) + +#ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) x; +#else // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) \ + try { \ + x; \ + } catch(...) { DOCTEST_RB.translateException(); } +#endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wuseless-cast") \ + static_cast(__VA_ARGS__); \ + DOCTEST_GCC_SUPPRESS_WARNING_POP +#else // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) __VA_ARGS__; +#endif // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS + +// registers the test by initializing a dummy var with a function +#define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ + global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT */ \ + doctest::detail::regTest( \ + doctest::detail::TestCase( \ + f, __FILE__, __LINE__, \ + doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ + decorators)) + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ + namespace { /* NOLINT */ \ + struct der : public base \ + { \ + void f(); \ + }; \ + static DOCTEST_INLINE_NOINLINE void func() { \ + der v; \ + v.f(); \ + } \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ + } \ + DOCTEST_INLINE_NOINLINE void der::f() // NOLINT(misc-definitions-in-headers) + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ + static void f(); \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, f, decorators) \ + static void f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ + static doctest::detail::funcType proxy() { return f; } \ + DOCTEST_REGISTER_FUNCTION(inline, proxy(), decorators) \ + static void f() + +// for registering tests +#define DOCTEST_TEST_CASE(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) + +// for registering tests in classes - requires C++17 for inline variables! +#if DOCTEST_CPLUSPLUS >= 201703L +#define DOCTEST_TEST_CASE_CLASS(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_PROXY_), \ + decorators) +#else // DOCTEST_TEST_CASE_CLASS +#define DOCTEST_TEST_CASE_CLASS(...) \ + TEST_CASES_CAN_BE_REGISTERED_IN_CLASSES_ONLY_IN_CPP17_MODE_OR_WITH_VS_2017_OR_NEWER +#endif // DOCTEST_TEST_CASE_CLASS + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), c, \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_AS(str, ...) \ + namespace doctest { \ + template <> \ + inline String toString<__VA_ARGS__>() { \ + return str; \ + } \ + } \ + static_assert(true, "") + +#define DOCTEST_TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING_AS(#__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ + template \ + static void func(); \ + namespace { /* NOLINT */ \ + template \ + struct iter; \ + template \ + struct iter> \ + { \ + iter(const char* file, unsigned line, int index) { \ + doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ + doctest_detail_test_suite_ns::getCurrentTestSuite(), \ + doctest::toString(), \ + int(line) * 1000 + index) \ + * dec); \ + iter>(file, line, index + 1); \ + } \ + }; \ + template <> \ + struct iter> \ + { \ + iter(const char*, unsigned, int) {} \ + }; \ + } \ + template \ + static void func() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)) + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY), /* NOLINT(cert-err58-cpp, fuchsia-statically-constructed-objects) */ \ + doctest::detail::instantiationHelper( \ + DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0))) + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ + static_assert(true, "") + +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) \ + static_assert(true, "") + +#define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(anon, anon, std::tuple<__VA_ARGS__>) \ + template \ + static void anon() + +#define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) + +// for subcases +#define DOCTEST_SUBCASE(name) \ + if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ + doctest::detail::Subcase(name, __FILE__, __LINE__)) + +// for grouping tests in test suites by using code blocks +#define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ + namespace ns_name { namespace doctest_detail_test_suite_ns { \ + static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() noexcept { \ + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmissing-field-initializers") \ + static doctest::detail::TestSuite data{}; \ + static bool inited = false; \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + DOCTEST_GCC_SUPPRESS_WARNING_POP \ + if(!inited) { \ + data* decorators; \ + inited = true; \ + } \ + return data; \ + } \ + } \ + } \ + namespace ns_name + +#define DOCTEST_TEST_SUITE(decorators) \ + DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(DOCTEST_ANON_SUITE_)) + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(decorators) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators)) \ + static_assert(true, "") + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * "")) \ + using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int + +// for registering exception translators +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ + inline doctest::String translatorName(signature); \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerExceptionTranslator(translatorName)) \ + doctest::String translatorName(signature) + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), \ + signature) + +// for registering reporters +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerReporter(name, priority, true)) \ + static_assert(true, "") + +// for registering listeners +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerReporter(name, priority, false)) \ + static_assert(true, "") + +// clang-format off +// for logging - disabling formatting because it's important to have these on 2 separate lines - see PR #557 +#define DOCTEST_INFO(...) \ + DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_), \ + DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_OTHER_), \ + __VA_ARGS__) +// clang-format on + +#define DOCTEST_INFO_IMPL(mb_name, s_name, ...) \ + auto DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ + [&](std::ostream* s_name) { \ + doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ + mb_name.m_stream = s_name; \ + mb_name * __VA_ARGS__; \ + }) + +#define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := ", x) + +#define DOCTEST_ADD_AT_IMPL(type, file, line, mb, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ + mb * __VA_ARGS__; \ + if(mb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + mb.react(); \ + } DOCTEST_FUNC_SCOPE_END + +// clang-format off +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +// clang-format on + +#define DOCTEST_MESSAGE(...) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, __VA_ARGS__) +#define DOCTEST_FAIL_CHECK(...) DOCTEST_ADD_FAIL_CHECK_AT(__FILE__, __LINE__, __VA_ARGS__) +#define DOCTEST_FAIL(...) DOCTEST_ADD_FAIL_AT(__FILE__, __LINE__, __VA_ARGS__) + +#define DOCTEST_TO_LVALUE(...) __VA_ARGS__ // Not removed to keep backwards compatibility. + +#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(DOCTEST_RB.setResult( \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__)) /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB) \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ + } DOCTEST_FUNC_SCOPE_END // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + +#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY( \ + DOCTEST_RB.binary_assert( \ + __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(DOCTEST_RB.unary_assert(__VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +// necessary for _MESSAGE +#define DOCTEST_ASSERT_IMPLEMENT_2 DOCTEST_ASSERT_IMPLEMENT_1 + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + doctest::detail::decomp_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ + doctest::detail::binary_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ + #__VA_ARGS__, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) +#define DOCTEST_CHECK(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK, __VA_ARGS__) +#define DOCTEST_REQUIRE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE, __VA_ARGS__) +#define DOCTEST_WARN_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) + +// clang-format off +#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +// clang-format on + +#define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) +#define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) +#define DOCTEST_WARN_NE(...) DOCTEST_BINARY_ASSERT(DT_WARN_NE, ne, __VA_ARGS__) +#define DOCTEST_CHECK_NE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_NE, ne, __VA_ARGS__) +#define DOCTEST_REQUIRE_NE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_NE, ne, __VA_ARGS__) +#define DOCTEST_WARN_GT(...) DOCTEST_BINARY_ASSERT(DT_WARN_GT, gt, __VA_ARGS__) +#define DOCTEST_CHECK_GT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GT, gt, __VA_ARGS__) +#define DOCTEST_REQUIRE_GT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GT, gt, __VA_ARGS__) +#define DOCTEST_WARN_LT(...) DOCTEST_BINARY_ASSERT(DT_WARN_LT, lt, __VA_ARGS__) +#define DOCTEST_CHECK_LT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LT, lt, __VA_ARGS__) +#define DOCTEST_REQUIRE_LT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LT, lt, __VA_ARGS__) +#define DOCTEST_WARN_GE(...) DOCTEST_BINARY_ASSERT(DT_WARN_GE, ge, __VA_ARGS__) +#define DOCTEST_CHECK_GE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GE, ge, __VA_ARGS__) +#define DOCTEST_REQUIRE_GE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GE, ge, __VA_ARGS__) +#define DOCTEST_WARN_LE(...) DOCTEST_BINARY_ASSERT(DT_WARN_LE, le, __VA_ARGS__) +#define DOCTEST_CHECK_LE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LE, le, __VA_ARGS__) +#define DOCTEST_REQUIRE_LE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LE, le, __VA_ARGS__) + +#define DOCTEST_WARN_UNARY(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY, __VA_ARGS__) +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #expr, #__VA_ARGS__, message); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(const typename doctest::detail::types::remove_const< \ + typename doctest::detail::types::remove_reference<__VA_ARGS__>::type>::type&) {\ + DOCTEST_RB.translateException(); \ + DOCTEST_RB.m_threw_as = true; \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } else { /* NOLINT(*-else-after-return) */ \ + DOCTEST_FUNC_SCOPE_RET(false); \ + } \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, expr_str, "", __VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } else { /* NOLINT(*-else-after-return) */ \ + DOCTEST_FUNC_SCOPE_RET(false); \ + } \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +// clang-format off +#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") +#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") + +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) + +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +// clang-format on + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// ================================================================================================= +// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == +// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == +// ================================================================================================= +#else // DOCTEST_CONFIG_DISABLE + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ + namespace /* NOLINT */ { \ + template \ + struct der : public base \ + { void f(); }; \ + } \ + template \ + inline void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ + template \ + static inline void f() + +// for registering tests +#define DOCTEST_TEST_CASE(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for registering tests in classes +#define DOCTEST_TEST_CASE_CLASS(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), x, \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_AS(str, ...) static_assert(true, "") +#define DOCTEST_TYPE_TO_STRING(...) static_assert(true, "") + +// for typed tests +#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ + template \ + inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ + template \ + inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) static_assert(true, "") +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) static_assert(true, "") + +// for subcases +#define DOCTEST_SUBCASE(name) + +// for a testsuite block +#define DOCTEST_TEST_SUITE(name) namespace // NOLINT + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(name) static_assert(true, "") + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + template \ + static inline doctest::String DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_)(signature) + +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) + +#define DOCTEST_INFO(...) (static_cast(0)) +#define DOCTEST_CAPTURE(x) (static_cast(0)) +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_MESSAGE(...) (static_cast(0)) +#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) +#define DOCTEST_FAIL(...) (static_cast(0)) + +#if defined(DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED) \ + && defined(DOCTEST_CONFIG_ASSERTS_RETURN_VALUES) + +#define DOCTEST_WARN(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_CHECK(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_REQUIRE(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_WARN_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_CHECK_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_FALSE(...) [&] { return !(__VA_ARGS__); }() + +#define DOCTEST_WARN_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_CHECK_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() + +namespace doctest { +namespace detail { +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + bool name(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { return lhs op rhs; } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) +} // namespace detail +} // namespace doctest + +#define DOCTEST_WARN_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_CHECK_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_WARN_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_CHECK_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_WARN_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_CHECK_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_WARN_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_CHECK_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_WARN_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_CHECK_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_WARN_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_CHECK_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_WARN_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_CHECK_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_REQUIRE_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_WARN_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_CHECK_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_WARN_THROWS_WITH(expr, with, ...) [] { static_assert(false, "Exception translation is not available when doctest is disabled."); return false; }() +#define DOCTEST_CHECK_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) + +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) + +#define DOCTEST_WARN_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_CHECK_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_REQUIRE_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_WARN_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_CHECK_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_WARN_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_CHECK_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_REQUIRE_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#else // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED + +#define DOCTEST_WARN(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_LE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_LE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_LE(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_WARN_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#endif // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED + +#endif // DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#define DOCTEST_EXCEPTION_EMPTY_FUNC DOCTEST_FUNC_EMPTY +#else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#define DOCTEST_EXCEPTION_EMPTY_FUNC [] { static_assert(false, "Exceptions are disabled! " \ + "Use DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS if you want to compile with exceptions disabled."); return false; }() + +#undef DOCTEST_REQUIRE +#undef DOCTEST_REQUIRE_FALSE +#undef DOCTEST_REQUIRE_MESSAGE +#undef DOCTEST_REQUIRE_FALSE_MESSAGE +#undef DOCTEST_REQUIRE_EQ +#undef DOCTEST_REQUIRE_NE +#undef DOCTEST_REQUIRE_GT +#undef DOCTEST_REQUIRE_LT +#undef DOCTEST_REQUIRE_GE +#undef DOCTEST_REQUIRE_LE +#undef DOCTEST_REQUIRE_UNARY +#undef DOCTEST_REQUIRE_UNARY_FALSE + +#define DOCTEST_REQUIRE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_FALSE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_EQ DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_GT DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_LT DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_GE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_LE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_UNARY DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_UNARY_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#define DOCTEST_WARN_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// clang-format off +// KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS +#define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ +#define DOCTEST_FAST_CHECK_EQ DOCTEST_CHECK_EQ +#define DOCTEST_FAST_REQUIRE_EQ DOCTEST_REQUIRE_EQ +#define DOCTEST_FAST_WARN_NE DOCTEST_WARN_NE +#define DOCTEST_FAST_CHECK_NE DOCTEST_CHECK_NE +#define DOCTEST_FAST_REQUIRE_NE DOCTEST_REQUIRE_NE +#define DOCTEST_FAST_WARN_GT DOCTEST_WARN_GT +#define DOCTEST_FAST_CHECK_GT DOCTEST_CHECK_GT +#define DOCTEST_FAST_REQUIRE_GT DOCTEST_REQUIRE_GT +#define DOCTEST_FAST_WARN_LT DOCTEST_WARN_LT +#define DOCTEST_FAST_CHECK_LT DOCTEST_CHECK_LT +#define DOCTEST_FAST_REQUIRE_LT DOCTEST_REQUIRE_LT +#define DOCTEST_FAST_WARN_GE DOCTEST_WARN_GE +#define DOCTEST_FAST_CHECK_GE DOCTEST_CHECK_GE +#define DOCTEST_FAST_REQUIRE_GE DOCTEST_REQUIRE_GE +#define DOCTEST_FAST_WARN_LE DOCTEST_WARN_LE +#define DOCTEST_FAST_CHECK_LE DOCTEST_CHECK_LE +#define DOCTEST_FAST_REQUIRE_LE DOCTEST_REQUIRE_LE + +#define DOCTEST_FAST_WARN_UNARY DOCTEST_WARN_UNARY +#define DOCTEST_FAST_CHECK_UNARY DOCTEST_CHECK_UNARY +#define DOCTEST_FAST_REQUIRE_UNARY DOCTEST_REQUIRE_UNARY +#define DOCTEST_FAST_WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE +#define DOCTEST_FAST_CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE +#define DOCTEST_FAST_REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id,__VA_ARGS__) +// clang-format on + +// BDD style macros +// clang-format off +#define DOCTEST_SCENARIO(name) DOCTEST_TEST_CASE(" Scenario: " name) +#define DOCTEST_SCENARIO_CLASS(name) DOCTEST_TEST_CASE_CLASS(" Scenario: " name) +#define DOCTEST_SCENARIO_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(" Scenario: " name, T, __VA_ARGS__) +#define DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(" Scenario: " name, T, id) + +#define DOCTEST_GIVEN(name) DOCTEST_SUBCASE(" Given: " name) +#define DOCTEST_WHEN(name) DOCTEST_SUBCASE(" When: " name) +#define DOCTEST_AND_WHEN(name) DOCTEST_SUBCASE("And when: " name) +#define DOCTEST_THEN(name) DOCTEST_SUBCASE(" Then: " name) +#define DOCTEST_AND_THEN(name) DOCTEST_SUBCASE(" And: " name) +// clang-format on + +// == SHORT VERSIONS OF THE MACROS +#ifndef DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES + +#define TEST_CASE(name) DOCTEST_TEST_CASE(name) +#define TEST_CASE_CLASS(name) DOCTEST_TEST_CASE_CLASS(name) +#define TEST_CASE_FIXTURE(x, name) DOCTEST_TEST_CASE_FIXTURE(x, name) +#define TYPE_TO_STRING_AS(str, ...) DOCTEST_TYPE_TO_STRING_AS(str, __VA_ARGS__) +#define TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING(__VA_ARGS__) +#define TEST_CASE_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(name, T, __VA_ARGS__) +#define TEST_CASE_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, T, id) +#define TEST_CASE_TEMPLATE_INVOKE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, __VA_ARGS__) +#define TEST_CASE_TEMPLATE_APPLY(id, ...) DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, __VA_ARGS__) +#define SUBCASE(name) DOCTEST_SUBCASE(name) +#define TEST_SUITE(decorators) DOCTEST_TEST_SUITE(decorators) +#define TEST_SUITE_BEGIN(name) DOCTEST_TEST_SUITE_BEGIN(name) +#define TEST_SUITE_END DOCTEST_TEST_SUITE_END +#define REGISTER_EXCEPTION_TRANSLATOR(signature) DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) +#define REGISTER_REPORTER(name, priority, reporter) DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define REGISTER_LISTENER(name, priority, reporter) DOCTEST_REGISTER_LISTENER(name, priority, reporter) +#define INFO(...) DOCTEST_INFO(__VA_ARGS__) +#define CAPTURE(x) DOCTEST_CAPTURE(x) +#define ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_MESSAGE_AT(file, line, __VA_ARGS__) +#define ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_FAIL_CHECK_AT(file, line, __VA_ARGS__) +#define ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_FAIL_AT(file, line, __VA_ARGS__) +#define MESSAGE(...) DOCTEST_MESSAGE(__VA_ARGS__) +#define FAIL_CHECK(...) DOCTEST_FAIL_CHECK(__VA_ARGS__) +#define FAIL(...) DOCTEST_FAIL(__VA_ARGS__) +#define TO_LVALUE(...) DOCTEST_TO_LVALUE(__VA_ARGS__) + +#define WARN(...) DOCTEST_WARN(__VA_ARGS__) +#define WARN_FALSE(...) DOCTEST_WARN_FALSE(__VA_ARGS__) +#define WARN_THROWS(...) DOCTEST_WARN_THROWS(__VA_ARGS__) +#define WARN_THROWS_AS(expr, ...) DOCTEST_WARN_THROWS_AS(expr, __VA_ARGS__) +#define WARN_THROWS_WITH(expr, ...) DOCTEST_WARN_THROWS_WITH(expr, __VA_ARGS__) +#define WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_WARN_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define WARN_NOTHROW(...) DOCTEST_WARN_NOTHROW(__VA_ARGS__) +#define CHECK(...) DOCTEST_CHECK(__VA_ARGS__) +#define CHECK_FALSE(...) DOCTEST_CHECK_FALSE(__VA_ARGS__) +#define CHECK_THROWS(...) DOCTEST_CHECK_THROWS(__VA_ARGS__) +#define CHECK_THROWS_AS(expr, ...) DOCTEST_CHECK_THROWS_AS(expr, __VA_ARGS__) +#define CHECK_THROWS_WITH(expr, ...) DOCTEST_CHECK_THROWS_WITH(expr, __VA_ARGS__) +#define CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define CHECK_NOTHROW(...) DOCTEST_CHECK_NOTHROW(__VA_ARGS__) +#define REQUIRE(...) DOCTEST_REQUIRE(__VA_ARGS__) +#define REQUIRE_FALSE(...) DOCTEST_REQUIRE_FALSE(__VA_ARGS__) +#define REQUIRE_THROWS(...) DOCTEST_REQUIRE_THROWS(__VA_ARGS__) +#define REQUIRE_THROWS_AS(expr, ...) DOCTEST_REQUIRE_THROWS_AS(expr, __VA_ARGS__) +#define REQUIRE_THROWS_WITH(expr, ...) DOCTEST_REQUIRE_THROWS_WITH(expr, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define REQUIRE_NOTHROW(...) DOCTEST_REQUIRE_NOTHROW(__VA_ARGS__) + +#define WARN_MESSAGE(cond, ...) DOCTEST_WARN_MESSAGE(cond, __VA_ARGS__) +#define WARN_FALSE_MESSAGE(cond, ...) DOCTEST_WARN_FALSE_MESSAGE(cond, __VA_ARGS__) +#define WARN_THROWS_MESSAGE(expr, ...) DOCTEST_WARN_THROWS_MESSAGE(expr, __VA_ARGS__) +#define WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_WARN_NOTHROW_MESSAGE(expr, __VA_ARGS__) +#define CHECK_MESSAGE(cond, ...) DOCTEST_CHECK_MESSAGE(cond, __VA_ARGS__) +#define CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_CHECK_FALSE_MESSAGE(cond, __VA_ARGS__) +#define CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_CHECK_THROWS_MESSAGE(expr, __VA_ARGS__) +#define CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_CHECK_NOTHROW_MESSAGE(expr, __VA_ARGS__) +#define REQUIRE_MESSAGE(cond, ...) DOCTEST_REQUIRE_MESSAGE(cond, __VA_ARGS__) +#define REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_REQUIRE_FALSE_MESSAGE(cond, __VA_ARGS__) +#define REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_REQUIRE_THROWS_MESSAGE(expr, __VA_ARGS__) +#define REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, __VA_ARGS__) + +#define SCENARIO(name) DOCTEST_SCENARIO(name) +#define SCENARIO_CLASS(name) DOCTEST_SCENARIO_CLASS(name) +#define SCENARIO_TEMPLATE(name, T, ...) DOCTEST_SCENARIO_TEMPLATE(name, T, __VA_ARGS__) +#define SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) +#define GIVEN(name) DOCTEST_GIVEN(name) +#define WHEN(name) DOCTEST_WHEN(name) +#define AND_WHEN(name) DOCTEST_AND_WHEN(name) +#define THEN(name) DOCTEST_THEN(name) +#define AND_THEN(name) DOCTEST_AND_THEN(name) + +#define WARN_EQ(...) DOCTEST_WARN_EQ(__VA_ARGS__) +#define CHECK_EQ(...) DOCTEST_CHECK_EQ(__VA_ARGS__) +#define REQUIRE_EQ(...) DOCTEST_REQUIRE_EQ(__VA_ARGS__) +#define WARN_NE(...) DOCTEST_WARN_NE(__VA_ARGS__) +#define CHECK_NE(...) DOCTEST_CHECK_NE(__VA_ARGS__) +#define REQUIRE_NE(...) DOCTEST_REQUIRE_NE(__VA_ARGS__) +#define WARN_GT(...) DOCTEST_WARN_GT(__VA_ARGS__) +#define CHECK_GT(...) DOCTEST_CHECK_GT(__VA_ARGS__) +#define REQUIRE_GT(...) DOCTEST_REQUIRE_GT(__VA_ARGS__) +#define WARN_LT(...) DOCTEST_WARN_LT(__VA_ARGS__) +#define CHECK_LT(...) DOCTEST_CHECK_LT(__VA_ARGS__) +#define REQUIRE_LT(...) DOCTEST_REQUIRE_LT(__VA_ARGS__) +#define WARN_GE(...) DOCTEST_WARN_GE(__VA_ARGS__) +#define CHECK_GE(...) DOCTEST_CHECK_GE(__VA_ARGS__) +#define REQUIRE_GE(...) DOCTEST_REQUIRE_GE(__VA_ARGS__) +#define WARN_LE(...) DOCTEST_WARN_LE(__VA_ARGS__) +#define CHECK_LE(...) DOCTEST_CHECK_LE(__VA_ARGS__) +#define REQUIRE_LE(...) DOCTEST_REQUIRE_LE(__VA_ARGS__) +#define WARN_UNARY(...) DOCTEST_WARN_UNARY(__VA_ARGS__) +#define CHECK_UNARY(...) DOCTEST_CHECK_UNARY(__VA_ARGS__) +#define REQUIRE_UNARY(...) DOCTEST_REQUIRE_UNARY(__VA_ARGS__) +#define WARN_UNARY_FALSE(...) DOCTEST_WARN_UNARY_FALSE(__VA_ARGS__) +#define CHECK_UNARY_FALSE(...) DOCTEST_CHECK_UNARY_FALSE(__VA_ARGS__) +#define REQUIRE_UNARY_FALSE(...) DOCTEST_REQUIRE_UNARY_FALSE(__VA_ARGS__) + +// KEPT FOR BACKWARDS COMPATIBILITY +#define FAST_WARN_EQ(...) DOCTEST_FAST_WARN_EQ(__VA_ARGS__) +#define FAST_CHECK_EQ(...) DOCTEST_FAST_CHECK_EQ(__VA_ARGS__) +#define FAST_REQUIRE_EQ(...) DOCTEST_FAST_REQUIRE_EQ(__VA_ARGS__) +#define FAST_WARN_NE(...) DOCTEST_FAST_WARN_NE(__VA_ARGS__) +#define FAST_CHECK_NE(...) DOCTEST_FAST_CHECK_NE(__VA_ARGS__) +#define FAST_REQUIRE_NE(...) DOCTEST_FAST_REQUIRE_NE(__VA_ARGS__) +#define FAST_WARN_GT(...) DOCTEST_FAST_WARN_GT(__VA_ARGS__) +#define FAST_CHECK_GT(...) DOCTEST_FAST_CHECK_GT(__VA_ARGS__) +#define FAST_REQUIRE_GT(...) DOCTEST_FAST_REQUIRE_GT(__VA_ARGS__) +#define FAST_WARN_LT(...) DOCTEST_FAST_WARN_LT(__VA_ARGS__) +#define FAST_CHECK_LT(...) DOCTEST_FAST_CHECK_LT(__VA_ARGS__) +#define FAST_REQUIRE_LT(...) DOCTEST_FAST_REQUIRE_LT(__VA_ARGS__) +#define FAST_WARN_GE(...) DOCTEST_FAST_WARN_GE(__VA_ARGS__) +#define FAST_CHECK_GE(...) DOCTEST_FAST_CHECK_GE(__VA_ARGS__) +#define FAST_REQUIRE_GE(...) DOCTEST_FAST_REQUIRE_GE(__VA_ARGS__) +#define FAST_WARN_LE(...) DOCTEST_FAST_WARN_LE(__VA_ARGS__) +#define FAST_CHECK_LE(...) DOCTEST_FAST_CHECK_LE(__VA_ARGS__) +#define FAST_REQUIRE_LE(...) DOCTEST_FAST_REQUIRE_LE(__VA_ARGS__) + +#define FAST_WARN_UNARY(...) DOCTEST_FAST_WARN_UNARY(__VA_ARGS__) +#define FAST_CHECK_UNARY(...) DOCTEST_FAST_CHECK_UNARY(__VA_ARGS__) +#define FAST_REQUIRE_UNARY(...) DOCTEST_FAST_REQUIRE_UNARY(__VA_ARGS__) +#define FAST_WARN_UNARY_FALSE(...) DOCTEST_FAST_WARN_UNARY_FALSE(__VA_ARGS__) +#define FAST_CHECK_UNARY_FALSE(...) DOCTEST_FAST_CHECK_UNARY_FALSE(__VA_ARGS__) +#define FAST_REQUIRE_UNARY_FALSE(...) DOCTEST_FAST_REQUIRE_UNARY_FALSE(__VA_ARGS__) + +#define TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES + +#ifndef DOCTEST_CONFIG_DISABLE + +// this is here to clear the 'current test suite' for the current translation unit - at the top +DOCTEST_TEST_SUITE_END(); + +#endif // DOCTEST_CONFIG_DISABLE + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_POP + +#endif // DOCTEST_LIBRARY_INCLUDED + +#ifndef DOCTEST_SINGLE_HEADER +#define DOCTEST_SINGLE_HEADER +#endif // DOCTEST_SINGLE_HEADER + +#if defined(DOCTEST_CONFIG_IMPLEMENT) || !defined(DOCTEST_SINGLE_HEADER) + +#ifndef DOCTEST_SINGLE_HEADER +#include "doctest_fwd.h" +#endif // DOCTEST_SINGLE_HEADER + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") + +#ifndef DOCTEST_LIBRARY_IMPLEMENTATION +#define DOCTEST_LIBRARY_IMPLEMENTATION + +DOCTEST_CLANG_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnonportable-system-include-path") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data +DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled +DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified +DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal +DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch +DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C +DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) +DOCTEST_MSVC_SUPPRESS_WARNING(5245) // unreferenced function with internal linkage has been removed + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN + +// required includes - will go only in one translation unit! +#include +#include +#include +// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/doctest/doctest/pull/37 +#ifdef __BORLANDC__ +#include +#endif // __BORLANDC__ +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM +#include +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM +#include +#include +#include +#ifndef DOCTEST_CONFIG_NO_MULTITHREADING +#include +#include +#define DOCTEST_DECLARE_MUTEX(name) std::mutex name; +#define DOCTEST_DECLARE_STATIC_MUTEX(name) static DOCTEST_DECLARE_MUTEX(name) +#define DOCTEST_LOCK_MUTEX(name) std::lock_guard DOCTEST_ANONYMOUS(DOCTEST_ANON_LOCK_)(name); +#else // DOCTEST_CONFIG_NO_MULTITHREADING +#define DOCTEST_DECLARE_MUTEX(name) +#define DOCTEST_DECLARE_STATIC_MUTEX(name) +#define DOCTEST_LOCK_MUTEX(name) +#endif // DOCTEST_CONFIG_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef DOCTEST_PLATFORM_MAC +#include +#include +#include +#endif // DOCTEST_PLATFORM_MAC + +#ifdef DOCTEST_PLATFORM_WINDOWS + +// defines for a leaner windows.h +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#define DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#define DOCTEST_UNDEF_NOMINMAX +#endif // NOMINMAX + +// not sure what AfxWin.h is for - here I do what Catch does +#ifdef __AFXDLL +#include +#else +#include +#endif +#include + +#else // DOCTEST_PLATFORM_WINDOWS + +#include +#include + +#endif // DOCTEST_PLATFORM_WINDOWS + +// this is a fix for https://github.com/doctest/doctest/issues/348 +// https://mail.gnome.org/archives/xml/2012-January/msg00000.html +#if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) +#define STDOUT_FILENO fileno(stdout) +#endif // HAVE_UNISTD_H + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END + +// counts the number of elements in a C array +#define DOCTEST_COUNTOF(x) (sizeof(x) / sizeof(x[0])) + +#ifdef DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_disabled +#else // DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_not_disabled +#endif // DOCTEST_CONFIG_DISABLE + +#ifndef DOCTEST_CONFIG_OPTIONS_PREFIX +#define DOCTEST_CONFIG_OPTIONS_PREFIX "dt-" +#endif + +#ifndef DOCTEST_THREAD_LOCAL +#if defined(DOCTEST_CONFIG_NO_MULTITHREADING) || DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_THREAD_LOCAL +#else // DOCTEST_MSVC +#define DOCTEST_THREAD_LOCAL thread_local +#endif // DOCTEST_MSVC +#endif // DOCTEST_THREAD_LOCAL + +#ifndef DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES +#define DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES 32 +#endif + +#ifndef DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE +#define DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE 64 +#endif + +#ifdef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS +#define DOCTEST_OPTIONS_PREFIX_DISPLAY DOCTEST_CONFIG_OPTIONS_PREFIX +#else +#define DOCTEST_OPTIONS_PREFIX_DISPLAY "" +#endif + +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS +#endif + +#ifndef DOCTEST_CDECL +#define DOCTEST_CDECL __cdecl +#endif + +namespace doctest { + +bool is_running_in_test = false; + +namespace { + using namespace detail; + + template + DOCTEST_NORETURN void throw_exception(Ex const& e) { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + throw e; +#else // DOCTEST_CONFIG_NO_EXCEPTIONS +#ifdef DOCTEST_CONFIG_HANDLE_EXCEPTION + DOCTEST_CONFIG_HANDLE_EXCEPTION(e); +#else // DOCTEST_CONFIG_HANDLE_EXCEPTION +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + std::cerr << "doctest will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM +#endif // DOCTEST_CONFIG_HANDLE_EXCEPTION + std::terminate(); +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } + +#ifndef DOCTEST_INTERNAL_ERROR +#define DOCTEST_INTERNAL_ERROR(msg) \ + throw_exception(std::logic_error( \ + __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) +#endif // DOCTEST_INTERNAL_ERROR + + // case insensitive strcmp + int stricmp(const char* a, const char* b) { + for(;; a++, b++) { + const int d = tolower(*a) - tolower(*b); + if(d != 0 || !*a) + return d; + } + } + + struct Endianness + { + enum Arch + { + Big, + Little + }; + + static Arch which() { + int x = 1; + // casting any data pointer to char* is allowed + auto ptr = reinterpret_cast(&x); + if(*ptr) + return Little; + return Big; + } + }; +} // namespace + +namespace detail { + DOCTEST_THREAD_LOCAL class + { + std::vector stack; + std::stringstream ss; + + public: + std::ostream* push() { + stack.push_back(ss.tellp()); + return &ss; + } + + String pop() { + if (stack.empty()) + DOCTEST_INTERNAL_ERROR("TLSS was empty when trying to pop!"); + + std::streampos pos = stack.back(); + stack.pop_back(); + unsigned sz = static_cast(ss.tellp() - pos); + ss.rdbuf()->pubseekpos(pos, std::ios::in | std::ios::out); + return String(ss, sz); + } + } g_oss; + + std::ostream* tlssPush() { + return g_oss.push(); + } + + String tlssPop() { + return g_oss.pop(); + } + +#ifndef DOCTEST_CONFIG_DISABLE + +namespace timer_large_integer +{ + +#if defined(DOCTEST_PLATFORM_WINDOWS) + using type = ULONGLONG; +#else // DOCTEST_PLATFORM_WINDOWS + using type = std::uint64_t; +#endif // DOCTEST_PLATFORM_WINDOWS +} + +using ticks_t = timer_large_integer::type; + +#ifdef DOCTEST_CONFIG_GETCURRENTTICKS + ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } +#elif defined(DOCTEST_PLATFORM_WINDOWS) + ticks_t getCurrentTicks() { + static LARGE_INTEGER hz = { {0} }, hzo = { {0} }; + if(!hz.QuadPart) { + QueryPerformanceFrequency(&hz); + QueryPerformanceCounter(&hzo); + } + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart - hzo.QuadPart) * LONGLONG(1000000)) / hz.QuadPart; + } +#else // DOCTEST_PLATFORM_WINDOWS + ticks_t getCurrentTicks() { + timeval t; + gettimeofday(&t, nullptr); + return static_cast(t.tv_sec) * 1000000 + static_cast(t.tv_usec); + } +#endif // DOCTEST_PLATFORM_WINDOWS + + struct Timer + { + void start() { m_ticks = getCurrentTicks(); } + unsigned int getElapsedMicroseconds() const { + return static_cast(getCurrentTicks() - m_ticks); + } + //unsigned int getElapsedMilliseconds() const { + // return static_cast(getElapsedMicroseconds() / 1000); + //} + double getElapsedSeconds() const { return static_cast(getCurrentTicks() - m_ticks) / 1000000.0; } + + private: + ticks_t m_ticks = 0; + }; + +#ifdef DOCTEST_CONFIG_NO_MULTITHREADING + template + using Atomic = T; +#else // DOCTEST_CONFIG_NO_MULTITHREADING + template + using Atomic = std::atomic; +#endif // DOCTEST_CONFIG_NO_MULTITHREADING + +#if defined(DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS) || defined(DOCTEST_CONFIG_NO_MULTITHREADING) + template + using MultiLaneAtomic = Atomic; +#else // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + // Provides a multilane implementation of an atomic variable that supports add, sub, load, + // store. Instead of using a single atomic variable, this splits up into multiple ones, + // each sitting on a separate cache line. The goal is to provide a speedup when most + // operations are modifying. It achieves this with two properties: + // + // * Multiple atomics are used, so chance of congestion from the same atomic is reduced. + // * Each atomic sits on a separate cache line, so false sharing is reduced. + // + // The disadvantage is that there is a small overhead due to the use of TLS, and load/store + // is slower because all atomics have to be accessed. + template + class MultiLaneAtomic + { + struct CacheLineAlignedAtomic + { + Atomic atomic{}; + char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(Atomic)]; + }; + CacheLineAlignedAtomic m_atomics[DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES]; + + static_assert(sizeof(CacheLineAlignedAtomic) == DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE, + "guarantee one atomic takes exactly one cache line"); + + public: + T operator++() DOCTEST_NOEXCEPT { return fetch_add(1) + 1; } + + T operator++(int) DOCTEST_NOEXCEPT { return fetch_add(1); } + + T fetch_add(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + return myAtomic().fetch_add(arg, order); + } + + T fetch_sub(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + return myAtomic().fetch_sub(arg, order); + } + + operator T() const DOCTEST_NOEXCEPT { return load(); } + + T load(std::memory_order order = std::memory_order_seq_cst) const DOCTEST_NOEXCEPT { + auto result = T(); + for(auto const& c : m_atomics) { + result += c.atomic.load(order); + } + return result; + } + + T operator=(T desired) DOCTEST_NOEXCEPT { // lgtm [cpp/assignment-does-not-return-this] + store(desired); + return desired; + } + + void store(T desired, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + // first value becomes desired", all others become 0. + for(auto& c : m_atomics) { + c.atomic.store(desired, order); + desired = {}; + } + } + + private: + // Each thread has a different atomic that it operates on. If more than NumLanes threads + // use this, some will use the same atomic. So performance will degrade a bit, but still + // everything will work. + // + // The logic here is a bit tricky. The call should be as fast as possible, so that there + // is minimal to no overhead in determining the correct atomic for the current thread. + // + // 1. A global static counter laneCounter counts continuously up. + // 2. Each successive thread will use modulo operation of that counter so it gets an atomic + // assigned in a round-robin fashion. + // 3. This tlsLaneIdx is stored in the thread local data, so it is directly available with + // little overhead. + Atomic& myAtomic() DOCTEST_NOEXCEPT { + static Atomic laneCounter; + DOCTEST_THREAD_LOCAL size_t tlsLaneIdx = + laneCounter++ % DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES; + + return m_atomics[tlsLaneIdx].atomic; + } + }; +#endif // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + + // this holds both parameters from the command line and runtime data for tests + struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats + { + MultiLaneAtomic numAssertsCurrentTest_atomic; + MultiLaneAtomic numAssertsFailedCurrentTest_atomic; + + std::vector> filters = decltype(filters)(9); // 9 different filters + + std::vector reporters_currently_used; + + assert_handler ah = nullptr; + + Timer timer; + + std::vector stringifiedContexts; // logging from INFO() due to an exception + + // stuff for subcases + bool reachedLeaf; + std::vector subcaseStack; + std::vector nextSubcaseStack; + std::unordered_set fullyTraversedSubcases; + size_t currentSubcaseDepth; + Atomic shouldLogCurrentException; + + void resetRunData() { + numTestCases = 0; + numTestCasesPassingFilters = 0; + numTestSuitesPassingFilters = 0; + numTestCasesFailed = 0; + numAsserts = 0; + numAssertsFailed = 0; + numAssertsCurrentTest = 0; + numAssertsFailedCurrentTest = 0; + } + + void finalizeTestCaseData() { + seconds = timer.getElapsedSeconds(); + + // update the non-atomic counters + numAsserts += numAssertsCurrentTest_atomic; + numAssertsFailed += numAssertsFailedCurrentTest_atomic; + numAssertsCurrentTest = numAssertsCurrentTest_atomic; + numAssertsFailedCurrentTest = numAssertsFailedCurrentTest_atomic; + + if(numAssertsFailedCurrentTest) + failure_flags |= TestCaseFailureReason::AssertFailure; + + if(Approx(currentTest->m_timeout).epsilon(DBL_EPSILON) != 0 && + Approx(seconds).epsilon(DBL_EPSILON) > currentTest->m_timeout) + failure_flags |= TestCaseFailureReason::Timeout; + + if(currentTest->m_should_fail) { + if(failure_flags) { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedAndDid; + } else { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedButDidnt; + } + } else if(failure_flags && currentTest->m_may_fail) { + failure_flags |= TestCaseFailureReason::CouldHaveFailedAndDid; + } else if(currentTest->m_expected_failures > 0) { + if(numAssertsFailedCurrentTest == currentTest->m_expected_failures) { + failure_flags |= TestCaseFailureReason::FailedExactlyNumTimes; + } else { + failure_flags |= TestCaseFailureReason::DidntFailExactlyNumTimes; + } + } + + bool ok_to_fail = (TestCaseFailureReason::ShouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::CouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); + + // if any subcase has failed - the whole test case has failed + testCaseSuccess = !(failure_flags && !ok_to_fail); + if(!testCaseSuccess) + numTestCasesFailed++; + } + }; + + ContextState* g_cs = nullptr; + + // used to avoid locks for the debug output + // TODO: figure out if this is indeed necessary/correct - seems like either there still + // could be a race or that there wouldn't be a race even if using the context directly + DOCTEST_THREAD_LOCAL bool g_no_colors; + +#endif // DOCTEST_CONFIG_DISABLE +} // namespace detail + +char* String::allocate(size_type sz) { + if (sz <= last) { + buf[sz] = '\0'; + setLast(last - sz); + return buf; + } else { + setOnHeap(); + data.size = sz; + data.capacity = data.size + 1; + data.ptr = new char[data.capacity]; + data.ptr[sz] = '\0'; + return data.ptr; + } +} + +void String::setOnHeap() noexcept { *reinterpret_cast(&buf[last]) = 128; } +void String::setLast(size_type in) noexcept { buf[last] = char(in); } +void String::setSize(size_type sz) noexcept { + if (isOnStack()) { buf[sz] = '\0'; setLast(last - sz); } + else { data.ptr[sz] = '\0'; data.size = sz; } +} + +void String::copy(const String& other) { + if(other.isOnStack()) { + memcpy(buf, other.buf, len); + } else { + memcpy(allocate(other.data.size), other.data.ptr, other.data.size); + } +} + +String::String() noexcept { + buf[0] = '\0'; + setLast(); +} + +String::~String() { + if(!isOnStack()) + delete[] data.ptr; +} // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + +String::String(const char* in) + : String(in, strlen(in)) {} + +String::String(const char* in, size_type in_size) { + memcpy(allocate(in_size), in, in_size); +} + +String::String(std::istream& in, size_type in_size) { + in.read(allocate(in_size), in_size); +} + +String::String(const String& other) { copy(other); } + +String& String::operator=(const String& other) { + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + + copy(other); + } + + return *this; +} + +String& String::operator+=(const String& other) { + const size_type my_old_size = size(); + const size_type other_size = other.size(); + const size_type total_size = my_old_size + other_size; + if(isOnStack()) { + if(total_size < len) { + // append to the current stack space + memcpy(buf + my_old_size, other.c_str(), other_size + 1); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + setLast(last - total_size); + } else { + // alloc new chunk + char* temp = new char[total_size + 1]; + // copy current data to new location before writing in the union + memcpy(temp, buf, my_old_size); // skip the +1 ('\0') for speed + // update data in union + setOnHeap(); + data.size = total_size; + data.capacity = data.size + 1; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } else { + if(data.capacity > total_size) { + // append to the current heap block + data.size = total_size; + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } else { + // resize + data.capacity *= 2; + if(data.capacity <= total_size) + data.capacity = total_size + 1; + // alloc new chunk + char* temp = new char[data.capacity]; + // copy current data to new location before releasing it + memcpy(temp, data.ptr, my_old_size); // skip the +1 ('\0') for speed + // release old chunk + delete[] data.ptr; + // update the rest of the union members + data.size = total_size; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } + + return *this; +} + +String::String(String&& other) noexcept { + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); +} + +String& String::operator=(String&& other) noexcept { + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); + } + return *this; +} + +char String::operator[](size_type i) const { + return const_cast(this)->operator[](i); +} + +char& String::operator[](size_type i) { + if(isOnStack()) + return reinterpret_cast(buf)[i]; + return data.ptr[i]; +} + +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") +String::size_type String::size() const { + if(isOnStack()) + return last - (size_type(buf[last]) & 31); // using "last" would work only if "len" is 32 + return data.size; +} +DOCTEST_GCC_SUPPRESS_WARNING_POP + +String::size_type String::capacity() const { + if(isOnStack()) + return len; + return data.capacity; +} + +String String::substr(size_type pos, size_type cnt) && { + cnt = std::min(cnt, size() - 1 - pos); + char* cptr = c_str(); + memmove(cptr, cptr + pos, cnt); + setSize(cnt); + return std::move(*this); +} + +String String::substr(size_type pos, size_type cnt) const & { + cnt = std::min(cnt, size() - 1 - pos); + return String{ c_str() + pos, cnt }; +} + +String::size_type String::find(char ch, size_type pos) const { + const char* begin = c_str(); + const char* end = begin + size(); + const char* it = begin + pos; + for (; it < end && *it != ch; it++); + if (it < end) { return static_cast(it - begin); } + else { return npos; } +} + +String::size_type String::rfind(char ch, size_type pos) const { + const char* begin = c_str(); + const char* it = begin + std::min(pos, size() - 1); + for (; it >= begin && *it != ch; it--); + if (it >= begin) { return static_cast(it - begin); } + else { return npos; } +} + +int String::compare(const char* other, bool no_case) const { + if(no_case) + return doctest::stricmp(c_str(), other); + return std::strcmp(c_str(), other); +} + +int String::compare(const String& other, bool no_case) const { + return compare(other.c_str(), no_case); +} + +String operator+(const String& lhs, const String& rhs) { return String(lhs) += rhs; } + +bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } +bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } +bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } +bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } +bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } +bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } + +std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } + +Contains::Contains(const String& str) : string(str) { } + +bool Contains::checkWith(const String& other) const { + return strstr(other.c_str(), string.c_str()) != nullptr; +} + +String toString(const Contains& in) { + return "Contains( " + in.string + " )"; +} + +bool operator==(const String& lhs, const Contains& rhs) { return rhs.checkWith(lhs); } +bool operator==(const Contains& lhs, const String& rhs) { return lhs.checkWith(rhs); } +bool operator!=(const String& lhs, const Contains& rhs) { return !rhs.checkWith(lhs); } +bool operator!=(const Contains& lhs, const String& rhs) { return !lhs.checkWith(rhs); } + +namespace { + void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) +} // namespace + +namespace Color { + std::ostream& operator<<(std::ostream& s, Color::Enum code) { + color_to_stream(s, code); + return s; + } +} // namespace Color + +// clang-format off +const char* assertString(assertType::Enum at) { + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4061) // enum 'x' in switch of enum 'y' is not explicitly handled + #define DOCTEST_GENERATE_ASSERT_TYPE_CASE(assert_type) case assertType::DT_ ## assert_type: return #assert_type + #define DOCTEST_GENERATE_ASSERT_TYPE_CASES(assert_type) \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN_ ## assert_type); \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK_ ## assert_type); \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE_ ## assert_type) + switch(at) { + DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN); + DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK); + DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(FALSE); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_AS); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH_AS); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(NOTHROW); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(EQ); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(NE); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(GT); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(LT); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(GE); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(LE); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY_FALSE); + + default: DOCTEST_INTERNAL_ERROR("Tried stringifying invalid assert type!"); + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP +} +// clang-format on + +const char* failureString(assertType::Enum at) { + if(at & assertType::is_warn) //!OCLINT bitwise operator in conditional + return "WARNING"; + if(at & assertType::is_check) //!OCLINT bitwise operator in conditional + return "ERROR"; + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return "FATAL ERROR"; + return ""; +} + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +// depending on the current options this will remove the path of filenames +const char* skipPathFromFilename(const char* file) { +#ifndef DOCTEST_CONFIG_DISABLE + if(getContextOptions()->no_path_in_filenames) { + auto back = std::strrchr(file, '\\'); + auto forward = std::strrchr(file, '/'); + if(back || forward) { + if(back > forward) + forward = back; + return forward + 1; + } + } +#endif // DOCTEST_CONFIG_DISABLE + return file; +} +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +bool SubcaseSignature::operator==(const SubcaseSignature& other) const { + return m_line == other.m_line + && std::strcmp(m_file, other.m_file) == 0 + && m_name == other.m_name; +} + +bool SubcaseSignature::operator<(const SubcaseSignature& other) const { + if(m_line != other.m_line) + return m_line < other.m_line; + if(std::strcmp(m_file, other.m_file) != 0) + return std::strcmp(m_file, other.m_file) < 0; + return m_name.compare(other.m_name) < 0; +} + +DOCTEST_DEFINE_INTERFACE(IContextScope) + +namespace detail { + void filldata::fill(std::ostream* stream, const void* in) { + if (in) { *stream << in; } + else { *stream << "nullptr"; } + } + + template + String toStreamLit(T t) { + std::ostream* os = tlssPush(); + os->operator<<(t); + return tlssPop(); + } +} + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +String toString(const std::string& in) { return in.c_str(); } +#endif // VS 2019 + +String toString(String in) { return in; } + +String toString(std::nullptr_t) { return "nullptr"; } + +String toString(bool in) { return in ? "true" : "false"; } + +String toString(float in) { return toStreamLit(in); } +String toString(double in) { return toStreamLit(in); } +String toString(double long in) { return toStreamLit(in); } + +String toString(char in) { return toStreamLit(static_cast(in)); } +String toString(char signed in) { return toStreamLit(static_cast(in)); } +String toString(char unsigned in) { return toStreamLit(static_cast(in)); } +String toString(short in) { return toStreamLit(in); } +String toString(short unsigned in) { return toStreamLit(in); } +String toString(signed in) { return toStreamLit(in); } +String toString(unsigned in) { return toStreamLit(in); } +String toString(long in) { return toStreamLit(in); } +String toString(long unsigned in) { return toStreamLit(in); } +String toString(long long in) { return toStreamLit(in); } +String toString(long long unsigned in) { return toStreamLit(in); } + +Approx::Approx(double value) + : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) + , m_scale(1.0) + , m_value(value) {} + +Approx Approx::operator()(double value) const { + Approx approx(value); + approx.epsilon(m_epsilon); + approx.scale(m_scale); + return approx; +} + +Approx& Approx::epsilon(double newEpsilon) { + m_epsilon = newEpsilon; + return *this; +} +Approx& Approx::scale(double newScale) { + m_scale = newScale; + return *this; +} + +bool operator==(double lhs, const Approx& rhs) { + // Thanks to Richard Harris for his help refining this formula + return std::fabs(lhs - rhs.m_value) < + rhs.m_epsilon * (rhs.m_scale + std::max(std::fabs(lhs), std::fabs(rhs.m_value))); +} +bool operator==(const Approx& lhs, double rhs) { return operator==(rhs, lhs); } +bool operator!=(double lhs, const Approx& rhs) { return !operator==(lhs, rhs); } +bool operator!=(const Approx& lhs, double rhs) { return !operator==(rhs, lhs); } +bool operator<=(double lhs, const Approx& rhs) { return lhs < rhs.m_value || lhs == rhs; } +bool operator<=(const Approx& lhs, double rhs) { return lhs.m_value < rhs || lhs == rhs; } +bool operator>=(double lhs, const Approx& rhs) { return lhs > rhs.m_value || lhs == rhs; } +bool operator>=(const Approx& lhs, double rhs) { return lhs.m_value > rhs || lhs == rhs; } +bool operator<(double lhs, const Approx& rhs) { return lhs < rhs.m_value && lhs != rhs; } +bool operator<(const Approx& lhs, double rhs) { return lhs.m_value < rhs && lhs != rhs; } +bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs != rhs; } +bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } + +String toString(const Approx& in) { + return "Approx( " + doctest::toString(in.m_value) + " )"; +} +const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } + +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4738) +template +IsNaN::operator bool() const { + return std::isnan(value) ^ flipped; +} +DOCTEST_MSVC_SUPPRESS_WARNING_POP +template struct DOCTEST_INTERFACE_DEF IsNaN; +template struct DOCTEST_INTERFACE_DEF IsNaN; +template struct DOCTEST_INTERFACE_DEF IsNaN; +template +String toString(IsNaN in) { return String(in.flipped ? "! " : "") + "IsNaN( " + doctest::toString(in.value) + " )"; } +String toString(IsNaN in) { return toString(in); } +String toString(IsNaN in) { return toString(in); } +String toString(IsNaN in) { return toString(in); } + +} // namespace doctest + +#ifdef DOCTEST_CONFIG_DISABLE +namespace doctest { +Context::Context(int, const char* const*) {} +Context::~Context() = default; +void Context::applyCommandLine(int, const char* const*) {} +void Context::addFilter(const char*, const char*) {} +void Context::clearFilters() {} +void Context::setOption(const char*, bool) {} +void Context::setOption(const char*, int) {} +void Context::setOption(const char*, const char*) {} +bool Context::shouldExit() { return false; } +void Context::setAsDefaultForAssertsOutOfTestCases() {} +void Context::setAssertHandler(detail::assert_handler) {} +void Context::setCout(std::ostream*) {} +int Context::run() { return 0; } + +int IReporter::get_num_active_contexts() { return 0; } +const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } +int IReporter::get_num_stringified_contexts() { return 0; } +const String* IReporter::get_stringified_contexts() { return nullptr; } + +int registerReporter(const char*, int, IReporter*) { return 0; } + +} // namespace doctest +#else // DOCTEST_CONFIG_DISABLE + +#if !defined(DOCTEST_CONFIG_COLORS_NONE) +#if !defined(DOCTEST_CONFIG_COLORS_WINDOWS) && !defined(DOCTEST_CONFIG_COLORS_ANSI) +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_CONFIG_COLORS_WINDOWS +#else // linux +#define DOCTEST_CONFIG_COLORS_ANSI +#endif // platform +#endif // DOCTEST_CONFIG_COLORS_WINDOWS && DOCTEST_CONFIG_COLORS_ANSI +#endif // DOCTEST_CONFIG_COLORS_NONE + +namespace doctest_detail_test_suite_ns { +// holds the current test suite +doctest::detail::TestSuite& getCurrentTestSuite() { + static doctest::detail::TestSuite data{}; + return data; +} +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +namespace { + // the int (priority) is part of the key for automatic sorting - sadly one can register a + // reporter with a duplicate name and a different priority but hopefully that won't happen often :| + using reporterMap = std::map, reporterCreatorFunc>; + + reporterMap& getReporters() { + static reporterMap data; + return data; + } + reporterMap& getListeners() { + static reporterMap data; + return data; + } +} // namespace +namespace detail { +#define DOCTEST_ITERATE_THROUGH_REPORTERS(function, ...) \ + for(auto& curr_rep : g_cs->reporters_currently_used) \ + curr_rep->function(__VA_ARGS__) + + bool checkIfShouldThrow(assertType::Enum at) { + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return true; + + if((at & assertType::is_check) //!OCLINT bitwise operator in conditional + && getContextOptions()->abort_after > 0 && + (g_cs->numAssertsFailed + g_cs->numAssertsFailedCurrentTest_atomic) >= + getContextOptions()->abort_after) + return true; + + return false; + } + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN void throwException() { + g_cs->shouldLogCurrentException = false; + throw TestFailureException(); // NOLINT(hicpp-exception-baseclass) + } +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + void throwException() {} +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +} // namespace detail + +namespace { + using namespace detail; + // matching of a string against a wildcard mask (case sensitivity configurable) taken from + // https://www.codeproject.com/Articles/1088/Wildcard-string-compare-globbing + int wildcmp(const char* str, const char* wild, bool caseSensitive) { + const char* cp = str; + const char* mp = wild; + + while((*str) && (*wild != '*')) { + if((caseSensitive ? (*wild != *str) : (tolower(*wild) != tolower(*str))) && + (*wild != '?')) { + return 0; + } + wild++; + str++; + } + + while(*str) { + if(*wild == '*') { + if(!*++wild) { + return 1; + } + mp = wild; + cp = str + 1; + } else if((caseSensitive ? (*wild == *str) : (tolower(*wild) == tolower(*str))) || + (*wild == '?')) { + wild++; + str++; + } else { + wild = mp; //!OCLINT parameter reassignment + str = cp++; //!OCLINT parameter reassignment + } + } + + while(*wild == '*') { + wild++; + } + return !*wild; + } + + // checks if the name matches any of the filters (and can be configured what to do when empty) + bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, + bool caseSensitive) { + if (filters.empty() && matchEmpty) + return true; + for (auto& curr : filters) + if (wildcmp(name, curr.c_str(), caseSensitive)) + return true; + return false; + } + + DOCTEST_NO_SANITIZE_INTEGER + unsigned long long hash(unsigned long long a, unsigned long long b) { + return (a << 5) + b; + } + + // C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html + DOCTEST_NO_SANITIZE_INTEGER + unsigned long long hash(const char* str) { + unsigned long long hash = 5381; + char c; + while ((c = *str++)) + hash = ((hash << 5) + hash) + c; // hash * 33 + c + return hash; + } + + unsigned long long hash(const SubcaseSignature& sig) { + return hash(hash(hash(sig.m_file), hash(sig.m_name.c_str())), sig.m_line); + } + + unsigned long long hash(const std::vector& sigs, size_t count) { + unsigned long long running = 0; + auto end = sigs.begin() + count; + for (auto it = sigs.begin(); it != end; it++) { + running = hash(running, hash(*it)); + } + return running; + } + + unsigned long long hash(const std::vector& sigs) { + unsigned long long running = 0; + for (const SubcaseSignature& sig : sigs) { + running = hash(running, hash(sig)); + } + return running; + } +} // namespace +namespace detail { + bool Subcase::checkFilters() { + if (g_cs->subcaseStack.size() < size_t(g_cs->subcase_filter_levels)) { + if (!matchesAny(m_signature.m_name.c_str(), g_cs->filters[6], true, g_cs->case_sensitive)) + return true; + if (matchesAny(m_signature.m_name.c_str(), g_cs->filters[7], false, g_cs->case_sensitive)) + return true; + } + return false; + } + + Subcase::Subcase(const String& name, const char* file, int line) + : m_signature({name, file, line}) { + if (!g_cs->reachedLeaf) { + if (g_cs->nextSubcaseStack.size() <= g_cs->subcaseStack.size() + || g_cs->nextSubcaseStack[g_cs->subcaseStack.size()] == m_signature) { + // Going down. + if (checkFilters()) { return; } + + g_cs->subcaseStack.push_back(m_signature); + g_cs->currentSubcaseDepth++; + m_entered = true; + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } + } else { + if (g_cs->subcaseStack[g_cs->currentSubcaseDepth] == m_signature) { + // This subcase is reentered via control flow. + g_cs->currentSubcaseDepth++; + m_entered = true; + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } else if (g_cs->nextSubcaseStack.size() <= g_cs->currentSubcaseDepth + && g_cs->fullyTraversedSubcases.find(hash(hash(g_cs->subcaseStack, g_cs->currentSubcaseDepth), hash(m_signature))) + == g_cs->fullyTraversedSubcases.end()) { + if (checkFilters()) { return; } + // This subcase is part of the one to be executed next. + g_cs->nextSubcaseStack.clear(); + g_cs->nextSubcaseStack.insert(g_cs->nextSubcaseStack.end(), + g_cs->subcaseStack.begin(), g_cs->subcaseStack.begin() + g_cs->currentSubcaseDepth); + g_cs->nextSubcaseStack.push_back(m_signature); + } + } + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + Subcase::~Subcase() { + if (m_entered) { + g_cs->currentSubcaseDepth--; + + if (!g_cs->reachedLeaf) { + // Leaf. + g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); + g_cs->nextSubcaseStack.clear(); + g_cs->reachedLeaf = true; + } else if (g_cs->nextSubcaseStack.empty()) { + // All children are finished. + g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); + } + +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) + if(std::uncaught_exceptions() > 0 +#else + if(std::uncaught_exception() +#endif + && g_cs->shouldLogCurrentException) { + DOCTEST_ITERATE_THROUGH_REPORTERS( + test_case_exception, {"exception thrown in subcase - will translate later " + "when the whole test case has been exited (cannot " + "translate while there is an active exception)", + false}); + g_cs->shouldLogCurrentException = false; + } + + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + Subcase::operator bool() const { return m_entered; } + + Result::Result(bool passed, const String& decomposition) + : m_passed(passed) + , m_decomp(decomposition) {} + + ExpressionDecomposer::ExpressionDecomposer(assertType::Enum at) + : m_at(at) {} + + TestSuite& TestSuite::operator*(const char* in) { + m_test_suite = in; + return *this; + } + + TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const String& type, int template_id) { + m_file = file; + m_line = line; + m_name = nullptr; // will be later overridden in operator* + m_test_suite = test_suite.m_test_suite; + m_description = test_suite.m_description; + m_skip = test_suite.m_skip; + m_no_breaks = test_suite.m_no_breaks; + m_no_output = test_suite.m_no_output; + m_may_fail = test_suite.m_may_fail; + m_should_fail = test_suite.m_should_fail; + m_expected_failures = test_suite.m_expected_failures; + m_timeout = test_suite.m_timeout; + + m_test = test; + m_type = type; + m_template_id = template_id; + } + + TestCase::TestCase(const TestCase& other) + : TestCaseData() { + *this = other; + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + TestCase& TestCase::operator=(const TestCase& other) { + TestCaseData::operator=(other); + m_test = other.m_test; + m_type = other.m_type; + m_template_id = other.m_template_id; + m_full_name = other.m_full_name; + + if(m_template_id != -1) + m_name = m_full_name.c_str(); + return *this; + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& TestCase::operator*(const char* in) { + m_name = in; + // make a new name with an appended type for templated test case + if(m_template_id != -1) { + m_full_name = String(m_name) + "<" + m_type + ">"; + // redirect the name to point to the newly constructed full name + m_name = m_full_name.c_str(); + } + return *this; + } + + bool TestCase::operator<(const TestCase& other) const { + // this will be used only to differentiate between test cases - not relevant for sorting + if(m_line != other.m_line) + return m_line < other.m_line; + const int name_cmp = strcmp(m_name, other.m_name); + if(name_cmp != 0) + return name_cmp < 0; + const int file_cmp = m_file.compare(other.m_file); + if(file_cmp != 0) + return file_cmp < 0; + return m_template_id < other.m_template_id; + } + + // all the registered tests + std::set& getRegisteredTests() { + static std::set data; + return data; + } +} // namespace detail +namespace { + using namespace detail; + // for sorting tests by file/line + bool fileOrderComparator(const TestCase* lhs, const TestCase* rhs) { + // this is needed because MSVC gives different case for drive letters + // for __FILE__ when evaluated in a header and a source file + const int res = lhs->m_file.compare(rhs->m_file, bool(DOCTEST_MSVC)); + if(res != 0) + return res < 0; + if(lhs->m_line != rhs->m_line) + return lhs->m_line < rhs->m_line; + return lhs->m_template_id < rhs->m_template_id; + } + + // for sorting tests by suite/file/line + bool suiteOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_test_suite, rhs->m_test_suite); + if(res != 0) + return res < 0; + return fileOrderComparator(lhs, rhs); + } + + // for sorting tests by name/suite/file/line + bool nameOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_name, rhs->m_name); + if(res != 0) + return res < 0; + return suiteOrderComparator(lhs, rhs); + } + + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + void color_to_stream(std::ostream& s, Color::Enum code) { + static_cast(s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS + static_cast(code); // for DOCTEST_CONFIG_COLORS_NONE +#ifdef DOCTEST_CONFIG_COLORS_ANSI + if(g_no_colors || + (isatty(STDOUT_FILENO) == false && getContextOptions()->force_colors == false)) + return; + + auto col = ""; + // clang-format off + switch(code) { //!OCLINT missing break in switch statement / unnecessary default statement in covered switch statement + case Color::Red: col = "[0;31m"; break; + case Color::Green: col = "[0;32m"; break; + case Color::Blue: col = "[0;34m"; break; + case Color::Cyan: col = "[0;36m"; break; + case Color::Yellow: col = "[0;33m"; break; + case Color::Grey: col = "[1;30m"; break; + case Color::LightGrey: col = "[0;37m"; break; + case Color::BrightRed: col = "[1;31m"; break; + case Color::BrightGreen: col = "[1;32m"; break; + case Color::BrightWhite: col = "[1;37m"; break; + case Color::Bright: // invalid + case Color::None: + case Color::White: + default: col = "[0m"; + } + // clang-format on + s << "\033" << col; +#endif // DOCTEST_CONFIG_COLORS_ANSI + +#ifdef DOCTEST_CONFIG_COLORS_WINDOWS + if(g_no_colors || + (_isatty(_fileno(stdout)) == false && getContextOptions()->force_colors == false)) + return; + + static struct ConsoleHelper { + HANDLE stdoutHandle; + WORD origFgAttrs; + WORD origBgAttrs; + + ConsoleHelper() { + stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(stdoutHandle, &csbiInfo); + origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | + BACKGROUND_BLUE | BACKGROUND_INTENSITY); + origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | + FOREGROUND_BLUE | FOREGROUND_INTENSITY); + } + } ch; + +#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(ch.stdoutHandle, x | ch.origBgAttrs) + + // clang-format off + switch (code) { + case Color::White: DOCTEST_SET_ATTR(FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::Red: DOCTEST_SET_ATTR(FOREGROUND_RED); break; + case Color::Green: DOCTEST_SET_ATTR(FOREGROUND_GREEN); break; + case Color::Blue: DOCTEST_SET_ATTR(FOREGROUND_BLUE); break; + case Color::Cyan: DOCTEST_SET_ATTR(FOREGROUND_BLUE | FOREGROUND_GREEN); break; + case Color::Yellow: DOCTEST_SET_ATTR(FOREGROUND_RED | FOREGROUND_GREEN); break; + case Color::Grey: DOCTEST_SET_ATTR(0); break; + case Color::LightGrey: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY); break; + case Color::BrightRed: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_RED); break; + case Color::BrightGreen: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN); break; + case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::None: + case Color::Bright: // invalid + default: DOCTEST_SET_ATTR(ch.origFgAttrs); + } + // clang-format on +#endif // DOCTEST_CONFIG_COLORS_WINDOWS + } + DOCTEST_CLANG_SUPPRESS_WARNING_POP + + std::vector& getExceptionTranslators() { + static std::vector data; + return data; + } + + String translateActiveException() { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + String res; + auto& translators = getExceptionTranslators(); + for(auto& curr : translators) + if(curr->translate(res)) + return res; + // clang-format off + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wcatch-value") + try { + throw; + } catch(std::exception& ex) { + return ex.what(); + } catch(std::string& msg) { + return msg.c_str(); + } catch(const char* msg) { + return msg; + } catch(...) { + return "unknown exception"; + } + DOCTEST_GCC_SUPPRESS_WARNING_POP +// clang-format on +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + return ""; +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } +} // namespace + +namespace detail { + // used by the macros for registering tests + int regTest(const TestCase& tc) { + getRegisteredTests().insert(tc); + return 0; + } + + // sets the current test suite + int setTestSuite(const TestSuite& ts) { + doctest_detail_test_suite_ns::getCurrentTestSuite() = ts; + return 0; + } + +#ifdef DOCTEST_IS_DEBUGGER_ACTIVE + bool isDebuggerActive() { return DOCTEST_IS_DEBUGGER_ACTIVE(); } +#else // DOCTEST_IS_DEBUGGER_ACTIVE +#ifdef DOCTEST_PLATFORM_LINUX + class ErrnoGuard { + public: + ErrnoGuard() : m_oldErrno(errno) {} + ~ErrnoGuard() { errno = m_oldErrno; } + private: + int m_oldErrno; + }; + // See the comments in Catch2 for the reasoning behind this implementation: + // https://github.com/catchorg/Catch2/blob/v2.13.1/include/internal/catch_debugger.cpp#L79-L102 + bool isDebuggerActive() { + ErrnoGuard guard; + std::ifstream in("/proc/self/status"); + for(std::string line; std::getline(in, line);) { + static const int PREFIX_LEN = 11; + if(line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0) { + return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; + } + } + return false; + } +#elif defined(DOCTEST_PLATFORM_MAC) + // The following function is taken directly from the following technical note: + // https://developer.apple.com/library/archive/qa/qa1361/_index.html + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive() { + int mib[4]; + kinfo_proc info; + size_t size; + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + info.kp_proc.p_flag = 0; + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + // Call sysctl. + size = sizeof(info); + if(sysctl(mib, DOCTEST_COUNTOF(mib), &info, &size, 0, 0) != 0) { + std::cerr << "\nCall to sysctl failed - unable to determine if debugger is active **\n"; + return false; + } + // We're being debugged if the P_TRACED flag is set. + return ((info.kp_proc.p_flag & P_TRACED) != 0); + } +#elif DOCTEST_MSVC || defined(__MINGW32__) || defined(__MINGW64__) + bool isDebuggerActive() { return ::IsDebuggerPresent() != 0; } +#else + bool isDebuggerActive() { return false; } +#endif // Platform +#endif // DOCTEST_IS_DEBUGGER_ACTIVE + + void registerExceptionTranslatorImpl(const IExceptionTranslator* et) { + if(std::find(getExceptionTranslators().begin(), getExceptionTranslators().end(), et) == + getExceptionTranslators().end()) + getExceptionTranslators().push_back(et); + } + + DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() + + ContextScopeBase::ContextScopeBase() { + g_infoContexts.push_back(this); + } + + ContextScopeBase::ContextScopeBase(ContextScopeBase&& other) noexcept { + if (other.need_to_destroy) { + other.destroy(); + } + other.need_to_destroy = false; + g_infoContexts.push_back(this); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + // destroy cannot be inlined into the destructor because that would mean calling stringify after + // ContextScope has been destroyed (base class destructors run after derived class destructors). + // Instead, ContextScope calls this method directly from its destructor. + void ContextScopeBase::destroy() { +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) + if(std::uncaught_exceptions() > 0) { +#else + if(std::uncaught_exception()) { +#endif + std::ostringstream s; + this->stringify(&s); + g_cs->stringifiedContexts.push_back(s.str().c_str()); + } + g_infoContexts.pop_back(); + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP +} // namespace detail +namespace { + using namespace detail; + +#if !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && !defined(DOCTEST_CONFIG_WINDOWS_SEH) + struct FatalConditionHandler + { + static void reset() {} + static void allocateAltStackMem() {} + static void freeAltStackMem() {} + }; +#else // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + + void reportFatal(const std::string&); + +#ifdef DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + DWORD id; + const char* name; + }; + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + SignalDefs signalDefs[] = { + {static_cast(EXCEPTION_ILLEGAL_INSTRUCTION), + "SIGILL - Illegal instruction signal"}, + {static_cast(EXCEPTION_STACK_OVERFLOW), "SIGSEGV - Stack overflow"}, + {static_cast(EXCEPTION_ACCESS_VIOLATION), + "SIGSEGV - Segmentation violation signal"}, + {static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error"}, + }; + + struct FatalConditionHandler + { + static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { + // Multiple threads may enter this filter/handler at once. We want the error message to be printed on the + // console just once no matter how many threads have crashed. + DOCTEST_DECLARE_STATIC_MUTEX(mutex) + static bool execute = true; + { + DOCTEST_LOCK_MUTEX(mutex) + if(execute) { + bool reported = false; + for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + if(ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { + reportFatal(signalDefs[i].name); + reported = true; + break; + } + } + if(reported == false) + reportFatal("Unhandled SEH exception caught"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + } + execute = false; + } + std::exit(EXIT_FAILURE); + } + + static void allocateAltStackMem() {} + static void freeAltStackMem() {} + + FatalConditionHandler() { + isSet = true; + // 32k seems enough for doctest to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + // Register an unhandled exception filter + previousTop = SetUnhandledExceptionFilter(handleException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + + // On Windows uncaught exceptions from another thread, exceptions from + // destructors, or calls to std::terminate are not a SEH exception + + // The terminal handler gets called when: + // - std::terminate is called FROM THE TEST RUNNER THREAD + // - an exception is thrown from a destructor FROM THE TEST RUNNER THREAD + original_terminate_handler = std::get_terminate(); + std::set_terminate([]() DOCTEST_NOEXCEPT { + reportFatal("Terminate handler called"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + std::exit(EXIT_FAILURE); // explicitly exit - otherwise the SIGABRT handler may be called as well + }); + + // SIGABRT is raised when: + // - std::terminate is called FROM A DIFFERENT THREAD + // - an exception is thrown from a destructor FROM A DIFFERENT THREAD + // - an uncaught exception is thrown FROM A DIFFERENT THREAD + prev_sigabrt_handler = std::signal(SIGABRT, [](int signal) DOCTEST_NOEXCEPT { + if(signal == SIGABRT) { + reportFatal("SIGABRT - Abort (abnormal termination) signal"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + std::exit(EXIT_FAILURE); + } + }); + + // The following settings are taken from google test, and more + // specifically from UnitTest::Run() inside of gtest.cc + + // the user does not want to see pop-up dialogs about crashes + prev_error_mode_1 = SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | + SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); + // This forces the abort message to go to stderr in all circumstances. + prev_error_mode_2 = _set_error_mode(_OUT_TO_STDERR); + // In the debug version, Visual Studio pops up a separate dialog + // offering a choice to debug the aborted program - we want to disable that. + prev_abort_behavior = _set_abort_behavior(0x0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); + // In debug mode, the Windows CRT can crash with an assertion over invalid + // input (e.g. passing an invalid file descriptor). The default handling + // for these assertions is to pop up a dialog and wait for user input. + // Instead ask the CRT to dump such assertions to stderr non-interactively. + prev_report_mode = _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + prev_report_file = _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); + } + + static void reset() { + if(isSet) { + // Unregister handler and restore the old guarantee + SetUnhandledExceptionFilter(previousTop); + SetThreadStackGuarantee(&guaranteeSize); + std::set_terminate(original_terminate_handler); + std::signal(SIGABRT, prev_sigabrt_handler); + SetErrorMode(prev_error_mode_1); + _set_error_mode(prev_error_mode_2); + _set_abort_behavior(prev_abort_behavior, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); + static_cast(_CrtSetReportMode(_CRT_ASSERT, prev_report_mode)); + static_cast(_CrtSetReportFile(_CRT_ASSERT, prev_report_file)); + isSet = false; + } + } + + ~FatalConditionHandler() { reset(); } + + private: + static UINT prev_error_mode_1; + static int prev_error_mode_2; + static unsigned int prev_abort_behavior; + static int prev_report_mode; + static _HFILE prev_report_file; + static void (DOCTEST_CDECL *prev_sigabrt_handler)(int); + static std::terminate_handler original_terminate_handler; + static bool isSet; + static ULONG guaranteeSize; + static LPTOP_LEVEL_EXCEPTION_FILTER previousTop; + }; + + UINT FatalConditionHandler::prev_error_mode_1; + int FatalConditionHandler::prev_error_mode_2; + unsigned int FatalConditionHandler::prev_abort_behavior; + int FatalConditionHandler::prev_report_mode; + _HFILE FatalConditionHandler::prev_report_file; + void (DOCTEST_CDECL *FatalConditionHandler::prev_sigabrt_handler)(int); + std::terminate_handler FatalConditionHandler::original_terminate_handler; + bool FatalConditionHandler::isSet = false; + ULONG FatalConditionHandler::guaranteeSize = 0; + LPTOP_LEVEL_EXCEPTION_FILTER FatalConditionHandler::previousTop = nullptr; + +#else // DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + int id; + const char* name; + }; + SignalDefs signalDefs[] = {{SIGINT, "SIGINT - Terminal interrupt signal"}, + {SIGILL, "SIGILL - Illegal instruction signal"}, + {SIGFPE, "SIGFPE - Floating point error signal"}, + {SIGSEGV, "SIGSEGV - Segmentation violation signal"}, + {SIGTERM, "SIGTERM - Termination request signal"}, + {SIGABRT, "SIGABRT - Abort (abnormal termination) signal"}}; + + struct FatalConditionHandler + { + static bool isSet; + static struct sigaction oldSigActions[DOCTEST_COUNTOF(signalDefs)]; + static stack_t oldSigStack; + static size_t altStackSize; + static char* altStackMem; + + static void handleSignal(int sig) { + const char* name = ""; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + SignalDefs& def = signalDefs[i]; + if(sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise(sig); + } + + static void allocateAltStackMem() { + altStackMem = new char[altStackSize]; + } + + static void freeAltStackMem() { + delete[] altStackMem; + } + + FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = altStackSize; + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = {}; + sa.sa_handler = handleSignal; + sa.sa_flags = SA_ONSTACK; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + ~FatalConditionHandler() { reset(); } + static void reset() { + if(isSet) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + isSet = false; + } + } + }; + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[DOCTEST_COUNTOF(signalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + size_t FatalConditionHandler::altStackSize = 4 * SIGSTKSZ; + char* FatalConditionHandler::altStackMem = nullptr; + +#endif // DOCTEST_PLATFORM_WINDOWS +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + +} // namespace + +namespace { + using namespace detail; + +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) +#else + // TODO: integration with XCode and other IDEs +#define DOCTEST_OUTPUT_DEBUG_STRING(text) +#endif // Platform + + void addAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsCurrentTest_atomic++; + } + + void addFailedAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsFailedCurrentTest_atomic++; + } + +#if defined(DOCTEST_CONFIG_POSIX_SIGNALS) || defined(DOCTEST_CONFIG_WINDOWS_SEH) + void reportFatal(const std::string& message) { + g_cs->failure_flags |= TestCaseFailureReason::Crash; + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); + + while (g_cs->subcaseStack.size()) { + g_cs->subcaseStack.pop_back(); + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + + g_cs->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH +} // namespace + +AssertData::AssertData(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const StringContains& exception_string) + : m_test_case(g_cs->currentTest), m_at(at), m_file(file), m_line(line), m_expr(expr), + m_failed(true), m_threw(false), m_threw_as(false), m_exception_type(exception_type), + m_exception_string(exception_string) { +#if DOCTEST_MSVC + if (m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC + ++m_expr; +#endif // MSVC +} + +namespace detail { + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const String& exception_string) + : AssertData(at, file, line, expr, exception_type, exception_string) { } + + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const Contains& exception_string) + : AssertData(at, file, line, expr, exception_type, exception_string) { } + + void ResultBuilder::setResult(const Result& res) { + m_decomp = res.m_decomp; + m_failed = !res.m_passed; + } + + void ResultBuilder::translateException() { + m_threw = true; + m_exception = translateActiveException(); + } + + bool ResultBuilder::log() { + if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw; + } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT + m_failed = !m_threw_as || !m_exception_string.check(m_exception); + } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw_as; + } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + m_failed = !m_exception_string.check(m_exception); + } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + m_failed = m_threw; + } + + if(m_exception.size()) + m_exception = "\"" + m_exception + "\""; + + if(is_running_in_test) { + addAssert(m_at); + DOCTEST_ITERATE_THROUGH_REPORTERS(log_assert, *this); + + if(m_failed) + addFailedAssert(m_at); + } else if(m_failed) { + failed_out_of_a_testing_context(*this); + } + + return m_failed && isDebuggerActive() && !getContextOptions()->no_breaks && + (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger + } + + void ResultBuilder::react() const { + if(m_failed && checkIfShouldThrow(m_at)) + throwException(); + } + + void failed_out_of_a_testing_context(const AssertData& ad) { + if(g_cs->ah) + g_cs->ah(ad); + else + std::abort(); + } + + bool decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, + const Result& result) { + bool failed = !result.m_passed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); + DOCTEST_ASSERT_IN_TESTS(result.m_decomp); + return !failed; + } + + MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { + m_stream = tlssPush(); + m_file = file; + m_line = line; + m_severity = severity; + } + + MessageBuilder::~MessageBuilder() { + if (!logged) + tlssPop(); + } + + DOCTEST_DEFINE_INTERFACE(IExceptionTranslator) + + bool MessageBuilder::log() { + if (!logged) { + m_string = tlssPop(); + logged = true; + } + + DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); + + const bool isWarn = m_severity & assertType::is_warn; + + // warn is just a message in this context so we don't treat it as an assert + if(!isWarn) { + addAssert(m_severity); + addFailedAssert(m_severity); + } + + return isDebuggerActive() && !getContextOptions()->no_breaks && !isWarn && + (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger + } + + void MessageBuilder::react() { + if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional + throwException(); + } +} // namespace detail +namespace { + using namespace detail; + + // clang-format off + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); + + void encodeTo( std::ostream& os ) const; + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ); + + ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT; + ScopedElement& operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT; + + ~ScopedElement(); + + ScopedElement& writeText( std::string const& text, bool indent = true ); + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer = nullptr; + }; + +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + XmlWriter( std::ostream& os = std::cout ); +#else // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + XmlWriter( std::ostream& os ); +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + ~XmlWriter(); + + XmlWriter( XmlWriter const& ) = delete; + XmlWriter& operator=( XmlWriter const& ) = delete; + + XmlWriter& startElement( std::string const& name ); + + ScopedElement scopedElement( std::string const& name ); + + XmlWriter& endElement(); + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); + + XmlWriter& writeAttribute( std::string const& name, const char* attribute ); + + XmlWriter& writeAttribute( std::string const& name, bool attribute ); + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + std::stringstream rss; + rss << attribute; + return writeAttribute( name, rss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ); + + //XmlWriter& writeComment( std::string const& text ); + + //void writeStylesheetRef( std::string const& url ); + + //XmlWriter& writeBlankLine(); + + void ensureTagClosed(); + + void writeDeclaration(); + + private: + + void newlineIfNecessary(); + + bool m_tagIsOpen = false; + bool m_needsNewline = false; + std::vector m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + +using uchar = unsigned char; + +namespace { + + size_t trailingBytes(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return 2; + } + if ((c & 0xF0) == 0xE0) { + return 3; + } + if ((c & 0xF8) == 0xF0) { + return 4; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + uint32_t headerValue(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return c & 0x1F; + } + if ((c & 0xF0) == 0xE0) { + return c & 0x0F; + } + if ((c & 0xF8) == 0xF0) { + return c & 0x07; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + void hexEscapeChar(std::ostream& os, unsigned char c) { + std::ios_base::fmtflags f(os.flags()); + os << "\\x" + << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast(c); + os.flags(f); + } + +} // anonymous namespace + + XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void XmlEncode::encodeTo( std::ostream& os ) const { + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: https://www.w3.org/TR/xml/#syntax) + + for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { + uchar c = m_str[idx]; + switch (c) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: https://www.w3.org/TR/xml/#syntax + if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') + os << ">"; + else + os << c; + break; + + case '\"': + if (m_forWhat == ForAttributes) + os << """; + else + os << c; + break; + + default: + // Check for control characters and invalid utf-8 + + // Escape control characters in standard ascii + // see https://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { + hexEscapeChar(os, c); + break; + } + + // Plain ASCII: Write it to stream + if (c < 0x7F) { + os << c; + break; + } + + // UTF-8 territory + // Check if the encoding is valid and if it is not, hex escape bytes. + // Important: We do not check the exact decoded values for validity, only the encoding format + // First check that this bytes is a valid lead byte: + // This means that it is not encoded as 1111 1XXX + // Or as 10XX XXXX + if (c < 0xC0 || + c >= 0xF8) { + hexEscapeChar(os, c); + break; + } + + auto encBytes = trailingBytes(c); + // Are there enough bytes left to avoid accessing out-of-bounds memory? + if (idx + encBytes - 1 >= m_str.size()) { + hexEscapeChar(os, c); + break; + } + // The header is valid, check data + // The next encBytes bytes must together be a valid utf-8 + // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) + bool valid = true; + uint32_t value = headerValue(c); + for (std::size_t n = 1; n < encBytes; ++n) { + uchar nc = m_str[idx + n]; + valid &= ((nc & 0xC0) == 0x80); + value = (value << 6) | (nc & 0x3F); + } + + if ( + // Wrong bit pattern of following bytes + (!valid) || + // Overlong encodings + (value < 0x80) || + ( value < 0x800 && encBytes > 2) || // removed "0x80 <= value &&" because redundant + (0x800 < value && value < 0x10000 && encBytes > 3) || + // Encoded value out of range + (value >= 0x110000) + ) { + hexEscapeChar(os, c); + break; + } + + // If we got here, this is in fact a valid(ish) utf-8 sequence + for (std::size_t n = 0; n < encBytes; ++n) { + os << m_str[idx + n]; + } + idx += encBytes - 1; + break; + } + } + } + + std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT + : m_writer( other.m_writer ){ + other.m_writer = nullptr; + } + XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT { + if ( m_writer ) { + m_writer->endElement(); + } + m_writer = other.m_writer; + other.m_writer = nullptr; + return *this; + } + + + XmlWriter::ScopedElement::~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { + m_writer->writeText( text, indent ); + return *this; + } + + XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) + { + // writeDeclaration(); // called explicitly by the reporters that use the writer class - see issue #627 + } + + XmlWriter::~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& XmlWriter::startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& XmlWriter::endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << ""; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, const char* attribute ) { + if( !name.empty() && attribute && attribute[0] != '\0' ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + //XmlWriter& XmlWriter::writeComment( std::string const& text ) { + // ensureTagClosed(); + // m_os << m_indent << ""; + // m_needsNewline = true; + // return *this; + //} + + //void XmlWriter::writeStylesheetRef( std::string const& url ) { + // m_os << "\n"; + //} + + //XmlWriter& XmlWriter::writeBlankLine() { + // ensureTagClosed(); + // m_os << '\n'; + // return *this; + //} + + void XmlWriter::ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + void XmlWriter::writeDeclaration() { + m_os << "\n"; + } + + void XmlWriter::newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } + +// ================================================================================================= +// End of copy-pasted code from Catch +// ================================================================================================= + + // clang-format on + + struct XmlReporter : public IReporter + { + XmlWriter xml; + DOCTEST_DECLARE_MUTEX(mutex) + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + XmlReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + std::stringstream ss; + for(int i = 0; i < num_contexts; ++i) { + contexts[i]->stringify(&ss); + xml.scopedElement("Info").writeText(ss.str()); + ss.str(""); + } + } + } + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + void test_case_start_impl(const TestCaseData& in) { + bool open_ts_tag = false; + if(tc != nullptr) { // we have already opened a test suite + if(std::strcmp(tc->m_test_suite, in.m_test_suite) != 0) { + xml.endElement(); + open_ts_tag = true; + } + } + else { + open_ts_tag = true; // first test case ==> first test suite + } + + if(open_ts_tag) { + xml.startElement("TestSuite"); + xml.writeAttribute("name", in.m_test_suite); + } + + tc = ∈ + xml.startElement("TestCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file.c_str())) + .writeAttribute("line", line(in.m_line)) + .writeAttribute("description", in.m_description); + + if(Approx(in.m_timeout) != 0) + xml.writeAttribute("timeout", in.m_timeout); + if(in.m_may_fail) + xml.writeAttribute("may_fail", true); + if(in.m_should_fail) + xml.writeAttribute("should_fail", true); + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + test_run_start(); + if(opt.list_reporters) { + for(auto& curr : getListeners()) + xml.scopedElement("Listener") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + for(auto& curr : getReporters()) + xml.scopedElement("Reporter") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + } else if(opt.count || opt.list_test_cases) { + for(unsigned i = 0; i < in.num_data; ++i) { + xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) + .writeAttribute("testsuite", in.data[i]->m_test_suite) + .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) + .writeAttribute("line", line(in.data[i]->m_line)) + .writeAttribute("skipped", in.data[i]->m_skip); + } + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + } else if(opt.list_test_suites) { + for(unsigned i = 0; i < in.num_data; ++i) + xml.scopedElement("TestSuite").writeAttribute("name", in.data[i]->m_test_suite); + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + xml.scopedElement("OverallResultsTestSuites") + .writeAttribute("unskipped", in.run_stats->numTestSuitesPassingFilters); + } + xml.endElement(); + } + + void test_run_start() override { + xml.writeDeclaration(); + + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + + xml.startElement("doctest").writeAttribute("binary", binary_name); + if(opt.no_version == false) + xml.writeAttribute("version", DOCTEST_VERSION_STR); + + // only the consequential ones (TODO: filters) + xml.scopedElement("Options") + .writeAttribute("order_by", opt.order_by.c_str()) + .writeAttribute("rand_seed", opt.rand_seed) + .writeAttribute("first", opt.first) + .writeAttribute("last", opt.last) + .writeAttribute("abort_after", opt.abort_after) + .writeAttribute("subcase_filter_levels", opt.subcase_filter_levels) + .writeAttribute("case_sensitive", opt.case_sensitive) + .writeAttribute("no_throw", opt.no_throw) + .writeAttribute("no_skip", opt.no_skip); + } + + void test_run_end(const TestRunStats& p) override { + if(tc) // the TestSuite tag - only if there has been at least 1 test case + xml.endElement(); + + xml.scopedElement("OverallResultsAsserts") + .writeAttribute("successes", p.numAsserts - p.numAssertsFailed) + .writeAttribute("failures", p.numAssertsFailed); + + xml.startElement("OverallResultsTestCases") + .writeAttribute("successes", + p.numTestCasesPassingFilters - p.numTestCasesFailed) + .writeAttribute("failures", p.numTestCasesFailed); + if(opt.no_skipped_summary == false) + xml.writeAttribute("skipped", p.numTestCases - p.numTestCasesPassingFilters); + xml.endElement(); + + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + test_case_start_impl(in); + xml.ensureTagClosed(); + } + + void test_case_reenter(const TestCaseData&) override {} + + void test_case_end(const CurrentTestCaseStats& st) override { + xml.startElement("OverallResultsAsserts") + .writeAttribute("successes", + st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) + .writeAttribute("failures", st.numAssertsFailedCurrentTest) + .writeAttribute("test_case_success", st.testCaseSuccess); + if(opt.duration) + xml.writeAttribute("duration", st.seconds); + if(tc->m_expected_failures) + xml.writeAttribute("expected_failures", tc->m_expected_failures); + xml.endElement(); + + xml.endElement(); + } + + void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) + + xml.scopedElement("Exception") + .writeAttribute("crash", e.is_crash) + .writeText(e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + xml.startElement("SubCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file)) + .writeAttribute("line", line(in.m_line)); + xml.ensureTagClosed(); + } + + void subcase_end() override { xml.endElement(); } + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed && !opt.success) + return; + + DOCTEST_LOCK_MUTEX(mutex) + + xml.startElement("Expression") + .writeAttribute("success", !rb.m_failed) + .writeAttribute("type", assertString(rb.m_at)) + .writeAttribute("filename", skipPathFromFilename(rb.m_file)) + .writeAttribute("line", line(rb.m_line)); + + xml.scopedElement("Original").writeText(rb.m_expr); + + if(rb.m_threw) + xml.scopedElement("Exception").writeText(rb.m_exception.c_str()); + + if(rb.m_at & assertType::is_throws_as) + xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); + if(rb.m_at & assertType::is_throws_with) + xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string.c_str()); + if((rb.m_at & assertType::is_normal) && !rb.m_threw) + xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void log_message(const MessageData& mb) override { + DOCTEST_LOCK_MUTEX(mutex) + + xml.startElement("Message") + .writeAttribute("type", failureString(mb.m_severity)) + .writeAttribute("filename", skipPathFromFilename(mb.m_file)) + .writeAttribute("line", line(mb.m_line)); + + xml.scopedElement("Text").writeText(mb.m_string.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void test_case_skipped(const TestCaseData& in) override { + if(opt.no_skipped_summary == false) { + test_case_start_impl(in); + xml.writeAttribute("skipped", "true"); + xml.endElement(); + } + } + }; + + DOCTEST_REGISTER_REPORTER("xml", 0, XmlReporter); + + void fulltext_log_assert_to_stream(std::ostream& s, const AssertData& rb) { + if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) == + 0) //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << " ) " + << Color::None; + + if(rb.m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "threw as expected!" : "did NOT throw at all!") << "\n"; + } else if((rb.m_at & assertType::is_throws_as) && + (rb.m_at & assertType::is_throws_with)) { //!OCLINT + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string.c_str() + << "\", " << rb.m_exception_type << " ) " << Color::None; + if(rb.m_threw) { + if(!rb.m_failed) { + s << "threw as expected!\n"; + } else { + s << "threw a DIFFERENT exception! (contents: " << rb.m_exception << ")\n"; + } + } else { + s << "did NOT throw at all!\n"; + } + } else if(rb.m_at & + assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", " + << rb.m_exception_type << " ) " << Color::None + << (rb.m_threw ? (rb.m_threw_as ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & + assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string.c_str() + << "\" ) " << Color::None + << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "THREW exception: " : "didn't throw!") << Color::Cyan + << rb.m_exception << "\n"; + } else { + s << (rb.m_threw ? "THREW exception: " : + (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")); + if(rb.m_threw) + s << rb.m_exception << "\n"; + else + s << " values: " << assertString(rb.m_at) << "( " << rb.m_decomp << " )\n"; + } + } + + // TODO: + // - log_message() + // - respond to queries + // - honor remaining options + // - more attributes in tags + struct JUnitReporter : public IReporter + { + XmlWriter xml; + DOCTEST_DECLARE_MUTEX(mutex) + Timer timer; + std::vector deepestSubcaseStackNames; + + struct JUnitTestCaseData + { + static std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + + std::tm timeInfo; +#ifdef DOCTEST_PLATFORM_WINDOWS + gmtime_s(&timeInfo, &rawtime); +#else // DOCTEST_PLATFORM_WINDOWS + gmtime_r(&rawtime, &timeInfo); +#endif // DOCTEST_PLATFORM_WINDOWS + + char timeStamp[timeStampSize]; + const char* const fmt = "%Y-%m-%dT%H:%M:%SZ"; + + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); + return std::string(timeStamp); + } + + struct JUnitTestMessage + { + JUnitTestMessage(const std::string& _message, const std::string& _type, const std::string& _details) + : message(_message), type(_type), details(_details) {} + + JUnitTestMessage(const std::string& _message, const std::string& _details) + : message(_message), type(), details(_details) {} + + std::string message, type, details; + }; + + struct JUnitTestCase + { + JUnitTestCase(const std::string& _classname, const std::string& _name) + : classname(_classname), name(_name), time(0), failures() {} + + std::string classname, name; + double time; + std::vector failures, errors; + }; + + void add(const std::string& classname, const std::string& name) { + testcases.emplace_back(classname, name); + } + + void appendSubcaseNamesToLastTestcase(std::vector nameStack) { + for(auto& curr: nameStack) + if(curr.size()) + testcases.back().name += std::string("/") + curr.c_str(); + } + + void addTime(double time) { + if(time < 1e-4) + time = 0; + testcases.back().time = time; + totalSeconds += time; + } + + void addFailure(const std::string& message, const std::string& type, const std::string& details) { + testcases.back().failures.emplace_back(message, type, details); + ++totalFailures; + } + + void addError(const std::string& message, const std::string& details) { + testcases.back().errors.emplace_back(message, details); + ++totalErrors; + } + + std::vector testcases; + double totalSeconds = 0; + int totalErrors = 0, totalFailures = 0; + }; + + JUnitTestCaseData testCaseData; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + JUnitReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData&) override { + xml.writeDeclaration(); + } + + void test_run_start() override { + xml.writeDeclaration(); + } + + void test_run_end(const TestRunStats& p) override { + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + xml.startElement("testsuites"); + xml.startElement("testsuite").writeAttribute("name", binary_name) + .writeAttribute("errors", testCaseData.totalErrors) + .writeAttribute("failures", testCaseData.totalFailures) + .writeAttribute("tests", p.numAsserts); + if(opt.no_time_in_output == false) { + xml.writeAttribute("time", testCaseData.totalSeconds); + xml.writeAttribute("timestamp", JUnitTestCaseData::getCurrentTimestamp()); + } + if(opt.no_version == false) + xml.writeAttribute("doctest_version", DOCTEST_VERSION_STR); + + for(const auto& testCase : testCaseData.testcases) { + xml.startElement("testcase") + .writeAttribute("classname", testCase.classname) + .writeAttribute("name", testCase.name); + if(opt.no_time_in_output == false) + xml.writeAttribute("time", testCase.time); + // This is not ideal, but it should be enough to mimic gtest's junit output. + xml.writeAttribute("status", "run"); + + for(const auto& failure : testCase.failures) { + xml.scopedElement("failure") + .writeAttribute("message", failure.message) + .writeAttribute("type", failure.type) + .writeText(failure.details, false); + } + + for(const auto& error : testCase.errors) { + xml.scopedElement("error") + .writeAttribute("message", error.message) + .writeText(error.details); + } + + xml.endElement(); + } + xml.endElement(); + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + timer.start(); + } + + void test_case_reenter(const TestCaseData& in) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + + timer.start(); + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + } + + void test_case_end(const CurrentTestCaseStats&) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + } + + void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) + testCaseData.addError("exception", e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + deepestSubcaseStackNames.push_back(in.m_name); + } + + void subcase_end() override {} + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed) // report only failures & ignore the `success` option + return; + + DOCTEST_LOCK_MUTEX(mutex) + + std::ostringstream os; + os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") + << line(rb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; + + fulltext_log_assert_to_stream(os, rb); + log_contexts(os); + testCaseData.addFailure(rb.m_decomp.c_str(), assertString(rb.m_at), os.str()); + } + + void log_message(const MessageData& mb) override { + if(mb.m_severity & assertType::is_warn) // report only failures + return; + + DOCTEST_LOCK_MUTEX(mutex) + + std::ostringstream os; + os << skipPathFromFilename(mb.m_file) << (opt.gnu_file_line ? ":" : "(") + << line(mb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; + + os << mb.m_string.c_str() << "\n"; + log_contexts(os); + + testCaseData.addFailure(mb.m_string.c_str(), + mb.m_severity & assertType::is_check ? "FAIL_CHECK" : "FAIL", os.str()); + } + + void test_case_skipped(const TestCaseData&) override {} + + void log_contexts(std::ostringstream& s) { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << std::endl; + } + } + } + }; + + DOCTEST_REGISTER_REPORTER("junit", 0, JUnitReporter); + + struct Whitespace + { + int nrSpaces; + explicit Whitespace(int nr) + : nrSpaces(nr) {} + }; + + std::ostream& operator<<(std::ostream& out, const Whitespace& ws) { + if(ws.nrSpaces != 0) + out << std::setw(ws.nrSpaces) << ' '; + return out; + } + + struct ConsoleReporter : public IReporter + { + std::ostream& s; + bool hasLoggedCurrentTestStart; + std::vector subcasesStack; + size_t currentSubcaseLevel; + DOCTEST_DECLARE_MUTEX(mutex) + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc; + + ConsoleReporter(const ContextOptions& co) + : s(*co.cout) + , opt(co) {} + + ConsoleReporter(const ContextOptions& co, std::ostream& ostr) + : s(ostr) + , opt(co) {} + + // ========================================================================================= + // WHAT FOLLOWS ARE HELPERS USED BY THE OVERRIDES OF THE VIRTUAL METHODS OF THE INTERFACE + // ========================================================================================= + + void separator_to_stream() { + s << Color::Yellow + << "===============================================================================" + "\n"; + } + + const char* getSuccessOrFailString(bool success, assertType::Enum at, + const char* success_str) { + if(success) + return success_str; + return failureString(at); + } + + Color::Enum getSuccessOrFailColor(bool success, assertType::Enum at) { + return success ? Color::BrightGreen : + (at & assertType::is_warn) ? Color::Yellow : Color::Red; + } + + void successOrFailColoredStringToStream(bool success, assertType::Enum at, + const char* success_str = "SUCCESS") { + s << getSuccessOrFailColor(success, at) + << getSuccessOrFailString(success, at, success_str) << ": "; + } + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << Color::None << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << "\n"; + } + } + + s << "\n"; + } + + // this was requested to be made virtual so users could override it + virtual void file_line_to_stream(const char* file, int line, + const char* tail = "") { + s << Color::LightGrey << skipPathFromFilename(file) << (opt.gnu_file_line ? ":" : "(") + << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option + << (opt.gnu_file_line ? ":" : "):") << tail; + } + + void logTestStart() { + if(hasLoggedCurrentTestStart) + return; + + separator_to_stream(); + file_line_to_stream(tc->m_file.c_str(), tc->m_line, "\n"); + if(tc->m_description) + s << Color::Yellow << "DESCRIPTION: " << Color::None << tc->m_description << "\n"; + if(tc->m_test_suite && tc->m_test_suite[0] != '\0') + s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n"; + if(strncmp(tc->m_name, " Scenario:", 11) != 0) + s << Color::Yellow << "TEST CASE: "; + s << Color::None << tc->m_name << "\n"; + + for(size_t i = 0; i < currentSubcaseLevel; ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + + if(currentSubcaseLevel != subcasesStack.size()) { + s << Color::Yellow << "\nDEEPEST SUBCASE STACK REACHED (DIFFERENT FROM THE CURRENT ONE):\n" << Color::None; + for(size_t i = 0; i < subcasesStack.size(); ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + } + + s << "\n"; + + hasLoggedCurrentTestStart = true; + } + + void printVersion() { + if(opt.no_version == false) + s << Color::Cyan << "[doctest] " << Color::None << "doctest version is \"" + << DOCTEST_VERSION_STR << "\"\n"; + } + + void printIntro() { + if(opt.no_intro == false) { + printVersion(); + s << Color::Cyan << "[doctest] " << Color::None + << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; + } + } + + void printHelp() { + int sizePrefixDisplay = static_cast(strlen(DOCTEST_OPTIONS_PREFIX_DISPLAY)); + printVersion(); + // clang-format off + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "boolean values: \"1/on/yes/true\" or \"0/off/no/false\"\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filter values: \"str1,str2,str3\" (comma separated strings)\n"; + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filters use wildcards for matching strings\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "something passes a filter if any of the strings in a filter matches\n"; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "ALL FLAGS, OPTIONS AND FILTERS ALSO AVAILABLE WITH A \"" DOCTEST_CONFIG_OPTIONS_PREFIX "\" PREFIX!!!\n"; +#endif + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "Query flags - the program quits after them. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "?, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "help, -" DOCTEST_OPTIONS_PREFIX_DISPLAY "h " + << Whitespace(sizePrefixDisplay*0) << "prints this message\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "v, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "version " + << Whitespace(sizePrefixDisplay*1) << "prints the version\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "c, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "count " + << Whitespace(sizePrefixDisplay*1) << "prints the number of matching tests\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ltc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-cases " + << Whitespace(sizePrefixDisplay*1) << "lists all matching tests by name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-suites " + << Whitespace(sizePrefixDisplay*1) << "lists all matching test suites\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-reporters " + << Whitespace(sizePrefixDisplay*1) << "lists all registered reporters\n\n"; + // ================================================================================== << 79 + s << Color::Cyan << "[doctest] " << Color::None; + s << "The available / options/filters are:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sfe, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tse, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase= " + << Whitespace(sizePrefixDisplay*1) << "filters subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "r, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "reporters= " + << Whitespace(sizePrefixDisplay*1) << "reporters to use (console is default)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "o, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "out= " + << Whitespace(sizePrefixDisplay*1) << "output filename\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ob, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "order-by= " + << Whitespace(sizePrefixDisplay*1) << "how the tests should be ordered\n"; + s << Whitespace(sizePrefixDisplay*3) << " - [file/suite/name/rand/none]\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "rs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "rand-seed= " + << Whitespace(sizePrefixDisplay*1) << "seed for random ordering\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "f, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "first= " + << Whitespace(sizePrefixDisplay*1) << "the first test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "l, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "last= " + << Whitespace(sizePrefixDisplay*1) << "the last test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "aa, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "abort-after= " + << Whitespace(sizePrefixDisplay*1) << "stop after failed assertions\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "scfl,--" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-filter-levels= " + << Whitespace(sizePrefixDisplay*1) << "apply filters for the first levels\n"; + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "Bool options - can be used like flags and true is assumed. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "s, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "success= " + << Whitespace(sizePrefixDisplay*1) << "include successful assertions in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "cs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "case-sensitive= " + << Whitespace(sizePrefixDisplay*1) << "filters being treated as case sensitive\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "e, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "exit= " + << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " + << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "m, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "minimal= " + << Whitespace(sizePrefixDisplay*1) << "minimal console output (only failures)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "q, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "quiet= " + << Whitespace(sizePrefixDisplay*1) << "no console output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " + << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " + << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " + << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ni, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-intro= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework intro in the output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " + << Whitespace(sizePrefixDisplay*1) << "disables colors in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "fc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "force-colors= " + << Whitespace(sizePrefixDisplay*1) << "use colors even when not in a tty\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nb, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-breaks= " + << Whitespace(sizePrefixDisplay*1) << "disables breakpoints in debuggers\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ns, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-skip= " + << Whitespace(sizePrefixDisplay*1) << "don't skip test cases marked as skip\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "gfl, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "gnu-file-line= " + << Whitespace(sizePrefixDisplay*1) << ":n: vs (n): for line numbers in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "npf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-path-filenames= " + << Whitespace(sizePrefixDisplay*1) << "only filenames and no paths in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nln, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-line-numbers= " + << Whitespace(sizePrefixDisplay*1) << "0 instead of real line numbers in output\n"; + // ================================================================================== << 79 + // clang-format on + + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "for more information visit the project documentation\n\n"; + } + + void printRegisteredReporters() { + printVersion(); + auto printReporters = [this] (const reporterMap& reporters, const char* type) { + if(reporters.size()) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all registered " << type << "\n"; + for(auto& curr : reporters) + s << "priority: " << std::setw(5) << curr.first.first + << " name: " << curr.first.second << "\n"; + } + }; + printReporters(getListeners(), "listeners"); + printReporters(getReporters(), "reporters"); + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + if(opt.version) { + printVersion(); + } else if(opt.help) { + printHelp(); + } else if(opt.list_reporters) { + printRegisteredReporters(); + } else if(opt.count || opt.list_test_cases) { + if(opt.list_test_cases) { + s << Color::Cyan << "[doctest] " << Color::None + << "listing all test case names\n"; + separator_to_stream(); + } + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_name << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + + } else if(opt.list_test_suites) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all test suites\n"; + separator_to_stream(); + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_test_suite << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "test suites with unskipped test cases passing the current filters: " + << g_cs->numTestSuitesPassingFilters << "\n"; + } + } + + void test_run_start() override { + if(!opt.minimal) + printIntro(); + } + + void test_run_end(const TestRunStats& p) override { + if(opt.minimal && p.numTestCasesFailed == 0) + return; + + separator_to_stream(); + s << std::dec; + + auto totwidth = int(std::ceil(log10(static_cast(std::max(p.numTestCasesPassingFilters, static_cast(p.numAsserts))) + 1))); + auto passwidth = int(std::ceil(log10(static_cast(std::max(p.numTestCasesPassingFilters - p.numTestCasesFailed, static_cast(p.numAsserts - p.numAssertsFailed))) + 1))); + auto failwidth = int(std::ceil(log10(static_cast(std::max(p.numTestCasesFailed, static_cast(p.numAssertsFailed))) + 1))); + const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0; + s << Color::Cyan << "[doctest] " << Color::None << "test cases: " << std::setw(totwidth) + << p.numTestCasesPassingFilters << " | " + << ((p.numTestCasesPassingFilters == 0 || anythingFailed) ? Color::None : + Color::Green) + << std::setw(passwidth) << p.numTestCasesPassingFilters - p.numTestCasesFailed << " passed" + << Color::None << " | " << (p.numTestCasesFailed > 0 ? Color::Red : Color::None) + << std::setw(failwidth) << p.numTestCasesFailed << " failed" << Color::None << " |"; + if(opt.no_skipped_summary == false) { + const int numSkipped = p.numTestCases - p.numTestCasesPassingFilters; + s << " " << (numSkipped == 0 ? Color::None : Color::Yellow) << numSkipped + << " skipped" << Color::None; + } + s << "\n"; + s << Color::Cyan << "[doctest] " << Color::None << "assertions: " << std::setw(totwidth) + << p.numAsserts << " | " + << ((p.numAsserts == 0 || anythingFailed) ? Color::None : Color::Green) + << std::setw(passwidth) << (p.numAsserts - p.numAssertsFailed) << " passed" << Color::None + << " | " << (p.numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(failwidth) + << p.numAssertsFailed << " failed" << Color::None << " |\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "Status: " << (p.numTestCasesFailed > 0 ? Color::Red : Color::Green) + << ((p.numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl; + } + + void test_case_start(const TestCaseData& in) override { + hasLoggedCurrentTestStart = false; + tc = ∈ + subcasesStack.clear(); + currentSubcaseLevel = 0; + } + + void test_case_reenter(const TestCaseData&) override { + subcasesStack.clear(); + } + + void test_case_end(const CurrentTestCaseStats& st) override { + if(tc->m_no_output) + return; + + // log the preamble of the test case only if there is something + // else to print - something other than that an assert has failed + if(opt.duration || + (st.failure_flags && st.failure_flags != static_cast(TestCaseFailureReason::AssertFailure))) + logTestStart(); + + if(opt.duration) + s << Color::None << std::setprecision(6) << std::fixed << st.seconds + << " s: " << tc->m_name << "\n"; + + if(st.failure_flags & TestCaseFailureReason::Timeout) + s << Color::Red << "Test case exceeded time limit of " << std::setprecision(6) + << std::fixed << tc->m_timeout << "!\n"; + + if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedButDidnt) { + s << Color::Red << "Should have failed but didn't! Marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedAndDid) { + s << Color::Yellow << "Failed as expected so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::CouldHaveFailedAndDid) { + s << Color::Yellow << "Allowed to fail so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::DidntFailExactlyNumTimes) { + s << Color::Red << "Didn't fail exactly " << tc->m_expected_failures + << " times so marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::FailedExactlyNumTimes) { + s << Color::Yellow << "Failed exactly " << tc->m_expected_failures + << " times as expected so marking it as not failed!\n"; + } + if(st.failure_flags & TestCaseFailureReason::TooManyFailedAsserts) { + s << Color::Red << "Aborting - too many failed asserts!\n"; + } + s << Color::None; // lgtm [cpp/useless-expression] + } + + void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) + if(tc->m_no_output) + return; + + logTestStart(); + + file_line_to_stream(tc->m_file.c_str(), tc->m_line, " "); + successOrFailColoredStringToStream(false, e.is_crash ? assertType::is_require : + assertType::is_check); + s << Color::Red << (e.is_crash ? "test case CRASHED: " : "test case THREW exception: ") + << Color::Cyan << e.error_string << "\n"; + + int num_stringified_contexts = get_num_stringified_contexts(); + if(num_stringified_contexts) { + auto stringified_contexts = get_stringified_contexts(); + s << Color::None << " logged: "; + for(int i = num_stringified_contexts; i > 0; --i) { + s << (i == num_stringified_contexts ? "" : " ") + << stringified_contexts[i - 1] << "\n"; + } + } + s << "\n" << Color::None; + } + + void subcase_start(const SubcaseSignature& subc) override { + subcasesStack.push_back(subc); + ++currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void subcase_end() override { + --currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void log_assert(const AssertData& rb) override { + if((!rb.m_failed && !opt.success) || tc->m_no_output) + return; + + DOCTEST_LOCK_MUTEX(mutex) + + logTestStart(); + + file_line_to_stream(rb.m_file, rb.m_line, " "); + successOrFailColoredStringToStream(!rb.m_failed, rb.m_at); + + fulltext_log_assert_to_stream(s, rb); + + log_contexts(); + } + + void log_message(const MessageData& mb) override { + if(tc->m_no_output) + return; + + DOCTEST_LOCK_MUTEX(mutex) + + logTestStart(); + + file_line_to_stream(mb.m_file, mb.m_line, " "); + s << getSuccessOrFailColor(false, mb.m_severity) + << getSuccessOrFailString(mb.m_severity & assertType::is_warn, mb.m_severity, + "MESSAGE") << ": "; + s << Color::None << mb.m_string << "\n"; + log_contexts(); + } + + void test_case_skipped(const TestCaseData&) override {} + }; + + DOCTEST_REGISTER_REPORTER("console", 0, ConsoleReporter); + +#ifdef DOCTEST_PLATFORM_WINDOWS + struct DebugOutputWindowReporter : public ConsoleReporter + { + DOCTEST_THREAD_LOCAL static std::ostringstream oss; + + DebugOutputWindowReporter(const ContextOptions& co) + : ConsoleReporter(co, oss) {} + +#define DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(func, type, arg) \ + void func(type arg) override { \ + bool with_col = g_no_colors; \ + g_no_colors = false; \ + ConsoleReporter::func(arg); \ + if(oss.tellp() != std::streampos{}) { \ + DOCTEST_OUTPUT_DEBUG_STRING(oss.str().c_str()); \ + oss.str(""); \ + } \ + g_no_colors = with_col; \ + } + + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_start, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_end, const TestRunStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_start, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_reenter, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_end, const CurrentTestCaseStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_exception, const TestCaseException&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_start, const SubcaseSignature&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_end, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_assert, const AssertData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_message, const MessageData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_skipped, const TestCaseData&, in) + }; + + DOCTEST_THREAD_LOCAL std::ostringstream DebugOutputWindowReporter::oss; +#endif // DOCTEST_PLATFORM_WINDOWS + + // the implementation of parseOption() + bool parseOptionImpl(int argc, const char* const* argv, const char* pattern, String* value) { + // going from the end to the beginning and stopping on the first occurrence from the end + for(int i = argc; i > 0; --i) { + auto index = i - 1; + auto temp = std::strstr(argv[index], pattern); + if(temp && (value || strlen(temp) == strlen(pattern))) { //!OCLINT prefer early exits and continue + // eliminate matches in which the chars before the option are not '-' + bool noBadCharsFound = true; + auto curr = argv[index]; + while(curr != temp) { + if(*curr++ != '-') { + noBadCharsFound = false; + break; + } + } + if(noBadCharsFound && argv[index][0] == '-') { + if(value) { + // parsing the value of an option + temp += strlen(pattern); + const unsigned len = strlen(temp); + if(len) { + *value = temp; + return true; + } + } else { + // just a flag - no value + return true; + } + } + } + } + return false; + } + + // parses an option and returns the string after the '=' character + bool parseOption(int argc, const char* const* argv, const char* pattern, String* value = nullptr, + const String& defaultVal = String()) { + if(value) + *value = defaultVal; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + // offset (normally 3 for "dt-") to skip prefix + if(parseOptionImpl(argc, argv, pattern + strlen(DOCTEST_CONFIG_OPTIONS_PREFIX), value)) + return true; +#endif // DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + return parseOptionImpl(argc, argv, pattern, value); + } + + // locates a flag on the command line + bool parseFlag(int argc, const char* const* argv, const char* pattern) { + return parseOption(argc, argv, pattern); + } + + // parses a comma separated list of words after a pattern in one of the arguments in argv + bool parseCommaSepArgs(int argc, const char* const* argv, const char* pattern, + std::vector& res) { + String filtersString; + if(parseOption(argc, argv, pattern, &filtersString)) { + // tokenize with "," as a separator, unless escaped with backslash + std::ostringstream s; + auto flush = [&s, &res]() { + auto string = s.str(); + if(string.size() > 0) { + res.push_back(string.c_str()); + } + s.str(""); + }; + + bool seenBackslash = false; + const char* current = filtersString.c_str(); + const char* end = current + strlen(current); + while(current != end) { + char character = *current++; + if(seenBackslash) { + seenBackslash = false; + if(character == ',' || character == '\\') { + s.put(character); + continue; + } + s.put('\\'); + } + if(character == '\\') { + seenBackslash = true; + } else if(character == ',') { + flush(); + } else { + s.put(character); + } + } + + if(seenBackslash) { + s.put('\\'); + } + flush(); + return true; + } + return false; + } + + enum optionType + { + option_bool, + option_int + }; + + // parses an int/bool option from the command line + bool parseIntOption(int argc, const char* const* argv, const char* pattern, optionType type, + int& res) { + String parsedValue; + if(!parseOption(argc, argv, pattern, &parsedValue)) + return false; + + if(type) { + // integer + // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... + int theInt = std::atoi(parsedValue.c_str()); + if (theInt != 0) { + res = theInt; //!OCLINT parameter reassignment + return true; + } + } else { + // boolean + const char positive[][5] = { "1", "true", "on", "yes" }; // 5 - strlen("true") + 1 + const char negative[][6] = { "0", "false", "off", "no" }; // 6 - strlen("false") + 1 + + // if the value matches any of the positive/negative possibilities + for (unsigned i = 0; i < 4; i++) { + if (parsedValue.compare(positive[i], true) == 0) { + res = 1; //!OCLINT parameter reassignment + return true; + } + if (parsedValue.compare(negative[i], true) == 0) { + res = 0; //!OCLINT parameter reassignment + return true; + } + } + } + return false; + } +} // namespace + +Context::Context(int argc, const char* const* argv) + : p(new detail::ContextState) { + parseArgs(argc, argv, true); + if(argc) + p->binary_name = argv[0]; +} + +Context::~Context() { + if(g_cs == p) + g_cs = nullptr; + delete p; +} + +void Context::applyCommandLine(int argc, const char* const* argv) { + parseArgs(argc, argv); + if(argc) + p->binary_name = argv[0]; +} + +// parses args +void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { + using namespace detail; + + // clang-format off + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sf=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file-exclude=",p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sfe=", p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ts=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite-exclude=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tse=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tc=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case-exclude=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tce=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sc=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase-exclude=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sce=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "reporters=", p->filters[8]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "r=", p->filters[8]); + // clang-format on + + int intRes = 0; + String strRes; + +#define DOCTEST_PARSE_AS_BOOL_OR_FLAG(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_bool, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_bool, intRes)) \ + p->var = static_cast(intRes); \ + else if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name) || \ + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname)) \ + p->var = true; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_INT_OPTION(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_int, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_int, intRes)) \ + p->var = intRes; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_STR_OPTION(name, sname, var, default) \ + if(parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", &strRes, default) || \ + parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", &strRes, default) || \ + withDefaults) \ + p->var = strRes + + // clang-format off + DOCTEST_PARSE_STR_OPTION("out", "o", out, ""); + DOCTEST_PARSE_STR_OPTION("order-by", "ob", order_by, "file"); + DOCTEST_PARSE_INT_OPTION("rand-seed", "rs", rand_seed, 0); + + DOCTEST_PARSE_INT_OPTION("first", "f", first, 0); + DOCTEST_PARSE_INT_OPTION("last", "l", last, UINT_MAX); + + DOCTEST_PARSE_INT_OPTION("abort-after", "aa", abort_after, 0); + DOCTEST_PARSE_INT_OPTION("subcase-filter-levels", "scfl", subcase_filter_levels, INT_MAX); + + DOCTEST_PARSE_AS_BOOL_OR_FLAG("success", "s", success, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("minimal", "m", minimal, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("quiet", "q", quiet, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-intro", "ni", no_intro, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-breaks", "nb", no_breaks, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skip", "ns", no_skip, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("gnu-file-line", "gfl", gnu_file_line, !bool(DOCTEST_MSVC)); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-path-filenames", "npf", no_path_in_filenames, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-line-numbers", "nln", no_line_numbers, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-debug-output", "ndo", no_debug_output, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skipped-summary", "nss", no_skipped_summary, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-time-in-output", "ntio", no_time_in_output, false); + // clang-format on + + if(withDefaults) { + p->help = false; + p->version = false; + p->count = false; + p->list_test_cases = false; + p->list_test_suites = false; + p->list_reporters = false; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "help") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "h") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "?")) { + p->help = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "version") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "v")) { + p->version = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "count") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "c")) { + p->count = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-cases") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ltc")) { + p->list_test_cases = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-suites") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lts")) { + p->list_test_suites = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-reporters") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lr")) { + p->list_reporters = true; + p->exit = true; + } +} + +// allows the user to add procedurally to the filters from the command line +void Context::addFilter(const char* filter, const char* value) { setOption(filter, value); } + +// allows the user to clear all filters from the command line +void Context::clearFilters() { + for(auto& curr : p->filters) + curr.clear(); +} + +// allows the user to override procedurally the bool options from the command line +void Context::setOption(const char* option, bool value) { + setOption(option, value ? "true" : "false"); +} + +// allows the user to override procedurally the int options from the command line +void Context::setOption(const char* option, int value) { + setOption(option, toString(value).c_str()); +} + +// allows the user to override procedurally the string options from the command line +void Context::setOption(const char* option, const char* value) { + auto argv = String("-") + option + "=" + value; + auto lvalue = argv.c_str(); + parseArgs(1, &lvalue); +} + +// users should query this in their main() and exit the program if true +bool Context::shouldExit() { return p->exit; } + +void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } + +void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } + +void Context::setCout(std::ostream* out) { p->cout = out; } + +static class DiscardOStream : public std::ostream +{ +private: + class : public std::streambuf + { + private: + // allowing some buffering decreases the amount of calls to overflow + char buf[1024]; + + protected: + std::streamsize xsputn(const char_type*, std::streamsize count) override { return count; } + + int_type overflow(int_type ch) override { + setp(std::begin(buf), std::end(buf)); + return traits_type::not_eof(ch); + } + } discardBuf; + +public: + DiscardOStream() + : std::ostream(&discardBuf) {} +} discardOut; + +// the main function that does all the filtering and test running +int Context::run() { + using namespace detail; + + // save the old context state in case such was setup - for using asserts out of a testing context + auto old_cs = g_cs; + // this is the current contest + g_cs = p; + is_running_in_test = true; + + g_no_colors = p->no_colors; + p->resetRunData(); + + std::fstream fstr; + if(p->cout == nullptr) { + if(p->quiet) { + p->cout = &discardOut; + } else if(p->out.size()) { + // to a file if specified + fstr.open(p->out.c_str(), std::fstream::out); + p->cout = &fstr; + } else { +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + // stdout by default + p->cout = &std::cout; +#else // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + return EXIT_FAILURE; +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + } + } + + FatalConditionHandler::allocateAltStackMem(); + + auto cleanup_and_return = [&]() { + FatalConditionHandler::freeAltStackMem(); + + if(fstr.is_open()) + fstr.close(); + + // restore context + g_cs = old_cs; + is_running_in_test = false; + + // we have to free the reporters which were allocated when the run started + for(auto& curr : p->reporters_currently_used) + delete curr; + p->reporters_currently_used.clear(); + + if(p->numTestCasesFailed && !p->no_exitcode) + return EXIT_FAILURE; + return EXIT_SUCCESS; + }; + + // setup default reporter if none is given through the command line + if(p->filters[8].empty()) + p->filters[8].push_back("console"); + + // check to see if any of the registered reporters has been selected + for(auto& curr : getReporters()) { + if(matchesAny(curr.first.second.c_str(), p->filters[8], false, p->case_sensitive)) + p->reporters_currently_used.push_back(curr.second(*g_cs)); + } + + // TODO: check if there is nothing in reporters_currently_used + + // prepend all listeners + for(auto& curr : getListeners()) + p->reporters_currently_used.insert(p->reporters_currently_used.begin(), curr.second(*g_cs)); + +#ifdef DOCTEST_PLATFORM_WINDOWS + if(isDebuggerActive() && p->no_debug_output == false) + p->reporters_currently_used.push_back(new DebugOutputWindowReporter(*g_cs)); +#endif // DOCTEST_PLATFORM_WINDOWS + + // handle version, help and no_run + if(p->no_run || p->version || p->help || p->list_reporters) { + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, QueryData()); + + return cleanup_and_return(); + } + + std::vector testArray; + for(auto& curr : getRegisteredTests()) + testArray.push_back(&curr); + p->numTestCases = testArray.size(); + + // sort the collected records + if(!testArray.empty()) { + if(p->order_by.compare("file", true) == 0) { + std::sort(testArray.begin(), testArray.end(), fileOrderComparator); + } else if(p->order_by.compare("suite", true) == 0) { + std::sort(testArray.begin(), testArray.end(), suiteOrderComparator); + } else if(p->order_by.compare("name", true) == 0) { + std::sort(testArray.begin(), testArray.end(), nameOrderComparator); + } else if(p->order_by.compare("rand", true) == 0) { + std::srand(p->rand_seed); + + // random_shuffle implementation + const auto first = &testArray[0]; + for(size_t i = testArray.size() - 1; i > 0; --i) { + int idxToSwap = std::rand() % (i + 1); + + const auto temp = first[i]; + + first[i] = first[idxToSwap]; + first[idxToSwap] = temp; + } + } else if(p->order_by.compare("none", true) == 0) { + // means no sorting - beneficial for death tests which call into the executable + // with a specific test case in mind - we don't want to slow down the startup times + } + } + + std::set testSuitesPassingFilt; + + bool query_mode = p->count || p->list_test_cases || p->list_test_suites; + std::vector queryResults; + + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_start, DOCTEST_EMPTY); + + // invoke the registered functions if they match the filter criteria (or just count them) + for(auto& curr : testArray) { + const auto& tc = *curr; + + bool skip_me = false; + if(tc.m_skip && !p->no_skip) + skip_me = true; + + if(!matchesAny(tc.m_file.c_str(), p->filters[0], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_file.c_str(), p->filters[1], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_test_suite, p->filters[2], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_test_suite, p->filters[3], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_name, p->filters[4], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_name, p->filters[5], false, p->case_sensitive)) + skip_me = true; + + if(!skip_me) + p->numTestCasesPassingFilters++; + + // skip the test if it is not in the execution range + if((p->last < p->numTestCasesPassingFilters && p->first <= p->last) || + (p->first > p->numTestCasesPassingFilters)) + skip_me = true; + + if(skip_me) { + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_skipped, tc); + continue; + } + + // do not execute the test if we are to only count the number of filter passing tests + if(p->count) + continue; + + // print the name of the test and don't execute it + if(p->list_test_cases) { + queryResults.push_back(&tc); + continue; + } + + // print the name of the test suite if not done already and don't execute it + if(p->list_test_suites) { + if((testSuitesPassingFilt.count(tc.m_test_suite) == 0) && tc.m_test_suite[0] != '\0') { + queryResults.push_back(&tc); + testSuitesPassingFilt.insert(tc.m_test_suite); + p->numTestSuitesPassingFilters++; + } + continue; + } + + // execute the test if it passes all the filtering + { + p->currentTest = &tc; + + p->failure_flags = TestCaseFailureReason::None; + p->seconds = 0; + + // reset atomic counters + p->numAssertsFailedCurrentTest_atomic = 0; + p->numAssertsCurrentTest_atomic = 0; + + p->fullyTraversedSubcases.clear(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); + + p->timer.start(); + + bool run_test = true; + + do { + // reset some of the fields for subcases (except for the set of fully passed ones) + p->reachedLeaf = false; + // May not be empty if previous subcase exited via exception. + p->subcaseStack.clear(); + p->currentSubcaseDepth = 0; + + p->shouldLogCurrentException = true; + + // reset stuff for logging with INFO() + p->stringifiedContexts.clear(); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +// MSVC 2015 diagnoses fatalConditionHandler as unused (because reset() is a static method) +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4101) // unreferenced local variable + FatalConditionHandler fatalConditionHandler; // Handle signals + // execute the test + tc.m_test(); + fatalConditionHandler.reset(); +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + } catch(const TestFailureException&) { + p->failure_flags |= TestCaseFailureReason::AssertFailure; + } catch(...) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, + {translateActiveException(), false}); + p->failure_flags |= TestCaseFailureReason::Exception; + } +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + + // exit this loop if enough assertions have failed - even if there are more subcases + if(p->abort_after > 0 && + p->numAssertsFailed + p->numAssertsFailedCurrentTest_atomic >= p->abort_after) { + run_test = false; + p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; + } + + if(!p->nextSubcaseStack.empty() && run_test) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); + if(p->nextSubcaseStack.empty()) + run_test = false; + } while(run_test); + + p->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + p->currentTest = nullptr; + + // stop executing tests if enough assertions have failed + if(p->abort_after > 0 && p->numAssertsFailed >= p->abort_after) + break; + } + } + + if(!query_mode) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } else { + QueryData qdata; + qdata.run_stats = g_cs; + qdata.data = queryResults.data(); + qdata.num_data = unsigned(queryResults.size()); + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); + } + + return cleanup_and_return(); +} + +DOCTEST_DEFINE_INTERFACE(IReporter) + +int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } +const IContextScope* const* IReporter::get_active_contexts() { + return get_num_active_contexts() ? &detail::g_infoContexts[0] : nullptr; +} + +int IReporter::get_num_stringified_contexts() { return detail::g_cs->stringifiedContexts.size(); } +const String* IReporter::get_stringified_contexts() { + return get_num_stringified_contexts() ? &detail::g_cs->stringifiedContexts[0] : nullptr; +} + +namespace detail { + void registerReporterImpl(const char* name, int priority, reporterCreatorFunc c, bool isReporter) { + if(isReporter) + getReporters().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + else + getListeners().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + } +} // namespace detail + +} // namespace doctest + +#endif // DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) // 'function' : must be 'attribute' - see issue #182 +int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_POP + +#endif // DOCTEST_LIBRARY_IMPLEMENTATION +#endif // DOCTEST_CONFIG_IMPLEMENT + +#ifdef DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN +#undef WIN32_LEAN_AND_MEAN +#undef DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN +#endif // DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN + +#ifdef DOCTEST_UNDEF_NOMINMAX +#undef NOMINMAX +#undef DOCTEST_UNDEF_NOMINMAX +#endif // DOCTEST_UNDEF_NOMINMAX diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index fb70475c..81352f36 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -37,148 +37,257 @@ extern "C" { TINYTC_EXPORT char const *tinytc_error_string(tinytc_status_t status); //////////////////////////// -//////// Scalar type /////// +////////// FP math ///////// //////////////////////////// -//! Convert scalar type to string -TINYTC_EXPORT char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty); -//! Size of scalar type in bytes -TINYTC_EXPORT size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty); +/** + * @brief Convert f32 number to bf16 number (represented as ushort) + * + * @param x f32 number + * + * @return bf16 number + */ +TINYTC_EXPORT uint16_t tinytc_f32_to_bf16_as_ui16(float x); + +/** + * @brief Convert bf16 number (represented as ushort) to f32 number + * + * @param x bf16 number + * + * @return f32 number + */ +TINYTC_EXPORT float tinytc_bf16_as_ui16_to_f32(uint16_t x); + +/** + * @brief Convert f32 number to f16 number (represented as ushort) + * + * @param x f32 number + * + * @return f16 number + */ +TINYTC_EXPORT uint16_t tinytc_f32_to_f16_as_ui16(float x); + +/** + * @brief Convert f16 number (represented as ushort) to f32 number + * + * @param x f16 number + * + * @return f32 number + */ +TINYTC_EXPORT float tinytc_f16_as_ui16_to_f32(uint16_t x); //////////////////////////// -///////// Data type //////// +///////// Attribute //////// //////////////////////////// /** - * @brief Create scalar data type + * @brief Get array attribute * - * @param dt [out] pointer to the data type object created - * @param type [in] scalar type - * @param loc [in][optional] Source code location; can be nullptr + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param array_size [in] number of elements in array, must be 0 if array == nullptr + * @param array [in][range(0, array_size)] attribute array * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_array_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + uint32_t array_size, + const tinytc_attr_t *array); /** - * @brief Create memref data type + * @brief Get boolean attribute * - * @param dt [out] pointer to the data type object created - * @param scalar_ty [in] element type - * @param shape_size [in] tensor order; number of elements in shape array, must be 0 if shape == - * nullptr - * @param shape [in][range(0, shape_size)] array of mode sizes - * @param stride_size [in][optional] number of elements in stride array; must be either 0 for - * automatic stride calculation or must match shape_size; must be 0 if stride == nullptr - * @param stride [in][optional][range(0, stride_size)] stride array - * @param loc [in][optional] Source code location; can be nullptr + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param value [in] value of attribute * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, - tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, - uint32_t stride_size, const int64_t *stride, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_boolean_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + tinytc_bool_t value); /** - * @brief Create group data type + * @brief Get dictionary attribute * - * @param dt [out] pointer to the data type object created - * @param memref_ty [in] memref data type object - * @param offset [in][optional] offset parameter; pass 0 for default - * @param loc [in][optional] Source code location; can be nullptr + * Each name must only appear once. + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param items_size [in] number of elements in items array, must be 0 if items == nullptr + * @param items [in][range(0, items_size)] array of items * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_group_type_create(tinytc_data_type_t *dt, - tinytc_data_type_t memref_ty, int64_t offset, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_dictionary_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + uint32_t items_size, + tinytc_named_attr_t *items); /** - * @brief Release data type object + * @brief Get dictionary attribute with pre-sorted items * - * Decreases reference count by 1, free memory if reference count is 0. + * The list of items must be sorted by name and each name must only appear once. * - * @param dt [inout] data type object + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param items_size [in] number of elements in items array, must be 0 if items == nullptr + * @param items [in][range(0, items_size)] array of items * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_data_type_release(tinytc_data_type_t dt); +TINYTC_EXPORT tinytc_status_t +tinytc_dictionary_attr_get_with_sorted(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + uint32_t items_size, const tinytc_named_attr_t *items); /** - * @brief Increase reference count of data type object by 1 + * @brief Sort items array by name * - * @param dt [inout] data type object + * @param items_size [in] number of elements in items array, must be 0 if items == nullptr + * @param items [in][range(0, items_size)] array of items * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_data_type_retain(tinytc_data_type_t dt); +TINYTC_EXPORT tinytc_status_t tinytc_dictionary_attr_sort(uint32_t items_size, + tinytc_named_attr_t *items); + +/** + * @brief Get integer attribute + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param value [in] value of attribute + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_integer_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, int64_t value); + +/** + * @brief Get string attribute + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param str_length [in] number of characters (not including a null terminator) + * @param str [in] string; not necessarily null-terminated + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_string_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + uint32_t str_length, char const *str); //////////////////////////// -/////////// Value ////////// +//////// Scalar type /////// +//////////////////////////// + +//! Convert scalar type to string +TINYTC_EXPORT char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty); +//! Size of scalar type in bytes +TINYTC_EXPORT size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty); + +//////////////////////////// +///////// Data type //////// //////////////////////////// /** - * @brief Create value + * @brief Get boolean data type * - * @param vl [out] pointer to the value object created - * @param type [in] data type object - * @param loc [in][optional] Source code location; can be nullptr + * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_boolean_type_get(tinytc_data_type_t *dt, + tinytc_compiler_context_t ctx); /** - * @brief Create floating point immediate value + * @brief Get scalar data type * - * @param vl [out] pointer to the value object created - * @param imm [in] immediate value - * @param type [in] type of immediate value - * @param loc [in][optional] Source code location; can be nullptr + * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context + * @param type [in] scalar type * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, + tinytc_compiler_context_t ctx, + tinytc_scalar_type_t type); + /** - * @brief Create integer immediate value + * @brief Get memref data type + * + * Note: modifies compiler context * - * @param vl [out] pointer to the value object created - * @param imm [in] immediate value - * @param type [in] type of immediate value + * @param dt [out] pointer to the data type object created + * @param scalar_ty [in] element type + * @param shape_size [in] tensor order; number of elements in shape array, must be 0 if shape == + * nullptr + * @param shape [in][range(0, shape_size)] array of mode sizes + * @param stride_size [in][optional] number of elements in stride array; must be either 0 for + * automatic stride calculation or must match shape_size; must be 0 if stride == nullptr + * @param stride [in][optional][range(0, stride_size)] stride array + * @param addrspace [in][optional] Address space; default is tinytc_address_space_global * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, + tinytc_data_type_t scalar_ty, + uint32_t shape_size, const int64_t *shape, + uint32_t stride_size, const int64_t *stride, + tinytc_address_space_t addrspace, + const tinytc_location_t *loc); /** - * @brief Release value object + * @brief Get group data type * - * Decreases reference count by 1, free memory if reference count is 0. + * Note: modifies compiler context * - * @param vl [inout] value object + * @param dt [out] pointer to the data type object created + * @param memref_ty [in] memref data type object + * @param size [in] group size; may be TINYTC_DYNAMIC + * @param offset [in][optional] offset parameter; pass 0 for default; may be TINYTC_DYNAMIC + * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_value_release(tinytc_value_t vl); - +TINYTC_EXPORT tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, + tinytc_data_type_t memref_ty, int64_t size, + int64_t offset, const tinytc_location_t *loc); /** - * @brief Increase reference count of value object by 1 + * @brief Get coopmatrix data type * - * @param vl [inout] value object + * Note: modifies compiler context + * + * @param dt [out] pointer to the data type object created + * @param scalar_ty [in] component type + * @param rows [in] number of rows + * @param cols [in] number of cols + * @param use [in] matrix use + * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_value_retain(tinytc_value_t vl); +TINYTC_EXPORT tinytc_status_t tinytc_coopmatrix_type_get(tinytc_data_type_t *dt, + tinytc_data_type_t scalar_ty, int64_t rows, + int64_t cols, tinytc_matrix_use_t use, + const tinytc_location_t *loc); +/** + * @brief Get void data type + * + * @param dt [out] pointer to the data type object created + * @param ctx [inout] compiler context + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_void_type_get(tinytc_data_type_t *dt, + tinytc_compiler_context_t ctx); + +//////////////////////////// +/////////// Value ////////// +//////////////////////////// /** * @brief Set name of value @@ -190,6 +299,18 @@ TINYTC_EXPORT tinytc_status_t tinytc_value_retain(tinytc_value_t vl); */ TINYTC_EXPORT tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name); +/** + * @brief Set name of value with explicit number of characters + * + * @param vl [inout] value object + * @param name_length [in] number of characters + * @param name [in] name; not necessarily null-terminated + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_set_name_n(tinytc_value_t vl, uint32_t name_length, + char const *name); + /** * @brief Get name of value * @@ -203,44 +324,76 @@ TINYTC_EXPORT tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char cons */ TINYTC_EXPORT tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name); +/** + * @brief Get type of value + * + * @param vl [in] value object + * @param ty [out] pointer to data type + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_get_type(const_tinytc_value_t vl, + tinytc_data_type_t *ty); + //////////////////////////// /////// Instructions /////// //////////////////////////// +//! Convert address space to string +TINYTC_EXPORT char const *tinytc_address_space_to_string(tinytc_address_space_t as); //! Convert arithmetic operation type to string TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); //! Convert arithmetic operation type to string (unary) TINYTC_EXPORT char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op); +//! Convert builtin type to string +TINYTC_EXPORT char const *tinytc_builtin_to_string(tinytc_builtin_t b); +//! Convert checked flag to string +TINYTC_EXPORT char const *tinytc_checked_flag_to_string(tinytc_checked_flag_t flag); //! Convert cmp condition to string TINYTC_EXPORT char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond); +//! Convert subgroup arithmmetic to string +TINYTC_EXPORT char const *tinytc_group_arithmetic_to_string(tinytc_group_arithmetic_t op); +//! Convert subgroup op to string +TINYTC_EXPORT char const *tinytc_group_operation_to_string(tinytc_group_operation_t op); +//! Convert reduce mode to string +TINYTC_EXPORT char const *tinytc_reduce_mode_to_string(tinytc_reduce_mode_t m); +//! Convert math operation type to string (unary) +TINYTC_EXPORT char const *tinytc_math_unary_to_string(tinytc_math_unary_t op); +//! Convert matrix use to string +TINYTC_EXPORT char const *tinytc_matrix_use_to_string(tinytc_matrix_use_t u); +//! Convert store flag to string +TINYTC_EXPORT char const *tinytc_store_flag_to_string(tinytc_store_flag_t flag); //! Convert transpose to string TINYTC_EXPORT char const *tinytc_transpose_to_string(tinytc_transpose_t t); /** * @brief Create arithmetic instruction (binary) * - * @code %value = arith. %a, %b : type(%a) ; type(%a) == type(%b) @endcode + * @code %value = arith. %a, %b : ty ; ty == type(%a) and ty == type(%b) @endcode * * @param instr [out] pointer to the inst object created * @param op [in] arithmetic operation type * @param a [in] left-hand operand * @param b [in] right-hand operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, tinytc_value_t a, tinytc_value_t b, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create arithmetic instruction (unary) * - * @code %value = arith. %a : type(%a) @endcode + * @code %value = arith. %a : ty @endcode * * @param instr [out] pointer to the inst object created * @param op [in] unary arithmetic operation type * @param a [in] operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise @@ -248,12 +401,42 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tin TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_unary_t op, tinytc_value_t a, + tinytc_data_type_t ty, const tinytc_location_t *loc); +/** + * @brief Create barrier instruction + * + * @param instr [out] pointer to the inst object created + * @param fence_flags [in] address space(s) of memory fence; set to 0 for no fence + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_barrier_inst_create(tinytc_inst_t *instr, + tinytc_address_spaces_t fence_flags, + const tinytc_location_t *loc); +/** + * @brief Create builtin instruction + * + * @code %value = builtin. : %ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param btype [in] builtin type + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_builtin_inst_create(tinytc_inst_t *instr, + tinytc_builtin_t btype, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + /** * @brief Create cast instruction * - * @code %value = cast %a, %b : type(%a) -> %to_ty @endcode + * @code %value = cast %a, %b : %to_ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand @@ -263,31 +446,270 @@ TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *inst * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_scalar_type_t to_ty, + tinytc_data_type_t to_ty, const tinytc_location_t *loc); /** * @brief Create binary op instruction * - * @code %value = cmp. %a, %b : type(%a) ; type(%a) == type(%b) @endcode + * @code %value = cmp. %a, %b : ty ; type(%a) == type(%b) @endcode * * @param instr [out] pointer to the inst object created * @param cond [in] compare type * @param a [in] left-hand operand * @param b [in] right-hand operand + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, tinytc_value_t a, - tinytc_value_t b, + tinytc_value_t b, tinytc_data_type_t ty, const tinytc_location_t *loc); +/** + * @brief Create boolean constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_boolean(tinytc_inst_t *instr, + tinytc_bool_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create complex constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value_re [in] constant value (real part) + * @param value_im [in] constant value (imaginary part) + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, + double value_re, double value_im, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create floating constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_float(tinytc_inst_t *instr, double value, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create integer constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Creates the multiplicative identity constant (i.e. "1") for the given data type + * + * @param instr [out] pointer to the inst object created + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Creates the additive identity constant (i.e. "0") for the given data type + * + * @param instr [out] pointer to the inst object created + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix apply instruction + * + * @code %value = cooperative_matrix_apply (%row,%column,%value)=%mat -> ty {} + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param mat [in] %mat + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_apply_inst_create( + tinytc_inst_t *instr, tinytc_value_t mat, tinytc_data_type_t ty, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix extract instruction + * + * @code %value = cooperative_matrix_extract %mat[index] : ty + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param mat [in] %mat + * @param index [in] index + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_extract_inst_create( + tinytc_inst_t *instr, tinytc_value_t mat, int64_t index, tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix insert instruction + * + * @code %value = cooperative_matrix_insert %val, %mat[index] : ty + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param val [in] %val + * @param mat [in] %mat + * @param index [in] index + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_insert_inst_create( + tinytc_inst_t *instr, tinytc_value_t val, tinytc_value_t mat, int64_t index, + tinytc_data_type_t ty, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix load instruction + * + * @code %value = cooperative_matrix_load.transpose.checked %op[%p0, %p1] : to_ty + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param transpose [in] transpose operation applied on load + * @param flag [in] out-of-bounds checks type + * @param op [in] %op + * @param p0 [in] %p0 + * @param p1 [in] %p1 + * @param to_ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_load_inst_create( + tinytc_inst_t *instr, tinytc_transpose_t transpose, tinytc_checked_flag_t flag, + tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, + const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix mul add instruction + * + * @code cooperative_matrix_mul_add %a, %b, %c : to_ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param a [in] %a + * @param b [in] %b + * @param c [in] %c + * @param to_ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, tinytc_value_t b, tinytc_value_t c, + tinytc_data_type_t to_ty, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix prefetch instruction + * + * @code cooperative_matrix_store cache_level, %op[%p0, %p1], rows, cols + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param cache_level [in] Cache-level; "0" is closest to the core + * @param op [in] %op + * @param p0 [in] %p0 + * @param p1 [in] %p1 + * @param rows [in] Number of rows + * @param cols [in] Number of cols + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_prefetch_inst_create( + tinytc_inst_t *instr, int32_t cache_level, tinytc_value_t op, tinytc_value_t p0, + tinytc_value_t p1, int32_t rows, int32_t cols, const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix scale instruction + * + * @code cooperative_matrix_scale %a, %b : ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param a [in] %a + * @param b [in] %b + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_scale_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create cooperative matrix store instruction + * + * @code cooperative_matrix_store.checked.store_flag %val, %op[%p0, %p1] + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param cflag [in] out-of-bounds checks type + * @param sflag [in] store flag + * @param val [in] %val + * @param op [in] %op + * @param p0 [in] %p0 + * @param p1 [in] %p1 + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cooperative_matrix_store_inst_create( + tinytc_inst_t *instr, tinytc_checked_flag_t cflag, tinytc_store_flag_t sflag, + tinytc_value_t val, tinytc_value_t op, tinytc_value_t p0, tinytc_value_t p1, + const tinytc_location_t *loc); + /** * @brief Create alloca instruction * - * @code %value = alloca -> %ty @endcode + * @code %value = alloca : %ty @endcode * * @param instr [out] pointer to the inst object created * @param ty [in] type that is allocated @@ -302,7 +724,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, ti * @brief Create axpby instruction * * @code - * axpby.. %alpha, %A, %beta, %B : type(%alpha), type(%A), type(%beta), type(%B) + * axpby.. %alpha, %A, %beta, %B * @endcode * * @param instr [out] pointer to the inst object created @@ -322,92 +744,99 @@ TINYTC_EXPORT tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tin tinytc_value_t B, const tinytc_location_t *loc); +/** + * @brief Create cumsum instruction + * + * @code + * cumsum. %alpha, %A, mode, %beta, %B + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param atomic [in] true for atomic updates of B + * @param alpha [in] @f$\alpha@f$ + * @param A [in] A + * @param mode [in] n (summation mode) + * @param beta [in] @f$\beta@f$ + * @param B [in] B + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_cumsum_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, + tinytc_value_t alpha, tinytc_value_t A, + int64_t mode, tinytc_value_t beta, + tinytc_value_t B, + const tinytc_location_t *loc); + /** * @brief Create expand instruction * - * @code %value = expand %a[%mode -> %expand_shape] : type(%a) @endcode + * @code %value = expand %a[%mode -> %expand_shape] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand - * @param mode [in] expanded mode - * @param expand_shape_size [in] dimension of expand shape; must be at least 2 - * @param expand_shape [in][range(2, expand_shape_size)] expand shape array + * @param expanded_mode [in] expanded mode + * @param static_expand_shape_size [in] dimension of static expand shape; must be at least 2 + * @param static_expand_shape [in][range(2, static expand_shape_size)] static expand shape array + * @param expand_shape_size [in][optional] dimension of expand shape; must match number of entries + * equal to TINYTC_DYNAMIC in static_expand_shape array; can be 0 + * @param expand_shape [in][optional][range(0, expand_shape_size)] expand shape array + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t mode, uint32_t expand_shape_size, - tinytc_value_t *expand_shape, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t +tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t expanded_mode, + uint32_t static_expand_shape_size, const int64_t *static_expand_shape, + uint32_t expand_shape_size, const tinytc_value_t *expand_shape, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create fuse instruction * - * @code %value = fuse %a[%from, %to] : type(%a) @endcode + * @code %value = fuse %a[%from, %to] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand * @param from [in] first mode to fuse * @param to [in] last mode to fuse + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t from, int64_t to, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create load instruction * - * @code %value = load %a[%index_list] : type(%a) @endcode + * @code %value = load [flag] %a[%index_list] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand * @param index_list_size [in] number of indices * @param index_list [in][range(0, index_list_size)] indices array; may be nullptr if * index_list_size is 0 + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t index_list_size, - tinytc_value_t *index_list, + const tinytc_value_t *index_list, + tinytc_data_type_t ty, const tinytc_location_t *loc); -/** - * @brief Create group_id instruction - * - * @code %value = group_id @endcode - * - * @param instr [out] pointer to the inst object created - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, - const tinytc_location_t *loc); - -/** - * @brief Create group_size instruction - * - * @code %value = group_size @endcode - * - * @param instr [out] pointer to the inst object created - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, - const tinytc_location_t *loc); /** * @brief Create GEMM instruction * * @code * gemm... %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -435,7 +864,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tiny * * @code * gemv.. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -461,7 +889,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tiny * * @code * ger. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) * @endcode * * @param instr [out] pointer to the inst object created @@ -475,72 +902,148 @@ TINYTC_EXPORT tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tiny * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, - tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, - tinytc_value_t C, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, + tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t B, tinytc_value_t beta, + tinytc_value_t C, + const tinytc_location_t *loc); + +/** + * @brief Create Hadamard instruction + * + * @code + * hadamard. %alpha, %A, %B, %beta, %C + * @endcode + * + * @param instr [out] pointer to the inst object created + * @param atomic [in] true for atomic updates of C + * @param alpha [in] @f$\alpha@f$ + * @param A [in] A + * @param B [in] B + * @param beta [in] @f$\beta@f$ + * @param C [in] C + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( + tinytc_inst_t *instr, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, const tinytc_location_t *loc); + +/** + * @brief Create math instruction (unary) + * + * @code %value = math. %a : ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param op [in] unary math operation type + * @param a [in] operand + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_math_unary_inst_create(tinytc_inst_t *instr, + tinytc_math_unary_t op, + tinytc_value_t a, tinytc_data_type_t ty, + const tinytc_location_t *loc); /** - * @brief Create Hadamard instruction + * @brief Create parallel region * * @code - * hadamard. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) + * parallel { } * @endcode * * @param instr [out] pointer to the inst object created - * @param atomic [in] true for atomic updates of C - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param B [in] B - * @param beta [in] @f$\beta@f$ - * @param C [in] C * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( - tinytc_inst_t *instr, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, + const tinytc_location_t *loc); /** * @brief Create size instruction * - * @code %value = size %a[%mode] : type(%a) @endcode + * @code %value = size %a[%mode] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand * @param mode [in] mode for that the size is queried + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t mode, const tinytc_location_t *loc); + int64_t mode, tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create subgroup broadcast instruction + * + * @code %value = subgroup_broadcast %a, %idx : ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param a [in] operand + * @param idx [in] subgroup local index + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_subgroup_broadcast_inst_create(tinytc_inst_t *instr, + tinytc_value_t a, + tinytc_value_t idx, + tinytc_data_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create subgroup operation instruction + * + * @code %value = subgroup_operation.group_arithmetic.group_operation %a : ty @endcode + * + * @param instr [out] pointer to the inst object created + * @param arith [in] group arithmetic type + * @param operation [in] group operation type + * @param a [in] operand + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_subgroup_operation_inst_create( + tinytc_inst_t *instr, tinytc_group_arithmetic_t arith, tinytc_group_operation_t operation, + tinytc_value_t a, tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create subview instruction * - * @code %value = subview %a[%offset1:%size1,...,%offsetN:%sizeN] : type(%a) @endcode + * @code %value = subview %a[%offset1:%size1,...,%offsetN:%sizeN] : ty @endcode * * @param instr [out] pointer to the inst object created * @param a [in] operand - * @param slice_list_size [in] number of slices - * @param offset_list [in][range(0, slice_list_size)] offset array; may be nullptr if - * slice_list_size is 0 - * @param size_list [in][range(0, slice_list_size)] size array; may be nullptr if slice_list_size - * is 0; size_list[i] may be nullptr if a single offset shall be passed instead of a range for the - * i-th mode + * @param static_list_size [in] number of slices + * @param static_offset_list [in][range(0, static_list_size)] offsets (need to add value to + * offset_list if static_offset_list[i] == TINYTC_DYNAMIC); may be nullptr if static_offset_list = 0 + * @param static_size_list [in][range(0, static_list_size)] sizes (need to add value to size_list + * if static_size_list[i] == TINYTC_DYNAMIC); may be nullptr if static_offset_list = 0 + * @param offset_list_size [in] number of dynamic offsets + * @param offset_list [in][range(0, offset_list_size)] offset array; may be nullptr if + * offset_list_size is 0 + * @param size_list_size [in] number of dynamic sizes + * @param size_list [in][range(0, size_list_size)] size array; may be nullptr if size_list_size is 0 + * @param ty [in] result type * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t slice_list_size, - tinytc_value_t *offset_list, - tinytc_value_t *size_list, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create( + tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, + const int64_t *static_offset_list, const int64_t *static_size_list, uint32_t offset_list_size, + const tinytc_value_t *offset_list, uint32_t size_list_size, const tinytc_value_t *size_list, + tinytc_data_type_t ty, const tinytc_location_t *loc); /** * @brief Create store instruction @@ -548,6 +1051,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, t * @code store %val, %a[%index_list] : type(%a) @endcode * * @param instr [out] pointer to the inst object created + * @param flag [in] store flag * @param val [in] value to store * @param a [in] operand * @param index_list_size [in] number of indices @@ -557,16 +1061,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, t * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, +TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, + tinytc_store_flag_t flag, tinytc_value_t val, tinytc_value_t a, uint32_t index_list_size, - tinytc_value_t *index_list, + const tinytc_value_t *index_list, const tinytc_location_t *loc); /** * @brief Create sum instruction * * @code - * sum.. %alpha, %A, %beta, %B : type(%alpha), type(%A), type(%beta), type(%B) + * sum.. %alpha, %A, %beta, %B * @endcode * * @param instr [out] pointer to the inst object created @@ -590,62 +1095,63 @@ TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinyt * @brief Create for loop * * @code - * for %loop_var = %from, %to, %step : type(%loop_var) { %body } - * ; type(%loop_var) == type(%from) - * ; type(%loop_var) == type(%to) - * ; type(%loop_var) == type(%step) + * for %loop_var : loop_var_type = %from, %to, %step + * init(initial_value_list) -> (types(initial_value_list)) { } + * ; loop_var_type == type(%from) + * ; loop_var_type == type(%to) + * ; loop_var_type == type(%step) * @endcode * * @param instr [out] pointer to the inst object created - * @param loop_var [in] loop variable + * @param loop_var_type [in] type of loop variable * @param from [in] loop begion * @param to [in] loop bound * @param step [in][optional] loop step; can be nullptr - * @param body [in] loop body + * @param init_return_list_size [in] length of init_value_list and return_type_list + * @param initial_value_list [in][range(0, init_return_list_size)] array of initial values; can be + * nullptr if init_return_list_size is 0 + * @param return_type_list [in][range(0, init_return_list_size)] array of return types; can be + * nullptr if init_return_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_value_t step, tinytc_region_t body, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create( + tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, tinytc_value_t from, tinytc_value_t to, + tinytc_value_t step, uint32_t init_return_list_size, const tinytc_value_t *initial_value_list, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); /** * @brief Create foreach loop * * @code - * foreach %loop_var = %from, %to : type(%loop_var) { %body } - * ; type(%loop_var) == type(%from) - * ; type(%loop_var) == type(%to) + * foreach (loop_var_list) : loop_var_type = (from_list), (to_list) { } + * ; loop_var_type == type(%f) forall %f in from_list + * ; loop_var_type == type(%t) forall %t in to_list * @endcode * * @param instr [out] pointer to the inst object created - * @param loop_var [in] loop variable - * @param from [in] loop begion - * @param to [in] loop bound - * @param body [in] loop body + * @param loop_var_type [in] type of loop variable + * @param dim [in] length of from and to array; must be > 0 + * @param from_list [in][range(1, dim)] loop begion + * @param to_list [in][range(1, dim)] loop bound * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, - tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_region_t body, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create( + tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, uint32_t dim, + const tinytc_value_t *from_list, const tinytc_value_t *to_list, const tinytc_location_t *loc); /** * @brief Create if condition * * @code - * if %condition { %then } else { %otherwise } + * if %condition -> (return_type_list, ...) { } else { } * @endcode * * @param instr [out] pointer to the inst object created * @param condition [in] condition - * @param then [in] region taken if condition is true - * @param otherwise [in][optional] region taken if condition is false; can be nullptr * @param return_type_list_size [in] length of return type array * @param return_type_list [in][range(0, return_type_list_size)] return type array; can be nullptr * if return_type_list_size is 0 @@ -654,9 +1160,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, - tinytc_region_t then, tinytc_region_t otherwise, uint32_t return_type_list_size, - tinytc_scalar_type_t *return_type_list, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc); /** @@ -667,47 +1172,54 @@ TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc * @endcode * * @param instr [out] pointer to the inst object created - * @param yield_list_size [in] length of yielded values list; must be at least 1 - * @param yield_list [in][range(1, yield_list_size)] yielded values array + * @param yield_list_size [in] length of yielded values list + * @param yield_list [in][range(0, yield_list_size)] yielded values array; can be nullptr if + * yield_list_size is 0 * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, - tinytc_value_t *yield_list, + const tinytc_value_t *yield_list, const tinytc_location_t *loc); /** - * @brief Release inst object - * - * Decreases reference count by 1, free memory if reference count is 0. + * @brief Delete inst object * * @param instr [inout] inst object - * - * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_release(tinytc_inst_t instr); +TINYTC_EXPORT void tinytc_inst_destroy(tinytc_inst_t instr); /** - * @brief Increase reference count of inst object by 1 + * @brief Get parent region of instruction * - * @param instr [inout] inst object + * @param instr [in] inst object + * @param parent [out] parent region * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_retain(tinytc_inst_t instr); +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_parent_region(tinytc_inst_t instr, + tinytc_region_t *parent); /** - * @brief Get value produced by instruction + * @brief Get child regions of instruction + * + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results * * @param instr [in] inst object - * @param result [out] result value; may be set to nullptr if instruction does not return a value + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greater than the number of + * results, the value is updated with the correct number of results + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, - tinytc_value_t *result); +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, + uint32_t *result_list_size, + tinytc_region_t *result_list); /** * @brief Get values produced by instruction @@ -716,130 +1228,203 @@ TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, * the number of results * * @param instr [in] inst object - * @param result_list_size [inout] number of results to fetch; is updated with the actual value + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greater than the number of + * results, the value is updated with the correct number of results * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_list_size, tinytc_value_t *result_list); +/** + * @brief Set instruction attributes + * + * @param instr [inout] inst object + * @param a [in] attribute object (dictionary attribute) + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_inst_set_attr(tinytc_inst_t instr, tinytc_attr_t a); + //////////////////////////// ////////// Region ////////// //////////////////////////// /** - * @brief Create region + * @brief Append instruction to region * - * @param reg [out] pointer to the region object created - * @param instruction_list_size [in] length of instruction array - * @param instruction_list [in][range(0, instruction_list_size)] instruction array; can be nullptr - * if instruction_list_size is 0 - * @param loc [in][optional] Source code location; can be nullptr + * The region takes ownership of the instruction. + * An instruction must not be added to multiple regions. + * + * @param reg [inout] region object + * @param instr [in,pass_ownership] instruction * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_region_create(tinytc_region_t *reg, - uint32_t instruction_list_size, - tinytc_inst_t *instruction_list, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_region_append(tinytc_region_t reg, tinytc_inst_t instr); + /** - * @brief Release region object + * @brief Returns iterator pointing to begin of the region * - * Decreases reference count by 1, free memory if reference count is 0. + * @param reg [in] region + * @param iterator [out] inst iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_begin(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator); + +/** + * @brief Returns iterator pointing to the end of the region + * + * The end iterator must not be dereferenced. + * + * @param reg [in] region] + * @param iterator [out] inst iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_end(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator); + +/** + * @brief Erase instruction at the position of the iterator + * + * The iterator is updated to point to the instruction coming after the iterator or the end iterator * * @param reg [inout] region object + * @param iterator [inout] iterator * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_region_release(tinytc_region_t reg); +TINYTC_EXPORT tinytc_status_t tinytc_region_erase(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator); /** - * @brief Increase reference count of region object by 1 + * @brief Insert instruction at the position before the iterator + * + * The iterator is updated to point to the instruction that was just inserted. + * + * The region takes ownership of the instruction. + * An instruction must not be inserted into multiple regions. * * @param reg [inout] region object + * @param iterator [inout] + * @param instr [in,pass_ownership] * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_region_retain(tinytc_region_t reg); +TINYTC_EXPORT tinytc_status_t tinytc_region_insert(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator, + tinytc_inst_t instr); -//////////////////////////// -/////////// Func /////////// -//////////////////////////// +/** + * @brief Move iterator to the next instruction + * + * @param iterator [inout] iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_next_inst(tinytc_inst_iterator_t *iterator); /** - * @brief Create function prototype + * @brief Move iterator to the previous instruction * - * @param fun [out] pointer to the func object created - * @param name [in] function name - * @param arg_list_size [in] length of argument array - * @param arg_list [in][range(0, arg_list_size)] argument array; can be nullptr if arg_list_size is - * 0 - * @param loc [in][optional] Source code location; can be nullptr + * @param iterator [inout] iterator * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, - tinytc_value_t *arg_list, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_prev_inst(tinytc_inst_iterator_t *iterator); + +/** + * @brief Get region parameters + * + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results + * + * @param reg [in] region object + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greather than the number of + * results, the value is updated with the correct number of results + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, + uint32_t *result_list_size, + tinytc_value_t *result_list); + +//////////////////////////// +/////////// Func /////////// +//////////////////////////// /** * @brief Create function * + * Function takes ownership of region. + * * @param fun [out] pointer to the func object created - * @param prototype [in] function prototype - * @param body [in] function body + * @param name_length [in] length of function_name + * @param name [in] function name + * @param num_params [in] number of parameters + * @param param_type_list [in][range(0,num_params)] parameter data types; can be nullptr if + * num_params is 0 + * @param ty [in] result type (must be void for host-callable function) * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototype, - tinytc_region_t body, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t name_length, + char const *name, uint32_t num_params, + const tinytc_data_type_t *param_type_list, + tinytc_data_type_t ty, + const tinytc_location_t *loc); /** - * @brief Set work-group size + * @brief Set function attributes * - * @param fun [out] func object (must be the function definition, not the function prototype) - * @param x [in] number of rows in parallel grid; must be a multiple of the subgroup size - * @param y [in] number of columns in parallel grid + * @param fun [inout] function object + * @param a [in] attribute dictionary * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, - int32_t y); +TINYTC_EXPORT tinytc_status_t tinytc_func_set_attr(tinytc_func_t fun, tinytc_attr_t a); + /** - * @brief Set subgroup size + * @brief Set parameter attributes * - * @param fun [out] func object (must be the function definition, not the function prototype) - * @param sgs [in] subgroup size; the supported values need to be queried from the compute device + * @param fun [in] function object + * @param param_no [in] paramater number (0 to num_parameters-1) + * @param a [in] attribute dictionary * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs); +TINYTC_EXPORT tinytc_status_t tinytc_func_set_parameter_attr(tinytc_func_t fun, int32_t param_no, + tinytc_attr_t a); /** - * @brief Release function object + * @brief Get function body * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param fun [inout] function object + * @param fun [in] function object + * @param body [out] pointer to body region * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_func_release(tinytc_func_t fun); +TINYTC_EXPORT tinytc_status_t tinytc_func_get_body(tinytc_func_t fun, tinytc_region_t *body); /** - * @brief Increase reference count of function object by 1 + * @brief Delete function object * * @param fun [inout] function object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_func_retain(tinytc_func_t fun); +TINYTC_EXPORT void tinytc_func_destroy(tinytc_func_t fun); //////////////////////////// /////////// Prog /////////// @@ -849,15 +1434,39 @@ TINYTC_EXPORT tinytc_status_t tinytc_func_retain(tinytc_func_t fun); * @brief Create program * * @param prg [out] pointer to the prog object created - * @param fun_list_size [in] length of func array - * @param fun_list [in][range(0, fun_list_size)] func array; can be nullptr if fun_list_size is 0 + * @param ctx [in] compiler context object * @param loc [in][optional] Source code location; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size, - tinytc_func_t *fun_list, - const tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc); + +/** + * @brief Append function to program + * + * The program takes ownership of the function. + * A function must not be added to multiple programs nor must the user destroy the function after + * adding it to the program. + * + * @param prg [inout] program object + * @param fun [in,pass_ownership] function object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun); + +/** + * @brief Get context object from program object + * + * @param prg [in] program object + * @param ctx [out] pointer to context object; reference count is increased so the user needs to + * call tinytc_compiler_context_release to clean up + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, + tinytc_compiler_context_t *ctx); /** * @brief Release program object @@ -879,6 +1488,30 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_release(tinytc_prog_t prg); */ TINYTC_EXPORT tinytc_status_t tinytc_prog_retain(tinytc_prog_t prg); +//////////////////////////// +/////// SPIR-V Module ////// +//////////////////////////// + +/** + * @brief Release SPIR-V module + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param mod [inout] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_release(tinytc_spv_mod_t mod); + +/** + * @brief Increase reference count of SPIR-V module by 1 + * + * @param mod [inout] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_retain(tinytc_spv_mod_t mod); + //////////////////////////// // Visitors and transforms / //////////////////////////// @@ -915,6 +1548,39 @@ TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, */ TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_string(const_tinytc_prog_t prg, char **str); +/** + * @brief Dump SPIR-V module to stderr + * + * @param mod [in] module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_dump(const_tinytc_spv_mod_t mod); + +/** + * @brief Print SPIR-V module to file + * + * @param mod [in] module + * @param filename [in] filename + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_print_to_file(const_tinytc_spv_mod_t mod, + char const *filename); + +/** + * @brief Print SPIR-V module to string + * + * The user is responsible to dispose the string with tinytc_string_destroy. + * + * @param mod [in] module + * @param str [out] pointer to string + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_print_to_string(const_tinytc_spv_mod_t mod, + char **str); + /** * @brief Delete a (non-const) string returned from tinytc API * @@ -954,6 +1620,17 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_generic_create(tinytc_core_info_t TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_arch( tinytc_core_info_t *info, tinytc_intel_gpu_architecture_t arch); +/** + * @brief Look up core info for Intel GPU architecture + * + * @param info [out] pointer to the core_info object created + * @param name [in] architecture name + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_name(tinytc_core_info_t *info, + char const *name); + /** * @brief Create core_info for Intel GPUs * @@ -1000,7 +1677,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_register_space(const_tinytc_c /** * @brief Set core features * - * @param info [in] core info object + * @param info [inout] core info object * @param flags [in] set core features; must be 0 or a combination of tinytc_core_feature_flag_t * * @return tinytc_status_success on success and error otherwise @@ -1016,8 +1693,56 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_core_features(tinytc_core_inf * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t -tinytc_core_info_get_core_features(tinytc_core_info_t info, tinytc_core_feature_flags_t *flags); +TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_core_features( + const_tinytc_core_info_t info, tinytc_core_feature_flags_t *flags); + +/** + * @brief Set SPIR-V feature + * + * @param info [inout] core info object + * @param feature [in] SPIR-V feature + * @param available [in] Set to true if feature is available and false otherwise + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_spirv_feature(tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t available); + +/** + * @brief Get SPIR-V feature + * + * @param info [in] core info object + * @param feature [in] SPIR-V feature + * @param available [out] Writes true to available if feature is available and false otherwise + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_have_spirv_feature(const_tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t *available); + +/** + * @brief Get default memref alignment + * + * @param info [in] Core info + * @param alignment [out] pointer to alignment in bytes + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_default_alignment(const_tinytc_core_info_t info, + int32_t *alignment); + +/** + * @brief Set default memref alignment + * + * @param info [inout] Core info + * @param alignment [in] alignment in bytes + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_default_alignment(tinytc_core_info_t info, + int32_t alignment); /** * @brief Release core info object @@ -1039,6 +1764,9 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_release(tinytc_core_info_t obj); */ TINYTC_EXPORT tinytc_status_t tinytc_core_info_retain(tinytc_core_info_t obj); +//! Convert SPIR-V feature to string +TINYTC_EXPORT char const *tinytc_spirv_feature_to_string(tinytc_spirv_feature_t f); + //////////////////////////// ////////// Parser ////////// //////////////////////////// @@ -1048,22 +1776,22 @@ TINYTC_EXPORT tinytc_status_t tinytc_core_info_retain(tinytc_core_info_t obj); * * @param prg [out] pointer to prog object created * @param filename [in] path to source file - * @param ctx [inout][optional] source context object; stores error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, - tinytc_source_context_t ctx); + tinytc_compiler_context_t ctx); /** * @brief Parser tensor language source from stdin and create prog * * @param prg [out] pointer to prog object created - * @param ctx [inout][optional] source context object; stores error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t ctx); +TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_compiler_context_t ctx); /** * @brief Parser tensor language source from string and create prog @@ -1071,176 +1799,208 @@ TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_sour * @param prg [out] pointer to prog object created * @param source_size [in] length of source string * @param source [in] source string - * @param ctx [inout][optional] source context object; stores error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, - char const *source, tinytc_source_context_t ctx); + char const *source, + tinytc_compiler_context_t ctx); /** - * @brief Create source context + * @brief Create context * - * The source context stores the tensor language source and enhaces error messages with - * source code context. + * The context stores the tensor language source and reports enhaces error messages with + * source code context. Moreover, the context caches data such as types and constants. * - * @param ctx [out] pointer to the source context object created + * @param ctx [out] pointer to the context object created * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_create(tinytc_source_context_t *ctx); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx); /** * @brief Add source context * - * Manually add a source file to the source context that can be referenced in a tinytc_location. + * Manually add a source file to the context that can be referenced in a tinytc_location. * Useful to enhance error messages when using the builder methods and classes. * - * @param ctx [in] source context object + * @param ctx [in] context object * @param name [in] source name * @param text [in] source text * @param source_id [out] pointer to source id * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_add_source(tinytc_source_context_t ctx, - char const *name, char const *text, - int32_t *source_id); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler_context_t ctx, + char const *name, char const *text, + int32_t *source_id); + +/** + * @brief Set error reporter + * + * Error reporting function that is called whenever an error occurs in the parser or the builder. + * + * @param ctx [inout] context object + * @param reporter [in] error reporting callback; set to nullptr to disable reporting + * @param user_data [in][optional] pointer to user data that is passed to the callback; can be + * nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_error_reporter( + tinytc_compiler_context_t ctx, tinytc_error_reporter_t reporter, void *user_data); /** - * @brief Get error log + * @brief Sets an optimization flag + * + * The state can be 0 (disabled), 1 (enabled), or -1 (use default according to optimization level). + * + * @param ctx [inout] context object + * @param flag [in] optimization flag + * @param state [in] flag state * - * The string's memory is owned by source context. - * Note that the pointer may invalidated by any function call involving the source context object, - * so the string should be copied or printed right after a call to this function. + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_optimization_flag( + tinytc_compiler_context_t ctx, tinytc_optflag_t flag, int32_t state); + +/** + * @brief Set optimization level (from 0 to 2) * - * @param ctx [in] source context object - * @param log [out] pointer to string + * @param ctx [inout] context object + * @param level [in] optimization level * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_get_error_log(const_tinytc_source_context_t ctx, - char const **log); +TINYTC_EXPORT tinytc_status_t +tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, int32_t level); /** * @brief Report an error and augment the error with source context * - * @param ctx [in] source context object + * @param ctx [in] context object * @param location [in] source location * @param what [in] error description - * @param append [in] true: append to error log, false: clear error log * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_report_error(tinytc_source_context_t ctx, - const tinytc_location_t *location, - char const *what, - tinytc_bool_t append); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_report_error( + tinytc_compiler_context_t ctx, const tinytc_location_t *location, char const *what); /** - * @brief Release source context object + * @brief Release context object * * Decreases reference count by 1, free memory if reference count is 0. * - * @param obj [inout] source context object + * @param obj [inout] context object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_release(tinytc_source_context_t obj); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_release(tinytc_compiler_context_t obj); /** - * @brief Increase reference count of source context object by 1 + * @brief Increase reference count of context object by 1 * - * @param obj [inout] source context object + * @param obj [inout] context object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_retain(tinytc_source_context_t obj); +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_retain(tinytc_compiler_context_t obj); //////////////////////////// ///////// Compiler ///////// //////////////////////////// /** - * @brief Compile tensor language to OpenCL-C + * @brief Run a function pass on every function of a program * - * @param src [out] pointer to the source object created - * @param prg [inout] tensor program; modified as compiler passes are run - * @param info [in] core info object - * @param ctx [inout][optional] source context object to save extended error messages that are - * enhanced with source code context; can be nullptr + * @param pass_name [in] name of function pass; cf. tinytc_list_function_passes + * @param prg [inout] tensor program; modified as compiler pass is run + * @param info [in][optional] core info object; might be nullptr if core info is not required for + * pass * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx); +TINYTC_EXPORT tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info); /** - * @brief Get source text + * @brief List function passes * - * @param src [in] source object - * @param length [out] pointer to code length - * @param code [out] code contains a pointer to the source text; the pointer is only valid as long - * as the source object is alive + * @param names_size [out] pointer to number of function pass names + * @param names [out][range(0,names_size)] pointer to array of C-strings; array owned by tinytc * - * @return tinytc_status_success on success and error otherwise + * @return */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, - char const **code); +TINYTC_EXPORT tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, + char const *const **names); /** - * @brief Get source location + * @brief Compile tensor language to SPIR-V * - * @param src [in] source object - * @param loc [out] pointer to location + * @param mod [out] pointer to the SPIR-V module created + * @param prg [inout] tensor program; modified as compiler passes are run + * @param info [in] core info object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, - tinytc_location_t *loc); +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, + const_tinytc_core_info_t info); /** - * @brief Get core features + * @brief Compiler tensor language to SPIR-V and assemble * - * @param src [in] source object - * @param core_features [out] pointer to core features + * @param bin [out] pointer to the binary object created + * @param prg [inout] tensor program; modified as compiler passes are run + * @param info [in] core info object * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_core_features( - const_tinytc_source_t src, tinytc_core_feature_flags_t *core_features); +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble( + tinytc_binary_t *bin, tinytc_prog_t prg, const_tinytc_core_info_t info); /** - * @brief Get required OpenCL extensions + * @brief Assemble SPIR-V module * - * @param src [in] source object - * @param extensions_size [out] pointer to number of extensions - * @param extensions [out][range(0,extensions_size)] pointer to array of C-strings; array owned by - * source object + * @param bin [out] pointer to the binary object created + * @param mod [in] SPIR-V module * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t src, - uint32_t *extensions_size, - char const *const **extensions); +TINYTC_EXPORT tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, + const_tinytc_spv_mod_t mod); /** * @brief Create binary * * @param bin [out] pointer to binary object + * @param ctx [in] compiler context * @param format [in] Bundle format (SPIR-V or Native) * @param data_size [in] Size of data in bytes * @param data [in][range(0, data_size)] Binary data; data is copied - * @param core_features [in][optional] requested core features; must be 0 (default) or a combination - * of tinytc_core_feature_flag_t + * @param core_features [in][optional] requested core features; must be 0 (default) or a + * combination of tinytc_core_feature_flag_t * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, + tinytc_compiler_context_t ctx, tinytc_bundle_format_t format, size_t data_size, uint8_t const *data, tinytc_core_feature_flags_t core_features); +/** + * @brief Get context object from binary object + * + * @param bin [in] binary object + * @param ctx [out] pointer to context object; reference count is increased so the user needs to + * call tinytc_compiler_context_release to clean up + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_get_compiler_context(const_tinytc_binary_t bin, + tinytc_compiler_context_t *ctx); + /** * @brief Get raw binary data * @@ -1265,26 +2025,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, TINYTC_EXPORT tinytc_status_t tinytc_binary_get_core_features( const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features); -/** - * @brief Release source object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_release(tinytc_source_t obj); - -/** - * @brief Increase reference count of source object by 1 - * - * @param obj [inout] source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_retain(tinytc_source_t obj); - /** * @brief Release binary object * @@ -1312,8 +2052,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_binary_retain(tinytc_binary_t bin); /** * @brief Returns a small batched GEMM recipe * - * The program contains a kernel for @f$\beta=0@f$ called "gemm_beta0" and a kernel for @f$\beta\neq - * 0@f$ called "gemm". All matrix shapes and strides are known at compile-time. + * The program contains a kernel for @f$\beta=0@f$ called "gemm_beta0" and a kernel for + * @f$\beta\neq 0@f$ called "gemm". All matrix shapes and strides are known at compile-time. * * The signature of the generated kernels gemm and gemm_beta0 is (if A and B are not transposed) * @@ -1345,7 +2085,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_binary_retain(tinytc_binary_t bin); * @param strideB [in] Number of elements between B-matrices * @param ldC [in] Leading dimension of C * @param strideC [in] Number of elements between C-matrices - * @param ctx [inout][optional] source context object; saves error log; can be nullptr + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ @@ -1353,7 +2093,7 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_create( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, tinytc_transpose_t tA, tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, int64_t strideA, int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC, - tinytc_source_context_t ctx); + tinytc_compiler_context_t ctx); /** * @brief Set kernel arguments for small GEMM batched recipe @@ -1381,8 +2121,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( /** * @brief Returns a tall and skinny recipe * - * The program contains a kernel for beta = 0 called "gemm_beta0" and a kernel for beta != 0 called - * "gemm". M (= number of rows of A, C) and strides are dynamic. + * The program contains a kernel for beta = 0 called "gemm_beta0" and a kernel for beta != 0 + * called "gemm". M (= number of rows of A, C) and strides are dynamic. * * The signature of the generated kernels gemm and gemm_beta0 is * @@ -1407,24 +2147,27 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( * @param ty [in] Scalar type of alpha, A, B, beta, C * @param N [in] Number of columns of B, C * @param K [in] Number columns of A, number of rows of B - * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the - * parameter auto-selected - * @param ctx [inout][optional] source context object; saves error log; can be nullptr + * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have + * the parameter auto-selected + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t N, - int64_t K, int32_t M_block_size, tinytc_source_context_t ctx); + int64_t K, int32_t M_block_size, tinytc_compiler_context_t ctx); /** * @brief Returns a tall and skinny recipe with additional specialization constants * * Similar to tinytc_recipe_tall_and_skinny_create but with the additional specialization * constants M, ldA, ldB, and ldC. - * The specializtion constants may be either set to a fixed value or to TINYTC_DYNAMIC. - * Note that if a specialization constant is set to a fixed value then the parameter with the same - * name in tinytc_recipe_tall_and_skinny_set_args is ignored. + * The specialization constants may be either set to a fixed value or to TINYTC_DYNAMIC. + * Note that if a specialization constant is set to a fixed value then the parameter with the + * same name in tinytc_recipe_tall_and_skinny_set_args is ignored. + * + * Furthermore, the memory alignment may be passed with alignA, alignB, and alignC or set to 0 to + * use the default memory alignment (= size of scalar type). * * The generated kernels have the following signature: * @@ -1445,16 +2188,19 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( * @param ldA [in] Leading dimension of A; can be TINYTC_DYNAMIC * @param ldB [in] Leading dimension of B; can be TINYTC_DYNAMIC * @param ldC [in] Leading dimension of C; can be TINYTC_DYNAMIC - * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the - * parameter auto-selected - * @param ctx [inout][optional] source context object; saves error log; can be nullptr + * @param alignA [in] Memory alignment of A; can be 0 + * @param alignB [in] Memory alignment of B; can be 0 + * @param alignC [in] Memory alignment of C; can be 0 + * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have + * the parameter auto-selected + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr * * @return */ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t M, - int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t M_block_size, - tinytc_source_context_t ctx); + int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t alignA, int32_t alignB, + int32_t alignC, int32_t M_block_size, tinytc_compiler_context_t ctx); /** * @brief Suggest an M block size for tall and skinny recipe @@ -1498,8 +2244,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( * @brief Get prog object * * @param recipe [in] recipe object - * @param prg [out] pointer to prog object; reference count is increased so the user needs to call - * tinytc_prog_release to clean up + * @param prg [out] pointer to prog object; reference count is increased so the user needs to + * call tinytc_prog_release to clean up * * @return tinytc_status_success on success and error otherwise */ @@ -1507,16 +2253,16 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recip tinytc_prog_t *prg); /** - * @brief Get source object + * @brief Get binary * * @param recipe [in] recipe object - * @param src [out] pointer to source object; reference count is increased so the user needs to call - * tinytc_source_release to clean up + * @param bin [out] pointer to binary; reference count is increased so the user needs to + * call tinytc_binary_release to clean up * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_source(const_tinytc_recipe_t recipe, - tinytc_source_t *src); +TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_binary(const_tinytc_recipe_t recipe, + tinytc_binary_t *bin); /** * @brief Release recipe object @@ -1542,8 +2288,8 @@ TINYTC_EXPORT tinytc_status_t tinytc_recipe_retain(tinytc_recipe_t obj); * @brief Get recipe object * * @param handler [in] recipe handler object - * @param recipe [out] pointer to recipe object; reference count is increased so the user needs to - * call tinytc_recipe_release to clean up + * @param recipe [out] pointer to recipe object; reference count is increased so the user needs + * to call tinytc_recipe_release to clean up * * @return tinytc_status_success on success and error otherwise */ diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 99b19b41..d5fafb09 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -7,16 +7,24 @@ #include "tinytc/tinytc.h" #include "tinytc/types.hpp" +#include +#include #include #include #include #include -#include #include #include #include #include +// For bit_cast, memcpy for C++ < 2020 +#if __cplusplus >= 202002L +#include +#else +#include +#endif + namespace tinytc { //////////////////////////// @@ -59,6 +67,311 @@ inline void CHECK_STATUS_LOC(tinytc_status_t code, location const &loc) { } } +//////////////////////////// +////////// FP math ///////// +//////////////////////////// + +/** + * @brief IEEE754 floating point format parameters + * + * @tparam ExponentBits Number of exponent bits + * @tparam MantissaBits Number of mantissa bits + */ +template struct ieee754_format { + constexpr static uint32_t exponent_bits = ExponentBits; ///< Number of exponent bits + constexpr static uint32_t mantissa_bits = MantissaBits; ///< Number of mantissa bits + //! Total number of bits + constexpr static uint32_t num_bits = 1 + exponent_bits + mantissa_bits; + //! Bias + constexpr static uint32_t bias = (1 << (exponent_bits - 1)) - 1; + //! Max exponent when encoded with bias added + constexpr static uint32_t max_biased_exponent = (1 << exponent_bits) - 1; + //! Bit mask for sign bit + constexpr static uint32_t sign_mask = 1 << (num_bits - 1); + //! Bit mask for exponent bits + constexpr static uint32_t exponent_mask = max_biased_exponent << mantissa_bits; + //! Bit mask for exponent mantissa bits + constexpr static uint32_t mantissa_mask = (1 << mantissa_bits) - 1; + //! Number of bytes + constexpr static uint32_t num_bytes = 1 + (num_bits - 1) / 8; + //! Unsigned integer type large enough to store bit pattern + using bits_type = std::conditional_t< + num_bytes == 1, std::uint8_t, + std::conditional_t< + num_bytes == 2, std::uint16_t, + std::conditional_t>>>; +}; + +//! Floating point format for bf16 (bfloat16) +using bf16_format = ieee754_format<8, 7>; +//! Floating point format for f16 (half) +using f16_format = ieee754_format<5, 10>; +//! Floating point format for f32 (float) +using f32_format = ieee754_format<8, 23>; + +/** + * @brief Truncate high precision floating point number and return low precision floating point + * number + * + * @tparam F16f low precision floating point format + * @tparam F32f high precision floating point format + * @param x bit pattern of high precision number + * + * @return bit pattern of low precision number + */ +template +constexpr auto ieee754_truncate(typename F32f::bits_type x) -> F16f::bits_type { + using UI = F32f::bits_type; + using UITrunc = F16f::bits_type; + constexpr UI num_shift_bits = F32f::mantissa_bits - F16f::mantissa_bits; + auto const round_nearest_even_and_truncate = [](UI mantissa32) { + constexpr UI midpoint = (1 << num_shift_bits) / 2; + const UI bias = ((mantissa32 >> num_shift_bits) & 0x1) + (midpoint - 1); + return (mantissa32 + bias) >> num_shift_bits; + }; + + const UITrunc sign = (x & F32f::sign_mask) >> (F32f::num_bits - F16f::num_bits); + const UI exponent32 = (x & F32f::exponent_mask) >> F32f::mantissa_bits; + const UI mantissa32 = x & F32f::mantissa_mask; + + UITrunc exponent16 = 0; + UITrunc mantissa16 = 0; + if (exponent32 > F32f::bias + F16f::bias) { + exponent16 = F16f::max_biased_exponent; + // Map numbers except NaN to inf + if (exponent32 < F32f::max_biased_exponent) { + mantissa16 = 0; + } else { + // Need to ceil to make sure that NaN is not truncated to inf + mantissa16 = 1 + ((mantissa32 - 1) >> num_shift_bits); + } + } else if (F32f::bias == F16f::bias || exponent32 > F32f::bias - F16f::bias) { + // convert bias + // E_{32} = e + F32f::bias + // E_{16} = e + F16f::bias + // = E_{32} - F32f::bias + F16f::bias + // = E_{32} - (F32f::bias - F16f::bias) + exponent16 = exponent32 - (F32f::bias - F16f::bias); + mantissa16 = round_nearest_even_and_truncate(mantissa32); + } else if (exponent32 >= F32f::bias + 1 - F16f::bias - F16f::mantissa_bits) { + exponent16 = 0; + mantissa16 = round_nearest_even_and_truncate((mantissa32 | (1 << F32f::mantissa_bits)) >> + ((F32f::bias + 1 - F16f::bias) - exponent32)); + } + + exponent16 <<= F16f::mantissa_bits; + + // Need to add mantissa as it might overflow during rounding and then we need to increase the + // exponent by 1 + return (sign | exponent16) + mantissa16; +} + +/** + * @brief Extend low precision floating point number and return high precision floating point + * number + * + * @tparam F32f high precision floating point format + * @tparam F16f low precision floating point format + * @param x bit pattern of low precision number + * + * @return bit pattern of high precision number + */ +template +constexpr auto ieee754_extend(typename F16f::bits_type x) -> F32f::bits_type { + using UIExt = F32f::bits_type; + const UIExt sign = (x & F16f::sign_mask) << (F32f::num_bits - F16f::num_bits); + const UIExt exponent16 = (x & F16f::exponent_mask) >> F16f::mantissa_bits; + const UIExt mantissa16 = x & F16f::mantissa_mask; + + UIExt exponent32 = exponent16; + UIExt mantissa32 = mantissa16; + if (F32f::exponent_bits != F16f::exponent_bits) { + if (exponent16 == F16f::max_biased_exponent) { + // Inf and NaN + exponent32 = F32f::max_biased_exponent; + } else if (exponent16 != 0) { + // convert bias + // E_{16} = e + F16f::bias + // E_{32} = e + F32f::bias + // = E_{16} - F16f::bias + F32f::bias + // = E_{16} + (F32f::bias - F16f::bias) + exponent32 += F32f::bias - F16f::bias; + } + + // Subnormal f16 numbers must be represented as f32 normal numbers + if (exponent16 == 0 && mantissa16 != 0) { + UIExt shift_count = 0; + do { + mantissa32 <<= 1; + ++shift_count; + } while ((mantissa32 & (1 << F16f::mantissa_bits)) != (1 << F16f::mantissa_bits)); + mantissa32 &= F16f::mantissa_mask; + exponent32 = F32f::bias + 1 - F16f::bias - shift_count; + } + } + + // shift mantissa + mantissa32 <<= F32f::mantissa_bits - F16f::mantissa_bits; + + // shift exponent + exponent32 <<= F32f::mantissa_bits; + + return sign | exponent32 | mantissa32; +} + +/** + * @brief Low precision float type + * + * For all operations, low precision floats are converted single precision, the operation is done in + * single precision, and then the result is stored in the low precision type + * + * @tparam T storage type + * @tparam F16f low precision floating point format + */ +template class lp_float { + public: + using lp_format = F16f; + + constexpr lp_float() = default; + + constexpr lp_float(lp_float const &) = default; + constexpr lp_float(lp_float &&) = default; + constexpr lp_float &operator=(lp_float const &) = default; + constexpr lp_float &operator=(lp_float &&) = default; + +#if __cplusplus >= 202002L +#define TINYTC_LPFLOAT_CONSTEXPR constexpr + //! construct from float + constexpr lp_float(float const &val) + : data_{ieee754_truncate(std::bit_cast(val))} {} + //! assign float + constexpr auto operator=(float const &rhs) -> lp_float & { return *this = lp_float{rhs}; } + //! implicit conversion to float + constexpr operator float() const { + auto bits = ieee754_extend(data_); + return std::bit_cast(bits); + } +#else +#define TINYTC_LPFLOAT_CONSTEXPR + //! construct from float + lp_float(float const &val) { + f32_format::bits_type bits; + memcpy(&bits, &val, sizeof(f32_format::bits_type)); + data_ = ieee754_truncate(bits); + } + //! assign float + auto operator=(float const &rhs) -> lp_float & { return *this = lp_float{rhs}; } + //! implicit conversion to float + operator float() const { + auto bits = ieee754_extend(data_); + float number; + memcpy(&number, &bits, sizeof(f32_format::bits_type)); + return number; + } +#endif + + //! Get bit representation + TINYTC_LPFLOAT_CONSTEXPR auto bits() const -> T { return data_; } + //! Construct lp_float from bit representation + constexpr static auto from_bits(T const &val) -> lp_float { + auto r = lp_float{}; + r.data_ = val; + return r; + } + + //! add + TINYTC_LPFLOAT_CONSTEXPR auto operator+(lp_float const &rhs) const -> lp_float { + return operator float() + static_cast(rhs); + } + //! add to + TINYTC_LPFLOAT_CONSTEXPR auto operator+=(lp_float const &rhs) -> lp_float & { + return *this = *this + rhs; + } + //! subtract + TINYTC_LPFLOAT_CONSTEXPR auto operator-(lp_float const &rhs) const -> lp_float { + return operator float() - static_cast(rhs); + } + //! subtract from + TINYTC_LPFLOAT_CONSTEXPR auto operator-=(lp_float const &rhs) -> lp_float & { + return *this = *this - rhs; + } + //! multiply + TINYTC_LPFLOAT_CONSTEXPR auto operator*(lp_float const &rhs) const -> lp_float { + return operator float() * static_cast(rhs); + } + //! multiply with + TINYTC_LPFLOAT_CONSTEXPR auto operator*=(lp_float const &rhs) -> lp_float & { + return *this = *this * rhs; + } + //! divide + TINYTC_LPFLOAT_CONSTEXPR auto operator/(lp_float const &rhs) const -> lp_float { + return operator float() / static_cast(rhs); + } + //! divide with + TINYTC_LPFLOAT_CONSTEXPR auto operator/=(lp_float const &rhs) -> lp_float & { + return *this = *this / rhs; + } + //! unary minus + TINYTC_LPFLOAT_CONSTEXPR auto operator-() -> lp_float { return -operator float(); } + //! pre-increase by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator++() -> lp_float & { + return *this = operator float() + 1.0f; + } + //! post-increase by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator++(int) -> lp_float { + lp_float tmp = *this; + operator++(); + return tmp; + } + //! pre-decrease by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator--() -> lp_float & { + return *this = operator float() - 1.0f; + } + //! post-decrease by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator--(int) -> lp_float { + lp_float tmp = *this; + operator--(); + return tmp; + } + //! equal + TINYTC_LPFLOAT_CONSTEXPR auto operator==(lp_float const &rhs) const -> bool { + return operator float() == static_cast(rhs); + } + //! not equal + TINYTC_LPFLOAT_CONSTEXPR auto operator!=(lp_float const &rhs) const -> bool { + return operator float() != static_cast(rhs); + } + //! greater than + TINYTC_LPFLOAT_CONSTEXPR auto operator>(lp_float const &rhs) const -> bool { + return operator float() > static_cast(rhs); + } + //! greater than or equal + TINYTC_LPFLOAT_CONSTEXPR auto operator>=(lp_float const &rhs) const -> bool { + return operator float() >= static_cast(rhs); + } + //! less than + TINYTC_LPFLOAT_CONSTEXPR auto operator<(lp_float const &rhs) const -> bool { + return operator float() < static_cast(rhs); + } + //! less than or equal + TINYTC_LPFLOAT_CONSTEXPR auto operator<=(lp_float const &rhs) const -> bool { + return operator float() <= static_cast(rhs); + } + + private: + T data_; +}; + +/** + * @brief bf16 host emulation type + */ +using bfloat16 = lp_float; +/** + * @brief fp16 host emulation type + */ +using half = lp_float; + //////////////////////////// //////// Scalar type /////// //////////////////////////// @@ -80,10 +393,6 @@ inline std::size_t size(scalar_type ty) { */ template struct to_scalar_type; //! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i1; ///< value -}; -//! to_scalar_type specialization template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::i8; ///< value }; @@ -100,6 +409,14 @@ template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::i64; ///< value }; //! to_scalar_type specialization +template <> struct to_scalar_type { + static constexpr scalar_type value = scalar_type::bf16; ///< value +}; +//! to_scalar_type specialization +template <> struct to_scalar_type { + static constexpr scalar_type value = scalar_type::f16; ///< value +}; +//! to_scalar_type specialization template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::f32; ///< value }; @@ -107,6 +424,14 @@ template <> struct to_scalar_type { template <> struct to_scalar_type { static constexpr scalar_type value = scalar_type::f64; ///< value }; +//! to_scalar_type specialization +template <> struct to_scalar_type> { + static constexpr scalar_type value = scalar_type::c32; ///< value +}; +//! to_scalar_type specialization +template <> struct to_scalar_type> { + static constexpr scalar_type value = scalar_type::c64; ///< value +}; /** * Convenience variable for to_scalar_type. * @@ -118,6 +443,34 @@ template inline constexpr scalar_type to_scalar_type_v = to_scalar_ // Shared / unique handle // //////////////////////////// +template class handle { + public: + //! Create empty (invalid) handle + handle() : obj_{nullptr} {} + //! Create handle from C handle + handle(T obj) : obj_(obj) {} + + //! Dereference C handle and get reference to underlying type + auto operator*() const -> std::remove_pointer_t & { return *obj_; } + //! Convert handle to C handle + auto operator->() const -> T { return obj_; } + //! Returns C handle + auto get() const -> T { return obj_; } + + //! Check whether handle is non-empty (valid) + explicit operator bool() const noexcept { return obj_ != nullptr; } + + //! Check equality + bool operator==(handle const &other) const { return obj_ == other.obj_; } + //! Check inequality + bool operator!=(handle const &other) const { return !(*this == other); } + + operator T() const { return obj_; } + + protected: + T obj_; +}; + namespace internal { //! Wraps retain / release calls for type T template struct shared_handle_traits {}; @@ -273,303 +626,527 @@ template class unique_handle { }; //////////////////////////// -///////// Data type //////// +//////// Array view //////// //////////////////////////// -//! Check if mode i is dynamic ('?') -inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } +/** + * @brief Base implementation of array view + * + * @tparam T array element type + */ +template class array_view_base { + public: + using iterator = T *; -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_data_type_t handle) -> tinytc_status_t { - return tinytc_data_type_retain(handle); - } - static auto release(tinytc_data_type_t handle) -> tinytc_status_t { - return tinytc_data_type_release(handle); + /** + * @brief Empty array view + */ + array_view_base() = default; + + /** + * @brief Single element view + * + * @param single the single element + */ + array_view_base(T &single) : data_{&single}, size_{1} {} + + /** + * @brief ctor + * + * @param data base pointer + * @param size array size + */ + array_view_base(T *data, std::size_t size) : data_{data}, size_{size} {} + + /** + * @brief ctor + * + * @param begin begin pointer + * @param end end pointer (not included) + */ + array_view_base(T *begin, T *end) + : data_{begin}, size_{static_cast(end - begin)} {} + + //! Begin iterator + auto begin() const -> iterator { return data_; } + //! End iterator + auto end() const -> iterator { return data_ + size_; } + //! Returns true if view is empty + auto empty() const -> bool { return size_ == 0; } + //! Returns array size + auto size() const -> std::size_t { return size_; } + //! Access first element; must not call when array size is 0 + auto front() const -> T & { return data_[0]; } + //! Access last element; must not call when array size is 0 + auto back() const -> T & { return data_[size_ - 1]; } + //! Get data pointer + auto data() const -> T * { return data_; } + //! Access operator + auto operator[](std::size_t n) const -> T & { return data_[n]; } + //! Convert to vector + operator std::vector>() const { + return std::vector>(data_, data_ + size_); + } + //! Equals operator + auto operator==(array_view_base const &other) const -> bool { + bool eq = true; + for (std::size_t i = 0; i < size_; ++i) { + eq = eq && data_[i] == other.data_[i]; + } + return eq; } -}; -} // namespace internal + auto operator!=(array_view_base const &other) const -> bool { return !(*this == other); } -//! @brief Reference-counting wrapper for tinytc_data_type_t -class data_type : public shared_handle { - public: - using shared_handle::shared_handle; + private: + T *data_ = nullptr; + std::size_t size_ = 0; }; /** - * @brief Make a scalar data type - * - * Cf. \ref tinytc_scalar_type_create - * - * @param scalar_ty Scalar type - * @param loc Source code location + * @brief Stores an immutable view on an array (pointer + size) * - * @return Data type + * @tparam T array element type */ -inline data_type make_scalar(scalar_type scalar_ty, location const &loc = {}) { - tinytc_data_type_t st; - CHECK_STATUS_LOC( - tinytc_scalar_type_create(&st, static_cast(scalar_ty), &loc), loc); - return data_type{st}; -} +template class array_view : public array_view_base { + public: + using array_view_base::array_view_base; -/** - * @brief Make a memref data type - * - * Cf. \ref tinytc_memref_type_create - * - * @param scalar_ty Element type - * @param shape Tensor shape - * @param stride Tensor stride - * @param loc Source code location - * - * @return Data type - */ -inline data_type make_memref(scalar_type scalar_ty, std::vector const &shape, - std::vector const &stride = {}, - location const &loc = {}) { - tinytc_data_type_t mt; - CHECK_STATUS_LOC(tinytc_memref_type_create(&mt, static_cast(scalar_ty), - shape.size(), shape.data(), stride.size(), - stride.data(), &loc), - loc); - return data_type{mt}; -} + /** + * @brief Convert vector to array view + * + * @param vec standard vector + */ + array_view(std::vector const &vec) + : array_view_base{(!vec.empty() ? vec.data() : nullptr), vec.size()} {} + + /** + * @brief Convert std::array to array view + * + * @tparam N array size + * @param arr standard array + */ + template + array_view(std::array const &arr) : array_view_base{arr.data(), arr.size()} {} + + /** + * @brief Convert initializer list to array view (array_view must be rvalue) + * + * @param arr initializer list + */ + array_view(std::initializer_list const &arr) + : array_view_base{(arr.begin() != arr.end() ? arr.begin() : nullptr), arr.size()} { + } +}; + +template array_view(T const &) -> array_view; +template array_view(T const *, std::size_t) -> array_view; +template array_view(T const *, T const *) -> array_view; /** - * @brief Make a group data type + * @brief Stores a mutable view on an array (pointer + size) * - * @param memref_ty Memref data type - * @param offset Offset parameter - * @param loc Source code location - * - * @return Data type + * @tparam T array element type */ -inline data_type make_group(data_type const &memref_ty, std::int64_t offset = 0, - location const &loc = {}) { - tinytc_data_type_t gt; - CHECK_STATUS_LOC(tinytc_group_type_create(>, memref_ty.get(), offset, &loc), loc); - return data_type{gt}; -} +template class mutable_array_view : public array_view_base { + public: + using array_view_base::array_view_base; + + /** + * @brief Convert vector to array view + * + * @param vec standard vector + */ + mutable_array_view(std::vector &vec) + : array_view_base{(!vec.empty() ? vec.data() : nullptr), vec.size()} {} + + /** + * @brief Convert std::array to array view + * + * @tparam N array size + * @param arr standard array + */ + template + mutable_array_view(std::array &arr) : array_view_base{arr.data(), arr.size()} {} +}; + +template mutable_array_view(T &) -> mutable_array_view; +template mutable_array_view(T *, std::size_t) -> mutable_array_view; +template mutable_array_view(T *, T *) -> mutable_array_view; //////////////////////////// -/////////// Value ////////// +///// Compiler context ///// //////////////////////////// namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_value_t handle) -> tinytc_status_t { - return tinytc_value_retain(handle); +template <> struct shared_handle_traits { + static auto retain(tinytc_compiler_context_t handle) -> tinytc_status_t { + return tinytc_compiler_context_retain(handle); } - static auto release(tinytc_value_t handle) -> tinytc_status_t { - return tinytc_value_release(handle); + static auto release(tinytc_compiler_context_t handle) -> tinytc_status_t { + return tinytc_compiler_context_release(handle); } }; } // namespace internal -//! @brief Reference-counting wrapper for tinytc_value_t -class value : public shared_handle { +//! @brief Reference-counting wrapper for tinytc_compiler_context_t +class compiler_context : public shared_handle { public: using shared_handle::shared_handle; /** - * @brief Get name + * @brief Add compiler to context + * + * @param name File name + * @param text Source text * - * @return Name + * @return Source id (should be set in position.source_id) */ - inline auto get_name() const -> char const * { - char const *name; - CHECK_STATUS(tinytc_value_get_name(obj_, &name)); - return name; + inline auto add_source(char const *name, char const *text) -> std::int32_t { + std::int32_t source_id; + CHECK_STATUS(tinytc_compiler_context_add_source(obj_, name, text, &source_id)); + return source_id; } /** - * @brief Set name + * @brief Set error reporter * - * @param name Name + * Error reporting function that is called whenever an error occurs in the parser or the + * builder. + * + * @param reporter error reporting callback + * @param user_data pointer to user data that is passed to the callback + * + * @return tinytc_status_success on success and error otherwise */ - inline void name(std::string const &name) { - CHECK_STATUS(tinytc_value_set_name(obj_, name.c_str())); + inline void set_error_reporter(error_reporter_t reporter, void *user_data) { + CHECK_STATUS(tinytc_compiler_context_set_error_reporter(obj_, reporter, user_data)); } -}; -namespace internal { -//! Is reinterpret_cast(&v) allowed, where v has type value -constexpr bool value_reinterpret_allowed = - std::is_standard_layout_v && sizeof(value) == sizeof(tinytc_value_t); -} // namespace internal + /** + * @brief Sets an optimization flag + * + * The state can be 0 (disabled), 1 (enabled), or -1 (use default according to optimization + * level). + * + * @param flag optimization flag + * @param state flag state + */ + inline void set_optimization_flag(optflag flag, std::int32_t state) { + CHECK_STATUS(tinytc_compiler_context_set_optimization_flag( + obj_, static_cast(flag), state)); + } + /** + * @brief Set optimization level + * + * @param level optimization level + */ + inline void set_optimization_level(std::int32_t level) { + CHECK_STATUS(tinytc_compiler_context_set_optimization_level(obj_, level)); + } + /** + * @brief Enhance error message with compiler context; useful when builder is used + * + * @param loc Source location + * @param what Error description + */ + inline void report_error(location const &loc, char const *what) { + CHECK_STATUS(tinytc_compiler_context_report_error(obj_, &loc, what)); + } +}; /** - * @brief Make value - * - * @param ty Data type - * @param loc Source code location + * @brief Create compiler context * - * @return Value + * @return Compiler context */ -inline auto make_value(data_type const &ty, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_value_create(&val, ty.get(), &loc), loc); - return value{val}; +inline auto make_compiler_context() -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_compiler_context_create(&ctx)); + return compiler_context{ctx}; } +//////////////////////////// +///////// Attribute //////// +//////////////////////////// + +//! Alias for tinytc_attr_t +using attr = tinytc_attr_t; +//! Alias for tinytc_named_attr_t +using named_attr = tinytc_named_attr_t; + /** - * @brief Make value + * @brief Get array attribute * - * @param scalar_ty Scalar type - * @param loc Source code location + * @param ctx compiler context + * @param array attribute array * - * @return Value + * @return Attribute */ -inline auto make_value(scalar_type scalar_ty, location const &loc = {}) -> value { - tinytc_value_t val; - auto ty = make_scalar(scalar_ty, loc); - CHECK_STATUS_LOC(tinytc_value_create(&val, ty.get(), &loc), loc); - return value{val}; +inline attr get_array_attr(compiler_context const &ctx, array_view array) { + attr a; + CHECK_STATUS(tinytc_array_attr_get(&a, ctx.get(), array.size(), array.data())); + return a; } /** - * @brief Make immediate value + * @brief Get boolean attribute * - * Type is f32. - * - * @param imm Float value - * @param loc Source code location + * @param ctx compiler context + * @param value boolean value * - * @return Value + * @return Attribute */ -inline auto make_imm(float imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_float_imm_create(&val, imm, tinytc_scalar_type_f32, &loc), loc); - return value{val}; +inline attr get_boolean_attr(compiler_context const &ctx, bool value) { + attr a; + CHECK_STATUS(tinytc_boolean_attr_get(&a, ctx.get(), value)); + return a; } /** - * @brief Make immediate value + * @brief Get dictionary attribute * - * @param imm Float value - * @param type Type of immediate value - * @param loc Source code location + * Each name must only appear once. + * + * @param ctx compiler context + * @param items named items array * - * @return Value + * @return Attribute */ -inline auto make_imm(double imm, scalar_type type = scalar_type::f64, location const &loc = {}) - -> value { - tinytc_value_t val; - CHECK_STATUS_LOC( - tinytc_float_imm_create(&val, imm, static_cast(type), &loc), loc); - return value{val}; +inline attr get_dictionary_attr(compiler_context const &ctx, mutable_array_view items) { + attr a; + CHECK_STATUS(tinytc_dictionary_attr_get(&a, ctx.get(), items.size(), items.data())); + return a; } /** - * @brief Make immediate value + * @brief Get dictionary attribute * - * Type is i8. + * The list of items must be sorted by name and each name must only appear once. * - * @param imm Int value - * @param loc Source code location + * @param ctx compiler context + * @param items named items array * - * @return Value + * @return Attribute */ -inline auto make_imm(std::int8_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i8, &loc), loc); - return value{val}; +inline attr get_dictionary_attr_with_sorted(compiler_context const &ctx, + array_view items) { + attr a; + CHECK_STATUS(tinytc_dictionary_attr_get_with_sorted(&a, ctx.get(), items.size(), items.data())); + return a; } /** - * @brief Make immediate value + * @brief Sort list of items * - * Type is i16. + * Each name must only appear once. * - * @param imm Int value - * @param loc Source code location - * - * @return Value + * @param items named items array */ -inline auto make_imm(std::int16_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i16, &loc), loc); - return value{val}; +inline void sort_items(mutable_array_view items) { + CHECK_STATUS(tinytc_dictionary_attr_sort(items.size(), items.data())); } /** - * @brief Make immediate value + * @brief Get integer attribute * - * Type is i32. - * - * @param imm Int value - * @param loc Source code location + * @param ctx compiler context + * @param value integer value * - * @return Value + * @return Attribute */ -inline auto make_imm(std::int32_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i32, &loc), loc); - return value{val}; +inline attr get_integer_attr(compiler_context const &ctx, std::int64_t value) { + attr a; + CHECK_STATUS(tinytc_integer_attr_get(&a, ctx.get(), value)); + return a; } /** - * @brief Make immediate value + * @brief Get string attribute * - * @param imm Int value - * @param type Type of immediate value - * @param loc Source code location + * @param ctx compiler context + * @param str string * - * @return Value + * @return Attribute */ -inline auto make_imm(std::int64_t imm, scalar_type type = scalar_type::i64, - location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC( - tinytc_int_imm_create(&val, imm, static_cast(type), &loc), loc); - return value{val}; +inline attr get_string_attr(compiler_context const &ctx, std::string_view str) { + attr a; + CHECK_STATUS(tinytc_string_attr_get(&a, ctx.get(), str.size(), str.data())); + return a; } +//////////////////////////// +///////// Data type //////// +//////////////////////////// + +//! Check if mode i is dynamic ('?') +inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } + +//! Alias for tinytc_data_type_t +using data_type = tinytc_data_type_t; + /** - * @brief Make immediate index value + * @brief Get the boolean data type * - * @param imm index value - * @param loc Source code location + * Cf. \ref tinytc_boolean_type_get + * + * @param ctx Compiler context * - * @return Value + * @return Data type */ -inline auto make_index(std::int32_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_index, &loc), loc); - return value{val}; +inline data_type get_boolean(compiler_context const &ctx) { + tinytc_data_type_t bt; + CHECK_STATUS(tinytc_boolean_type_get(&bt, ctx.get())); + return bt; } /** - * @brief Make immediate index value + * @brief Get a scalar data type * - * @param imm index value - * @param loc Source code location + * Cf. \ref tinytc_scalar_type_get + * + * @param ctx Compiler context + * @param scalar_ty Scalar type * - * @return Value + * @return Data type */ -inline auto make_index(std::int64_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_index, &loc), loc); - return value{val}; +inline data_type get_scalar(compiler_context const &ctx, scalar_type scalar_ty) { + tinytc_data_type_t st; + CHECK_STATUS( + tinytc_scalar_type_get(&st, ctx.get(), static_cast(scalar_ty))); + return st; } /** - * @brief Make dynamic ('?') + * @brief Get a memref data type * + * Cf. \ref tinytc_memref_type_get + * + * @param scalar_ty Element type + * @param shape Tensor shape + * @param stride Tensor stride + * @param addrspace Address space * @param loc Source code location * - * @return Value + * @return Data type */ -inline auto make_dynamic(location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, dynamic, tinytc_scalar_type_i64, &loc), loc); - return value{val}; -} +inline data_type get_memref(data_type scalar_ty, array_view shape, + array_view stride = {}, + address_space addrspace = address_space::global, + location const &loc = {}) { + tinytc_data_type_t mt; + CHECK_STATUS_LOC(tinytc_memref_type_get(&mt, scalar_ty, shape.size(), shape.data(), + stride.size(), stride.data(), + static_cast(addrspace), &loc), + loc); + return mt; +} + +/** + * @brief Get a group data type + * + * @param memref_ty Memref data type + * @param size Size parameter + * @param offset Offset parameter + * @param loc Source code location + * + * @return Data type + */ +inline data_type get_group(data_type memref_ty, std::int64_t size, std::int64_t offset = 0, + location const &loc = {}) { + tinytc_data_type_t gt; + CHECK_STATUS_LOC(tinytc_group_type_get(>, memref_ty, size, offset, &loc), loc); + return gt; +} + +/** + * @brief Get a coopmatrix data type + * + * @param scalar_ty Component type + * @param rows Number of rows + * @param cols Number of cols + * @param use Matrix use + * @param loc Source code location + * + * @return Data type + */ +inline data_type get_coopmatrix(data_type scalar_ty, std::int64_t rows, std::int64_t cols, + matrix_use use, location const &loc = {}) { + tinytc_data_type_t ct; + CHECK_STATUS_LOC(tinytc_coopmatrix_type_get(&ct, scalar_ty, rows, cols, + static_cast<::tinytc_matrix_use_t>(use), &loc), + loc); + return ct; +} + +/** + * @brief Get a void data type + * + * @param ctx Context + * + * @return Data type + */ +inline data_type get_void(compiler_context const &ctx) { + tinytc_data_type_t vt; + CHECK_STATUS(tinytc_void_type_get(&vt, ctx.get())); + return vt; +} + +//////////////////////////// +/////////// Value ////////// +//////////////////////////// + +//! @brief OO-wrapper for tinytc_value_t +class value : public handle { + public: + using handle::handle; + + /** + * @brief Get name + * + * @return Name as C-string + */ + inline auto get_name() -> char const * { + char const *name; + CHECK_STATUS(tinytc_value_get_name(obj_, &name)); + return name; + } + + /** + * @brief Set value name + * + * @param name Name + */ + inline void set_name(std::string_view name) { + CHECK_STATUS(tinytc_value_set_name_n(obj_, name.size(), name.data())); + } + + /** + * @brief Get type + * + * @return Data type + */ + inline auto get_type() -> data_type { + tinytc_data_type_t ty; + CHECK_STATUS(tinytc_value_get_type(obj_, &ty)); + return ty; + } +}; +static_assert(std::is_standard_layout_v && sizeof(value) == sizeof(tinytc_value_t)); //////////////////////////// /////////// Inst /////////// //////////////////////////// +/** + * @brief Convert address space to string + * + * @param as Address space + * + * @return C-string + */ +inline char const *to_string(address_space as) { + return ::tinytc_address_space_to_string(static_cast<::tinytc_address_space_t>(as)); +} + /** * @brief Convert arithmetic operation type to string * @@ -592,6 +1169,28 @@ inline char const *to_string(arithmetic_unary op) { return ::tinytc_arithmetic_unary_to_string(static_cast<::tinytc_arithmetic_unary_t>(op)); } +/** + * @brief Convert builtin type to string + * + * @param b Builtin type + * + * @return C-string + */ +inline char const *to_string(builtin b) { + return ::tinytc_builtin_to_string(static_cast<::tinytc_builtin_t>(b)); +} + +/** + * @brief Convert checked flag string + * + * @param flag Flag + * + * @return C-string + */ +inline char const *to_string(checked_flag flag) { + return ::tinytc_checked_flag_to_string(static_cast<::tinytc_checked_flag_t>(flag)); +} + /** * @brief Convert cmp condition to string * @@ -603,6 +1202,72 @@ inline char const *to_string(cmp_condition cond) { return ::tinytc_cmp_condition_to_string(static_cast<::tinytc_cmp_condition_t>(cond)); } +/** + * @brief Convert subgroup arithmetic to string + * + * @param op Operation type + * + * @return C-string + */ +inline char const *to_string(group_arithmetic op) { + return ::tinytc_group_arithmetic_to_string(static_cast<::tinytc_group_arithmetic_t>(op)); +} + +/** + * @brief Convert subgroup operation to string + * + * @param op Operation type + * + * @return C-string + */ +inline char const *to_string(group_operation op) { + return ::tinytc_group_operation_to_string(static_cast<::tinytc_group_operation_t>(op)); +} + +/** + * @brief Convert reduce mode to string + * + * @param m Reduce mode + * + * @return C-string + */ +inline char const *to_string(reduce_mode m) { + return ::tinytc_reduce_mode_to_string(static_cast<::tinytc_reduce_mode_t>(m)); +} + +/** + * @brief Convert math operation type to string (unary) + * + * @param op Math operation type + * + * @return C-string + */ +inline char const *to_string(math_unary op) { + return ::tinytc_math_unary_to_string(static_cast<::tinytc_math_unary_t>(op)); +} + +/** + * @brief Convert matrix use to string + * + * @param u Matrix use + * + * @return C-string + */ +inline char const *to_string(matrix_use u) { + return ::tinytc_matrix_use_to_string(static_cast<::tinytc_matrix_use_t>(u)); +} + +/** + * @brief Convert store flag to string + * + * @param flag Store flag + * + * @return C-string + */ +inline char const *to_string(store_flag flag) { + return ::tinytc_store_flag_to_string(static_cast(flag)); +} + /** * @brief Convert transpose to string * @@ -615,166 +1280,520 @@ inline char const *to_string(transpose t) { } namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_inst_t handle) -> tinytc_status_t { - return tinytc_inst_retain(handle); - } - static auto release(tinytc_inst_t handle) -> tinytc_status_t { - return tinytc_inst_release(handle); - } +template <> struct unique_handle_traits { + static void destroy(tinytc_inst_t handle) { return tinytc_inst_destroy(handle); } }; } // namespace internal +class region; + //! @brief Reference-counting wrapper for tinytc_inst_t -class inst : public shared_handle { +class inst : public unique_handle { public: - using shared_handle::shared_handle; + using unique_handle::unique_handle; /** - * @brief Get result value + * @brief Get result values + * + * May be called with empty view (vals = {}) to get the number of results. + * + * @param vals view on buffer that stores results * - * @return Value; may be empty + * @return Minimum of view size and actual number of result values */ - inline auto get_value() const -> value { - tinytc_value_t result; - CHECK_STATUS(tinytc_inst_get_value(obj_, &result)); - return value{result}; + inline auto get_values(mutable_array_view vals) const -> std::uint32_t { + std::uint32_t result_list_size = vals.size(); + tinytc_value_t *vs = reinterpret_cast(vals.data()); + CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, vs)); + return result_list_size; } /** - * @brief Get result values + * @brief Get child regions + * + * May be called with empty view (vals = {}) to get the number of child regions. * - * @return Vector of values + * @param regs view on buffer that stores results + * + * @return Minimum of view size and actual number of child regions */ - inline auto get_values() const -> std::vector { - static_assert(internal::value_reinterpret_allowed); - std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, nullptr)); - auto values = std::vector(result_list_size); - tinytc_value_t *result_list = reinterpret_cast(values.data()); - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, result_list)); - return values; + inline auto get_regions(mutable_array_view regs) const -> std::uint32_t { + std::uint32_t result_list_size = regs.size(); + tinytc_region_t *rl = reinterpret_cast(regs.data()); + CHECK_STATUS(tinytc_inst_get_regions(obj_, &result_list_size, rl)); + return result_list_size; } -}; -namespace internal { -//! Is reinterpret_cast(&i) allowed, where i has type inst -constexpr bool inst_reinterpret_allowed = - std::is_standard_layout_v && sizeof(inst) == sizeof(tinytc_inst_t); -} // namespace internal + /** + * @brief Set attribute + * + * @param a attribute + */ + inline void set_attr(attr a) { CHECK_STATUS(tinytc_inst_set_attr(obj_, a)); } +}; //////////////////////////// ////////// Region ////////// //////////////////////////// -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_region_t handle) -> tinytc_status_t { - return tinytc_region_retain(handle); - } - static auto release(tinytc_region_t handle) -> tinytc_status_t { - return tinytc_region_release(handle); - } -}; -} // namespace internal +//! @brief OO-wrapper for tinytc_region_t +class region : public handle { + public: + using handle::handle; + + /** + * @brief Append instruction to region + * + * @param instruction instruction object + */ + inline void append(inst instruction) { + CHECK_STATUS(tinytc_region_append(obj_, instruction.release())); + } + + /** + * @brief Get iterator pointing to the begin of the region + * + * @return iterator + */ + inline auto begin() -> tinytc_inst_iterator_t { + tinytc_inst_iterator_t it; + CHECK_STATUS(tinytc_region_begin(obj_, &it)); + return it; + } + + /** + * @brief Get iterator pointing to past the end of the region + * + * @return iterator + */ + inline auto end() -> tinytc_inst_iterator_t { + tinytc_inst_iterator_t it; + CHECK_STATUS(tinytc_region_end(obj_, &it)); + return it; + } + + /** + * @brief Erase instruction at iterator + * + * @param iterator Iterator + * + * @return Iterator pointing to the instruction after the one erased + */ + inline auto erase(tinytc_inst_iterator_t iterator) -> tinytc_inst_iterator_t { + auto it = iterator; + CHECK_STATUS(tinytc_region_erase(obj_, &it)); + return it; + } + + /** + * @brief Insert instruction into region before the iterator + * + * @param iterator Iterator + * @param instruction instruction object + * + * @return Iterator pointing to the newly inserted instruction + */ + inline auto insert(tinytc_inst_iterator_t iterator, inst instruction) + -> tinytc_inst_iterator_t { + auto it = iterator; + CHECK_STATUS(tinytc_region_insert(obj_, &it, instruction.release())); + return it; + } + + /** + * + * @brief Get region parameters + * + * May be called with empty view (vals = {}) to get the number of parameters. + * + * @param params view on buffer that stores parameters + * + * @return Minimum of view size and actual number of parameters + */ + inline auto get_parameters(mutable_array_view params) -> std::uint32_t { + std::uint32_t result_list_size = params.size(); + tinytc_value_t *ps = reinterpret_cast(params.data()); + CHECK_STATUS(tinytc_region_get_parameters(obj_, &result_list_size, ps)); + return result_list_size; + } +}; +static_assert(std::is_standard_layout_v && sizeof(region) == sizeof(tinytc_region_t)); + +/** + * @brief Move iterator to next instruction + * + * @param iterator + */ +inline void next(tinytc_inst_iterator_t &iterator) { CHECK_STATUS(tinytc_next_inst(&iterator)); } +/** + * @brief Move iterator to previous instruction + * + * @param iterator + */ +inline void prev(tinytc_inst_iterator_t &iterator) { CHECK_STATUS(tinytc_prev_inst(&iterator)); } + +//////////////////////////// +/////// Instructions /////// +//////////////////////////// + +/** + * @brief Make arithmetic instruction (binary) + * + * @param op Arithmetic operation type + * @param a First operand + * @param b Second operand + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_arith(arithmetic op, value a, value b, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_arith_inst_create(&instr, static_cast(op), a, b, ty, &loc), + loc); + return inst(instr); +} + +/** + * @brief Make arithmetic instruction (unary) + * + * @param op Arithmetic operation type + * @param a Operand + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_arith(arithmetic_unary op, value a, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_arith_unary_inst_create( + &instr, static_cast(op), a, ty, &loc), + loc); + return inst(instr); +} + +/** + * @brief Make barrier instruction + * + * @param fence_flags address space(s) of memory fence; set to 0 for no fence + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_barrier(tinytc_address_spaces_t fence_flags, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_barrier_inst_create(&instr, fence_flags, &loc), loc); + return inst(instr); +} + +/** + * @brief Make builtin instruction + * + * @param btype Builtin type + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_builtin(builtin btype, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_builtin_inst_create(&instr, static_cast(btype), ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make cast instruction + * + * @param a Operand + * @param to_ty Target type + * @param loc Source code lcoation + * + * @return Instruction + */ +inline inst make_cast(value a, data_type to_ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cast_inst_create(&instr, a, to_ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make compare instruction + * + * @param cond Condition type + * @param a First operand + * @param b Second operand + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_cmp(cmp_condition cond, value a, value b, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_cmp_inst_create(&instr, static_cast(cond), a, b, ty, &loc), + loc); + return inst(instr); +} + +/** + * @brief Make boolean constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(bool value, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_boolean(&instr, value, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make complex constant + * + * @param value Complex constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(std::complex value, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_constant_inst_create_complex(&instr, value.real(), value.imag(), ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make floating constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(double value, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_float(&instr, value, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make integer constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(std::int32_t value, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make integer constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant(std::int64_t value, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make multiplicative identity constant ("1") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant_one(data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_one(&instr, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make additive identity constant ("0") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_constant_zero(data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_zero(&instr, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Create cooperative matrix apply instruction + * + * @param mat %mat + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +inline inst make_cooperative_matrix_apply(value mat, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cooperative_matrix_apply_inst_create(&instr, mat, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Create cooperative matrix extract instruction + * + * @param mat %mat + * @param index index + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +inline inst make_cooperative_matrix_extract(value mat, std::int64_t index, data_type ty, + location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cooperative_matrix_extract_inst_create(&instr, mat, index, ty, &loc), + loc); + return inst(instr); +} -//! @brief Reference-counting wrapper for tinytc_region_t -class region : public shared_handle { - public: - using shared_handle::shared_handle; -}; +/** + * @brief Create cooperative matrix insert instruction + * + * @param val %val + * @param mat %mat + * @param index index + * @param ty [in] result type + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +inline inst make_cooperative_matrix_insert(value val, value mat, std::int64_t index, data_type ty, + location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_cooperative_matrix_insert_inst_create(&instr, val, mat, index, ty, &loc), loc); + return inst(instr); +} /** - * @brief Make region + * @brief Create cooperative matrix load instruction * - * @param instructions Vector of instructions + * @param trans transpose operation applied on load + * @param flag out-of-bounds checks type + * @param op %op + * @param p0 %p0 + * @param p1 %p1 + * @param to_ty result type * @param loc Source code location * - * @return Region + * @return Instruction */ -inline region make_region(std::vector &instructions, location const &loc = {}) { - tinytc_region_t reg; - static_assert(internal::inst_reinterpret_allowed); - if (instructions.size() > std::numeric_limits::max()) { - throw std::out_of_range("Instruction list too long"); - } - CHECK_STATUS_LOC(tinytc_region_create(®, instructions.size(), - reinterpret_cast(instructions.data()), - &loc), +inline inst make_cooperative_matrix_load(transpose trans, checked_flag flag, value op, value p0, + value p1, data_type to_ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cooperative_matrix_load_inst_create( + &instr, static_cast(trans), + static_cast(flag), op, p0, p1, to_ty, &loc), loc); - return region{reg}; + return inst(instr); } -//////////////////////////// -/////// Instructions /////// -//////////////////////////// - /** - * @brief Make arithmetic instruction (binary) + * @brief Create cooperative matrix mul add instruction * - * @param op Arithmetic operation type - * @param a First operand - * @param b Second operand + * @param a %a + * @param b %b + * @param c %c + * @param to_ty result type * @param loc Source code location * * @return Instruction */ -inline inst make_arith(arithmetic op, value const &a, value const &b, location const &loc = {}) { +inline inst make_cooperative_matrix_mul_add(value a, value b, value c, data_type to_ty, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_arith_inst_create(&instr, static_cast(op), a.get(), - b.get(), &loc), + CHECK_STATUS_LOC(tinytc_cooperative_matrix_mul_add_inst_create(&instr, a, b, c, to_ty, &loc), loc); return inst(instr); } /** - * @brief Make arithmetic instruction (unary) + * @brief Create cooperative matrix prefetch instruction * - * @param op Arithmetic operation type - * @param a Operand + * @param cache_level Cache-level; "0" is closest to the core + * @param op %op + * @param p0 %p0 + * @param p1 %p1 + * @param rows Number of rows + * @param cols Number of cols * @param loc Source code location * * @return Instruction */ -inline inst make_arith(arithmetic_unary op, value const &a, location const &loc = {}) { +inline inst make_cooperative_matrix_prefetch(std::int32_t cache_level, value op, value p0, value p1, + std::int32_t rows, std::int32_t cols, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_arith_unary_inst_create( - &instr, static_cast(op), a.get(), &loc), + CHECK_STATUS_LOC(tinytc_cooperative_matrix_prefetch_inst_create(&instr, cache_level, op, p0, p1, + rows, cols, &loc), loc); return inst(instr); } /** - * @brief Make cast instruction + * @brief Create cooperative matrix scale instruction * - * @param a Operand - * @param to_ty Target type - * @param loc Source code lcoation + * @param a %a + * @param b %b + * @param ty Result type + * @param loc Source code location * * @return Instruction */ -inline inst make_cast(value const &a, scalar_type to_ty, location const &loc = {}) { +inline inst make_cooperative_matrix_scale(value a, value b, data_type ty, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_cast_inst_create(&instr, a.get(), static_cast(to_ty), &loc), - loc); + CHECK_STATUS_LOC(tinytc_cooperative_matrix_scale_inst_create(&instr, a, b, ty, &loc), loc); return inst(instr); } /** - * @brief Make compare instruction + * @brief Create cooperative matrix store instruction * - * @param cond Condition type - * @param a First operand - * @param b Second operand + * @param cflag out-of-bounds checks type + * @param sflag store flag + * @param val %val + * @param op %op + * @param p0 %p0 + * @param p1 %p1 * @param loc Source code location * * @return Instruction */ -inline inst make_cmp(cmp_condition cond, value const &a, value const &b, location const &loc = {}) { +inline inst make_cooperative_matrix_store(checked_flag cflag, store_flag sflag, value val, value op, + value p0, value p1, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_cmp_inst_create(&instr, static_cast(cond), - a.get(), b.get(), &loc), + CHECK_STATUS_LOC(tinytc_cooperative_matrix_store_inst_create( + &instr, static_cast(cflag), + static_cast(sflag), val, op, p0, p1, &loc), loc); return inst(instr); } @@ -787,9 +1806,9 @@ inline inst make_cmp(cmp_condition cond, value const &a, value const &b, locatio * * @return Instruction */ -inline inst make_alloca(data_type const &ty, location const &loc = {}) { +inline inst make_alloca(data_type ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_alloca_inst_create(&instr, ty.get(), &loc), loc); + CHECK_STATUS_LOC(tinytc_alloca_inst_create(&instr, ty, &loc), loc); return inst(instr); } @@ -806,36 +1825,63 @@ inline inst make_alloca(data_type const &ty, location const &loc = {}) { * * @return Instruction */ -inline inst make_axpby(transpose tA, bool atomic, value const &alpha, value const &A, - value const &beta, value const &B, location const &loc = {}) { +inline inst make_axpby(transpose tA, bool atomic, value alpha, value A, value beta, value B, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_axpby_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), beta.get(), B.get(), &loc), + alpha, A, beta, B, &loc), loc); return inst(instr); } +/** + * @brief Make cumsum instruction + * + * @param atomic true for atomic updates of B + * @param alpha @f$\alpha@f$ + * @param A A + * @param mode n (summation mode) + * @param beta @f$\beta@f$ + * @param B B + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_cumsum(bool atomic, value alpha, value A, std::int64_t mode, value beta, value B, + location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_cumsum_inst_create(&instr, atomic, alpha, A, mode, beta, B, &loc), loc); + return inst(instr); +} + /** * @brief Make expand instruction * * @param a Operand - * @param mode Expanded mode - * @param expand_shape New shape of mode + * @param expanded_mode Expanded mode + * @param static_expand_shape Static expand shape + * @param expand_shape Dynamic expand shape + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_expand(value const &a, std::int64_t mode, std::vector const &expand_shape, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline inst make_expand(value a, std::int64_t expanded_mode, + array_view static_expand_shape, + array_view expand_shape, data_type ty, location const &loc = {}) { tinytc_inst_t instr; + auto static_len = static_expand_shape.size(); + if (static_len > std::numeric_limits::max()) { + throw std::out_of_range("static expand shape too large"); + } auto len = expand_shape.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("expand shape too large"); } - tinytc_value_t *eshape = - const_cast(reinterpret_cast(expand_shape.data())); - CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a.get(), mode, len, eshape, &loc), loc); + const tinytc_value_t *es = reinterpret_cast(expand_shape.data()); + CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a, expanded_mode, static_len, + static_expand_shape.data(), len, es, ty, &loc), + loc); return inst(instr); } @@ -845,14 +1891,15 @@ inline inst make_expand(value const &a, std::int64_t mode, std::vector co * @param a Operand * @param from First mode to fuse * @param to Last mode to fuse + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_fuse(value const &a, std::int64_t from, std::int64_t to, +inline inst make_fuse(value a, std::int64_t from, std::int64_t to, data_type ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a.get(), from, to, &loc), loc); + CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a, from, to, ty, &loc), loc); return inst(instr); } @@ -861,47 +1908,20 @@ inline inst make_fuse(value const &a, std::int64_t from, std::int64_t to, * * @param a Operand * @param index_list Vector of indices + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_load(value const &a, std::vector const &index_list, +inline inst make_load(value a, array_view index_list, tinytc_data_type_t ty, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } - tinytc_value_t *il = - const_cast(reinterpret_cast(index_list.data())); - CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a.get(), len, il, &loc), loc); - return inst(instr); -} - -/** - * @brief Make group id instruction - * - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_group_id(location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_id_inst_create(&instr, &loc), loc); - return inst(instr); -} - -/** - * @brief Make group size instruction - * - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_group_size(location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_size_inst_create(&instr, &loc), loc); + const tinytc_value_t *il = reinterpret_cast(index_list.data()); + CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a, len, il, ty, &loc), loc); return inst(instr); } @@ -920,13 +1940,12 @@ inline inst make_group_size(location const &loc = {}) { * * @return Instruction */ -inline inst make_gemm(transpose tA, transpose tB, bool atomic, value const &alpha, value const &A, - value const &B, value const &beta, value const &C, location const &loc = {}) { +inline inst make_gemm(transpose tA, transpose tB, bool atomic, value alpha, value A, value B, + value beta, value C, location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_gemm_inst_create(&instr, static_cast(tA), - static_cast(tB), atomic, - alpha.get(), A.get(), B.get(), beta.get(), C.get(), - &loc), + static_cast(tB), atomic, alpha, A, + B, beta, C, &loc), loc); return inst(instr); } @@ -945,12 +1964,11 @@ inline inst make_gemm(transpose tA, transpose tB, bool atomic, value const &alph * * @return Instruction */ -inline inst make_gemv(transpose tA, bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { +inline inst make_gemv(transpose tA, bool atomic, value alpha, value A, value B, value beta, value C, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_gemv_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), B.get(), beta.get(), C.get(), - &loc), + alpha, A, B, beta, C, &loc), loc); return inst(instr); } @@ -968,12 +1986,10 @@ inline inst make_gemv(transpose tA, bool atomic, value const &alpha, value const * * @return Instruction */ -inline inst make_ger(bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { +inline inst make_ger(bool atomic, value alpha, value A, value B, value beta, value C, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_ger_inst_create(&instr, atomic, alpha.get(), A.get(), B.get(), - beta.get(), C.get(), &loc), - loc); + CHECK_STATUS_LOC(tinytc_ger_inst_create(&instr, atomic, alpha, A, B, beta, C, &loc), loc); return inst(instr); } @@ -990,12 +2006,41 @@ inline inst make_ger(bool atomic, value const &alpha, value const &A, value cons * * @return Instruction */ -inline inst make_hadamard(bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { +inline inst make_hadamard(bool atomic, value alpha, value A, value B, value beta, value C, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_hadamard_inst_create(&instr, atomic, alpha.get(), A.get(), B.get(), - beta.get(), C.get(), &loc), - loc); + CHECK_STATUS_LOC(tinytc_hadamard_inst_create(&instr, atomic, alpha, A, B, beta, C, &loc), loc); + return inst(instr); +} + +/** + * @brief Make math instruction (unary) + * + * @param op Math operation type + * @param a Operand + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_math(math_unary op, value a, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_math_unary_inst_create(&instr, static_cast(op), a, ty, &loc), + loc); + return inst(instr); +} + +/** + * @brief Make parallel region + * + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_parallel(location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_parallel_inst_create(&instr, &loc), loc); return inst(instr); } @@ -1004,13 +2049,51 @@ inline inst make_hadamard(bool atomic, value const &alpha, value const &A, value * * @param a Operand * @param mode Mode + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_size(value a, std::int64_t mode, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a, mode, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make subgroup broadcast instruction + * + * @param a Operand + * @param idx Subgroup local index + * @param ty Result type + * @param loc Source code location + * + * @return Instruction + */ +inline inst make_subgroup_broadcast(value a, value idx, data_type ty, location const &loc = {}) { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_subgroup_broadcast_inst_create(&instr, a, idx, ty, &loc), loc); + return inst(instr); +} + +/** + * @brief Make subgroup operation instruction + * + * @param arith Group arithmetic + * @param op Group operation + * @param a Operand + * @param ty Result type * @param loc Source code location * * @return Instruction */ -inline inst make_size(value const &a, std::int64_t mode, location const &loc = {}) { +inline inst make_subgroup_operation(group_arithmetic arith, group_operation op, value a, + data_type ty, location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a.get(), mode, &loc), loc); + CHECK_STATUS_LOC(tinytc_subgroup_operation_inst_create( + &instr, static_cast(arith), + static_cast(op), a, ty, &loc), + loc); return inst(instr); } @@ -1018,34 +2101,50 @@ inline inst make_size(value const &a, std::int64_t mode, location const &loc = { * @brief Make subview instruction * * @param a Operand - * @param offset_list Vector of offsets - * @param size_list Vector of sizes; initialize with empty value if only offset is required + * @param static_offset_list Static offsets + * @param static_size_list Static sizes + * @param offset_list Vector of offsets; need to add dynamic offsets here if static_offset_list + * contains "dynamic" + * @param size_list Vector of sizes; need to add dynamic sizes here if static_size_list contains + * "dynamic" + * @param ty Return type * @param loc Source code location * * @return Instruction */ -inline inst make_subview(value const &a, std::vector const &offset_list, - std::vector const &size_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline inst make_subview(value a, array_view static_offset_list, + array_view static_size_list, array_view offset_list, + array_view size_list, data_type ty, location const &loc = {}) { tinytc_inst_t instr; - if (offset_list.size() != size_list.size()) { - throw std::invalid_argument("offset list must have the same length as the size list"); - } - auto len = offset_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("slice list too long"); - } - tinytc_value_t *ol = - const_cast(reinterpret_cast(offset_list.data())); - tinytc_value_t *sl = - const_cast(reinterpret_cast(size_list.data())); - CHECK_STATUS_LOC(tinytc_subview_inst_create(&instr, a.get(), len, ol, sl, &loc), loc); + if (static_offset_list.size() != static_size_list.size()) { + throw std::invalid_argument( + "static offset list must have the same length as the static size list"); + } + auto static_len = static_offset_list.size(); + if (static_len > std::numeric_limits::max()) { + throw std::out_of_range("static slice list too long"); + } + auto offset_len = offset_list.size(); + if (offset_len > std::numeric_limits::max()) { + throw std::out_of_range("dynamic offset list too long"); + } + auto size_len = size_list.size(); + if (size_len > std::numeric_limits::max()) { + throw std::out_of_range("dynamic size list too long"); + } + const tinytc_value_t *ol = reinterpret_cast(offset_list.data()); + const tinytc_value_t *sl = reinterpret_cast(size_list.data()); + CHECK_STATUS_LOC(tinytc_subview_inst_create(&instr, a, static_len, static_offset_list.data(), + static_size_list.data(), offset_len, ol, size_len, + sl, ty, &loc), + loc); return inst(instr); } /** * @brief Make store instruction * + * @param flag store flag * @param val Value that is stored * @param a Target memref * @param index_list Vector of indices @@ -1053,17 +2152,17 @@ inline inst make_subview(value const &a, std::vector const &offset_list, * * @return Instruction */ -inline inst make_store(value const &val, value const &a, std::vector const &index_list, +inline inst make_store(store_flag flag, value val, value a, array_view index_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); tinytc_inst_t instr; auto len = index_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("index list too long"); } - tinytc_value_t *il = - const_cast(reinterpret_cast(index_list.data())); - CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val.get(), a.get(), len, il, &loc), loc); + const tinytc_value_t *il = reinterpret_cast(index_list.data()); + CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, static_cast(flag), val, + a, len, il, &loc), + loc); return inst(instr); } @@ -1080,11 +2179,11 @@ inline inst make_store(value const &val, value const &a, std::vector cons * * @return Instruction */ -inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const &A, - value const &beta, value const &B, location const &loc = {}) { +inline inst make_sum(transpose tA, bool atomic, value alpha, value A, value beta, value B, + location const &loc = {}) { tinytc_inst_t instr; CHECK_STATUS_LOC(tinytc_sum_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), beta.get(), B.get(), &loc), + alpha, A, beta, B, &loc), loc); return inst(instr); } @@ -1092,20 +2191,31 @@ inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const /** * @brief Make for loop instruction * - * @param loop_var Loop variable + * @param loop_var_type Type of loop variable * @param from Loop variable start * @param to Loop variable bound - * @param step Loop variable step - * @param body Loop body + * @param step Loop variable step; can be {} + * @param initial_value_list Array of initial values; can be {} + * @param return_type_list Array of return types; can be {} * @param loc Source code location * * @return Instruction */ -inline inst make_for(value const &loop_var, value const &from, value const &to, value const &step, - region const &body, location const &loc = {}) { +inline inst make_for(data_type loop_var_type, value from, value to, value step, + array_view initial_value_list, array_view return_type_list, + location const &loc = {}) { tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, loop_var.get(), from.get(), to.get(), - step.get(), body.get(), &loc), + auto len = initial_value_list.size(); + if (len > std::numeric_limits::max()) { + throw std::out_of_range("initial value list too long"); + } + if (len != return_type_list.size()) { + throw std::invalid_argument( + "initial value list must have the same length as the return type list"); + } + const tinytc_value_t *il = reinterpret_cast(initial_value_list.data()); + CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, loop_var_type, from, to, step, len, il, + return_type_list.data(), &loc), loc); return inst(instr); } @@ -1113,20 +2223,32 @@ inline inst make_for(value const &loop_var, value const &from, value const &to, /** * @brief Make foreach loop instruction * - * @param loop_var Loop variable - * @param from Loop variable start - * @param to Loop variable bound - * @param body Loop body + * @param loop_var_type Type of loop variable + * @param from_list List of loop variable start + * @param to_list List of loop variable bound * @param loc Source code location * * @return Instruction */ -inline inst make_foreach(value const &loop_var, value const &from, value const &to, - region const &body, location const &loc = {}) { +inline inst make_foreach(data_type loop_var_type, array_view from_list, + array_view to_list, location const &loc = {}) { + tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_foreach_inst_create(&instr, loop_var.get(), from.get(), to.get(), body.get(), &loc), - loc); + if (from_list.size() != to_list.size()) { + throw std::invalid_argument("from list must have the same length as the to list"); + } + const auto from_len = from_list.size(); + if (from_len > std::numeric_limits::max()) { + throw std::out_of_range("from list too long"); + } + const auto to_len = to_list.size(); + if (to_len > std::numeric_limits::max()) { + throw std::out_of_range("to list too long"); + } + const tinytc_value_t *fl = reinterpret_cast(from_list.data()); + const tinytc_value_t *tl = reinterpret_cast(to_list.data()); + CHECK_STATUS_LOC(tinytc_foreach_inst_create(&instr, loop_var_type, from_len, fl, tl, &loc), + loc); return inst(instr); } @@ -1134,28 +2256,19 @@ inline inst make_foreach(value const &loop_var, value const &from, value const & * @brief Make if condition instruction * * @param condition Condition value (of type bool) - * @param then Then region - * @param otherwise Else region * @param return_type_list Types of returned values * @param loc Source code location * * @return Instruction */ -inline inst make_if(value const &condition, region const &then, region const &otherwise = region{}, - std::vector const &return_type_list = {}, +inline inst make_if(value condition, array_view return_type_list = {}, location const &loc = {}) { tinytc_inst_t instr; auto len = return_type_list.size(); if (len > std::numeric_limits::max()) { throw std::out_of_range("return type list too long"); } - auto rl_vec = std::vector(); - rl_vec.resize(len); - for (auto const &rt : return_type_list) { - rl_vec.emplace_back(static_cast(rt)); - } - CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition.get(), then.get(), otherwise.get(), - len, rl_vec.data(), &loc), + CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition, len, return_type_list.data(), &loc), loc); return inst(instr); } @@ -1168,15 +2281,13 @@ inline inst make_if(value const &condition, region const &then, region const &ot * * @return Instruction */ -inline inst make_yield(std::vector const &yield_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); +inline inst make_yield(array_view yield_list, location const &loc = {}) { tinytc_inst_t instr; auto len = yield_list.size(); if (len > std::numeric_limits::max()) { - throw std::out_of_range("slice list too long"); + throw std::out_of_range("yield list too long"); } - tinytc_value_t *yl = - const_cast(reinterpret_cast(yield_list.data())); + const tinytc_value_t *yl = reinterpret_cast(yield_list.data()); CHECK_STATUS_LOC(tinytc_yield_inst_create(&instr, len, yl, &loc), loc); return inst(instr); } @@ -1186,86 +2297,52 @@ inline inst make_yield(std::vector const &yield_list, location const &loc //////////////////////////// namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_func_t handle) -> tinytc_status_t { - return tinytc_func_retain(handle); - } - static auto release(tinytc_func_t handle) -> tinytc_status_t { - return tinytc_func_release(handle); - } +template <> struct unique_handle_traits { + static void destroy(tinytc_func_t handle) { return tinytc_func_destroy(handle); } }; } // namespace internal //! @brief Reference-counting wrapper for tinytc_func_t -class func : public shared_handle { +class func : public unique_handle { public: - using shared_handle::shared_handle; -}; + using unique_handle::unique_handle; -namespace internal { -//! Is reinterpret_cast(&f) allowed, where f has type func -constexpr bool func_reinterpret_allowed = - std::is_standard_layout_v && sizeof(func) == sizeof(tinytc_func_t); -} // namespace internal + auto get_body() -> region { + tinytc_region_t body; + CHECK_STATUS(tinytc_func_get_body(obj_, &body)); + return region{body}; + } -/** - * @brief Make function prototype - * - * @param name Function name - * @param arg_list Argument list - * @param loc Source code location - * - * @return Function - */ -inline func make_function_prototype(char const *name, std::vector &arg_list, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_func_t fun; - auto len = arg_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("argument list too long"); + void set_attr(attr a) { CHECK_STATUS(tinytc_func_set_attr(obj_, a)); } + + void set_parameter_attr(std::int32_t arg_no, attr a) { + CHECK_STATUS(tinytc_func_set_parameter_attr(obj_, arg_no, a)); } - tinytc_value_t *al = reinterpret_cast(arg_list.data()); - CHECK_STATUS_LOC(tinytc_function_prototype_create(&fun, name, len, al, &loc), loc); - return func(fun); -} +}; /** * @brief Make function * - * @param prototype Function prototype - * @param body Function body + * @param name Function name + * @param param_type_list List of parameter types + * @param ty Result type (must be void for host-callable function) * @param loc Source code location * * @return Function */ -inline func make_function(func const &prototype, region const &body, location const &loc = {}) { +inline func make_func(std::string_view name, array_view param_type_list, data_type ty, + location const &loc = {}) { tinytc_func_t fun; - CHECK_STATUS_LOC(tinytc_function_create(&fun, prototype.get(), body.get(), &loc), loc); + auto len = param_type_list.size(); + if (len > std::numeric_limits::max()) { + throw std::out_of_range("param list too long"); + } + CHECK_STATUS_LOC( + tinytc_func_create(&fun, name.size(), name.data(), len, param_type_list.data(), ty, &loc), + loc); return func(fun); } -/** - * @brief Set work-group size (x,y) - * - * @param fun Function object; must have been created with "make_function" - * @param x x - * @param y y - */ -inline void set_work_group_size(func &fun, std::int32_t x, std::int32_t y) { - CHECK_STATUS(tinytc_function_set_work_group_size(fun.get(), x, y)); -} - -/** - * @brief Set subgroup size - * - * @param fun Function object; must have been created with "make_function" - * @param sgs Subgroup size - */ -inline void set_subgroup_size(func &fun, std::int32_t sgs) { - CHECK_STATUS(tinytc_function_set_subgroup_size(fun.get(), sgs)); -} - //////////////////////////// /////////// Prog /////////// //////////////////////////// @@ -1289,50 +2366,107 @@ class prog : public shared_handle { public: using shared_handle::shared_handle; + /** + * @brief Append function to program + * + * @param fun function + */ + inline void add_function(func fun) { + CHECK_STATUS(tinytc_prog_add_function(obj_, fun.release())); + } + /** * @brief Dump program to stderr */ - void dump() const { CHECK_STATUS(tinytc_prog_dump(obj_)); } + void dump() const { CHECK_STATUS(tinytc_prog_dump(obj_)); } + /** + * @brief Get context + * + * @return Compiler context + */ + auto get_compiler_context() const -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_prog_get_compiler_context(obj_, &ctx)); + return compiler_context{ctx, true}; + } + /** + * @brief Dump program to file + * + * @param filename Path to file + */ + void print_to_file(char const *filename) const { + CHECK_STATUS(tinytc_prog_print_to_file(obj_, filename)); + } + /** + * @brief Dump program to string + * + * @return C-string (unique handle) + */ + auto print_to_string() const -> unique_handle { + char *str; + CHECK_STATUS(tinytc_prog_print_to_string(obj_, &str)); + return unique_handle{str}; + } +}; + +/** + * @brief Make program + * + * @param ctx Compiler context + * @param loc Source code location + * + * @return Program + */ +inline prog make_prog(compiler_context const &ctx, location const &loc = {}) { + tinytc_prog_t prg; + CHECK_STATUS_LOC(tinytc_prog_create(&prg, ctx.get(), &loc), loc); + return prog{prg}; +} + +//////////////////////////// +/////// SPIR-V Module ////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_spv_mod_t handle) -> tinytc_status_t { + return tinytc_spv_mod_retain(handle); + } + static auto release(tinytc_spv_mod_t handle) -> tinytc_status_t { + return tinytc_spv_mod_release(handle); + } +}; +} // namespace internal + +//! @brief Reference-counting wrapper for tinytc_spv_mod_t +class spv_mod : public shared_handle { + public: + using shared_handle::shared_handle; + + /** + * @brief Dump module to stderr + */ + void dump() const { CHECK_STATUS(tinytc_spv_mod_dump(obj_)); } /** - * @brief Dump program to file + * @brief Dump module to file * * @param filename Path to file */ void print_to_file(char const *filename) const { - CHECK_STATUS(tinytc_prog_print_to_file(obj_, filename)); + CHECK_STATUS(tinytc_spv_mod_print_to_file(obj_, filename)); } /** - * @brief Dump program to string + * @brief Dump module to string * * @return C-string (unique handle) */ auto print_to_string() const -> unique_handle { char *str; - CHECK_STATUS(tinytc_prog_print_to_string(obj_, &str)); + CHECK_STATUS(tinytc_spv_mod_print_to_string(obj_, &str)); return unique_handle{str}; } }; -/** - * @brief Make program - * - * @param fun_list Vector of functions - * @param loc Source code location - * - * @return Program - */ -inline prog make_program(std::vector &fun_list, location const &loc = {}) { - tinytc_prog_t prg; - static_assert(internal::func_reinterpret_allowed); - auto len = fun_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("function list too long"); - } - tinytc_func_t *fl = reinterpret_cast(fun_list.data()); - CHECK_STATUS_LOC(tinytc_program_create(&prg, len, fl, &loc), loc); - return prog{prg}; -} - //////////////////////////// ////////// Builder ///////// //////////////////////////// @@ -1341,30 +2475,37 @@ inline prog make_program(std::vector &fun_list, location const &loc = {}) class region_builder { public: /** - * @brief Returns built product + * @brief ctor * - * @param loc Source code location + * @param reg region object + */ + region_builder(region reg) : reg_{reg}, ip_{reg_.end()} {} + /** + * @brief ctor * - * @return Region + * @param reg region object + * @param ip insertion point */ - inline auto get_product(location const &loc = {}) -> region { - return make_region(instructions_, loc); - } + region_builder(region reg, tinytc_inst_iterator_t ip) : reg_{reg}, ip_{ip} {} + + /** + * @brief Get insertion point + * + * @return Iterator + */ + inline auto get_insertion_point() const -> tinytc_inst_iterator_t { return ip_; } /** * @brief Add instruction * * @param i Instruction - * @param name Result name * * @return Value returned by instruction; may be empty */ - [[maybe_unused]] inline auto add(inst i, std::string const &name = "") -> value { - auto result = i.get_value(); - if (result && name.size() > 0) { - result.name(name); - } - instructions_.emplace_back(std::move(i)); + [[maybe_unused]] inline auto add(inst i) -> value { + auto result = value{}; + i.get_values(result); + reg_.insert(ip_, std::move(i)); return result; } @@ -1372,42 +2513,40 @@ class region_builder { * @brief Add instruction that returns multiple values * * @param i Instruction - * @param name Result name * * @return Values returned by instruction */ - [[maybe_unused]] inline auto add_multivalued(inst i, std::string const &name = "") - -> std::vector { - auto results = i.get_values(); - if (name.size() > 0) { - int counter = 0; - for (auto &result : results) { - result.name(name + std::to_string(counter++)); - } - } - instructions_.emplace_back(std::move(i)); + [[maybe_unused]] inline auto add_multivalued(inst i) -> std::vector { + auto num_results = i.get_values({}); + auto results = std::vector(static_cast(num_results)); + results.resize(i.get_values(results)); + reg_.insert(ip_, std::move(i)); return results; } /** - * @brief Build for-loop with functor f(region_builder&) -> void + * @brief Build for-loop with functor f(region_builder&, value) -> void + * + * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type - * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound + * @param loop_var_ty Type of loop variable * @param f Functor - * @param name Loop variable name + * @param attributes For attributes * @param loc Source code location */ template - void for_loop(scalar_type loop_var_ty, value const &from, value const &to, F &&f, - std::string const &name = "", location const &loc = {}) { - for_loop(std::move(loop_var_ty), std::move(from), std::move(to), value{nullptr}, - std::forward(f), name, loc); + void for_loop(data_type loop_var_ty, value from, value to, F &&f, attr attributes = nullptr, + location const &loc = {}) { + for_loop(std::move(loop_var_ty), std::move(from), std::move(to), nullptr, + std::forward(f), attributes, loc); } /** - * @brief Build for-loop with functor f(region_builder&) -> void + * @brief Build for-loop with functor f(region_builder&, value) -> void + * + * The loop trip count is passed as second argument to the functor. * * @tparam F Functor type * @param loop_var_ty Type of loop variable @@ -1415,62 +2554,113 @@ class region_builder { * @param to Loop variable bound * @param step Loop variable step * @param f Functor - * @param name Loop variable name + * @param attributes For attributes * @param loc Source code location */ template - void for_loop(scalar_type loop_var_ty, value const &from, value const &to, value const &step, - F &&f, std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(loop_var_ty); - if (name.size() > 0) { - loop_var.name(name); + void for_loop(data_type loop_var_ty, value from, value to, value step, F &&f, + attr attributes = nullptr, location const &loc = {}) { + auto fi = ::tinytc::make_for(loop_var_ty, from, to, step, {}, {}, loc); + fi.set_attr(attributes); + auto reg = region{}; + fi.get_regions(reg); + auto loop_var = value{}; + reg.get_parameters(loop_var); + if (!reg || !loop_var) { + throw status::internal_compiler_error; } - auto bb = region_builder{}; - f(bb); - add(::tinytc::make_for(std::move(loop_var), from, to, step, bb.get_product(), loc)); + reg_.insert(ip_, std::move(fi)); + auto bb = region_builder{reg}; + f(bb, loop_var); } /** - * @brief Build foreach-loop with functor f(region_builder&) -> void + * @brief Build for-loop with functor f(region_builder&, array_view) -> void + * + * The loop trip count is the first value in the array_view. + * The following values are the loop-carried values. * * @tparam F Functor type * @param loop_var_ty Type of loop variable * @param from Loop variable start * @param to Loop variable bound + * @param step Loop variable step + * @param initial_value_list Array of initial values; can be {} + * @param return_type_list Array of return types; can be {} + * @param f Functor + * @param attributes For attributes + * @param loc Source code location + */ + template + auto for_loop(data_type loop_var_ty, value from, value to, value step, + array_view initial_value_list, array_view return_type_list, + F &&f, attr attributes = nullptr, location const &loc = {}) + -> std::vector { + auto fi = ::tinytc::make_for(loop_var_ty, from, to, step, initial_value_list, + return_type_list, loc); + fi.set_attr(attributes); + auto reg = region{}; + fi.get_regions(reg); + auto num_params = reg.get_parameters({}); + auto params = std::vector(num_params); + reg.get_parameters(params); + if (!reg || num_params != 1 + initial_value_list.size()) { + throw status::internal_compiler_error; + } + auto results = add_multivalued(std::move(fi)); + auto bb = region_builder{reg}; + f(bb, array_view(params)); + return results; + } + /** + * @brief Build foreach-loop with functor f(region_builder&, array_view) -> void + * + * @tparam F Functor type + * @param loop_var_ty Type of loop variable + * @param from Loop variable start list + * @param to Loop variable bound list * @param f functor - * @param name Loop variable name * @param loc Source code location */ template - void foreach (data_type const &loop_var_ty, value const &from, value const &to, F && f, - std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(loop_var_ty); - if (name.size() > 0) { - loop_var.name(name); + void foreach_loop(data_type loop_var_ty, array_view from, array_view to, F &&f, + location const &loc = {}) { + auto fi = ::tinytc::make_foreach(loop_var_ty, std::move(from), std::move(to), loc); + auto reg = region{}; + fi.get_regions(reg); + auto num_params = reg.get_parameters({}); + auto params = std::vector(num_params); + reg.get_parameters(params); + if (!reg || num_params != from.size() || num_params != to.size()) { + throw status::internal_compiler_error; } - auto bb = region_builder{}; - f(bb); - add(::tinytc::make_foreach(std::move(loop_var), from, to, bb.get_product(), loc)); + reg_.insert(ip_, std::move(fi)); + auto bb = region_builder{reg}; + f(bb, array_view(params)); } /** * @brief Build if with functor then(region_builder&) -> void * + * Note: If the if instruction returns values then we must have a "yield" instruction in + * both the "then" and the "else" branch. So to return values use the "ifelse" function. + * * @tparam F Functor type * @param condition Condition value * @param then Then region functor - * @param return_type_list Types of returned values * @param loc Source code location * * @return Returned values */ - template - auto if_condition(value const &condition, F &&then, - std::vector const &return_type_list = {}, - location const &loc = {}) -> std::vector { - auto bb = region_builder{}; + template void if_condition(value condition, F &&then, location const &loc = {}) { + auto ii = ::tinytc::make_if(std::move(condition), {}, loc); + auto reg = region{}; + ii.get_regions(reg); + if (!reg) { + throw status::internal_compiler_error; + } + reg_.insert(ip_, std::move(ii)); + auto bb = region_builder{reg}; then(bb); - return add_multivalued(::tinytc::make_if(std::move(condition), bb.get_product(), region{}, - return_type_list, loc)); } /** * @brief Build if/else with functors then(region_builder&) -> void and @@ -1487,139 +2677,33 @@ class region_builder { * @return Returned values */ template - auto ifelse(value const &condition, F &&then, G &&otherwise, - std::vector const &return_type_list = {}, location const &loc = {}) + auto ifelse(value condition, F &&then, G &&otherwise, + array_view return_type_list = {}, location const &loc = {}) -> std::vector { - auto bb1 = region_builder{}; - then(bb1); - auto bb2 = region_builder{}; - otherwise(bb2); - return add_multivalued(::tinytc::make_if(std::move(condition), bb1.get_product(), - bb2.get_product(), return_type_list, loc)); - } - - private: - std::vector instructions_; -}; - -//! Builder for functions -class function_builder { - public: - /** - * @brief creates function \@name - * - * @param name Function name - */ - inline function_builder(std::string name) : name_(std::move(name)), body_{nullptr} {} - - /** - * @brief Returns built product - * - * @param loc Source code location - * - * @return Function - */ - inline func get_product(location const &loc = {}) { - auto proto = make_function_prototype(name_.c_str(), arguments_, loc); - auto fun = make_function(proto, body_); - if (x_ > 0 && y_ > 0) { - set_work_group_size(fun, x_, y_); - } - if (sgs_ > 0) { - set_subgroup_size(fun, sgs_); - } - return fun; - } - - /** - * @brief @code %name: %ty @endcode - * - * @param ty Argument type - * @param name Argument name - * @param loc Source code location - * - * @return Value - */ - inline value argument(data_type const &ty, std::string const &name = "", - location const &loc = {}) { - auto v = make_value(ty, loc); - if (name.size() > 0) { - v.name(name); + auto ii = ::tinytc::make_if(std::move(condition), return_type_list, loc); + std::array regs = {}; + ii.get_regions(regs); + if (!regs[0] || !regs[1]) { + throw status::internal_compiler_error; } - arguments_.emplace_back(std::move(v)); - return arguments_.back(); - } - - /** - * @brief @code work_group_size(%x, %y) @endcode - * - * @param x x - * @param y y - */ - inline void work_group_size(std::int32_t x, std::int32_t y) { - x_ = x; - y_ = y; - } - /** - * @brief @code subgroup_size(%subgroup_size) @endcode - * - * @param subgroup_size Subgroup size - */ - inline void subgroup_size(std::int32_t subgroup_size) { sgs_ = subgroup_size; } - - /** - * @brief Build function body with functor f(region_builder&) -> void - * - * @tparam F Functor type - * @param f Functor - * @param loc Source code location - */ - template void body(F &&f, location const &loc = {}) { - auto bb = region_builder{}; - f(bb); - body_ = bb.get_product(loc); + auto results = add_multivalued(std::move(ii)); + auto bb0 = region_builder{regs[0]}; + then(bb0); + auto bb1 = region_builder{regs[1]}; + otherwise(bb1); + return results; } - private: - std::string name_; - region body_; - std::vector arguments_; - std::int32_t x_ = 0, y_ = 0, sgs_ = 0; -}; - -//! Builder for programs -class program_builder { - public: - /** - * @brief create function \@name with functor f(function_builder&) -> void - * - * @tparam F Functor type - * @param name Function name - * @param f Functor - * @param loc Source code location - */ - template void create(std::string name, F &&f, location const &loc = {}) { - auto fb = function_builder(std::move(name)); - f(fb); - add(fb.get_product(loc)); - } - /** - * @brief Add function - * - * @param f function - */ - inline void add(func f) { functions_.emplace_back(std::move(f)); } /** - * @brief Returns built product - * - * @param loc Source code location + * @brief Get region * - * @return Program + * @return Region */ - inline prog get_product(location const &loc = {}) { return make_program(functions_, loc); } + inline auto get_region() -> region { return reg_; } private: - std::vector functions_; + region reg_; + tinytc_inst_iterator_t ip_; }; //////////////////////////// @@ -1645,13 +2729,13 @@ class core_info : public shared_handle { /** * @brief Get subgroup sizes * - * Cf. @ref tinytc_core_info_get_subgroup_sizes - * - * @param sgs_size Pointer to size of subgroup size array - * @param sgs Pointer ot subgroup size array + * @return Subgroup sizes */ - void get_subgroup_sizes(std::uint32_t *sgs_size, std::int32_t const **sgs) { - CHECK_STATUS(tinytc_core_info_get_subgroup_sizes(obj_, sgs_size, sgs)); + auto get_subgroup_sizes() const -> array_view { + std::uint32_t sgs_size = 0; + std::int32_t const *sgs = nullptr; + CHECK_STATUS(tinytc_core_info_get_subgroup_sizes(obj_, &sgs_size, &sgs)); + return array_view(sgs, static_cast(sgs_size)); } /** @@ -1659,7 +2743,7 @@ class core_info : public shared_handle { * * @return Register space */ - auto get_register_space() -> std::int32_t { + auto get_register_space() const -> std::int32_t { std::int32_t space; CHECK_STATUS(tinytc_core_info_get_register_space(obj_, &space)); return space; @@ -1684,6 +2768,51 @@ class core_info : public shared_handle { CHECK_STATUS(tinytc_core_info_get_core_features(obj_, &flags)); return flags; } + + /** + * @brief Set SPIR-V feature + * + * @param feature SPIR-V feature + * @param available true if feature is available and false otherwise + */ + void set_spirv_feature(spirv_feature feature, bool available) { + CHECK_STATUS(tinytc_core_info_set_spirv_feature( + obj_, static_cast(feature), available)); + } + + /** + * @brief Get SPIR-V feature + * + * @param feature SPIR-V feature + * + * @return true if feature is available and false otherwise + */ + auto have_spirv_feature(spirv_feature feature) const -> bool { + tinytc_bool_t available; + CHECK_STATUS(tinytc_core_info_have_spirv_feature( + obj_, static_cast(feature), &available)); + return available; + } + + /** + * @brief Get default alignment + * + * @return alignment in bytes + */ + auto get_default_alignment() const -> std::int32_t { + std::int32_t alignment; + CHECK_STATUS(tinytc_core_info_get_default_alignment(obj_, &alignment)); + return alignment; + } + + /** + * @brief Set default alignment + * + * @param alignment alignment in bytes + */ + void set_default_alignment(std::int32_t alignment) { + CHECK_STATUS(tinytc_core_info_set_default_alignment(obj_, alignment)); + } }; /** @@ -1696,7 +2825,7 @@ class core_info : public shared_handle { * @return Core info */ inline auto make_core_info_generic(std::int32_t register_space, std::int32_t max_work_group_size, - std::vector sgs) -> core_info { + array_view sgs) -> core_info { tinytc_core_info_t info; CHECK_STATUS(tinytc_core_info_generic_create(&info, register_space, max_work_group_size, sgs.size(), sgs.data())); @@ -1717,6 +2846,19 @@ inline auto make_core_info_intel_from_arch(intel_gpu_architecture arch) -> core_ return core_info{info}; } +/** + * @brief Get core info for Intel GPUs from lookup table + * + * @param name architecture name + * + * @return Core info + */ +inline auto make_core_info_intel_from_name(char const *name) -> core_info { + tinytc_core_info_t info; + CHECK_STATUS(tinytc_core_info_intel_create_from_name(&info, name)); + return core_info{info}; +} + /** * @brief Create core info for Intel GPUs manually * @@ -1728,7 +2870,7 @@ inline auto make_core_info_intel_from_arch(intel_gpu_architecture arch) -> core_ * @return Core info */ inline auto make_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, - std::int32_t num_threads_per_eu, std::vector sgs) + std::int32_t num_threads_per_eu, array_view sgs) -> core_info { tinytc_core_info_t info; CHECK_STATUS(tinytc_core_info_intel_create(&info, ip_version, num_eus_per_subslice, @@ -1736,116 +2878,52 @@ inline auto make_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_ return core_info{info}; } +//! Convert SPIR-V feature to string +inline char const *to_string(spirv_feature f) { + return ::tinytc_spirv_feature_to_string(static_cast(f)); +} + //////////////////////////// ////////// Parser ////////// //////////////////////////// -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_source_context_t handle) -> tinytc_status_t { - return tinytc_source_context_retain(handle); - } - static auto release(tinytc_source_context_t handle) -> tinytc_status_t { - return tinytc_source_context_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_source_context_t -class source_context : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Add source to context - * - * @param name File name - * @param text Source text - * - * @return Source id (should be set in position.source_id) - */ - inline auto add_source(char const *name, char const *text) -> std::int32_t { - std::int32_t source_id; - CHECK_STATUS(tinytc_source_context_add_source(obj_, name, text, &source_id)); - return source_id; - } - /** - * @brief Get error log - * - * @return C-string that is valid as long as source_context is not modified; empty string if - * source_context is empty - */ - inline auto get_error_log() const noexcept -> char const * { - if (obj_) { - char const *log; - // No need to call CHECK_STATUS, as the only possible error code is - // tinytc_status_invalid_arguments but we only pass valid arguments - tinytc_source_context_get_error_log(obj_, &log); - return log; - } - return ""; - } - /** - * @brief Enhance error message with source context; useful when builder is used - * - * @param loc Source location - * @param what Error description - * @param append True: append to error log; false: clear error log - */ - inline void report_error(location const &loc, char const *what, bool append = true) { - CHECK_STATUS(tinytc_source_context_report_error(obj_, &loc, what, - static_cast(append))); - } -}; - -/** - * @brief Create source context - * - * @return Source context - */ -inline auto make_source_context() -> source_context { - tinytc_source_context_t ctx; - CHECK_STATUS(tinytc_source_context_create(&ctx)); - return source_context{ctx}; -} - /** * @brief Parse source text from file * * @param filename Filename - * @param source_ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Program */ -inline auto parse_file(char const *filename, source_context source_ctx = {}) -> prog { +inline auto parse_file(char const *filename, compiler_context const &ctx = {}) -> prog { tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_file(&prg, filename, source_ctx.get())); + CHECK_STATUS(tinytc_parse_file(&prg, filename, ctx.get())); return prog(prg); } /** * @brief Parse source text from stdin * - * @param source_ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Program */ -inline auto parse_stdin(source_context source_ctx = {}) -> prog { +inline auto parse_stdin(compiler_context const &ctx = {}) -> prog { tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_stdin(&prg, source_ctx.get())); + CHECK_STATUS(tinytc_parse_stdin(&prg, ctx.get())); return prog(prg); } /** * @brief Parse source text from string * * @param src Source text - * @param source_ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Porgram */ -inline auto parse_string(std::string const &src, source_context source_ctx = {}) -> prog { +inline auto parse_string(std::string const &src, compiler_context const &ctx = {}) -> prog { tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), source_ctx.get())); + CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), ctx.get())); return prog(prg); } @@ -1853,57 +2931,6 @@ inline auto parse_string(std::string const &src, source_context source_ctx = {}) ///////// Compiler ///////// //////////////////////////// -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_source_t handle) -> tinytc_status_t { - return tinytc_source_retain(handle); - } - static auto release(tinytc_source_t handle) -> tinytc_status_t { - return tinytc_source_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_source_t -class source : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get code - * - * @return Pointer to C-string that is bound to the lifetime of the source object - */ - inline auto get_code() const -> std::string_view { - char const *code = nullptr; - std::size_t length = 0; - CHECK_STATUS(tinytc_source_get_code(obj_, &length, &code)); - return std::string_view(code, length); - } - - /** - * @brief Get location - * - * @return Location - */ - inline auto get_location() const -> location { - location loc = {}; - CHECK_STATUS(tinytc_source_get_location(obj_, &loc)); - return loc; - } - - /** - * @brief Get OpenCL extension - * - * @param extensions_size Number of extensions - * @param extensions Array of extensions - */ - inline void get_extensions(std::uint32_t &extensions_size, - char const *const *&extensions) const { - CHECK_STATUS(tinytc_source_get_extensions(obj_, &extensions_size, &extensions)); - } -}; - namespace internal { template <> struct shared_handle_traits { static auto retain(tinytc_binary_t handle) -> tinytc_status_t { @@ -1932,19 +2959,29 @@ class binary : public shared_handle { * * @return Raw data */ - inline auto get_raw() -> raw { + inline auto get_raw() const -> raw { raw r; tinytc_bundle_format_t f; CHECK_STATUS(tinytc_binary_get_raw(obj_, &f, &r.data_size, &r.data)); r.format = bundle_format{std::underlying_type_t(f)}; return r; } + /** + * @brief Get compiler context + * + * @return Compiler context + */ + inline auto get_compiler_context() const -> compiler_context { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_binary_get_compiler_context(obj_, &ctx)); + return compiler_context{ctx, true}; + } /** * @brief Get core features * * @return Core features */ - inline auto get_core_features() -> tinytc_core_feature_flags_t { + inline auto get_core_features() const -> tinytc_core_feature_flags_t { tinytc_core_feature_flags_t cf; CHECK_STATUS(tinytc_binary_get_core_features(obj_, &cf)); return cf; @@ -1954,6 +2991,7 @@ class binary : public shared_handle { /** * @brief Make binary * + * @param ctx Compiler context * @param format Bundle format (SPIR-V or Native) * @param data_size Size of data in bytes * @param data Binary data; data is copied @@ -1962,27 +3000,75 @@ class binary : public shared_handle { * * @return Binary */ -inline auto make_binary(bundle_format format, std::size_t data_size, std::uint8_t const *data, - tinytc_core_feature_flags_t core_features) -> binary { +inline auto make_binary(compiler_context const &ctx, bundle_format format, std::size_t data_size, + std::uint8_t const *data, tinytc_core_feature_flags_t core_features) + -> binary { tinytc_binary_t bin; - CHECK_STATUS(tinytc_binary_create(&bin, static_cast(format), data_size, - data, core_features)); + CHECK_STATUS(tinytc_binary_create(&bin, ctx.get(), static_cast(format), + data_size, data, core_features)); return binary{bin}; } /** - * @brief Compile program to OpenCL-C + * @brief Run a function pass on every function of a program + * + * @param pass_name name of function pass; cf. list_function_passes + * @param prg tensor program; modified as compiler pass is run + * @param info core info object; might be nullptr if core info is not required for pass + */ +inline void run_function_pass(char const *pass_name, prog prg, core_info info = {}) { + CHECK_STATUS(tinytc_run_function_pass(pass_name, prg.get(), info.get())); +} + +/** + * @brief Get function pass names + * + * @param names_size Number of function pass names + * @param names Array of function pass names + */ +inline void list_function_passes(std::uint32_t &names_size, char const *const *&names) { + CHECK_STATUS(tinytc_list_function_passes(&names_size, &names)); +} + +/** + * @brief Convert tensor language to SPIR-V + * + * @param prg Program + * @param info Core info + * + * @return SPIR-V module + */ +inline auto compile_to_spirv(prog prg, core_info const &info) -> spv_mod { + tinytc_spv_mod_t mod; + CHECK_STATUS(tinytc_prog_compile_to_spirv(&mod, prg.get(), info.get())); + return spv_mod{mod}; +} + +/** + * @brief Compile program to SPIR-V and assemble * * @param prg Program * @param info Core info - * @param ctx Source context for improved error reporting * - * @return Source + * @return Binary + */ +inline auto compile_to_spirv_and_assemble(prog prg, core_info const &info) -> binary { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, prg.get(), info.get())); + return binary{bin}; +} + +/** + * @brief Assemble SPIR-V module + * + * @param mod [in] SPIR-V module + * + * @return Binary */ -inline auto compile_to_opencl(prog prg, core_info const &info, source_context ctx = {}) -> source { - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, prg.get(), info.get(), ctx.get())); - return source{src}; +inline auto spirv_assemble(spv_mod const &mod) -> binary { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_spirv_assemble(&bin, mod.get())); + return binary{bin}; } //////////////////////////// @@ -1997,16 +3083,32 @@ inline auto compile_to_opencl(prog prg, core_info const &info, source_context ct template struct auto_mem_type; /** - * @brief True if T is either pointer to a fundamental type or a pointer to a pointer to a - * fundamental type + * @brief Check whether T maps to a scalar data type + * + * @tparam T type + */ +template +constexpr bool is_supported_scalar_type = std::is_same_v || // i8 + std::is_same_v || // i16 + std::is_same_v || // i32 + std::is_same_v || // i64 + std::is_same_v || // f32 + std::is_same_v || // f64 + std::is_same_v> || // c32 + std::is_same_v>; // c64 + +/** + * @brief True if T is either pointer to a support scalar type or a pointer to a pointer to a + * supported scalar type; void* is fine, too * * @tparam T type */ template -constexpr bool usm_pointer_type = - std::is_pointer_v && - (std::is_fundamental_v> || - std::is_fundamental_v>>); +constexpr bool is_usm_pointer_type = + std::is_same_v || + (std::is_pointer_v && + (is_supported_scalar_type> || + is_supported_scalar_type>>)); /** * @brief Specialize auto_mem_type for pointer to non-class types @@ -2016,7 +3118,7 @@ constexpr bool usm_pointer_type = * * @tparam T memory object type */ -template struct auto_mem_type>> { +template struct auto_mem_type>> { constexpr static mem_type value = mem_type::usm_pointer; ///< Pointer maps to USM pointer type }; @@ -2079,14 +3181,14 @@ class recipe : public shared_handle { } /** - * @brief Get source + * @brief Get binary * - * @return Source + * @return Binary */ - auto get_source() const -> source { - tinytc_source_t src; - CHECK_STATUS(tinytc_recipe_get_source(obj_, &src)); - return source{src}; + auto get_binary() const -> binary { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_recipe_get_binary(obj_, &bin)); + return binary{bin}; } }; @@ -2152,7 +3254,7 @@ class small_gemm_batched : public recipe { * @param strideB Stride of B-matrices * @param ldC Leading dimension of an C matrix * @param strideC Stride of C-matrices - * @param ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Small GEMM batched recipe */ @@ -2160,7 +3262,7 @@ inline auto make_small_gemm_batched(core_info const &info, scalar_type ty, trans transpose tB, std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t strideA, std::int64_t ldB, std::int64_t strideB, std::int64_t ldC, std::int64_t strideC, - source_context ctx = {}) -> small_gemm_batched { + compiler_context const &ctx = {}) -> small_gemm_batched { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_small_gemm_batched_create( &rec, info.get(), static_cast(ty), @@ -2209,13 +3311,13 @@ class tall_and_skinny : public recipe { * @param N Number of columns of B and C * @param K Number of columns of A, number of rows of B * @param M_block_size Chunk size for M-mode - * @param ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Tall and skinny recipe */ inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int64_t N, std::int64_t K, std::int32_t M_block_size = 0, - source_context ctx = {}) -> tall_and_skinny { + compiler_context const &ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create( &rec, info.get(), static_cast(ty), N, K, M_block_size, ctx.get())); @@ -2235,23 +3337,37 @@ inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int * @param ldA Leading dimension of A; can be dynamic * @param ldB Leading dimension of B; can be dynamic * @param ldC Leading dimension of C; can be dynamic + * @param alignA [in] Memory alignment of A; can be 0 + * @param alignB [in] Memory alignment of B; can be 0 + * @param alignC [in] Memory alignment of C; can be 0 * @param M_block_size Chunk size for M-mode - * @param ctx Source context for improved error reporting + * @param ctx Compiler context * * @return Tall and skinny recipe */ inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type ty, std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, - std::int32_t M_block_size = 0, source_context ctx = {}) - -> tall_and_skinny { + std::int32_t alignA, std::int32_t alignB, + std::int32_t alignC, std::int32_t M_block_size = 0, + compiler_context const &ctx = {}) -> tall_and_skinny { tinytc_recipe_t rec; CHECK_STATUS(tinytc_recipe_tall_and_skinny_create_specialized( - &rec, info.get(), static_cast(ty), M, N, K, ldA, ldB, ldC, - M_block_size, ctx.get())); + &rec, info.get(), static_cast(ty), M, N, K, ldA, ldB, ldC, alignA, + alignB, alignC, M_block_size, ctx.get())); return tall_and_skinny{rec}; } } // namespace tinytc +namespace std { +template struct hash> { + size_t operator()(tinytc::lp_float const &val) const noexcept { + using h = hash::lp_format::bits_type>; + return h{}(val.bits()); + } +}; + +} // namespace std + #endif // TINYTC_20240403_HPP diff --git a/include/tinytc/tinytc_cl.h b/include/tinytc/tinytc_cl.h index 148958b2..dd62c114 100644 --- a/include/tinytc/tinytc_cl.h +++ b/include/tinytc/tinytc_cl.h @@ -57,22 +57,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *inf ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bundle [out] pointer to the kernel bundle (cl_program) object created - * @param context [in] context handle - * @param device [in] device handle - * @param src [in] source text and extensions - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source( - cl_program *bundle, cl_context context, cl_device_id device, const_tinytc_source_t src, - tinytc_source_context_t source_ctx); - /** * @brief Compile tensor program * @@ -82,14 +66,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source( * @param prg [inout] tensor program; modified as compiler passes are run * @param core_features [in][optional] requested core features; must be 0 (default) or a combination * of tinytc_core_feature_flag_t - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, tinytc_source_context_t source_ctx); + tinytc_core_feature_flags_t core_features); /** * @brief Create an OpenCL program from a tinytc binary @@ -98,14 +80,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( * @param context [in] context handle * @param device [in] device handle * @param bin [in] binary object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary( - cl_program *bundle, cl_context context, cl_device_id device, const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, + cl_context context, + cl_device_id device, + const_tinytc_binary_t bin); /** * @brief Get work group size for kernel @@ -121,11 +102,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_get_group_size(cl_kernel kernel, size_t /** * @brief Convert group size to opencl global range * - * @param howmany group size + * @param num_groups [in][range(0,3)] pointer to number of groups of size >= 3 * @param local_size [in][range(0,3)] pointer to local size array of size >= 3 * @param global_size [out][range(0,3)] pointer to global size array of size >= 3 */ -TINYTC_EXPORT void tinytc_cl_get_global_size(int64_t howmany, const size_t *local_size, +TINYTC_EXPORT void tinytc_cl_get_global_size(const size_t *num_groups, const size_t *local_size, size_t *global_size); //////////////////////////// @@ -139,16 +120,13 @@ TINYTC_EXPORT void tinytc_cl_get_global_size(int64_t howmany, const size_t *loca * @param context [in] context handle * @param device [in] device handle * @param recipe [in] recipe object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cl_recipe_handler_create(tinytc_recipe_handler_t *handler, cl_context context, cl_device_id device, - tinytc_recipe_t recipe, - tinytc_source_context_t source_ctx); + tinytc_recipe_t recipe); /** * @brief Submit recipe to device diff --git a/include/tinytc/tinytc_cl.hpp b/include/tinytc/tinytc_cl.hpp index 93227d29..2dd01e7b 100644 --- a/include/tinytc/tinytc_cl.hpp +++ b/include/tinytc/tinytc_cl.hpp @@ -77,24 +77,6 @@ template <> struct shared_handle_traits { }; } // namespace internal -/** - * @brief Make an OpenCL program from a tinytc source - * - * @param context Context - * @param device Device - * @param src Source - * @param source_ctx Source context for improved error reporting - * - * @return cl_program (shared handle) - */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, source const &src, - source_context source_ctx = {}) -> shared_handle { - cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_source(&obj, context, device, src.get(), - source_ctx.get())); - return shared_handle{obj}; -} - /** * @brief Make an OpenCL program from a tinytc program * @@ -103,16 +85,15 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, source c * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ inline auto make_kernel_bundle(cl_context context, cl_device_id device, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) -> shared_handle { + tinytc_core_feature_flags_t core_features = 0) + -> shared_handle { cl_program obj; CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_program(&obj, context, device, prg.get(), - core_features, source_ctx.get())); + core_features)); return shared_handle{obj}; } @@ -122,15 +103,13 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, prog prg * @param context Context * @param device Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, binary const &bin, - source_context source_ctx = {}) -> shared_handle { +inline auto make_kernel_bundle(cl_context context, cl_device_id device, binary const &bin) + -> shared_handle { cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_binary(&obj, context, device, bin.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_binary(&obj, context, device, bin.get())); return shared_handle{obj}; } @@ -165,15 +144,16 @@ inline auto get_group_size(cl_kernel kernel) -> std::array { /** * @brief Convert group size to opencl global range * - * @param howmany Group size + * @param num_groups Number of groups * @param local_size Work-group size * * @return Global size */ -inline auto get_global_size(std::int64_t howmany, std::array const &local_size) +inline auto get_global_size(std::array const &num_groups, + std::array const &local_size) -> std::array { auto global_size = std::array{}; - tinytc_cl_get_global_size(howmany, local_size.data(), global_size.data()); + tinytc_cl_get_global_size(num_groups.data(), local_size.data(), global_size.data()); return global_size; } @@ -242,15 +222,13 @@ class opencl_recipe_handler : public recipe_handler { * @param context Context * @param device Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return OpenCL recipe handler */ -inline auto make_recipe_handler(cl_context context, cl_device_id device, recipe const &rec, - source_context source_ctx = {}) -> opencl_recipe_handler { +inline auto make_recipe_handler(cl_context context, cl_device_id device, recipe const &rec) + -> opencl_recipe_handler { tinytc_recipe_handler_t handler; - CHECK_STATUS( - tinytc_cl_recipe_handler_create(&handler, context, device, rec.get(), source_ctx.get())); + CHECK_STATUS(tinytc_cl_recipe_handler_create(&handler, context, device, rec.get())); return opencl_recipe_handler{handler}; } diff --git a/include/tinytc/tinytc_sycl.hpp b/include/tinytc/tinytc_sycl.hpp index 52bc2b80..35c4a2f8 100644 --- a/include/tinytc/tinytc_sycl.hpp +++ b/include/tinytc/tinytc_sycl.hpp @@ -41,20 +41,6 @@ TINYTC_EXPORT auto make_core_info(sycl::device const &dev) -> core_info; ////////// Kernel ////////// //////////////////////////// -/** - * @brief Make SYCL kernel bundle from tinytc source - * - * @param ctx Context - * @param dev Device - * @param src Source - * @param source_ctx Source context for improved error reporting - * - * @return SYCL kernel bundle - */ -TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - source const &src, source_context source_ctx = {}) - -> sycl::kernel_bundle; - /** * @brief Make SYCL kernel bundle from tinytc program * @@ -63,13 +49,11 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) + tinytc_core_feature_flags_t core_features = 0) -> sycl::kernel_bundle; /** @@ -78,12 +62,11 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * @param ctx Context * @param dev Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - binary const &bin, source_context source_ctx = {}) + binary const &bin) -> sycl::kernel_bundle; /** @@ -109,23 +92,29 @@ TINYTC_EXPORT auto get_group_size(sycl::kernel const &krnl) -> sycl::range<3u>; /** * @brief Convert group size to SYCL range * - * @param howmany Group size + * **Important:** num_groups is in SYCL ZYX order, meaning that the range should contain + * {num_groups_z, num_groups_y, num_groups_x}. + * + * @param num_groups Number of groups * @param local_size Work-group size * * @return Global size */ -TINYTC_EXPORT auto get_global_size(std::int64_t howmany, sycl::range<3u> const &local_size) - -> sycl::range<3u>; +TINYTC_EXPORT auto get_global_size(sycl::range<3u> const &num_groups, + sycl::range<3u> const &local_size) -> sycl::range<3u>; /** * @brief Get SYCL nd_range * + * **Important:** num_groups is in SYCL ZYX order, meaning that the range should contain + * {num_groups_z, num_groups_y, num_groups_x}. + * * @param krnl Kernel - * @param howmany Group size + * @param num_groups Number of groups * * @return ND range */ -TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, std::int64_t howmany) +TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, sycl::range<3u> const &num_groups) -> sycl::nd_range<3u>; //////////////////////////// @@ -179,24 +168,21 @@ class TINYTC_EXPORT sycl_recipe_handler : public recipe_handler { * @param ctx Context * @param dev Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return SYCL recipe handler */ TINYTC_EXPORT auto make_recipe_handler(sycl::context const &ctx, sycl::device const &dev, - recipe const &rec, source_context source_ctx = {}) - -> sycl_recipe_handler; + recipe const &rec) -> sycl_recipe_handler; /** * @brief Make recipe handler * * @param q Queue * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return SYCL recipe handler */ -TINYTC_EXPORT auto make_recipe_handler(sycl::queue const &q, recipe const &rec, - source_context source_ctx = {}) -> sycl_recipe_handler; +TINYTC_EXPORT auto make_recipe_handler(sycl::queue const &q, recipe const &rec) + -> sycl_recipe_handler; } // namespace tinytc diff --git a/include/tinytc/tinytc_ze.h b/include/tinytc/tinytc_ze.h index d5b4d45a..d485cad1 100644 --- a/include/tinytc/tinytc_ze.h +++ b/include/tinytc/tinytc_ze.h @@ -57,38 +57,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_core_info_create(tinytc_core_info_t *inf ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bin [out] pointer to the binary object created - * @param src [in] source text - * @param ip_version [in] IP version (pass tinytc_intel_gpu_architecture_t here) - * @param format [in] binary format (SPIR-V or native) - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_source_compile_to_binary( - tinytc_binary_t *bin, const_tinytc_source_t src, uint32_t ip_version, - tinytc_bundle_format_t format, tinytc_source_context_t source_ctx); - -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bundle [out] pointer to the kernel bundle (ze_module_handle_t) object created - * @param context [in] context handle - * @param device [in] device handle - * @param src [in] source text and extensions - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_source( - ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_source_t src, tinytc_source_context_t source_ctx); - /** * @brief Compile tensor program * @@ -98,15 +66,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_source( * @param prg [inout] tensor program; modified as compiler passes are run * @param core_features [in][optional] requested core features; must be 0 (default) or a combination * of tinytc_core_feature_flag_t - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_program( ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - tinytc_prog_t prg, tinytc_core_feature_flags_t core_features, - tinytc_source_context_t source_ctx); + tinytc_prog_t prg, tinytc_core_feature_flags_t core_features); /** * @brief Create an OpenCL program from a tinytc binary @@ -115,14 +80,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_program( * @param context [in] context handle * @param device [in] device handle * @param bin [in] binary object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_binary( - ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_binary_t bin, tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t +tinytc_ze_kernel_bundle_create_with_binary(ze_module_handle_t *bundle, ze_context_handle_t context, + ze_device_handle_t device, const_tinytc_binary_t bin); /** * @brief Create a kernel and set group size @@ -149,15 +112,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_create(ze_kernel_handle_t *krnl, TINYTC_EXPORT tinytc_status_t tinytc_ze_get_group_size(ze_kernel_handle_t kernel, uint32_t *x, uint32_t *y, uint32_t *z); -/** - * @brief Convert group size to level zero group count - * - * @param howmany group size - * - * @return group count - */ -TINYTC_EXPORT ze_group_count_t tinytc_ze_get_group_count(int64_t howmany); - //////////////////////////// ////////// Recipe ////////// //////////////////////////// @@ -169,16 +123,13 @@ TINYTC_EXPORT ze_group_count_t tinytc_ze_get_group_count(int64_t howmany); * @param context [in] context handle * @param device [in] device handle * @param recipe [in] recipe object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_ze_recipe_handler_create(tinytc_recipe_handler_t *handler, ze_context_handle_t context, ze_device_handle_t device, - tinytc_recipe_t recipe, - tinytc_source_context_t source_ctx); + tinytc_recipe_t recipe); /** * @brief Submit recipe to device diff --git a/include/tinytc/tinytc_ze.hpp b/include/tinytc/tinytc_ze.hpp index 57a2eac2..9932e0b2 100644 --- a/include/tinytc/tinytc_ze.hpp +++ b/include/tinytc/tinytc_ze.hpp @@ -58,24 +58,6 @@ inline auto make_core_info(ze_device_handle_t device) -> core_info { ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile source to binary - * - * @param src Source object - * @param ip_version IP version (pass tinytc_intel_gpu_architecture_t here) - * @param format Bundle format (SPIR-V or Native) - * @param ctx Source context for improved error reporting - * - * @return Binary - */ -inline auto compile_to_binary(source const &src, std::uint32_t ip_version, bundle_format format, - source_context ctx = {}) -> binary { - tinytc_binary_t bin; - CHECK_STATUS(tinytc_ze_source_compile_to_binary( - &bin, src.get(), ip_version, static_cast(format), ctx.get())); - return binary{bin}; -} - namespace internal { template <> struct unique_handle_traits { static void destroy(ze_kernel_handle_t obj) { zeKernelDestroy(obj); } @@ -85,25 +67,6 @@ template <> struct unique_handle_traits { }; } // namespace internal -/** - * @brief Make a Level Zero module from a tinytc source - * - * @param context Context - * @param device Device - * @param src Source - * @param source_ctx Source context for improved error reporting - * - * @return Level Zero module (unique handle) - */ -inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - source const &src, source_context source_ctx = {}) - -> unique_handle { - ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_source(&obj, context, device, src.get(), - source_ctx.get())); - return unique_handle{obj}; -} - /** * @brief Make a Level Zero module from a tinytc program * @@ -112,17 +75,15 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) + tinytc_core_feature_flags_t core_features = 0) -> unique_handle { ze_module_handle_t obj; CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_program(&obj, context, device, prg.get(), - core_features, source_ctx.get())); + core_features)); return unique_handle{obj}; } @@ -132,16 +93,13 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * @param context Context * @param device Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - binary const &bin, source_context source_ctx = {}) - -> unique_handle { + binary const &bin) -> unique_handle { ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_binary(&obj, context, device, bin.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_binary(&obj, context, device, bin.get())); return unique_handle{obj}; } @@ -153,8 +111,8 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * * @return Level Zero kernel (unique handle) */ -inline auto make_kernel(ze_module_handle_t mod, char const *name) - -> unique_handle { +inline auto make_kernel(ze_module_handle_t mod, + char const *name) -> unique_handle { ze_kernel_handle_t obj; CHECK_STATUS(tinytc_ze_kernel_create(&obj, mod, name)); return unique_handle{obj}; @@ -173,17 +131,6 @@ inline auto get_group_size(ze_kernel_handle_t kernel) -> std::array ze_group_count_t { - return tinytc_ze_get_group_count(howmany); -} - //////////////////////////// ////////// Recipe ////////// //////////////////////////// @@ -218,16 +165,13 @@ class level_zero_recipe_handler : public recipe_handler { * @param context Context * @param device Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return Level Zero recipe handler */ inline auto make_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, - recipe const &rec, source_context source_ctx = {}) - -> level_zero_recipe_handler { + recipe const &rec) -> level_zero_recipe_handler { tinytc_recipe_handler_t handler; - CHECK_STATUS( - tinytc_ze_recipe_handler_create(&handler, context, device, rec.get(), source_ctx.get())); + CHECK_STATUS(tinytc_ze_recipe_handler_create(&handler, context, device, rec.get())); return level_zero_recipe_handler{handler}; } diff --git a/include/tinytc/types.h b/include/tinytc/types.h index a404daa1..9249e302 100644 --- a/include/tinytc/types.h +++ b/include/tinytc/types.h @@ -39,6 +39,9 @@ typedef enum { tinytc_status_unsupported_backend = 0xc, ///< Unsupported backend (SYCL runtime) tinytc_status_invalid_kernel_arguments = 0xd, ///< Kernel got invalid arguments tinytc_status_unsupported_device = 0xe, ///< Unsupported device + tinytc_status_invalid_core_info = 0xf, ///< Invalid core info object + tinytc_status_unknown_pass_name = 0x10, ///< Invalid compiler pass name + tinytc_status_not_implemented = 0x11, ///< Not implemented // IR errors tinytc_status_ir_out_of_bounds = 0x100, ///< Out of bounds access tinytc_status_ir_invalid_shape = 0x101, ///< Invalid tensor shape @@ -46,19 +49,78 @@ typedef enum { tinytc_status_ir_shape_stride_mismatch = 0x103, ///< Mismatch of shape and stride tinytc_status_ir_scalar_mismatch = 0x104, ///< Mismatch of scalar types tinytc_status_ir_invalid_number_of_indices = 0x105, /// Invalid number of indices - tinytc_status_ir_expected_scalar = 0x106, ///< Expected a value of scalar type - tinytc_status_ir_expected_memref = 0x107, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x108, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x109, ///< Expected a value of memref or group type - tinytc_status_ir_expected_vector_or_matrix = 0x10a, ///< Expected a vector or marix - tinytc_status_ir_unexpected_yield = 0x10b, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x10c, ///< Wrong number of yielded values - tinytc_status_ir_multiple_dynamic_modes = 0x10d, ///< At most one mode must be dynamic - tinytc_status_ir_invalid_slice = 0x10e, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x10f, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x110, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x111, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x112, ///< Instruction does not support floating type + tinytc_status_ir_expected_boolean = 0x106, ///< Expected a value of boolean type + tinytc_status_ir_expected_scalar = 0x107, ///< Expected a value of scalar type + tinytc_status_ir_expected_int = 0x108, ///< Expected a value of integer type + tinytc_status_ir_expected_float = 0x109, ///< Expected a value of float type + tinytc_status_ir_expected_complex = 0x10a, ///< Expected a value of complex type + tinytc_status_ir_expected_i32 = 0x10b, ///< Expected a value of i32 type + tinytc_status_ir_expected_index = 0x10c, ///< Expected a value of index type + tinytc_status_ir_expected_coopmatrix = 0x10d, ///< Expected a value of coopmatrix type + tinytc_status_ir_expected_coopmatrix_or_scalar = + 0x10e, ///< Expected a value of coopmatrix or scalar type + tinytc_status_ir_expected_coopmatrix_scalar_or_boolean = + 0x10f, ///< Expected a value of coopmatrix, scalar type, or boolean + tinytc_status_ir_expected_memref = 0x110, ///< Expected a value of memref type + tinytc_status_ir_expected_memref_or_scalar = 0x111, ///< Expected memref or scalar type + tinytc_status_ir_expected_memref_or_group = 0x112, ///< Expected a value of memref or group type + tinytc_status_ir_expected_memref_order_0 = 0x113, ///< Expected memref of order 0 + tinytc_status_ir_expected_memref_order_1 = 0x114, ///< Expected memref of order 1 + tinytc_status_ir_expected_memref_order_2 = 0x115, ///< Expected memref of order 2 + tinytc_status_ir_expected_memref_order_0_or_1 = 0x116, ///< Expected memref of order 0 or 1 + tinytc_status_ir_expected_memref_order_1_or_2 = 0x117, ///< Expected memref of order 1 or 2 + tinytc_status_ir_expected_memref_order_0_1_or_2 = 0x118, ///< Expected memref of order 0, 1 or 2 + tinytc_status_ir_unexpected_yield = 0x119, ///< Unexpected yield instruction + tinytc_status_ir_yield_mismatch = 0x11a, ///< Wrong number of yielded values + tinytc_status_ir_subview_mismatch = 0x11b, ///< Mismatch in subview + tinytc_status_ir_invalid_slice = 0x11c, ///< Invalid slice + tinytc_status_ir_expand_shape_order_too_small = 0x11d, ///< Expand shape too small + tinytc_status_ir_expand_shape_mismatch = 0x11e, ///< Invalid expand shape + tinytc_status_ir_collective_called_from_spmd = 0x11f, ///< Collective instruction from SPMD + tinytc_status_ir_fp_unsupported = 0x120, ///< Instruction does not support floating type + tinytc_status_ir_spmd_called_from_collective = 0x121, ///< SPMD instruction from collective + tinytc_status_ir_expected_local_address_space = 0x122, ///< Expected local address space + tinytc_status_ir_expected_global_address_space = 0x123, ///< Expected global address space + tinytc_status_ir_address_space_mismatch = 0x124, ///< Address space must match + tinytc_status_ir_invalid_offset = 0x125, ///< Invalid offset + tinytc_status_ir_int_unsupported = 0x126, ///< Instruction does not support int type + tinytc_status_ir_boolean_unsupported = 0x127, ///< Instruction does not support boolean type + tinytc_status_ir_complex_unsupported = 0x128, ///< Instruction does not support complex type + tinytc_status_ir_coopmatrix_unsupported = + 0x129, ///< Instruction does not support coopmatrix type + tinytc_status_ir_forbidden_cast = 0x12a, ///< Forbidden cast + tinytc_status_ir_invalid_beta = 0x12b, ///< Invalid beta value + tinytc_status_ir_init_return_mismatch = 0x12c, ///< Mismatch of init values and returned values + tinytc_status_ir_invalid_matrix_use = 0x12d, ///< Invalid matrix use + tinytc_status_ir_unsupported_coopmatrix_shape = 0x12e, ///< Unsupported coopmatrix shape + tinytc_status_ir_forbidden_promotion = 0x12f, ///< Forbidden promotion + tinytc_status_ir_constant_mismatch = 0x130, ///< Constant mismatch + tinytc_status_ir_insufficient_alignment = 0x131, ///< Insufficient alignment + tinytc_status_ir_must_have_yield = 0x132, ///< Must have yield instruction + tinytc_status_ir_yield_in_else_branch_missing = + 0x133, ///< Must have yield instruction in else branch + tinytc_status_ir_from_to_mismatch = 0x134, ///< size(from) != size(to) in foreach + tinytc_status_ir_operand_type_must_match_return_type = + 0x135, /// Operand type must match return type + tinytc_status_ir_invalid_stride = 0x136, ///< Invalid stride + tinytc_status_ir_init_return_type_mismatch = 0x137, ///< Init return type mismatch + tinytc_status_ir_invalid_alignment = 0x138, ///< Invalid alignment + tinytc_status_ir_value_still_has_uses = 0x139, ///< Value still has uses + tinytc_status_ir_expected_array_attribute = 0x140, ///< Expected array attribute + tinytc_status_ir_expected_boolean_attribute = 0x141, ///< Expected boolean attribute + tinytc_status_ir_expected_dictionary_attribute = 0x142, ///< Expected dictionary attribute + tinytc_status_ir_expected_integer_attribute = 0x143, ///< Expected integer attribute + tinytc_status_ir_expected_string_attribute = 0x144, ///< Expected string attribute + tinytc_status_ir_duplicate_key_in_dictionary = 0x145, ///< Duplicate key + tinytc_status_ir_unexpected_array_attribute_size = 0x146, ///< Unexpected array size + tinytc_status_ir_expected_non_scalar_memref = 0x147, ///< Expected memref of dimension >= 1 + // SPIR-V errors + tinytc_status_spirv_forbidden_forward_declaration = + 0x1000, ///< Forward declaration of id is forbidden + tinytc_status_spirv_undefined_value = 0x1001, ///< Undefined value + tinytc_status_spirv_missing_dope_vector = 0x1002, ///< Missing dope vector + tinytc_status_spirv_unsupported_atomic_data_type = 0x1003, ///< Unsupported atomic data type + tinytc_status_spirv_required_feature_unavailable = 0x1004, ///< Required feature unavailable // Level zero errors tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST @@ -212,36 +274,63 @@ typedef enum { //! Scalar types typedef enum { - tinytc_scalar_type_i1 = 0, ///< Signed 1 bit integer (boolean) - tinytc_scalar_type_i8 = 1, ///< Signed 8 bit integer - tinytc_scalar_type_i16 = 2, ///< Signed 16 bit integer - tinytc_scalar_type_i32 = 3, ///< Signed 32 bit integer - tinytc_scalar_type_i64 = 4, ///< Signed 64 bit integer - tinytc_scalar_type_index = 5, ///< Integer type for indices - tinytc_scalar_type_f32 = 6, ///< Single precision floating point (32 bit) - tinytc_scalar_type_f64 = 7 ///< Double precision floating point (64 bit) + tinytc_scalar_type_i8 = 0, ///< Signed 8 bit integer + tinytc_scalar_type_i16 = 1, ///< Signed 16 bit integer + tinytc_scalar_type_i32 = 2, ///< Signed 32 bit integer + tinytc_scalar_type_i64 = 3, ///< Signed 64 bit integer + tinytc_scalar_type_index = 4, ///< Integer type for indices + tinytc_scalar_type_bf16 = 5, ///< Brain floating point format with 16 bits + tinytc_scalar_type_f16 = 6, ///< Half precision floating point (16 bit) + tinytc_scalar_type_f32 = 7, ///< Single precision floating point (32 bit) + tinytc_scalar_type_f64 = 8, ///< Double precision floating point (64 bit) + tinytc_scalar_type_c32 = 9, ///< Single precision complex (2x32 bit) + tinytc_scalar_type_c64 = 10 ///< Double precision complex (2x64 bit) } tinytc_scalar_type_t; +#define TINYTC_NUMBER_OF_SCALAR_TYPES 11 // @todo Keep up to date with tinytc_scalar_type_t //! Arithmetic operations typedef enum { - tinytc_arithmetic_add = 0, ///< add - tinytc_arithmetic_sub = 1, ///< subtract - tinytc_arithmetic_mul = 2, ///< multiply - tinytc_arithmetic_div = 3, ///< divide - tinytc_arithmetic_rem = 4, ///< division remainder - tinytc_arithmetic_shl = 5, ///< left shift - tinytc_arithmetic_shr = 6, ///< arithmetic right shift - tinytc_arithmetic_and = 7, ///< bitwise and - tinytc_arithmetic_or = 8, ///< bitwise or - tinytc_arithmetic_xor = 9 ///< bitwise xor + tinytc_arithmetic_add = 0, ///< add + tinytc_arithmetic_sub = 1, ///< subtract + tinytc_arithmetic_mul = 2, ///< multiply + tinytc_arithmetic_div = 3, ///< divide + tinytc_arithmetic_rem = 4, ///< division remainder + tinytc_arithmetic_shl = 5, ///< left shift + tinytc_arithmetic_shr = 6, ///< arithmetic right shift + tinytc_arithmetic_and = 7, ///< bitwise and + tinytc_arithmetic_or = 8, ///< bitwise or + tinytc_arithmetic_xor = 9, ///< bitwise xor + tinytc_arithmetic_min = 10, ///< minimum + tinytc_arithmetic_max = 11 ///< maximum } tinytc_arithmetic_t; //! Arithmetic operations (unary) typedef enum { - tinytc_arithmetic_unary_neg = 0, ///< negation - tinytc_arithmetic_unary_not = 1 ///< bitwise not + tinytc_arithmetic_unary_neg = 0, ///< negation + tinytc_arithmetic_unary_not = 1, ///< bitwise not + tinytc_arithmetic_unary_abs = 2, ///< absolute value + tinytc_arithmetic_unary_conj = 3, ///< complex conjugate + tinytc_arithmetic_unary_im = 4, ///< imaginary part + tinytc_arithmetic_unary_re = 5 ///< real part } tinytc_arithmetic_unary_t; +//! Builtin values +typedef enum { + tinytc_builtin_group_id_x = 0, ///< group_id.x + tinytc_builtin_group_id_y = 1, ///< group_id.y + tinytc_builtin_group_id_z = 2, ///< group_id.z + tinytc_builtin_num_groups_x = 3, ///< num_groups.x + tinytc_builtin_num_groups_y = 4, ///< num_groups.y + tinytc_builtin_num_groups_z = 5, ///< num_groups.z + tinytc_builtin_num_subgroups_x = 6, ///< num_subgroups.x + tinytc_builtin_num_subgroups_y = 7, ///< num_subgroups.y + tinytc_builtin_subgroup_size = 8, ///< subgroup_size + tinytc_builtin_subgroup_id_x = 9, ///< subgroup_id.x + tinytc_builtin_subgroup_id_y = 10, ///< subgroup_id.y + tinytc_builtin_subgroup_linear_id = 11, ///< subgroup_linear_id + tinytc_builtin_subgroup_local_id = 12 ///< subgroup_local_id +} tinytc_builtin_t; + //! Compare operation typedef enum { tinytc_cmp_condition_eq = 0, ///< equals @@ -252,12 +341,111 @@ typedef enum { tinytc_cmp_condition_le = 5 ///< less or equal than } tinytc_cmp_condition_t; +//! Math operations (unary) +typedef enum { + tinytc_math_unary_cos = 0, ///< Cosine + tinytc_math_unary_sin = 1, ///< Sine + tinytc_math_unary_exp = 2, ///< Base-e exponential + tinytc_math_unary_exp2 = 3, ///< Base-2 exponential + tinytc_math_unary_native_cos = 4, ///< Native cosine + tinytc_math_unary_native_sin = 5, ///< Native sine + tinytc_math_unary_native_exp = 6, ///< Native base-e exponential + tinytc_math_unary_native_exp2 = 7, ///< Native base-2 exponential +} tinytc_math_unary_t; + +//! Group arithmetic +typedef enum { + tinytc_group_arithmetic_add = 0, ///< Group add + tinytc_group_arithmetic_max = 1, ///< Group max + tinytc_group_arithmetic_min = 2 ///< Group min +} tinytc_group_arithmetic_t; + +//! Group operation +typedef enum { + tinytc_group_operation_exclusive_scan = 0, ///< Exclusive scan + tinytc_group_operation_inclusive_scan = 1, ///< Inclusive scan + tinytc_group_operation_reduce = 2 ///< Reduction +} tinytc_group_operation_t; + +//! Reduce mode +typedef enum { + tinytc_reduce_mode_row = 0, ///< Reduction over rows + tinytc_reduce_mode_column = 1, ///< Reducation over columns +} tinytc_reduce_mode_t; + //! Transpose typedef enum { tinytc_transpose_N = 0, ///< No transpose tinytc_transpose_T = 1 ///< Transpose } tinytc_transpose_t; +//! Address space +typedef enum { + tinytc_address_space_global = 0x1, ///< Global memory + tinytc_address_space_local = 0x2 ///< Local memory, returned by alloca +} tinytc_address_space_t; + +//! Type for combination of address spaces +typedef uint32_t tinytc_address_spaces_t; + +/** + * @brief Checked flag + * + * Checks can be combined by bitwise or, that is, + * + * tinytc_checked_flag_both = tinytc_checked_flag_rows | tinytc_checked_flag_cols + * tinytc_checked_flag_rows = tinytc_checked_flag_rows | tinytc_checked_flag_none + */ +typedef enum { + tinytc_checked_flag_none = 0x0, ///< Perform no checks + tinytc_checked_flag_rows = 0x1, ///< Check for out-of-bound rows + tinytc_checked_flag_cols = 0x2, ///< Check for out-of-bound cols + tinytc_checked_flag_both = 0x3 ///< Check for out-of-bound rows and cols +} tinytc_checked_flag_t; + +//! Store flag +typedef enum { + tinytc_store_flag_regular = 0, ///< Non-atomic store + tinytc_store_flag_atomic = 1, ///< Atomic store + tinytc_store_flag_atomic_add = 2, ///< Atomic fetch add + tinytc_store_flag_atomic_max = 3, ///< Atomic fetch max + tinytc_store_flag_atomic_min = 4 ///< Atomic fetch min +} tinytc_store_flag_t; + +//! Matrix use +typedef enum { + tinytc_matrix_use_a, ///< matrix_a + tinytc_matrix_use_b, ///< matrix_b + tinytc_matrix_use_acc ///< matrix_acc +} tinytc_matrix_use_t; + +//! SPIR-V features +typedef enum { + tinytc_spirv_feature_float16 = 0, ///< f16 support + tinytc_spirv_feature_float64 = 1, ///< f64 support + tinytc_spirv_feature_int64_atomics = 2, ///< i64 atomics support + tinytc_spirv_feature_groups = 3, ///< work group collectives + tinytc_spirv_feature_subgroup_dispatch = 4, ///< subgroup support + tinytc_spirv_feature_atomic_float16_add_local = 5, ///< f16 atomic add on local pointer + tinytc_spirv_feature_atomic_float16_add_global = 6, ///< f16 atomic add on global pointer + tinytc_spirv_feature_atomic_float32_add_local = 7, ///< f32 atomic add on local pointer + tinytc_spirv_feature_atomic_float32_add_global = 8, ///< f32 atomic add on global pointer + tinytc_spirv_feature_atomic_float64_add_local = 9, ///< f64 atomic add on local pointer + tinytc_spirv_feature_atomic_float64_add_global = 10, ///< f64 atomic add on global pointer + tinytc_spirv_feature_atomic_float16_min_max_local = 11, ///< f16 atomic min/max on local pointer + tinytc_spirv_feature_atomic_float16_min_max_global = + 12, ///< f16 atomic min/max on global pointer + tinytc_spirv_feature_atomic_float32_min_max_local = 13, ///< f32 atomic min/max on local pointer + tinytc_spirv_feature_atomic_float32_min_max_global = + 14, ///< f32 atomic min/max on global pointer + tinytc_spirv_feature_atomic_float64_min_max_local = 15, ///< f64 atomic min/max on local pointer + tinytc_spirv_feature_atomic_float64_min_max_global = + 16, ///< f64 atomic minmax on global pointer + tinytc_spirv_feature_bfloat16_conversion = 17, ///< bf16 -> f32 and f32 -> bf16 conversion + tinytc_spirv_feature_subgroup_buffer_block_io = 18 ///< subgroup block read/write support +} tinytc_spirv_feature_t; +#define TINYTC_NUMBER_OF_SPIRV_FEATURES 19 // @todo Keep up to date with tinytc_spirv_feature_t + //! Core features that may be optionally enabled typedef enum { /** @@ -273,6 +461,8 @@ typedef enum { //! Type for combination of core feature flags typedef uint32_t tinytc_core_feature_flags_t; +//! + /** * @brief IP versions for Intel GPUs * @@ -282,8 +472,10 @@ typedef uint32_t tinytc_core_feature_flags_t; */ typedef enum { tinytc_intel_gpu_architecture_tgl = 0x03000000, ///< Tiger Lake - tinytc_intel_gpu_architecture_pvc = 0x030f0007 ///< Ponte Vecchio + tinytc_intel_gpu_architecture_pvc = 0x030f0000, ///< Ponte Vecchio + tinytc_intel_gpu_architecture_bmg = 0x05004000 ///< BMG } tinytc_intel_gpu_architecture_t; +#define TINYTC_INTEL_GPU_ARCHITECTURE_SUB_VERSION_BITS 0xfff //! Target binary format typedef enum { @@ -291,6 +483,12 @@ typedef enum { tinytc_bundle_format_native = 1 ///< Native device binary } tinytc_bundle_format_t; +//! Flags for optimizer +typedef enum { + tinytc_optflag_unsafe_fp_math = 0 ///< Unsafe floating point math (e.g. 0.0 * x = 0.0) +} tinytc_optflag_t; +#define TINYTC_NUMBER_OF_OPTFLAGS 1 // @todo Keep up to date with tinytc_optflag_t + //! Memory object type typedef enum { tinytc_mem_type_buffer = 0x0, ///< Buffer object (e.g. cl_mem) @@ -315,9 +513,17 @@ typedef enum { //! @brief Bool type {0,1} typedef uint8_t tinytc_bool_t; +//! @struct tinytc_attr +//! @brief Opaque struct for an attribute +struct tinytc_attr; // IWYU pragma: export +//! @brief data_type handle +typedef struct tinytc_attr *tinytc_attr_t; +//! @brief const data_type handle +typedef const struct tinytc_attr *const_tinytc_attr_t; + //! @struct tinytc_data_type //! @brief Opaque struct for a data type -struct tinytc_data_type; +struct tinytc_data_type; // IWYU pragma: export //! @brief data_type handle typedef struct tinytc_data_type *tinytc_data_type_t; //! @brief const data_type handle @@ -325,7 +531,7 @@ typedef const struct tinytc_data_type *const_tinytc_data_type_t; //! @struct tinytc_value //! @brief Opaque struct for a value -struct tinytc_value; +struct tinytc_value; // IWYU pragma: export //! @brief value handle typedef struct tinytc_value *tinytc_value_t; //! @brief const value handle @@ -333,15 +539,18 @@ typedef const struct tinytc_value *const_tinytc_value_t; //! @struct tinytc_inst //! @brief Opaque struct for an instruction -struct tinytc_inst; +struct tinytc_inst; // IWYU pragma: export //! @brief inst handle typedef struct tinytc_inst *tinytc_inst_t; //! @brief const inst handle typedef const struct tinytc_inst *const_tinytc_inst_t; +//! @brief inst iterator handle +typedef struct tinytc_inst *tinytc_inst_iterator_t; + //! @struct tinytc_region //! @brief Opaque struct for a region -struct tinytc_region; +struct tinytc_region; // IWYU pragma: export //! @brief region handle typedef struct tinytc_region *tinytc_region_t; //! @brief const region handle @@ -349,7 +558,7 @@ typedef const struct tinytc_region *const_tinytc_region_t; //! @struct tinytc_func //! @brief Opaque struct for a function -struct tinytc_func; +struct tinytc_func; // IWYU pragma: export //! @brief func handle typedef struct tinytc_func *tinytc_func_t; //! @brief const func handle @@ -357,39 +566,39 @@ typedef const struct tinytc_func *const_tinytc_func_t; //! @struct tinytc_prog //! @brief Opaque struct for a program -struct tinytc_prog; +struct tinytc_prog; // IWYU pragma: export //! @brief prog handle typedef struct tinytc_prog *tinytc_prog_t; //! @brief const prog handle typedef const struct tinytc_prog *const_tinytc_prog_t; +//! @struct tinytc_spv_mod +//! @brief Opaque struct for a SPIR-V module +struct tinytc_spv_mod; // IWYU pragma: export +//! @brief spv_mod handle +typedef struct tinytc_spv_mod *tinytc_spv_mod_t; +//! @brief const spv_mod handle +typedef const struct tinytc_spv_mod *const_tinytc_spv_mod_t; + //! @struct tinytc_core_info; //! @brief Opaque struct for core information -struct tinytc_core_info; +struct tinytc_core_info; // IWYU pragma: export //! @brief core_info handle typedef struct tinytc_core_info *tinytc_core_info_t; //! @brief const core_info handle typedef const struct tinytc_core_info *const_tinytc_core_info_t; -//! @struct tinytc_source; -//! @brief Opaque struct for source text -struct tinytc_source; -//! @brief source handle -typedef struct tinytc_source *tinytc_source_t; -//! @brief const source handle -typedef const struct tinytc_source *const_tinytc_source_t; - -//! @struct tintyc_source_context -//! @brief Opaque struct for source context -struct tinytc_source_context; -//! @brief source_context handle -typedef struct tinytc_source_context *tinytc_source_context_t; -//! @brief const source_context handle -typedef const struct tinytc_source_context *const_tinytc_source_context_t; +//! @struct tintyc_compiler_context +//! @brief Opaque struct for compiler context +struct tinytc_compiler_context; // IWYU pragma: export +//! @brief compiler_context handle +typedef struct tinytc_compiler_context *tinytc_compiler_context_t; +//! @brief const compiler_context handle +typedef const struct tinytc_compiler_context *const_tinytc_compiler_context_t; //! @struct tinytc_binary; //! @brief Opaque struct for a binary -struct tinytc_binary; +struct tinytc_binary; // IWYU pragma: export //! @brief binary handle typedef struct tinytc_binary *tinytc_binary_t; //! @brief const binary handle @@ -397,7 +606,7 @@ typedef const struct tinytc_binary *const_tinytc_binary_t; //! @struct tinytc_recipe; //! @brief Opaque struct for a recipe -struct tinytc_recipe; +struct tinytc_recipe; // IWYU pragma: export //! @brief recipe handle typedef struct tinytc_recipe *tinytc_recipe_t; //! @brief const recipe handle @@ -405,7 +614,7 @@ typedef const struct tinytc_recipe *const_tinytc_recipe_t; //! @struct tinytc_recipe_handler; //! @brief Opaque struct for a recipe handler -struct tinytc_recipe_handler; +struct tinytc_recipe_handler; // IWYU pragma: export //! @brief recipe_handler handle typedef struct tinytc_recipe_handler *tinytc_recipe_handler_t; //! @brief const recipe_handler handle @@ -415,6 +624,12 @@ typedef const struct tinytc_recipe_handler *const_tinytc_recipe_handler_t; ////////// Structs ///////// //////////////////////////// +//! @brief Named attribute +typedef struct tinytc_named_attr { + tinytc_attr_t name; ///< Name stored as string attribute + tinytc_attr_t attr; ///< Attribute +} tinytc_named_attr_t; + //! @brief Source code position typedef struct tinytc_position { int32_t source_id; ///< Source file identifier; 0 is "unknown source" @@ -428,6 +643,20 @@ typedef struct tinytc_location { tinytc_position_t end; ///< End position } tinytc_location_t; +//////////////////////////// +///////// Callbacks //////// +//////////////////////////// + +/** + * @brief Signature for error reporting callback + * + * @param what Error description + * @param location Source code location + * @param user_data user data that is passed on to callback + */ +typedef void (*tinytc_error_reporter_t)(char const *what, const tinytc_location_t *location, + void *user_data); + #ifdef __cplusplus } #endif diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp index 06514da5..51df5dbf 100644 --- a/include/tinytc/types.hpp +++ b/include/tinytc/types.hpp @@ -49,6 +49,9 @@ enum class status { unsupported_backend = tinytc_status_unsupported_backend, invalid_kernel_arguments = tinytc_status_invalid_kernel_arguments, unsupported_device = tinytc_status_unsupported_device, + invalid_core_info = tinytc_status_invalid_core_info, + unknown_pass_name = tinytc_status_unknown_pass_name, + not_implemented = tinytc_status_not_implemented, // IR errors ir_out_of_bounds = tinytc_status_ir_out_of_bounds, ir_invalid_shape = tinytc_status_ir_invalid_shape, @@ -56,20 +59,72 @@ enum class status { ir_shape_stride_mismatch = tinytc_status_ir_shape_stride_mismatch, ir_scalar_mismatch = tinytc_status_ir_scalar_mismatch, ir_invalid_number_of_indices = tinytc_status_ir_invalid_number_of_indices, + ir_expected_boolean = tinytc_status_ir_expected_boolean, ir_expected_scalar = tinytc_status_ir_expected_scalar, + ir_expected_int = tinytc_status_ir_expected_int, + ir_expected_float = tinytc_status_ir_expected_float, + ir_expected_complex = tinytc_status_ir_expected_complex, + ir_expected_i32 = tinytc_status_ir_expected_i32, + ir_expected_index = tinytc_status_ir_expected_index, + ir_expected_coopmatrix = tinytc_status_ir_expected_coopmatrix, + ir_expected_coopmatrix_or_scalar = tinytc_status_ir_expected_coopmatrix_or_scalar, + ir_expected_coopmatrix_scalar_or_boolean = + tinytc_status_ir_expected_coopmatrix_scalar_or_boolean, ir_expected_memref = tinytc_status_ir_expected_memref, ir_expected_memref_or_scalar = tinytc_status_ir_expected_memref_or_scalar, ir_expected_memref_or_group = tinytc_status_ir_expected_memref_or_group, - ir_expected_vector_or_matrix = tinytc_status_ir_expected_vector_or_matrix, + ir_expected_memref_order_0 = tinytc_status_ir_expected_memref_order_0, + ir_expected_memref_order_1 = tinytc_status_ir_expected_memref_order_1, + ir_expected_memref_order_2 = tinytc_status_ir_expected_memref_order_2, + ir_expected_memref_order_0_or_1 = tinytc_status_ir_expected_memref_order_0_or_1, + ir_expected_memref_order_1_or_2 = tinytc_status_ir_expected_memref_order_1_or_2, + ir_expected_memref_order_0_1_or_2 = tinytc_status_ir_expected_memref_order_0_1_or_2, ir_unexpected_yield = tinytc_status_ir_unexpected_yield, ir_yield_mismatch = tinytc_status_ir_yield_mismatch, - ir_multiple_dynamic_modes = tinytc_status_ir_multiple_dynamic_modes, + ir_subview_mismatch = tinytc_status_ir_subview_mismatch, ir_invalid_slice = tinytc_status_ir_invalid_slice, ir_expand_shape_order_too_small = tinytc_status_ir_expand_shape_order_too_small, ir_expand_shape_mismatch = tinytc_status_ir_expand_shape_mismatch, ir_collective_called_from_spmd = tinytc_status_ir_collective_called_from_spmd, ir_fp_unsupported = tinytc_status_ir_fp_unsupported, - // Level Zero errors + ir_spmd_called_from_collective = tinytc_status_ir_spmd_called_from_collective, + ir_expected_local_address_space = tinytc_status_ir_expected_local_address_space, + ir_expected_global_address_space = tinytc_status_ir_expected_global_address_space, + ir_address_space_mismatch = tinytc_status_ir_address_space_mismatch, + ir_invalid_offset = tinytc_status_ir_invalid_offset, + ir_int_unsupported = tinytc_status_ir_int_unsupported, + ir_boolean_unsupported = tinytc_status_ir_boolean_unsupported, + ir_complex_unsupported = tinytc_status_ir_complex_unsupported, + ir_coopmatrix_unsupported = tinytc_status_ir_coopmatrix_unsupported, + ir_forbidden_cast = tinytc_status_ir_forbidden_cast, + ir_invalid_beta = tinytc_status_ir_invalid_beta, + ir_init_return_mismatch = tinytc_status_ir_init_return_mismatch, + ir_invalid_matrix_use = tinytc_status_ir_invalid_matrix_use, + ir_unsupported_coopmatrix_shape = tinytc_status_ir_unsupported_coopmatrix_shape, + ir_forbidden_promotion = tinytc_status_ir_forbidden_promotion, + ir_constant_mismatch = tinytc_status_ir_constant_mismatch, + ir_insufficient_alignment = tinytc_status_ir_insufficient_alignment, + ir_must_have_yield = tinytc_status_ir_must_have_yield, + ir_yield_in_else_branch_missing = tinytc_status_ir_yield_in_else_branch_missing, + ir_from_to_mismatch = tinytc_status_ir_from_to_mismatch, + ir_operand_type_must_match_return_type = tinytc_status_ir_operand_type_must_match_return_type, + ir_invalid_stride = tinytc_status_ir_invalid_stride, + ir_init_return_type_mismatch = tinytc_status_ir_init_return_type_mismatch, + ir_invalid_alignment = tinytc_status_ir_invalid_alignment, + ir_value_still_has_uses = tinytc_status_ir_value_still_has_uses, + ir_expected_array_attribute = tinytc_status_ir_expected_array_attribute, + ir_expected_boolean_attribute = tinytc_status_ir_expected_boolean_attribute, + ir_expected_dictionary_attribute = tinytc_status_ir_expected_dictionary_attribute, + ir_expected_integer_attribute = tinytc_status_ir_expected_integer_attribute, + ir_expected_string_attribute = tinytc_status_ir_expected_string_attribute, + ir_duplicate_key_in_dictionary = tinytc_status_ir_duplicate_key_in_dictionary, + ir_unexpected_array_attribute_size = tinytc_status_ir_unexpected_array_attribute_size, + ir_expected_non_scalar_memref = tinytc_status_ir_expected_non_scalar_memref, + spirv_forbidden_forward_declaration = tinytc_status_spirv_forbidden_forward_declaration, + spirv_undefined_value = tinytc_status_spirv_undefined_value, + spirv_missing_dope_vector = tinytc_status_spirv_missing_dope_vector, + spirv_unsupported_atomic_data_type = tinytc_status_spirv_unsupported_atomic_data_type, + spirv_required_feature_unavailable = tinytc_status_spirv_required_feature_unavailable, ze_result_not_ready = tinytc_status_ze_result_not_ready, ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, @@ -197,14 +252,17 @@ enum class status { //! Scalar types enum class scalar_type { - i1 = tinytc_scalar_type_i1, ///< Signed 1 bit integer (boolean) i8 = tinytc_scalar_type_i8, ///< Signed 8 bit integer i16 = tinytc_scalar_type_i16, ///< Signed 16 bit integer i32 = tinytc_scalar_type_i32, ///< Signed 32 bit integer i64 = tinytc_scalar_type_i64, ///< Signed 64 bit integer index = tinytc_scalar_type_index, ///< Unsigned Integer type for indices + f16 = tinytc_scalar_type_f16, ///< Half precision floating point (16 bit) + bf16 = tinytc_scalar_type_bf16, ///< Brain floating point format with 16 bits f32 = tinytc_scalar_type_f32, ///< Single precision floating point (32 bit) - f64 = tinytc_scalar_type_f64 ///< Double precision floating point (64 bit) + f64 = tinytc_scalar_type_f64, ///< Double precision floating point (64 bit) + c32 = tinytc_scalar_type_c32, ///< Single precision complex (2x32 bit) + c64 = tinytc_scalar_type_c64 ///< Double precision complex (2x64 bit) }; //! Arithmetic operations @@ -218,13 +276,36 @@ enum class arithmetic { shr = tinytc_arithmetic_shr, ///< arithmetic right shift and_ = tinytc_arithmetic_and, ///< bitwise and or_ = tinytc_arithmetic_or, ///< bitwise or - xor_ = tinytc_arithmetic_xor ///< bitwise xor + xor_ = tinytc_arithmetic_xor, ///< bitwise xor + min = tinytc_arithmetic_min, ///< minimum + max = tinytc_arithmetic_max ///< maximum }; //! Arithmetic operations (unary) enum class arithmetic_unary { - neg = tinytc_arithmetic_unary_neg, ///< negation - not_ = tinytc_arithmetic_unary_not ///< bitwise not + neg = tinytc_arithmetic_unary_neg, ///< negation + not_ = tinytc_arithmetic_unary_not, ///< bitwise not + abs = tinytc_arithmetic_unary_abs, ///< absolute value + conj = tinytc_arithmetic_unary_conj, ///< complex conjugate + im = tinytc_arithmetic_unary_im, ///< imaginary part + re = tinytc_arithmetic_unary_re ///< real part +}; + +//! Builtin values +enum class builtin { + group_id_x = tinytc_builtin_group_id_x, ///< group_id.x + group_id_y = tinytc_builtin_group_id_y, ///< group_id.y + group_id_z = tinytc_builtin_group_id_z, ///< group_id.z + num_groups_x = tinytc_builtin_num_groups_x, ///< num_groups.x + num_groups_y = tinytc_builtin_num_groups_y, ///< num_groups.y + num_groups_z = tinytc_builtin_num_groups_z, ///< num_groups.z + num_subgroups_x = tinytc_builtin_num_subgroups_x, ///< num_subgroups.x + num_subgroups_y = tinytc_builtin_num_subgroups_y, ///< num_subgroups.y + subgroup_size = tinytc_builtin_subgroup_size, ///< subgroup_size + subgroup_id_x = tinytc_builtin_subgroup_id_x, ///< subgroup_id.x + subgroup_id_y = tinytc_builtin_subgroup_id_y, ///< subgroup_id.y + subgroup_linear_id = tinytc_builtin_subgroup_linear_id, ///< subgroup_linear_id + subgroup_local_id = tinytc_builtin_subgroup_local_id ///< subgroup_local_id }; //! Compare operation @@ -236,12 +317,98 @@ enum class cmp_condition { lt = tinytc_cmp_condition_lt, ///< less than le = tinytc_cmp_condition_le ///< less or equal than }; + +//! Group arithmetic +enum class group_arithmetic { + add = tinytc_group_arithmetic_add, ///< Group add + max = tinytc_group_arithmetic_max, ///< Group max + min = tinytc_group_arithmetic_min ///< Group min +}; + +//! Group operation +enum class group_operation { + exclusive_scan = tinytc_group_operation_exclusive_scan, ///< Exclusive scan + inclusive_scan = tinytc_group_operation_inclusive_scan, ///< Inclusive scan + reduce = tinytc_group_operation_reduce ///< Reduction +}; + +//! Reduce mode +enum class reduce_mode { + row = tinytc_reduce_mode_row, ///< Row reduction + column = tinytc_reduce_mode_column, ///< Column reduction +}; + +//! Math operations (unary) +enum class math_unary { + cos = tinytc_math_unary_cos, ///< Cosine + sin = tinytc_math_unary_sin, ///< Sine + exp = tinytc_math_unary_exp, ///< Base-e exponential + exp2 = tinytc_math_unary_exp2, ///< Base-2 exponential + native_cos = tinytc_math_unary_native_cos, ///< Native cosine + native_sin = tinytc_math_unary_native_sin, ///< Native sine + native_exp = tinytc_math_unary_native_exp, ///< Native base-e exponential + native_exp2 = tinytc_math_unary_native_exp2, ///< Native base-2 exponential +}; + //! Transpose enum class transpose { N = tinytc_transpose_N, ///< no transpose T = tinytc_transpose_T ///< transpose }; +//! Address space +enum class address_space { + global = tinytc_address_space_global, ///< Global memory + local = tinytc_address_space_local ///< Local memory, returned by alloca +}; + +//! Checked flag +enum class checked_flag { + none = tinytc_checked_flag_none, ///< Perform no checks + rows = tinytc_checked_flag_rows, ///< Check for out-of-bound rows + cols = tinytc_checked_flag_cols, ///< Check for out-of-bound cols + both = tinytc_checked_flag_both ///< Check for out-of-bound rows and cols +}; + +//! Store flag +enum class store_flag { + regular = tinytc_store_flag_regular, ///< Non-atomic non-block store + atomic = tinytc_store_flag_atomic, ///< Atomic store + atomic_add = tinytc_store_flag_atomic_add, ///< Atomic fetch add + atomic_max = tinytc_store_flag_atomic_max, ///< Atomic fetch max + atomic_min = tinytc_store_flag_atomic_min ///< Atomic fetch min +}; + +//! Matrix use +enum class matrix_use { + a = tinytc_matrix_use_a, ///< matrix_a + b = tinytc_matrix_use_b, ///< matrix_b + acc = tinytc_matrix_use_acc ///< matrix_acc +}; + +//! @brief Cf. @ref tinytc_spirv_feature_t +enum class spirv_feature { + float16 = tinytc_spirv_feature_float16, + float64 = tinytc_spirv_feature_float64, + int64_atomics = tinytc_spirv_feature_int64_atomics, + groups = tinytc_spirv_feature_groups, + subgroup_dispatch = tinytc_spirv_feature_subgroup_dispatch, + atomic_float16_add_local = tinytc_spirv_feature_atomic_float16_add_local, + atomic_float16_add_global = tinytc_spirv_feature_atomic_float16_add_global, + atomic_float32_add_local = tinytc_spirv_feature_atomic_float32_add_local, + atomic_float32_add_global = tinytc_spirv_feature_atomic_float32_add_global, + atomic_float64_add_local = tinytc_spirv_feature_atomic_float64_add_local, + atomic_float64_add_global = tinytc_spirv_feature_atomic_float64_add_global, + atomic_float16_min_max_local = tinytc_spirv_feature_atomic_float16_min_max_local, + atomic_float16_min_max_global = tinytc_spirv_feature_atomic_float16_min_max_global, + atomic_float32_min_max_local = tinytc_spirv_feature_atomic_float32_min_max_local, + atomic_float32_min_max_global = tinytc_spirv_feature_atomic_float32_min_max_global, + atomic_float64_min_max_local = tinytc_spirv_feature_atomic_float64_min_max_local, + atomic_float64_min_max_global = tinytc_spirv_feature_atomic_float64_min_max_global, + bfloat16_conversion = tinytc_spirv_feature_bfloat16_conversion, + subgroup_buffer_block_io = tinytc_spirv_feature_subgroup_buffer_block_io, +}; + //! @brief Cf. @ref tinytc_core_feature_flag_t enum class core_feature_flag { large_register_file = tinytc_core_feature_flag_large_register_file }; @@ -257,6 +424,9 @@ enum class bundle_format { native = tinytc_bundle_format_native ///< Native device binary }; +//! Flags for optimizer +enum class optflag { unsafe_fp_math = tinytc_optflag_unsafe_fp_math }; + //! Memory object type enum class mem_type { buffer = tinytc_mem_type_buffer, ///< Buffer object (e.g. cl_mem) @@ -278,10 +448,14 @@ enum class support_level { /////// Type aliases /////// //////////////////////////// +//! @brief Alias for tinytc_named_attr in namespace tinytc +using named_attr = ::tinytc_named_attr; //! @brief Alias for tinytc_position in namespace tinytc using position = ::tinytc_position; //! @brief Alias for tinytc_location in namespace tinytc using location = ::tinytc_location; +//! @brief Alias for tinytc_error_reporter_t in namespace tinytc +using error_reporter_t = ::tinytc_error_reporter_t; } // namespace tinytc diff --git a/include/tinytc/version.h.in b/include/tinytc/version.h.in index 5559caae..36cfcc36 100644 --- a/include/tinytc/version.h.in +++ b/include/tinytc/version.h.in @@ -5,12 +5,12 @@ #define VERSION_20240408_H // clang-format off -#define TINYTC_VERSION_MAJOR = @GIT_MAJOR_VERSION@; ///< Major version (X.x.x) -#define TINYTC_VERSION_MINOR = @GIT_MINOR_VERSION@; ///< Minor version (x.X.x) -#define TINYTC_VERSION_PATCH = @GIT_PATCH_VERSION@; ///< Patch version (x.x.X) -#define TINYTC_VERSION_HASH = "@GIT_COMMIT@"; ///< Git commit hash -#define TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE = @GIT_COMMITS_SINCE_RELEASE@; ///< Number of commits since last tag -#define TINYTC_VERSION_DESCRIPTION = "v@GIT_MAJOR_VERSION@.@GIT_MINOR_VERSION@.@GIT_PATCH_VERSION@-@GIT_COMMITS_SINCE_RELEASE@-@GIT_COMMIT@"; ///< Version string (vx.x.x-x-x) +#define TINYTC_VERSION_MAJOR @GIT_MAJOR_VERSION@ ///< Major version (X.x.x) +#define TINYTC_VERSION_MINOR @GIT_MINOR_VERSION@ ///< Minor version (x.X.x) +#define TINYTC_VERSION_PATCH @GIT_PATCH_VERSION@ ///< Patch version (x.x.X) +#define TINYTC_VERSION_HASH "@GIT_COMMIT@" ///< Git commit hash +#define TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE @GIT_COMMITS_SINCE_RELEASE@ ///< Number of commits since last tag +#define TINYTC_VERSION_DESCRIPTION "v@GIT_MAJOR_VERSION@.@GIT_MINOR_VERSION@.@GIT_PATCH_VERSION@-@GIT_COMMITS_SINCE_RELEASE@-@GIT_COMMIT@" ///< Version string (vx.x.x-x-x) // clang-format on #endif // VERSION_20240408_H diff --git a/iwyu.imp b/iwyu.imp new file mode 100644 index 00000000..ed56ff03 --- /dev/null +++ b/iwyu.imp @@ -0,0 +1,4 @@ +[ + { "symbol": ["std::array", "private", "", "public"] }, + { "symbol": ["std::pair", "private", "", "public"] } +] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3d74f3db..9ca401b9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,55 +12,106 @@ else () set(type static) endif () -find_package(clir 0.5.1 REQUIRED ${type}) -find_package(re2c REQUIRED) -find_package(BISON 3.8.2 REQUIRED) - set(SOURCES + analysis/aa_results.cpp + analysis/alias.cpp + analysis/cfg.cpp + analysis/gcd.cpp + analysis/equal.cpp + analysis/stack.cpp + attribute.cpp binary.cpp codegen_tools.cpp compiler.cpp + compiler_context.cpp + compiler_context_cache.cpp + coopmatrix_layout.cpp data_type.cpp device_info.cpp error.cpp func.cpp - gemm_generator.cpp + gemm_tools.cpp + half.cpp inst.cpp location.cpp + matrix_ext_info.cpp + node/attr_node.cpp node/data_type_node.cpp + node/function_node.cpp node/inst_node.cpp + node/region_node.cpp + node/program_node.cpp + node/value_node.cpp parser/parse_context.cpp parser.cpp - passes.cpp - precision_helper.cpp + pass/check_ir.cpp + pass/clone.cpp + pass/constant_folding.cpp + pass/constant_propagation.cpp + pass/convert_to_spirv.cpp + pass/dead_code_elimination.cpp + pass/dump_cfg.cpp + pass/dump_def_use.cpp + pass/dump_gcd.cpp + pass/dump_ir.cpp + pass/insert_barrier.cpp + pass/insert_lifetime_stop.cpp + pass/lower_coopmatrix.cpp + pass/lower_foreach.cpp + pass/lower_linalg.cpp + pass/slot_tracker.cpp + pass/stack.cpp + pass/work_group_size.cpp prog.cpp recipe.cpp recipe/small_gemm_batched.cpp recipe/tall_and_skinny.cpp region.cpp - required_extensions.cpp scalar_type.cpp - source.cpp + spv/block2d_diy.cpp + spv/capex_util.cpp + spv/converter.cpp + spv/converter_aux.cpp + spv/coopmatrix_impl.cpp + spv/coopmatrix_impl_block.cpp + spv/coopmatrix_impl_dpas.cpp + spv/dope_vector.cpp + spv/inst_assembler.cpp + spv/matrix_walker.cpp + spv/module.cpp + spv/names.cpp + spv/opencl.std.cpp + spv/pass/assemble.cpp + spv/pass/assign_ids.cpp + spv/pass/dump_asm.cpp + spv/pass/capex.cpp + spv/uniquifier.cpp + support/temp_counter.cpp tiling.cpp value.cpp - visitor/aa_results.cpp - visitor/alias_analysis.cpp - visitor/check_ir.cpp - visitor/dump_ir.cpp - visitor/equal.cpp - visitor/insert_barrier.cpp - visitor/lifetime_analysis.cpp - visitor/metadata.cpp - visitor/opencl_ast.cpp - visitor/slot_tracker.cpp - visitor/stack.cpp - visitor/work_group_size.cpp -) -set(RE2C_SOURCES - parser/lexer.re + support/walk.cpp ) -BISON_TARGET(parser parser/parser_impl.yy ${CMAKE_CURRENT_BINARY_DIR}/parser/parser_impl.cpp - DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/parser/parser_impl.hpp) + +add_library(tinytc-objects OBJECT ${SOURCES}) + +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/parser/lexer.cpp) + message(STATUS "Pre-generated lexer available -- skipping re2c dependency") + target_sources(tinytc-objects PRIVATE parser/lexer.cpp) +else() + find_package(re2c REQUIRED) + add_re2c_to_target(TARGET tinytc-objects SOURCES parser/lexer.re) +endif() + +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/parser/parser_impl.cpp) + message(STATUS "Pre-generated parser available -- skipping bison dependency") + target_sources(tinytc-objects PRIVATE parser/parser_impl.cpp) +else() + find_package(BISON 3.8.2 REQUIRED) + BISON_TARGET(parser parser/parser_impl.yy ${CMAKE_CURRENT_BINARY_DIR}/parser/parser_impl.cpp + DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/parser/parser_impl.hpp) + target_sources(tinytc-objects PRIVATE ${BISON_parser_OUTPUTS}) +endif() + set(PUBLIC_HEADERS tinytc.h tinytc.hpp @@ -71,16 +122,12 @@ list(TRANSFORM PUBLIC_HEADERS PREPEND "${PROJECT_SOURCE_DIR}/include/tinytc/") add_flag_if_available_to_source_files(CXX "${BISON_parser_OUTPUTS}" "-Wno-unused-but-set-variable") -add_library(tinytc-objects OBJECT ${SOURCES} ${BISON_parser_OUTPUTS}) -add_re2c_to_target(TARGET tinytc-objects SOURCES ${RE2C_SOURCES}) set_cxx_common_options(tinytc-objects) -target_link_libraries(tinytc-objects PUBLIC clir::clir) target_compile_definitions(tinytc-objects PUBLIC "$<$>:TINYTC_STATIC_DEFINE>") add_library(tinytc $) add_library(tinytc::tinytc ALIAS tinytc) -target_link_libraries(tinytc PRIVATE clir::clir) set_cxx_common_options(tinytc) # Generate export header @@ -98,6 +145,7 @@ configure_file(${tinytc_version_header_in} ${tinytc_version_header}) target_include_directories(tinytc-objects PRIVATE "$" "$" + "$" ) target_include_directories(tinytc-objects PUBLIC "$" @@ -113,7 +161,6 @@ target_sources(tinytc PUBLIC FILE_SET HEADERS # install - install_lib(tinytc tinytc) # subdirs @@ -126,3 +173,14 @@ endif() if(BUILD_LEVEL_ZERO) add_subdirectory(ze) endif() + +# cpack + +include(GeneratedFiles) +get_generated_files(GENERATED_FILES tinytc-objects) +if(BUILD_OPENCL) + list(APPEND GENERATED_FILES ${GENERATED_FILES_TINYTC_CL}) +endif() + +include(CPackSetup) +cpack_setup() diff --git a/src/visitor/aa_results.cpp b/src/analysis/aa_results.cpp similarity index 81% rename from src/visitor/aa_results.cpp rename to src/analysis/aa_results.cpp index dbfafba3..391be83e 100644 --- a/src/visitor/aa_results.cpp +++ b/src/analysis/aa_results.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/aa_results.hpp" +#include "analysis/aa_results.hpp" #include "node/value_node.hpp" #include @@ -12,14 +12,14 @@ aa_results::aa_results(std::unordered_map allocs) : alias_(std::move(alias)), allocs_(std::move(allocs)) {} -auto aa_results::root(value_node const &a) -> value_node const * { +auto aa_results::root(value_node const &a) const -> value_node const * { auto root = &a; - if (alias_.find(root) != alias_.end()) { - root = alias_[root]; + if (auto it = alias_.find(root); it != alias_.end()) { + root = it->second; } return root; } -bool aa_results::alias(value_node const &a, value_node const &b) { +bool aa_results::alias(value_node const &a, value_node const &b) const { auto ra = root(a); auto rb = root(b); if (ra == rb) { diff --git a/src/visitor/aa_results.hpp b/src/analysis/aa_results.hpp similarity index 85% rename from src/visitor/aa_results.hpp rename to src/analysis/aa_results.hpp index e65f2613..3c553e64 100644 --- a/src/visitor/aa_results.hpp +++ b/src/analysis/aa_results.hpp @@ -13,21 +13,19 @@ namespace tinytc { class aa_results { public: - aa_results() = default; - auto root(::tinytc_value const &a) -> ::tinytc_value const *; - bool alias(::tinytc_value const &a, ::tinytc_value const &b); - - private: struct allocation { std::int64_t start, stop; }; aa_results(std::unordered_map<::tinytc_value const *, ::tinytc_value const *> alias, std::unordered_map<::tinytc_value const *, allocation> allocs); + + auto root(::tinytc_value const &a) const -> ::tinytc_value const *; + bool alias(::tinytc_value const &a, ::tinytc_value const &b) const; + + private: std::unordered_map<::tinytc_value const *, ::tinytc_value const *> alias_; std::unordered_map<::tinytc_value const *, allocation> allocs_; - - friend class alias_analyser; }; } // namespace tinytc diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp new file mode 100644 index 00000000..d3cefade --- /dev/null +++ b/src/analysis/alias.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/alias.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "support/walk.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/visit.hpp" + +#include +#include +#include + +namespace tinytc { + +class alias_analysis_visitor { + public: + void operator()(inst_node const &); + void operator()(alloca_inst const &a); + void operator()(expand_inst const &e); + void operator()(fuse_inst const &f); + void operator()(subview_inst const &s); + + auto get_result() && -> aa_results { return aa_results(std::move(alias_), std::move(allocs_)); } + + private: + std::unordered_map allocs_; + std::unordered_map alias_; +}; + +void alias_analysis_visitor::operator()(inst_node const &) {} +void alias_analysis_visitor::operator()(alloca_inst const &a) { + if (a.stack_ptr() >= 0) { + auto t = dyn_cast(a.result()->ty()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + allocs_[a.result()] = + aa_results::allocation{a.stack_ptr(), a.stack_ptr() + t->size_in_bytes()}; + } +} +void alias_analysis_visitor::operator()(expand_inst const &e) { + value_node const *source = &e.operand(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[e.result()] = source; +} +void alias_analysis_visitor::operator()(fuse_inst const &f) { + value_node const *source = &f.operand(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[f.result()] = source; +} + +void alias_analysis_visitor::operator()(subview_inst const &s) { + value_node const *source = &s.operand(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[s.result()] = source; +} + +auto alias_analysis::run_on_function(function_node &fn) -> aa_results { + auto visitor = alias_analysis_visitor{}; + + walk(fn, [&visitor](inst_node &i) { visit(visitor, i); }); + + return std::move(visitor).get_result(); +} + +} // namespace tinytc diff --git a/src/analysis/alias.hpp b/src/analysis/alias.hpp new file mode 100644 index 00000000..65b029d4 --- /dev/null +++ b/src/analysis/alias.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ALIAS_20240912_HPP +#define ALIAS_20240912_HPP + +#include "analysis/aa_results.hpp" +#include "node/function_node.hpp" + +namespace tinytc { + +class alias_analysis { + public: + auto run_on_function(function_node &fn) -> aa_results; +}; + +} // namespace tinytc + +#endif // ALIAS_20240912_HPP diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp new file mode 100644 index 00000000..2c9f2d42 --- /dev/null +++ b/src/analysis/cfg.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/cfg.hpp" +#include "node/inst_node.hpp" +#include "util/casting.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc { + +void control_flow_graph::insert_before(inst_node *before_inst, inst_node *new_inst) { + add_node(new_inst, adj_[before_inst].kind_max); + adj_[new_inst].pred = std::move(adj_[before_inst].pred); + add_edge(new_inst, before_inst); +} + +auto control_flow_graph::node_queue() const -> std::queue { + auto q = std::queue{}; + for (auto &[key, neighbors] : adj_) { + q.push(key); + } + return q; +} + +auto get_control_flow_graph(region_node &topreg) -> control_flow_graph { + auto cfg = control_flow_graph{}; + + const auto add_region = + [&cfg](region_node ®, region_kind kind_max, + auto &add_region_ref) -> std::pair> { + if (reg.empty()) { + return {}; + } + + auto pred_nodes = std::queue{}; + const auto visit_inst = [&](inst_node *node) { + bool empty_child_regions = true; + if (node->num_child_regions() > 0) { + for (auto &subreg : node->child_regions()) { + auto [substart, subexits] = + add_region_ref(subreg, std::max(kind_max, subreg.kind()), add_region_ref); + if (substart != nullptr && !subexits.empty()) { + empty_child_regions = false; + cfg.add_edge(node, substart); + if (isa(*node)) { + for (; !subexits.empty(); subexits.pop()) { + cfg.add_edge(subexits.front(), node); + } + pred_nodes.push(node); + } else { + for (; !subexits.empty(); subexits.pop()) { + pred_nodes.push(subexits.front()); + } + } + } + } + } + if (empty_child_regions) { + pred_nodes.push(node); + } + }; + + auto start = reg.begin().get(); + cfg.add_node(start, kind_max); + visit_inst(start); + + for (auto it = ++reg.begin(); it != reg.end(); ++it) { + inst_node *node = it.get(); + cfg.add_node(node, kind_max); + + for (; !pred_nodes.empty(); pred_nodes.pop()) { + cfg.add_edge(pred_nodes.front(), node); + } + + visit_inst(node); + } + + return std::make_pair(std::move(start), std::move(pred_nodes)); + }; + + add_region(topreg, topreg.kind(), add_region); + + return cfg; +} + +} // namespace tinytc diff --git a/src/analysis/cfg.hpp b/src/analysis/cfg.hpp new file mode 100644 index 00000000..db8ee336 --- /dev/null +++ b/src/analysis/cfg.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CFG_20240919_HPP +#define CFG_20240919_HPP + +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "util/iterator.hpp" + +#include +#include +#include + +namespace tinytc { + +class control_flow_graph { + public: + inline void add_node(inst_node *a, region_kind kind_max) { + adj_[a] = adjacency_list{}; + adj_[a].kind_max = kind_max; + } + inline void add_edge(inst_node *a, inst_node *b) { + adj_[a].succ.push_back(b); + adj_[b].pred.push_back(a); + } + void insert_before(inst_node *before_inst, inst_node *new_inst); + + auto node_queue() const -> std::queue; + + inline auto kind_max(inst_node *a) -> region_kind { return adj_[a].kind_max; } + + inline auto pred_begin(inst_node *a) { return adj_[a].pred.begin(); } + inline auto pred_end(inst_node *a) { return adj_[a].pred.end(); } + inline auto predecessors(inst_node *a) + -> iterator_range_wrapper::iterator> { + return {pred_begin(a), pred_end(a)}; + } + + inline auto succ_begin(inst_node *a) { return adj_[a].succ.begin(); } + inline auto succ_end(inst_node *a) { return adj_[a].succ.end(); } + inline auto successors(inst_node *a) + -> iterator_range_wrapper::iterator> { + return {succ_begin(a), succ_end(a)}; + } + + private: + struct adjacency_list { + region_kind kind_max = region_kind::mixed; + std::vector pred; + std::vector succ; + }; + std::unordered_map adj_; +}; + +auto get_control_flow_graph(region_node ®) -> control_flow_graph; + +} // namespace tinytc + +#endif // CFG_20240919_HPP diff --git a/src/visitor/equal.cpp b/src/analysis/equal.cpp similarity index 66% rename from src/visitor/equal.cpp rename to src/analysis/equal.cpp index d84f99c2..906c6655 100644 --- a/src/visitor/equal.cpp +++ b/src/analysis/equal.cpp @@ -1,27 +1,26 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/equal.hpp" -#include "tinytc/tinytc.hpp" - -#include +#include "analysis/equal.hpp" +#include "util/visit.hpp" #include -using clir::visit; - namespace tinytc { bool equal::operator()(data_type_node const &, data_type_node const &) { return false; } bool equal::operator()(void_data_type const &, void_data_type const &) { return true; } bool equal::operator()(group_data_type const &a, group_data_type const &b) { - return visit(*this, *a.ty(), *b.ty()); + return visit(*this, *a.ty(), *b.ty()) && a.size() == b.size() && a.offset() == b.offset(); } bool equal::operator()(memref_data_type const &a, memref_data_type const &b) { - return a.element_ty() == b.element_ty() && a.shape() == b.shape() && a.stride() == b.stride(); + return a.element_ty() == b.element_ty() && a.shape() == b.shape() && a.stride() == b.stride() && + a.addrspace() == b.addrspace(); } bool equal::operator()(scalar_data_type const &a, scalar_data_type const &b) { return a.ty() == b.ty(); } +bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b) { return visit(equal{}, a, b); } + } // namespace tinytc diff --git a/src/visitor/equal.hpp b/src/analysis/equal.hpp similarity index 87% rename from src/visitor/equal.hpp rename to src/analysis/equal.hpp index 657716b2..17f75c33 100644 --- a/src/visitor/equal.hpp +++ b/src/analysis/equal.hpp @@ -5,6 +5,7 @@ #define EQUAL_20240208_HPP #include "node/data_type_node.hpp" +#include "tinytc/types.h" namespace tinytc { @@ -18,6 +19,8 @@ class equal { bool operator()(scalar_data_type const &a, scalar_data_type const &b); }; +bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b); + } // namespace tinytc #endif // EQUAL_20240208_HPP diff --git a/src/analysis/gcd.cpp b/src/analysis/gcd.cpp new file mode 100644 index 00000000..8864979c --- /dev/null +++ b/src/analysis/gcd.cpp @@ -0,0 +1,374 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/gcd.hpp" +#include "codegen_tools.hpp" +#include "error.hpp" +#include "node/attr_node.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "support/walk.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include // IWYU pragma: keep +#include +#include +#include +#include +#include + +namespace tinytc { + +memref_info::memref_info(std::int64_t offset_gcd, std::vector shape_gcd, + std::vector stride_gcd) + : offset_gcd_(offset_gcd), shape_gcd_(std::move(shape_gcd)), + stride_gcd_(std::move(stride_gcd)) {} + +auto gcd_analysis_result::get(::const_tinytc_value_t a) const -> std::int64_t { + const auto g = get_if(a); + return g ? *g : 1; +} +auto gcd_analysis_result::get(::tinytc_value const &a) const -> std::int64_t { return get(&a); } +auto gcd_analysis_result::get_if(::const_tinytc_value_t a) const -> std::optional { + if (auto it = gcd_.find(a); it != gcd_.end()) { + return it->second; + } + return std::nullopt; +} +auto gcd_analysis_result::get_if(::tinytc_value const &a) const -> std::optional { + return get_if(&a); +} +void gcd_analysis_result::set(::tinytc_value const &a, std::int64_t g) { gcd_[&a] = g; } + +auto gcd_analysis_result::get_memref_if(::const_tinytc_value_t a) const -> memref_info const * { + if (auto it = memref_info_.find(a); it != memref_info_.end()) { + return &it->second; + } + return nullptr; +} +auto gcd_analysis_result::get_memref_if(::tinytc_value const &a) const -> memref_info const * { + return get_memref_if(&a); +} +void gcd_analysis_result::set_memref(::tinytc_value const &a, memref_info g) { + memref_info_[&a] = std::move(g); +} + +class gcd_helper { + public: + inline gcd_helper(std::int32_t default_alignment) : default_alignment_{default_alignment} {} + + void operator()(inst_node const &in); + void operator()(alloca_inst const &in); + void operator()(arith_inst const &in); + void operator()(arith_unary_inst const &in); + void operator()(cast_inst const &in); + void operator()(constant_inst const &in); + void operator()(expand_inst const &in); + void operator()(for_inst const &in); + void operator()(fuse_inst const &in); + void operator()(load_inst const &in); + void operator()(size_inst const &in); + void operator()(subgroup_broadcast_inst const &in); + void operator()(subview_inst const &in); + + void set_from_attributes(function_node const &fn); + + auto get_result() && { return std::move(gcd_); } + + private: + std::int32_t default_alignment_; + gcd_analysis_result gcd_; +}; + +void gcd_helper::operator()(inst_node const &) {} +void gcd_helper::operator()(alloca_inst const &in) { + if (in.stack_ptr() >= 0) { + const auto rt = get_memref_type(in.result(0).ty()); + std::int32_t i = rt->element_alignment(); + while (i < default_alignment_) { + const auto i2 = 2 * i; + if (in.stack_ptr() % i2 != 0) { + break; + } + i = i2; + } + // alloca shape/stride must be static, therefore we can set shape_gcd/stride_gcd to + // shape/stride + gcd_.set_memref(in.result(0), + memref_info(i / size(rt->element_ty()), rt->shape(), rt->stride())); + } +} +void gcd_helper::operator()(arith_inst const &in) { + auto compute_gcd = [&]() -> std::optional { + const auto ga = gcd_.get(in.a()); + const auto gb = gcd_.get(in.b()); + switch (in.operation()) { + case arithmetic::add: + return std::gcd(ga, gb); + case arithmetic::mul: + return ga * gb; + case arithmetic::div: { + return ga % gb == 0 ? ga / gb : 1; + } + default: + break; + } + return std::nullopt; + }; + auto g = compute_gcd(); + if (g) { + gcd_.set(in.result(0), *g); + } +} +void gcd_helper::operator()(arith_unary_inst const &in) { + auto compute_gcd = [&]() -> std::optional { + switch (in.operation()) { + case arithmetic_unary::abs: + case arithmetic_unary::not_: + return gcd_.get(in.a()); + default: + break; + } + return std::nullopt; + }; + auto g = compute_gcd(); + if (g) { + gcd_.set(in.result(0), *g); + } +} +void gcd_helper::operator()(cast_inst const &in) { + auto g = gcd_.get_if(in.a()); + if (g) { + gcd_.set(in.result(0), *g); + } +} +void gcd_helper::operator()(constant_inst const &in) { + if (std::holds_alternative(in.value())) { + gcd_.set(in.result(0), std::abs(std::get(in.value()))); + } +} +void gcd_helper::operator()(expand_inst const &in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + const auto mt = get_memref_type(in.operand()); + const auto offset_gcd = mi->offset_gcd(); + auto shape_gcd = std::vector{}; + auto stride_gcd = std::vector{}; + + auto static_shape = in.static_expand_shape(); + auto dyn_shape = in.expand_shape(); + + shape_gcd.reserve(mt->dim() + static_shape.size() - 1); + stride_gcd.reserve(mt->dim() + static_shape.size() - 1); + + for (std::int64_t i = 0; i < in.expanded_mode(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + + auto get_shape = [&, j = std::size_t{0}](std::int64_t s) mutable { + if (is_dynamic_value(s)) { + return gcd_.get(dyn_shape[j++]); + } + return s; + }; + stride_gcd.push_back(mi->stride_gcd(in.expanded_mode())); + shape_gcd.push_back(get_shape(static_shape[0])); + for (std::size_t j = 1; j < static_shape.size(); ++j) { + stride_gcd.push_back(stride_gcd.back() * shape_gcd.back()); + shape_gcd.push_back(get_shape(static_shape[j])); + } + + for (std::int64_t i = in.expanded_mode() + 1; i < mt->dim(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + + gcd_.set_memref(in.result(0), memref_info(offset_gcd, shape_gcd, stride_gcd)); + } +} +void gcd_helper::operator()(for_inst const &in) { + if (in.has_step()) { + auto g = std::gcd(gcd_.get(in.from()), gcd_.get(in.step())); + gcd_.set(in.loop_var(), g); + } +} +void gcd_helper::operator()(fuse_inst const &in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + const auto mt = get_memref_type(in.operand()); + const auto offset_gcd = mi->offset_gcd(); + auto shape_gcd = std::vector{}; + auto stride_gcd = std::vector{}; + + shape_gcd.reserve(mt->dim()); + stride_gcd.reserve(mt->dim()); + + std::int64_t i = 0; + for (; i < in.from(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + std::int64_t prod = mi->shape_gcd(i++); + for (; i <= in.to(); ++i) { + prod *= mi->shape_gcd(i); + } + shape_gcd.push_back(prod); + stride_gcd.push_back(mi->stride_gcd(in.from())); + for (i = in.to() + 1; i < mt->dim(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + + gcd_.set_memref(in.result(0), memref_info(offset_gcd, shape_gcd, stride_gcd)); + } +} +void gcd_helper::operator()(load_inst const &in) { + if (auto mi = gcd_.get_memref_if(in.operand()); + mi && isa(*in.operand().ty())) { + gcd_.set_memref(in.result(0), *mi); + } +} +void gcd_helper::operator()(size_inst const &in) { + const auto size = + visit(overloaded{[&](group_data_type const &g) -> std::int64_t { + return !is_dynamic_value(g.size()) ? g.size() : 1; + }, + [&](memref_data_type const &m) -> std::int64_t { + const auto s_i = m.shape(in.mode()); + if (is_dynamic_value(s_i)) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + return mi->shape_gcd(in.mode()); + } + return 1; + } + return s_i; + }, + [&](auto const &) -> std::int64_t { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + }}, + *in.operand().ty()); + + gcd_.set(in.result(0), size); +} +void gcd_helper::operator()(subgroup_broadcast_inst const &in) { + auto g = gcd_.get_if(in.a()); + if (g) { + gcd_.set(in.result(0), *g); + } +} +void gcd_helper::operator()(subview_inst const &in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + const auto mt = get_memref_type(in.operand()); + + auto shape_gcd = std::vector{}; + auto stride_gcd = std::vector{}; + + shape_gcd.reserve(mt->dim()); + stride_gcd.reserve(mt->dim()); + auto dyn_offsets = in.offsets(); + auto dyn_sizes = in.sizes(); + std::int64_t offset_gcd = mi->offset_gcd(); + for (std::int64_t i = 0, joffset = 0, jsize = 0; i < mt->dim(); ++i) { + const std::int64_t offset = in.static_offsets()[i]; + + auto const get_offset = [&]() -> std::int64_t { + if (is_dynamic_value(offset)) { + return gcd_.get(dyn_offsets[joffset++]); + } + return offset; + }; + offset_gcd = std::gcd(offset_gcd, get_offset() * mi->stride_gcd(i)); + + const std::int64_t size = in.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + auto const get_size = [&]() -> std::int64_t { + if (is_dynamic_value(size)) { + return gcd_.get(dyn_sizes[jsize++]); + } + return size; + }; + shape_gcd.emplace_back(get_size()); + stride_gcd.emplace_back(mi->stride_gcd(i)); + } + } + + gcd_.set_memref(in.result(0), memref_info(offset_gcd, shape_gcd, stride_gcd)); + } +} + +void gcd_helper::set_from_attributes(function_node const &fn) { + auto known_memref_info = [&](memref_data_type const *mr, tinytc_attr_t dict) -> memref_info { + const std::int64_t alignment = [&]() -> std::int64_t { + if (auto alignment_attr = get_attr(dict, "alignment"); alignment_attr) { + auto ia = dyn_cast(alignment_attr); + if (ia) { + return ia->value(); + } + throw status::ir_expected_integer_attribute; + } + return default_alignment_; + }(); + + auto shape_gcd = [&]() -> std::vector { + if (auto shape_attr = get_attr(dict, "shape_gcd"); shape_attr) { + return get_array_attr_as(shape_attr); + } + return std::vector{}; + }(); + auto const dim = static_cast(mr->dim()); + if (shape_gcd.size() > dim) { + shape_gcd.resize(dim); + } else if (shape_gcd.size() < dim) { + std::size_t i = shape_gcd.size(); + shape_gcd.resize(mr->dim()); + for (; i < shape_gcd.size(); ++i) { + const auto s = mr->shape(i); + shape_gcd[i] = !is_dynamic_value(s) ? s : 1; + } + } + + auto stride_gcd = [&]() -> std::vector { + if (auto stride_attr = get_attr(dict, "stride_gcd"); stride_attr) { + return get_array_attr_as(stride_attr); + } + return std::vector{}; + }(); + if (stride_gcd.size() > dim) { + stride_gcd.resize(dim); + } else if (stride_gcd.size() < dim) { + std::size_t i = stride_gcd.size(); + stride_gcd.resize(mr->dim()); + for (; i < stride_gcd.size(); ++i) { + const auto s = mr->stride(i); + stride_gcd[i] = !is_dynamic_value(s) ? s : 1; + } + } + + return memref_info(alignment / size(mr->element_ty()), std::move(shape_gcd), + std::move(stride_gcd)); + }; + for (std::int32_t arg_no = 0; arg_no < fn.num_params(); ++arg_no) { + auto ty = fn.params()[arg_no].ty(); + if (auto g = dyn_cast(ty); g) { + ty = g->ty(); + } + if (auto mr = dyn_cast(ty); mr) { + gcd_.set_memref(fn.params()[arg_no], known_memref_info(mr, fn.param_attr(arg_no))); + } + } +} + +auto gcd_analysis::run_on_function(function_node const &fn) -> gcd_analysis_result { + auto visitor = gcd_helper{default_alignment_}; + visitor.set_from_attributes(fn); + + walk(fn, [&visitor](inst_node const &i) { visit(visitor, i); }); + + return std::move(visitor).get_result(); +} + +} // namespace tinytc diff --git a/src/analysis/gcd.hpp b/src/analysis/gcd.hpp new file mode 100644 index 00000000..3162b642 --- /dev/null +++ b/src/analysis/gcd.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GCD_20241203_HPP +#define GCD_20241203_HPP + +#include "node/function_node.hpp" +#include "tinytc/types.h" + +#include +#include +#include +#include +#include + +namespace tinytc { + +class memref_info { + public: + memref_info() = default; + memref_info(std::int64_t offset_gcd, std::vector shape_gcd, + std::vector stride_gcd); + + inline auto offset_gcd() const { return offset_gcd_; } + inline auto shape_gcd_begin() const { return shape_gcd_.begin(); } + inline auto shape_gcd_end() const { return shape_gcd_.end(); } + inline auto shape_gcd() const -> std::vector const & { return shape_gcd_; } + inline auto shape_gcd(std::size_t i) const -> std::int64_t { return shape_gcd_[i]; } + inline auto stride_gcd_begin() const { return stride_gcd_.begin(); } + inline auto stride_gcd_end() const { return stride_gcd_.end(); } + inline auto stride_gcd() const -> std::vector const & { return stride_gcd_; } + inline auto stride_gcd(std::size_t i) const -> std::int64_t { return stride_gcd_[i]; } + + private: + std::int64_t offset_gcd_; + std::vector shape_gcd_, stride_gcd_; +}; + +class gcd_analysis_result { + public: + auto get(::const_tinytc_value_t a) const -> std::int64_t; + auto get(::tinytc_value const &a) const -> std::int64_t; + auto get_if(::const_tinytc_value_t a) const -> std::optional; + auto get_if(::tinytc_value const &a) const -> std::optional; + void set(::tinytc_value const &a, std::int64_t g); + + auto get_memref_if(::const_tinytc_value_t a) const -> memref_info const *; + auto get_memref_if(::tinytc_value const &a) const -> memref_info const *; + void set_memref(::tinytc_value const &a, memref_info g); + + private: + std::unordered_map<::tinytc_value const *, std::int64_t> gcd_; + std::unordered_map<::tinytc_value const *, memref_info> memref_info_; +}; + +/** + * In the "GCD-analysis" we want to infer at compile time which integer divide an SSA value is + * divisible. For example, for + * + * %0 = constant 32 : index + * %1 = arith.mul %0, %x : index + * + * we know that %1 is at least divisible by 2, 4, 8, 16, 32 without knowing anything about %x. + * + * For the analysis, let P(%x) be the set of known prime factors of a value %x. We define + * + * "%x = constant C": P(%x) := set of prime factors of |C| + * + * For multiplication and addition formulae we define + * "%z = arith.add %x, %y": P(%z) := P(%x) ∩ P(%y) + * "%z = arith.sub %x, %y": P(%z) := P(%x) ∩ P(%y) + * "%z = arith.mul %x, %y": P(%z) := P(%x) ∪ P(%y) + * "%z = arith.div %x, %y": P(%z) := P(%x) ∖ P(%y) if P(%y) ⊆ P(%x) else {1} + * + * If nothing is known about %x we let + * P(%x) := {1} + * + * For efficiency, we encode the set of prime factors by its product, that is, by the integer + * + * p(%x) := \prod_{f\in P(%x)} f + * + * We can update p without having to resort to P as following: + * + * "%x = constant C": p(%x) := C + * "%z = arith.add %x, %y": p(%z) := gcd(p(%x),p(%y)) + * "%z = arith.sub %x, %y": p(%z) := gcd(p(%x),p(%y)) + * "%z = arith.mul %x, %y": p(%z) := p(%x) * p(%y) + * "%z = arith.div %x, %y": p(%z) := p(%x) / p(%y) if p(%x) % p(%y) == 0 else 1 + * "%x = unknown": p(%x) := 1 + * + * where gcd is the greatest common divisor. + * + * One special case is when we have %zero = constant 0. By definition gcd(%zero) = 0. + * We then have automatically + * "%z = arith.add %zero, %x": p(%z) = p(%x) // Fine, 0 + x = x so we must have p(%z) = p(%x) + * "%z = arith.mul %zero, %x": p(%z) = 0 // Fine, 0 * x = 0 so we must have p(%z) = p(%zero) + */ +class gcd_analysis { + public: + inline gcd_analysis(std::int32_t default_alignment) : default_alignment_(default_alignment) {} + + auto run_on_function(function_node const &fn) -> gcd_analysis_result; + + private: + std::int32_t default_alignment_; +}; + +} // namespace tinytc + +#endif // GCD_20241203_HPP diff --git a/src/analysis/stack.cpp b/src/analysis/stack.cpp new file mode 100644 index 00000000..8e33509c --- /dev/null +++ b/src/analysis/stack.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/stack.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "support/walk.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include + +namespace tinytc { + +auto stack_high_water_mark::run_on_function(function_node const &fn) -> std::int64_t { + std::int64_t high_water_mark = 0; + + walk(fn, [&high_water_mark](inst_node const &i) { + if (auto *a = dyn_cast(&i); a) { + auto t = dyn_cast(a->result(0).ty()); + if (t == nullptr) { + throw compilation_error(a->loc(), status::ir_expected_memref); + } + high_water_mark = std::max(high_water_mark, a->stack_ptr() + t->size_in_bytes()); + } + }); + + return high_water_mark; +} + +} // namespace tinytc diff --git a/src/analysis/stack.hpp b/src/analysis/stack.hpp new file mode 100644 index 00000000..ed18b2c8 --- /dev/null +++ b/src/analysis/stack.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef STACK_20241112_HPP +#define STACK_20241112_HPP + +#include "node/function_node.hpp" + +#include + +namespace tinytc { + +class stack_high_water_mark { + public: + auto run_on_function(function_node const &fn) -> std::int64_t; +}; + +} // namespace tinytc + +#endif // STACK_20241112_HPP diff --git a/src/attribute.cpp b/src/attribute.cpp new file mode 100644 index 00000000..9271d195 --- /dev/null +++ b/src/attribute.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "error.hpp" +#include "node/attr_node.hpp" +#include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +#include +#include + +using namespace tinytc; + +extern "C" { + +tinytc_status_t tinytc_array_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + uint32_t array_size, const tinytc_attr_t *array) { + if (attr == nullptr || ctx == nullptr || (array_size != 0 && array == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *attr = array_attr::get(ctx, array_view(array, array_size)); }); +} + +tinytc_status_t tinytc_boolean_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + tinytc_bool_t value) { + if (attr == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *attr = boolean_attr::get(ctx, value); }); +} + +tinytc_status_t tinytc_dictionary_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + uint32_t items_size, tinytc_named_attr_t *items) { + TINYTC_CHECK_STATUS(tinytc_dictionary_attr_sort(items_size, items)); + return tinytc_dictionary_attr_get_with_sorted(attr, ctx, items_size, items); +} + +tinytc_status_t tinytc_dictionary_attr_get_with_sorted(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + uint32_t items_size, + const tinytc_named_attr_t *items) { + if (attr == nullptr || ctx == nullptr || (items_size != 0 && items == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *attr = dictionary_attr::get(ctx, array_view(items, items_size)); + }); +} + +tinytc_status_t tinytc_dictionary_attr_sort(uint32_t items_size, tinytc_named_attr_t *items) { + return exception_to_status_code( + [&] { dictionary_attr::sort(mutable_array_view(items, items_size)); }); +} + +tinytc_status_t tinytc_integer_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + int64_t value) { + if (attr == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *attr = integer_attr::get(ctx, value); }); +} + +tinytc_status_t tinytc_string_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + uint32_t str_length, char const *str) { + if (attr == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *attr = string_attr::get(ctx, std::string_view(str, str_length)); }); +} +} diff --git a/src/binary.cpp b/src/binary.cpp index 360fc0ab..ad9aa16e 100644 --- a/src/binary.cpp +++ b/src/binary.cpp @@ -5,7 +5,7 @@ #include "error.hpp" #include "tinytc/tinytc.h" #include "tinytc/types.h" -#include "util.hpp" +#include "util/casting.hpp" #include #include @@ -13,20 +13,23 @@ using namespace tinytc; -tinytc_binary::tinytc_binary(std::vector data, bundle_format format, - tinytc_core_feature_flags_t core_features) - : data_(std::move(data)), format_(format), core_features_(core_features) {} +tinytc_binary::tinytc_binary(compiler_context ctx, std::vector data, + bundle_format format, tinytc_core_feature_flags_t core_features) + : ctx_(std::move(ctx)), data_(std::move(data)), format_(format), core_features_(core_features) { +} extern "C" { -tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, tinytc_bundle_format_t format, - size_t data_size, uint8_t const *data, +tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, tinytc_compiler_context_t ctx, + tinytc_bundle_format_t format, size_t data_size, + uint8_t const *data, tinytc_core_feature_flags_t core_features) { - if (bin == nullptr || data == nullptr) { + if (bin == nullptr || ctx == nullptr || data == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *bin = std::make_unique(std::vector(data, data + data_size), + *bin = std::make_unique(compiler_context{ctx, true}, + std::vector(data, data + data_size), enum_cast(format), core_features) .release(); }); @@ -43,6 +46,14 @@ tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, tinytc_bundle_f return tinytc_status_success; } +tinytc_status_t tinytc_binary_get_compiler_context(const_tinytc_binary_t bin, + tinytc_compiler_context_t *ctx) { + if (bin == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = bin->share_context().release(); }); +} + tinytc_status_t tinytc_binary_get_core_features(const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features) { if (bin == nullptr || core_features == nullptr) { diff --git a/src/binary.hpp b/src/binary.hpp index 3607c0db..56510463 100644 --- a/src/binary.hpp +++ b/src/binary.hpp @@ -5,13 +5,17 @@ #define BINARY_20240308_HPP #include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include "tinytc/types.hpp" #include #include #include +namespace tinytc { +enum class bundle_format; +} + /** * @brief Container encapsulating a SPIR-V or native device binary */ @@ -20,14 +24,17 @@ struct tinytc_binary : tinytc::reference_counted { /** * @brief Create binary * + * @param ctx Compiler context * @param data Binary data * @param format Binary format (SPIR-V or native device binary) * @param metadata_map Dictionary kernel name -> kernel metadata * @param core_features Required core features */ - tinytc_binary(std::vector data, tinytc::bundle_format format, - tinytc_core_feature_flags_t core_features); + tinytc_binary(tinytc::compiler_context ctx, std::vector data, + tinytc::bundle_format format, tinytc_core_feature_flags_t core_features); + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } //! Get raw data inline auto data() const noexcept -> std::uint8_t const * { return data_.data(); } //! Get size of raw data @@ -40,6 +47,7 @@ struct tinytc_binary : tinytc::reference_counted { } private: + tinytc::compiler_context ctx_; std::vector data_; tinytc::bundle_format format_; tinytc_core_feature_flags_t core_features_; diff --git a/src/cl/CMakeLists.txt b/src/cl/CMakeLists.txt index e0cf379d..f0def75b 100644 --- a/src/cl/CMakeLists.txt +++ b/src/cl/CMakeLists.txt @@ -6,7 +6,6 @@ include(GNUInstallDirs) include(InstallLib) find_package(OpenCL REQUIRED) -find_package(re2c REQUIRED) set(SOURCES device_info.cpp @@ -14,9 +13,6 @@ set(SOURCES kernel.cpp recipe_handler.cpp ) -set(RE2C_SOURCES - device_info_helper.re -) set(PUBLIC_HEADERS tinytc_cl.h tinytc_cl.hpp @@ -24,13 +20,20 @@ set(PUBLIC_HEADERS list(TRANSFORM PUBLIC_HEADERS PREPEND "${PROJECT_SOURCE_DIR}/include/tinytc/") add_library(tinytc_cl-objects OBJECT ${SOURCES}) -add_re2c_to_target(TARGET tinytc_cl-objects SOURCES ${RE2C_SOURCES} FLAGS "--tags") set_cxx_common_options(tinytc_cl-objects) target_link_libraries(tinytc_cl-objects PUBLIC tinytc OpenCL::OpenCL) target_include_directories(tinytc_cl-objects PRIVATE "$" ) +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/device_info_helper.cpp) + message(STATUS "Pre-generated device info helper available -- skipping re2c dependency") + target_sources(tinytc_cl-objects PRIVATE device_info_helper.cpp) +else() + find_package(re2c REQUIRED) + add_re2c_to_target(TARGET tinytc_cl-objects SOURCES device_info_helper.re FLAGS "--tags") +endif() + add_library(tinytc_cl $) add_library(tinytc::tinytc_cl ALIAS tinytc_cl) @@ -52,3 +55,9 @@ install(FILES "${PROJECT_SOURCE_DIR}/cmake/FindOpenCL.cmake" DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/tinytc_cl ) + +# cpack + +include(GeneratedFiles) +get_generated_files(GENERATED_FILES_TINYTC_CL tinytc_cl-objects) +set(GENERATED_FILES_TINYTC_CL ${GENERATED_FILES_TINYTC_CL} PARENT_SCOPE) diff --git a/src/cl/argument_handler.hpp b/src/cl/argument_handler.hpp index 7cc8d5f2..ef1da91e 100644 --- a/src/cl/argument_handler.hpp +++ b/src/cl/argument_handler.hpp @@ -19,13 +19,16 @@ class opencl_argument_handler { using clSetKernelArgMemPointerINTEL_t = cl_int (*)(cl_kernel kernel, cl_uint arg_index, const void *arg_value); //! ctor - inline opencl_argument_handler() : clSetKernelArgMemPointerINTEL_(nullptr) {} + inline opencl_argument_handler() = default; //! ctor; checks whether cl_intel_unified_shared_memory is available and gets //! clSetKernelArgMemPointerINTEL - inline opencl_argument_handler(cl_platform_id plat) - : clSetKernelArgMemPointerINTEL_( - (clSetKernelArgMemPointerINTEL_t)clGetExtensionFunctionAddressForPlatform( - plat, "clSetKernelArgMemPointerINTEL")) {} + inline opencl_argument_handler(cl_platform_id plat) { set_platform(plat); } + + inline void set_platform(cl_platform_id plat) { + clSetKernelArgMemPointerINTEL_ = + (clSetKernelArgMemPointerINTEL_t)clGetExtensionFunctionAddressForPlatform( + plat, "clSetKernelArgMemPointerINTEL"); + } /** * @brief Set single kernel argument @@ -69,7 +72,7 @@ class opencl_argument_handler { } private: - clSetKernelArgMemPointerINTEL_t clSetKernelArgMemPointerINTEL_; + clSetKernelArgMemPointerINTEL_t clSetKernelArgMemPointerINTEL_ = nullptr; }; } // namespace tinytc diff --git a/src/cl/device_info.cpp b/src/cl/device_info.cpp index 76c7fe2b..5b4f5a62 100644 --- a/src/cl/device_info.cpp +++ b/src/cl/device_info.cpp @@ -3,7 +3,9 @@ #include "../device_info.hpp" #include "device_info_helper.hpp" +#include "error.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.h" #include "tinytc/types.h" @@ -12,31 +14,101 @@ #include #include #include -#include #include +#ifndef CL_DEVICE_SINGLE_FP_ATOMIC_CAPABILITIES_EXT +#define CL_DEVICE_SINGLE_FP_ATOMIC_CAPABILITIES_EXT 0x4231 +#endif +#ifndef CL_DEVICE_DOUBLE_FP_ATOMIC_CAPABILITIES_EXT +#define CL_DEVICE_DOUBLE_FP_ATOMIC_CAPABILITIES_EXT 0x4232 +#endif +#ifndef CL_DEVICE_HALF_FP_ATOMIC_CAPABILITIES_EXT +#define CL_DEVICE_HALF_FP_ATOMIC_CAPABILITIES_EXT 0x4233 +#endif + +#ifndef CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT +#define CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT (1 << 1) +#endif +#ifndef CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT +#define CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT (1 << 2) +#endif +#ifndef CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT +#define CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT (1 << 17) +#endif +#ifndef CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT +#define CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT (1 << 18) +#endif + +namespace tinytc { +void set_spirv_features(tinytc_core_info_t info, cl_device_id device) { + auto const set_feature = [&info](tinytc_spirv_feature_t feature, bool available) { + CHECK_STATUS(tinytc_core_info_set_spirv_feature(info, feature, available)); + }; + + const auto ocl_exts = get_opencl_extensions(device); + const auto ocl_version = get_opencl_version(device); + const auto max_num_subgroups = device_info(device, CL_DEVICE_MAX_NUM_SUB_GROUPS); + const auto double_fp_config = + device_info(device, CL_DEVICE_DOUBLE_FP_CONFIG); + + set_feature(tinytc_spirv_feature_float16, ocl_exts & opencl_ext_cl_khr_fp16); + set_feature(tinytc_spirv_feature_float64, + double_fp_config != 0 || ocl_exts & opencl_ext_cl_khr_fp64); + set_feature(tinytc_spirv_feature_groups, ocl_version.major >= 2 && max_num_subgroups != 0); + set_feature(tinytc_spirv_feature_subgroup_dispatch, + ocl_version.major >= 2 && max_num_subgroups != 0); + set_feature(tinytc_spirv_feature_subgroup_buffer_block_io, + ocl_exts & opencl_ext_cl_intel_spirv_subgroups); + set_feature(tinytc_spirv_feature_int64_atomics, + ocl_exts & (opencl_ext_cl_khr_int64_base_atomics | + opencl_ext_cl_khr_int64_extended_atomics)); + if (ocl_exts & opencl_ext_cl_ext_float_atomics) { + auto f16_flags = + device_info(device, CL_DEVICE_HALF_FP_ATOMIC_CAPABILITIES_EXT); + auto f32_flags = + device_info(device, CL_DEVICE_SINGLE_FP_ATOMIC_CAPABILITIES_EXT); + auto f64_flags = + device_info(device, CL_DEVICE_DOUBLE_FP_ATOMIC_CAPABILITIES_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_add_local, + f16_flags & CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_add_local, + f32_flags & CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_add_local, + f64_flags & CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_add_global, + f16_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_add_global, + f32_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_add_global, + f64_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_min_max_local, + f16_flags & CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_min_max_local, + f32_flags & CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_min_max_local, + f64_flags & CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_min_max_global, + f16_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_min_max_global, + f32_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_min_max_global, + f64_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT); + } +} +} // namespace tinytc + extern "C" { tinytc_status_t tinytc_cl_get_support_level(cl_device_id device, tinytc_support_level_t *level) { if (level == nullptr) { return tinytc_status_invalid_arguments; } - std::size_t extensions_size; - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, nullptr, &extensions_size)); - std::string extensions; - extensions.resize(extensions_size); - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, extensions_size, extensions.data(), nullptr)); - - bool has_subgroup = tinytc::has_subgroup_extension(extensions.size(), extensions.c_str()); + auto ocl_exts = tinytc::get_opencl_extensions(device); + bool has_subgroup = + ocl_exts & (tinytc::opencl_ext_cl_intel_subgroups | tinytc::opencl_ext_cl_khr_subgroups); if (!has_subgroup) { - char version_str[32]; - std::size_t version_str_size; - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_VERSION, sizeof(version_str) - 1, - version_str, &version_str_size)); - auto version = tinytc::get_opencl_version(version_str_size, version_str); + auto version = tinytc::get_opencl_version(device); if (version.major >= 3) { std::size_t features_size; TINYTC_CL_CHECK_STATUS( @@ -63,7 +135,11 @@ tinytc_status_t tinytc_cl_get_support_level(cl_device_id device, tinytc_support_ cl_version ip_ver = 0; cl_int err = clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, sizeof(ip_ver), &ip_ver, nullptr); - if (err == CL_SUCCESS && ip_ver == tinytc_intel_gpu_architecture_pvc) { + const auto is_arch = [&ip_ver](auto arch) { + return arch <= ip_ver && ip_ver <= arch + TINYTC_INTEL_GPU_ARCHITECTURE_SUB_VERSION_BITS; + }; + if (err == CL_SUCCESS && (is_arch(tinytc_intel_gpu_architecture_pvc) || + is_arch(tinytc_intel_gpu_architecture_bmg))) { *level = tinytc_support_level_tuned; } @@ -75,16 +151,18 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i return tinytc_status_invalid_arguments; } - cl_uint vendor_id; + cl_uint vendor_id, mem_base_addr_align; TINYTC_CL_CHECK_STATUS( clGetDeviceInfo(device, CL_DEVICE_VENDOR_ID, sizeof(vendor_id), &vendor_id, nullptr)); if (vendor_id == 0x8086) { - cl_version ip_ver; - cl_uint num_eus_per_subslice, num_threads_per_eu; + cl_device_type device_type; std::size_t subgroup_sizes_size = 0; + TINYTC_CL_CHECK_STATUS( + clGetDeviceInfo(device, CL_DEVICE_TYPE, sizeof(device_type), &device_type, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_SUB_GROUP_SIZES_INTEL, 0, nullptr, &subgroup_sizes_size)); auto subgroup_sizes_long = @@ -95,19 +173,36 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i auto subgroup_sizes = std::vector(subgroup_sizes_long.begin(), subgroup_sizes_long.end()); - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, sizeof(ip_ver), &ip_ver, nullptr)); - - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, - sizeof(num_eus_per_subslice), &num_eus_per_subslice, - nullptr)); - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, - sizeof(num_threads_per_eu), &num_threads_per_eu, - nullptr)); - - TINYTC_CHECK_STATUS(tinytc_core_info_intel_create(info, ip_ver, num_eus_per_subslice, - num_threads_per_eu, subgroup_sizes.size(), - subgroup_sizes.data())); + if (device_type == CL_DEVICE_TYPE_GPU) { + cl_version ip_ver; + cl_uint num_eus_per_subslice, num_threads_per_eu; + + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, + sizeof(ip_ver), &ip_ver, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, + sizeof(num_eus_per_subslice), + &num_eus_per_subslice, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, + sizeof(num_threads_per_eu), &num_threads_per_eu, + nullptr)); + + TINYTC_CHECK_STATUS(tinytc_core_info_intel_create( + info, ip_ver, num_eus_per_subslice, num_threads_per_eu, subgroup_sizes.size(), + subgroup_sizes.data())); + } else if (device_type == CL_DEVICE_TYPE_CPU) { + // 32 zmm registers + // @todo: need to do something smarter here + std::uint32_t register_space = 32 * 64; + size_t max_work_group_size; + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, + sizeof(max_work_group_size), + &max_work_group_size, nullptr)); + TINYTC_CHECK_STATUS( + tinytc_core_info_generic_create(info, register_space, max_work_group_size, + subgroup_sizes.size(), subgroup_sizes.data())); + } else { + return tinytc_status_unsupported_device; + } } else if (vendor_id == 0x1002) { // 512 KB / 32 wavefronts // @todo: can this info be queried? @@ -127,7 +222,13 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i return tinytc_status_unsupported_device; } - return tinytc_status_success; + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, + sizeof(mem_base_addr_align), &mem_base_addr_align, + nullptr)); + // mem_base_addr_align is in bits -> convert to bytes + TINYTC_CHECK_STATUS(tinytc_core_info_set_default_alignment(*info, mem_base_addr_align / 8)); + + return tinytc::exception_to_status_code_cl([&] { set_spirv_features(*info, device); }); } } diff --git a/src/cl/device_info_helper.hpp b/src/cl/device_info_helper.hpp index c2e265a0..5cbfb337 100644 --- a/src/cl/device_info_helper.hpp +++ b/src/cl/device_info_helper.hpp @@ -4,7 +4,13 @@ #ifndef CL_DEVICE_INFO_HELPER_20240503_HPP #define CL_DEVICE_INFO_HELPER_20240503_HPP +#include "tinytc/tinytc_cl.hpp" + +#include + #include +#include +#include namespace tinytc { @@ -13,9 +19,43 @@ struct opencl_version { int minor; }; -bool has_subgroup_extension(std::size_t str_length, const char *str); -bool has_additional_subgroup_extensions(std::size_t str_length, const char *str); +enum opencl_ext_t { + opencl_ext_cl_khr_fp16 = 0x1, + opencl_ext_cl_khr_fp64 = 0x2, + opencl_ext_cl_khr_subgroups = 0x4, + opencl_ext_cl_intel_subgroups = 0x8, + opencl_ext_cl_intel_required_subgroup_size = 0x10, + opencl_ext_cl_intel_subgroups_long = 0x20, + opencl_ext_cl_intel_subgroups_short = 0x40, + opencl_ext_cl_khr_int64_base_atomics = 0x80, + opencl_ext_cl_khr_int64_extended_atomics = 0x100, + opencl_ext_cl_ext_float_atomics = 0x200, + opencl_ext_cl_intel_spirv_subgroups = 0x400, +}; +//! Type for combination of core feature flags +using opencl_exts_t = std::uint32_t; + +auto get_opencl_extensions(std::size_t str_length, const char *str) -> opencl_exts_t; +auto get_opencl_extensions(cl_device_id device) -> opencl_exts_t; auto get_opencl_version(std::size_t str_length, const char *str) -> opencl_version; +auto get_opencl_version(cl_device_id device) -> opencl_version; + +template auto device_info(cl_device_id device, cl_device_info param_name) -> T { + T val = {}; + CL_CHECK_STATUS(clGetDeviceInfo(device, param_name, sizeof(T), &val, nullptr)); + return val; +} + +template <> +inline auto device_info(cl_device_id device, cl_device_info param_name) + -> std::string { + std::string str; + std::size_t str_len; + CL_CHECK_STATUS(clGetDeviceInfo(device, param_name, 0, nullptr, &str_len)); + str.resize(str_len); + CL_CHECK_STATUS(clGetDeviceInfo(device, param_name, str_len, str.data(), nullptr)); + return str; +} } // namespace tinytc diff --git a/src/cl/device_info_helper.re b/src/cl/device_info_helper.re index 8004f97a..10c7a258 100644 --- a/src/cl/device_info_helper.re +++ b/src/cl/device_info_helper.re @@ -3,12 +3,15 @@ #include "device_info_helper.hpp" +#include + namespace tinytc { -bool has_subgroup_extension(std::size_t str_length, const char *str) { +auto get_opencl_extensions(std::size_t str_length, const char *str) -> opencl_exts_t { const char *YYLIMIT = str + str_length; const char *YYCURSOR = str; const char *YYMARKER; + opencl_exts_t result = 0; for (;;) { /*!re2c re2c:yyfill:enable = 0; @@ -17,36 +20,49 @@ bool has_subgroup_extension(std::size_t str_length, const char *str) { whitespace = [ \t\v\r]+; - "cl_intel_subgroups" { return true; } - "cl_khr_subgroups" { return true; } - whitespace { continue; } - * { continue; } - $ { break; } + + "cl_khr_fp16" { result |= opencl_ext_cl_khr_fp16; continue; } + "cl_khr_fp64" { result |= opencl_ext_cl_khr_fp64; continue; } + "cl_khr_subgroups" { result |= opencl_ext_cl_khr_subgroups; continue; } + "cl_intel_subgroups" { result |= opencl_ext_cl_intel_subgroups; continue; } + "cl_intel_required_subgroup_size" { + result |= opencl_ext_cl_intel_required_subgroup_size; continue; + } + "cl_intel_subgroups_long" { + result |= opencl_ext_cl_intel_subgroups_long; continue; + } + "cl_intel_subgroups_short" { + result |= opencl_ext_cl_intel_subgroups_short; continue; + } + "cl_intel_spirv_subgroups" { + result |= opencl_ext_cl_intel_spirv_subgroups; continue; + } + "cl_khr_int64_base_atomics" { + result |= opencl_ext_cl_khr_int64_base_atomics; continue; + } + "cl_khr_int64_extended_atomics" { + result |= opencl_ext_cl_khr_int64_extended_atomics; continue; + } + "cl_ext_float_atomics" { + result |= opencl_ext_cl_ext_float_atomics; continue; + } + whitespace { continue; } + * { + // skip remaining characters until we find whitespace + while (!std::isspace(*YYCURSOR) && YYCURSOR < YYLIMIT) { + ++YYCURSOR; + } + continue; + } + $ { break; } */ } - return false; + return result; } -bool has_additional_subgroup_extensions(std::size_t str_length, const char *str) { - const char *YYLIMIT = str + str_length; - const char *YYCURSOR = str; - const char *YYMARKER; - bool has_reqd_subgroup_size = false, has_subgroups_long = false, has_subgroups_short = false; - for (;;) { - /*!re2c - re2c:yyfill:enable = 0; - re2c:define:YYCTYPE = char; - re2c:eof = 0; - - "cl_intel_required_subgroup_size" { has_reqd_subgroup_size = true; continue; } - "cl_intel_subgroups_long" { has_subgroups_long = true; continue; } - "cl_intel_subgroups_short" { has_subgroups_short = true; continue; } - whitespace { continue; } - * { continue; } - $ { break; } - */ - } - return has_reqd_subgroup_size && has_subgroups_long && has_subgroups_short; +auto get_opencl_extensions(cl_device_id device) -> opencl_exts_t { + std::string extensions = device_info(device, CL_DEVICE_EXTENSIONS); + return get_opencl_extensions(extensions.size(), extensions.c_str()); } auto get_opencl_version(std::size_t str_length, const char *str) -> opencl_version { @@ -81,4 +97,9 @@ ret: return {major, minor}; } +auto get_opencl_version(cl_device_id device) -> opencl_version { + std::string version = device_info(device, CL_DEVICE_VERSION); + return get_opencl_version(version.size(), version.c_str()); +} + } // namespace tinytc diff --git a/src/cl/kernel.cpp b/src/cl/kernel.cpp index 23814944..409ca376 100644 --- a/src/cl/kernel.cpp +++ b/src/cl/kernel.cpp @@ -3,75 +3,31 @@ #include "../compiler_options.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.h" #include "tinytc/types.h" #include #include -#include #include #include -#include #include -#include +#include -extern "C" { - -tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, cl_context context, - cl_device_id device, - const_tinytc_source_t src, - tinytc_source_context_t source_ctx) { - if (bundle == nullptr || src == nullptr) { - return tinytc_status_invalid_arguments; - } - - size_t length = 0; - char const *code = nullptr; - tinytc_core_feature_flags_t core_features = 0; - TINYTC_CL_CHECK_STATUS(tinytc_source_get_code(src, &length, &code)); - TINYTC_CL_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); - - cl_int err; - cl_program p = clCreateProgramWithSource(context, 1, &code, &length, &err); - TINYTC_CL_CHECK_STATUS(err); - - auto options = std::ostringstream{}; - for (auto const &opt : tinytc::default_compiler_options) { - options << opt << " "; - } - if (core_features & tinytc_core_feature_flag_large_register_file) { - options << tinytc::large_register_file_compiler_option_cl; - } - auto options_str = std::move(options).str(); - if (err = clBuildProgram(p, 1, &device, options_str.c_str(), nullptr, nullptr); - err != CL_SUCCESS) { - if (source_ctx) { - std::string log; - std::size_t log_size; - clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); - log.resize(log_size); - clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); +using tinytc::compiler_context; - tinytc_location_t loc = {}; - tinytc_source_get_location(src, &loc); - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); - } - clReleaseProgram(p); - TINYTC_CL_CHECK_STATUS(err); - } - *bundle = p; - return tinytc_status_success; -} +extern "C" { -tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( - cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, tinytc_source_context_t source_ctx) { +tinytc_status_t +tinytc_cl_kernel_bundle_create_with_program(cl_program *bundle, cl_context context, + cl_device_id device, tinytc_prog_t prg, + tinytc_core_feature_flags_t core_features) { if (bundle == nullptr || prg == nullptr) { return tinytc_status_invalid_arguments; } tinytc_core_info_t info = nullptr; - tinytc_source_t src = nullptr; + tinytc_binary_t bin = nullptr; tinytc_status_t status = tinytc_status_success; if (status = tinytc_cl_core_info_create(&info, device); status != tinytc_status_success) { @@ -81,17 +37,16 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( status != tinytc_status_success) { goto err; } - if (status = tinytc_prog_compile_to_opencl(&src, prg, info, source_ctx); + if (status = tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info); status != tinytc_status_success) { goto err; } - if (status = - tinytc_cl_kernel_bundle_create_with_source(bundle, context, device, src, source_ctx); + if (status = tinytc_cl_kernel_bundle_create_with_binary(bundle, context, device, bin); status != tinytc_status_success) { goto err; } err: - tinytc_source_release(src); + tinytc_binary_release(bin); tinytc_core_info_release(info); return status; @@ -99,8 +54,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, cl_context context, cl_device_id device, - const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx) { + const_tinytc_binary_t bin) { if (bin == nullptr) { return tinytc_status_invalid_arguments; } @@ -121,12 +75,16 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, c tinytc_core_feature_flags_t core_features; TINYTC_CHECK_STATUS(tinytc_binary_get_core_features(bin, &core_features)); + tinytc_compiler_context_t ctx = nullptr; + TINYTC_CHECK_STATUS(tinytc_binary_get_compiler_context(bin, &ctx)); + auto ctx_ = compiler_context{ctx}; // Clean-up ctx when ctx_ gets out of scope + char const *options = ""; if (core_features & tinytc_core_feature_flag_large_register_file) { options = tinytc::large_register_file_compiler_option_cl; } if (err = clBuildProgram(p, 1, &device, options, nullptr, nullptr); err != CL_SUCCESS) { - if (source_ctx) { + if (ctx_.get()) { std::string log; std::size_t log_size; clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); @@ -134,7 +92,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, c clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); tinytc_location_t loc = {}; - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); + tinytc_compiler_context_report_error(ctx_.get(), &loc, log.c_str()); } clReleaseProgram(p); TINYTC_CL_CHECK_STATUS(err); @@ -147,19 +105,33 @@ tinytc_status_t tinytc_cl_get_group_size(cl_kernel kernel, size_t *local_size) { if (local_size == nullptr) { return tinytc_status_invalid_arguments; } + constexpr int short_dev_list = 4; cl_program p; cl_device_id d; + cl_uint num_devices; TINYTC_CL_CHECK_STATUS(clGetKernelInfo(kernel, CL_KERNEL_PROGRAM, sizeof(p), &p, nullptr)); - TINYTC_CL_CHECK_STATUS(clGetProgramInfo(p, CL_PROGRAM_DEVICES, sizeof(d), &d, nullptr)); + TINYTC_CL_CHECK_STATUS( + clGetProgramInfo(p, CL_PROGRAM_NUM_DEVICES, sizeof(num_devices), &num_devices, nullptr)); + if (num_devices <= short_dev_list) { + cl_device_id dbuf[4]; + TINYTC_CL_CHECK_STATUS( + clGetProgramInfo(p, CL_PROGRAM_DEVICES, sizeof(dbuf), &dbuf, nullptr)); + d = dbuf[0]; + } else { + auto dbuf = std::vector(num_devices); + TINYTC_CL_CHECK_STATUS(clGetProgramInfo( + p, CL_PROGRAM_DEVICES, num_devices * sizeof(cl_device_id), dbuf.data(), nullptr)); + d = dbuf[0]; + } return tinytc_cl_convert_status( clGetKernelWorkGroupInfo(kernel, d, CL_KERNEL_COMPILE_WORK_GROUP_SIZE, 3 * sizeof(std::size_t), local_size, nullptr)); } -void tinytc_cl_get_global_size(int64_t howmany, const size_t *local_size, size_t *global_size) { +void tinytc_cl_get_global_size(const size_t *num_groups, const size_t *local_size, + size_t *global_size) { for (size_t i = 0; i < 3; ++i) { - global_size[i] = local_size[i]; + global_size[i] = num_groups[i] * local_size[i]; } - global_size[2] *= howmany; } } diff --git a/src/cl/recipe_handler.cpp b/src/cl/recipe_handler.cpp index e9f562f6..035b0fa5 100644 --- a/src/cl/recipe_handler.cpp +++ b/src/cl/recipe_handler.cpp @@ -3,7 +3,6 @@ #include "recipe_handler.hpp" #include "../recipe.hpp" -#include "../reference_counted.hpp" #include "error.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/tinytc_cl.h" @@ -17,11 +16,10 @@ namespace tinytc { -cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, recipe rec, - source_context source_ctx) +cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, recipe rec) : ::tinytc_recipe_handler(std::move(rec)) { - module_ = make_kernel_bundle(context, device, get_recipe().get_source(), std::move(source_ctx)); + module_ = make_kernel_bundle(context, device, get_recipe().get_binary()); auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); @@ -52,7 +50,8 @@ void cl_recipe_handler::mem_arg(std::uint32_t arg_index, const void *value, } void cl_recipe_handler::howmany(std::int64_t num) { - global_size_ = get_global_size(num, local_size()); + global_size_ = get_global_size( + std::array{static_cast(num), 1u, 1u}, local_size()); } auto cl_recipe_handler::kernel() -> cl_kernel { return kernels_[active_kernel_].get(); } @@ -69,14 +68,12 @@ extern "C" { tinytc_status_t tinytc_cl_recipe_handler_create(tinytc_recipe_handler_t *handler, cl_context context, cl_device_id device, - tinytc_recipe_t rec, - tinytc_source_context_t source_ctx) { + tinytc_recipe_t rec) { if (handler == nullptr || rec == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code_cl([&] { - *handler = std::make_unique(context, device, recipe(rec, true), - source_context(source_ctx, true)) + *handler = std::make_unique(context, device, recipe(rec, true)) .release(); }); } diff --git a/src/cl/recipe_handler.hpp b/src/cl/recipe_handler.hpp index 6c4ebea7..c2b03479 100644 --- a/src/cl/recipe_handler.hpp +++ b/src/cl/recipe_handler.hpp @@ -19,8 +19,7 @@ namespace tinytc { struct cl_recipe_handler : ::tinytc_recipe_handler { public: - cl_recipe_handler(cl_context context, cl_device_id device, recipe rec, - source_context source_ctx); + cl_recipe_handler(cl_context context, cl_device_id device, recipe rec); void active_kernel(int kernel_num) override; void arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) override; diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 1ff7acf7..4240ded2 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -2,231 +2,427 @@ // SPDX-License-Identifier: BSD-3-Clause #include "codegen_tools.hpp" +#include "compiler_context.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/constant_folding.hpp" #include "scalar_type.hpp" -#include "util.hpp" - -#include -#include -#include -#include -#include -#include +#include "tinytc/types.h" +#include "util/casting.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" +#include #include +#include +#include #include - -using namespace clir; +#include namespace tinytc { -expr vload_helper(short vec_size, expr offset, expr ptr) { - switch (vec_size) { - case 1: - return ptr[std::move(offset)]; - case 2: - return vload2(std::move(offset), std::move(ptr)); - case 3: - return vload3(std::move(offset), std::move(ptr)); - case 4: - return vload4(std::move(offset), std::move(ptr)); - case 8: - return vload8(std::move(offset), std::move(ptr)); - case 16: - return vload16(std::move(offset), std::move(ptr)); - default: - break; +auto get_core_config_and_tiling(function_node const &fn, const_tinytc_core_info_t info) + -> std::pair { + const auto get_core_config = [&]() -> core_config { + try { + return info->get_core_config(fn.subgroup_size()); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } }; - return nullptr; + core_config core_cfg = get_core_config(); + const auto wgs = fn.work_group_size(); + local_tiling tiling = {wgs[0] / core_cfg.subgroup_size, wgs[1]}; + return {core_cfg, tiling}; } -void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, address_space as, - expr value, expr beta) { - if (is_atomic) { - atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), std::move(beta)); - } else { - bb.assign(dereference(dst), std::move(value) + std::move(beta) * dereference(dst)); - } -} - -void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, address_space as, expr value, - expr beta) { - int mode = -1; - visit(overloaded{ - [&](clir::internal::int_imm &c) { - mode = c.value() == 0 ? 0 : (c.value() == 1 ? 1 : -1); - }, - [&](clir::internal::uint_imm &c) { - mode = c.value() == 0u ? 0 : (c.value() == 1u ? 1 : -1); - }, - [&](clir::internal::float_imm &c) { - mode = c.value() == 0.0 ? 0 : (c.value() == 1.0 ? 1 : -1); - }, - [&](auto &) {}, - }, - *beta); - auto pointer_ty = pointer_to(to_clir_atomic_ty(ty, as, type_qualifier::volatile_t)); - auto atomic_dst = cast(std::move(pointer_ty), dst); - if (mode == 0) { - bb.add(call_builtin(builtin_function::atomic_store_explicit, - {std::move(atomic_dst), std::move(value), memory_order::relaxed, - memory_scope::work_group})); - } else if (mode == 1) { - bb.add(call_builtin(builtin_function::atomic_fetch_add_explicit, - {std::move(atomic_dst), std::move(value), memory_order::relaxed, - memory_scope::work_group})); - } else { - auto expected = bb.declare_assign(to_clir_ty(ty), "expected", dereference(dst)); - auto desired = bb.declare(to_clir_ty(ty), "desired"); - auto cmpxchg = - call_builtin(builtin_function::atomic_compare_exchange_strong_explicit, - {std::move(atomic_dst), address_of(std::move(expected)), desired, - memory_order::relaxed, memory_order::relaxed, memory_scope::work_group}); - bb.add(while_loop_builder(std::move(cmpxchg), true) - .body([&](block_builder &bb) { bb.assign(desired, value + beta * expected); }) - .get_product()); - } -} - -void dispatch_constant_dynamic(expr e, std::function const &const_case, - std::function const &dyn_case) { - visit( - overloaded{ - [&](clir::internal::int_imm &c) { const_case(c.value()); }, - [&](clir::internal::uint_imm &c) { const_case(static_cast(c.value())); }, - [&](auto &) { dyn_case(std::move(e)); }, - }, - *e); -} - -void tile_loop_by_sgs(block_builder &bb, expr loop_trip_count, unsigned sgs, unsigned num_tiles, - var sg_id, sgs_loop_body_builder const &body) { - dispatch_constant_dynamic( - std::move(loop_trip_count), - [&](std::int64_t c) { - tile_loop_by_sgs_constant(bb, c, sgs, num_tiles, std::move(sg_id), body); - }, - [&](expr d) { - tile_loop_by_sgs_dynamic(bb, std::move(d), sgs, num_tiles, std::move(sg_id), body); +void tile_loop_by_sgs(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder const &body, attr for_attributes) { + auto ity = loop_trip_count->ty(); + auto bool_ty = boolean_data_type::get(ity->context()); + auto c_sgs = bb.add(make_constant(sgs, ity)); + auto c_sgs_tiles = bb.add(make_constant(sgs * num_tiles, ity)); + auto c0 = bb.add(make_constant(0, ity)); + auto c_tiles_1 = bb.add(make_constant(num_tiles - 1, ity)); + + auto blocks = + instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, c_sgs, ity)); + auto rem = + instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, c_sgs, ity)); + + auto sg_id_cast = instant_constant_fold_add(bb, make_cast(sg_id, ity)); + auto is_blocks_gt_0 = + instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, blocks, c0, bool_ty)); + bb.if_condition(is_blocks_gt_0, [&](region_builder &bb) { + auto block_start = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, sg_id_cast, ity)); + auto block_end = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, c_sgs, blocks, ity)); + bb.for_loop( + ity, std::move(block_start), std::move(block_end), c_sgs_tiles, + [&](region_builder &bb, value block) { body(bb, block, false, c_sgs); }, + for_attributes); + }); + + auto condition0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0, bool_ty)); + bb.if_condition(condition0, [&](region_builder &bb) { + auto condition1 = instant_constant_fold_add( + bb, make_cmp(cmp_condition::eq, sg_id_cast, c_tiles_1, bool_ty)); + bb.if_condition(condition1, [&](region_builder &bb) { + auto block = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks, c_sgs, ity)); + body(bb, block, true, rem); }); + }); } -void tile_loop_by_sgs_constant(block_builder &bb, unsigned loop_trip_count, unsigned sgs, - unsigned num_tiles, var sg_id, sgs_loop_body_builder const &body) { - auto blocks = loop_trip_count / sgs; - auto rem = loop_trip_count % sgs; - - auto block = bb.declare(generic_uint(), "blck"); - if (blocks > 0) { - bb.add(for_loop_builder(assignment(block, sgs * sg_id), block < sgs * blocks, - add_into(block, sgs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, false, sgs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - } - if (rem > 0) { - bb.assign(block, blocks * sgs); - bb.add(if_selection_builder(sg_id == num_tiles - 1u) - .then([&](block_builder &bb) { body(bb, block, true, rem); }) - .get_product()); - } -} - -void tile_loop_by_sgs_dynamic(block_builder &bb, expr loop_trip_count, unsigned sgs, - unsigned num_tiles, var sg_id, sgs_loop_body_builder const &body) { - auto blocks = bb.declare_assign(generic_uint(), "blocks", loop_trip_count / sgs); - auto rem = bb.declare_assign(generic_uint(), "rem", std::move(loop_trip_count) % sgs); - - auto block = bb.declare(generic_uint(), "blck"); - bb.add(for_loop_builder(assignment(block, sgs * sg_id), block < sgs * blocks, - add_into(block, sgs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, false, sgs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - - bb.add(if_selection_builder(rem > 0) - .then([&](block_builder &bb) { - bb.assign(block, blocks * sgs); - bb.add(if_selection_builder(sg_id == num_tiles - 1u) - .then([&](block_builder &bb) { body(bb, block, true, rem); }) - .get_product()); - }) - .get_product()); -} - -unsigned tile_loop_uniformly_max_block_size(unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles) { - auto blocks = 1 + (loop_trip_count - 1) / block_size; - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - auto bs = loop_trip_count / blocks; - auto rem = loop_trip_count % blocks; - return rem > 0 ? bs + 1 : bs; -} - -void tile_loop_uniformly(block_builder &bb, expr loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, uniform_loop_body_builder const &body) { - dispatch_constant_dynamic( - std::move(loop_trip_count), - [&](std::int64_t c) { - tile_loop_uniformly_constant(bb, c, block_size, num_tiles, std::move(sg_id), body); - }, - [&](expr d) { - tile_loop_uniformly_dynamic(bb, std::move(d), block_size, num_tiles, std::move(sg_id), - body); - }); +void tile_loop_uniformly(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, + value sg_id, uniform_loop_body_builder const &body, attr for_attributes) { + auto ity = loop_trip_count->ty(); + auto bool_ty = boolean_data_type::get(ity->context()); + auto c0 = bb.add(make_constant(0, ity)); + auto c1 = bb.add(make_constant(1, ity)); + auto c_tiles = bb.add(make_constant(num_tiles, ity)); + + // Here we compute + // blocks = ceil(loop_trip_count / block_size) = 1 + (loop_trip_count - 1) / block_size + // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * + // num_tiles + auto c_block_size = bb.add(make_constant(block_size, ity)); + auto blocks0 = + instant_constant_fold_add(bb, make_arith(arithmetic::sub, loop_trip_count, c1, ity)); + auto blocks1 = + instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks0, c_block_size, ity)); + auto blocks2 = + instant_constant_fold_add(bb, make_arith(arithmetic::div, blocks1, c_tiles, ity)); + auto blocks3 = instant_constant_fold_add(bb, make_arith(arithmetic::add, c1, blocks2, ity)); + auto blocks = instant_constant_fold_add(bb, make_arith(arithmetic::mul, blocks3, c_tiles, ity)); + + auto bs = + instant_constant_fold_add(bb, make_arith(arithmetic::div, loop_trip_count, blocks, ity)); + auto bs_1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, bs, c1, ity)); + auto rem = + instant_constant_fold_add(bb, make_arith(arithmetic::rem, loop_trip_count, blocks, ity)); + + auto sg_id_cast = instant_constant_fold_add(bb, make_cast(sg_id, ity)); + // The following if makes it easy to eliminate the remainder handler in optimization if rem + // == 0 is known at compile time. Without the if, we would need to prove that block_start_1 + // is non-negative to eliminate the for-loop. + auto is_rem_gt_0 = instant_constant_fold_add(bb, make_cmp(cmp_condition::gt, rem, c0, bool_ty)); + bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { + auto block_start_1 = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, sg_id_cast, ity)); + auto block_end_1 = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem, ity)); + auto step_1 = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, c_tiles, ity)); + bb.for_loop( + ity, std::move(block_start_1), std::move(block_end_1), std::move(step_1), + [&](region_builder &bb, value block) { body(bb, block, bs_1); }, for_attributes); + }); + + auto tmp0 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, rem, c_tiles, ity)); + auto tmp1 = instant_constant_fold_add(bb, make_arith(arithmetic::add, sg_id_cast, tmp0, ity)); + auto sg_id_1 = instant_constant_fold_add(bb, make_arith(arithmetic::rem, tmp1, c_tiles, ity)); + auto tmp2 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, sg_id_1, ity)); + auto tmp3 = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs_1, rem, ity)); + auto block_start = instant_constant_fold_add(bb, make_arith(arithmetic::add, tmp3, tmp2, ity)); + auto step = instant_constant_fold_add(bb, make_arith(arithmetic::mul, bs, c_tiles, ity)); + bb.for_loop( + ity, std::move(block_start), loop_trip_count, std::move(step), + [&](region_builder &bb, value block) { body(bb, block, bs); }, for_attributes); } -void tile_loop_uniformly_constant(block_builder &bb, unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, - uniform_loop_body_builder const &body) { - // Find minimum number of blocks such that the block sizes are smaller or equal block_size - auto blocks = 1 + (loop_trip_count - 1) / block_size; - // Increase the number of blocks if such that the number of blocks is a multiple - // of the number of tiles - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - auto bs = loop_trip_count / blocks; - auto bs_1 = bs + 1; - auto rem = loop_trip_count % blocks; - - auto block = bb.declare(generic_uint(), "blck"); - if (rem > 0) { - bb.add(for_loop_builder(assignment(block, bs_1 * sg_id), block < bs_1 * rem, - add_into(block, bs_1 * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs_1); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - } - - auto sg_id_1 = (std::move(sg_id) + rem % num_tiles) % num_tiles; - bb.add(for_loop_builder(assignment(block, bs_1 * rem + bs * std::move(sg_id_1)), - block < loop_trip_count, add_into(block, bs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); -} - -void tile_loop_uniformly_dynamic(block_builder &bb, expr loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, - uniform_loop_body_builder const &body) { - auto blocks = - bb.declare_assign(generic_uint(), "blocks", 1 + (loop_trip_count - 1) / block_size); - bb.assign(blocks, (1 + (blocks - 1) / num_tiles) * num_tiles); - auto bs = bb.declare_assign(generic_uint(), "bs", loop_trip_count / blocks); - auto bs_1 = bb.declare_assign(generic_uint(), "bs_1", bs + 1); - auto rem = bb.declare_assign(generic_uint(), "rem", loop_trip_count % blocks); - - auto block = bb.declare(generic_uint(), "blck"); - bb.add(for_loop_builder(assignment(block, bs_1 * sg_id), block < bs_1 * rem, - add_into(block, bs_1 * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs_1); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - - auto sg_id_1 = (std::move(sg_id) + rem % num_tiles) % num_tiles; - bb.add(for_loop_builder(assignment(block, bs_1 * rem + bs * std::move(sg_id_1)), - block < std::move(loop_trip_count), add_into(block, bs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); +auto mixed_precision_arithmetic(region_builder &bb, scalar_type result_ty, arithmetic operation, + value a, value b, location const &loc) -> value { + scalar_data_type *at = dyn_cast(a->ty()); + scalar_data_type *bt = dyn_cast(b->ty()); + if (at == nullptr || bt == nullptr) { + throw compilation_error(loc, status::ir_expected_scalar); + } + if (at->ty() != result_ty || bt->ty() != result_ty) { + if (!promotable(at->ty(), result_ty) || !promotable(bt->ty(), result_ty)) { + throw compilation_error(loc, status::ir_forbidden_promotion); + } + auto promoted_ty = scalar_data_type::get(at->context(), result_ty); + + if (at->ty() != result_ty) { + a = bb.add(make_cast(a, promoted_ty, loc)); + } + if (bt->ty() != result_ty) { + b = bb.add(make_cast(b, promoted_ty, loc)); + } + } + return bb.add( + make_arith(operation, a, b, scalar_data_type::get(at->context(), result_ty), loc)); +} +auto mixed_precision_coopmatrix_scale(region_builder &bb, value a, value b, location const &loc) + -> value { + scalar_data_type *at = dyn_cast(a->ty()); + if (at == nullptr) { + throw compilation_error(loc, status::ir_expected_scalar); + } + coopmatrix_data_type *bt = dyn_cast(b->ty()); + if (bt == nullptr) { + throw compilation_error(loc, status::ir_expected_coopmatrix); + } + const auto a_ty = at->ty(); + const auto b_ty = bt->component_ty(); + if (a_ty != b_ty) { + if (!promotable(a_ty, b_ty)) { + throw compilation_error(loc, status::ir_forbidden_promotion); + } + a = bb.add(make_cast(a, bt->ty(), loc)); + } + return bb.add(make_cooperative_matrix_scale(a, b, bt, loc)); +} + +auto get_atomic_store_flag(value beta) -> std::optional { + constant_inst *beta_cst = dyn_cast(beta->defining_inst()); + if (beta_cst) { + if (beta_cst->is_zero()) { + return store_flag::atomic; + } else if (beta_cst->is_identity()) { + return store_flag::atomic_add; + } + } + return std::nullopt; +} +void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value beta, value C, + array_view index_list, location const &loc) { + memref_data_type *ct = dyn_cast(C->ty()); + if (ct == nullptr) { + throw compilation_error(loc, {C.get()}, status::ir_expected_scalar); + } + auto alpha_ab = + mixed_precision_arithmetic(bb, ct->element_ty(), arithmetic::mul, alpha, ab, loc); + if (atomic) { + auto flag = get_atomic_store_flag(beta); + if (!flag) { + throw compilation_error(loc, status::ir_invalid_beta); + } + bb.add(make_store(*flag, alpha_ab, C, index_list, loc)); + } else { + auto c = bb.add(make_load(C, index_list, ct->element_data_ty(), loc)); + auto beta_c = + mixed_precision_arithmetic(bb, ct->element_ty(), arithmetic::mul, beta, c, loc); + auto alpha_ab_plus_beta_c = mixed_precision_arithmetic( + bb, ct->element_ty(), arithmetic::add, alpha_ab, beta_c, loc); + bb.add(make_store(store_flag::regular, alpha_ab_plus_beta_c, C, index_list, loc)); + } +} + +auto instant_constant_fold_add(region_builder &bb, inst i) -> value { + auto ctx = i->context(); + if (!ctx) { + throw compilation_error(i->loc(), status::internal_compiler_error); + } + + auto fold = visit(constant_folding{ctx->opt_flag(optflag::unsafe_fp_math)}, *i); + auto val = std::visit(overloaded{[](tinytc_value_t &v) -> tinytc_value_t { return v; }, + [&bb](inst &j) -> tinytc_value_t { + if (j) { + return bb.add(std::move(j)); + } + return nullptr; + }}, + fold); + if (val) { + return val; + } + return bb.add(std::move(i)); +} + +auto get_bool_constant(tinytc_value_t val) -> std::optional { + if (auto i = val->defining_inst(); i) { + if (auto *ci = dyn_cast(i); ci) { + if (std::holds_alternative(ci->value())) { + return std::get(ci->value()); + } + } + } + return std::nullopt; +} + +auto get_int_constant(const_tinytc_value_t val) -> std::optional { + if (auto i = val->defining_inst(); i) { + if (auto *ci = dyn_cast(i); ci) { + if (std::holds_alternative(ci->value())) { + return std::get(ci->value()); + } + } + } + return std::nullopt; +} + +auto get_int_constant(tinytc_value const &val) -> std::optional { + return get_int_constant(&val); +} + +auto get_coopmatrix_type(tinytc_value const &v) -> coopmatrix_data_type const * { + auto ct = dyn_cast(v.ty()); + if (!ct) { + throw compilation_error(v.loc(), status::ir_expected_coopmatrix); + } + return ct; +} +auto get_memref_type(tinytc_value const &v) -> memref_data_type const * { + auto mt = dyn_cast(v.ty()); + if (!mt) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return mt; +} +auto get_scalar_type(tinytc_value const &v) -> scalar_type { + auto st = dyn_cast(v.ty()); + if (!st) { + throw compilation_error(v.loc(), status::ir_expected_scalar); + } + return st->ty(); +} +auto get_yield(location const &loc, tinytc_region const ®) -> yield_inst const * { + const yield_inst *y = nullptr; + if (auto it = reg.end(); --it != reg.end()) { + y = dyn_cast(it.get()); + } + if (!y) { + throw compilation_error(loc, status::ir_must_have_yield); + } + return y; +} + +auto add_check(checked_flag flag, checked_flag new_flag) -> checked_flag { + return checked_flag{std::underlying_type_t(flag) | + std::underlying_type_t(new_flag)}; +} + +work_group_op::work_group_op(std::int32_t num_tiles, std::int32_t subgroup_size, data_type ty) + : num_tiles_{num_tiles}, subgroup_size_{subgroup_size}, ty_{ty}, tmp_{nullptr} {} + +void work_group_op::setup(region_builder &bb, location const &loc) { + if (num_tiles_ > 1) { + auto tmp_ty = get_memref(ty_, {num_tiles_}, {}, address_space::local, loc); + tmp_ = bb.add(make_alloca(tmp_ty, loc)); + } +} + +void work_group_op::teardown(region_builder &bb) { + if (tmp_) { + bb.add(inst{std::make_unique(tmp_).release()}); + } +} + +auto work_group_reduce::make(region_builder &bb, value a, location const &loc) -> value { + auto a_reduced = bb.add( + make_subgroup_operation(group_arithmetic::add, group_operation::reduce, a, ty_, loc)); + + if (num_tiles_ > 1) { + auto ctx = compiler_context{a->context(), true}; + auto bool_ty = get_boolean(ctx); + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto sgid = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, loc)); + auto sglid = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, loc)); + auto c_zero = bb.add(make_constant_zero(i32_ty, loc)); + auto is_sglid_0 = bb.add(make_cmp(cmp_condition::eq, sglid, c_zero, bool_ty, loc)); + bb.if_condition( + is_sglid_0, + [&](region_builder &bb) { + auto sgid_index = bb.add(make_cast(sgid, index_ty, loc)); + bb.add(make_store(store_flag::regular, a_reduced, tmp_, {sgid_index}, loc)); + }, + loc); + bb.add(make_barrier(static_cast(address_space::local), loc)); + + auto is_lid_0 = bb.add(make_cmp(cmp_condition::eq, sgid, c_zero, bool_ty, loc)); + bb.if_condition( + is_lid_0, + [&](region_builder &bb) { + auto c_num_tiles = bb.add(make_constant(num_tiles_, i32_ty, loc)); + auto c_sgs = bb.add(make_constant(subgroup_size_, i32_ty, loc)); + auto c_init = bb.add(make_constant_zero(ty_, loc)); + auto acc = bb.for_loop( + i32_ty, sglid, c_num_tiles, c_sgs, {c_init}, {ty_}, + [&](region_builder &bb, array_view args) { + auto lv_index = bb.add(make_cast(args[0], index_ty, loc)); + auto a_sg_reduced = bb.add(make_load(tmp_, {lv_index}, ty_, loc)); + auto sum = + bb.add(make_arith(arithmetic::add, args[1], a_sg_reduced, ty_, loc)); + bb.add(make_yield({sum}, loc)); + }); + a_reduced = bb.add(make_subgroup_operation( + group_arithmetic::add, group_operation::reduce, acc[0], ty_, loc)); + return a_reduced; + }, + loc); + } + return a_reduced; +} + +auto work_group_inclusive_scan::make(region_builder &bb, value a, bool compute_sum, + location const &loc) -> std::pair { + auto a_scan = bb.add(make_subgroup_operation(group_arithmetic::add, + group_operation::inclusive_scan, a, ty_, loc)); + + auto ctx = compiler_context{a->context(), true}; + auto i32_ty = get_scalar(ctx, scalar_type::i32); + + if (num_tiles_ > 1) { + auto bool_ty = get_boolean(ctx); + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto sgid = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, loc)); + auto sglid = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, loc)); + + auto c_sgs_1 = bb.add(make_constant(subgroup_size_ - 1, i32_ty, loc)); + auto is_last_sglid = bb.add(make_cmp(cmp_condition::eq, sglid, c_sgs_1, bool_ty, loc)); + bb.if_condition( + is_last_sglid, + [&](region_builder &bb) { + auto sgid_index = bb.add(make_cast(sgid, index_ty, loc)); + bb.add(make_store(store_flag::regular, a_scan, tmp_, {sgid_index}, loc)); + }, + loc); + bb.add(make_barrier(static_cast(address_space::local), loc)); + + auto c_zero = bb.add(make_constant_zero(i32_ty, loc)); + a_scan = bb.for_loop(i32_ty, c_zero, sgid, nullptr, {a_scan}, {ty_}, + [&](region_builder &bb, array_view args) { + auto lv_index = bb.add(make_cast(args[0], index_ty, loc)); + auto prefix = bb.add(make_load(tmp_, {lv_index}, ty_, loc)); + auto scan = + bb.add(make_arith(arithmetic::add, args[1], prefix, ty_, loc)); + bb.add(make_yield({scan}, loc)); + })[0]; + + if (compute_sum) { + auto c_num_tiles_1 = bb.add(make_constant(num_tiles_ - 1, i32_ty, loc)); + auto c_num_tiles_1_index = bb.add(make_cast(c_num_tiles_1, index_ty, loc)); + auto is_last_sgid = + bb.add(make_cmp(cmp_condition::eq, sgid, c_num_tiles_1, bool_ty, loc)); + auto is_last_work_item = + bb.add(make_arith(arithmetic::and_, is_last_sglid, is_last_sgid, bool_ty, loc)); + bb.if_condition( + is_last_work_item, + [&](region_builder &bb) { + bb.add( + make_store(store_flag::regular, a_scan, tmp_, {c_num_tiles_1_index}, loc)); + }, + loc); + bb.add(make_barrier(static_cast(address_space::local), loc)); + auto sum = bb.add(make_load(tmp_, {c_num_tiles_1_index}, ty_, loc)); + return {a_scan, sum}; + } + } else if (compute_sum) { + auto c_sgs_1 = bb.add(make_constant(subgroup_size_ - 1, i32_ty, loc)); + auto sum = bb.add(make_subgroup_broadcast(a_scan, c_sgs_1, ty_, loc)); + return {a_scan, sum}; + } + return {a_scan, nullptr}; } } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 4825b0a5..d27970fd 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -4,53 +4,100 @@ #ifndef CODEGEN_TOOLS_20240229_HPP #define CODEGEN_TOOLS_20240229_HPP +#include "device_info.hpp" +#include "node/function_node.hpp" +#include "tiling.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include #include - -#include -#include -#include -#include +#include +#include +#include namespace tinytc { -clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); - -void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, clir::expr beta); -void atomic_store_helper(clir::block_builder &bb, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, clir::expr beta); - -void dispatch_constant_dynamic(clir::expr e, std::function const &const_case, - std::function const &dyn_case); - -using sgs_loop_body_builder = - std::function; -using uniform_loop_body_builder = - std::function; - -void tile_loop_by_sgs(clir::block_builder &bb, clir::expr loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, sgs_loop_body_builder const &body); -void tile_loop_by_sgs_constant(clir::block_builder &bb, unsigned loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, - sgs_loop_body_builder const &body); -void tile_loop_by_sgs_dynamic(clir::block_builder &bb, clir::expr loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, - sgs_loop_body_builder const &body); - -unsigned tile_loop_uniformly_max_block_size(unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles); -void tile_loop_uniformly(clir::block_builder &bb, clir::expr loop_trip_count, unsigned block_size, - unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); -void tile_loop_uniformly_constant(clir::block_builder &bb, unsigned loop_trip_count, - unsigned block_size, unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); -void tile_loop_uniformly_dynamic(clir::block_builder &bb, clir::expr loop_trip_count, - unsigned block_size, unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); +class coopmatrix_data_type; +class memref_data_type; +class yield_inst; + +auto get_core_config_and_tiling(function_node const &fn, const_tinytc_core_info_t info) + -> std::pair; + +using sgs_loop_body_builder = std::function; +using uniform_loop_body_builder = std::function; + +void tile_loop_by_sgs(region_builder &bb, value loop_trip_count, int sgs, int num_tiles, + value sg_id, sgs_loop_body_builder const &body, + attr for_attributes = nullptr); + +void tile_loop_uniformly(region_builder &bb, value loop_trip_count, int block_size, int num_tiles, + value sg_id, uniform_loop_body_builder const &body, + attr for_attributes = nullptr); + +auto mixed_precision_arithmetic(region_builder &bb, scalar_type result_ty, arithmetic operation, + value a, value b, location const &loc) -> value; +auto mixed_precision_coopmatrix_scale(region_builder &bb, value a, value b, location const &loc) + -> value; + +auto get_atomic_store_flag(value beta) -> std::optional; +void blas_update(region_builder &bb, bool atomic, value alpha, value ab, value beta, value C, + array_view index_list, location const &loc); + +auto instant_constant_fold_add(region_builder &bb, inst i) -> value; +auto get_bool_constant(tinytc_value_t val) -> std::optional; +auto get_int_constant(const_tinytc_value_t val) -> std::optional; +auto get_int_constant(tinytc_value const &val) -> std::optional; +auto get_coopmatrix_type(tinytc_value const &v) -> coopmatrix_data_type const *; +auto get_memref_type(tinytc_value const &v) -> memref_data_type const *; +auto get_scalar_type(tinytc_value const &v) -> scalar_type; +auto get_yield(location const &loc, tinytc_region const ®) -> yield_inst const *; + +template auto get_int_constants(T &&val_range) -> std::vector { + auto result = std::vector{}; + result.reserve(val_range.size()); + for (auto &val : val_range) { + const auto cst = get_int_constant(val); + result.emplace_back(cst ? *cst : dynamic); + } + return result; +} + +auto add_check(checked_flag flag, checked_flag new_flag) -> checked_flag; + +class work_group_op { + public: + work_group_op(std::int32_t num_tiles, std::int32_t subgroup_size, data_type ty); + + void setup(region_builder &bb, location const &loc); + void teardown(region_builder &bb); + + inline auto num_tiles() const -> std::int32_t { return num_tiles_; } + inline auto subgroup_size() const -> std::int32_t { return subgroup_size_; } + inline auto ty() const -> data_type { return ty_; } + + protected: + std::int32_t num_tiles_, subgroup_size_; + data_type ty_; + value tmp_; +}; + +class work_group_reduce : public work_group_op { + public: + using work_group_op::work_group_op; + + auto make(region_builder &bb, value a, location const &loc) -> value; +}; + +class work_group_inclusive_scan : public work_group_op { + public: + using work_group_op::work_group_op; + + auto make(region_builder &bb, value a, bool compute_sum, location const &loc) + -> std::pair; +}; } // namespace tinytc diff --git a/src/compiler.cpp b/src/compiler.cpp index 8493d738..b726ae91 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -1,59 +1,166 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "device_info.hpp" +#include "compiler_context.hpp" #include "error.hpp" #include "node/program_node.hpp" -#include "parser.hpp" +// IWYU pragma: begin_keep +#include "pass/dump_cfg.hpp" +#include "pass/dump_def_use.hpp" +#include "pass/dump_gcd.hpp" +#include "pass/dump_ir.hpp" +// IWYU pragma: end_keep +#include "pass/check_ir.hpp" +#include "pass/constant_propagation.hpp" +#include "pass/convert_to_spirv.hpp" +#include "pass/dead_code_elimination.hpp" +#include "pass/insert_barrier.hpp" +#include "pass/insert_lifetime_stop.hpp" +#include "pass/lower_coopmatrix.hpp" +#include "pass/lower_foreach.hpp" +#include "pass/lower_linalg.hpp" +#include "pass/stack.hpp" +#include "pass/work_group_size.hpp" #include "passes.hpp" -#include "reference_counted.hpp" -#include "required_extensions.hpp" -#include "source.hpp" +#include "spv/pass/assemble.hpp" +#include "spv/pass/assign_ids.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include "tinytc/types.hpp" -#include -#include - -#include -#include +#include +#include +#include // IWYU pragma: keep #include -#include using namespace tinytc; +namespace tinytc { + +template struct optflag_setter { + PassT &pass; + tinytc_compiler_context_t ctx; + + template void operator()(Flags &&...flags) { + (pass.set_opt_flag(flags, ctx->opt_flag(flags)), ...); + } +}; + +void apply_default_optimization_pipeline(tinytc_prog_t prg, const_tinytc_core_info_t info) { + auto ctx = prg->context(); + const auto opt_level = ctx->opt_level(); + + // passes + auto cpp = constant_propagation_pass{}; + optflag_setter{cpp, ctx}(tinytc::optflag::unsafe_fp_math); + + run_function_pass(check_ir_pass{}, *prg); + + if (opt_level >= 1) { + // We run constant propagation + dead code elimination early to capture dead allocas + // (later on they are maybe "in use" due to the lifetime_stop instruction) + run_function_pass(cpp, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } + + run_function_pass(insert_lifetime_stop_pass{}, *prg); + run_function_pass(set_stack_ptr_pass{}, *prg); + run_function_pass(insert_barrier_pass{}, *prg); + run_function_pass(work_group_size_pass{info}, *prg); + + run_function_pass(lower_linalg_pass{info}, *prg); + // Run set stack ptr again as lower linalg may introduce allocas for the duration of the + // linalg op. Lower linalg is expected to insert lifetime_stop instructions, after it is done + // so we do not need to run the lifetime stop pass again. + run_function_pass(set_stack_ptr_pass{}, *prg); + run_function_pass(lower_foreach_pass{info}, *prg); + if (opt_level >= 1) { + run_function_pass(cpp, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } + run_function_pass(lower_coopmatrix_pass{info}, *prg); + + run_function_pass(check_ir_pass{}, *prg); +} + +} // namespace tinytc + extern "C" { -tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx) { - if (src == nullptr || prg == nullptr || info == nullptr) { +tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( [&] { - // passes - check_ir(*prg); - insert_lifetime_stop_inst(*prg); - set_stack_ptrs(*prg); - insert_barriers(*prg); - set_work_group_size(*prg, *info); - // opencl - auto ast = generate_opencl_ast(*prg, *info); - clir::make_names_unique(ast); - - auto oss = std::ostringstream{}; - auto ext = required_extensions(ast); - for (auto const &e : ext) { - oss << "#pragma OPENCL EXTENSION " << e << " : enable" << std::endl; - } - - clir::generate_opencl(oss, std::move(ast)); - - *src = std::make_unique<::tinytc_source>(oss.str(), prg->loc(), std::move(ext), - info->core_features()) - .release(); +#define FUNCTION_PASS(NAME, CREATE_PASS, ...) \ + if (strcmp(NAME, pass_name) == 0) { \ + auto pass = CREATE_PASS; \ + optflag_setter{pass, prg->context()}(__VA_ARGS__); \ + return run_function_pass(std::move(pass), *prg); \ + } +#define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) \ + if (strcmp(NAME, pass_name) == 0) { \ + return run_function_pass(CREATE_PASS(info), *prg); \ + } +#include "passes.def" +#undef FUNCTION_PASS +#undef FUNCTION_PASS_WITH_INFO + throw status::unknown_pass_name; }, - ctx); + prg->context()); +} + +tinytc_status_t tinytc_list_function_passes(uint32_t *names_size, char const *const **names) { + if (names_size == nullptr || names == nullptr) { + return tinytc_status_invalid_arguments; + } +#define FUNCTION_PASS(NAME, CREATE_PASS, ...) NAME, +#define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) NAME, + static char const *const pass_names[] = { +#include "passes.def" + }; +#undef FUNCTION_PASS +#undef FUNCTION_PASS_WITH_INFO + *names_size = sizeof(pass_names) / sizeof(char const *); + *names = pass_names; + + return tinytc_status_success; +} + +tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (mod == nullptr || prg == nullptr || info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { + apply_default_optimization_pipeline(prg, info); + + *mod = convert_to_spirv_pass{info}.run_on_program(*prg).release(); + spv::id_assigner{}.run_on_module(**mod); + }, + prg->context()); +} + +tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble(tinytc_binary_t *bin, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (bin == nullptr || prg == nullptr || info == nullptr) { + return tinytc_status_invalid_arguments; + } + tinytc_spv_mod_t mod; + TINYTC_CHECK_STATUS(tinytc_prog_compile_to_spirv(&mod, prg, info)); + auto mod_ = spv_mod{mod}; // For clean-up + TINYTC_CHECK_STATUS(tinytc_spirv_assemble(bin, mod_.get())); + return tinytc_status_success; +} + +tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, const_tinytc_spv_mod_t mod) { + if (bin == nullptr || mod == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *bin = spv::assembler{}.run_on_module(*mod).release(); }); } } diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp new file mode 100644 index 00000000..6c23efdc --- /dev/null +++ b/src/compiler_context.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" +#include "error.hpp" +#include "node/value_node.hpp" +#include "tinytc/tinytc.h" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { +void default_error_reporter(char const *, const tinytc_location_t *, void *) {} +} // namespace tinytc + +using namespace tinytc; + +extern "C" { + +tinytc_compiler_context::tinytc_compiler_context() + : cache_{std::make_unique(this)} { + opt_flags_.fill(-1); +} + +auto tinytc_compiler_context::source_name(std::int32_t source_id) + -> std::pair { + if (has_source_id(source_id)) { + auto &si = sources_[source_id - 1]; + return {si.name.c_str(), si.name.size()}; + } + return {unavailable_source_name, sizeof(unavailable_source_name) / sizeof(char) - 1}; +} +auto tinytc_compiler_context::source_text(std::int32_t source_id) + -> std::pair { + if (has_source_id(source_id)) { + auto &si = sources_[source_id - 1]; + return {si.text.c_str(), si.text.size()}; + } + return {"", 0}; +} +void tinytc_compiler_context::report_error(location const &l, char const *what) { + report_error(l, {}, what); +} + +void tinytc_compiler_context::report_error(tinytc_location const &l, + array_view const &ref_values, + char const *what) { + auto [name, name_size] = source_name(l.begin.source_id); + auto [text, text_size] = source_text(l.begin.source_id); + auto err = report_error_with_context(text, text_size, name, l, what); + reporter_(err.c_str(), &l, user_data_); + for (auto &ref_value : ref_values) { + if (ref_value) { + auto err = report_error_with_context(text, text_size, name, ref_value->loc(), + "value defined here"); + reporter_(err.c_str(), &ref_value->loc(), user_data_); + } + } +} + +auto tinytc_compiler_context::opt_flag(tinytc_optflag_t flag) const -> bool { + const auto state = opt_flags_[flag]; + if (state >= 0) { + return state > 0; + } + const auto clamped_opt_level = std::min(2, std::max(0, opt_level_)); + return default_opt_flags[clamped_opt_level][flag]; +} + +tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *ctx = std::make_unique().release(); }); +} + +tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler_context_t ctx, char const *name, + char const *text, int32_t *source_id) { + if (ctx == nullptr || name == nullptr || text == nullptr || source_id == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *source_id = ctx->add_source(name, text); }); +} + +tinytc_status_t tinytc_compiler_context_set_error_reporter(tinytc_compiler_context_t ctx, + tinytc_error_reporter_t reporter, + void *user_data) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->set_error_reporter(reporter, user_data); }); +} + +tinytc_status_t tinytc_compiler_context_set_optimization_flag(tinytc_compiler_context_t ctx, + tinytc_optflag_t flag, + int32_t state) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->opt_flag(flag, state); }); +} + +tinytc_status_t tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, + int32_t level) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + ctx->opt_level(level); + return tinytc_status_success; +} + +tinytc_status_t tinytc_compiler_context_report_error(tinytc_compiler_context_t ctx, + const tinytc_location_t *location, + char const *what) { + if (ctx == nullptr || location == nullptr || what == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->report_error(*location, what); }); +} + +tinytc_status_t tinytc_compiler_context_release(tinytc_compiler_context_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + auto ref_count = obj->dec_ref(); + if (ref_count == 0) { + delete obj; + } + return tinytc_status_success; +} + +tinytc_status_t tinytc_compiler_context_retain(tinytc_compiler_context_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + obj->inc_ref(); + return tinytc_status_success; +} +} diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp new file mode 100644 index 00000000..8ebfd922 --- /dev/null +++ b/src/compiler_context.hpp @@ -0,0 +1,84 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COMPILER_CONTEXT_20240924_HPP +#define COMPILER_CONTEXT_20240924_HPP + +#include "compiler_context_cache.hpp" +#include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { +void default_error_reporter(char const *what, const tinytc_location_t *location, void *user_data); +} // namespace tinytc + +struct tinytc_compiler_context : tinytc::reference_counted { + public: + constexpr static const char unavailable_source_name[] = "Source name unavailable"; + constexpr static std::array, 3u> default_opt_flags = + {{{false}, {false}, {true}}}; + + tinytc_compiler_context(); + + inline auto cache() -> tinytc::compiler_context_cache * { return cache_.get(); } + + inline void set_error_reporter(tinytc::error_reporter_t reporter, void *user_data) { + reporter_ = reporter; + user_data_ = user_data; + } + + // source / error handling + inline auto add_source(std::string name, std::string text) -> std::int32_t { + sources_.emplace_back(source_input{std::move(name), std::move(text)}); + return static_cast(sources_.size()); + } + inline auto add_source(char const *name, char const *text) -> std::int32_t { + sources_.emplace_back(source_input{std::string(name), std::string(text)}); + return static_cast(sources_.size()); + } + auto source_name(std::int32_t source_id) -> std::pair; + auto source_text(std::int32_t source_id) -> std::pair; + void report_error(tinytc_location const &l, char const *what); + void report_error(tinytc_location const &l, + tinytc::array_view const &ref_values, char const *what); + + auto opt_flag(tinytc_optflag_t flag) const -> bool; + inline void opt_flag(tinytc_optflag_t flag, std::int32_t state) { opt_flags_[flag] = state; } + inline auto opt_flag(tinytc::optflag flag) const -> bool { + return opt_flag(static_cast(flag)); + } + inline void opt_flag(tinytc::optflag flag, std::int32_t state) { + opt_flag(static_cast(flag), state); + } + + inline auto opt_level() const noexcept -> std::int32_t { return opt_level_; } + inline void opt_level(std::int32_t level) noexcept { opt_level_ = level; } + + private: + struct source_input { + std::string name, text; + }; + + inline bool has_source_id(std::int32_t source_id) const { + return source_id >= 1 && static_cast(source_id) <= sources_.size(); + } + + std::unique_ptr cache_; + tinytc::error_reporter_t reporter_ = &tinytc::default_error_reporter; + void *user_data_ = nullptr; + std::vector sources_; + std::array opt_flags_; + std::int32_t opt_level_ = 2; +}; + +#endif // COMPILER_CONTEXT_20240924_HPP diff --git a/src/compiler_context_cache.cpp b/src/compiler_context_cache.cpp new file mode 100644 index 00000000..d7e13610 --- /dev/null +++ b/src/compiler_context_cache.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context_cache.hpp" +#include "node/attr_node.hpp" +#include "node/data_type_node.hpp" +#include "util/casting.hpp" + +namespace tinytc { + +enum class scalar_type; + +compiler_context_cache::compiler_context_cache(tinytc_compiler_context_t ctx) { + bool_ty = std::unique_ptr(new boolean_data_type(ctx)); + void_ty = std::unique_ptr(new void_data_type(ctx)); + for (int i = 0; i < TINYTC_NUMBER_OF_SCALAR_TYPES; ++i) { + scalar_tys[i] = + std::unique_ptr(new scalar_data_type(ctx, enum_cast(i))); + } + + false_attr = std::unique_ptr(new boolean_attr(ctx, false)); + true_attr = std::unique_ptr(new boolean_attr(ctx, true)); +} + +compiler_context_cache::~compiler_context_cache() {} + +} // namespace tinytc + diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp new file mode 100644 index 00000000..5a18c601 --- /dev/null +++ b/src/compiler_context_cache.hpp @@ -0,0 +1,71 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COMPILER_CONTEXT_CACHE_20240925_HPP +#define COMPILER_CONTEXT_CACHE_20240925_HPP + +#include "tinytc/types.h" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace std { +template <> class hash> { + public: + auto operator()(std::pair const &key) const -> std::size_t { + return tinytc::fnv1a_combine(key.first, key.second); + } +}; +} // namespace std + +namespace tinytc { + +template class unique_storage { + public: + ~unique_storage() { + for (auto &m : map_) { + delete m.second; + } + } + + template + auto get(std::uint64_t hash, EqualFun &&is_equal, MakeFun &&make) -> T { + auto range = map_.equal_range(hash); + for (auto it = range.first; it != range.second; ++it) { + if (is_equal(it->second)) { + return it->second; + } + } + + return map_.emplace(hash, make())->second; + } + + private: + std::unordered_multimap map_; +}; + +class compiler_context_cache { + public: + compiler_context_cache(tinytc_compiler_context_t ctx); + ~compiler_context_cache(); + + compiler_context_cache(compiler_context_cache const &) = delete; + compiler_context_cache &operator=(compiler_context_cache const &) = delete; + + std::unique_ptr void_ty, bool_ty; + std::array, TINYTC_NUMBER_OF_SCALAR_TYPES> scalar_tys; + unique_storage coopmatrix_tys, group_tys, memref_tys; + + unique_storage array_attrs, dictionary_attrs, integer_attrs, string_attrs; + std::unique_ptr false_attr, true_attr; +}; + +} // namespace tinytc + +#endif // COMPILER_CONTEXT_CACHE_20240925_HPP diff --git a/src/coopmatrix_layout.cpp b/src/coopmatrix_layout.cpp new file mode 100644 index 00000000..bb007076 --- /dev/null +++ b/src/coopmatrix_layout.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "node/data_type_node.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { + +auto get_layout(core_config const &cfg, coopmatrix_data_type const *ct) -> coopmatrix_layout { + auto l = coopmatrix_layout{}; + l.sty = ct->component_ty(); + l.rows = std::min(ct->shape(0), static_cast(cfg.subgroup_size)); + l.cols = (1 + (l.rows * ct->shape(1) - 1) / cfg.subgroup_size) * cfg.subgroup_size / l.rows; + l.blocks = ct->shape(0) / l.rows; + l.length = l.rows * l.cols * l.blocks / cfg.subgroup_size; + l.shape1 = ct->shape(1); + l.blocks1 = 1; + if (ct->use() == matrix_use::b && l.blocks > 1) { + const auto omega_b = std::max(1, static_cast(2 / size(l.sty))); + l.blocks1 = omega_b; + } + l.ops_per_chan = 1; + if (ct->use() == matrix_use::a) { + const auto omega = std::max(1, static_cast(4 / size(l.sty))); + if (l.cols % l.ops_per_chan == 0) { + l.ops_per_chan = omega; + } + } + + return l; +} + +} // namespace tinytc diff --git a/src/coopmatrix_layout.hpp b/src/coopmatrix_layout.hpp new file mode 100644 index 00000000..dac1efa0 --- /dev/null +++ b/src/coopmatrix_layout.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_LAYOUT_20250428_HPP +#define COOPMATRIX_LAYOUT_20250428_HPP + +#include "util/fnv1a.hpp" + +#include +#include + +namespace tinytc { +class coopmatrix_data_type; +class core_config; +enum class scalar_type; + +struct coopmatrix_layout { + scalar_type sty; + std::int64_t rows, cols, blocks, length, shape1, blocks1; + std::int32_t ops_per_chan; + + inline auto operator==(coopmatrix_layout const &other) const { + return sty == other.sty && rows == other.rows && cols == other.cols && + blocks == other.blocks && length == other.length && shape1 == other.shape1 && + blocks1 == other.blocks1 && ops_per_chan == other.ops_per_chan; + } + + inline auto component_no(std::int64_t block1, std::int64_t col, std::int64_t block2) const + -> std::int64_t { + return block1 + col * blocks1 + block2 * blocks1 * (length / blocks); + } + + inline auto component_no(std::int64_t col, std::int64_t block) const -> std::int64_t { + return component_no(block % blocks1, col, block / blocks1); + } +}; + +auto get_layout(core_config const &cfg, coopmatrix_data_type const *ct) -> coopmatrix_layout; + +} // namespace tinytc + +namespace std { +template <> struct hash { + inline auto operator()(tinytc::coopmatrix_layout const &key) const -> std::size_t { + return fnv1a_combine(key.sty, key.rows, key.cols, key.blocks, key.length, key.shape1, + key.blocks1, key.ops_per_chan); + } +}; +} // namespace std + +#endif // COOPMATRIX_LAYOUT_20250428_HPP diff --git a/src/data_type.cpp b/src/data_type.cpp index cb9827b3..5c2ae5ba 100644 --- a/src/data_type.cpp +++ b/src/data_type.cpp @@ -7,79 +7,95 @@ #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include "tinytc/types.hpp" -#include "util.hpp" +#include "util/casting.hpp" #include -#include -#include -#include + +namespace tinytc { +enum class address_space; +enum class matrix_use; +enum class scalar_type; +} // namespace tinytc using namespace tinytc; extern "C" { -tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t type, - const tinytc_location_t *loc) { - if (dt == nullptr) { + +char const *tinytc_matrix_use_to_string(tinytc_matrix_use_t u) { + switch (u) { + case tinytc_matrix_use_a: + return "matrix_a"; + case tinytc_matrix_use_b: + return "matrix_b"; + case tinytc_matrix_use_acc: + return "matrix_acc"; + } + return "unknown"; +} + +tinytc_status_t tinytc_boolean_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx) { + if (dt == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *dt = std::make_unique(enum_cast(type), get_optional(loc)) - .release(); - }); + return exception_to_status_code([&] { *dt = boolean_data_type::get(ctx); }); } -tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, - uint32_t stride_size, const int64_t *stride, - const tinytc_location_t *loc) { - if (dt == nullptr) { +tinytc_status_t tinytc_scalar_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx, + tinytc_scalar_type_t type) { + if (dt == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - auto shape_vec = std::vector(shape, shape + shape_size); - auto stride_vec = std::vector(); - if (stride_size > 0) { - stride_vec.insert(stride_vec.end(), stride, stride + stride_size); - } - *dt = std::make_unique(enum_cast(scalar_ty), - std::move(shape_vec), std::move(stride_vec), - get_optional(loc)) - .release(); - }); + return exception_to_status_code( + [&] { *dt = scalar_data_type::get(ctx, enum_cast(type)); }); } -tinytc_status_t tinytc_group_type_create(tinytc_data_type_t *dt, tinytc_data_type_t memref_ty, - int64_t offset, const tinytc_location_t *loc) { - if (dt == nullptr) { +tinytc_status_t tinytc_memref_type_get(tinytc_data_type_t *dt, tinytc_data_type_t scalar_ty, + uint32_t shape_size, const int64_t *shape, + uint32_t stride_size, const int64_t *stride, + tinytc_address_space_t addrspace, + const tinytc_location_t *loc) { + if (dt == nullptr || (shape_size != 0 && shape == nullptr) || + (stride_size != 0 && stride == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *dt = - std::make_unique(data_type(memref_ty, true), offset, get_optional(loc)) - .release(); + *dt = memref_data_type::get(scalar_ty, array_view{shape, shape_size}, + array_view{stride, stride_size}, + enum_cast(addrspace), get_optional(loc)); }); } -tinytc_status_t tinytc_data_type_release(tinytc_data_type_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_group_type_get(tinytc_data_type_t *dt, tinytc_data_type_t memref_ty, + int64_t size, int64_t offset, const tinytc_location_t *loc) { + if (dt == nullptr) { return tinytc_status_invalid_arguments; } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; + + return exception_to_status_code( + [&] { *dt = group_data_type::get(memref_ty, size, offset, get_optional(loc)); }); +} + +tinytc_status_t tinytc_coopmatrix_type_get(tinytc_data_type_t *dt, tinytc_data_type_t scalar_ty, + int64_t rows, int64_t cols, tinytc_matrix_use_t use, + const tinytc_location_t *loc) { + if (dt == nullptr || scalar_ty == nullptr) { + return tinytc_status_invalid_arguments; } - return tinytc_status_success; + + return exception_to_status_code([&] { + *dt = coopmatrix_data_type::get(scalar_ty, rows, cols, enum_cast(use), + get_optional(loc)); + }); } -tinytc_status_t tinytc_data_type_retain(tinytc_data_type_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_void_type_get(tinytc_data_type_t *dt, tinytc_compiler_context_t ctx) { + if (dt == nullptr || ctx == nullptr) { return tinytc_status_invalid_arguments; } - obj->inc_ref(); - return tinytc_status_success; + + return exception_to_status_code([&] { *dt = void_data_type::get(ctx); }); } } diff --git a/src/device_info.cpp b/src/device_info.cpp index 5d239378..371f4b98 100644 --- a/src/device_info.cpp +++ b/src/device_info.cpp @@ -4,10 +4,14 @@ #include "device_info.hpp" #include "error.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" #include +#include #include #include #include @@ -22,7 +26,7 @@ core_info_generic::core_info_generic(std::int32_t register_space, std::int32_t m : register_space_(register_space), max_work_group_size_(max_work_group_size), subgroup_sizes_(std::move(subgroup_sizes)) {} -auto core_info_generic::subgroup_sizes() const -> std::vector const & { +auto core_info_generic::subgroup_sizes() const -> array_view { return subgroup_sizes_; } auto core_info_generic::register_space() const -> std::int32_t { return register_space_; } @@ -36,8 +40,9 @@ auto core_info_generic::get_core_config(std::int32_t subgroup_size) const -> tin subgroup_sizes_.end()) { throw std::out_of_range("Requested subgroup size not available"); } - return core_config{subgroup_size, max_work_group_size_, register_space_, false}; + return core_config{subgroup_size, max_work_group_size_, register_space_, &matrix_}; } +auto core_info_generic::matrix() const -> matrix_ext_info const & { return matrix_; } core_info_intel::core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, std::int32_t num_threads_per_eu, @@ -48,43 +53,58 @@ core_info_intel::core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_ std::sort(subgroup_sizes_.begin(), subgroup_sizes_.end()); register_size_ = 32; - if (ip_version_ >= static_cast(intel_gpu_architecture::pvc)) { + if (is_arch(tinytc_intel_gpu_architecture_pvc)) { register_size_ = 64; + set_spirv_feature(spirv_feature::bfloat16_conversion, true); + + const auto block_info = matrix_ext_block_io_info{.base_address_alignment = 64, + .min_stride = 64, + .max_stride = (1 << 24) - 1, + .pos0_alignment = 4, + .stride_alignment = 8, + .width_alignment = 4}; + matrix_ = matrix_ext_info(16, block_info, pvc_matrix_ext_types); + } else if (is_arch(tinytc_intel_gpu_architecture_bmg)) { + register_size_ = 64; + set_spirv_feature(spirv_feature::bfloat16_conversion, true); + + const auto block_info = matrix_ext_block_io_info{.base_address_alignment = 64, + .min_stride = 64, + .max_stride = (1 << 24) - 1, + .pos0_alignment = 4, + .stride_alignment = 16, + .width_alignment = 4}; + matrix_ = matrix_ext_info(16, block_info, pvc_matrix_ext_types); } - num_registers_per_thread_ = num_reg_small_grf(); } auto core_info_intel::num_reg_small_grf() const -> std::int32_t { return 128; } auto core_info_intel::num_reg_large_grf() const -> std::int32_t { - return ip_version_ >= static_cast(intel_gpu_architecture::pvc) - ? 256 - : num_reg_small_grf(); + if (is_arch(tinytc_intel_gpu_architecture_pvc) || is_arch(tinytc_intel_gpu_architecture_bmg)) { + return 256; + } + return num_reg_small_grf(); } -auto core_info_intel::subgroup_sizes() const -> std::vector const & { - return subgroup_sizes_; +auto core_info_intel::num_reg() const -> std::int32_t { + return core_features_ & tinytc_core_feature_flag_large_register_file ? num_reg_large_grf() + : num_reg_small_grf(); } -auto core_info_intel::register_space() const -> std::int32_t { - return register_size_ * num_registers_per_thread_; -} +auto core_info_intel::subgroup_sizes() const -> array_view { return subgroup_sizes_; } + +auto core_info_intel::register_space() const -> std::int32_t { return register_size_ * num_reg(); } auto core_info_intel::core_features() const -> tinytc_core_feature_flags_t { return core_features_; } -void core_info_intel::core_features(tinytc_core_feature_flags_t flags) { - if (flags & tinytc_core_feature_flag_large_register_file) { - num_registers_per_thread_ = num_reg_large_grf(); - } else { - num_registers_per_thread_ = num_reg_small_grf(); - } -} +void core_info_intel::core_features(tinytc_core_feature_flags_t flags) { core_features_ = flags; } auto core_info_intel::max_work_group_size(std::int32_t subgroup_size) const -> std::int32_t { auto const num_threads_per_eu_due_to_register_use = - num_threads_per_eu_ * num_reg_small_grf() / num_registers_per_thread_; + num_threads_per_eu_ * num_reg_small_grf() / num_reg(); auto const num_threads_per_eu_due_to_subgroup_size = num_threads_per_eu_ * subgroup_sizes_.front() / subgroup_size; auto const num_threads_per_eu = @@ -107,12 +127,12 @@ auto core_info_intel::get_core_config(std::int32_t subgroup_size) const -> core_ throw std::out_of_range("Requested subgroup size not available"); } - bool block_read_write_supported = !(subgroup_size == 32 && register_size_ == 32); - return core_config{subgroup_size, max_work_group_size(subgroup_size), register_space(), - block_read_write_supported}; + &matrix_}; } +auto core_info_intel::matrix() const -> matrix_ext_info const & { return matrix_; } + } // namespace tinytc using namespace tinytc; @@ -142,11 +162,64 @@ tinytc_status_t tinytc_core_info_intel_create_from_arch(tinytc_core_info_t *info *info = std::make_unique(static_cast(arch), 8, 7, std::vector{8, 16, 32}) .release(); + (*info)->set_spirv_feature(spirv_feature::float16, true); + (*info)->set_spirv_feature(spirv_feature::float64, false); + (*info)->set_spirv_feature(spirv_feature::int64_atomics, true); + (*info)->set_spirv_feature(spirv_feature::groups, true); + (*info)->set_spirv_feature(spirv_feature::subgroup_dispatch, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_local, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_global, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_local, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_global, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_local, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_global, false); + (*info)->set_spirv_feature(spirv_feature::bfloat16_conversion, false); + (*info)->set_spirv_feature(spirv_feature::subgroup_buffer_block_io, true); break; case tinytc_intel_gpu_architecture_pvc: + case tinytc_intel_gpu_architecture_bmg: *info = std::make_unique(static_cast(arch), 8, 8, std::vector{16, 32}) .release(); + (*info)->set_spirv_feature(spirv_feature::float16, true); + (*info)->set_spirv_feature(spirv_feature::float64, true); + (*info)->set_spirv_feature(spirv_feature::int64_atomics, true); + (*info)->set_spirv_feature(spirv_feature::groups, true); + (*info)->set_spirv_feature(spirv_feature::subgroup_dispatch, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_local, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_global, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_local, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_global, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_local, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_global, true); + (*info)->set_spirv_feature(spirv_feature::bfloat16_conversion, true); + (*info)->set_spirv_feature(spirv_feature::subgroup_buffer_block_io, true); + break; + default: + *info = nullptr; + throw status::invalid_arguments; + } + }); +} + +tinytc_status_t tinytc_core_info_intel_create_from_name(tinytc_core_info_t *info, + char const *name) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + switch (fnv1a(name, std::strlen(name))) { + case "tgl"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_tgl)); + break; + case "pvc"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_pvc)); + break; + case "bmg"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_bmg)); break; default: *info = nullptr; @@ -203,7 +276,7 @@ tinytc_status_t tinytc_core_info_set_core_features(tinytc_core_info_t info, return exception_to_status_code([&] { info->core_features(flags); }); } -tinytc_status_t tinytc_core_info_get_core_features(tinytc_core_info_t info, +tinytc_status_t tinytc_core_info_get_core_features(const_tinytc_core_info_t info, tinytc_core_feature_flags_t *flags) { if (info == nullptr || flags == nullptr) { return tinytc_status_invalid_arguments; @@ -211,6 +284,42 @@ tinytc_status_t tinytc_core_info_get_core_features(tinytc_core_info_t info, return exception_to_status_code([&] { *flags = info->core_features(); }); } +tinytc_status_t tinytc_core_info_set_spirv_feature(tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t available) { + + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { info->set_spirv_feature(enum_cast(feature), available); }); +} + +tinytc_status_t tinytc_core_info_have_spirv_feature(const_tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t *available) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *available = info->have_spirv_feature(enum_cast(feature)); }); +} + +tinytc_status_t tinytc_core_info_get_default_alignment(const_tinytc_core_info_t info, + int32_t *alignment) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *alignment = info->alignment(); }); +} + +tinytc_status_t tinytc_core_info_set_default_alignment(tinytc_core_info_t info, int32_t alignment) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { info->alignment(alignment); }); +} + tinytc_status_t tinytc_core_info_release(tinytc_core_info_t obj) { if (obj == nullptr) { return tinytc_status_invalid_arguments; @@ -229,4 +338,48 @@ tinytc_status_t tinytc_core_info_retain(tinytc_core_info_t obj) { obj->inc_ref(); return tinytc_status_success; } + +char const *tinytc_spirv_feature_to_string(tinytc_spirv_feature_t f) { + switch (f) { + case tinytc_spirv_feature_float16: + return "float16"; + case tinytc_spirv_feature_float64: + return "float64"; + case tinytc_spirv_feature_int64_atomics: + return "int64_atomics"; + case tinytc_spirv_feature_groups: + return "groups"; + case tinytc_spirv_feature_subgroup_dispatch: + return "subgroup_dispatch"; + case tinytc_spirv_feature_atomic_float16_add_local: + return "atomic_float16_add_local"; + case tinytc_spirv_feature_atomic_float16_add_global: + return "atomic_float16_add_global"; + case tinytc_spirv_feature_atomic_float32_add_local: + return "atomic_float32_add_local"; + case tinytc_spirv_feature_atomic_float32_add_global: + return "atomic_float32_add_global"; + case tinytc_spirv_feature_atomic_float64_add_local: + return "atomic_float64_add_local"; + case tinytc_spirv_feature_atomic_float64_add_global: + return "atomic_float64_add_global"; + case tinytc_spirv_feature_atomic_float16_min_max_local: + return "atomic_float16_min_max_local"; + case tinytc_spirv_feature_atomic_float16_min_max_global: + return "atomic_float16_min_max_global"; + case tinytc_spirv_feature_atomic_float32_min_max_local: + return "atomic_float32_min_max_local"; + case tinytc_spirv_feature_atomic_float32_min_max_global: + return "atomic_float32_min_max_global"; + case tinytc_spirv_feature_atomic_float64_min_max_local: + return "atomic_float64_min_max_local"; + case tinytc_spirv_feature_atomic_float64_min_max_global: + return "atomic_float64_min_max_global"; + case tinytc_spirv_feature_bfloat16_conversion: + return "bfloat16_conversion"; + case tinytc_spirv_feature_subgroup_buffer_block_io: + return "subgroup_buffer_block_io"; + } + return "unknown"; +} } diff --git a/src/device_info.hpp b/src/device_info.hpp index a0e638ba..d785f12d 100644 --- a/src/device_info.hpp +++ b/src/device_info.hpp @@ -4,21 +4,26 @@ #ifndef DEVICE_INFO_20240304_HPP #define DEVICE_INFO_20240304_HPP +#include "matrix_ext_info.hpp" #include "reference_counted.hpp" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include #include #include namespace tinytc { +enum class spirv_feature; + //! Core parameters for a specific choice of subgroup size and core feature flags class core_config { public: std::int32_t subgroup_size; ///< Smallest unit of execution std::int32_t max_work_group_size; ///< Maximum size of local work group in number of works items std::int32_t register_space; ///< Size of register file in bytes - bool block_read_write_supported; ///< True if block reads / block writes are suppported + matrix_ext_info const *matrix; }; } // namespace tinytc @@ -26,8 +31,7 @@ class core_config { struct tinytc_core_info : tinytc::reference_counted { //! empty dtor virtual ~tinytc_core_info(); - //! Returns available subgroup sizes - virtual auto subgroup_sizes() const -> std::vector const & = 0; + virtual auto subgroup_sizes() const -> tinytc::array_view = 0; //! Returns availabe register space per subgroup virtual auto register_space() const -> std::int32_t = 0; //! Get core features @@ -39,29 +43,52 @@ struct tinytc_core_info : tinytc::reference_counted { virtual auto minmax_work_group_size() const -> std::int32_t = 0; //! Return core config for specific subgroup size and number of registers per tile virtual auto get_core_config(std::int32_t subgroup_size) const -> tinytc::core_config = 0; + virtual void set_spirv_feature(tinytc::spirv_feature f, bool available) = 0; + virtual auto have_spirv_feature(tinytc::spirv_feature f) const -> bool = 0; + virtual auto matrix() const -> tinytc::matrix_ext_info const & = 0; + virtual auto alignment() const -> std::int32_t = 0; + virtual void alignment(std::int32_t alignment) = 0; }; namespace tinytc { -class core_info_generic : public ::tinytc_core_info { +class core_info_common : public ::tinytc_core_info { + public: + inline void set_spirv_feature(spirv_feature f, bool available) override { + spv_feature_[static_cast(f)] = available; + } + inline auto have_spirv_feature(spirv_feature f) const -> bool override { + return spv_feature_[static_cast(f)]; + } + inline auto alignment() const -> std::int32_t override { return alignment_; } + inline void alignment(std::int32_t alignment) override { alignment_ = alignment; } + + private: + std::array spv_feature_ = {}; + std::int32_t alignment_ = 128; +}; + +class core_info_generic : public core_info_common { public: core_info_generic(std::int32_t register_space, std::int32_t max_work_group_size, std::vector subgroup_sizes); - auto subgroup_sizes() const -> std::vector const & override; + auto subgroup_sizes() const -> array_view override; auto register_space() const -> std::int32_t override; auto core_features() const -> tinytc_core_feature_flags_t override; void core_features(tinytc_core_feature_flags_t flags) override; auto minmax_work_group_size() const -> std::int32_t override; auto get_core_config(std::int32_t subgroup_size) const -> tinytc::core_config override; + auto matrix() const -> matrix_ext_info const & override; private: std::int32_t register_space_; std::int32_t max_work_group_size_; std::vector subgroup_sizes_; + matrix_ext_info matrix_; }; //! Set of core configurations for Intel GPUs -class core_info_intel : public ::tinytc_core_info { +class core_info_intel : public core_info_common { public: /** * @brief ctor @@ -75,8 +102,7 @@ class core_info_intel : public ::tinytc_core_info { core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, std::int32_t num_threads_per_eu, std::vector subgroup_sizes); - //! @copydoc ::tinytc_core_info::subgroup_sizes - auto subgroup_sizes() const -> std::vector const & override; + auto subgroup_sizes() const -> array_view override; //! @copydoc ::tinytc_core_info::register_space auto register_space() const -> std::int32_t override; //! @copydoc ::tinytc_core_info::core_features() @@ -87,10 +113,16 @@ class core_info_intel : public ::tinytc_core_info { auto minmax_work_group_size() const -> std::int32_t override; //! @copydoc ::tinytc_core_info::get_core_config auto get_core_config(std::int32_t subgroup_size) const -> core_config override; + auto matrix() const -> matrix_ext_info const & override; private: + inline auto is_arch(tinytc_intel_gpu_architecture_t arch) const -> bool { + return arch <= ip_version_ && + ip_version_ <= arch + TINYTC_INTEL_GPU_ARCHITECTURE_SUB_VERSION_BITS; + } auto num_reg_small_grf() const -> std::int32_t; auto num_reg_large_grf() const -> std::int32_t; + auto num_reg() const -> std::int32_t; auto max_work_group_size(std::int32_t subgroup_size) const -> std::int32_t; std::uint32_t ip_version_; @@ -98,8 +130,8 @@ class core_info_intel : public ::tinytc_core_info { std::int32_t num_threads_per_eu_; std::vector subgroup_sizes_; std::int32_t register_size_; - std::int32_t num_registers_per_thread_; tinytc_core_feature_flags_t core_features_; + matrix_ext_info matrix_; }; } // namespace tinytc diff --git a/src/error.cpp b/src/error.cpp index 2431eecc..50b63841 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -5,6 +5,7 @@ #include "location.hpp" #include "tinytc/tinytc.h" +#include #include #include #include @@ -12,12 +13,28 @@ namespace tinytc { compilation_error::compilation_error(location const &loc, status code, std::string extra_info) - : loc_(loc), code_(code), extra_info_(std::move(extra_info)) {} + : loc_(loc), ref_values_{}, num_ref_values_{0}, code_(code), + extra_info_(std::move(extra_info)) {} + +compilation_error::compilation_error(location const &loc, + array_view ref_values, status code, + std::string extra_info) + : loc_(loc), code_(code), extra_info_(std::move(extra_info)) { + num_ref_values_ = std::min(error_max_ref, ref_values.size()); + for (std::size_t i = 0; i < num_ref_values_; ++i) { + ref_values_[i] = ref_values[i]; + } +} auto report_error_with_context(char const *code, std::size_t code_len, std::string const &file_name, location const &l, std::string const &what) -> std::string { constexpr int additional_context_lines = 2; + auto oerr = std::ostringstream{}; + oerr << file_name << ":"; + print_range(oerr, l.begin, l.end); + oerr << ": " << what; + int cur_line = 1; const char *begin = code; const char *limit = begin + code_len; @@ -27,7 +44,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri } ++begin; } - auto oerr = std::ostringstream{}; + char const *end = begin; int start_col = -1; while (cur_line <= l.end.line && *end != '\0' && end <= limit) { @@ -39,7 +56,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri if (start_col < 0) { start_col = static_cast(end - begin); } - oerr << std::string(begin, end) << std::endl; + oerr << std::endl << std::string(begin, end) << std::endl; if (cur_line >= l.begin.line) { int col_begin = 0; int num_col = 0; @@ -63,7 +80,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri col_begin = l.begin.column > 1 ? l.begin.column - 1 : 0; num_col = l.end.column > l.begin.column ? l.end.column - l.begin.column : 1; } - oerr << std::string(col_begin, ' ') << std::string(num_col, '~') << std::endl; + oerr << std::string(col_begin, ' ') << std::string(num_col, '~'); } ++cur_line; start_col = -1; @@ -71,10 +88,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri } ++end; } - oerr << file_name << ":"; - print_range(oerr, l.begin, l.end); - oerr << ": " << what; - return oerr.str(); + return std::move(oerr).str(); } } // namespace tinytc @@ -112,11 +126,18 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Invalid arguments passed to kernel"; case tinytc_status_unsupported_device: return "Unsupported device"; + case tinytc_status_invalid_core_info: + return "Invalid core info object (e.g. max work group size is 0 or subgroup sizes vector " + "is empty)"; + case tinytc_status_unknown_pass_name: + return "Unknown compiler pass name"; + case tinytc_status_not_implemented: + return "Not implemented"; // IR case tinytc_status_ir_out_of_bounds: return "Argument is out of bounds"; case tinytc_status_ir_invalid_shape: - return "Mode size must be non-negative"; + return "Invalid shape"; case tinytc_status_ir_incompatible_shapes: return "Incompatible tensor shapes"; case tinytc_status_ir_shape_stride_mismatch: @@ -125,32 +146,141 @@ char const *tinytc_error_string(tinytc_status_t status) { return "Scalar type mismatch"; case tinytc_status_ir_invalid_number_of_indices: return "Number of indices must match memref order or must be 1 for group types"; + case tinytc_status_ir_expected_boolean: + return "Expected boolean type"; case tinytc_status_ir_expected_scalar: return "Expected scalar type"; + case tinytc_status_ir_expected_int: + return "Expected integer type"; + case tinytc_status_ir_expected_float: + return "Expected floating point type"; + case tinytc_status_ir_expected_complex: + return "Expected complex type"; + case tinytc_status_ir_expected_i32: + return "Expected i32 type"; + case tinytc_status_ir_expected_index: + return "Expected index type"; + case tinytc_status_ir_expected_coopmatrix: + return "Expected coopmatrix type"; + case tinytc_status_ir_expected_coopmatrix_or_scalar: + return "Expected coopmatrix type or scalar type"; + case tinytc_status_ir_expected_coopmatrix_scalar_or_boolean: + return "Expected coopmatrix type, scalar type, or boolean type"; case tinytc_status_ir_expected_memref: return "Expected memref type"; case tinytc_status_ir_expected_memref_or_scalar: return "Expected memref type or scalar type"; case tinytc_status_ir_expected_memref_or_group: return "Expected memref or group operand"; - case tinytc_status_ir_expected_vector_or_matrix: - return "Expected vector or matrix input"; + case tinytc_status_ir_expected_memref_order_0: + return "Expected memref of order 0 (scalar)"; + case tinytc_status_ir_expected_memref_order_1: + return "Expected memref of order 1 (vector)"; + case tinytc_status_ir_expected_memref_order_2: + return "Expected memref of order 2 (matrix)"; + case tinytc_status_ir_expected_memref_order_0_or_1: + return "Expected memref of order 0 or 1 (scalar or vector)"; + case tinytc_status_ir_expected_memref_order_1_or_2: + return "Expected memref of order 1 or 2 (vector or matrix)"; + case tinytc_status_ir_expected_memref_order_0_1_or_2: + return "Expected memref of order 0, 1, or 2 (scalar, vector, or matrix)"; case tinytc_status_ir_unexpected_yield: return "Yield encountered in non-yielding region"; case tinytc_status_ir_yield_mismatch: - return "Number of yielded values does not match number of values yielded by region"; - case tinytc_status_ir_multiple_dynamic_modes: - return "At most one mode must be dynamic ('?')"; + return "Number of yielded values does not match number of values yielded by region or the " + "types are different"; + case tinytc_status_ir_subview_mismatch: + return "Number of dynamic offsets and sizes must match number of dynamic operands"; case tinytc_status_ir_invalid_slice: - return "Offset must be non-negative and must not be '?'; size must be positive or '?'"; + return "Static offset and size must be non-negative or dynamic ('?')"; case tinytc_status_ir_expand_shape_order_too_small: return "Expand shape must have at least 2 entries"; case tinytc_status_ir_expand_shape_mismatch: - return "Product of expand shape must equal mode size"; + return "Number of dynamic expand shape operands must equal number of dynamic modes in " + "static expand shape"; case tinytc_status_ir_collective_called_from_spmd: return "Collective instruction must not be called from SPMD region"; case tinytc_status_ir_fp_unsupported: - return "Floating point type unsupported for instruction"; + return "Floating point type unsupported by instruction"; + case tinytc_status_ir_spmd_called_from_collective: + return "SPMD instruction must not be called from collective region"; + case tinytc_status_ir_expected_local_address_space: + return "A memref with local address space is expected"; + case tinytc_status_ir_expected_global_address_space: + return "A memref with global address space is expected"; + case tinytc_status_ir_address_space_mismatch: + return "Address space must match"; + case tinytc_status_ir_invalid_offset: + return "Offset must be non-negative or dynamic"; + case tinytc_status_ir_int_unsupported: + return "int type unsupported by instruction"; + case tinytc_status_ir_boolean_unsupported: + return "boolean type unsupported by instruction"; + case tinytc_status_ir_complex_unsupported: + return "complex type unsupported by instruction"; + case tinytc_status_ir_coopmatrix_unsupported: + return "coopmatrix type unsupported by instruction"; + case tinytc_status_ir_forbidden_cast: + return "Forbidden cast"; + case tinytc_status_ir_invalid_beta: + return "beta must be constant and 0 or 1 for atomic linear algebra operations"; + case tinytc_status_ir_init_return_mismatch: + return "The number or types of the initial values does not match the return type list"; + case tinytc_status_ir_invalid_matrix_use: + return "Operands have invalid matrix use"; + case tinytc_status_ir_unsupported_coopmatrix_shape: + return "Unsupported coopmatrix shape for the combination of scalar type, matrix use, and " + "target architecture"; + case tinytc_status_ir_forbidden_promotion: + return "Scalar type promotion is forbidden"; + case tinytc_status_ir_constant_mismatch: + return "Type of constant does not match type of returned value"; + case tinytc_status_ir_insufficient_alignment: + return "Pointer does not satisfy minimum alignment requirements"; + case tinytc_status_ir_must_have_yield: + return "Last instruction of region that returns values must be \"yield\""; + case tinytc_status_ir_yield_in_else_branch_missing: + return "Else-branch must have yield instruction if then-branch has yield instruction"; + case tinytc_status_ir_from_to_mismatch: + return "length(from) must equal length(to) and length must be greater than 0"; + case tinytc_status_ir_operand_type_must_match_return_type: + return "Type of operand must match return type"; + case tinytc_status_ir_invalid_stride: + return "Invalid stride"; + case tinytc_status_ir_init_return_type_mismatch: + return "Type of initializer does not match return type or the number of return types is " + "not equal the number of initializers"; + case tinytc_status_ir_invalid_alignment: + return "Invalid alignment, must be a positive power of two"; + case tinytc_status_ir_value_still_has_uses: + return "A value shall be erased that still has uses"; + case tinytc_status_ir_expected_array_attribute: + return "Expected array attribute"; + case tinytc_status_ir_expected_boolean_attribute: + return "Expected boolean attribute"; + case tinytc_status_ir_expected_dictionary_attribute: + return "Expected dictionary attribute"; + case tinytc_status_ir_expected_integer_attribute: + return "Expected integer attribute"; + case tinytc_status_ir_expected_string_attribute: + return "Expected string attribute"; + case tinytc_status_ir_duplicate_key_in_dictionary: + return "Duplicate key detected in list of named attributes passed to dictionary"; + case tinytc_status_ir_unexpected_array_attribute_size: + return "Unexpected size in array attribute"; + case tinytc_status_ir_expected_non_scalar_memref: + return "Expected memref of dimension greater or equal than 1"; + // SPIR-V + case tinytc_status_spirv_forbidden_forward_declaration: + return "Forward declaration of id is forbidden"; + case tinytc_status_spirv_undefined_value: + return "Undefined SPIR-V value"; + case tinytc_status_spirv_missing_dope_vector: + return "Dope vector missing (internal compiler error)"; + case tinytc_status_spirv_unsupported_atomic_data_type: + return "Atomic data type unsupported by SPIR-V"; + case tinytc_status_spirv_required_feature_unavailable: + return "A required SPIR-V feature is unavailable"; // Level Zero case tinytc_status_ze_result_not_ready: return "ZE_RESULT_NOT_READY"; diff --git a/src/error.hpp b/src/error.hpp index 3b48d27c..d0d8b483 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -4,13 +4,13 @@ #ifndef ERROR_20240410_HPP #define ERROR_20240410_HPP -#include "parser.hpp" +#include "compiler_context.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include -#include #include #include #include @@ -24,12 +24,19 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri //! Compilation error class compilation_error : public std::exception { public: + constexpr static std::size_t error_max_ref = 4; + //! ctor; taking location, status code, and expanatory string compilation_error(location const &loc, status code, std::string extra_info = {}); + compilation_error(location const &loc, array_view ref_values, status code, + std::string extra_info = {}); //! Get status code inline auto code() const noexcept { return code_; } //! Get location inline auto loc() const noexcept -> location const & { return loc_; } + inline auto ref_values() const noexcept -> array_view { + return array_view(ref_values_.data(), num_ref_values_); + } //! Get explanatory string inline char const *what() const noexcept override { return error_string(code_); } //! Get additional information @@ -37,6 +44,8 @@ class compilation_error : public std::exception { private: location loc_; + std::array ref_values_; + std::size_t num_ref_values_; status code_; std::string extra_info_; }; @@ -47,7 +56,8 @@ class internal_compiler_error : public std::exception { }; template -auto exception_to_status_code(F &&f, tinytc_source_context_t context = nullptr) -> tinytc_status_t { +auto exception_to_status_code(F &&f, tinytc_compiler_context_t context = nullptr) + -> tinytc_status_t { try { f(); } catch (internal_compiler_error const &e) { @@ -64,8 +74,9 @@ auto exception_to_status_code(F &&f, tinytc_source_context_t context = nullptr) if (e.extra_info().size() > 0) { auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); + context->report_error(e.loc(), e.ref_values(), what.c_str()); } else { - context->report_error(e.loc(), e.what()); + context->report_error(e.loc(), e.ref_values(), e.what()); } } return static_cast(e.code()); diff --git a/src/func.cpp b/src/func.cpp index 53d5fcfa..cfadc16c 100644 --- a/src/func.cpp +++ b/src/func.cpp @@ -11,74 +11,45 @@ #include #include #include -#include -#include using namespace tinytc; extern "C" { -tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, tinytc_value_t *arg_list, - const tinytc_location_t *loc) { - if (fun == nullptr || (arg_list_size > 0 && arg_list == nullptr)) { +tinytc_status_t tinytc_func_create(tinytc_func_t *fun, uint32_t name_length, char const *name, + uint32_t num_params, const tinytc_data_type_t *param_type_list, + tinytc_data_type_t ty, const tinytc_location_t *loc) { + if (fun == nullptr || (num_params > 0 && param_type_list == nullptr) || ty == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto arg_vec = std::vector(); - arg_vec.reserve(arg_list_size); - for (uint32_t i = 0; i < arg_list_size; ++i) { - arg_vec.emplace_back(value(arg_list[i], true)); - } - *fun = std::make_unique(std::string(name), std::move(arg_vec), get_optional(loc)) + *fun = std::make_unique(std::string(name, name_length), + array_view(param_type_list, num_params), ty, + get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototype, - tinytc_region_t body, const tinytc_location_t *loc) { - if (fun == nullptr || prototype == nullptr || body == nullptr) { +tinytc_status_t tinytc_func_set_parameter_attr(tinytc_func_t fun, int32_t arg_no, attr a) { + if (fun == nullptr || arg_no < 0 || arg_no >= fun->num_params()) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - *fun = - std::make_unique(func{prototype, true}, region{body, true}, get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, int32_t y) { - function *f = dynamic_cast(fun); - if (f == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { f->work_group_size({x, y}); }); + return exception_to_status_code([&] { fun->param_attr(arg_no, a); }); } -tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs) { - function *f = dynamic_cast(fun); - if (f == nullptr) { +tinytc_status_t tinytc_func_set_attr(tinytc_func_t fun, tinytc_attr_t a) { + if (fun == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { f->subgroup_size(sgs); }); + return exception_to_status_code([&] { fun->attr(a); }); } -tinytc_status_t tinytc_func_release(tinytc_func_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_func_get_body(tinytc_func_t fun, tinytc_region_t *body) { + if (fun == nullptr || body == nullptr) { return tinytc_status_invalid_arguments; } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; + return exception_to_status_code([&] { *body = &fun->body(); }); } -tinytc_status_t tinytc_func_retain(tinytc_func_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} +void tinytc_func_destroy(tinytc_func_t obj) { delete obj; } } diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp deleted file mode 100644 index f101d219..00000000 --- a/src/gemm_generator.cpp +++ /dev/null @@ -1,509 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "gemm_generator.hpp" -#include "codegen_tools.hpp" -#include "device_info.hpp" -#include "precision_helper.hpp" -#include "scalar_type.hpp" -#include "tiling.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -using namespace clir; - -namespace tinytc { - -gemm_scalar_type::gemm_scalar_type(scalar_type ty) : alpha(ty), A(ty), B(ty), beta(ty), C(ty) {} -gemm_scalar_type::gemm_scalar_type(scalar_type alphaAB, scalar_type betaC) - : alpha(alphaAB), A(alphaAB), B(alphaAB), beta(betaC), C(betaC) {} -gemm_scalar_type::gemm_scalar_type(scalar_type alpha, scalar_type A, scalar_type B, - scalar_type beta, scalar_type C) - : alpha(alpha), A(A), B(B), beta(beta), C(C) {} - -std::string gemm_configuration::identifier(std::string_view prefix) const { - std::ostringstream oss; - auto const dyn_val = [&oss](std::int64_t v) { - if (v == dynamic) { - oss << "d"; - } else { - oss << v; - } - }; - auto const stride = [&oss, &dyn_val](char X, std::array const &s) { - oss << "_" << X << "stride"; - dyn_val(s[0]); - oss << "_"; - dyn_val(s[1]); - }; - oss << prefix << "_"; - if (atomic) { - oss << "atomic_"; - } - oss << to_string(ty.alpha) << to_string(ty.A) << to_string(ty.B) << to_string(ty.beta) - << to_string(ty.C) << "_A" << to_string(transA) << "_B" << to_string(transB) << "_M"; - dyn_val(M); - oss << "_N"; - dyn_val(N); - oss << "_K"; - dyn_val(K); - stride('A', A_stride); - stride('B', B_stride); - stride('C', C_stride); - auto const format_optional = [&](std::optional const &val) { - if (val) { - auto f = oss.flags(); - auto v = *val; - oss << std::hex << std::bit_cast(v); - oss.flags(f); - } else { - oss << "d"; - } - }; - oss << "_alpha"; - format_optional(alpha); - oss << "_beta"; - format_optional(beta); - return oss.str(); -} - -constexpr static int max_K_unrolling = 8; - -auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uint32_t sgs, - std::uint32_t register_space, - std::pair max_fill_fraction) - -> std::pair { - auto const arithmetic_intensity = [&sgs](std::uint32_t row_blocks, std::uint32_t cols) { - return (row_blocks * sgs * cols) / static_cast(row_blocks * sgs + cols); - }; - - auto const max_scalars = register_space * max_fill_fraction.first / - (max_fill_fraction.second * C_scalar_type_size_in_bytes); - - // The required number of scalars is given by - // row_blocks * sgs * (cols + max_K_unrolling) + cols * max_K_unrolling - auto const max_row_blocks = [&sgs, &max_scalars](std::uint32_t cols) { - return (max_scalars - cols * max_K_unrolling) / (sgs * (cols + max_K_unrolling)); - }; - auto const max_cols = [&sgs, &max_scalars](std::uint32_t row_blocks) { - return (max_scalars - row_blocks * sgs * max_K_unrolling) / - (row_blocks * sgs + max_K_unrolling); - }; - - double max_ai = 0.0; - std::uint32_t row_blocks = 1, cols = 1; - for (std::uint32_t r = 1; r <= max_row_blocks(1); ++r) { - for (std::uint32_t c = 1; c <= max_cols(r); ++c) { - auto const ai = arithmetic_intensity(r, c); - if (ai > max_ai) { - max_ai = ai; - row_blocks = r; - cols = c; - } - } - } - - return std::make_pair(row_blocks, cols); -} - -class generator { - public: - generator(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, address_space As, address_space Bs, address_space Cs) - : gemm_cfg(gemm_cfg), tiling(tiling), core_cfg(core_cfg), Aspace(As), Bspace(Bs), - Cspace(Cs) {} - void add_microkernel(block_builder &bb, bool is_remainder, expr M, expr N, var A, var B, var C, - expr C_offset, expr alpha, expr beta); - void add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, - expr beta); - void add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta); - ::clir::func function(std::string_view name); - - private: - gemm_configuration const gemm_cfg; - local_tiling const tiling; - core_config const core_cfg; - address_space Aspace, Bspace, Cspace; - unsigned row_blocks_in_register = 1; - unsigned cols_in_register = 1; - var c, m; - std::array MNK; - std::array A_stride, B_stride, C_stride; -}; - -void generator::add_microkernel(block_builder &bb, bool is_remainder, expr M, expr N, var A, var B, - var C, expr C_offset, expr alpha, expr beta) { - std::int64_t n_bs = 0; - bool is_N_constant = false; - dispatch_constant_dynamic( - N, - [&](std::int64_t n) { - n_bs = n; - is_N_constant = true; - }, - [&](expr) { - n_bs = static_cast(cols_in_register); - is_N_constant = false; - }); - std::int64_t const n_blocks = - 1 + (n_bs - 1) / static_cast(core_cfg.subgroup_size); - auto n = var("n"); - - auto my_row_blocks_in_register = row_blocks_in_register; - dispatch_constant_dynamic( - M, - [&](std::int64_t m) { - while (my_row_blocks_in_register > 1 && - m < static_cast(my_row_blocks_in_register) * - core_cfg.subgroup_size) { - --my_row_blocks_in_register; - } - }, - [&](expr) {}); - - auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; - auto const ak = gemm_cfg.transA == transpose::T ? 0 : 1; - auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), "Ab", A); - auto const Aoffset = [&](unsigned m_block) { - return A_stride[am] * (m + m_block * core_cfg.subgroup_size); - }; - - auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; - auto const bk = gemm_cfg.transB == transpose::T ? 1 : 0; - auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), "Bb", B); - auto const Boffset = [&](int n_block) { - return B_stride[bn] * (m + n_block * core_cfg.subgroup_size); - }; - - auto const cmn = [&](unsigned m_block, expr n) { - return c[m_block + row_blocks_in_register * std::move(n)]; - }; - - bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < n_bs, ++n) - .body([&](block_builder &bb) { - for (std::size_t m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - bb.assign(cmn(m_block, n), precision_helper{gemm_cfg.ty.C}.zero()); - } - }) - .attribute(opencl_unroll_hint(n_bs)) - .get_product()); - - auto const compute_c = [&](block_builder &bb, std::int64_t Kb, ::clir::expr K0, - ::clir::expr K1) { - auto kb = var("kb"); - bb.add( - for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), - kb < std::move(K1), add_into(kb, Kb)) - .body([&](block_builder &bb) { - auto at = precision_helper{gemm_cfg.ty.A}; - auto a = bb.declare(array_of(at.type(), my_row_blocks_in_register * Kb), "a"); - auto amk = [&](unsigned m_block, unsigned k) { - return a[m_block + my_row_blocks_in_register * k]; - }; - bool const map_b_to_vec_type = - gemm_cfg.B_stride[bk] == 1 && - (Kb == 2 || Kb == 3 || Kb == 4 || Kb == 8 || Kb == 16); - int k_load_block_size = map_b_to_vec_type ? Kb : 1; - auto bt = precision_helper{gemm_cfg.ty.B}; - auto b = map_b_to_vec_type - ? bb.declare(array_of(bt.type(Kb), n_blocks), "b") - : bb.declare(array_of(bt.type(), n_blocks * Kb), "b"); - auto const read_A = [&](block_builder &bb, unsigned m_block, unsigned k, - bool check) { - auto condition = m + m_block * core_cfg.subgroup_size < M; - auto rhs = Ab[Aoffset(m_block)]; - auto rhs_checked = - check ? ternary_conditional(std::move(condition), rhs, 0) : rhs; - bb.assign(amk(m_block, k), std::move(rhs_checked)); - }; - auto block_read_A = [&](block_builder &bb, unsigned m_block, unsigned k) { - bb.assign( - amk(m_block, k), - at.sub_group_block_read(Ab + m_block * core_cfg.subgroup_size, Aspace)); - }; - for (unsigned k = 0; k < Kb; ++k) { - for (unsigned m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - if (!is_remainder && core_cfg.block_read_write_supported && - gemm_cfg.A_stride[am] == 1) { - block_read_A(bb, m_block, k); - } else { - read_A(bb, m_block, k, is_remainder); - } - } - bb.add(add_into(Ab, A_stride[ak])); - } - - auto const read_B = [&](block_builder &bb, int k, int n_block, bool check) { - auto condition = m + n_block * core_cfg.subgroup_size < N; - if (map_b_to_vec_type) { - auto rhs = vload_helper(Kb, 0, Bb + Boffset(n_block)); - if (rhs) { - auto rhs_checked = - check ? ternary_conditional(condition, rhs, - init_vector(bt.type(Kb), {0})) - : rhs; - bb.assign(b[n_block], std::move(rhs_checked)); - } else { - throw std::logic_error("Vload for native type missing"); - } - } else { - auto rhs = Bb[Boffset(n_block)]; - auto rhs_checked = check ? ternary_conditional(condition, rhs, 0) : rhs; - bb.assign(b[k + n_block * Kb], std::move(rhs_checked)); - } - }; - int first_n_block_with_check = - n_bs < n_blocks * static_cast(core_cfg.subgroup_size) - ? n_blocks - 1 - : n_blocks; - if (!is_N_constant) { - first_n_block_with_check = 0; - } - for (int k = 0; k < Kb; k += k_load_block_size) { - for (int n_block = 0; n_block < first_n_block_with_check; ++n_block) { - read_B(bb, k, n_block, false); - } - for (int n_block = first_n_block_with_check; n_block < n_blocks; - ++n_block) { - read_B(bb, k, n_block, true); - } - bb.add(add_into(Bb, k_load_block_size * B_stride[bk])); - } - - const int nbb = 4; - for (unsigned m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - for (std::int64_t nb = 0; nb < n_bs; nb += nbb) { - for (int k = 0; k < Kb; ++k) { - for (std::int64_t n = 0; n < nbb; ++n) { - if (nb + n < n_bs) { - auto const n_block = (nb + n) / core_cfg.subgroup_size; - auto const n_offset = (nb + n) % core_cfg.subgroup_size; - auto my_a = amk(m_block, k); - auto bkn = map_b_to_vec_type ? b[n_block].s(k) - : b[k + n_block * Kb]; - auto my_b = sub_group_broadcast(std::move(bkn), n_offset); - auto my_c = cmn(m_block, nb + n); - if (gemm_cfg.ty.A == gemm_cfg.ty.B && - gemm_cfg.ty.B == gemm_cfg.ty.C) { - bb.assign(my_c, - fma(std::move(my_a), std::move(my_b), my_c)); - } else { - bb.add(add_into(std::move(my_c), - std::move(my_a) * std::move(my_b))); - } - } - } - } - } - } - }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - }; - dispatch_constant_dynamic( - MNK[2], - [&](std::int64_t K) { - static_assert(max_K_unrolling % 2 == 0, "max_K_unrolling must be a multiple of 2"); - auto Kb = max_K_unrolling; - while (K < Kb && Kb > 1) { - Kb /= 2; - } - auto KmultipleKb = (K / Kb) * Kb; - compute_c(bb, Kb, 0, KmultipleKb); - if (K - KmultipleKb > 0) { - compute_c(bb, 1, KmultipleKb, K); - } - }, - [&](expr K) { - auto KmultipleKb = bb.declare_assign(generic_uint(), "KmultipleKb", - (K / max_K_unrolling) * max_K_unrolling); - compute_c(bb, max_K_unrolling, 0, KmultipleKb); - bb.add(if_selection_builder(K - KmultipleKb > 0) - .then([&](block_builder &bb) { compute_c(bb, 1, KmultipleKb, K); }) - .get_product()); - }); - auto write_C = [&](block_builder &bb) { - auto n_to = is_N_constant ? n_bs : min(N, cast(generic_uint(), n_bs)); - auto n_unroll = is_N_constant ? n_bs : 1; -#if 0 - // We can use block writes if - // 1. They are supported (no block writes for SIMD32 before PVC) - // 2. We are not in a remainder loop - // 3. Data is adjacent in memory - // 4. The address is 16 byte aligned - // 5. We are not writing atomically - bool const use_block_write = - core_cfg.block_read_write_supported && !is_remainder && gemm_cfg.C_stride[0] == 1 && - gemm_cfg.C_stride[1] * size(gemm_cfg.ty.C) % 16 == 0 && !gemm_cfg.atomic -#endif - - // Block writes are disabled for now; would need to track memref alignment in subview / load - // instruction AND would need to impose alignment requirement in calling convention - constexpr bool use_block_write = false; - auto Cb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), "Cb", - C + C_offset); - if (!use_block_write) { - bb.add(add_into(Cb, C_stride[0] * m)); - } - bb.add( - for_loop_builder(declaration_assignment(generic_short(), n, 0), n < std::move(n_to), - ++n) - .body([&](block_builder &bb) { - for (std::size_t m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - auto my_c = alpha * cmn(m_block, n); - auto C_offset_m = C_stride[0] * (m_block * core_cfg.subgroup_size); - if (use_block_write) { - bb.add(precision_helper{gemm_cfg.ty.C}.sub_group_block_write( - Cb + m_block * core_cfg.subgroup_size, - std::move(my_c) + beta * Cb[std::move(C_offset_m)], Cspace)); - } else { - auto const write_C_mn = [&](block_builder &bb) { - store_helper(bb, gemm_cfg.atomic, Cb + C_offset_m, gemm_cfg.ty.C, - Cspace, my_c, beta); - }; - if (is_remainder) { - bb.add( - if_selection_builder(m + m_block * core_cfg.subgroup_size < M) - .then(write_C_mn) - .get_product()); - } else { - write_C_mn(bb); - } - } - } - bb.add(add_into(Cb, cast(generic_uint(), C_stride[1]))); - }) - .attribute(opencl_unroll_hint(n_unroll)) - .get_product()); - }; - write_C(bb); -} - -void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, - expr beta) { - auto sg_m = bb.declare_assign(generic_uint(), "sg_m", get_sub_group_id() % tiling.m_tiles()); - tile_loop_by_sgs( - bb, MNK[0], core_cfg.subgroup_size * row_blocks_in_register, tiling.m_tiles(), - std::move(sg_m), - [&](block_builder &bb, expr block, bool is_remainder, expr inner_trip_count) { - auto Astride_m = gemm_cfg.transA == transpose::T ? A_stride[1] : A_stride[0]; - auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), - "Ab", A + std::move(Astride_m) * block); - add_microkernel(bb, is_remainder, std::move(inner_trip_count), N, std::move(Ab), B, C, - C_stride[0] * std::move(block) + C_offset, alpha, beta); - }); -} - -void generator::add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta) { - m = bb.declare_assign(generic_uint(), "m", get_sub_group_local_id()); - c = var("c"); - - auto [max_row_blocks, max_cols] = max_register_block_gemm( - size(gemm_cfg.ty.C), core_cfg.subgroup_size, core_cfg.register_space); - row_blocks_in_register = max_row_blocks; - cols_in_register = max_cols; - if (!is_dynamic_value(gemm_cfg.M)) { - auto const row_blocks_needed_to_cover_M = 1 + (gemm_cfg.M - 1) / core_cfg.subgroup_size; - if (row_blocks_needed_to_cover_M < max_row_blocks) { - row_blocks_in_register = row_blocks_needed_to_cover_M; - } else { - auto blocks = gemm_cfg.M / row_blocks_in_register; - auto sg_blocks = 1 + (blocks - 1) / tiling.m_tiles(); - while (sg_blocks < tiling.m_tiles() && row_blocks_in_register >= 2) { - row_blocks_in_register /= 2; - blocks = gemm_cfg.M / row_blocks_in_register; - sg_blocks = 1 + (blocks - 1) / tiling.m_tiles(); - } - } - } - if (!is_dynamic_value(gemm_cfg.N)) { - cols_in_register = - tile_loop_uniformly_max_block_size(gemm_cfg.N, cols_in_register, tiling.n_tiles()); - } - bb.declare( - array_of(precision_helper{gemm_cfg.ty.C}.type(), row_blocks_in_register * cols_in_register), - c); - - auto sg_n = bb.declare_assign(generic_uint(), "sg_n", get_sub_group_id() / tiling.m_tiles()); - tile_loop_uniformly( - bb, MNK[1], max_cols, tiling.n_tiles(), std::move(sg_n), - [&](block_builder &bb, expr block, expr inner_trip_count) { - auto Bstride_n = gemm_cfg.transB == transpose::T ? B_stride[0] : B_stride[1]; - auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), - "Bb", B + std::move(Bstride_n) * block); - add_mloop(bb, std::move(inner_trip_count), A, std::move(Bb), C, - C_stride[1] * std::move(block), alpha, beta); - }); -} - -::clir::func generator::function(std::string_view name) { - auto A = var("A"); - auto B = var("B"); - auto C = var("C"); - - auto fb = ::clir::function_builder{std::string(name)}; - auto const scalar = [&](precision_helper const &fph, std::optional const &val, - std::string const &prefix) -> expr { - auto v = var{prefix}; - fb.argument(fph.type(), v); - return val ? fph.constant(*val) : v; - }; - auto const shape = [&](std::int64_t shape, expr &target, std::string const &prefix) { - auto v = var{prefix}; - fb.argument(to_clir_ty(scalar_type::index), v); - target = is_dynamic_value(shape) ? expr{std::move(v)} : expr{shape}; - }; - auto const stride = [&](std::array const &stride, std::array &target, - std::string const &prefix) { - for (std::size_t i = 0; i < stride.size(); ++i) { - auto v = var{prefix}; - fb.argument(to_clir_ty(scalar_type::index), v); - target[i] = is_dynamic_value(stride[i]) ? expr{std::move(v)} : expr{stride[i]}; - } - }; - - shape(gemm_cfg.M, MNK[0], "M"); - shape(gemm_cfg.N, MNK[1], "N"); - shape(gemm_cfg.K, MNK[2], "K"); - expr alpha = scalar(precision_helper{gemm_cfg.ty.alpha}, gemm_cfg.alpha, "alpha"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), A); - stride(gemm_cfg.A_stride, A_stride, "A_stride"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), B); - stride(gemm_cfg.B_stride, B_stride, "B_stride"); - expr beta = scalar(precision_helper{gemm_cfg.ty.beta}, gemm_cfg.beta, "beta"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), C); - stride(gemm_cfg.C_stride, C_stride, "C_stride"); - - fb.body([&](block_builder &bb) { add_function_body(bb, A, B, C, alpha, beta); }); - - auto f = fb.get_product(); - make_names_unique(f); - unsafe_simplify(f); - - return f; -} - -::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, address_space As, - address_space Bs, address_space Cs) { - return generator{gemm_cfg, tiling, core_cfg, As, Bs, Cs}.function(name); -} - -} // namespace tinytc diff --git a/src/gemm_generator.hpp b/src/gemm_generator.hpp deleted file mode 100644 index 8646ef70..00000000 --- a/src/gemm_generator.hpp +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef GEMM_GENERATOR_20240314_HPP -#define GEMM_GENERATOR_20240314_HPP - -#include "device_info.hpp" -#include "tiling.hpp" -#include "tinytc/types.hpp" - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace tinytc { - -//! Struct to handle mixed precision GEMMs -struct gemm_scalar_type { - //! alpha, A, B, beta, C all have the same type - gemm_scalar_type(scalar_type ty); - //! alpha's, A's, and B's type is different from beta's and C's type - gemm_scalar_type(scalar_type alphaAB, scalar_type betaC); - //! All operands potentially have a different type - gemm_scalar_type(scalar_type alpha, scalar_type A, scalar_type B, scalar_type beta, - scalar_type C); - scalar_type alpha, ///< @f$\alpha@f$ type - A, ///< A element type - B, ///< B element type - beta, ///< @f$\beta@f$ type - C; ///< C element type -}; - -/** - * @brief GEMM configuration struct - * - * The interface supports the operation - * - * C = alpha * opA(A) * opB(B) + beta * C, - * - * where - * - * opA/B(X) = transA/B == T ? X^T : X - * - * C is an MxN matrix, A is a MxK matrix, and B is a KxN matrix. - * - * The address of a matrix is calculated as following. Let X be element of {A,B,C}, then - * - * X(i,j) = X[i * X_stride[0] + j * X_stride[1]] - * - * If the atomic flag is set, C is updated atomically, either using - * - * * beta = 0: atomic store - * * beta = 1: atomic fetch add - * * general beta: atomic compare exchange - */ -struct gemm_configuration { - gemm_scalar_type ty; ///< scalar types of alpha, A, B, beta, C - transpose transA; ///< Transposition of A - transpose transB; ///< Transposition of B - std::int64_t M; ///< M, can be set to dynamic - std::int64_t N; ///< N, can be set to dynamic - std::int64_t K; ///< K, can be set to dynamic - std::array A_stride; ///< stride of A, entries can be set to dynamic - std::array B_stride; ///< stride of B, entries can be set to dynamic - std::array C_stride; ///< stride of C, entries can be set to dynamic - std::optional alpha; ///< fixed alpha if set; dynamic alpha if std::nullopt - std::optional beta; ///< fixed beta if set; dynamic beta if std::nullopt - bool atomic = false; ///< update C atomically - - std::string identifier( - std::string_view prefix = "gemm") const; ///< convert configuration to identification string -}; - -/** - * @brief Generate GEMM - * - * @param gemm_cfg configuration - * @param tiling Size of 2D subgroup grid - * @param core_cfg Core configuration - * @param name Routine prefix - * @param As Memory space of A (global or local) - * @param Bs Memory space of B (global or local) - * @param Cs Memory space of C (global or local) - * - * @return OpenCL-C AST - */ -::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, - ::clir::address_space As = ::clir::address_space::global_t, - ::clir::address_space Bs = ::clir::address_space::global_t, - ::clir::address_space Cs = ::clir::address_space::global_t); - -/** - * @brief Calculate maximum register blocking size of GEMM - * - * @param C_scalar_type_size_in_bytes Size of scalar type of result matrix in bytes - * @param sgs Subgroup size - * @param register_space Size of register file per core in bytes - * @param max_fill_fraction Fraction of register file that shall be blocked at most - * - * @return {number of row-blocks (block size = subgroup size), number of columns} - */ -auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uint32_t sgs, - std::uint32_t register_space, - std::pair max_fill_fraction = {1, 2}) - -> std::pair; - -} // namespace tinytc - -#endif // GEMM_GENERATOR_20240314_HPP diff --git a/src/gemm_tools.cpp b/src/gemm_tools.cpp new file mode 100644 index 00000000..5f1fe374 --- /dev/null +++ b/src/gemm_tools.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "gemm_tools.hpp" + +#include + +namespace tinytc { + +auto max_register_block_gemm(std::int32_t A_size, std::int32_t B_size, std::int32_t C_size, + std::int32_t subgroup_size, std::int32_t register_space, + std::int32_t C_blocks, + std::pair max_fill_fraction) + -> std::pair { + auto const arithmetic_intensity = [](std::int32_t rows, std::int32_t cols) { + return (rows * cols) / static_cast(rows + cols); + }; + + auto const max_bytes = register_space * max_fill_fraction.first / max_fill_fraction.second; + + constexpr std::int32_t max_K = standard_K_block_sizes.back(); + + // The required number of bytes is given by + // num_bytes = rows * (cols * C_blocks * C_size + max_K * A_size) + cols * max_K * B_size. + // Thus + // rows <= (max_bytes - cols * max_K * B_size) / (cols * C_blocks * C_size + max_K * A_size). + // Moreover, we require rows % subgroup_size = 0, so we set rows = k * subgroup_size and get + // k <= (max_bytes - cols * max_K * B_size) / + // (subgroup_size * (cols * C_blocks * C_size + max_K * A_size)). + auto const max_rows = [&](std::int32_t cols) { + const auto k = (max_bytes - cols * max_K * B_size) / + (subgroup_size * (cols * C_blocks * C_size + max_K * A_size)); + return k * subgroup_size; + }; + // Here, we have + // cols <= (max_bytes - rows * max_K * A_size) / (rows * C_blocks * C_size + max_K * B_size). + auto const max_cols = [&](std::int32_t rows) { + return (max_bytes - rows * max_K * A_size) / (rows * C_blocks * C_size + max_K * B_size); + }; + + double max_ai = 0.0; + std::int32_t rows = subgroup_size, cols = 1; + for (std::int32_t r = subgroup_size; r <= max_rows(1); r += subgroup_size) { + for (std::int32_t c = 1; c <= max_cols(r); ++c) { + auto const ai = arithmetic_intensity(r, c); + if (ai > max_ai) { + max_ai = ai; + rows = r; + cols = c; + } + } + } + + return std::make_pair(rows, cols); +} + +// We have block_size(k) = k * subgroup_size, where k is a positive integer, +// and num_blocks(k) = ceil(size / block_size(k)) +// We want to solve +// max_k block_size(k) s.t. +// block_size(k) <= max_block_size ; must not exceed max block size +// and num_blocks(k) % num_tiles == 0 ; no load imbalance +// and block_size(k) - size < sgs ; no excessive block size +// +// If the above optimization does not have a solution, the minimum block size (= subgroup size) is +// returned) +// +auto compute_m_block_size(std::int32_t subgroup_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t { + auto const block_size = [&subgroup_size](std::int32_t k) -> std::int32_t { + return k * subgroup_size; + }; + auto const num_blocks = [&block_size, &size](std::int32_t k) -> std::int64_t { + return 1 + (size - 1) / block_size(k); + }; + std::int32_t k = max_block_size / subgroup_size; + while (k > 1 && (num_blocks(k) % num_tiles != 0 || block_size(k) - size >= subgroup_size)) { + --k; + } + return k * subgroup_size; +} + +auto choose_block_size_multiple(std::int32_t min_block_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t { + auto const block_size = [&min_block_size](std::int32_t k) -> std::int32_t { + return k * min_block_size; + }; + auto const num_blocks = [&block_size, &size](std::int32_t k) -> std::int64_t { + return 1 + (size - 1) / block_size(k); + }; + std::int32_t k = 1; + while (2 * k * min_block_size < max_block_size) { + k *= 2; + } + while (k > 1 && (num_blocks(k) % num_tiles != 0 || block_size(k) - size >= min_block_size)) { + k /= 2; + } + return k; +} + +// Similar as compute_m_block_size for fixed sizes +// block_sizes array must be sorted in ascending order +auto choose_block_size(array_view block_sizes, std::int32_t num_tiles, + std::int64_t size) -> std::int32_t { + auto const num_blocks = [&size](std::int32_t block_size) -> std::int64_t { + return 1 + (size - 1) / block_size; + }; + std::size_t k = block_sizes.size() - 1; + while (k > 1 && (num_blocks(block_sizes[k]) % num_tiles != 0 || + block_sizes[k] - size >= block_sizes[0])) { + --k; + } + return block_sizes[k]; +} + +auto choose_k_block_size(array_view block_sizes, std::int64_t K) -> std::int32_t { + std::int64_t j = static_cast(block_sizes.size()) - 1; + for (; K < block_sizes[j] && j > 0; --j) { + } + return block_sizes[j]; +} + +} // namespace tinytc diff --git a/src/gemm_tools.hpp b/src/gemm_tools.hpp new file mode 100644 index 00000000..94c01445 --- /dev/null +++ b/src/gemm_tools.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GEMM_TOOLS_20241022_HPP +#define GEMM_TOOLS_20241022_HPP + +#include "tinytc/tinytc.hpp" + +#include +#include +#include + +namespace tinytc { + +constexpr static std::array standard_K_block_sizes = {1, 2, 4, 8}; + +/** + * @brief Calculate maximum register blocking size of GEMM + * + * @param A_size Size of scalar type of A matrix in bytes + * @param B_size Size of scalar type of B matrix in bytes + * @param C_size Size of scalar type of result matrix in bytes + * @param subgroup_size Subgroup size + * @param register_space Size of register file per core in bytes + * @param C_blocks Number of register blocks needed for C, usually 1, 2 for complex + * @param max_fill_fraction Fraction of register file that shall be blocked at most + * + * @return {number of rows, number of columns} + */ +auto max_register_block_gemm(std::int32_t A_size, std::int32_t B_size, std::int32_t C_size, + std::int32_t subgroup_size, std::int32_t register_space, + std::int32_t C_blocks = 1, + std::pair max_fill_fraction = {1, 2}) + -> std::pair; + +auto compute_m_block_size(std::int32_t subgroup_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t; +auto choose_block_size_multiple(std::int32_t min_block_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t; +auto choose_block_size(array_view block_sizes, std::int32_t num_tiles, + std::int64_t size) -> std::int32_t; +auto choose_k_block_size(array_view block_sizes, std::int64_t K) -> std::int32_t; + +} // namespace tinytc + +#endif // GEMM_TOOLS_20241022_HPP diff --git a/src/half.cpp b/src/half.cpp new file mode 100644 index 00000000..20d5e2d6 --- /dev/null +++ b/src/half.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" + +#include +#include + +using namespace tinytc; + +extern "C" { + +uint16_t tinytc_f32_to_f16_as_ui16(float x) { + return ieee754_truncate(std::bit_cast(x)); +} + +float tinytc_f16_as_ui16_to_f32(uint16_t x) { + const auto y = ieee754_extend(x); + return std::bit_cast(y); +} + +uint16_t tinytc_f32_to_bf16_as_ui16(float x) { + return ieee754_truncate(std::bit_cast(x)); +} + +float tinytc_bf16_as_ui16_to_f32(uint16_t x) { + const auto y = ieee754_extend(x); + return std::bit_cast(y); +} +} diff --git a/src/inst.cpp b/src/inst.cpp index 06f41460..d421b46a 100644 --- a/src/inst.cpp +++ b/src/inst.cpp @@ -3,25 +3,37 @@ #include "error.hpp" #include "location.hpp" +#include "node/data_type_node.hpp" #include "node/inst_node.hpp" -#include "slice.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" +#include "util/casting.hpp" #include +#include #include #include #include #include -#include -#include using namespace tinytc; extern "C" { + +char const *tinytc_address_space_to_string(tinytc_address_space_t as) { + switch (as) { + case tinytc_address_space_global: + return "global"; + case tinytc_address_space_local: + return "local"; + } + return "unknown"; +} + char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op) { switch (op) { case tinytc_arithmetic_add: @@ -44,16 +56,74 @@ char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op) { return "or"; case tinytc_arithmetic_xor: return "xor"; + case tinytc_arithmetic_min: + return "min"; + case tinytc_arithmetic_max: + return "max"; } return "unknown"; } char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op) { switch (op) { - case tinytc_arithmetic_unary_neg: - return "neg"; + case tinytc_arithmetic_unary_abs: + return "abs"; case tinytc_arithmetic_unary_not: return "not"; + case tinytc_arithmetic_unary_neg: + return "neg"; + case tinytc_arithmetic_unary_conj: + return "conj"; + case tinytc_arithmetic_unary_im: + return "im"; + case tinytc_arithmetic_unary_re: + return "re"; + } + return "unknown"; +} + +char const *tinytc_builtin_to_string(tinytc_builtin_t b) { + switch (b) { + case tinytc_builtin_group_id_x: + return "group_id.x"; + case tinytc_builtin_group_id_y: + return "group_id.y"; + case tinytc_builtin_group_id_z: + return "group_id.z"; + case tinytc_builtin_num_groups_x: + return "num_groups.x"; + case tinytc_builtin_num_groups_y: + return "num_groups.y"; + case tinytc_builtin_num_groups_z: + return "num_groups.z"; + case tinytc_builtin_num_subgroups_x: + return "num_subgroups.x"; + case tinytc_builtin_num_subgroups_y: + return "num_subgroups.y"; + case tinytc_builtin_subgroup_size: + return "subgroup_size"; + case tinytc_builtin_subgroup_id_x: + return "subgroup_id.x"; + case tinytc_builtin_subgroup_id_y: + return "subgroup_id.y"; + case tinytc_builtin_subgroup_linear_id: + return "subgroup_linear_id"; + case tinytc_builtin_subgroup_local_id: + return "subgroup_local_id"; + } + return "unknown"; +} + +char const *tinytc_checked_flag_to_string(tinytc_checked_flag_t flag) { + switch (flag) { + case tinytc_checked_flag_none: + return ""; + case tinytc_checked_flag_rows: + return "rows_checked"; + case tinytc_checked_flag_cols: + return "cols_checked"; + case tinytc_checked_flag_both: + return "both_checked"; } return "unknown"; } @@ -76,6 +146,78 @@ char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond) { return "unknown"; } +char const *tinytc_math_unary_to_string(tinytc_math_unary_t op) { + switch (op) { + case tinytc_math_unary_cos: + return "cos"; + case tinytc_math_unary_sin: + return "sin"; + case tinytc_math_unary_exp: + return "exp"; + case tinytc_math_unary_exp2: + return "exp2"; + case tinytc_math_unary_native_cos: + return "native_cos"; + case tinytc_math_unary_native_sin: + return "native_sin"; + case tinytc_math_unary_native_exp: + return "native_exp"; + case tinytc_math_unary_native_exp2: + return "native_exp2"; + } + return "unknown"; +} + +char const *tinytc_store_flag_to_string(tinytc_store_flag_t flag) { + switch (flag) { + case tinytc_store_flag_regular: + return ""; + case tinytc_store_flag_atomic: + return "atomic"; + case tinytc_store_flag_atomic_add: + return "atomic_add"; + case tinytc_store_flag_atomic_max: + return "atomic_max"; + case tinytc_store_flag_atomic_min: + return "atomic_min"; + } + return "unknown"; +} + +char const *tinytc_group_arithmetic_to_string(tinytc_group_arithmetic_t op) { + switch (op) { + case tinytc_group_arithmetic_add: + return "add"; + case tinytc_group_arithmetic_max: + return "max"; + case tinytc_group_arithmetic_min: + return "min"; + } + return "unknown"; +} + +char const *tinytc_group_operation_to_string(tinytc_group_operation_t op) { + switch (op) { + case tinytc_group_operation_exclusive_scan: + return "exclusive_scan"; + case tinytc_group_operation_inclusive_scan: + return "inclusive_scan"; + case tinytc_group_operation_reduce: + return "reduce"; + } + return "unknown"; +} + +char const *tinytc_reduce_mode_to_string(tinytc_reduce_mode_t m) { + switch (m) { + case tinytc_reduce_mode_row: + return "row"; + case tinytc_reduce_mode_column: + return "column"; + } + return "unknown"; +} + char const *tinytc_transpose_to_string(tinytc_transpose_t t) { switch (t) { case tinytc_transpose_T: @@ -87,65 +229,320 @@ char const *tinytc_transpose_to_string(tinytc_transpose_t t) { } tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, - tinytc_value_t a, tinytc_value_t b, + tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), value(a, true), - value(b, true), get_optional(loc)) - .release(); + *instr = + std::make_unique(enum_cast(op), a, b, ty, get_optional(loc)) + .release(); }); } tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_unary_t op, - tinytc_value_t a, const tinytc_location_t *loc) { + tinytc_value_t a, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), value(a, true), + *instr = std::make_unique(enum_cast(op), a, ty, get_optional(loc)) .release(); }); } +tinytc_status_t tinytc_barrier_inst_create(tinytc_inst_t *instr, + tinytc_address_spaces_t fence_flags, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(fence_flags, get_optional(loc)).release(); }); +} + tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_scalar_type_t to_ty, const tinytc_location_t *loc) { + tinytc_data_type_t to_ty, const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(a, to_ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, + tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), enum_cast(to_ty), - get_optional(loc)) + *instr = std::make_unique(enum_cast(cond), a, b, ty, + get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, - tinytc_value_t a, tinytc_value_t b, - const tinytc_location_t *loc) { +tinytc_status_t tinytc_constant_inst_create_boolean(tinytc_inst_t *instr, tinytc_bool_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(cond), value(a, true), - value(b, true), get_optional(loc)) + *instr = std::make_unique(value != 0, ty, get_optional(loc)).release(); + }); +} + +tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, double value_re, + double value_im, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(std::complex(value_re, value_im), ty, + get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, - const tinytc_location_t *loc) { +tinytc_status_t tinytc_constant_inst_create_float(tinytc_inst_t *instr, double value, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(value, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t value, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(value, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + + if (const auto *bt = dyn_cast(ty); bt != nullptr) { + return exception_to_status_code([&] { + *instr = std::make_unique(true, ty, get_optional(loc)).release(); + }); + } + + scalar_type sty; + if (const auto *st = dyn_cast(ty); st != nullptr) { + sty = st->ty(); + } else if (const auto *ct = dyn_cast(ty); ct != nullptr) { + sty = ct->component_ty(); + } else { + return tinytc_status_invalid_arguments; + } + + return exception_to_status_code([&] { + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + *instr = + std::make_unique(std::int64_t{1}, ty, get_optional(loc)).release(); + break; + case scalar_type::bf16: + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + *instr = std::make_unique(double{1}, ty, get_optional(loc)).release(); + break; + case scalar_type::c32: + case scalar_type::c64: + *instr = std::make_unique(std::complex{1}, ty, get_optional(loc)) + .release(); + break; + } + }); +} + +tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } + + if (const auto *bt = dyn_cast(ty); bt != nullptr) { + return exception_to_status_code([&] { + *instr = std::make_unique(false, ty, get_optional(loc)).release(); + }); + } + + scalar_type sty; + if (const auto *st = dyn_cast(ty); st != nullptr) { + sty = st->ty(); + } else if (const auto *ct = dyn_cast(ty); ct != nullptr) { + sty = ct->component_ty(); + } else { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { - *instr = std::make_unique(data_type(ty, true), get_optional(loc)).release(); + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + *instr = + std::make_unique(std::int64_t{0}, ty, get_optional(loc)).release(); + break; + case scalar_type::bf16: + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + *instr = std::make_unique(double{0}, ty, get_optional(loc)).release(); + break; + case scalar_type::c32: + case scalar_type::c64: + *instr = std::make_unique(std::complex{0}, ty, get_optional(loc)) + .release(); + break; + } }); } +tinytc_status_t tinytc_cooperative_matrix_apply_inst_create(tinytc_inst_t *instr, + tinytc_value_t mat, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || mat == nullptr || ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + std::make_unique(mat, ty, get_optional(loc)).release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_extract_inst_create(tinytc_inst_t *instr, + tinytc_value_t mat, int64_t index, + tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || mat == nullptr || ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + std::make_unique(mat, index, ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_insert_inst_create(tinytc_inst_t *instr, + tinytc_value_t val, tinytc_value_t mat, + int64_t index, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || val == nullptr || mat == nullptr || ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + std::make_unique(val, mat, index, ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_load_inst_create( + tinytc_inst_t *instr, tinytc_transpose_t trans, tinytc_checked_flag_t flag, tinytc_value_t op, + tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, const tinytc_location_t *loc) { + if (instr == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr || to_ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(enum_cast(trans), + enum_cast(flag), op, + p0, p1, to_ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_mul_add_inst_create(tinytc_inst_t *instr, + tinytc_value_t a, tinytc_value_t b, + tinytc_value_t c, + tinytc_data_type_t to_ty, + const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr || b == nullptr || c == nullptr || to_ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + std::make_unique(a, b, c, to_ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_prefetch_inst_create( + tinytc_inst_t *instr, int32_t cache_level, tinytc_value_t op, tinytc_value_t p0, + tinytc_value_t p1, int32_t rows, int32_t cols, const tinytc_location_t *loc) { + if (instr == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(cache_level, op, p0, p1, rows, + cols, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_scale_inst_create(tinytc_inst_t *instr, tinytc_value_t a, + tinytc_value_t b, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr || b == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + std::make_unique(a, b, ty, get_optional(loc)).release(); + }); +} + +tinytc_status_t tinytc_cooperative_matrix_store_inst_create(tinytc_inst_t *instr, + tinytc_checked_flag_t cflag, + tinytc_store_flag_t sflag, + tinytc_value_t val, tinytc_value_t op, + tinytc_value_t p0, tinytc_value_t p1, + const tinytc_location_t *loc) { + if (instr == nullptr || val == nullptr || op == nullptr || p0 == nullptr || p1 == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(enum_cast(cflag), + enum_cast(sflag), val, + op, p0, p1, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(ty, get_optional(loc)).release(); }); +} + tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, tinytc_value_t B, @@ -154,72 +551,77 @@ tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_ return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(beta, true), value(B, true), + *instr = std::make_unique(enum_cast(tA), alpha, A, beta, B, bool(atomic), get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, - uint32_t expand_shape_size, tinytc_value_t *expand_shape, - const tinytc_location_t *loc) { - if (instr == nullptr || expand_shape == nullptr) { +tinytc_status_t tinytc_builtin_inst_create(tinytc_inst_t *instr, tinytc_builtin_t btype, + tinytc_data_type_t ty, const tinytc_location_t *loc) { + if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto eshape_vec = std::vector(); - eshape_vec.reserve(expand_shape_size); - for (uint32_t i = 0; i < expand_shape_size; ++i) { - eshape_vec.emplace_back(value(expand_shape[i], true)); - } - *instr = std::make_unique(value(a, true), mode, std::move(eshape_vec), - get_optional(loc)) + *instr = std::make_unique(enum_cast(btype), ty, get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t from, - int64_t to, const tinytc_location_t *loc) { +tinytc_status_t tinytc_cumsum_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, + tinytc_value_t alpha, tinytc_value_t A, int64_t mode, + tinytc_value_t beta, tinytc_value_t B, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), from, to, get_optional(loc)).release(); + *instr = + std::make_unique(alpha, A, mode, beta, B, bool(atomic), get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t index_list_size, tinytc_value_t *index_list, - const tinytc_location_t *loc) { - if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { +tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, + int64_t expanded_mode, uint32_t static_expand_shape_size, + const int64_t *static_expand_shape, + uint32_t expand_shape_size, + const tinytc_value_t *expand_shape, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || static_expand_shape == nullptr || + (expand_shape_size > 0 && expand_shape == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto il_vec = std::vector(); - il_vec.reserve(index_list_size); - for (uint32_t i = 0; i < index_list_size; ++i) { - il_vec.emplace_back(value(index_list[i], true)); - } - *instr = std::make_unique(value(a, true), std::move(il_vec), get_optional(loc)) + *instr = std::make_unique( + a, expanded_mode, array_view{static_expand_shape, static_expand_shape_size}, + array_view{expand_shape, expand_shape_size}, ty, get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { +tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t from, + int64_t to, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + return exception_to_status_code([&] { + *instr = std::make_unique(a, from, to, ty, get_optional(loc)).release(); + }); } -tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { - if (instr == nullptr) { +tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, + uint32_t index_list_size, const tinytc_value_t *index_list, + tinytc_data_type_t ty, const tinytc_location_t *loc) { + if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { return tinytc_status_invalid_arguments; } - return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); + return exception_to_status_code([&] { + *instr = std::make_unique(a, array_view{index_list, index_list_size}, ty, + get_optional(loc)) + .release(); + }); } tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, @@ -232,9 +634,7 @@ tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t } return exception_to_status_code([&] { *instr = std::make_unique(enum_cast(tA), enum_cast(tB), - value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) + alpha, A, B, beta, C, bool(atomic), get_optional(loc)) .release(); }); } @@ -247,9 +647,8 @@ tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tinytc_transpose_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(B, true), value(beta, true), - value(C, true), bool(atomic), get_optional(loc)) + *instr = std::make_unique(enum_cast(tA), alpha, A, B, beta, C, + bool(atomic), get_optional(loc)) .release(); }); } @@ -262,9 +661,7 @@ tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomi return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) + *instr = std::make_unique(alpha, A, B, beta, C, bool(atomic), get_optional(loc)) .release(); }); } @@ -277,58 +674,103 @@ tinytc_status_t tinytc_hadamard_inst_create(tinytc_inst_t *instr, tinytc_bool_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) - .release(); + *instr = + std::make_unique(alpha, A, B, beta, C, bool(atomic), get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, - const tinytc_location_t *loc) { +tinytc_status_t tinytc_math_unary_inst_create(tinytc_inst_t *instr, tinytc_math_unary_t op, + tinytc_value_t a, tinytc_data_type_t ty, + const tinytc_location_t *loc) { if (instr == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), mode, get_optional(loc)).release(); + *instr = + std::make_unique(enum_cast(op), a, ty, get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t slice_list_size, tinytc_value_t *offset_list, - tinytc_value_t *size_list, - const tinytc_location_t *loc) { +tinytc_status_t tinytc_parallel_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, + tinytc_data_type_t ty, const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = std::make_unique(a, mode, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_subgroup_broadcast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, + tinytc_value_t idx, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr || idx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(a, idx, ty, get_optional(loc)).release(); + }); +} + +tinytc_status_t tinytc_subgroup_operation_inst_create(tinytc_inst_t *instr, + tinytc_group_arithmetic_t arith, + tinytc_group_operation_t operation, + tinytc_value_t a, tinytc_data_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr || a == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = std::make_unique(enum_cast(arith), + enum_cast(operation), a, + ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t +tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, uint32_t static_list_size, + const int64_t *static_offset_list, const int64_t *static_size_list, + uint32_t offset_list_size, const tinytc_value_t *offset_list, + uint32_t size_list_size, const tinytc_value_t *size_list, + tinytc_data_type_t ty, const tinytc_location_t *loc) { if (instr == nullptr || - (slice_list_size > 0 && (offset_list == nullptr || size_list == nullptr))) { + (static_list_size > 0 && (static_offset_list == nullptr || static_size_list == nullptr)) || + (offset_list_size > 0 && offset_list == nullptr) || + (size_list_size > 0 && size_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto slice_vec = std::vector(); - slice_vec.reserve(slice_list_size); - for (uint32_t i = 0; i < slice_list_size; ++i) { - slice_vec.emplace_back(value(offset_list[i], true), value(size_list[i], true)); - } - *instr = - std::make_unique(value(a, true), std::move(slice_vec), get_optional(loc)) - .release(); + *instr = std::make_unique(a, array_view{static_offset_list, static_list_size}, + array_view{static_size_list, static_list_size}, + array_view{offset_list, offset_list_size}, + array_view{size_list, size_list_size}, ty, + get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, tinytc_value_t a, - uint32_t index_list_size, tinytc_value_t *index_list, +tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_store_flag_t flag, + tinytc_value_t val, tinytc_value_t a, + uint32_t index_list_size, const tinytc_value_t *index_list, const tinytc_location_t *loc) { if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto il_vec = std::vector(); - il_vec.reserve(index_list_size); - for (uint32_t i = 0; i < index_list_size; ++i) { - il_vec.emplace_back(value(index_list[i], true)); - } - *instr = std::make_unique(value(val, true), value(a, true), std::move(il_vec), - get_optional(loc)) - .release(); + *instr = + std::make_unique(enum_cast(flag), val, a, + array_view{index_list, index_list_size}, get_optional(loc)) + .release(); }); } @@ -340,129 +782,135 @@ tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(beta, true), value(B, true), + *instr = std::make_unique(enum_cast(tA), alpha, A, beta, B, bool(atomic), get_optional(loc)) .release(); }); } -tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, +tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - tinytc_region_t body, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var == nullptr || from == nullptr || to == nullptr || - body == nullptr) { + uint32_t init_return_list_size, + const tinytc_value_t *initial_value_list, + const tinytc_data_type_t *return_type_list, + const tinytc_location_t *loc) { + if (instr == nullptr || loop_var_type == nullptr || from == nullptr || to == nullptr || + (init_return_list_size != 0 && + (initial_value_list == nullptr || return_type_list == nullptr))) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(value(loop_var, true), value(from, true), value(to, true), - value(step, true), region(body, true), get_optional(loc)) - .release(); + *instr = std::make_unique(loop_var_type, from, to, step, + array_view{initial_value_list, init_return_list_size}, + array_view{return_type_list, init_return_list_size}, + get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_region_t body, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var == nullptr || from == nullptr || to == nullptr || - body == nullptr) { +tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_data_type_t loop_var_type, + uint32_t dim, const tinytc_value_t *from_list, + const tinytc_value_t *to_list, + const tinytc_location_t *loc) { + if (instr == nullptr || loop_var_type == nullptr || from_list == nullptr || + to_list == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *instr = - std::make_unique(value(loop_var, true), value(from, true), - value(to, true), region(body, true), get_optional(loc)) - .release(); + *instr = std::make_unique(loop_var_type, array_view{from_list, dim}, + array_view{to_list, dim}, get_optional(loc)) + .release(); }); } tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, - tinytc_region_t then, tinytc_region_t otherwise, uint32_t return_type_list_size, - tinytc_scalar_type_t *return_type_list, + const tinytc_data_type_t *return_type_list, const tinytc_location_t *loc) { - if (instr == nullptr || condition == nullptr || then == nullptr || + if (instr == nullptr || condition == nullptr || (return_type_list_size > 0 && return_type_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto rt = std::vector(); - rt.reserve(return_type_list_size); - for (uint32_t i = 0; i < return_type_list_size; ++i) { - rt.emplace_back(enum_cast(return_type_list[i])); - } - *instr = - std::make_unique(value(condition, true), region(then, true), - region(otherwise, true), std::move(rt), get_optional(loc)) - .release(); + *instr = std::make_unique(condition, + array_view{return_type_list, return_type_list_size}, + get_optional(loc)) + .release(); }); } tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, - tinytc_value_t *yield_list, const tinytc_location_t *loc) { - if (instr == nullptr || yield_list_size == 0 || yield_list == nullptr) { + const tinytc_value_t *yield_list, + const tinytc_location_t *loc) { + if (instr == nullptr || (yield_list_size != 0 && yield_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto yl = std::vector(); - yl.reserve(yield_list_size); - for (uint32_t i = 0; i < yield_list_size; ++i) { - yl.emplace_back(value(yield_list[i], true)); - } - *instr = std::make_unique(std::move(yl), get_optional(loc)).release(); + *instr = + std::make_unique(array_view{yield_list, yield_list_size}, get_optional(loc)) + .release(); }); } -tinytc_status_t tinytc_inst_release(tinytc_inst_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} +void tinytc_inst_destroy(tinytc_inst_t obj) { delete obj; } -tinytc_status_t tinytc_inst_retain(tinytc_inst_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_inst_get_parent_region(tinytc_inst_t instr, tinytc_region_t *parent) { + if (instr == nullptr || parent == nullptr) { return tinytc_status_invalid_arguments; } - obj->inc_ref(); - return tinytc_status_success; + return exception_to_status_code([&] { *parent = instr->parent(); }); } -tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, tinytc_value_t *result) { - if (instr == nullptr || result == nullptr) { +tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, uint32_t *result_list_size, + tinytc_value_t *result_list) { + if (instr == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { *result = instr->result().release(); }); + return exception_to_status_code([&] { + auto const num_results = instr->num_results(); + if (num_results > std::numeric_limits::max()) { + throw std::out_of_range("too many results"); + } + auto num = static_cast(num_results); + if (*result_list_size > 0) { + num = std::min(num, *result_list_size); + auto results = instr->result_begin(); + for (uint32_t i = 0; i < num; ++i) { + result_list[i] = &results[i]; + } + } + *result_list_size = num; + }); } -tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *result_list_size, - tinytc_value_t *result_list) { +tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, uint32_t *result_list_size, + tinytc_region_t *result_list) { if (instr == nullptr || result_list_size == nullptr || (*result_list_size > 0 && result_list == nullptr)) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto const num_results = instr->num_results(); + auto const num_results = instr->num_child_regions(); if (num_results > std::numeric_limits::max()) { throw std::out_of_range("too many results"); } - auto const num = static_cast(num_results); + auto num = static_cast(num_results); if (*result_list_size > 0) { - auto results = instr->results(); - if (results.size() != num_results) { - throw internal_compiler_error(); - } - auto const limit = std::min(num, *result_list_size); - for (uint32_t i = 0; i < limit; ++i) { - result_list[i] = results[i].release(); + auto results = instr->child_regions_begin(); + num = std::min(num, *result_list_size); + for (uint32_t i = 0; i < num; ++i) { + result_list[i] = &results[i]; } } *result_list_size = num; }); } + +tinytc_status_t tinytc_inst_set_attr(tinytc_inst_t instr, tinytc_attr_t a) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { instr->attr(a); }); +} } diff --git a/src/matrix_ext_info.cpp b/src/matrix_ext_info.cpp new file mode 100644 index 00000000..0e3a1d44 --- /dev/null +++ b/src/matrix_ext_info.cpp @@ -0,0 +1,170 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "matrix_ext_info.hpp" +#include "node/data_type_node.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +matrix_ext_type::matrix_ext_type(scalar_type a, scalar_type b, std::vector acc, + std::vector mnk) + : a_{a}, b_{b}, acc_(std::move(acc)), mnk_(std::move(mnk)) {} + +auto matrix_ext_type::have_acc(scalar_type acc) const -> bool { + return std::find(acc_.begin(), acc_.end(), acc) != acc_.end(); +} +auto matrix_ext_type::have_type(scalar_type sty, std::int64_t rows, std::int64_t cols, + matrix_use use) const -> bool { + auto find_shape = [&rows, &cols](array_view mnks, std::int64_t gemm_mnk::*shape0, + std::int64_t gemm_mnk::*shape1) { + for (auto const &mnk : mnks) { + if (mnk.*shape0 == rows && mnk.*shape1 == cols) { + return true; + } + } + return false; + }; + switch (use) { + case matrix_use::a: + return a() == sty && find_shape(mnk(), &gemm_mnk::M, &gemm_mnk::K); + case matrix_use::b: + return b() == sty && find_shape(mnk(), &gemm_mnk::K, &gemm_mnk::N); + case matrix_use::acc: + return have_acc(sty) && find_shape(mnk(), &gemm_mnk::M, &gemm_mnk::N); + } + throw status::internal_compiler_error; +} + +template +auto block_sizes(std::vector const &mnks, Get get) -> std::vector { + auto bs = std::vector{}; + bs.reserve(mnks.size()); + for (auto &mnk : mnks) { + auto val = get(mnk); + if (val) { + bs.push_back(*val); + } + } + std::sort(bs.begin(), bs.end()); + auto end = std::unique(bs.begin(), bs.end()); + bs.erase(end, bs.end()); + return bs; +} + +auto matrix_ext_type::M_block_sizes() const -> std::vector { + return block_sizes(mnk_, + [](gemm_mnk const &mnk) -> std::optional { return mnk.M; }); +} +auto matrix_ext_type::N_block_sizes(std::int32_t M) const -> std::vector { + return block_sizes(mnk_, [&M](gemm_mnk const &mnk) -> std::optional { + return mnk.M == M ? std::make_optional(mnk.N) : std::nullopt; + }); +} +auto matrix_ext_type::K_block_sizes(std::int32_t M, std::int32_t N) const + -> std::vector { + return block_sizes(mnk_, [&M, N](gemm_mnk const &mnk) -> std::optional { + return mnk.M == M && mnk.N == N ? std::make_optional(mnk.K) : std::nullopt; + }); +} + +auto matrix_ext_info::get_precision(scalar_type a, scalar_type b, scalar_type acc) const + -> matrix_ext_type const * { + for (auto const &type : types_) { + if (type.a() == a && type.b() == b && type.have_acc(acc)) { + return &type; + } + } + return nullptr; +} + +auto matrix_ext_info::have_gemm(scalar_type a, scalar_type b, scalar_type c, scalar_type d, + std::int64_t M, std::int64_t N, std::int64_t K) const -> bool { + for (auto const &type : types_) { + if (type.have_type(a, M, K, matrix_use::a) && type.have_type(b, K, N, matrix_use::b) && + type.have_type(c, M, N, matrix_use::acc) && type.have_type(d, M, N, matrix_use::acc)) { + return true; + } + } + return false; +} + +auto matrix_ext_info::have_precision(scalar_type a, scalar_type b, scalar_type acc) const -> bool { + return get_precision(a, b, acc) != nullptr; +} + +auto matrix_ext_info::have_type(scalar_type sty, std::int64_t rows, std::int64_t cols, + matrix_use use) const -> bool { + for (auto const &type : types_) { + if (type.have_type(sty, rows, cols, use)) { + return true; + } + } + return false; +} + +auto matrix_ext_info::have_type(const coopmatrix_data_type *ty) const -> bool { + return have_type(ty->component_ty(), ty->rows(), ty->cols(), ty->use()); +} + +const std::array pvc_matrix_ext_types = { + {{scalar_type::i8, + scalar_type::i8, + {scalar_type::i32}, + {{16, 8, 32}, + {32, 8, 32}, + {64, 8, 32}, + {16, 16, 32}, + {32, 16, 32}, + {64, 16, 32}, + {16, 32, 32}, + {32, 32, 32}, + {64, 32, 32}, + {16, 8, 64}, + {32, 8, 64}, + {64, 8, 64}, + {16, 16, 64}, + {32, 16, 64}, + {64, 16, 64}, + {16, 32, 64}, + {32, 32, 64}, + {64, 32, 64}}}, + {scalar_type::f16, + scalar_type::f16, + {scalar_type::f16, scalar_type::f32}, + {{16, 8, 16}, + {32, 8, 16}, + {16, 16, 16}, + {32, 16, 16}, + {16, 24, 16}, + {32, 24, 16}, + {16, 32, 16}, + {32, 32, 16}, + {16, 8, 32}, + {32, 8, 32}, + {16, 16, 32}, + {32, 16, 32}, + {16, 24, 32}, + {32, 24, 32}, + {16, 32, 32}, + {32, 32, 32}}}, + {scalar_type::bf16, + scalar_type::bf16, + {scalar_type::bf16, scalar_type::f32}, + {{16, 8, 16}, + {32, 8, 16}, + {16, 16, 16}, + {32, 16, 16}, + {16, 32, 16}, + {32, 32, 16}, + {16, 8, 32}, + {32, 8, 32}, + {16, 16, 32}, + {32, 16, 32}, + {16, 32, 32}, + {32, 32, 32}}}}}; + +} // namespace tinytc diff --git a/src/matrix_ext_info.hpp b/src/matrix_ext_info.hpp new file mode 100644 index 00000000..20a24b94 --- /dev/null +++ b/src/matrix_ext_info.hpp @@ -0,0 +1,88 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MATRIX_EXT_INFO_20241204_HPP +#define MATRIX_EXT_INFO_20241204_HPP + +#include "tinytc/tinytc.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +class coopmatrix_data_type; +enum class matrix_use; +enum class scalar_type; + +struct gemm_mnk { + std::int64_t M, N, K; +}; + +class matrix_ext_type { + public: + matrix_ext_type(scalar_type a, scalar_type b, std::vector acc, + std::vector mnk); + + inline auto a() const -> scalar_type { return a_; } + inline auto b() const -> scalar_type { return b_; } + inline auto acc() const -> array_view { return acc_; } + inline auto mnk() const -> array_view { return mnk_; } + + auto M_block_sizes() const -> std::vector; + auto N_block_sizes(std::int32_t M) const -> std::vector; + auto K_block_sizes(std::int32_t M, std::int32_t N) const -> std::vector; + + auto have_acc(scalar_type acc) const -> bool; + auto have_type(scalar_type sty, std::int64_t rows, std::int64_t cols, matrix_use use) const + -> bool; + + private: + scalar_type a_, b_; + std::vector acc_; + std::vector mnk_; +}; + +struct matrix_ext_block_io_info { + std::int32_t base_address_alignment; + std::int32_t min_stride; + std::int32_t max_stride; + std::int32_t pos0_alignment; + std::int32_t stride_alignment; + std::int32_t width_alignment; +}; + +class matrix_ext_info { + public: + matrix_ext_info() = default; + inline matrix_ext_info(std::int32_t required_subgroup_size, matrix_ext_block_io_info block_io, + array_view types) + : required_sgs_{required_subgroup_size}, block_io_{block_io}, types_(std::move(types)) {} + + auto get_precision(scalar_type a, scalar_type b, scalar_type acc) const + -> matrix_ext_type const *; + auto have_gemm(scalar_type a, scalar_type b, scalar_type c, scalar_type d, std::int64_t M, + std::int64_t N, std::int64_t K) const -> bool; + auto have_precision(scalar_type a, scalar_type b, scalar_type acc) const -> bool; + auto have_type(scalar_type sty, std::int64_t rows, std::int64_t cols, matrix_use use) const + -> bool; + auto have_type(const coopmatrix_data_type *ty) const -> bool; + + inline auto required_subgroup_size() const -> std::int32_t { return required_sgs_; } + inline auto block_io() const -> matrix_ext_block_io_info const & { return block_io_; } + + inline auto have_dpas() const { return types_.size() > 0; } + + private: + std::int32_t required_sgs_; + matrix_ext_block_io_info block_io_; + array_view types_; +}; + +extern const std::array pvc_matrix_ext_types; + +} // namespace tinytc + +#endif // MATRIX_EXT_INFO_20241204_HPP diff --git a/src/node/attr_node.cpp b/src/node/attr_node.cpp new file mode 100644 index 00000000..efee6ed4 --- /dev/null +++ b/src/node/attr_node.cpp @@ -0,0 +1,160 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/attr_node.hpp" +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" +#include "support/fnv1a_array_view.hpp" // IWYU pragma: keep +#include "util/casting.hpp" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +auto array_attr::get(tinytc_compiler_context_t ctx, array_view values) + -> tinytc_attr_t { + const auto hash = fnv1a_combine(values); + const auto is_equal = [&](tinytc_attr_t a) { + const auto aa = dyn_cast(a); + return aa && + std::equal(values.begin(), values.end(), aa->values().begin(), aa->values().end()); + }; + const auto make = [&]() { return new array_attr(ctx, values); }; + + auto &attrs = ctx->cache()->array_attrs; + return attrs.get(hash, std::move(is_equal), std::move(make)); +} + +array_attr::array_attr(tinytc_compiler_context_t ctx, std::vector values) + : tinytc_attr(AK::array, ctx), values_{std::move(values)} {} + +auto boolean_attr::get(tinytc_compiler_context_t ctx, bool value) -> tinytc_attr_t { + auto cache = ctx->cache(); + return value ? cache->true_attr.get() : cache->false_attr.get(); +} + +boolean_attr::boolean_attr(tinytc_compiler_context_t ctx, bool value) + : tinytc_attr(AK::boolean, ctx), value_{value} {} + +auto dictionary_attr::get(tinytc_compiler_context_t ctx, array_view sorted_attrs) + -> tinytc_attr_t { + const auto hash = [&] { + auto h = fnv1a0(); + for (auto const &na : sorted_attrs) { + h = fnv1a_step(h, na.name); + h = fnv1a_step(h, na.attr); + } + return h; + }; + const auto is_equal = [&](tinytc_attr_t a) { + const auto da = dyn_cast(a); + return da && std::equal(sorted_attrs.begin(), sorted_attrs.end(), da->begin(), da->end(), + [](named_attr const &a, named_attr const &b) { + return a.name == b.name && a.attr == b.attr; + }); + }; + const auto make = [&]() { return new dictionary_attr(ctx, sorted_attrs); }; + + auto &attrs = ctx->cache()->dictionary_attrs; + return attrs.get(hash(), std::move(is_equal), std::move(make)); +} + +auto dictionary_attr::get_name_string(tinytc_attr_t name) -> std::string_view { + auto stra = dyn_cast(name); + if (stra) { + return stra->str(); + } + throw status::ir_expected_string_attribute; +} + +void dictionary_attr::sort(mutable_array_view unsorted_attrs) { + if (unsorted_attrs.empty()) { + return; + } + std::sort(unsorted_attrs.begin(), unsorted_attrs.end(), + [](named_attr const &a, named_attr const &b) { + return get_name_string(a.name) < get_name_string(b.name); + }); + for (std::size_t i = 1; i < unsorted_attrs.size(); ++i) { + if (unsorted_attrs[i - 1].name == unsorted_attrs[i].name) { + throw status::ir_duplicate_key_in_dictionary; + } + } +} + +dictionary_attr::dictionary_attr(tinytc_compiler_context_t ctx, + std::vector sorted_attrs) + : tinytc_attr(AK::dictionary, ctx), attrs_{std::move(sorted_attrs)} {} + +auto dictionary_attr::find(tinytc_attr_t name) -> tinytc_attr_t { + if (!attrs_.empty() && name) { + auto namestr = get_name_string(name); + std::size_t lb = 0; + std::size_t ub = attrs_.size(); + + do { + std::size_t mid = (lb + ub) / 2; + auto cmp = namestr.compare(get_name_string(attrs_[mid].name)); + if (cmp == 0) { + return attrs_[mid].attr; + } else if (cmp < 0) { + ub = mid; + } else { + lb = mid + 1; + } + } while (ub > lb); + } + return nullptr; +} +auto dictionary_attr::find(std::string_view name) -> tinytc_attr_t { + return find(string_attr::get(context(), name)); +} + +auto integer_attr::get(tinytc_compiler_context_t ctx, std::int64_t value) -> tinytc_attr_t { + const auto hash = fnv1a_combine(value); + const auto is_equal = [&](tinytc_attr_t a) { + const auto ia = dyn_cast(a); + return ia && value == ia->value(); + }; + const auto make = [&]() { return new integer_attr(ctx, value); }; + + auto &attrs = ctx->cache()->integer_attrs; + return attrs.get(hash, std::move(is_equal), std::move(make)); +} + +integer_attr::integer_attr(tinytc_compiler_context_t ctx, std::int64_t value) + : tinytc_attr(AK::integer, ctx), value_{value} {} + +auto string_attr::get(tinytc_compiler_context_t ctx, std::string_view str) -> tinytc_attr_t { + const auto hash = fnv1a_combine(str); + const auto is_equal = [&](tinytc_attr_t a) { + const auto stra = dyn_cast(a); + return stra && str == stra->str(); + }; + const auto make = [&]() { return new string_attr(ctx, std::string{str}); }; + + auto &attrs = ctx->cache()->string_attrs; + return attrs.get(hash, std::move(is_equal), std::move(make)); +} + +string_attr::string_attr(tinytc_compiler_context_t ctx, std::string str) + : tinytc_attr(AK::string, ctx), str_{std::move(str)} {} + +auto get_attr(tinytc_attr_t dict, tinytc_attr_t name) -> tinytc_attr_t { + if (auto da = dyn_cast(dict); da) { + return da->find(name); + } + return nullptr; +} +auto get_attr(tinytc_attr_t dict, std::string_view name) -> tinytc_attr_t { + if (auto da = dyn_cast(dict); da) { + return da->find(name); + } + return nullptr; +} + +} // namespace tinytc diff --git a/src/node/attr_node.hpp b/src/node/attr_node.hpp new file mode 100644 index 00000000..24eb3582 --- /dev/null +++ b/src/node/attr_node.hpp @@ -0,0 +1,150 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ATTR_NODE_20250205_HPP +#define ATTR_NODE_20250205_HPP + +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/type_list.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { +enum class AK { array, boolean, dictionary, integer, string }; +using attr_nodes = type_list; +} // namespace tinytc + +struct tinytc_attr { + public: + using leaves = tinytc::attr_nodes; + + inline tinytc_attr(tinytc::AK tid, tinytc_compiler_context_t ctx) : tid_(tid), ctx_(ctx) {} + virtual ~tinytc_attr() = default; + inline auto type_id() const -> tinytc::AK { return tid_; } + inline auto context() const -> tinytc_compiler_context_t { return ctx_; } + + private: + tinytc::AK tid_; + tinytc_compiler_context_t ctx_; +}; + +namespace tinytc { + +class array_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::array; } + static auto get(tinytc_compiler_context_t ctx, array_view values) + -> tinytc_attr_t; + + inline auto begin() const -> std::vector::const_iterator { + return values_.begin(); + } + inline auto end() const -> std::vector::const_iterator { return values_.end(); } + inline auto size() const { return values_.size(); } + inline auto const &values() const { return values_; } + inline auto value(std::size_t i) const -> tinytc_attr_t { return values_[i]; } + + protected: + array_attr(tinytc_compiler_context_t ctx, std::vector values); + + private: + std::vector values_; +}; + +class boolean_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::boolean; } + static auto get(tinytc_compiler_context_t ctx, bool value) -> tinytc_attr_t; + + inline auto value() const { return value_; } + + protected: + boolean_attr(tinytc_compiler_context_t ctx, bool value); + friend class compiler_context_cache; + + private: + bool value_; +}; + +class dictionary_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::dictionary; } + static auto get(tinytc_compiler_context_t ctx, array_view sorted_attrs) + -> tinytc_attr_t; + static void sort(mutable_array_view unsorted_attrs); + + inline auto begin() const -> std::vector::const_iterator { return attrs_.begin(); } + inline auto end() const -> std::vector::const_iterator { return attrs_.end(); } + inline auto const &attrs() const { return attrs_; } + + auto find(tinytc_attr_t name) -> tinytc_attr_t; + auto find(std::string_view name) -> tinytc_attr_t; + + protected: + dictionary_attr(tinytc_compiler_context_t ctx, std::vector sorted_attrs); + + private: + static auto get_name_string(tinytc_attr_t name) -> std::string_view; + + std::vector attrs_; +}; + +class integer_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::integer; } + static auto get(tinytc_compiler_context_t ctx, std::int64_t value) -> tinytc_attr_t; + + inline auto value() const -> std::int64_t { return value_; } + + protected: + integer_attr(tinytc_compiler_context_t ctx, std::int64_t value); + + private: + std::int64_t value_; +}; + +class string_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::string; } + static auto get(tinytc_compiler_context_t ctx, std::string_view str) -> tinytc_attr_t; + + inline auto str() const -> std::string_view { return str_; } + + protected: + string_attr(tinytc_compiler_context_t ctx, std::string str); + + private: + std::string str_; +}; + +auto get_attr(tinytc_attr_t dict, tinytc_attr_t name) -> tinytc_attr_t; +auto get_attr(tinytc_attr_t dict, std::string_view name) -> tinytc_attr_t; + +template auto get_array_attr_as(tinytc_attr_t a) -> std::vector { + auto aa = dyn_cast(a); + if (!aa) { + throw status::ir_expected_array_attribute; + } + auto result = std::vector{}; + result.reserve(aa->size()); + for (auto const &va : *aa) { + auto val = dyn_cast_or_throw(va, [&] { + return status::ir_expected_integer_attribute; + })->value(); + result.emplace_back(static_cast(val)); + } + return result; +} + +} // namespace tinytc + +#endif // ATTR_NODE_20250205_HPP diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp index df570f75..8ba9c42d 100644 --- a/src/node/data_type_node.cpp +++ b/src/node/data_type_node.cpp @@ -2,47 +2,180 @@ // SPDX-License-Identifier: BSD-3-Clause #include "node/data_type_node.hpp" +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" #include "error.hpp" +#include "scalar_type.hpp" +#include "support/fnv1a_array_view.hpp" // IWYU pragma: keep #include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" +#include "util/math.hpp" +#include +#include #include +#include #include namespace tinytc { -memref_data_type::memref_data_type(scalar_type type, std::vector shape, - std::vector stride, location const &lc) - : element_ty_(std::move(type)), shape_(std::move(shape)), stride_(std::move(stride)) { - loc(lc); +auto boolean_data_type::get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t { + return ctx->cache()->bool_ty.get(); +} + +auto coopmatrix_data_type::get(tinytc_data_type_t component_ty, std::int64_t rows, + std::int64_t cols, matrix_use use, location const &lc) + -> tinytc_data_type_t { + const auto hash = fnv1a_combine(component_ty, rows, cols, use); + const auto is_equal = [&](tinytc_data_type_t ty) { + const auto ct = dyn_cast(ty); + return ct && component_ty == ct->ty() && rows == ct->rows() && cols == ct->cols() && + use == ct->use(); + }; + const auto make = [&]() { return new coopmatrix_data_type(component_ty, rows, cols, use, lc); }; + + auto &tys = component_ty->context()->cache()->coopmatrix_tys; + return tys.get(hash, is_equal, make); +} + +coopmatrix_data_type::coopmatrix_data_type(tinytc_data_type_t ty, std::int64_t rows0, + std::int64_t cols0, matrix_use use, location const &lc) + : data_type_node(DTK::coopmatrix, ty->context()), ty_(std::move(ty)), shape_{rows0, cols0}, + use_(use) { + if (!isa(*ty_)) { + throw compilation_error(lc, status::ir_expected_scalar); + } + if (rows() < 0 || is_dynamic_value(rows())) { + throw compilation_error(lc, status::ir_invalid_shape); + } + if (!is_positive_power_of_two(rows())) { + throw compilation_error(lc, status::ir_unsupported_coopmatrix_shape); + } + if (cols() < 0 || is_dynamic_value(cols())) { + throw compilation_error(lc, status::ir_invalid_shape); + } +} + +auto coopmatrix_data_type::component_ty() const -> scalar_type { + return dyn_cast(ty_)->ty(); +} + +auto group_data_type::get(tinytc_data_type_t memref_ty, std::int64_t size, std::int64_t offset, + location const &lc) -> tinytc_data_type_t { + const auto hash = fnv1a_combine(memref_ty, size, offset); + const auto is_equal = [&](tinytc_data_type_t ty) { + const auto gt = dyn_cast(ty); + return gt && memref_ty == gt->ty() && size == gt->size() && offset == gt->offset(); + }; + const auto make = [&]() { return new group_data_type(memref_ty, size, offset, lc); }; + + auto &tys = memref_ty->context()->cache()->group_tys; + return tys.get(hash, std::move(is_equal), std::move(make)); +} + +group_data_type::group_data_type(tinytc_data_type_t ty, std::int64_t size, std::int64_t offset, + location const &lc) + : data_type_node(DTK::group, ty->context()), ty_(std::move(ty)), size_(size), offset_(offset) { + if (!isa(*ty_)) { + throw compilation_error(lc, status::ir_expected_memref); + } + if (size < 0 && !is_dynamic_value(size)) { + throw compilation_error(lc, status::ir_invalid_shape); + } + if (offset < 0 && !is_dynamic_value(offset)) { + throw compilation_error(lc, status::ir_invalid_offset); + } +} + +memref_data_type::memref_data_type(tinytc_data_type_t element_ty, std::vector shape, + std::vector stride, address_space addrspace, + location const &lc) + : data_type_node(DTK::memref, element_ty->context()), element_ty_(element_ty), + shape_(std::move(shape)), stride_(std::move(stride)), addrspace_(addrspace) { + if (!isa(*element_ty_)) { + throw compilation_error(lc, status::ir_expected_scalar); + } + if (stride_.size() != shape_.size()) { + throw compilation_error(lc, status::ir_shape_stride_mismatch); + } for (auto const &s : shape_) { if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(loc(), status::ir_invalid_shape); + throw compilation_error(lc, status::ir_invalid_shape); } } - if (stride_.empty()) { - stride_ = canonical_stride(); - } else { - for (auto const &s : stride_) { - if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(loc(), status::ir_invalid_shape); - } + for (auto const &s : stride_) { + if (s < 0 && !is_dynamic_value(s)) { + throw compilation_error(lc, status::ir_invalid_shape); } } - if (stride_.size() != shape_.size()) { - throw compilation_error(loc(), status::ir_shape_stride_mismatch); +} + +scalar_type memref_data_type::element_ty() const { + return dyn_cast(element_ty_)->ty(); +} + +auto memref_data_type::element_alignment() const -> std::int32_t { + return ::tinytc::alignment(element_ty()); +} +auto memref_data_type::size_in_bytes() const -> std::int64_t { + if (is_dynamic()) { + return dynamic; } + std::size_t s = size(element_ty()); + if (dim() > 0) { + s *= stride_.back() * shape_.back(); + } + return s; } -auto memref_data_type::canonical_stride() const -> std::vector { - if (shape_.empty()) { +auto memref_data_type::get(tinytc_data_type_t element_ty, array_view shape, + array_view stride, address_space addrspace, + location const &lc) -> tinytc_data_type_t { + + auto stride_buffer = std::vector{}; + if (stride.empty()) { + stride_buffer = canonical_stride(shape); + stride = array_view{stride_buffer}; + } + + const auto hash = fnv1a_combine(element_ty, shape, stride, addrspace); + const auto is_equal = [&](tinytc_data_type_t ty) { + const auto mt = dyn_cast(ty); + return mt && element_ty == mt->element_data_ty() && addrspace == mt->addrspace() && + std::equal(shape.begin(), shape.end(), mt->shape().begin(), mt->shape().end()) && + std::equal(stride.begin(), stride.end(), mt->stride().begin(), mt->stride().end()); + }; + const auto make = [&]() { + if (!stride_buffer.empty()) { + return new memref_data_type(element_ty, shape, std::move(stride_buffer), addrspace, lc); + } + return new memref_data_type(element_ty, shape, stride, addrspace, lc); + }; + + auto &tys = element_ty->context()->cache()->memref_tys; + return tys.get(hash, std::move(is_equal), std::move(make)); +} + +auto memref_data_type::canonical_stride(array_view shape) + -> std::vector { + if (shape.empty()) { return {}; } - auto stride = std::vector(shape_.size(), dynamic); + auto stride = std::vector(shape.size(), dynamic); stride[0] = 1; - for (std::size_t i = 0; i < shape_.size() - 1 && !is_dynamic_value(shape_[i]); ++i) { - stride[i + 1] = stride[i] * shape_[i]; + for (std::size_t i = 0; i < shape.size() - 1 && !is_dynamic_value(shape[i]); ++i) { + stride[i + 1] = stride[i] * shape[i]; } return stride; } +auto scalar_data_type::get(tinytc_compiler_context_t ctx, scalar_type ty) -> tinytc_data_type_t { + return ctx->cache()->scalar_tys[static_cast(ty)].get(); +} + +auto void_data_type::get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t { + return ctx->cache()->void_ty.get(); +} + } // namespace tinytc diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp index 8321554c..21490815 100644 --- a/src/node/data_type_node.hpp +++ b/src/node/data_type_node.hpp @@ -4,75 +4,122 @@ #ifndef DATA_TYPE_NODE_20230309_HPP #define DATA_TYPE_NODE_20230309_HPP -#include "reference_counted.hpp" -#include "scalar_type.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" - -#include -#include -#include +#include "util/type_list.hpp" #include +#include #include -#include #include namespace tinytc { -using data_type_nodes = clir::virtual_type_list; -} +enum class DTK { bool_, coopmatrix, group, memref, scalar, void_ }; +using data_type_nodes = + type_list; +} // namespace tinytc -struct tinytc_data_type : tinytc::reference_counted, tinytc::data_type_nodes { +struct tinytc_data_type { public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + using leaves = tinytc::data_type_nodes; + + inline tinytc_data_type(tinytc::DTK tid, tinytc_compiler_context_t ctx) + : tid_(tid), ctx_(ctx) {} + virtual ~tinytc_data_type() = default; + inline auto type_id() const -> tinytc::DTK { return tid_; } + inline auto context() const -> tinytc_compiler_context_t { return ctx_; } private: - tinytc::location loc_; + tinytc::DTK tid_; + tinytc_compiler_context_t ctx_; }; namespace tinytc { using data_type_node = ::tinytc_data_type; -class group_data_type : public clir::visitable { +class boolean_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::bool_; } + static auto get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t; + + protected: + inline boolean_data_type(tinytc_compiler_context_t ctx) : data_type_node(DTK::bool_, ctx) {} + friend class compiler_context_cache; +}; + +class coopmatrix_data_type : public data_type_node { public: - inline group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}) - : ty_(std::move(ty)), offset_(offset) { - loc(lc); + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::coopmatrix; } + static auto get(tinytc_data_type_t ty, std::int64_t rows, std::int64_t cols, matrix_use use, + location const &lc = {}) -> tinytc_data_type_t; + + inline auto ty() const -> tinytc_data_type_t { return ty_; } + auto component_ty() const -> scalar_type; + inline auto shape() const -> std::array { return shape_; } + inline auto shape(int mode) const -> std::int64_t { return shape_[mode]; } + inline auto rows() const -> std::int64_t { return shape_[0]; } + inline auto cols() const -> std::int64_t { return shape_[1]; } + inline auto use() const -> matrix_use { return use_; } + inline auto distributed_mode() const -> int { return use_ == matrix_use::b ? 1 : 0; } + inline auto num_blocks(std::int32_t subgroup_size) const -> std::int64_t { + return 1 + (shape(distributed_mode()) - 1) / subgroup_size; } + // Number of components per work-item + inline auto length(std::int32_t subgroup_size) const -> std::int64_t { + return num_blocks(subgroup_size) * shape(1 - distributed_mode()); + } + + protected: + coopmatrix_data_type(tinytc_data_type_t ty, std::int64_t rows, std::int64_t cols, + matrix_use use, location const &lc = {}); + + private: + tinytc_data_type_t ty_; + std::array shape_; + matrix_use use_; +}; + +class group_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::group; } + static auto get(tinytc_data_type_t ty, std::int64_t size, std::int64_t offset, + location const &lc = {}) -> tinytc_data_type_t; - inline auto ty() const -> data_type const & { return ty_; } + inline auto ty() const -> tinytc_data_type_t { return ty_; } + inline auto size() const -> std::int64_t { return size_; } inline auto offset() const -> std::int64_t { return offset_; } + protected: + group_data_type(tinytc_data_type_t memref_ty, std::int64_t size, std::int64_t offset = 0, + location const &lc = {}); + private: - data_type ty_; + tinytc_data_type_t ty_; + std::int64_t size_; std::int64_t offset_; }; -class void_data_type : public clir::visitable {}; - -class memref_data_type : public clir::visitable { +class memref_data_type : public data_type_node { public: - memref_data_type(scalar_type type, std::vector shape, - std::vector stride = {}, location const &lc = {}); - - inline scalar_type element_ty() const { return element_ty_; } - inline clir::data_type clir_element_ty() const { return to_clir_ty(element_ty_, addrspace_); } - inline clir::data_type clir_atomic_element_ty() const { - return to_clir_atomic_ty(element_ty_, addrspace_); - } + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::memref; } + static auto canonical_stride(array_view shape) -> std::vector; + static auto get(tinytc_data_type_t element_ty, array_view shape, + array_view stride, + address_space addrspace = address_space::global, location const &lc = {}) + -> tinytc_data_type_t; + + scalar_type element_ty() const; + inline tinytc_data_type_t element_data_ty() const { return element_ty_; } inline std::int64_t dim() const { return shape_.size(); } inline auto const &shape() const { return shape_; } inline std::int64_t shape(std::int64_t i) const { return shape_[i]; } inline auto const &stride() const { return stride_; } inline std::int64_t stride(std::int64_t i) const { return stride_[i]; } - inline std::int64_t size_in_bytes() const { - return is_dynamic() ? dynamic : size(element_ty_) * stride_.back() * shape_.back(); - } - inline clir::address_space addrspace() const { return addrspace_; } - inline void addrspace(clir::address_space space) { addrspace_ = space; } + inline auto addrspace() const -> address_space { return addrspace_; } + inline void addrspace(address_space space) { addrspace_ = space; } inline bool is_dynamic_shape() const { return std::any_of(shape_.begin(), shape_.end(), is_dynamic_value); @@ -81,27 +128,47 @@ class memref_data_type : public clir::visitable std::vector; + auto element_alignment() const -> std::int32_t; + auto size_in_bytes() const -> std::int64_t; - scalar_type element_ty_; + protected: + memref_data_type(tinytc_data_type_t element_ty, std::vector shape, + std::vector stride, + address_space addrspace = address_space::global, location const &lc = {}); + + tinytc_data_type_t element_ty_; std::vector shape_, stride_; - clir::address_space addrspace_ = clir::address_space::global_t; + address_space addrspace_ = address_space::global; }; -class scalar_data_type : public clir::visitable { +class scalar_data_type : public data_type_node { public: - inline scalar_data_type(scalar_type type, location const &lc) : ty_(type) { loc(lc); } + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::scalar; } + static auto get(tinytc_compiler_context_t ctx, scalar_type ty) -> tinytc_data_type_t; inline scalar_type ty() const { return ty_; } - inline clir::data_type clir_ty() const { return to_clir_ty(ty_); } + + protected: + inline scalar_data_type(tinytc_compiler_context_t ctx, scalar_type type) + : data_type_node(DTK::scalar, ctx), ty_(type) {} + friend class compiler_context_cache; private: scalar_type ty_; }; +class void_data_type : public data_type_node { + public: + inline static bool classof(data_type_node const &d) { return d.type_id() == DTK::void_; } + static auto get(tinytc_compiler_context_t ctx) -> tinytc_data_type_t; + + protected: + inline void_data_type(tinytc_compiler_context_t ctx) : data_type_node(DTK::void_, ctx) {} + friend class compiler_context_cache; +}; + } // namespace tinytc #endif // DATA_TYPE_NODE_20230309_HPP diff --git a/src/node/function_node.cpp b/src/node/function_node.cpp new file mode 100644 index 00000000..ca6607ba --- /dev/null +++ b/src/node/function_node.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/function_node.hpp" +#include "error.hpp" +#include "node/attr_node.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include + +using namespace tinytc; + +tinytc_func::tinytc_func(std::string name, tinytc::array_view params, + tinytc_data_type_t ty, tinytc_location const &lc) + : name_(std::move(name)), ty_{ty}, loc_{lc} { + body_.kind(tinytc::region_kind::collective); + body_.loc(loc_); + body_.set_params(std::move(params)); +} + +void tinytc_func::param_attr(std::int32_t param_no, tinytc_attr_t a) { + if (param_no < 0 || param_no >= num_params()) { + throw compilation_error(loc(), status::invalid_arguments); + } + if (static_cast(param_attr_.size()) != num_params()) { + param_attr_.resize(num_params(), nullptr); + } + param_attr_[param_no] = a; +} +auto tinytc_func::param_attr(std::int32_t param_no) const -> tinytc_attr_t { + if (param_no < 0 || param_no >= num_params()) { + throw compilation_error(loc(), status::invalid_arguments); + } + if (param_attr_.empty()) { + return nullptr; + } + return param_attr_[param_no]; +} + +auto tinytc_func::subgroup_size() const -> std::int32_t { + if (auto sgs_attr = get_attr(attr_, "subgroup_size"); sgs_attr) { + auto sgs = dyn_cast_or_throw(sgs_attr, [&] { + return compilation_error(loc_, status::ir_expected_integer_attribute); + }); + return sgs->value(); + } + throw compilation_error(loc_, status::internal_compiler_error, "Subgroup size is missing"); +} + +auto tinytc_func::work_group_size() const -> std::array { + if (auto wgs_attr = get_attr(attr_, "work_group_size"); wgs_attr) { + auto wgs_array = dyn_cast_or_throw( + wgs_attr, [&] { return compilation_error(loc_, status::ir_expected_array_attribute); }); + if (wgs_array->size() != 2) { + throw compilation_error(loc_, status::ir_unexpected_array_attribute_size, + "Work group size attribute must have 2 entries"); + } + auto wgs = std::array{}; + for (std::size_t i = 0; i < 2; ++i) { + wgs[i] = dyn_cast_or_throw(wgs_array->value(i), [&] { + return compilation_error(loc_, status::ir_expected_integer_attribute); + })->value(); + } + return wgs; + } + throw compilation_error(loc_, status::internal_compiler_error, "Work group size is missing"); +} diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp index aa964558..1c03f025 100644 --- a/src/node/function_node.hpp +++ b/src/node/function_node.hpp @@ -4,78 +4,55 @@ #ifndef FUNCTION_NODE_20230310_HPP #define FUNCTION_NODE_20230310_HPP -#include "location.hpp" -#include "reference_counted.hpp" +#include "node/region_node.hpp" #include "tinytc/tinytc.hpp" - -#include +#include "tinytc/types.h" #include #include #include -#include +#include #include -namespace tinytc { -using function_nodes = clir::virtual_type_list; -} - -struct tinytc_func : tinytc::reference_counted, tinytc::function_nodes { +struct tinytc_func final { public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + tinytc_func(std::string name, tinytc::array_view params, + tinytc_data_type_t ty, tinytc_location const &lc = {}); - virtual auto name() const -> std::string_view = 0; + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } + inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } - private: - tinytc::location loc_; -}; + inline auto ty() const noexcept -> tinytc_data_type_t { return ty_; } -namespace tinytc { + inline auto params() { return body_.params(); } + inline auto params() const { return body_.params(); } + inline auto num_params() const noexcept { return body_.num_params(); } -using function_node = ::tinytc_func; + inline auto name() const -> std::string_view { return name_; } + inline auto body() -> tinytc_region & { return body_; } + inline auto body() const -> tinytc_region const & { return body_; } -class prototype : public clir::visitable { - public: - inline prototype(std::string name, std::vector args = {}, location const &lc = {}) - : name_(std::move(name)), args_(std::move(args)) { - loc(lc); - } + inline void attr(tinytc_attr_t a) { attr_ = a; } + inline auto attr() const -> tinytc_attr_t { return attr_; } + + void param_attr(std::int32_t param_no, tinytc_attr_t a); + auto param_attr(std::int32_t param_no) const -> tinytc_attr_t; - inline auto name() const -> std::string_view override { return name_; } - inline auto args() const -> std::vector const & { return args_; } + auto subgroup_size() const -> std::int32_t; + auto work_group_size() const -> std::array; private: std::string name_; - std::vector args_; + tinytc_data_type_t ty_; + tinytc_region body_; + tinytc_location loc_; + tinytc_attr_t attr_ = nullptr; + std::vector param_attr_; }; -class function : public clir::visitable { - public: - inline function(func prototype, region body, location const &lc = {}) - : prototype_(std::move(prototype)), body_(std::move(body)), work_group_size_{0, 0}, - subgroup_size_{0} { - loc(lc); - } - - inline auto name() const -> std::string_view override { return prototype_->name(); } - - inline auto prototype() const -> func const & { return prototype_; } - inline auto body() const -> region const & { return body_; } - inline auto work_group_size() const -> std::array { return work_group_size_; } - - inline void work_group_size(std::array const &work_group_size) { - work_group_size_ = work_group_size; - } - inline auto subgroup_size() const -> std::int32_t { return subgroup_size_; } - inline void subgroup_size(std::int32_t subgroup_size) { subgroup_size_ = subgroup_size; } +namespace tinytc { - private: - func prototype_; - region body_; - std::array work_group_size_; - std::int32_t subgroup_size_; -}; +using function_node = ::tinytc_func; } // namespace tinytc diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp index 0fa368ea..4965f11b 100644 --- a/src/node/inst_node.cpp +++ b/src/node/inst_node.cpp @@ -4,82 +4,256 @@ #include "node/inst_node.hpp" #include "error.hpp" #include "node/data_type_node.hpp" +#include "node/region_node.hpp" #include "node/value_node.hpp" #include "scalar_type.hpp" #include "tinytc/types.hpp" -#include "util.hpp" - -#include -#include +#include "util/casting.hpp" +#include "util/visit.hpp" +#include #include #include +#include #include +#include + +auto tinytc_inst::context() const -> tinytc_compiler_context_t { + if (num_results() > 0) { + return result(0).context(); + } else if (num_operands() > 0) { + return op(0).context(); + } + return nullptr; +} + +void tinytc_inst::subs(tinytc_value_t old_value, tinytc_value_t new_value, bool recursive) { + for (auto op = op_begin_; op != op_end_; ++op) { + if (op->get() == old_value) { + op->set(new_value); + } + } + if (recursive) { + for (auto ® : child_regions()) { + for (auto &in : reg) { + in.subs(old_value, new_value, true); + } + } + } +} + +auto tinytc_inst::kind() const -> tinytc::inst_execution_kind { + switch (type_id()) { + case tinytc::IK::alloca: + case tinytc::IK::lifetime_stop: + case tinytc::IK::foreach_loop: + case tinytc::IK::parallel: + case tinytc::IK::blas_a2: + case tinytc::IK::axpby_blas_a2: + case tinytc::IK::cumsum_blas_a2: + case tinytc::IK::sum_blas_a2: + case tinytc::IK::last_blas_a2: + case tinytc::IK::blas_a3: + case tinytc::IK::gemm_blas_a3: + case tinytc::IK::gemv_blas_a3: + case tinytc::IK::ger_blas_a3: + case tinytc::IK::hadamard_blas_a3: + case tinytc::IK::last_blas_a3: + return tinytc::inst_execution_kind::collective; + case tinytc::IK::arith: + case tinytc::IK::arith_unary: + case tinytc::IK::barrier: + case tinytc::IK::cast: + case tinytc::IK::compare: + case tinytc::IK::constant: + case tinytc::IK::expand: + case tinytc::IK::fuse: + case tinytc::IK::if_: + case tinytc::IK::load: + case tinytc::IK::math_unary_: + case tinytc::IK::size: + case tinytc::IK::store: + case tinytc::IK::subview: + case tinytc::IK::yield: + case tinytc::IK::loop: + case tinytc::IK::for_loop: + case tinytc::IK::last_loop: + return tinytc::inst_execution_kind::mixed; + case tinytc::IK::cooperative_matrix_apply: + case tinytc::IK::cooperative_matrix_extract: + case tinytc::IK::cooperative_matrix_insert: + case tinytc::IK::cooperative_matrix_load: + case tinytc::IK::cooperative_matrix_mul_add: + case tinytc::IK::cooperative_matrix_prefetch: + case tinytc::IK::cooperative_matrix_reduce: + case tinytc::IK::cooperative_matrix_scale: + case tinytc::IK::cooperative_matrix_store: + case tinytc::IK::subgroup_broadcast: + case tinytc::IK::subgroup_operation: + return tinytc::inst_execution_kind::spmd; + case tinytc::IK::builtin: + return tinytc::dyn_cast(this)->kind(); + }; + throw tinytc::internal_compiler_error(); +} namespace tinytc { -scalar_data_type *get_scalar_type(location const &loc, value &v) { - auto m = dynamic_cast(v->ty().get()); +coopmatrix_data_type *get_coopmatrix_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); if (m == nullptr) { - throw compilation_error(loc, status::ir_expected_scalar); + throw compilation_error(loc, {&v}, status::ir_expected_coopmatrix); } return m; } -memref_data_type *get_memref_type(location const &loc, value &v) { - auto m = dynamic_cast(v->ty().get()); +scalar_data_type *get_scalar_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); if (m == nullptr) { - throw compilation_error(loc, status::ir_expected_memref); + throw compilation_error(loc, {&v}, status::ir_expected_scalar); } return m; } -blas_a2_inst::blas_a2_inst(value alpha, value A, value beta, value B, bool atomic) - : alpha_(std::move(alpha)), A_(std::move(A)), beta_(std::move(beta)), B_(std::move(B)), - atomic_(atomic) {} +memref_data_type *get_memref_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); + if (m == nullptr) { + throw compilation_error(loc, {&v}, status::ir_expected_memref); + } + return m; +} -blas_a3_inst::blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic) - : alpha_(std::move(alpha)), A_(std::move(A)), B_(std::move(B)), beta_(std::move(beta)), - C_(std::move(C)), atomic_(atomic) {} +void check_index_ty(location const &loc, tinytc_value const &v) { + if (auto sty = dyn_cast(v.ty()); !sty || sty->ty() != scalar_type::index) { + throw compilation_error(loc, {&v}, status::ir_expected_index); + } +} -loop_inst::loop_inst(value loop_var, value from, value to, region body, location const &lc) - : loop_inst(std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), lc) {} +void check_memref_shape(memref_data_type *rt, std::int64_t ri, memref_data_type *ot, + std::int64_t oi, location const &loc) { + if (rt->shape(ri) != ot->shape(oi)) { + auto extra_info = std::ostringstream{} << "Size of mode " << ri + << " does not match operand mode " << oi << " [" + << rt->shape(ri) << "!=" << ot->shape(oi) << "]"; + throw compilation_error(loc, status::ir_invalid_shape, std::move(extra_info).str()); + } +} +void check_memref_stride(memref_data_type *rt, std::int64_t ri, memref_data_type *ot, + std::int64_t oi, location const &loc) { + if (!is_dynamic_value(rt->stride(ri)) && rt->stride(ri) != ot->stride(oi)) { + auto extra_info = std::ostringstream{} << "Stride of mode " << ri + << " does not match operand stride " << oi << " [" + << rt->stride(ri) << "!=" << ot->stride(oi) << "]"; + throw compilation_error(loc, status::ir_invalid_stride, std::move(extra_info).str()); + } +} -loop_inst::loop_inst(value loop_var, value from, value to, value step, region body, - location const &lc) - : loop_var_(std::move(loop_var)), from_(std::move(from)), to_(std::move(to)), - step_(std::move(step)), body_(std::move(body)) { +void check_memref_mode(memref_data_type *rt, std::int64_t ri, memref_data_type *ot, std::int64_t oi, + location const &loc) { + check_memref_shape(rt, ri, ot, oi, loc); + check_memref_stride(rt, ri, ot, oi, loc); +} + +auto get_and_check_memref_type_addrspace(tinytc_value const &operand, tinytc_data_type_t ty, + location const &loc) + -> std::pair { + auto rt = dyn_cast(ty); + if (!rt) { + throw compilation_error(loc, status::ir_expected_memref); + } + auto ot = get_memref_type(loc, operand); + if (rt->element_data_ty() != ot->element_data_ty()) { + throw compilation_error(loc, {&operand}, status::ir_scalar_mismatch); + } + if (rt->addrspace() != ot->addrspace()) { + throw compilation_error(loc, {&operand}, status::ir_address_space_mismatch); + } + return {ot, rt}; +} + +blas_a2_inst::blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic, location const &lc) + : standard_inst{tid}, atomic_(atomic) { + op(op_alpha, alpha); + op(op_A, A); + op(op_beta, beta); + op(op_B, B); loc(lc); - auto lvt = get_scalar_type(loc(), loop_var_); - auto fromt = get_scalar_type(loc(), from_); - auto tot = get_scalar_type(loc(), to_); - bool step_ok = true; - if (step_) { - auto stept = get_scalar_type(loc(), step_); - step_ok = lvt->ty() == stept->ty(); + + auto At = get_memref_type(loc(), op(op_A)); + auto Bt = get_memref_type(loc(), op(op_B)); + auto alphat = get_scalar_type(loc(), op(op_alpha)); + auto betat = get_scalar_type(loc(), op(op_beta)); + + if (!promotable(alphat->ty(), At->element_ty())) { + throw compilation_error(loc(), {&op(op_alpha), &op(op_A)}, status::ir_forbidden_promotion); } + if (!promotable(At->element_ty(), Bt->element_ty())) { + throw compilation_error(loc(), {&op(op_A), &op(op_B)}, status::ir_forbidden_promotion); + } + if (!promotable(betat->ty(), Bt->element_ty())) { + throw compilation_error(loc(), {&op(op_beta), &op(op_B)}, status::ir_forbidden_promotion); + } +} - if (lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || !step_ok) { - throw compilation_error(loc(), status::ir_scalar_mismatch); +blas_a3_inst::blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic, location const &lc) + : standard_inst{tid}, atomic_(atomic) { + op(op_alpha, alpha); + op(op_A, A); + op(op_B, B); + op(op_beta, beta); + op(op_C, C); + loc(lc); + + auto At = get_memref_type(loc(), op(op_A)); + auto Bt = get_memref_type(loc(), op(op_B)); + auto Ct = get_memref_type(loc(), op(op_C)); + auto alphat = get_scalar_type(loc(), op(op_alpha)); + auto betat = get_scalar_type(loc(), op(op_beta)); + + const auto AB_ty = promote(At->element_ty(), Bt->element_ty()); + if (!AB_ty) { + throw compilation_error(loc(), {&op(op_A), &op(op_B)}, status::ir_forbidden_promotion); + } + if (!promotable(alphat->ty(), *AB_ty)) { + throw compilation_error(loc(), {&op(op_alpha), &op(op_A), &op(op_B)}, + status::ir_forbidden_promotion); + } + if (!promotable(*AB_ty, Ct->element_ty())) { + throw compilation_error(loc(), {&op(op_A), &op(op_B), &op(op_C)}, + status::ir_forbidden_promotion); + } + if (!promotable(betat->ty(), Ct->element_ty())) { + throw compilation_error(loc(), {&op(op_beta), &op(op_C)}, status::ir_forbidden_promotion); } } -alloca_inst::alloca_inst(data_type ty, location const &lc) - : result_{make_value(std::move(ty))}, stack_ptr_{-1} { +alloca_inst::alloca_inst(tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::alloca}, stack_ptr_{-1} { loc(lc); - auto memref = dynamic_cast(result_->ty().get()); + + result(0) = value_node{ty, this, lc}; + auto memref = dyn_cast(result(0).ty()); if (memref == nullptr) { throw compilation_error(loc(), status::ir_expected_memref); } - memref->addrspace(clir::address_space::local_t); + if (memref->addrspace() != address_space::local) { + throw compilation_error(loc(), status::ir_expected_local_address_space); + } } -axpby_inst::axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(beta), std::move(B), atomic), tA_(tA) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); +axpby_inst::axpby_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t beta0, + tinytc_value_t B0, bool atomic, location const &lc) + : blas_a2_inst(IK::axpby_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + std::move(B0), atomic, lc), + tA_(tA) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + + if (b->dim() < 0 || b->dim() > 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_0_1_or_2); + } bool shape_equal = false; if (tA_ == transpose::T && a->dim() == 2 && b->dim() == 2) { @@ -89,445 +263,1217 @@ axpby_inst::axpby_inst(transpose tA, value alpha, value A, value beta, value B, } if (!shape_equal) { - throw compilation_error(loc(), status::ir_incompatible_shapes); + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); } +} + +arith_inst::arith_inst(arithmetic operation, tinytc_value_t a0, tinytc_value_t b0, + tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::arith}, operation_(operation) { + op(op_a, a0); + op(op_b, b0); + loc(lc); - if (b->dim() > 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix); + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } + if (b().ty() != ty) { + throw compilation_error(loc(), {&b()}, status::ir_operand_type_must_match_return_type); } + + if (isa(*ty)) { + auto const inst_supports_bool = [&] { + switch (operation) { + case arithmetic::and_: + case arithmetic::or_: + case arithmetic::xor_: + return true; + default: + return false; + } + }(); + if (!inst_supports_bool) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + } else { + auto const check_scalar_ty = [&](scalar_type sty) { + bool inst_supports_fp = true; + bool inst_supports_complex = true; + switch (operation) { + case arithmetic::add: + case arithmetic::sub: + case arithmetic::mul: + case arithmetic::div: + break; + case arithmetic::min: + case arithmetic::max: + case arithmetic::rem: + inst_supports_complex = false; + break; + case arithmetic::and_: + case arithmetic::or_: + case arithmetic::xor_: + inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic::shl: + case arithmetic::shr: + inst_supports_fp = false; + inst_supports_complex = false; + break; + } + if (!inst_supports_fp && is_floating_type(sty)) { + throw compilation_error(loc(), status::ir_fp_unsupported); + } + if (!inst_supports_complex && is_complex_type(sty)) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + }; + + if (auto ct = dyn_cast(ty); ct) { + check_scalar_ty(ct->component_ty()); + } else { + check_scalar_ty(get_scalar_type(loc(), ty)->ty()); + } + } + + result(0) = value_node{ty, this, lc}; } -arith_inst::arith_inst(arithmetic op, value a, value b, location const &lc) - : op_(op), a_(std::move(a)), b_(std::move(b)) { +arith_unary_inst::arith_unary_inst(arithmetic_unary operation, tinytc_value_t a0, + tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::arith_unary}, operation_(operation) { + op(op_a, a0); loc(lc); - auto at = get_scalar_type(loc(), a_); - auto bt = get_scalar_type(loc(), b_); + result(0) = value_node{ty, this, lc}; - if (at->ty() != bt->ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } - bool inst_supports_fp = false; - switch (op) { - case arithmetic::add: - case arithmetic::sub: - case arithmetic::mul: - case arithmetic::div: - case arithmetic::rem: - inst_supports_fp = true; + if (isa(*ty)) { + if (operation_ != arithmetic_unary::not_) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + } else { + auto const check_scalar_ty = [&](scalar_type a_ty, scalar_type r_ty) { + // Check if inst is supported for combination of a type and result type + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::im: + case arithmetic_unary::re: { + if (r_ty != component_type(a_ty)) { + throw compilation_error(loc(), {&a()}, + status::ir_operand_type_must_match_return_type); + } + break; + } + default: + if (a_ty != r_ty) { + throw compilation_error(loc(), {&a()}, + status::ir_operand_type_must_match_return_type); + } + break; + } + + bool inst_supports_int = true; + bool inst_supports_fp = true; + bool inst_supports_complex = true; + switch (operation_) { + case arithmetic_unary::abs: + case arithmetic_unary::neg: + break; + case arithmetic_unary::not_: + inst_supports_fp = false; + inst_supports_complex = false; + break; + case arithmetic_unary::conj: + case arithmetic_unary::im: + case arithmetic_unary::re: + inst_supports_int = false; + inst_supports_fp = false; + break; + } + if (!inst_supports_int && is_integer_type(a_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_int_unsupported); + } + if (!inst_supports_fp && is_floating_type(a_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_fp_unsupported); + } + if (!inst_supports_complex && is_complex_type(a_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + }; + + auto ct = dyn_cast(a().ty()); + auto rt = dyn_cast(ty); + if (ct && rt) { + check_scalar_ty(ct->component_ty(), rt->component_ty()); + } else { + check_scalar_ty(get_scalar_type(loc(), a())->ty(), + get_scalar_type(loc(), result(0))->ty()); + } + } +} + +builtin_inst::builtin_inst(builtin btype, tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::builtin}, btype_{btype} { + loc(lc); + + auto rt = dyn_cast(ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_scalar); + } + + switch (builtin_type()) { + case builtin::group_id_x: + case builtin::group_id_y: + case builtin::group_id_z: + case builtin::num_groups_x: + case builtin::num_groups_y: + case builtin::num_groups_z: + if (rt->ty() != scalar_type::index) { + throw compilation_error(loc(), status::ir_expected_index); + } break; - case arithmetic::shl: - case arithmetic::shr: - case arithmetic::and_: - case arithmetic::or_: - case arithmetic::xor_: - inst_supports_fp = false; + case builtin::num_subgroups_x: + case builtin::num_subgroups_y: + case builtin::subgroup_size: + case builtin::subgroup_id_x: + case builtin::subgroup_id_y: + case builtin::subgroup_linear_id: + case builtin::subgroup_local_id: + if (rt->ty() != scalar_type::i32) { + throw compilation_error(loc(), status::ir_expected_i32); + } break; } - if (!inst_supports_fp && is_floating_type(at->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); + + result(0) = value_node{ty, this, lc}; +} + +auto builtin_inst::kind() const -> tinytc::inst_execution_kind { + switch (builtin_type()) { + case builtin::group_id_x: + case builtin::group_id_y: + case builtin::group_id_z: + case builtin::num_groups_x: + case builtin::num_groups_y: + case builtin::num_groups_z: + case builtin::num_subgroups_x: + case builtin::num_subgroups_y: + case builtin::subgroup_size: + return tinytc::inst_execution_kind::mixed; + case builtin::subgroup_id_x: + case builtin::subgroup_id_y: + case builtin::subgroup_linear_id: + case builtin::subgroup_local_id: + return tinytc::inst_execution_kind::spmd; + } + return tinytc::inst_execution_kind::spmd; +} + +cast_inst::cast_inst(tinytc_value_t a0, tinytc_data_type_t to_ty, location const &lc) + : standard_inst{IK::cast} { + op(op_a, a0); + loc(lc); + + if (auto rt = dyn_cast(to_ty); rt) { + auto ct = dyn_cast(a().ty()); + if (!ct) { + throw compilation_error(loc(), {&a()}, status::ir_expected_coopmatrix); + } + if (ct->rows() != rt->rows() || ct->cols() != rt->cols()) { + throw compilation_error(lc, {&a()}, status::ir_forbidden_cast); + } + const bool use_matches = ct->use() == rt->use(); + const bool use_conversion_allowed = + ct->use() == matrix_use::acc && + (rt->use() == matrix_use::a || rt->use() == matrix_use::b); + if (!use_matches && !use_conversion_allowed) { + throw compilation_error(lc, {&a()}, status::ir_forbidden_cast); + } + if (!is_cast_allowed(ct->component_ty(), rt->component_ty())) { + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); + } + } else { + auto to_ty_scalar = dyn_cast(to_ty); + if (to_ty_scalar == nullptr) { + throw compilation_error(lc, status::ir_expected_scalar); + } + + auto at = get_scalar_type(loc(), a()); + if (!is_cast_allowed(at->ty(), to_ty_scalar->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); + } } - result_ = make_value(at->ty()); + + result(0) = value_node{to_ty, this, loc()}; } -arith_unary_inst::arith_unary_inst(arithmetic_unary op, value a, location const &lc) - : op_(op), a_(std::move(a)) { +compare_inst::compare_inst(cmp_condition cond, tinytc_value_t a0, tinytc_value_t b0, + tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::compare}, cond_(cond) { + op(op_a, a0); + op(op_b, b0); loc(lc); - auto at = get_scalar_type(loc(), a_); - bool inst_supports_fp = false; - switch (op) { - case arithmetic_unary::neg: - inst_supports_fp = true; + if (!isa(*ty)) { + throw compilation_error(loc(), status::ir_expected_boolean); + } + + auto at = get_scalar_type(loc(), a()); + auto bt = get_scalar_type(loc(), b()); + + if (at->ty() != bt->ty()) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_scalar_mismatch); + } + + bool inst_supports_complex = true; + switch (cond_) { + case cmp_condition::eq: + case cmp_condition::ne: break; - case arithmetic_unary::not_: - inst_supports_fp = false; + case cmp_condition::gt: + case cmp_condition::ge: + case cmp_condition::lt: + case cmp_condition::le: + inst_supports_complex = false; break; } - if (!inst_supports_fp && is_floating_type(at->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); + if (!inst_supports_complex && is_complex_type(at->ty())) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_complex_unsupported); } - result_ = make_value(at->ty()); + + result(0) = value_node{ty, this, lc}; } -cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) - : a_(std::move(a)), result_{make_value(to_ty)} { +constant_inst::constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::constant}, value_(value) { loc(lc); + + const auto type_ok = [](value_type const &val, scalar_type ty) { + return (is_integer_type(ty) && std::holds_alternative(val)) || + (is_floating_type(ty) && std::holds_alternative(val)) || + (is_complex_type(ty) && std::holds_alternative>(val)); + }; + + if (auto bt = dyn_cast(ty); bt) { + if (!std::holds_alternative(value_)) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else if (auto st = dyn_cast(ty); st) { + if (!type_ok(value_, st->ty())) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else if (auto ct = dyn_cast(ty); ct) { + if (!type_ok(value_, ct->component_ty())) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else { + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); + } + + result(0) = value_node{ty, this, lc}; +} + +auto constant_inst::is_zero() const -> bool { + return std::visit([](auto const &v) { return v == decltype(v){0}; }, value_); +} +auto constant_inst::is_identity() const -> bool { + return std::visit([](auto const &v) { return v == decltype(v){1}; }, value_); } -compare_inst::compare_inst(cmp_condition cond, value a, value b, location const &lc) - : cond_(cond), a_(std::move(a)), b_(std::move(b)), result_{make_value(scalar_type::i1)} { +cooperative_matrix_apply_inst::cooperative_matrix_apply_inst(tinytc_value_t a0, + tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_apply} { + op(0, a0); loc(lc); - auto at = get_scalar_type(loc(), a_); - auto bt = get_scalar_type(loc(), b_); + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } + + auto at = get_coopmatrix_type(loc(), a()); - if (at->ty() != bt->ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + auto i32_ty = scalar_data_type::get(at->context(), scalar_type::i32); + + body().loc(lc); + body().defining_inst(this); + body().set_num_params(3); + body().set_param(0, i32_ty); + body().set_param(1, i32_ty); + body().set_param(2, at->ty()); + + result(0) = value_node{ty, this, lc}; +} + +cooperative_matrix_extract_inst::cooperative_matrix_extract_inst(tinytc_value_t mat0, + std::int64_t index, + tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_extract}, index_{index} { + op(0, mat0); + loc(lc); + + auto matt = get_coopmatrix_type(loc(), mat()); + if (matt->ty() != ty) { + throw compilation_error(loc(), {&mat()}, status::ir_scalar_mismatch); } + + result(0) = value_node{ty, this, lc}; } -gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, - bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic), - tA_(tA), tB_(tB) { +cooperative_matrix_insert_inst::cooperative_matrix_insert_inst(tinytc_value_t val0, + tinytc_value_t mat0, + std::int64_t index, + tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_insert}, index_{index} { + op(0, val0); + op(1, mat0); loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemm only supported for memref of order 2 (matrices)"); + if (mat().ty() != ty) { + throw compilation_error(loc(), {&mat()}, status::ir_operand_type_must_match_return_type); } - auto ak = tA_ == transpose::T ? 0 : 1; - auto bk = tB_ == transpose::T ? 1 : 0; - auto M = c->shape(0); - auto N = c->shape(1); - auto K = a->shape(ak); - if (a->shape(1 - ak) != M || b->shape(bk) != K || b->shape(1 - bk) != N) { - std::ostringstream oss; - oss << "Got "; - oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; - oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; - oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + auto valt = get_scalar_type(loc(), val()); + auto matt = get_coopmatrix_type(loc(), mat()); + if (matt->ty() != valt) { + throw compilation_error(loc(), {&val(), &mat()}, status::ir_scalar_mismatch); } + + result(0) = value_node{ty, this, lc}; } -gemv_inst::gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic), - tA_(tA) { +cooperative_matrix_load_inst::cooperative_matrix_load_inst(transpose t, checked_flag flag, + tinytc_value_t op0, tinytc_value_t p0, + tinytc_value_t p1, + tinytc_data_type_t to_ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_load}, t_(t), flag_(flag) { + op(op_operand, op0); + op(op_pos0, p0); + op(op_pos1, p1); loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - if (a->dim() != 2 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemv only supports matrix-vector products"); + auto rt = dyn_cast(to_ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_coopmatrix); } - auto ak = tA_ == transpose::T ? 0 : 1; - auto M = c->shape(0); - auto K = a->shape(ak); - if (a->shape(1 - ak) != M || b->shape(0) != K) { - std::ostringstream oss; - oss << "Got "; - oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + auto ot = get_memref_type(loc(), operand()); + if (ot->element_ty() != rt->component_ty()) { + throw compilation_error(loc(), {&operand()}, status::ir_scalar_mismatch); + } + if (ot->dim() != 2) { + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); } + + check_index_ty(lc, pos0()); + check_index_ty(lc, pos1()); + + result(0) = value_node{to_ty, this, lc}; } -ger_inst::ger_inst(value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic) { +cooperative_matrix_mul_add_inst::cooperative_matrix_mul_add_inst(tinytc_value_t a0, + tinytc_value_t b0, + tinytc_value_t c0, + tinytc_data_type_t to_ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_mul_add} { + op(op_a, a0); + op(op_b, b0); + op(op_c, c0); loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "ger requires two vectors as input and one matrix as output"); + auto rt = dyn_cast(to_ty); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_memref); + } + if (rt->use() != matrix_use::acc) { + throw compilation_error(loc(), status::ir_invalid_matrix_use); } - auto M = c->shape(0); - auto N = c->shape(1); - if (a->shape(0) != M || b->shape(0) != N) { + auto at = get_coopmatrix_type(loc(), a()); + auto bt = get_coopmatrix_type(loc(), b()); + auto ct = get_coopmatrix_type(loc(), c()); + if (at->use() != matrix_use::a) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_matrix_use); + } + if (bt->use() != matrix_use::b) { + throw compilation_error(loc(), {&b()}, status::ir_invalid_matrix_use); + } + if (ct->use() != matrix_use::acc) { + throw compilation_error(loc(), {&c()}, status::ir_invalid_matrix_use); + } + + auto M = rt->rows(); + auto N = rt->cols(); + auto K = at->cols(); + if (rt->rows() != M || rt->cols() != N || ct->rows() != M || ct->cols() != N || + at->rows() != M || bt->rows() != K || bt->cols() != N) { std::ostringstream oss; oss << "Got "; - oss << "a=" << a->shape(0) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + oss << "A=" << at->rows() << "x" << at->cols() << ", "; + oss << "B=" << bt->rows() << "x" << bt->cols() << ", "; + oss << "C=" << ct->rows() << "x" << ct->cols() << ", "; + oss << "result=" << rt->rows() << "x" << rt->cols(); + throw compilation_error(loc(), {&a(), &b(), &c()}, status::ir_incompatible_shapes, + oss.str()); + } + + const auto AB_ty = promote(at->component_ty(), bt->component_ty()); + if (!AB_ty) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_forbidden_promotion); + } + if (!promotable(*AB_ty, ct->component_ty())) { + throw compilation_error(loc(), {&a(), &b(), &c()}, status::ir_forbidden_promotion); + } + if (!is_cast_allowed(ct->component_ty(), rt->component_ty())) { + throw compilation_error(loc(), {&c()}, status::ir_forbidden_cast); + } + + result(0) = value_node{to_ty, this, lc}; +} + +auto cooperative_matrix_mul_add_inst::is_c_zero() const -> bool { + if (auto c_def = c().defining_inst(); c_def) { + if (auto *c_def_const = dyn_cast(c_def); c_def_const) { + return std::visit( + overloaded{[](bool) { return false; }, [](auto v) { return v == decltype(v){0}; }}, + c_def_const->value()); + } } + return false; } -hadamard_inst::hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic) { +cooperative_matrix_prefetch_inst::cooperative_matrix_prefetch_inst( + std::int32_t cache_level, tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, + std::int32_t rows, std::int32_t cols, location const &lc) + : standard_inst{IK::cooperative_matrix_prefetch}, cache_level_{cache_level}, rows_{rows}, + cols_{cols} { + op(op_operand, op0); + op(op_pos0, p0); + op(op_pos1, p1); loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "hadamard requires two vectors as input and one vector as output"); + auto ot = get_memref_type(loc(), operand()); + if (ot->dim() != 2) { + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); + } + if (rows_ <= 0 || cols_ <= 0) { + throw compilation_error(loc(), {}, status::ir_invalid_shape); } - auto M = c->shape(0); - if (a->shape(0) != M || b->shape(0) != M) { - std::ostringstream oss; - oss << "Got "; - oss << "a=" << a->shape(0) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); + check_index_ty(lc, pos0()); + check_index_ty(lc, pos1()); +} + +cooperative_matrix_reduce_inst::cooperative_matrix_reduce_inst(group_arithmetic arith, + reduce_mode mode, tinytc_value_t a0, + tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_reduce}, arith_{arith}, mode_{mode} { + op(0, a0); + loc(lc); + + auto at = get_coopmatrix_type(loc(), a()); + auto rt = get_coopmatrix_type(loc(), ty); + if (at->ty() != rt->ty()) { + throw compilation_error(loc(), {&a()}, status::ir_scalar_mismatch); } + if (at->use() != rt->use()) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_matrix_use); + } + const int m = mode_ == reduce_mode::column ? 0 : 1; + if (rt->shape(1 - m) != at->shape(1 - m) || rt->shape(m) != 1) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_shape); + } + + result(0) = value_node{ty, this, lc}; } -expand_inst::expand_inst(value op, std::int64_t mode, std::vector expand_shape, +cooperative_matrix_scale_inst::cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, + tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::cooperative_matrix_scale} { + op(op_a, a0); + op(op_b, b0); + loc(lc); + + if (b().ty() != ty) { + throw compilation_error(loc(), {&b()}, status::ir_operand_type_must_match_return_type); + } + + auto at = get_scalar_type(loc(), a()); + auto bt = get_coopmatrix_type(loc(), b()); + + if (at->ty() != bt->component_ty()) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_scalar_mismatch); + } + + result(0) = value_node{ty, this, lc}; +} + +cooperative_matrix_store_inst::cooperative_matrix_store_inst(checked_flag cflag, store_flag sflag, + tinytc_value_t val0, + tinytc_value_t op0, tinytc_value_t p0, + tinytc_value_t p1, location const &lc) + : standard_inst{IK::cooperative_matrix_store}, cflag_(cflag), sflag_(sflag) { + op(op_val, val0); + op(op_operand, op0); + op(op_pos0, p0); + op(op_pos1, p1); + loc(lc); + + auto vt = get_coopmatrix_type(loc(), val()); + auto ot = get_memref_type(loc(), operand()); + if (vt->component_ty() != ot->element_ty()) { + throw compilation_error(loc(), {&val(), &operand()}, status::ir_scalar_mismatch); + } + if (ot->dim() != 2) { + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); + } + + check_index_ty(lc, pos0()); + check_index_ty(lc, pos1()); +} + +cumsum_inst::cumsum_inst(tinytc_value_t alpha0, tinytc_value_t A0, std::int64_t mode, + tinytc_value_t beta0, tinytc_value_t B0, bool atomic, location const &lc) + : blas_a2_inst(IK::cumsum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + std::move(B0), atomic, lc), + mode_(mode) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + + if (a->dim() < 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_non_scalar_memref); + } + if (mode_ >= a->dim()) { + throw compilation_error(loc(), {&A()}, status::ir_out_of_bounds); + } + + bool shape_equal = a->dim() == b->dim(); + if (shape_equal) { + for (std::int64_t i = 0; i < a->dim(); ++i) { + shape_equal = shape_equal && a->shape(i) == b->shape(i); + } + } + + if (!shape_equal) { + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); + } +} + +expand_inst::expand_inst(tinytc_value_t op0, std::int64_t expanded_mode, + array_view static_expand_shape0, + array_view expand_shape0, tinytc_data_type_t ty, location const &lc) - : op_(std::move(op)), mode_(mode), expand_shape_(std::move(expand_shape)) { + : standard_inst{IK::expand, static_cast(1 + expand_shape0.size())}, + expanded_mode_(expanded_mode), static_expand_shape_(std::move(static_expand_shape0)) { + op(0, op0); + for (std::size_t i = 0; i < expand_shape0.size(); ++i) { + check_index_ty(loc(), *expand_shape0[i]); + op(1 + i, expand_shape0[i]); + } loc(lc); - auto m = get_memref_type(loc(), op_); - bool const range_ok = 0 <= mode_ && mode_ < m->dim(); + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + bool const range_ok = 0 <= expanded_mode_ && expanded_mode_ < ot->dim(); if (!range_ok) { - throw compilation_error(loc(), status::ir_out_of_bounds); + throw compilation_error(loc(), {&operand()}, status::ir_out_of_bounds); } - if (expand_shape_.size() < 2) { + if (static_expand_shape_.size() < 2) { throw compilation_error(loc(), status::ir_expand_shape_order_too_small); } + if (std::count(static_expand_shape_.begin(), static_expand_shape_.end(), dynamic) != + num_operands() - 1) { + throw compilation_error(loc(), status::ir_expand_shape_mismatch); + } - auto known_expand_shape = std::vector(); - known_expand_shape.reserve(expand_shape_.size()); - std::size_t dyn_count = 0, non_imm_count = 0; - for (auto &s : expand_shape_) { - visit(overloaded{[&](int_imm &i) { - if (is_dynamic_value(i.value())) { - known_expand_shape.push_back(dynamic); - ++dyn_count; - return; - } - if (i.value() < 0) { - throw compilation_error(loc(), status::ir_invalid_shape); - } - known_expand_shape.push_back(i.value()); - }, - [&](auto &) { - known_expand_shape.push_back(dynamic); - ++non_imm_count; - }}, - *s); + for (std::int64_t i = 0; i < expanded_mode_; ++i) { + check_memref_mode(rt, i, ot, i, loc()); + } + auto stride = ot->stride(expanded_mode_); + for (std::size_t i = 0; i < static_expand_shape_.size(); ++i) { + const auto mode = expanded_mode_ + i; + if (rt->shape(mode) != static_expand_shape()[i]) { + auto extra_info = std::ostringstream{} + << "Size of mode " << mode << " does not match static expand shape (" + << rt->shape(mode) << "!=" << static_expand_shape()[i] << ")"; + throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); + } + if (!is_dynamic_value(rt->stride(mode)) && rt->stride(mode) != stride) { + auto extra_info = std::ostringstream{} << "Stride of mode " << mode << " is invalid (" + << rt->stride(mode) << "!=" << stride << ")"; + throw compilation_error(loc(), status::ir_invalid_stride, std::move(extra_info).str()); + } + stride = is_dynamic_value(stride) || is_dynamic_value(rt->shape(mode)) + ? dynamic + : stride * rt->shape(mode); } + for (std::int64_t i = expanded_mode_ + 1; i < ot->dim(); ++i) { + check_memref_mode(rt, i + static_expand_shape_.size() - 1, ot, i, loc()); + } + + result(0) = value_node{ty, this, lc}; +} - if (dyn_count > 1) { - throw compilation_error(loc(), status::ir_multiple_dynamic_modes); +for_inst::for_inst(tinytc_data_type_t loop_var_type, tinytc_value_t from0, tinytc_value_t to0, + tinytc_value_t step0, array_view init_values, + array_view return_types, location const &lc) + : loop_inst{IK::for_loop, (step0 ? 3 : 2) + static_cast(init_values.size()), + static_cast(init_values.size())} { + op(op_from, from0); + op(op_to, to0); + if (step0) { + op(op_step, step0); } + loc(lc); - auto size = m->shape(mode_); - if (!is_dynamic_value(size) && non_imm_count == 0) { - std::int64_t prod = 1; - std::int64_t dyn_mode = -1; - for (std::size_t i = 0; i < known_expand_shape.size(); ++i) { - auto const s = known_expand_shape[i]; - if (is_dynamic_value(s)) { - dyn_mode = i; - } else { - prod *= s; - } - } - if (dyn_mode >= 0) { - std::int64_t const s = size / prod; - known_expand_shape[dyn_mode] = s; - expand_shape_[dyn_mode] = make_imm(s); - prod *= s; + body().loc(lc); + body().defining_inst(this); + body().set_num_params(1 + init_values.size()); + body().set_param(0, loop_var_type); + for (std::size_t i = 0; i < return_types.size(); ++i) { + if (!isa(*return_types[i]) && !isa(*return_types[i]) && + !isa(*return_types[i])) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); } - if (prod != size) { - throw compilation_error(loc(), status::ir_expand_shape_mismatch); + body().set_param(1 + i, return_types[i]); + result(i) = value_node{return_types[i], this, lc}; + } + if (init_values.size() != return_types.size()) { + throw compilation_error(loc(), status::ir_init_return_type_mismatch); + } + for (std::size_t i = 0; i < init_values.size(); ++i) { + if (init_values[i]->ty() != return_types[i]) { + throw compilation_error(loc(), {init_values[i]}, status::ir_init_return_type_mismatch); } + op(op_init() + i, init_values[i]); } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim() + known_expand_shape.size() - 1); - stride.reserve(m->dim() + known_expand_shape.size() - 1); - for (std::int64_t i = 0; i < mode_; ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + auto lvt = get_scalar_type(loc(), loop_var()); + auto fromt = get_scalar_type(loc(), from()); + auto tot = get_scalar_type(loc(), to()); + + if (!is_integer_type(lvt->ty())) { + throw compilation_error(loc(), status::ir_expected_int); } + if (lvt->ty() != fromt->ty()) { + throw compilation_error(loc(), {&from()}, status::ir_scalar_mismatch); + } + if (lvt->ty() != tot->ty()) { + throw compilation_error(loc(), {&to()}, status::ir_scalar_mismatch); + } + if (has_step()) { + auto stept = get_scalar_type(loc(), step()); + if (lvt->ty() != stept->ty()) { + throw compilation_error(loc(), {&step()}, status::ir_scalar_mismatch); + } + } +} - stride.push_back(m->stride(mode_)); - shape.push_back(known_expand_shape[0]); - for (std::size_t j = 1; j < known_expand_shape.size(); ++j) { - stride.push_back(is_dynamic_value(stride.back()) || is_dynamic_value(shape.back()) - ? dynamic - : stride.back() * shape.back()); - shape.push_back(known_expand_shape[j]); +foreach_inst::foreach_inst(tinytc_data_type_t loop_var_type, array_view from, + array_view to, location const &lc) + : loop_inst{IK::foreach_loop, static_cast(from.size() + to.size()), + std::int64_t{0}} { + std::int64_t op_no = 0; + for (auto &v : from) { + op(op_no++, v); + } + for (auto &v : to) { + op(op_no++, v); } - for (std::int64_t i = mode_ + 1; i < m->dim(); ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + body().loc(lc); + body().defining_inst(this); + body().set_num_params(from.size()); + for (std::int64_t i = 0; i < static_cast(from.size()); ++i) { + body().set_param(i, loop_var_type); + } + child_region(0).kind(region_kind::spmd); + loc(lc); + + if (from.size() == 0 || from.size() != to.size()) { + throw compilation_error(loc(), status::ir_from_to_mismatch); } - auto r = std::make_unique(m->element_ty(), shape, stride); - r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); + if (auto lv_ty = dyn_cast(loop_var_type); lv_ty) { + if (!is_integer_type(lv_ty->ty()) || + std::any_of(op_begin(), op_end(), [&loop_var_type](tinytc_value &val) { + return val.ty() != loop_var_type; + })) { + throw compilation_error(loc(), status::ir_scalar_mismatch); + } + } else { + throw compilation_error(loc(), status::ir_expected_scalar); + } } -fuse_inst::fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc) - : op_(std::move(op)), from_(from), to_(to) { +fuse_inst::fuse_inst(tinytc_value_t op0, std::int64_t from, std::int64_t to, tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::fuse}, from_(from), to_(to) { + op(0, op0); loc(lc); - auto m = get_memref_type(loc(), op_); - bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); + + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + bool const range_ok = 0 <= from_ && from_ < to_ && to_ < ot->dim(); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - std::int64_t i = 0; - for (; i < from_; ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + + for (std::int64_t i = 0; i < from_; ++i) { + check_memref_mode(rt, i, ot, i, loc()); } + std::int64_t prod = 1; - for (; i <= to_; ++i) { - if (is_dynamic_value(m->shape(i))) { + for (std::int64_t i = from_; i <= to_; ++i) { + if (is_dynamic_value(ot->shape(i))) { prod = dynamic; break; } - prod *= m->shape(i); + prod *= ot->shape(i); } - shape.push_back(prod); - stride.push_back(m->stride(from_)); - for (i = to_ + 1; i < m->dim(); ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); + if (rt->shape(from_) != prod) { + auto extra_info = std::ostringstream{} << "Size of mode " << from_ + << " does not match shape product (" + << rt->shape(from_) << "!=" << prod << ")"; + throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); } - auto r = std::make_unique(m->element_ty(), shape, stride); - - r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); -} + check_memref_stride(rt, from_, ot, from_, loc()); -if_inst::if_inst(value condition, region then, region otherwise, - std::vector const &return_types, location const &lc) - : condition_(std::move(condition)), then_(std::move(then)), otherwise_(std::move(otherwise)) { - loc(lc); - for (auto &ty : return_types) { - results_.push_back(make_value(ty)); + for (std::int64_t i = to_ + 1; i < ot->dim(); ++i) { + check_memref_mode(rt, i - to_ + from_, ot, i, loc()); } + + result(0) = value_node{ty, this, lc}; } -load_inst::load_inst(value op, std::vector index_list, location const &lc) - : op_(std::move(op)), index_list_(std::move(index_list)) { +load_inst::load_inst(tinytc_value_t op0, array_view index_list0, + tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::load, static_cast(1 + index_list0.size())} { + op(0, op0); + for (std::size_t i = 0; i < index_list0.size(); ++i) { + check_index_ty(lc, *index_list0[i]); + op(1 + i, index_list0[i]); + } loc(lc); + visit(overloaded{ [&](group_data_type &g) { - if (static_cast(index_list_.size()) != 1) { + if (g.ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } + if (static_cast(index_list().size()) != 1) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result_ = make_value(g.ty()); + result(0) = value_node{ty, this, lc}; }, [&](memref_data_type &m) { - if (m.dim() != static_cast(index_list_.size())) { + if (m.element_data_ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } + if (m.dim() != static_cast(index_list().size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } - result_ = make_value(m.element_ty()); + result(0) = value_node{ty, this, lc}; }, [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, - *op_->ty()); + *operand().ty()); +} + +gemm_inst::gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha0, tinytc_value_t A0, + tinytc_value_t B0, tinytc_value_t beta0, tinytc_value_t C0, bool atomic, + location const &lc) + : blas_a3_inst(IK::gemm_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic, lc), + tA_(tA), tB_(tB) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() != 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_2); + } + if (c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_2); + } + + auto ak = tA_ == transpose::T ? 0 : 1; + auto bk = tB_ == transpose::T ? 1 : 0; + auto M = c->shape(0); + auto N = c->shape(1); + auto K = a->shape(ak); + if (a->shape(1 - ak) != M || b->shape(bk) != K || b->shape(1 - bk) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } } -size_inst::size_inst(value op, std::int64_t mode, location const &lc) - : op_(std::move(op)), mode_(mode) { +gemv_inst::gemv_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, + tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) + : blas_a3_inst(IK::gemv_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic, lc), + tA_(tA) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 1) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_1); + } + + auto ak = tA_ == transpose::T ? 0 : 1; + auto M = c->shape(0); + auto K = a->shape(ak); + if (a->shape(1 - ak) != M || b->shape(0) != K) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "c=" << c->shape(0); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } +} + +ger_inst::ger_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, + tinytc_value_t beta0, tinytc_value_t C0, bool atomic, location const &lc) + : blas_a3_inst(IK::ger_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic, lc) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_2); + } + + auto M = c->shape(0); + auto N = c->shape(1); + if (a->shape(0) != M || b->shape(0) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "a=" << a->shape(0) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } +} + +hadamard_inst::hadamard_inst(tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t B0, + tinytc_value_t beta0, tinytc_value_t C0, bool atomic, + location const &lc) + : blas_a3_inst(IK::hadamard_blas_a3, std::move(alpha0), std::move(A0), std::move(B0), + std::move(beta0), std::move(C0), atomic, lc) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 1 && a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1_or_2); + } + if (b->dim() != 1 && b->dim() != 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1_or_2); + } + if (c->dim() != 1 && c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_1_or_2); + } + if (c->dim() != a->dim() || c->dim() != b->dim()) { + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes); + } + + auto M = c->shape(0); + if (c->dim() == 1) { + if (a->shape(0) != M || b->shape(0) != M) { + std::ostringstream oss; + oss << "Got "; + oss << "a=" << a->shape(0) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "c=" << c->shape(0); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } + } else if (c->dim() == 2) { + auto N = c->shape(1); + if (a->shape(0) != M || a->shape(1) != N || b->shape(0) != M || b->shape(1) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } + } +} + +if_inst::if_inst(tinytc_value_t condition, array_view return_types, + location const &lc) + : standard_inst{IK::if_, 1, static_cast(return_types.size())} { + op(0, condition); loc(lc); - auto m = get_memref_type(loc(), op_); - bool const range_ok = 0 <= mode_ && mode_ < m->dim(); + then().loc(lc); + then().defining_inst(this); + otherwise().loc(lc); + otherwise().defining_inst(this); + if (!isa(*condition->ty())) { + throw compilation_error(loc(), {condition}, status::ir_expected_boolean); + } + for (std::size_t i = 0; i < return_types.size(); ++i) { + if (!isa(*return_types[i]) && !isa(*return_types[i]) && + !isa(*return_types[i])) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_scalar_or_boolean); + } + result(i) = value_node{return_types[i], this, lc}; + } +} + +math_unary_inst::math_unary_inst(math_unary operation, tinytc_value_t a0, tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::math_unary_}, operation_(operation) { + op(0, a0); + loc(lc); + + result(0) = value_node{ty, this, lc}; + + // Check if inst is supported for combination of a type and result type + auto a_ty = get_scalar_type(loc(), a()); + + const auto complex_supported = [](math_unary op) { + switch (op) { + case math_unary::exp: + case math_unary::exp2: + case math_unary::native_exp: + case math_unary::native_exp2: + return true; + default: + return false; + } + }(operation_); + + if (is_integer_type(a_ty->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_int_unsupported); + } else if (is_complex_type(a_ty->ty()) && !complex_supported) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } +} + +parallel_inst::parallel_inst(location const &lc) : standard_inst{IK::parallel} { + loc(lc); + + child_region(0).kind(region_kind::spmd); + child_region(0).loc(lc); + child_region(0).defining_inst(this); +} + +size_inst::size_inst(tinytc_value_t op0, std::int64_t mode, tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::size}, mode_(mode) { + op(0, op0); + loc(lc); + + auto rt = dyn_cast(ty); + if (!rt || rt->ty() != scalar_type::index) { + throw compilation_error(loc(), status::ir_expected_index); + } + + const bool range_ok = + visit(overloaded{[&](group_data_type &) -> bool { return 0 <= mode_ && mode_ < 1; }, + [&](memref_data_type &m) -> bool { return 0 <= mode_ && mode_ < m.dim(); }, + [&](auto &) -> bool { + throw compilation_error(loc(), status::ir_expected_memref_or_group); + }}, + *operand().ty()); if (!range_ok) { throw compilation_error(loc(), status::ir_out_of_bounds); } - result_ = make_value(scalar_type::index); + result(0) = value_node{ty, this, lc}; } -subview_inst::subview_inst(value op, std::vector slices, location const &lc) - : op_(std::move(op)), slices_(std::move(slices)) { +subgroup_broadcast_inst::subgroup_broadcast_inst(tinytc_value_t a0, tinytc_value_t idx0, + tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::subgroup_broadcast} { + op(0, a0); + op(1, idx0); loc(lc); - auto m = get_memref_type(loc(), op_); - if (m->dim() != static_cast(slices_.size())) { - throw compilation_error(loc(), status::ir_invalid_number_of_indices); + + if (!isa(*ty)) { + throw compilation_error(loc(), status::ir_expected_scalar); } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - for (std::int64_t i = 0; i < m->dim(); ++i) { - auto &slice = slices_[i]; - visit(overloaded{[&](int_imm &i) { - if (i.value() < 0) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *slice.first); - if (slice.second) { // if size is given - visit(overloaded{[&](int_imm &i) { - if (i.value() < 1 && !is_dynamic_value(i.value())) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *slice.second); - auto size = visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return is_dynamic_value(m->shape(i)) - ? dynamic - : m->shape(i) - offset.value(); - } - return size.value(); - }, - [&](val &, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return dynamic; - } - return size.value(); - }, - [](auto &, auto &) -> std::int64_t { return dynamic; }}, - *slice.first, *slice.second); - shape.push_back(size); - stride.push_back(m->stride(i)); - } - } - auto r = std::make_unique(m->element_ty(), shape, stride); - - r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); -} - -store_inst::store_inst(value val, value op, std::vector index_list, location const &lc) - : val_(std::move(val)), op_(std::move(op)), index_list_(std::move(index_list)) { - loc(lc); - auto v = get_scalar_type(loc(), val_); - auto o = get_memref_type(loc(), op_); + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } - if (v->ty() != o->element_ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); + if (auto idxt = dyn_cast(idx().ty()); + !idxt || idxt->ty() != scalar_type::i32) { + throw compilation_error(loc(), {&idx()}, status::ir_expected_i32); } - if (o->dim() != static_cast(index_list_.size())) { + result(0) = value_node{ty, this, lc}; +} + +subgroup_operation_inst::subgroup_operation_inst(group_arithmetic arith0, group_operation operation, + tinytc_value_t a0, tinytc_data_type_t ty, + location const &lc) + : standard_inst{IK::subgroup_operation}, arith_{arith0}, operation_{operation} { + op(0, a0); + loc(lc); + + auto sty = get_scalar_type(loc(), a()); + if (arith() != group_arithmetic::add && is_complex_type(sty->ty())) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } + + result(0) = value_node{ty, this, lc}; +} + +subview_inst::subview_inst(tinytc_value_t op0, array_view static_offsets0, + array_view static_sizes0, + array_view offsets0, array_view sizes0, + tinytc_data_type_t ty, location const &lc) + : standard_inst{IK::subview, static_cast(1 + offsets0.size() + sizes0.size())}, + static_offsets_(std::move(static_offsets0)), static_sizes_(std::move(static_sizes0)) { + op(0, op0); + { + std::size_t i = 1; + for (auto const &val : offsets0) { + check_index_ty(loc(), *val); + op(i++, val); + } + num_dyn_offsets_ = i - 1; + for (auto const &val : sizes0) { + check_index_ty(loc(), *val); + op(i++, val); + } + } + loc(lc); + + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + if (ot->dim() != static_cast(static_offsets_.size()) || + ot->dim() != static_cast(static_sizes_.size())) { throw compilation_error(loc(), status::ir_invalid_number_of_indices); } + if (std::count(static_offsets_.begin(), static_offsets_.end(), dynamic) != num_dyn_offsets_ || + std::count(static_sizes_.begin(), static_sizes_.end(), dynamic) != + num_operands() - num_dyn_offsets_ - 1) { + throw compilation_error(loc(), status::ir_subview_mismatch); + } + + std::int64_t ri = 0; + for (std::int64_t i = 0; i < ot->dim(); ++i) { + auto offset = static_offsets_[i]; + auto size = static_sizes_[i]; + if ((offset < 0 && !is_dynamic_value(offset)) || (size < 0 && !is_dynamic_value(size))) { + throw compilation_error(loc(), status::ir_invalid_slice); + } + if (size > 0 || is_dynamic_value(size)) { + if (rt->shape(ri) != size) { + auto extra_info = std::ostringstream{} << "Size of mode " << ri + << " does not match slice size [" + << rt->shape(ri) << "!=" << size << "]"; + throw compilation_error(loc(), status::ir_invalid_shape, + std::move(extra_info).str()); + } + check_memref_stride(rt, ri, ot, i, loc()); + ++ri; + } + } + + result(0) = value_node{ty, this, lc}; } -sum_inst::sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(beta), std::move(B), atomic), tA_(tA) { +store_inst::store_inst(store_flag flag, tinytc_value_t val0, tinytc_value_t op0, + array_view index_list0, location const &lc) + : standard_inst{IK::store, static_cast(2 + index_list0.size())}, flag_{flag} { + op(op_val, val0); + op(op_operand, op0); + { + std::size_t i = op_operand; + for (auto const &val : index_list0) { + check_index_ty(lc, *val); + op(++i, val); + } + } loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - bool const size_ok = (a->dim() == 2 && b->dim() == 1) || (a->dim() == 1 && b->dim() == 0); - if (!size_ok) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix); + auto v = get_scalar_type(loc(), val()); + auto o = get_memref_type(loc(), operand()); + + if (v->ty() != o->element_ty()) { + throw compilation_error(loc(), {&val(), &operand()}, status::ir_scalar_mismatch); + } + + if (o->dim() != static_cast(index_list0.size())) { + throw compilation_error(loc(), {&operand()}, status::ir_invalid_number_of_indices); + } +} + +sum_inst::sum_inst(transpose tA, tinytc_value_t alpha0, tinytc_value_t A0, tinytc_value_t beta0, + tinytc_value_t B0, bool atomic, location const &lc) + : blas_a2_inst(IK::sum_blas_a2, std::move(alpha0), std::move(A0), std::move(beta0), + std::move(B0), atomic, lc), + tA_(tA) { + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + + if (b->dim() == 1 && a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() == 0 && a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 0 && b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_0_or_1); } if (a->dim() == 2) { if (a->shape(tA_ == transpose::T ? 1 : 0) != b->shape(0)) { - throw compilation_error(loc(), status::ir_incompatible_shapes); + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); } } } +yield_inst::yield_inst(array_view vals, location const &lc) + : standard_inst{IK::yield, static_cast(vals.size())} { + loc(lc); + for (std::size_t i = 0; i < vals.size(); ++i) { + op(i, vals[i]); + } +} + } // namespace tinytc diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp index 1dc54846..b1152223 100644 --- a/src/node/inst_node.hpp +++ b/src/node/inst_node.hpp @@ -4,135 +4,366 @@ #ifndef INST_NODE_20230327_HPP #define INST_NODE_20230327_HPP -#include "reference_counted.hpp" -#include "slice.hpp" +#include "error.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/type_list.hpp" -#include - +#include +#include #include #include -#include +#include +#include +#include +#include #include namespace tinytc { //! Instruction classification -enum class inst_kind { - replicated, ///< replicated instruction executed in every work-item - collective ///< collective instruction distributed among work-items +enum class inst_execution_kind { + mixed, ///< mixed instruction on uniform or varying data + collective, ///< collective instruction on uniform data, distributed among work-items + spmd ///< SPMD instruction on varying data }; +enum class IK { + alloca, + arith, + arith_unary, + barrier, + builtin, + cast, + compare, + constant, + cooperative_matrix_apply, + cooperative_matrix_extract, + cooperative_matrix_insert, + cooperative_matrix_load, + cooperative_matrix_mul_add, + cooperative_matrix_prefetch, + cooperative_matrix_reduce, + cooperative_matrix_scale, + cooperative_matrix_store, + expand, + fuse, + load, + lifetime_stop, + if_, + math_unary_, + parallel, + size, + subgroup_broadcast, + subgroup_operation, + subview, + store, + yield, + // blas a2 + blas_a2, + axpby_blas_a2, + cumsum_blas_a2, + sum_blas_a2, + last_blas_a2, + // blas a3 + blas_a3, + gemm_blas_a3, + gemv_blas_a3, + ger_blas_a3, + hadamard_blas_a3, + last_blas_a3, + // loop inst + loop, + for_loop, + foreach_loop, + last_loop +}; using inst_nodes = - clir::virtual_type_list; + type_list; + +using result_range = iterator_range_wrapper; +using const_result_range = iterator_range_wrapper; + +using region_range = iterator_range_wrapper; +using const_region_range = iterator_range_wrapper; } // namespace tinytc -struct tinytc_inst : tinytc::reference_counted, tinytc::inst_nodes { +struct tinytc_inst : tinytc::ilist_node_with_parent { public: + using leaves = tinytc::inst_nodes; + + using op_iterator = + tinytc::indirect_random_access_iterator; + using const_op_iterator = + tinytc::indirect_random_access_iterator; + + using op_range = tinytc::iterator_range_wrapper; + using const_op_range = tinytc::iterator_range_wrapper; + + static_assert(std::random_access_iterator); + static_assert(std::random_access_iterator); + static_assert(std::ranges::random_access_range); + static_assert(std::ranges::random_access_range); + + inline tinytc_inst(tinytc::IK tid) : tid_(tid) {} + virtual ~tinytc_inst() = default; + + tinytc_inst(tinytc_inst const &other) = delete; + tinytc_inst(tinytc_inst &&other) = delete; + tinytc_inst &operator=(tinytc_inst const &other) = delete; + tinytc_inst &operator=(tinytc_inst &&other) = delete; + + auto context() const -> tinytc_compiler_context_t; + inline auto type_id() const -> tinytc::IK { return tid_; } + + inline auto attr() const noexcept -> tinytc_attr_t { return attr_; } + inline void attr(tinytc_attr_t attr) noexcept { attr_ = attr; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - virtual tinytc::value result() const = 0; - inline virtual auto results() const -> std::vector { - if (auto r = result(); r) { - return {std::move(r)}; - } - return {}; + // Iterator over operands + inline auto op_begin() -> op_iterator { return {op_begin_}; } + inline auto op_end() -> op_iterator { return {op_end_}; } + inline auto operands() -> op_range { return {op_begin(), op_end()}; } + inline auto op_begin() const -> const_op_iterator { return {op_begin_}; } + inline auto op_end() const -> const_op_iterator { return {op_end_}; } + inline auto operands() const -> const_op_range { return {op_begin(), op_end()}; } + inline auto op(std::size_t pos) -> tinytc_value & { return *op_begin_[pos].get(); } + inline auto op(std::size_t pos) const -> tinytc_value const & { return *op_begin_[pos].get(); } + inline void op(std::size_t pos, tinytc_value_t val) { op_begin_[pos] = val; } + inline auto get_use(std::size_t pos) -> tinytc::use & { return op_begin_[pos]; } + inline auto get_use(std::size_t pos) const -> tinytc::use const & { return op_begin_[pos]; } + inline auto num_operands() const -> std::int64_t { return op_end_ - op_begin_; } + + void subs(tinytc_value_t old_value, tinytc_value_t new_value, bool recursive = true); + + // Iterator over results + inline auto result_begin() -> tinytc_value_t { return result_begin_; } + inline auto result_end() -> tinytc_value_t { return result_end_; } + inline auto results() -> tinytc::result_range { return {result_begin_, result_end_}; } + inline auto result_begin() const -> const_tinytc_value_t { return result_begin_; } + inline auto result_end() const -> const_tinytc_value_t { return result_end_; } + inline auto results() const -> tinytc::const_result_range { + return {result_begin_, result_end_}; + } + inline auto result() const -> tinytc_value_t { + return num_results() > 0 ? result_begin_ : nullptr; + } + inline auto result(std::size_t pos) -> tinytc_value & { return result_begin_[pos]; } + inline auto result(std::size_t pos) const -> tinytc_value const & { return result_begin_[pos]; } + inline auto num_results() const -> std::int64_t { return result_end_ - result_begin_; } + + // Iterator over regions + inline auto child_regions_begin() -> tinytc_region_t { return child_regions_begin_; } + inline auto child_regions_end() -> tinytc_region_t { return child_regions_end_; } + inline auto child_regions() -> tinytc::region_range { + return tinytc::region_range{child_regions_begin(), child_regions_end()}; + } + inline auto child_regions_begin() const -> const_tinytc_region_t { + return child_regions_begin_; + } + inline auto child_regions_end() const -> const_tinytc_region_t { return child_regions_end_; } + inline auto child_regions() const -> tinytc::const_region_range { + return tinytc::const_region_range{child_regions_begin(), child_regions_end()}; + } + auto child_region(std::size_t pos) -> tinytc_region & { return child_regions_begin_[pos]; } + auto child_region(std::size_t pos) const -> tinytc_region const & { + return child_regions_begin_[pos]; + } + auto num_child_regions() const -> std::int64_t { + return child_regions_end_ - child_regions_begin_; + } + + auto kind() const -> tinytc::inst_execution_kind; + + protected: + inline auto set_op_range(tinytc::use *begin, tinytc::use *end) noexcept { + op_begin_ = begin; + op_end_ = end; + } + inline auto set_result_range(tinytc_value_t begin, tinytc_value_t end) noexcept { + result_begin_ = begin; + result_end_ = end; + } + inline auto set_child_regions_range(tinytc_region_t begin, tinytc_region_t end) noexcept { + child_regions_begin_ = begin; + child_regions_end_ = end; } - inline virtual auto num_results() const -> std::size_t { return result() ? 1u : 0u; } - virtual tinytc::inst_kind kind() const = 0; private: + tinytc::IK tid_; tinytc::location loc_; + tinytc::use *op_begin_ = nullptr, *op_end_ = nullptr; + tinytc_value_t result_begin_ = nullptr, result_end_ = nullptr; + tinytc_region_t child_regions_begin_ = nullptr, child_regions_end_ = nullptr; + tinytc_attr_t attr_ = nullptr; }; namespace tinytc { using inst_node = ::tinytc_inst; -class scalar_inst : public inst_node {}; +template class object_container { + public: + object_container(std::int64_t num_objects) { + // Check that num_objects is not larger than container size + // Smaller is ok too support optional arguments + if (num_objects > NumObjects) { + throw internal_compiler_error(); + } + } + auto get() -> T * { + if constexpr (NumObjects == 0) { + return nullptr; + } + return objs_.data(); + } + + private: + std::array objs_; +}; + +template class object_container { + public: + object_container(std::int64_t num_objects) : objs_{std::make_unique(num_objects)} {} + + auto get() -> T * { return objs_.get(); } + + private: + std::unique_ptr objs_; +}; + +template +class standard_inst : public inst_node { + public: + standard_inst(IK tid, std::int64_t num_operands = NumOperands, + std::int64_t num_results = NumResults, + std::int64_t num_child_regions = NumChildRegions) + : inst_node{tid}, ops_{num_operands}, results_{num_results}, + child_regions_{num_child_regions} { + if (num_operands > 0) { + auto *op_begin = ops_.get(); + set_op_range(op_begin, op_begin + num_operands); + if constexpr (NumOperands != 0) { + for (std::int64_t i = 0; i < num_operands; ++i) { + op_begin[i].owner(this); + } + } + } + if (num_results > 0) { + auto *result_begin = results_.get(); + set_result_range(result_begin, result_begin + num_results); + } + if (num_child_regions > 0) { + set_child_regions_range(child_regions_.get(), child_regions_.get() + num_child_regions); + } + } + + private: + object_container ops_; + object_container results_; + object_container child_regions_; +}; -class blas_a2_inst : public inst_node { +class blas_a2_inst : public standard_inst<4, 0> { public: - blas_a2_inst(value alpha, value A, value beta, value B, bool atomic); + inline static bool classof(inst_node const &i) { + return i.type_id() >= IK::blas_a2 && i.type_id() <= IK::last_blas_a2; + } + enum op_number { op_alpha = 0, op_A = 1, op_beta = 2, op_B = 3 }; + blas_a2_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic, location const &lc); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return alpha_; } - inline auto A() const -> value const & { return A_; } - inline auto beta() const -> value const & { return beta_; } - inline auto B() const -> value const & { return B_; } - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline auto alpha() -> tinytc_value & { return op(op_alpha); } + inline auto alpha() const -> tinytc_value const & { return op(op_alpha); } + inline auto A() -> tinytc_value & { return op(op_A); } + inline auto A() const -> tinytc_value const & { return op(op_A); } + inline auto beta() -> tinytc_value & { return op(op_beta); } + inline auto beta() const -> tinytc_value const & { return op(op_beta); } + inline auto B() -> tinytc_value & { return op(op_B); } + inline auto B() const -> tinytc_value const & { return op(op_B); } protected: - value alpha_, A_, beta_, B_; bool atomic_; }; -class blas_a3_inst : public inst_node { +class blas_a3_inst : public standard_inst<5, 0> { public: - blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic); + inline static bool classof(inst_node const &i) { + return i.type_id() >= IK::blas_a3 && i.type_id() <= IK::last_blas_a3; + } + enum op_number { op_alpha = 0, op_A = 1, op_B = 2, op_beta = 3, op_C = 4 }; + blas_a3_inst(IK tid, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic, location const &lc); inline bool atomic() const { return atomic_; } inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return alpha_; } - inline auto A() const -> value const & { return A_; } - inline auto B() const -> value const & { return B_; } - inline auto beta() const -> value const & { return beta_; } - inline auto C() const -> value const & { return C_; } - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline auto alpha() -> tinytc_value & { return op(op_alpha); } + inline auto alpha() const -> tinytc_value const & { return op(op_alpha); } + inline auto A() -> tinytc_value & { return op(op_A); } + inline auto A() const -> tinytc_value const & { return op(op_A); } + inline auto B() -> tinytc_value & { return op(op_B); } + inline auto B() const -> tinytc_value const & { return op(op_B); } + inline auto beta() -> tinytc_value & { return op(op_beta); } + inline auto beta() const -> tinytc_value const & { return op(op_beta); } + inline auto C() -> tinytc_value & { return op(op_C); } + inline auto C() const -> tinytc_value const & { return op(op_C); } protected: - value alpha_, A_, B_, beta_, C_; bool atomic_; }; -class loop_inst : public inst_node { +class loop_inst : public standard_inst { public: - loop_inst(value loop_var, value from, value to, region body, location const &loc = {}); - loop_inst(value loop_var, value from, value to, value step, region body, - location const &loc = {}); - inline auto loop_var() const -> value const & { return loop_var_; } - inline auto from() const -> value const & { return from_; } - inline auto to() const -> value const & { return to_; } - inline auto step() const -> value const & { return step_; } - inline auto body() const -> region const & { return body_; } - inline value result() const override { return value{}; } + inline static bool classof(inst_node const &i) { + return i.type_id() >= IK::loop && i.type_id() <= IK::last_loop; + } + inline loop_inst(IK tid, std::int64_t num_operands, std::int64_t num_results) + : standard_inst{tid, num_operands, num_results} {} - private: - value loop_var_, from_, to_, step_; - region body_; + inline auto body() -> tinytc_region & { return child_region(0); } + inline auto body() const -> tinytc_region const & { return child_region(0); } }; -class alloca_inst : public clir::visitable { +class alloca_inst : public standard_inst<0, 1> { public: - alloca_inst(data_type ty, location const &loc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::alloca; } + alloca_inst(tinytc_data_type_t ty, location const &loc = {}); - inline value result() const override { return result_; } inline std::int64_t stack_ptr() const { return stack_ptr_; } inline void stack_ptr(std::int64_t ptr) { stack_ptr_ = ptr; } - inline inst_kind kind() const override { return inst_kind::collective; } private: - value result_; std::int64_t stack_ptr_; }; -class axpby_inst : public clir::visitable { +class axpby_inst : public blas_a2_inst { public: - using super = clir::visitable; - axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, - location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::axpby_blas_a2; } + axpby_inst(transpose tA, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } @@ -140,155 +371,375 @@ class axpby_inst : public clir::visitable { transpose tA_; }; -class arith_inst : public clir::visitable { +class arith_inst : public standard_inst<2, 1> { public: - arith_inst(arithmetic op, value a, value b, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith; } + enum op_number { op_a = 0, op_b = 1 }; + arith_inst(arithmetic op, tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + location const &lc = {}); - inline arithmetic op() const { return op_; } - inline auto a() const -> value const & { return a_; } - inline auto b() const -> value const & { return b_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline arithmetic operation() const { return operation_; } + inline auto a() -> tinytc_value & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } + inline auto b() const -> tinytc_value const & { return op(op_b); } private: - arithmetic op_; - value a_, b_, result_; + arithmetic operation_; }; -class arith_unary_inst : public clir::visitable { +class arith_unary_inst : public standard_inst<1, 1> { public: - arith_unary_inst(arithmetic_unary op, value a, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::arith_unary; } + enum op_number { op_a = 0 }; + arith_unary_inst(arithmetic_unary op, tinytc_value_t a, tinytc_data_type_t ty, + location const &lc = {}); - inline arithmetic_unary op() const { return op_; } - inline auto a() const -> value const & { return a_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline arithmetic_unary operation() const { return operation_; } + inline auto a() -> tinytc_value & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } private: - arithmetic_unary op_; - value a_, result_; + arithmetic_unary operation_; }; -class barrier_inst : public clir::visitable { +class barrier_inst : public standard_inst<0, 0> { public: - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::barrier; } + inline barrier_inst(tinytc_address_spaces_t fence_flags, location const &lc = {}) + : standard_inst{IK::barrier}, fence_flags_(fence_flags) { + loc(lc); + } + + inline auto fence_flags() const -> tinytc_address_spaces_t { return fence_flags_; } + inline auto fence_flags(tinytc_address_spaces_t fence_flags) { fence_flags_ = fence_flags; } + inline auto has_fence(address_space as) const { + return (fence_flags_ & static_cast(as)) > 0; + } + + private: + tinytc_address_spaces_t fence_flags_; }; -class cast_inst : public clir::visitable { +class builtin_inst : public standard_inst<0, 1> { public: - cast_inst(value a, scalar_type to_ty, location const &lc = {}); - inline auto a() const -> value const & { return a_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::builtin; } + builtin_inst(builtin btype, tinytc_data_type_t ty, location const &lc = {}); + + inline auto builtin_type() const -> builtin { return btype_; } + + auto kind() const -> tinytc::inst_execution_kind; private: - value a_, result_; + builtin btype_; }; -class compare_inst : public clir::visitable { +class cast_inst : public standard_inst<1, 1> { public: - compare_inst(cmp_condition cond, value a, value b, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::cast; } + enum op_number { op_a = 0 }; + cast_inst(tinytc_value_t a, tinytc_data_type_t to_ty, location const &lc = {}); + + inline auto a() -> tinytc_value & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } +}; + +class compare_inst : public standard_inst<2, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::compare; } + enum op_number { op_a = 0, op_b = 1 }; + compare_inst(cmp_condition cond, tinytc_value_t a, tinytc_value_t b, tinytc_data_type_t ty, + location const &lc = {}); inline cmp_condition cond() const { return cond_; } - inline auto a() const -> value const & { return a_; } - inline auto b() const -> value const & { return b_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline auto a() -> tinytc_value & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } + inline auto b() const -> tinytc_value const & { return op(op_b); } private: cmp_condition cond_; - value a_, b_, result_; }; -class expand_inst : public clir::visitable { +class constant_inst : public standard_inst<0, 1> { public: - expand_inst(value op, std::int64_t mode, std::vector expand_shape, - location const &lc = {}); + using value_type = std::variant>; - inline auto operand() const -> value const & { return op_; } - inline std::int64_t mode() const { return mode_; } - inline auto expand_shape() const -> std::vector const & { return expand_shape_; } - inline auto expand_shape(std::int64_t i) const -> value const & { return expand_shape_[i]; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::constant; } + constant_inst(value_type const &value, tinytc_data_type_t ty, location const &lc = {}); + + auto value() const -> value_type const & { return value_; } + auto is_zero() const -> bool; + auto is_identity() const -> bool; private: - value op_, result_; - std::int64_t mode_; - std::vector expand_shape_; + value_type value_; +}; + +class cooperative_matrix_apply_inst : public standard_inst<1, 1, 1> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_apply; + } + cooperative_matrix_apply_inst(tinytc_value_t a, tinytc_data_type_t ty, + location const &loc = {}); + + inline auto a() -> tinytc_value & { return op(0); } + inline auto a() const -> tinytc_value const & { return op(0); } + + inline auto body() -> tinytc_region & { return child_region(0); } + inline auto body() const -> tinytc_region const & { return child_region(0); } + inline auto row() -> tinytc_value & { return body().param(0); } + inline auto row() const -> tinytc_value const & { return body().param(0); } + inline auto col() -> tinytc_value & { return body().param(1); } + inline auto col() const -> tinytc_value const & { return body().param(1); } + inline auto val() -> tinytc_value & { return body().param(2); } + inline auto val() const -> tinytc_value const & { return body().param(2); } }; -class fuse_inst : public clir::visitable { +class cooperative_matrix_extract_inst : public standard_inst<1, 1, 0> { public: - fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc = {}); + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_extract; + } + cooperative_matrix_extract_inst(tinytc_value_t mat, std::int64_t index, tinytc_data_type_t ty, + location const &loc = {}); - inline auto operand() const -> value const & { return op_; } - inline std::int64_t from() const { return from_; } - inline std::int64_t to() const { return to_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline auto mat() -> tinytc_value & { return op(0); } + inline auto mat() const -> tinytc_value const & { return op(0); } + inline auto index() const -> std::int64_t { return index_; } private: - value op_, result_; - std::int64_t from_, to_; + std::int64_t index_; }; -class load_inst : public clir::visitable { +class cooperative_matrix_insert_inst : public standard_inst<2, 1, 0> { public: - load_inst(value op, std::vector index_list, location const &lc = {}); + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_insert; + } + cooperative_matrix_insert_inst(tinytc_value_t val, tinytc_value_t mat, std::int64_t index, + tinytc_data_type_t ty, location const &loc = {}); - inline auto operand() const -> value const & { return op_; } - inline auto index_list() const -> std::vector const & { return index_list_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline auto val() -> tinytc_value & { return op(0); } + inline auto val() const -> tinytc_value const & { return op(0); } + inline auto mat() -> tinytc_value & { return op(1); } + inline auto mat() const -> tinytc_value const & { return op(1); } + inline auto index() const -> std::int64_t { return index_; } private: - value op_; - std::vector index_list_; - value result_; + std::int64_t index_; }; -class group_id_inst : public clir::visitable { +class cooperative_matrix_load_inst : public standard_inst<3, 1, 0> { public: - inline group_id_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { - loc(lc); + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_load; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + enum op_number { op_operand = 0, op_pos0 = 1, op_pos1 = 2 }; + cooperative_matrix_load_inst(transpose t, checked_flag flag, tinytc_value_t op0, + tinytc_value_t p0, tinytc_value_t p1, tinytc_data_type_t to_ty, + location const &lc = {}); + + inline auto t() const -> transpose { return t_; } + inline auto checked() const -> checked_flag { return flag_; } + inline auto operand() -> tinytc_value & { return op(op_operand); } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() -> tinytc_value & { return op(op_pos0); } + inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() -> tinytc_value & { return op(op_pos1); } + inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } + + auto kind() const -> tinytc::inst_execution_kind; private: - value result_; + transpose t_; + checked_flag flag_; }; -class group_size_inst : public clir::visitable { +class cooperative_matrix_mul_add_inst : public standard_inst<3, 1, 0> { public: - inline group_size_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { - loc(lc); + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_mul_add; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + enum op_number { op_a = 0, op_b = 1, op_c = 2 }; + cooperative_matrix_mul_add_inst(tinytc_value_t a0, tinytc_value_t b0, tinytc_value_t c0, + tinytc_data_type_t to_ty, location const &lc = {}); + + inline auto a() -> tinytc_value & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } + inline auto b() const -> tinytc_value const & { return op(op_b); } + inline auto c() -> tinytc_value & { return op(op_c); } + inline auto c() const -> tinytc_value const & { return op(op_c); } + + auto is_c_zero() const -> bool; +}; + +class cooperative_matrix_prefetch_inst : public standard_inst<3, 0, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_prefetch; + } + enum op_number { op_operand = 0, op_pos0 = 1, op_pos1 = 2 }; + cooperative_matrix_prefetch_inst(std::int32_t cache_level, tinytc_value_t op0, + tinytc_value_t p0, tinytc_value_t p1, std::int32_t rows, + std::int32_t cols, location const &lc = {}); + + inline auto cache_level() const -> std::int32_t { return cache_level_; } + inline auto operand() -> tinytc_value & { return op(op_operand); } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() -> tinytc_value & { return op(op_pos0); } + inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() -> tinytc_value & { return op(op_pos1); } + inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } + inline auto rows() const -> std::int32_t { return rows_; } + inline auto cols() const -> std::int32_t { return cols_; } + + private: + std::int32_t cache_level_, rows_, cols_; +}; + +class cooperative_matrix_reduce_inst : public standard_inst<1, 1, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_reduce; + } + cooperative_matrix_reduce_inst(group_arithmetic arith, reduce_mode mode, tinytc_value_t a, + tinytc_data_type_t ty, location const &lc = {}); + + inline auto arith() const -> group_arithmetic { return arith_; } + inline auto mode() const -> reduce_mode { return mode_; } + inline auto a() -> tinytc_value & { return op(0); } + inline auto a() const -> tinytc_value const & { return op(0); } + + private: + group_arithmetic arith_; + reduce_mode mode_; +}; + +class cooperative_matrix_scale_inst : public standard_inst<2, 1, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_scale; + } + enum op_number { op_a = 0, op_b = 1 }; + cooperative_matrix_scale_inst(tinytc_value_t a0, tinytc_value_t b0, tinytc_data_type_t ty, + location const &lc = {}); + + inline auto a() -> tinytc_value & { return op(op_a); } + inline auto a() const -> tinytc_value const & { return op(op_a); } + inline auto b() -> tinytc_value & { return op(op_b); } + inline auto b() const -> tinytc_value const & { return op(op_b); } +}; + +class cooperative_matrix_store_inst : public standard_inst<4, 0, 0> { + public: + inline static bool classof(inst_node const &i) { + return i.type_id() == IK::cooperative_matrix_store; + } + enum op_number { op_val = 0, op_operand = 1, op_pos0 = 2, op_pos1 = 3 }; + cooperative_matrix_store_inst(checked_flag cflag, store_flag sflag, tinytc_value_t val0, + tinytc_value_t op0, tinytc_value_t p0, tinytc_value_t p1, + location const &lc = {}); + + inline auto checked() const -> checked_flag { return cflag_; } + inline auto flag() const -> store_flag { return sflag_; } + inline auto val() -> tinytc_value & { return op(op_val); } + inline auto val() const -> tinytc_value const & { return op(op_val); } + inline auto operand() -> tinytc_value & { return op(op_operand); } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto pos0() -> tinytc_value & { return op(op_pos0); } + inline auto pos0() const -> tinytc_value const & { return op(op_pos0); } + inline auto pos1() -> tinytc_value & { return op(op_pos1); } + inline auto pos1() const -> tinytc_value const & { return op(op_pos1); } private: - value result_; + checked_flag cflag_; + store_flag sflag_; }; -class lifetime_stop_inst : public clir::visitable { +class cumsum_inst : public blas_a2_inst { public: - inline lifetime_stop_inst(value obj) : obj_(std::move(obj)) {} - inline auto object() const -> value const & { return obj_; } - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::cumsum_blas_a2; } + cumsum_inst(tinytc_value_t alpha, tinytc_value_t A, std::int64_t mode, tinytc_value_t beta, + tinytc_value_t B, bool atomic = false, location const &lc = {}); + + inline std::int64_t mode() const { return mode_; } + + private: + std::int64_t mode_; +}; + +class expand_inst : public standard_inst { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::expand; } + expand_inst(tinytc_value_t op, std::int64_t expanded_mode, + array_view static_expand_shape, + array_view expand_shape, tinytc_data_type_t ty, + location const &lc = {}); + + inline std::int64_t expanded_mode() const { return expanded_mode_; } + inline auto static_expand_shape() const -> array_view { + return static_expand_shape_; + } + + inline auto operand() -> tinytc_value & { return op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } + inline auto expand_shape() { return operands() | std::views::drop(1); } + inline auto expand_shape() const { return operands() | std::views::drop(1); } + inline auto expand_shape(std::int64_t i) const -> tinytc_value const & { return op(i + 1); } private: - value obj_; + std::int64_t expanded_mode_; + std::vector static_expand_shape_; }; -class gemm_inst : public clir::visitable { +class fuse_inst : public standard_inst<1, 1> { public: - using super = clir::visitable; - gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, - bool atomic = false, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::fuse; } + fuse_inst(tinytc_value_t op, std::int64_t from, std::int64_t to, tinytc_data_type_t ty, + location const &lc = {}); + + inline auto operand() -> tinytc_value & { return op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } + inline std::int64_t from() const { return from_; } + inline std::int64_t to() const { return to_; } + + private: + std::int64_t from_, to_; +}; + +class load_inst : public standard_inst { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::load; } + load_inst(tinytc_value_t op, array_view index_list, tinytc_data_type_t ty, + location const &lc = {}); + + inline auto operand() -> tinytc_value & { return op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } + inline auto index_list() { return operands() | std::views::drop(1); } + inline auto index_list() const { return operands() | std::views::drop(1); } +}; + +class lifetime_stop_inst : public standard_inst<1, 0> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::lifetime_stop; } + inline lifetime_stop_inst(tinytc_value_t obj, location const &lc = {}) + : standard_inst{IK::lifetime_stop} { + op(0, obj); + loc(lc); + } + + inline auto object() -> tinytc_value & { return op(0); } + inline auto object() const -> tinytc_value const & { return op(0); } +}; + +class gemm_inst : public blas_a3_inst { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::gemm_blas_a3; } + gemm_inst(transpose tA, transpose tB, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } inline transpose tB() const { return tB_; } @@ -297,11 +748,11 @@ class gemm_inst : public clir::visitable { transpose tA_, tB_; }; -class gemv_inst : public clir::visitable { +class gemv_inst : public blas_a3_inst { public: - using super = clir::visitable; - gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::gemv_blas_a3; } + gemv_inst(transpose tA, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, + tinytc_value_t beta, tinytc_value_t C, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } @@ -309,106 +760,199 @@ class gemv_inst : public clir::visitable { transpose tA_; }; -class ger_inst : public clir::visitable { +class ger_inst : public blas_a3_inst { public: - using super = clir::visitable; - ger_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::ger_blas_a3; } + ger_inst(tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, + tinytc_value_t C, bool atomic = false, location const &lc = {}); }; -class for_inst : public clir::visitable { +class for_inst : public loop_inst { public: - using super = clir::visitable; - using super::super; - inline inst_kind kind() const override { return inst_kind::replicated; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::for_loop; } + enum op_number { op_from = 0, op_to = 1, op_step = 2 }; + for_inst(tinytc_data_type_t loop_var_type, tinytc_value_t from, tinytc_value_t to, + tinytc_value_t step, array_view init_values, + array_view return_types, location const &loc = {}); + + inline auto from() -> tinytc_value & { return op(op_from); } + inline auto from() const -> tinytc_value const & { return op(op_from); } + inline auto to() -> tinytc_value & { return op(op_to); } + inline auto to() const -> tinytc_value const & { return op(op_to); } + inline auto has_step() const -> bool { return op_init() == 3; } + inline auto step() -> tinytc_value & { return op(op_step); } + inline auto step() const -> tinytc_value const & { return op(op_step); } + inline auto loop_var() -> tinytc_value & { return body().param(0); } + inline auto loop_var() const -> tinytc_value const & { return body().param(0); } + inline auto iter_arg(std::int64_t no) -> tinytc_value & { return body().param(no + 1); } + inline auto iter_arg(std::int64_t no) const -> tinytc_value const & { + return body().param(no + 1); + } + inline auto iter_init(std::int64_t no) -> tinytc_value & { return op(op_init() + no); } + inline auto iter_init(std::int64_t no) const -> tinytc_value const & { + return op(op_init() + no); + } + inline auto iter_init() { return operands() | std::views::drop(op_init()); } + inline auto iter_init() const { return operands() | std::views::drop(op_init()); } + + private: + inline auto op_init() const -> std::int64_t { return num_operands() - num_results(); } }; -class foreach_inst : public clir::visitable { +class foreach_inst : public loop_inst { public: - using super = clir::visitable; - inline foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : super(std::move(loop_var), std::move(from), std::move(to), std::move(body), loc) {} - inline inst_kind kind() const override { return inst_kind::collective; } + inline static bool classof(inst_node const &i) { return i.type_id() == IK::foreach_loop; } + foreach_inst(tinytc_data_type_t loop_var_type, array_view from, + array_view to, location const &lc = {}); + + inline auto dim() const -> std::int64_t { return num_operands() / 2; } + inline auto loop_vars() { return body().params(); } + inline auto loop_vars() const { return body().params(); } + inline auto from() { return operands() | std::views::take(dim()); } + inline auto from() const { return operands() | std::views::take(dim()); } + inline auto to() { return operands() | std::views::drop(dim()); } + inline auto to() const { return operands() | std::views::drop(dim()); } }; -class hadamard_inst : public clir::visitable { +class hadamard_inst : public blas_a3_inst { public: - using super = clir::visitable; - hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::hadamard_blas_a3; } + hadamard_inst(tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, + tinytc_value_t C, bool atomic = false, location const &lc = {}); }; -class if_inst : public clir::visitable { +class if_inst : public standard_inst<1, dynamic, 2> { public: - if_inst(value condition, region then, region otherwise = {}, - std::vector const &return_types = {}, location const &lc = {}); - inline auto condition() const -> value const & { return condition_; } - inline auto then() const -> region const & { return then_; } - inline auto otherwise() const -> region const & { return otherwise_; } - inline value result() const override { - return results_.size() > 0 ? results_.front() : value{}; + inline static bool classof(inst_node const &i) { return i.type_id() == IK::if_; } + enum child_region_number { child_region_then = 0, child_region_otherwise = 1 }; + if_inst(tinytc_value_t condition, array_view return_types = {}, + location const &lc = {}); + + inline auto condition() -> tinytc_value & { return op(0); } + inline auto condition() const -> tinytc_value const & { return op(0); } + inline auto then() -> tinytc_region & { return child_region(child_region_then); } + inline auto then() const -> tinytc_region const & { return child_region(child_region_then); } + inline auto otherwise() -> tinytc_region & { return child_region(child_region_otherwise); } + inline auto otherwise() const -> tinytc_region const & { + return child_region(child_region_otherwise); } - inline auto results() const -> std::vector override { return results_; } - inline auto num_results() const -> std::size_t override { return results_.size(); } - inline auto results_ref() -> std::vector & { return results_; } - inline auto results_ref() const -> std::vector const & { return results_; } - inline inst_kind kind() const override { return inst_kind::replicated; } + inline bool is_otherwise_empty() const { return otherwise().insts().empty(); } +}; + +class math_unary_inst : public standard_inst<1, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::math_unary_; } + math_unary_inst(math_unary op, tinytc_value_t a, tinytc_data_type_t ty, + location const &lc = {}); + + inline auto operation() const -> math_unary { return operation_; } + inline auto a() -> tinytc_value & { return op(0); } + inline auto a() const -> tinytc_value const & { return op(0); } private: - value condition_; - region then_, otherwise_; - std::vector results_; + math_unary operation_; +}; + +class parallel_inst : public standard_inst<0, 0, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::parallel; } + parallel_inst(location const &lc = {}); + + inline auto body() -> tinytc_region & { return child_region(0); } + inline auto body() const -> tinytc_region const & { return child_region(0); } }; -class size_inst : public clir::visitable { +class size_inst : public standard_inst<1, 1> { public: - size_inst(value op, std::int64_t mode, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::size; } + size_inst(tinytc_value_t op, std::int64_t mode, tinytc_data_type_t ty, location const &lc = {}); - inline auto operand() const -> value const & { return op_; } + inline auto operand() -> tinytc_value & { return op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } inline std::int64_t mode() const { return mode_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } private: - value op_, result_; std::int64_t mode_; }; -class subview_inst : public clir::visitable { +class subgroup_broadcast_inst : public standard_inst<2, 1> { public: - subview_inst(value op, std::vector slices, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_broadcast; } + subgroup_broadcast_inst(tinytc_value_t a, tinytc_value_t idx, tinytc_data_type_t ty, + location const &lc = {}); + + inline auto a() -> tinytc_value & { return op(0); } + inline auto a() const -> tinytc_value const & { return op(0); } + inline auto idx() -> tinytc_value & { return op(1); } + inline auto idx() const -> tinytc_value const & { return op(1); } +}; - inline auto slices() const -> std::vector const & { return slices_; } - inline auto operand() const -> value const & { return op_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } +class subgroup_operation_inst : public standard_inst<1, 1> { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subgroup_operation; } + subgroup_operation_inst(group_arithmetic arith, group_operation operation, tinytc_value_t a, + tinytc_data_type_t ty, location const &lc = {}); + + inline auto arith() const -> group_arithmetic { return arith_; } + inline auto operation() const -> group_operation { return operation_; } + inline auto a() -> tinytc_value & { return op(0); } + inline auto a() const -> tinytc_value const & { return op(0); } private: - value op_; - std::vector slices_; - value result_; + group_arithmetic arith_; + group_operation operation_; }; -class store_inst : public clir::visitable { +class subview_inst : public standard_inst { public: - store_inst(value val, value op, std::vector index_list, location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::subview; } + subview_inst(tinytc_value_t op, array_view static_offsets, + array_view static_sizes, array_view offsets, + array_view sizes, tinytc_data_type_t ty, location const &lc = {}); + + inline auto static_offsets() const -> array_view { return static_offsets_; } + inline auto static_sizes() const -> array_view { return static_sizes_; } + + inline auto operand() -> tinytc_value & { return op(0); } + inline auto operand() const -> tinytc_value const & { return op(0); } + inline auto offsets() { + return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); + } + inline auto offsets() const { + return operands() | std::views::drop(1) | std::views::take(num_dyn_offsets_); + } + inline auto sizes() { return operands() | std::views::drop(1 + num_dyn_offsets_); } + inline auto sizes() const { return operands() | std::views::drop(1 + num_dyn_offsets_); } - inline auto val() const -> value const & { return val_; } - inline auto operand() const -> value const & { return op_; } - inline auto index_list() const -> std::vector const & { return index_list_; } - inline value result() const override { return {}; } - inline inst_kind kind() const override { return inst_kind::replicated; } + private: + std::vector static_offsets_, static_sizes_; + std::int32_t num_dyn_offsets_; +}; + +class store_inst : public standard_inst { + public: + inline static bool classof(inst_node const &i) { return i.type_id() == IK::store; } + enum op_number { op_val = 0, op_operand = 1 }; + store_inst(store_flag flag, tinytc_value_t val, tinytc_value_t op, + array_view index_list, location const &lc = {}); + + inline auto flag() const -> store_flag { return flag_; } + inline auto val() -> tinytc_value & { return op(op_val); } + inline auto val() const -> tinytc_value const & { return op(op_val); } + inline auto operand() -> tinytc_value & { return op(op_operand); } + inline auto operand() const -> tinytc_value const & { return op(op_operand); } + inline auto index_list() { return operands() | std::views::drop(2); } + inline auto index_list() const { return operands() | std::views::drop(2); } private: - value val_, op_; - std::vector index_list_; + store_flag flag_; }; -class sum_inst : public clir::visitable { +class sum_inst : public blas_a2_inst { public: - using super = clir::visitable; - sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, - location const &lc = {}); + inline static bool classof(inst_node const &i) { return i.type_id() == IK::sum_blas_a2; } + sum_inst(transpose tA, tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t beta, + tinytc_value_t B, bool atomic = false, location const &lc = {}); inline transpose tA() const { return tA_; } @@ -416,17 +960,10 @@ class sum_inst : public clir::visitable { transpose tA_; }; -class yield_inst : public clir::visitable { +class yield_inst : public standard_inst { public: - inline yield_inst(std::vector vals, location const &lc = {}) : vals_(std::move(vals)) { - loc(lc); - } - inline value result() const override { return value{}; } - inline auto vals() const -> std::vector const & { return vals_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - std::vector vals_; + inline static bool classof(inst_node const &i) { return i.type_id() == IK::yield; } + yield_inst(array_view vals, location const &lc = {}); }; } // namespace tinytc diff --git a/src/node/program_node.cpp b/src/node/program_node.cpp new file mode 100644 index 00000000..c2c8cc5e --- /dev/null +++ b/src/node/program_node.cpp @@ -0,0 +1,17 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/program_node.hpp" + +#include + +using namespace tinytc; + +extern "C" { + +tinytc_prog::tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc) + : ctx_{std::move(ctx)} { + loc(lc); +} +} + diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp index 372c41d7..76f76b45 100644 --- a/src/node/program_node.hpp +++ b/src/node/program_node.hpp @@ -4,44 +4,44 @@ #ifndef PROGRAM_NODE_20240208_HPP #define PROGRAM_NODE_20240208_HPP -#include "location.hpp" #include "reference_counted.hpp" #include "tinytc/tinytc.hpp" - -#include +#include "tinytc/types.h" +#include "util/iterator.hpp" #include #include -namespace tinytc { -using program_nodes = clir::virtual_type_list; -} - -struct tinytc_prog : tinytc::reference_counted, tinytc::program_nodes { +struct tinytc_prog final : tinytc::reference_counted { public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + using iterator = tinytc::indirect_random_access_iterator::iterator>; + using const_iterator = + tinytc::indirect_random_access_iterator::const_iterator>; - private: - tinytc::location loc_; -}; + tinytc_prog(tinytc::compiler_context ctx, tinytc_location const &lc = {}); -namespace tinytc { + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } -using program_node = ::tinytc_prog; + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } + inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } -class program : public clir::visitable { - public: - inline program(std::vector decls, location const &lc = {}) : decls_(std::move(decls)) { - loc(lc); - } - inline auto declarations() -> std::vector & { return decls_; } - inline auto declarations() const -> std::vector const & { return decls_; } + inline auto begin() -> iterator { return iterator{funcs_.begin()}; } + inline auto end() -> iterator { return iterator{funcs_.end()}; } + inline auto begin() const -> const_iterator { return const_iterator{funcs_.begin()}; } + inline auto end() const -> const_iterator { return const_iterator{funcs_.end()}; } + inline void push_back(tinytc::func fun) { funcs_.push_back(std::move(fun)); } private: - std::vector decls_; + tinytc::compiler_context ctx_; + std::vector funcs_; + tinytc_location loc_; }; +namespace tinytc { + +using program_node = ::tinytc_prog; + } // namespace tinytc #endif // PROGRAM_NODE_20240208_HPP diff --git a/src/node/region_node.cpp b/src/node/region_node.cpp new file mode 100644 index 00000000..afc3ae33 --- /dev/null +++ b/src/node/region_node.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/region_node.hpp" +#include "node/inst_node.hpp" +#include "tinytc/tinytc.h" + +namespace tinytc { + +auto ilist_callbacks::get_parent_region() -> tinytc_region * { + return reinterpret_cast(reinterpret_cast(this) - + tinytc_region::inst_list_offset()); +} + +void ilist_callbacks::node_added(tinytc_inst_t node) { + node->parent(get_parent_region()); +} +void ilist_callbacks::node_moved(tinytc_inst_t node) { + node->parent(get_parent_region()); +} +void ilist_callbacks::node_removed(tinytc_inst_t node) { tinytc_inst_destroy(node); } + +} // namespace tinytc + +using namespace tinytc; + +tinytc_region::tinytc_region() : def_inst_{nullptr}, kind_{tinytc::region_kind::mixed} {} + +tinytc_region::~tinytc_region() {} + +void tinytc_region::loc(tinytc::location const &loc) { + loc_ = loc; + for (auto ¶m : params_) { + param.loc(loc_); + } +} +void tinytc_region::defining_inst(tinytc_inst_t def_inst) { + def_inst_ = def_inst; + for (auto ¶m : params_) { + param.defining_inst(def_inst_); + } +} + +void tinytc_region::set_params(array_view param_types) { + params_.resize(param_types.size()); + for (std::size_t i = 0; i < param_types.size(); ++i) { + set_param(i, param_types[i]); + } +} + +void tinytc_region::set_num_params(std::size_t num_params) { params_.resize(num_params); } +void tinytc_region::set_param(std::size_t idx, tinytc_data_type_t param_type) { + params_[idx] = tinytc_value{param_type, def_inst_, loc_}; +} diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp index bcf1d983..74edb049 100644 --- a/src/node/region_node.hpp +++ b/src/node/region_node.hpp @@ -4,45 +4,97 @@ #ifndef REGION_NODE_20230908_HPP #define REGION_NODE_20230908_HPP -#include "reference_counted.hpp" +#include "node/value_node.hpp" #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/iterator.hpp" -#include - -#include +#include +#include +#include #include namespace tinytc { -using region_nodes = clir::virtual_type_list; -} -struct tinytc_region : tinytc::reference_counted, tinytc::region_nodes { +//! Instruction classification +enum class region_kind { mixed = 0x0, collective = 0x1, spmd = 0x2 }; + +template <> struct ilist_callbacks { + auto get_parent_region() -> tinytc_region *; + void node_added(tinytc_inst_t node); + void node_moved(tinytc_inst_t node); + void node_removed(tinytc_inst_t node); +}; + +} // namespace tinytc + +struct tinytc_region final { public: + using iterator = tinytc::ilist::iterator; + using const_iterator = tinytc::ilist::const_iterator; + + tinytc_region(); + ~tinytc_region(); + + tinytc_region(tinytc_region const &) = delete; + tinytc_region(tinytc_region &&) = delete; + tinytc_region &operator=(tinytc_region const &) = delete; + tinytc_region &operator=(tinytc_region &&) = delete; + + inline auto kind() const noexcept -> tinytc::region_kind { return kind_; } + inline void kind(tinytc::region_kind kind) noexcept { kind_ = kind; } + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + void loc(tinytc::location const &loc); + + // Can be nullptr, e.g. if the region is the body of a function + inline auto defining_inst() const -> tinytc_inst_t { return def_inst_; } + void defining_inst(tinytc_inst_t def_inst); + + inline auto begin() -> iterator { return insts_.begin(); } + inline auto end() -> iterator { return insts_.end(); } + inline auto insts() -> tinytc::ilist & { return insts_; } + inline auto begin() const -> const_iterator { return insts_.cbegin(); } + inline auto end() const -> const_iterator { return insts_.cend(); } + inline auto insts() const -> tinytc::ilist const & { return insts_; } + inline auto empty() const -> bool { return insts_.empty(); } + + inline auto param_begin() { return params_.begin(); } + inline auto param_end() { return params_.end(); } + inline auto params() { return tinytc::iterator_range_wrapper{param_begin(), param_end()}; } + inline auto param_begin() const { return params_.begin(); } + inline auto param_end() const { return params_.end(); } + inline auto param(std::int64_t pos) -> tinytc_value & { return params_[pos]; } + inline auto param(std::int64_t pos) const -> tinytc_value const & { return params_[pos]; } + inline auto params() const { + return tinytc::iterator_range_wrapper{param_begin(), param_end()}; + } + inline auto num_params() const noexcept -> std::int64_t { return params_.size(); } + void set_params(tinytc::array_view param_types); + void set_num_params(std::size_t num_params); + void set_param(std::size_t idx, tinytc_data_type_t param_type); private: + static auto inst_list_offset() -> std::size_t { + static_assert(std::is_standard_layout_v, "offsetof not guaranteed to work"); + return offsetof(tinytc_region, insts_); + } + friend struct tinytc::ilist_callbacks; + + tinytc_inst_t def_inst_; + tinytc::region_kind kind_; tinytc::location loc_; + // params_ must come before insts_ such that the dtors are called in the correct order + std::vector params_; + tinytc::ilist insts_; }; namespace tinytc { using region_node = ::tinytc_region; -class rgn : public clir::visitable { - public: - inline rgn(std::vector insts = {}, location const &lc = {}) : insts_(std::move(insts)) { - loc(lc); - } - - inline auto insts() -> std::vector & { return insts_; } - inline auto insts() const -> std::vector const & { return insts_; } - inline void insts(std::vector insts) { insts_ = std::move(insts); } - - private: - std::vector insts_; -}; - } // namespace tinytc #endif // REGION_NODE_20230908_HPP diff --git a/src/node/value_node.cpp b/src/node/value_node.cpp new file mode 100644 index 00000000..35941bad --- /dev/null +++ b/src/node/value_node.cpp @@ -0,0 +1,102 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/value_node.hpp" + +#include + +using namespace tinytc; + +tinytc_value::tinytc_value(tinytc_data_type_t ty, tinytc_inst_t def_inst, location const &lc) + : ty_{std::move(ty)}, loc_{lc}, def_inst_{def_inst} {} + +tinytc_value::~tinytc_value() { + assert(!has_uses() && "Destructor called for value that still has uses"); +} + +auto tinytc_value::use_begin() -> use_iterator { return {first_use_}; } +auto tinytc_value::use_end() -> use_iterator { return {nullptr}; } +auto tinytc_value::uses() -> iterator_range_wrapper { + return {use_begin(), use_end()}; +} +auto tinytc_value::use_begin() const -> const_use_iterator { return {first_use_}; } +auto tinytc_value::use_end() const -> const_use_iterator { return {nullptr}; } +auto tinytc_value::uses() const -> iterator_range_wrapper { + return {use_begin(), use_end()}; +} +auto tinytc_value::has_uses() const -> bool { return first_use_ != nullptr; } + +namespace tinytc { + +use::use(tinytc_inst_t owner) : owner_{owner} {} + +use::~use() { + if (value_) { + remove_use_from_current_list(); + } +} + +use &use::operator=(value_node *val) { + set(val); + return *this; +} + +void use::set(value_node *value) { + if (value_) { + remove_use_from_current_list(); + } + value_ = value; + if (value_) { + add_use_to_list(&value_->first_use_); + } +} + +/* + * Let next = &A.n and we have + * + * ...A|.p|.n-->B|.p|.n-->C|.p|.n... + * ...----| ^-------| ^-------| + * + * After inserting T (T = this) we want the following new or adjusted pointers + * + * ...A|.p|.n==>T|.p|.n==>B|.p|.n-->C|.p|.n... + * ...---| ^======| ^======| ^------| + * + * We need to set + * next_ = T.n -> B = *next + * next_->prev_ = B.p -> &T.n = &next_ + * prev_ = T.p -> &A.n = next + * *next = A.n -> T = this + */ +void use::add_use_to_list(use **next) { + next_ = *next; + if (next_) { + next_->prev_ = &next_; + } + prev_ = next; + *next = this; +} + +/* + * We want to remove T (T = this): + * + * ...A|.p|.n-->T|.p|.n-->B|.p|.n-->C|.p|.n... + * ...---| ^------| ^------| ^------| + * + * After removing T we want the following adjusted pointers + * + * ...A|.p|.n==>B|.p|.n-->C|.p|.n... + * ...---| ^======| ^------| + * + * We need to set + * next_->prev_ = B.p -> &A.n = prev_ + * *prev_ = A.n -> B = next_ + */ +void use::remove_use_from_current_list() { + if (next_) { + next_->prev_ = prev_; + } + *prev_ = next_; +} + +} // namespace tinytc diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp index 461bbf8d..f08ce828 100644 --- a/src/node/value_node.hpp +++ b/src/node/value_node.hpp @@ -4,88 +4,139 @@ #ifndef VALUE_NODE_20230309_HPP #define VALUE_NODE_20230309_HPP -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" +#include "node/data_type_node.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/iterator.hpp" -#include - -#include +#include +#include #include +#include #include namespace tinytc { -using value_nodes = clir::virtual_type_list; -} +class use; +class use_iterator; +class const_use_iterator; +}; // namespace tinytc -struct tinytc_value : tinytc::reference_counted, tinytc::value_nodes { +struct tinytc_value final { public: + tinytc_value(tinytc_data_type_t ty = nullptr, tinytc_inst_t def_inst_ = nullptr, + tinytc::location const &lc = {}); + ~tinytc_value(); + + tinytc_value(tinytc_value const &) = delete; + tinytc_value(tinytc_value &&) = default; + tinytc_value &operator=(tinytc_value const &) = delete; + tinytc_value &operator=(tinytc_value &&) = default; + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - virtual tinytc::data_type ty() const = 0; - virtual void ty(tinytc::data_type ty) = 0; - virtual auto name() const -> char const * = 0; - virtual void name(std::string name) = 0; - virtual auto has_name() const -> bool = 0; + inline auto ty() const -> tinytc_data_type_t { return ty_; } + + inline auto context() const -> tinytc_compiler_context_t { return ty_->context(); } + + inline auto name() const -> char const * { return name_.c_str(); } + inline void name(std::string name) { name_ = std::move(name); } + auto has_name() const -> bool { return !name_.empty(); } + + auto use_begin() -> tinytc::use_iterator; + auto use_end() -> tinytc::use_iterator; + auto uses() -> tinytc::iterator_range_wrapper; + auto use_begin() const -> tinytc::const_use_iterator; + auto use_end() const -> tinytc::const_use_iterator; + auto uses() const -> tinytc::iterator_range_wrapper; + auto has_uses() const -> bool; + + // Can be nullptr, e.g. if value is a region parameter + inline auto defining_inst() const -> tinytc_inst_t { return def_inst_; } + inline void defining_inst(tinytc_inst_t def_inst) { def_inst_ = def_inst; } private: + tinytc_data_type_t ty_; tinytc::location loc_; + tinytc_inst_t def_inst_ = nullptr; + std::string name_; + + friend class tinytc::use; + tinytc::use *first_use_ = nullptr; }; namespace tinytc { using value_node = ::tinytc_value; -class float_imm : public clir::visitable { +class use final { public: - inline float_imm(double v, scalar_type ty = scalar_type::f64) - : ty_{make_scalar(ty)}, value_(v) {} + use() = default; + use(tinytc_inst_t owner); + ~use(); - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } - inline auto name() const -> char const * override { return ""; } - inline void name(std::string) override {} - auto has_name() const -> bool override { return false; } + use(use &&other) = delete; + use(use const &other) = delete; + use &operator=(use &&other) = delete; + use &operator=(use const &other) = delete; - inline double value() const { return value_; } + use &operator=(value_node *val); - private: - data_type ty_; - double value_; -}; + inline auto get() -> value_node * { return value_; } + inline auto get() const -> value_node const * { return value_; } + void set(value_node *value); -class int_imm : public clir::visitable { - public: - inline int_imm(std::int64_t v, scalar_type ty = scalar_type::i64) - : ty_{make_scalar(ty)}, value_(v) {} + inline auto owner() const -> tinytc_inst_t { return owner_; } + inline void owner(tinytc_inst_t owner) { owner_ = owner; } - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } - inline auto name() const -> char const * override { return ""; } - inline void name(std::string) override {} - auto has_name() const -> bool override { return false; } - - inline std::int64_t value() const { return value_; } + inline auto next() -> use * { return next_; } + inline auto next() const -> use const * { return next_; } private: - data_type ty_; - std::int64_t value_; + void add_use_to_list(use **next); + void remove_use_from_current_list(); + + tinytc_inst_t owner_ = nullptr; + value_node *value_ = nullptr; + use **prev_ = nullptr; + use *next_ = nullptr; }; -class val : public clir::visitable { +namespace detail { +template class use_iterator_base { public: - inline val(data_type ty) : ty_(std::move(ty)) {} - - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } - inline auto name() const -> char const * override { return name_.c_str(); } - inline void name(std::string name) override { name_ = std::move(name); } - virtual auto has_name() const -> bool override { return !name_.empty(); } + using value_type = std::conditional_t; + using pointer = value_type *; + using reference = value_type &; + using difference_type = std::ptrdiff_t; + + use_iterator_base() : pos_{nullptr} {} + use_iterator_base(pointer pos) : pos_{std::move(pos)} {} + + auto operator*() const -> reference { return *pos_; } + auto operator->() const -> pointer { return pos_; } + auto operator++() -> Derived & { + pos_ = pos_->next(); + return *static_cast(this); + } + auto operator++(int) -> Derived { + auto old_pos = pos_; + pos_ = pos_->next(); + return Derived{old_pos}; + } + auto operator==(use_iterator_base const &other) const -> bool { return pos_ == other.pos_; } + auto operator!=(use_iterator_base const &other) const -> bool { return pos_ != other.pos_; } private: - data_type ty_; - std::string name_; + pointer pos_; }; +} // namespace detail + +class use_iterator : public detail::use_iterator_base {}; +class const_use_iterator : public detail::use_iterator_base {}; + +static_assert(std::forward_iterator); +static_assert(std::forward_iterator); } // namespace tinytc diff --git a/src/parser.cpp b/src/parser.cpp index 9abe3fb7..3ebe6c7e 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -3,8 +3,8 @@ #include "parser.hpp" +#include "compiler_context.hpp" #include "error.hpp" -#include "location.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" #include "parser/parser_impl.hpp" @@ -17,69 +17,31 @@ #include #include #include -#include -#include +#include #include namespace tinytc { -auto parse(std::uint64_t size, char const *input) -> prog { - auto const initial_loc = location{position{0, 1, 1}, position{0, 1, 1}}; - auto lex = lexer(size, input, initial_loc); - auto ctx = parse_context{}; - auto p = parser(lex, ctx); - if (p() == 0) { - return ctx.program(); - } - return prog{}; -} -} // namespace tinytc -using namespace tinytc; - -extern "C" { - -tinytc_source_context::tinytc_source_context() {} - -auto tinytc_source_context::parse(std::string name, std::string text) -> prog { - sources_.emplace_back(source_input{std::move(name), std::move(text)}); - std::int32_t source_id = static_cast(sources_.size()); +auto parse(std::string name, std::string text, compiler_context const &compiler_ctx) -> prog { + std::int32_t source_id = compiler_ctx->add_source(std::move(name), std::move(text)); auto const initial_loc = location{position{source_id, 1, 1}, position{source_id, 1, 1}}; - auto const &input = sources_.back(); - auto lex = lexer(input.text.size(), input.text.c_str(), initial_loc); - auto ctx = parse_context{}; - auto p = parser(lex, ctx); + auto [ir, ir_size] = compiler_ctx->source_text(source_id); + auto lex = lexer(ir_size, ir, initial_loc); + auto parse_ctx = parse_context{compiler_ctx}; + auto p = parser(lex, parse_ctx); if (p() == 0) { - return ctx.program(); - } - last_error_log_.clear(); - for (auto const &err : ctx.errors()) { - last_error_log_ = report_error_with_context(input.text.c_str(), input.text.size(), - input.name, err.first, err.second); + return parse_ctx.program(); } return prog{}; } -void tinytc_source_context::report_error(location const &l, char const *what, bool append) { - auto err = std::string{}; - if (l.begin.source_id >= 1 && static_cast(l.begin.source_id) <= sources_.size()) { - auto const &src = sources_[l.begin.source_id - 1]; - err = report_error_with_context(src.text.c_str(), src.text.size(), src.name, l, what); - } else { - err = (std::ostringstream{} << "\n" - << l << ": " << what) - .str(); - } - if (append) { - last_error_log_ += std::move(err); - } else { - last_error_log_ = std::move(err); - } -} +} // namespace tinytc + +using namespace tinytc; tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, - tinytc_source_context_t source_ctx) { + tinytc_compiler_context_t ctx) { if (prg == nullptr || filename == nullptr) { return tinytc_status_invalid_arguments; } @@ -89,9 +51,8 @@ tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, throw status::file_io_error; } auto ir = std::string(std::istreambuf_iterator{ir_stream}, {}); - - auto prog = source_ctx ? source_ctx->parse(std::string(filename), std::move(ir)) - : parse(ir.size(), ir.c_str()); + auto ctx_ = ctx ? compiler_context{ctx, true} : make_compiler_context(); + auto prog = parse(std::string(filename), std::move(ir), ctx_); if (!prog) { throw status::parse_error; } @@ -99,14 +60,14 @@ tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, }); } -tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t source_ctx) { +tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_compiler_context_t ctx) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { auto ir = std::string(std::istreambuf_iterator{std::cin}, {}); - auto prog = - source_ctx ? source_ctx->parse("", std::move(ir)) : parse(ir.size(), ir.c_str()); + auto ctx_ = ctx ? compiler_context{ctx, true} : make_compiler_context(); + auto prog = parse("", std::move(ir), ctx_); if (!prog) { throw status::parse_error; } @@ -115,71 +76,16 @@ tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t s } tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, char const *source, - tinytc_source_context_t source_ctx) { + tinytc_compiler_context_t ctx) { if (prg == nullptr || source_size == 0 || source == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto prog = source_ctx - ? source_ctx->parse("", std::string(source, source + source_size)) - : parse(source_size, source); + auto ctx_ = ctx ? compiler_context{ctx, true} : make_compiler_context(); + auto prog = parse("", std::string(source, source + source_size), ctx_); if (!prog) { throw status::parse_error; } *prg = prog.release(); }); } - -tinytc_status_t tinytc_source_context_create(tinytc_source_context_t *ctx) { - if (ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *ctx = std::make_unique().release(); }); -} - -tinytc_status_t tinytc_source_context_add_source(tinytc_source_context_t ctx, char const *name, - char const *text, int32_t *source_id) { - if (ctx == nullptr || name == nullptr || text == nullptr || source_id == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *source_id = ctx->add_source(name, text); }); -} - -tinytc_status_t tinytc_source_context_get_error_log(const_tinytc_source_context_t ctx, - char const **log) { - if (ctx == nullptr || log == nullptr) { - return tinytc_status_invalid_arguments; - } - *log = ctx->last_error_log().c_str(); // last_error_log and c_str are noexcept - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_context_report_error(tinytc_source_context_t ctx, - const tinytc_location_t *location, - char const *what, tinytc_bool_t append) { - if (ctx == nullptr || location == nullptr || what == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { ctx->report_error(*location, what, bool(append)); }); -} - -tinytc_status_t tinytc_source_context_release(tinytc_source_context_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_context_retain(tinytc_source_context_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/parser.hpp b/src/parser.hpp index 0d1cd1a9..f613edf8 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -4,48 +4,12 @@ #ifndef PARSER_20230614_HPP #define PARSER_20230614_HPP -#include "reference_counted.hpp" #include "tinytc/tinytc.hpp" -#include "tinytc/types.h" #include -#include -#include namespace tinytc { auto parse(std::uint64_t size, char const *input) -> prog; } -/** - * @brief Source manager - * - * The source manager can parse tensor programs from files, stdin, or memory. - * Source code is stored in the manager such that error messages can be enhanced - * with code context. - */ -struct tinytc_source_context : tinytc::reference_counted { - public: - //! @brief ctor - tinytc_source_context(); - - auto parse(std::string name, std::string text) -> tinytc::prog; - - inline auto add_source(char const *name, char const *text) -> std::int32_t { - sources_.emplace_back(source_input{std::string(name), std::string(text)}); - return static_cast(sources_.size()); - } - - //! Annotate context to error message - void report_error(tinytc_location const &l, char const *what, bool append = false); - //! Return error log of last parse call - inline auto last_error_log() const noexcept -> std::string const & { return last_error_log_; } - - private: - struct source_input { - std::string name, text; - }; - std::vector sources_; - std::string last_error_log_; -}; - #endif // PARSER_20230614_HPP diff --git a/src/parser/lexer.hpp b/src/parser/lexer.hpp index b49ad4d4..ad30db76 100644 --- a/src/parser/lexer.hpp +++ b/src/parser/lexer.hpp @@ -26,8 +26,6 @@ class lexer { std::uint64_t lex_number(char const *s, char const *e); std::int64_t lex_integer_constant(char const *s, char const *e); double lex_floating_constant(char const *s, char const *e); - scalar_type lex_integer_type(char const *s, char const *e); - scalar_type lex_floating_type(char const *s, char const *e); char const *input_; std::size_t len_; diff --git a/src/parser/lexer.re b/src/parser/lexer.re index e76b2b87..9d80d69c 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -25,6 +25,7 @@ lex: char const *b = YYCURSOR; step(loc_); auto const adv_loc = [&]() { columns(loc_, YYCURSOR - b); }; + // clang-format off /*!re2c re2c:define:YYCTYPE = char; re2c:yyfill:enable = 0; @@ -35,13 +36,12 @@ lex: newline = "\r"? "\n"; whitespace = [ \t\v\r]+; - identifier = [0-9]+ | ([a-zA-Z] [a-zA-Z0-9_]*); - local_identifier = "%" identifier; - global_identifier = "@" identifier; - - integer_type = "i" ("1" | "8" | "16" | "32" | "64") | "index"; - floating_type = "f" ("32" | "64"); - + unnamed_identifier = [0-9]+; + named_identifier = [a-zA-Z] [a-zA-Z0-9_]*; + local_unnamed_identifier = "%" unnamed_identifier; + local_named_identifier = "%" named_identifier; + global_identifier = "@" (unnamed_identifier | named_identifier); + string = "\"" [^\"]* "\""; digit = [0-9]; hexdigit = [0-9a-fA-F]; integer_constant = [-+]? digit+; @@ -49,12 +49,17 @@ lex: mantissa_hex = (hexdigit* "." hexdigit+ | hexdigit+ "."); exponent = [eE] [-+]? digit+; exponent_hex = [pP] [-+]? digit+; - floating_constant = [-+]? (mantissa exponent? | digit+ exponent); + floating_constant = [-+]? (mantissa exponent? | digit+ exponent | "inf" | "nan"); floating_constant_hex = [-+]? "0x" (mantissa_hex exponent_hex? | hexdigit+ exponent_hex); // identifier - local_identifier { + local_unnamed_identifier { + adv_loc(); + std::int64_t id = lex_integer_constant(b + 1, YYCURSOR); + return parser::make_LOCAL_IDENTIFIER(std::move(id), loc_); + } + local_named_identifier { adv_loc(); auto id = std::string(b + 1, YYCURSOR); return parser::make_LOCAL_IDENTIFIER(std::move(id), loc_); @@ -81,17 +86,23 @@ lex: // keywords "func" { adv_loc(); return parser::make_FUNC(loc_); } - "work_group_size" { adv_loc(); return parser::make_WORK_GROUP_SIZE(loc_); } - "subgroup_size" { adv_loc(); return parser::make_SUBGROUP_SIZE(loc_); } - "->" { adv_loc(); return parser::make_RETURNS(loc_); } + "->" { adv_loc(); return parser::make_ARROW(loc_); } "?" { adv_loc(); return parser::make_DYNAMIC(loc_); } ".n" { adv_loc(); return parser::make_NOTRANS(loc_); } ".t" { adv_loc(); return parser::make_TRANS(loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } + ".atomic_add" { adv_loc(); return parser::make_ATOMIC_ADD(loc_); } + ".atomic_max" { adv_loc(); return parser::make_ATOMIC_MAX(loc_); } + ".atomic_min" { adv_loc(); return parser::make_ATOMIC_MIN(loc_); } + "init" { adv_loc(); return parser::make_INIT(loc_); } + "local" { adv_loc(); return parser::make_LOCAL(loc_); } + "global" { adv_loc(); return parser::make_GLOBAL(loc_); } + ".local" { adv_loc(); return parser::make_LOCAL_ATTR(loc_); } + ".global" { adv_loc(); return parser::make_GLOBAL_ATTR(loc_); } // constants - "true" { adv_loc(); return parser::make_INTEGER_CONSTANT(1, loc_); } - "false" { adv_loc(); return parser::make_INTEGER_CONSTANT(0, loc_); } + "true" { adv_loc(); return parser::make_BOOLEAN_CONSTANT(true, loc_); } + "false" { adv_loc(); return parser::make_BOOLEAN_CONSTANT(false, loc_); } integer_constant { adv_loc(); auto i = lex_integer_constant(b, YYCURSOR); @@ -103,17 +114,26 @@ lex: return parser::make_FLOATING_CONSTANT(f, loc_); } - // types - integer_type { - adv_loc(); - auto t = lex_integer_type(b, YYCURSOR); - return parser::make_INTEGER_TYPE(t, loc_); - } - floating_type { - adv_loc(); - auto t = lex_floating_type(b, YYCURSOR); - return parser::make_FLOATING_TYPE(t, loc_); + // attributes + "attributes" { return parser::make_ATTRIBUTES(loc_); } + "alignment" | "shape_gcd" | "stride_gcd" | "subgroup_size" | "unroll" | "work_group_size" { + adv_loc(); return parser::make_ATTR_NAME(std::string(b, YYCURSOR), loc_); } + + // types + "bool" { return parser::make_BOOLEAN(loc_); } + "i8" { adv_loc(); return parser::make_INTEGER_TYPE(scalar_type::i8, loc_); } + "i16" { adv_loc(); return parser::make_INTEGER_TYPE(scalar_type::i16, loc_); } + "i32" { adv_loc(); return parser::make_INTEGER_TYPE(scalar_type::i32, loc_); } + "i64" { adv_loc(); return parser::make_INTEGER_TYPE(scalar_type::i64, loc_); } + "index" { adv_loc(); return parser::make_INTEGER_TYPE(scalar_type::index, loc_); } + "bf16" { adv_loc(); return parser::make_FLOATING_TYPE(scalar_type::bf16, loc_); } + "f16" { adv_loc(); return parser::make_FLOATING_TYPE(scalar_type::f16, loc_); } + "f32" { adv_loc(); return parser::make_FLOATING_TYPE(scalar_type::f32, loc_); } + "f64" { adv_loc(); return parser::make_FLOATING_TYPE(scalar_type::f64, loc_); } + "c32" { adv_loc(); return parser::make_FLOATING_TYPE(scalar_type::c32, loc_); } + "c64" { adv_loc(); return parser::make_FLOATING_TYPE(scalar_type::c64, loc_); } + "coopmatrix" { adv_loc(); return parser::make_COOPMATRIX(loc_); } "memref" { adv_loc(); return parser::make_MEMREF(loc_); } "group" { adv_loc(); return parser::make_GROUP(loc_); } @@ -121,60 +141,127 @@ lex: "offset" { adv_loc(); return parser::make_OFFSET(loc_); } "strided" { adv_loc(); return parser::make_STRIDED(loc_); } + // matrix use + "matrix_a" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::a, loc_); } + "matrix_b" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::b, loc_); } + "matrix_acc" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::acc, loc_); } + + // checked flag + ".rows_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::rows, loc_); } + ".cols_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::cols, loc_); } + ".both_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::both, loc_); } + // instructions - "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } - "arith" { adv_loc(); return parser::make_ARITH(loc_); } - "gemm" { adv_loc(); return parser::make_GEMM(loc_); } - "gemv" { adv_loc(); return parser::make_GEMV(loc_); } - "ger" { adv_loc(); return parser::make_GER(loc_); } - "hadamard" { adv_loc(); return parser::make_HADAMARD(loc_); } - "alloca" { adv_loc(); return parser::make_ALLOCA(loc_); } - "cast" { adv_loc(); return parser::make_CAST(loc_); } - "cmp" { adv_loc(); return parser::make_CMP(loc_); } - "expand" { adv_loc(); return parser::make_EXPAND(loc_); } - "fuse" { adv_loc(); return parser::make_FUSE(loc_); } - "load" { adv_loc(); return parser::make_LOAD(loc_); } - "group_id" { adv_loc(); return parser::make_GROUP_ID(loc_); } - "group_size" { adv_loc(); return parser::make_GROUP_SIZE(loc_); } - "for" { adv_loc(); return parser::make_FOR(loc_); } - "foreach" { adv_loc(); return parser::make_FOREACH(loc_); } - "if" { adv_loc(); return parser::make_IF(loc_); } - "else" { adv_loc(); return parser::make_ELSE(loc_); } - "size" { adv_loc(); return parser::make_SIZE(loc_); } - "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } - "store" { adv_loc(); return parser::make_STORE(loc_); } - "sum" { adv_loc(); return parser::make_SUM(loc_); } - "yield" { adv_loc(); return parser::make_YIELD(loc_); } + "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } + "barrier" { adv_loc(); return parser::make_BARRIER(loc_); } + "cumsum" { adv_loc(); return parser::make_CUMSUM(loc_); } + "gemm" { adv_loc(); return parser::make_GEMM(loc_); } + "gemv" { adv_loc(); return parser::make_GEMV(loc_); } + "ger" { adv_loc(); return parser::make_GER(loc_); } + "hadamard" { adv_loc(); return parser::make_HADAMARD(loc_); } + "alloca" { adv_loc(); return parser::make_ALLOCA(loc_); } + "cast" { adv_loc(); return parser::make_CAST(loc_); } + "constant" { adv_loc(); return parser::make_CONSTANT(loc_); } + "cooperative_matrix_apply" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_APPLY(loc_); } + "cooperative_matrix_extract" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_EXTRACT(loc_); } + "cooperative_matrix_insert" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_INSERT(loc_); } + "cooperative_matrix_load" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_LOAD(loc_); } + "cooperative_matrix_mul_add" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_MUL_ADD(loc_); } + "cooperative_matrix_prefetch" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_PREFETCH(loc_); } + "cooperative_matrix_scale" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_SCALE(loc_); } + "cooperative_matrix_store" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_STORE(loc_); } + "expand" { adv_loc(); return parser::make_EXPAND(loc_); } + "fuse" { adv_loc(); return parser::make_FUSE(loc_); } + "load" { adv_loc(); return parser::make_LOAD(loc_); } + "for" { adv_loc(); return parser::make_FOR(loc_); } + "foreach" { adv_loc(); return parser::make_FOREACH(loc_); } + "if" { adv_loc(); return parser::make_IF(loc_); } + "parallel" { adv_loc(); return parser::make_PARALLEL(loc_); } + "else" { adv_loc(); return parser::make_ELSE(loc_); } + "size" { adv_loc(); return parser::make_SIZE(loc_); } + "subgroup_broadcast" { adv_loc(); return parser::make_SUBGROUP_BROADCAST(loc_); } + "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } + "store" { adv_loc(); return parser::make_STORE(loc_); } + "sum" { adv_loc(); return parser::make_SUM(loc_); } + "yield" { adv_loc(); return parser::make_YIELD(loc_); } // binary op - ".add" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::add, loc_); } - ".sub" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::sub, loc_); } - ".mul" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::mul, loc_); } - ".div" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::div, loc_); } - ".rem" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::rem, loc_); } - ".shl" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::shl, loc_); } - ".shr" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::shr, loc_); } - ".and" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::and_, loc_); } - ".or" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::or_, loc_); } - ".xor" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::xor_, loc_); } + "arith.add" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::add, loc_); } + "arith.sub" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::sub, loc_); } + "arith.mul" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::mul, loc_); } + "arith.div" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::div, loc_); } + "arith.rem" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::rem, loc_); } + "arith.shl" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::shl, loc_); } + "arith.shr" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::shr, loc_); } + "arith.and" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::and_, loc_); } + "arith.or" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::or_, loc_); } + "arith.xor" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::xor_, loc_); } + "arith.min" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::min, loc_); } + "arith.max" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::max, loc_); } // unary op - ".neg" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::neg, loc_); } - ".not" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::not_, loc_); } + "arith.abs" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::abs, loc_); } + "arith.neg" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::neg, loc_); } + "arith.not" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::not_, loc_); } + "arith.conj" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::conj, loc_); } + "arith.im" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::im, loc_); } + "arith.re" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::re, loc_); } + + // builtin + "builtin.group_id.x" { adv_loc(); return parser::make_BUILTIN(builtin::group_id_x, loc_); } + "builtin.group_id.y" { adv_loc(); return parser::make_BUILTIN(builtin::group_id_y, loc_); } + "builtin.group_id.z" { adv_loc(); return parser::make_BUILTIN(builtin::group_id_z, loc_); } + "builtin.num_groups.x" { adv_loc(); return parser::make_BUILTIN(builtin::num_groups_x, loc_); } + "builtin.num_groups.y" { adv_loc(); return parser::make_BUILTIN(builtin::num_groups_y, loc_); } + "builtin.num_groups.z" { adv_loc(); return parser::make_BUILTIN(builtin::num_groups_z, loc_); } + "builtin.num_subgroups.x" { adv_loc(); return parser::make_BUILTIN(builtin::num_subgroups_x, loc_); } + "builtin.num_subgroups.y" { adv_loc(); return parser::make_BUILTIN(builtin::num_subgroups_y, loc_); } + "builtin.subgroup_size" { adv_loc(); return parser::make_BUILTIN(builtin::subgroup_size, loc_); } + "builtin.subgroup_id.x" { adv_loc(); return parser::make_BUILTIN(builtin::subgroup_id_x, loc_); } + "builtin.subgroup_id.y" { adv_loc(); return parser::make_BUILTIN(builtin::subgroup_id_y, loc_); } + "builtin.subgroup_linear_id" { adv_loc(); return parser::make_BUILTIN(builtin::subgroup_linear_id, loc_); } + "builtin.subgroup_local_id" { adv_loc(); return parser::make_BUILTIN(builtin::subgroup_local_id, loc_); } // comparison condition - ".eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, - loc_); } - ".ne" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ne, - loc_); } - ".gt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::gt, - loc_); } - ".ge" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ge, - loc_); } - ".lt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::lt, - loc_); } - ".le" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::le, - loc_); } + "cmp.eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, loc_); } + "cmp.ne" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ne, loc_); } + "cmp.gt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::gt, loc_); } + "cmp.ge" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ge, loc_); } + "cmp.lt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::lt, loc_); } + "cmp.le" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::le, loc_); } + + // math op + "math.cos" { adv_loc(); return parser::make_MATH_UNARY(math_unary::cos, loc_); } + "math.sin" { adv_loc(); return parser::make_MATH_UNARY(math_unary::sin, loc_); } + "math.exp" { adv_loc(); return parser::make_MATH_UNARY(math_unary::exp, loc_); } + "math.exp2" { adv_loc(); return parser::make_MATH_UNARY(math_unary::exp2, loc_); } + "math.native_cos" { adv_loc(); return parser::make_MATH_UNARY(math_unary::native_cos, loc_); } + "math.native_sin" { adv_loc(); return parser::make_MATH_UNARY(math_unary::native_sin, loc_); } + "math.native_exp" { adv_loc(); return parser::make_MATH_UNARY(math_unary::native_exp, loc_); } + "math.native_exp2" { adv_loc(); return parser::make_MATH_UNARY(math_unary::native_exp2, loc_); } + + // coopmatrix reduce + "cooperative_matrix_reduce.add.row" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE({group_arithmetic::add, reduce_mode::row}, loc_); } + "cooperative_matrix_reduce.add.column" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE({group_arithmetic::add, reduce_mode::column}, loc_); } + "cooperative_matrix_reduce.max.row" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE({group_arithmetic::max, reduce_mode::row}, loc_); } + "cooperative_matrix_reduce.max.column" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE({group_arithmetic::max, reduce_mode::column}, loc_); } + "cooperative_matrix_reduce.min.row" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE({group_arithmetic::min, reduce_mode::row}, loc_); } + "cooperative_matrix_reduce.min.column" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE({group_arithmetic::min, reduce_mode::column}, loc_); } + + // subgroup op + "subgroup_operation.add.exclusive_scan" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::add, group_operation::exclusive_scan}, loc_); } + "subgroup_operation.add.inclusive_scan" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::add, group_operation::inclusive_scan}, loc_); } + "subgroup_operation.add.reduce" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::add, group_operation::reduce}, loc_); } + "subgroup_operation.max.exclusive_scan" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::max, group_operation::exclusive_scan}, loc_); } + "subgroup_operation.max.inclusive_scan" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::max, group_operation::inclusive_scan}, loc_); } + "subgroup_operation.max.reduce" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::max, group_operation::reduce}, loc_); } + "subgroup_operation.min.exclusive_scan" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::min, group_operation::exclusive_scan}, loc_); } + "subgroup_operation.min.inclusive_scan" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::min, group_operation::inclusive_scan}, loc_); } + "subgroup_operation.min.reduce" { adv_loc(); return parser::make_SUBGROUP_OPERATION({group_arithmetic::min, group_operation::reduce}, loc_); } + + // other strings + string { adv_loc(); return parser::make_STRING(std::string(b+1, YYCURSOR-1), loc_); } + whitespace { adv_loc(); goto lex; } comment { adv_loc(); goto lex; } @@ -186,6 +273,7 @@ lex: throw parser::syntax_error(loc_, "Unknown token"); } */ + // clang-format on } std::uint64_t lexer::lex_number(char const *s, char const *e) { @@ -235,35 +323,4 @@ double lexer::lex_floating_constant(char const *s, char const *e) { return d; } -scalar_type lexer::lex_integer_type(char const *s, char const *) { - char const *YYMARKER; - /*!re2c - re2c:yyfill:enable = 0; - re2c:define:YYCURSOR = s; - - "i1" { return scalar_type::i1; } - "i8" { return scalar_type::i8; } - "i16" { return scalar_type::i16; } - "i32" { return scalar_type::i32; } - "i64" { return scalar_type::i64; } - "index" { return scalar_type::index; } - $ { return {}; } - * { return {}; } - */ - return scalar_type{}; -} -scalar_type lexer::lex_floating_type(char const *s, char const *) { - char const *YYMARKER; - /*!re2c - re2c:yyfill:enable = 0; - re2c:define:YYCURSOR = s; - - "f32" { return scalar_type::f32; } - "f64" { return scalar_type::f64; } - $ { return {}; } - * { return {}; } - */ - return scalar_type{}; -} - } // namespace tinytc diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 336c0226..2b906443 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -2,8 +2,8 @@ // SPDX-License-Identifier: BSD-3-Clause #include "parse_context.hpp" +#include "compiler_context.hpp" #include "location.hpp" -#include "node/function_node.hpp" #include "node/value_node.hpp" #include "parser/parser_impl.hpp" @@ -12,48 +12,80 @@ namespace tinytc { -void parse_context::push_scope() { id_map_.push_back({}); } -void parse_context::pop_scope() { id_map_.pop_back(); } +parse_context::parse_context(compiler_context compiler_ctx) : compiler_ctx_(compiler_ctx) {} -void parse_context::val(std::string const &id, value val, location const &l) { - for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { - if (auto other = it->find(id); other != it->end()) { - auto oss = std::ostringstream{}; - oss << "Identifier %" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(l, oss.str()); - } - } - val->loc(l); - id_map_.back()[id] = std::move(val); +void parse_context::push_scope() { + unnamed_id_map_.push_back({}); + named_id_map_.push_back({}); } +void parse_context::pop_scope() { + named_id_map_.pop_back(); + unnamed_id_map_.pop_back(); +} + +void parse_context::push_region(tinytc_region_t r) { regions_.push(r); } +void parse_context::pop_region() { regions_.pop(); } +auto parse_context::top_region() -> tinytc_region_t { return regions_.top(); } +auto parse_context::has_regions() -> bool { return !regions_.empty(); } -value parse_context::val(std::string const &id, location const &l) { - for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { - if (auto j = it->find(id); j != it->end()) { - return j->second; - } +void parse_context::add_global_name(std::string const &name, location const &l) { + if (auto other = global_names_.find(name); other != global_names_.end()) { + auto oss = std::ostringstream{}; + oss << "Identifier @" << name << " was already used at " << other->second; + throw parser::syntax_error(l, std::move(oss).str()); } - throw parser::syntax_error(l, "Undefined identifier %" + id); + global_names_[name] = l; } -void parse_context::prototype(std::string const &id, func p) { - if (auto other = prototype_map_.find(id); other != prototype_map_.end()) { - auto oss = std::ostringstream{}; - oss << "Identifier @" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(p->loc(), oss.str()); +void parse_context::val(std::variant const &id, tinytc_value &val, + location const &l) { + const auto handle_val = + [&l, &val](KeyT const &id, + std::vector> &map) { + if (map.empty()) { + throw parser::syntax_error(l, "No active scope"); + } + for (auto it = map.rbegin(); it != map.rend(); ++it) { + if (auto other = it->find(id); other != it->end()) { + auto oss = std::ostringstream{}; + oss << "Identifier %" << id << " was already used at " << other->second->loc(); + throw parser::syntax_error(l, std::move(oss).str()); + } + } + val.loc(l); + map.back()[id] = &val; + }; + if (std::holds_alternative(id)) { + handle_val(std::get(id), unnamed_id_map_); + } else { + auto const &sid = std::get(id); + handle_val(sid, named_id_map_); + val.name(sid); } - prototype_map_[id] = std::move(p); } -func parse_context::prototype(std::string const &id, location const &l) { - if (auto j = prototype_map_.find(id); j != prototype_map_.end()) { - return j->second; +auto parse_context::val(std::variant const &id, + location const &l) -> tinytc_value_t { + const auto handle_val = + [&l](KeyT const &id, + std::vector> &map) { + for (auto it = map.rbegin(); it != map.rend(); ++it) { + if (auto j = it->find(id); j != it->end()) { + return j->second; + } + } + auto oss = std::ostringstream{}; + oss << "Undefined identifier %" << id; + throw parser::syntax_error(l, std::move(oss).str()); + }; + if (std::holds_alternative(id)) { + return handle_val(std::get(id), unnamed_id_map_); } - throw parser::syntax_error(l, "Undefined identifier @" + id); + return handle_val(std::get(id), named_id_map_); } -void parse_context::add_error(location const &loc, std::string const &what) { - errors_.emplace_back(std::make_pair(loc, what)); +void parse_context::report_error(location const &loc, std::string const &what) { + compiler_ctx_->report_error(loc, what.c_str()); } } // namespace tinytc diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 832cc3a8..4a04867a 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -5,40 +5,51 @@ #define PARSE_CONTEXT_20231221_HPP #include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" +#include +#include #include #include #include +#include #include namespace tinytc { class parse_context { public: - inline parse_context() { id_map_.push_back({}); } + parse_context(compiler_context compiler_ctx); inline auto program() { return program_; } inline void program(prog p) { program_ = std::move(p); } + void val(std::variant const &id, tinytc_value &val, + location const &l); + auto val(std::variant const &id, + location const &l) -> tinytc_value_t; + + void report_error(location const &loc, std::string const &what); + + auto cctx() -> compiler_context const & { return compiler_ctx_; } + void push_scope(); void pop_scope(); - void val(std::string const &id, value val, location const &l); - value val(std::string const &id, location const &l); - - void prototype(std::string const &id, func p); - func prototype(std::string const &id, location const &l); - void add_error(location const &loc, std::string const &what); + void push_region(tinytc_region_t r); + void pop_region(); + auto top_region() -> tinytc_region_t; + auto has_regions() -> bool; - inline auto errors() const -> std::vector> const & { - return errors_; - } + void add_global_name(std::string const &name, location const &l); private: - std::vector> id_map_; - std::unordered_map prototype_map_; + compiler_context compiler_ctx_; + std::vector> unnamed_id_map_; + std::vector> named_id_map_; + std::stack regions_; + std::unordered_map global_names_; prog program_; - std::vector> errors_; }; } // namespace tinytc diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 07ff319e..8afa78a1 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -5,68 +5,67 @@ %language "c++" %code requires { - #include "node/function_node.hpp" - #include "slice.hpp" + #include "node/inst_node.hpp" #include "tinytc/tinytc.hpp" + #include "tinytc/types.h" #include "tinytc/types.hpp" - #include + #include #include #include + #include #include #include + #include + #include namespace tinytc { class parse_context; class lexer; + + using int_or_val = std::variant; + using unique_ptr_to_if_inst = std::unique_ptr; + + using identifier = std::variant; + struct param_attrs { + identifier id; + location loc; + attr dict; + }; } } %code { + #include "compiler_context.hpp" #include "error.hpp" - #include "node/data_type_node.hpp" - #include "node/inst_node.hpp" + #include "node/attr_node.hpp" + #include "node/function_node.hpp" #include "node/program_node.hpp" #include "node/region_node.hpp" #include "node/value_node.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" - #include "passes.hpp" - #include "util.hpp" - - #include + #include "util/ilist.hpp" + #include "util/iterator.hpp" + #include "util/visit.hpp" - #include + #include #include - #include - #include - #include + #include + #include #include + #include namespace tinytc { - void check_scalar_type(value & val, scalar_type const& sty, location & loc1, - location & loc2) { - clir::visit( - overloaded{[&](int_imm &i) { i.ty(make_scalar(sty)); }, - [&](float_imm &i) { i.ty(make_scalar(sty)); }, - [&](auto &) { - if (!val->ty() || !is_equal(*val->ty(), *make_scalar(sty))) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error( - loc, "Type of SSA value does not match operand type"); - } - }}, - *val); - } - void check_type(value & val, data_type & ty, location & loc1, - location & loc2) { - if (!val->ty() || !is_equal(*val->ty(), *ty)) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - }; + + void report_error(compiler_context const& cctx, compilation_error const& e) { + if (e.extra_info().size() > 0) { + auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); + cctx.get()->report_error(e.loc(), e.ref_values(), what.c_str()); + } else { + cctx.get()->report_error(e.loc(), e.ref_values(), e.what()); + } } + } // namespace tinytc } %header @@ -98,26 +97,44 @@ LSQBR "[" RSQBR "]" FUNC "func" - WORK_GROUP_SIZE "work_group_size" - SUBGROUP_SIZE "subgroup_size" - RETURNS "->" + ATTRIBUTES "attributes" + ARROW "->" DYNAMIC "?" NOTRANS ".n" TRANS ".t" ATOMIC ".atomic" + ATOMIC_ADD ".atomic_add" + ATOMIC_MAX ".atomic_max" + ATOMIC_MIN ".atomic_min" + INIT "init" + LOCAL "local" + GLOBAL "global" + LOCAL_ATTR ".local" + GLOBAL_ATTR ".global" + BOOLEAN "bool" + COOPMATRIX "coopmatrix" MEMREF "memref" GROUP "group" OFFSET "offset" STRIDED "strided" AXPBY "axpby" - ARITH "arith" + BARRIER "barrier" + CUMSUM "cumsum" GEMM "gemm" GEMV "gemv" GER "ger" HADAMARD "hadamard" ALLOCA "alloca" CAST "cast" - CMP "cmp" + CONSTANT "constant" + COOPERATIVE_MATRIX_APPLY "cooperative_matrix_apply" + COOPERATIVE_MATRIX_EXTRACT "cooperative_matrix_extract" + COOPERATIVE_MATRIX_INSERT "cooperative_matrix_insert" + COOPERATIVE_MATRIX_LOAD "cooperative_matrix_load" + COOPERATIVE_MATRIX_MUL_ADD "cooperative_matrix_mul_add" + COOPERATIVE_MATRIX_PREFETCH "cooperative_matrix_prefetch" + COOPERATIVE_MATRIX_SCALE "cooperative_matrix_scale" + COOPERATIVE_MATRIX_STORE "cooperative_matrix_store" EXPAND "expand" FUSE "fuse" LOAD "load" @@ -125,98 +142,134 @@ FOREACH "foreach" IF "if" ELSE "else" - GROUP_ID "group_id" - GROUP_SIZE "group_size" + PARALLEL "parallel" SIZE "size" + SUBGROUP_BROADCAST "subgroup_broadcast" SUBVIEW "subview" STORE "store" SUM "sum" YIELD "yield" ; -%token LOCAL_IDENTIFIER +%token LOCAL_IDENTIFIER %token GLOBAL_IDENTIFIER +%token ATTR_NAME +%token STRING +%token BOOLEAN_CONSTANT %token INTEGER_CONSTANT %token FLOATING_CONSTANT %token INTEGER_TYPE %token FLOATING_TYPE %token ARITHMETIC %token ARITHMETIC_UNARY +%token BUILTIN %token CMP_CONDITION +%token > COOPERATIVE_MATRIX_REDUCE +%token > SUBGROUP_OPERATION +%token MATH_UNARY +%token MATRIX_USE +%token CHECKED %nterm prog %nterm > func_list %nterm func -%nterm > arguments -%nterm <::tinytc::value> argument -%nterm >> attributes -%nterm > attribute -%nterm data_type -%nterm scalar_type -%nterm memref_type +%nterm ,std::vector>> parameters +%nterm > parameter +%nterm function_attributes +%nterm attribute +%nterm array_attribute +%nterm > attribute_list +%nterm dictionary_attribute +%nterm > named_attribute_list +%nterm named_attribute +%nterm attribute_name +%nterm optional_dictionary_attribute +%nterm data_type +%nterm boolean_type +%nterm scalar_type +%nterm coopmatrix_type +%nterm memref_type +%nterm optional_address_space %nterm > mode_list %nterm > optional_stride_list %nterm > stride_list %nterm constant_or_dynamic -%nterm group_type +%nterm group_type %nterm group_offset -%nterm memref_or_group_type -%nterm region -%nterm <::tinytc::value> var -%nterm > instructions +%nterm var %nterm instruction %nterm axpby_inst %nterm atomic -%nterm <::tinytc::value> identifier_or_constant -%nterm > optional_identifier_or_constant_list -%nterm > identifier_or_constant_list +%nterm > optional_value_list +%nterm > value_list +%nterm barrier_inst +%nterm optional_global_attr +%nterm optional_local_attr +%nterm cumsum_inst %nterm gemm_inst %nterm gemv_inst %nterm ger_inst %nterm transpose %nterm for_inst -%nterm <::tinytc::value> optional_step +%nterm , std::vector, std::vector>> optional_loop_carried_values +%nterm , std::vector>> init_value_list +%nterm > init_value +%nterm optional_step %nterm foreach_inst %nterm hadamard_inst %nterm if_inst -%nterm > optional_returned_values -%nterm > optional_scalar_type_list -%nterm > scalar_type_list -%nterm else_region +%nterm > optional_returned_values +%nterm > optional_return_type_list +%nterm > return_type_list %nterm sum_inst %nterm yield_inst -%nterm for_loop_var_type +%nterm for_loop_var_type %nterm var_definition -%nterm > identifier_list +%nterm > identifier_list %nterm valued_inst %nterm alloca_inst %nterm arith_inst %nterm arith_unary_inst +%nterm builtin_inst %nterm cast_inst %nterm compare_inst +%nterm constant_inst +%nterm cooperative_matrix_apply_inst +%nterm cooperative_matrix_extract_inst +%nterm cooperative_matrix_insert_inst +%nterm cooperative_matrix_load_inst +%nterm cooperative_matrix_mul_add_inst +%nterm cooperative_matrix_prefetch_inst +%nterm cooperative_matrix_reduce_inst +%nterm cooperative_matrix_scale_inst +%nterm cooperative_matrix_store_inst +%nterm checked %nterm expand_inst -%nterm <::tinytc::value> constant_or_dynamic_or_identifier -%nterm > expand_shape +%nterm integer_constant_or_identifier +%nterm > expand_shape %nterm fuse_inst %nterm load_inst -%nterm > optional_index_list -%nterm > index_list -%nterm <::tinytc::value> index_identifier_or_const -%nterm group_id_inst -%nterm group_size_inst +%nterm math_unary_inst +%nterm parallel_inst %nterm size_inst +%nterm subgroup_broadcast_inst +%nterm subgroup_operation_inst %nterm store_inst +%nterm store_flag %nterm subview_inst -%nterm > optional_slice_list -%nterm > slice_list -%nterm slice -%nterm <::tinytc::value> slice_size +%nterm >> optional_slice_list +%nterm >> slice_list +%nterm > slice +%nterm slice_size %% prog: func_list { - auto p = prog { std::make_unique(std::move($func_list), @prog).release() }; + auto p = prog { std::make_unique(ctx.cctx(), @prog).release() }; ctx.program(p); $$ = std::move(p); + for (auto& f : $func_list) { + $$.add_function(std::move(f)); + } } ; @@ -225,104 +278,155 @@ func_list: | func_list func { $$ = std::move($1); $$.emplace_back(std::move($func)); } func: - FUNC { - ctx.push_scope(); - } GLOBAL_IDENTIFIER LPAREN arguments RPAREN attributes region { + FUNC GLOBAL_IDENTIFIER LPAREN parameters RPAREN function_attributes { auto loc = @FUNC; loc.end = @RPAREN.end; - auto proto = func{ - std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), loc).release()}; - ctx.prototype($GLOBAL_IDENTIFIER, proto); - auto func_node = - std::make_unique(std::move(proto), std::move($region), @func).release(); - for (auto &attr : $attributes) { - attr(*func_node); - } - $func = func{func_node}; + try { + ctx.add_global_name($GLOBAL_IDENTIFIER, loc); + auto func_node = std::make_unique($GLOBAL_IDENTIFIER, $parameters.second, + get_void(ctx.cctx()), loc); + func_node->attr($function_attributes); + ctx.push_scope(); + auto name_it = $parameters.first.begin(); + for (auto &p : func_node->params()) { + ctx.val(name_it->id, p, name_it->loc); + if (name_it->dict != 0) { + func_node->param_attr(name_it - $parameters.first.begin(), name_it->dict); + } + ++name_it; + } + ctx.push_region(&func_node->body()); + $$ = func{func_node.release()}; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + }[prototype] region { + ctx.pop_region(); ctx.pop_scope(); + $$ = std::move($prototype); } ; -arguments: +parameters: %empty {} - | argument { $$.emplace_back(std::move($argument)); } - | arguments COMMA argument { $$ = std::move($1); $$.emplace_back(std::move($argument)); } + | parameter { + $$.first.emplace_back(std::move($parameter.first)); + $$.second.emplace_back(std::move($parameter.second)); + } + | parameters COMMA parameter { + $$.first = std::move($1.first); + $$.second = std::move($1.second); + $$.first.emplace_back(std::move($parameter.first)); + $$.second.emplace_back(std::move($parameter.second)); + } ; -argument: - LOCAL_IDENTIFIER COLON data_type { - auto v = make_value(std::move($data_type)); - v.name($LOCAL_IDENTIFIER); - ctx.val($LOCAL_IDENTIFIER, v, @LOCAL_IDENTIFIER); - $$ = std::move(v); +parameter: + LOCAL_IDENTIFIER COLON data_type optional_dictionary_attribute[dict] { + $$ = std::make_pair(param_attrs{$LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER, $dict}, $data_type); } ; -attributes: +function_attributes: %empty {} - | attributes attribute { $$ = std::move($1); $$.emplace_back(std::move($attribute)); } + | ATTRIBUTES dictionary_attribute { $$ = $dictionary_attribute; } ; attribute: - WORK_GROUP_SIZE LPAREN INTEGER_CONSTANT[m] COMMA INTEGER_CONSTANT[n] RPAREN { - if ($m <= 0) { - throw parser::syntax_error(@m, "Must be a non-negative number"); - } - if ($n <= 0) { - throw parser::syntax_error(@n, "Must be a non-negative number"); - } - auto const wgs = std::array{static_cast($m), - static_cast($n)}; - $$ = [=](function &f) { f.work_group_size(wgs); }; + array_attribute { $$ = $array_attribute; } + | BOOLEAN_CONSTANT { $$ = boolean_attr::get(ctx.cctx().get(), $BOOLEAN_CONSTANT); } + | dictionary_attribute { $$ = $dictionary_attribute; } + | INTEGER_CONSTANT { $$ = integer_attr::get(ctx.cctx().get(), $INTEGER_CONSTANT); } + | STRING { $$ = string_attr::get(ctx.cctx().get(), $STRING); } +; + +array_attribute: + LSQBR RSQBR { $$ = array_attr::get(ctx.cctx().get(), {}); } + | LSQBR attribute_list RSQBR { $$ = array_attr::get(ctx.cctx().get(), $attribute_list); } + +attribute_list: + attribute { $$.push_back($attribute); } + | attribute_list COMMA attribute { $$ = std::move($1); $$.push_back($attribute); } +; + +dictionary_attribute: + LBRACE RBRACE { $$ = dictionary_attr::get(ctx.cctx().get(), {}); } + | LBRACE named_attribute_list RBRACE { + dictionary_attr::sort($named_attribute_list); + $$ = dictionary_attr::get(ctx.cctx().get(), $named_attribute_list); } - | SUBGROUP_SIZE LPAREN INTEGER_CONSTANT RPAREN { - if ($INTEGER_CONSTANT <= 0) { - throw parser::syntax_error(@INTEGER_CONSTANT, "Must be a non-negative number"); - } - auto const sgs = static_cast($INTEGER_CONSTANT); - $$ = [=](function &f) { f.subgroup_size(sgs); }; +; + +named_attribute_list: + named_attribute { $$.push_back($named_attribute); } + | named_attribute_list COMMA named_attribute { $$ = std::move($1); $$.push_back($named_attribute); } +; + +named_attribute: + attribute_name EQUALS attribute { + $$ = named_attr{$attribute_name, $attribute}; } ; +attribute_name: + ATTR_NAME { $$ = string_attr::get(ctx.cctx().get(), $ATTR_NAME); } + | STRING { $$ = string_attr::get(ctx.cctx().get(), $STRING); } + +optional_dictionary_attribute: + %empty { $$ = nullptr; } + | dictionary_attribute { $$ = $dictionary_attribute; } +; + data_type: - scalar_type { $$ = make_scalar($scalar_type); $$->loc(@scalar_type); } - | memref_type + boolean_type + | coopmatrix_type | group_type + | memref_type + | scalar_type +; + +boolean_type: + BOOLEAN { $$ = get_boolean(ctx.cctx()); } ; scalar_type: - INTEGER_TYPE - | FLOATING_TYPE + INTEGER_TYPE { $$ = get_scalar(ctx.cctx(), $INTEGER_TYPE); } + | FLOATING_TYPE { $$ = get_scalar(ctx.cctx(), $FLOATING_TYPE); } +; + +coopmatrix_type: + COOPMATRIX LCHEV scalar_type TIMES INTEGER_CONSTANT[rows] TIMES INTEGER_CONSTANT[cols] COMMA MATRIX_USE RCHEV { + try { + $$ = get_coopmatrix($scalar_type, $rows, $cols, $MATRIX_USE, @coopmatrix_type); + } catch (compilation_error const& e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } ; memref_type: - MEMREF LCHEV scalar_type mode_list RCHEV { + MEMREF LCHEV scalar_type mode_list optional_address_space RCHEV { try { - $$ = data_type { - std::make_unique($scalar_type, std::move($mode_list), - std::vector{}, @memref_type) - .release() - }; + $$ = get_memref($scalar_type, $mode_list, {}, $optional_address_space, @memref_type); } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } - | MEMREF LCHEV scalar_type mode_list COMMA STRIDED LCHEV optional_stride_list RCHEV RCHEV { + | MEMREF LCHEV scalar_type mode_list COMMA STRIDED LCHEV optional_stride_list RCHEV optional_address_space RCHEV { if ($mode_list.size() != $optional_stride_list.size()) { auto loc = @scalar_type; loc.end = @optional_stride_list.end; throw syntax_error(loc, "Shape and stride list must have the same length"); } try { - $$ = data_type { - std::make_unique($scalar_type, std::move($mode_list), - std::move($optional_stride_list), @memref_type) - .release() - }; + $$ = get_memref($scalar_type, $mode_list, $optional_stride_list, + $optional_address_space, @memref_type); } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -333,6 +437,12 @@ mode_list: | mode_list TIMES constant_or_dynamic { $$ = std::move($1); $$.push_back($constant_or_dynamic); } ; +optional_address_space: + %empty { $$ = address_space::global; } + | COMMA GLOBAL { $$ = address_space::global; } + | COMMA LOCAL { $$ = address_space::local; } +; + optional_stride_list: %empty {} | stride_list { $$ = std::move($1); } @@ -349,9 +459,8 @@ constant_or_dynamic: ; group_type: - GROUP LCHEV memref_type group_offset RCHEV { - $$ = make_group(std::move($memref_type), $group_offset); - $$->loc(@group_type); + GROUP LCHEV memref_type TIMES constant_or_dynamic[group_size] group_offset RCHEV { + $$ = get_group(std::move($memref_type), $group_size, $group_offset, @group_type); } ; @@ -360,63 +469,56 @@ group_offset: | COMMA OFFSET COLON constant_or_dynamic { $$ = $constant_or_dynamic; } ; -memref_or_group_type: - memref_type - | group_type +var: + LOCAL_IDENTIFIER { $$ = ctx.val($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER); } ; region: - LBRACE { - ctx.push_scope(); - } instructions RBRACE { - $$ = region{std::make_unique(std::move($instructions), @region).release()}; - ctx.pop_scope(); - } -; - -var: - LOCAL_IDENTIFIER { $$ = ctx.val($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER); } + LBRACE { ctx.push_scope(); } instructions { ctx.pop_scope(); } RBRACE {} ; instructions: %empty {} | instructions instruction { - $$ = std::move($1); - $$.emplace_back(std::move($instruction)); + if (!ctx.has_regions()) { + error(@instruction, "Internal error: missing region"); + YYERROR; + } + tinytc_region_t reg = ctx.top_region(); + reg->insts().push_back($instruction.release()); } ; instruction: - axpby_inst - | gemm_inst - | gemv_inst - | ger_inst - | for_inst - | foreach_inst - | hadamard_inst - | if_inst - | var_definition - | store_inst - | sum_inst - | yield_inst + axpby_inst { $$ = std::move($1); } + | barrier_inst { $$ = std::move($1); } + | cooperative_matrix_prefetch_inst { $$ = std::move($1); } + | cooperative_matrix_store_inst { $$ = std::move($1); } + | cumsum_inst { $$ = std::move($1); } + | gemm_inst { $$ = std::move($1); } + | gemv_inst { $$ = std::move($1); } + | ger_inst { $$ = std::move($1); } + | for_inst { $$ = std::move($1); } + | foreach_inst { $$ = std::move($1); } + | hadamard_inst { $$ = std::move($1); } + | if_inst { $$ = std::move($1); } + | parallel_inst { $$ = std::move($1); } + | var_definition { $$ = std::move($1); } + | store_inst { $$ = std::move($1); } + | sum_inst { $$ = std::move($1); } + | yield_inst { $$ = std::move($1); } ; axpby_inst: - AXPBY transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($b, $mb, @b, @mb); + AXPBY transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] { try { $$ = inst { - std::make_unique($ta, std::move($alpha), std::move($a), - std::move($beta), std::move($b), $atomic, @axpby_inst) + std::make_unique($ta, std::move($alpha), std::move($a), std::move($beta), + std::move($b), $atomic, @axpby_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -427,34 +529,60 @@ atomic: | ATOMIC { $$ = true; } ; -identifier_or_constant: - var { $$ = $var; } - | INTEGER_CONSTANT { $$ = make_imm($INTEGER_CONSTANT); $$->loc(@INTEGER_CONSTANT); } - | FLOATING_CONSTANT { $$ = make_imm($FLOATING_CONSTANT); $$->loc(@FLOATING_CONSTANT); } -; - -optional_identifier_or_constant_list: +optional_value_list: %empty {} - | identifier_or_constant_list { $$ = std::move($1); } + | value_list { $$ = std::move($1); } +; -identifier_or_constant_list: - identifier_or_constant { $$.push_back(std::move($identifier_or_constant)); } - | identifier_or_constant_list COMMA identifier_or_constant { +value_list: + var { $$.push_back(std::move($var)); } + | value_list COMMA var { $$ = std::move($1); - $$.push_back(std::move($identifier_or_constant)); + $$.push_back(std::move($var)); + } +; + +barrier_inst: + BARRIER optional_global_attr optional_local_attr { + int32_t fence_flags = 0; + fence_flags |= $optional_global_attr; + fence_flags |= $optional_local_attr; + try { + $$ = inst { std::make_unique(fence_flags, @barrier_inst).release() }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } +; + +optional_global_attr: + %empty { $$ = 0; } + | GLOBAL_ATTR { $$ = tinytc_address_space_global; } +; + +optional_local_attr: + %empty { $$ = 0; } + | LOCAL_ATTR { $$ = tinytc_address_space_local; } +; + +cumsum_inst: + CUMSUM atomic var[alpha] COMMA var[a] COMMA INTEGER_CONSTANT[mode] COMMA var[beta] COMMA var[b] { + try { + $$ = inst { + std::make_unique(std::move($alpha), std::move($a), $mode, std::move($beta), + std::move($b), $atomic, @cumsum_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } } ; gemm_inst: - GEMM transpose[ta] transpose[tb] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + GEMM transpose[ta] transpose[tb] atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique($ta, $tb, std::move($alpha), std::move($a), @@ -463,22 +591,14 @@ gemm_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; gemv_inst: - GEMV transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + GEMV transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique($ta, std::move($alpha), std::move($a), std::move($b), @@ -486,7 +606,7 @@ gemv_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } @@ -498,15 +618,7 @@ transpose: ; ger_inst: - GER atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + GER atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique(std::move($alpha), std::move($a), std::move($b), @@ -514,89 +626,111 @@ ger_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; for_inst: - FOR LOCAL_IDENTIFIER[loop_var] - EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] optional_step - for_loop_var_type { - check_scalar_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type($to, $for_loop_var_type, @to, @for_loop_var_type); - if ($optional_step) { - check_scalar_type($optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); - } - auto v = make_value($for_loop_var_type); - v.name($loop_var); - ctx.val($loop_var, std::move(v), @loop_var); - } region { - try { - $$ = inst { - std::make_unique(ctx.val($loop_var, @loop_var), $from, $to, - $optional_step, std::move($region), @for_inst) - .release() - }; + FOR LOCAL_IDENTIFIER[loop_var] for_loop_var_type EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] { + try { + auto &[lcv_id, lcv_init, lcv_type] = $lcv; + location loc = @FOR; + loc.end = @lcv.end; + auto inode = std::make_unique($for_loop_var_type, $from, $to, $optional_step, + lcv_init, lcv_type, loc); + ctx.push_scope(); + auto &loop_var = inode->loop_var(); + ctx.val($loop_var, loop_var, @loop_var); + for (std::int64_t i = 0; i < inode->num_results(); ++i) { + ctx.val(lcv_id[i], inode->iter_arg(i), @lcv); + } + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } + }[loop_header] region optional_dictionary_attribute { + ctx.pop_region(); + ctx.pop_scope(); + $loop_header->attr($optional_dictionary_attribute); + $$ = std::move($loop_header); } ; optional_step: %empty { $$ = {}; } - | COMMA identifier_or_constant { $$ = $identifier_or_constant; } + | COMMA var { $$ = $var; } +; + +optional_loop_carried_values: + %empty { $$ = {}; } + | INIT LPAREN init_value_list RPAREN ARROW LPAREN return_type_list RPAREN { + $$ = std::make_tuple(std::move($init_value_list.first), std::move($init_value_list.second), + std::move($return_type_list)); + } +; + +init_value_list: + init_value { + $$.first.emplace_back($init_value.first); + $$.second.emplace_back($init_value.second); + } + | init_value_list COMMA init_value { + $$ = std::move($1); + $$.first.emplace_back($init_value.first); + $$.second.emplace_back($init_value.second); + } +; + +init_value: + LOCAL_IDENTIFIER EQUALS var { $$ = std::make_pair($LOCAL_IDENTIFIER, $var); } +; foreach_inst: - FOREACH LOCAL_IDENTIFIER[loop_var] - EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] for_loop_var_type { - check_scalar_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type($to, $for_loop_var_type, @to, @for_loop_var_type); - auto v = make_value($for_loop_var_type); - v.name($loop_var); - ctx.val($loop_var, std::move(v), @loop_var); - } region { - try { - $$ = inst { - std::make_unique(ctx.val($loop_var, @loop_var), $from, $to, - std::move($region), @foreach_inst) - .release() - }; + FOREACH LPAREN identifier_list[loop_var] RPAREN for_loop_var_type EQUALS + LPAREN value_list[from] RPAREN COMMA LPAREN value_list[to] RPAREN { + try { + location loc = @FOREACH; + loc.end = @for_loop_var_type.end; + auto inode = + std::make_unique($for_loop_var_type, $from, $to, loc); + ctx.push_scope(); + auto loop_vars = inode->loop_vars().begin(); + for (std::int64_t i = 0; i < inode->dim(); ++i) { + ctx.val($loop_var[i], loop_vars[i], @loop_var); + } + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } + }[loop_header] region { + ctx.pop_region(); + ctx.pop_scope(); + $$ = std::move($loop_header); } ; for_loop_var_type: - %empty { $$ = scalar_type::index; } - | COLON INTEGER_TYPE { $$ = $INTEGER_TYPE; } + %empty { $$ = get_scalar(ctx.cctx(), scalar_type::index); } + | COLON INTEGER_TYPE { $$ = get_scalar(ctx.cctx(), $INTEGER_TYPE); } ; var_definition: identifier_list EQUALS valued_inst { $$ = std::move($valued_inst); - if ($identifier_list.size() == 1) { - if (!$$->result()) { - throw syntax_error(@identifier_list, "Instruction does not return value"); - } - $$->result()->name($identifier_list[0]); - ctx.val($identifier_list[0], $$->result(), @identifier_list); - } else { - auto results = $$->results(); - if (results.size() != $identifier_list.size()) { - throw syntax_error( - @identifier_list, - "Number of identifiers does not equal number of returned values"); - } - for (std::size_t i = 0; i < results.size(); ++i) { - results[i]->name($identifier_list[i]); - ctx.val($identifier_list[i], results[i], @identifier_list); - } + if (static_cast($identifier_list.size()) != $$->num_results()) { + throw syntax_error( + @identifier_list, + "Number of identifiers does not equal number of returned values"); + } + auto results = $$->result_begin(); + for (std::int64_t i = 0; i < $$->num_results(); ++i) { + ctx.val($identifier_list[i], results[i], @identifier_list); } } ; @@ -608,15 +742,7 @@ identifier_list: hadamard_inst: - HADAMARD atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); + HADAMARD atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { try { $$ = inst { std::make_unique(std::move($alpha), std::move($a), std::move($b), @@ -625,20 +751,14 @@ hadamard_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; sum_inst: - SUM transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($b, $mb, @b, @mb); + SUM transpose[ta] atomic var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] { try { $$ = inst { std::make_unique($ta, std::move($alpha), std::move($a), std::move($beta), @@ -646,304 +766,580 @@ sum_inst: .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; yield_inst: - YIELD optional_identifier_or_constant_list[vals] COLON optional_scalar_type_list[tys] { - if ($vals.size() != $tys.size()) { - location loc = @vals; - loc.end = @tys.end; - throw syntax_error(loc, "Identifier and scalar type list must have the same length"); - } - for (std::size_t i = 0; i < $vals.size(); ++i) { - check_scalar_type($vals[i], $tys[i], @vals, @tys); - } - $$ = inst{std::make_unique(std::move($vals)).release()}; + YIELD LPAREN optional_value_list[vals] RPAREN { + $$ = inst{std::make_unique(std::move($vals), @yield_inst).release()}; } ; valued_inst: - alloca_inst - | arith_inst - | arith_unary_inst - | cast_inst - | compare_inst - | expand_inst - | fuse_inst - | if_inst - | load_inst - | group_id_inst - | group_size_inst - | size_inst - | subview_inst + alloca_inst { $$ = std::move($1); } + | arith_inst { $$ = std::move($1); } + | arith_unary_inst { $$ = std::move($1); } + | builtin_inst { $$ = std::move($1); } + | cast_inst { $$ = std::move($1); } + | compare_inst { $$ = std::move($1); } + | constant_inst { $$ = std::move($1); } + | cooperative_matrix_apply_inst { $$ = std::move($1); } + | cooperative_matrix_extract_inst { $$ = std::move($1); } + | cooperative_matrix_insert_inst { $$ = std::move($1); } + | cooperative_matrix_load_inst { $$ = std::move($1); } + | cooperative_matrix_mul_add_inst { $$ = std::move($1); } + | cooperative_matrix_reduce_inst { $$ = std::move($1); } + | cooperative_matrix_scale_inst { $$ = std::move($1); } + | expand_inst { $$ = std::move($1); } + | for_inst { $$ = std::move($1); } + | fuse_inst { $$ = std::move($1); } + | if_inst { $$ = std::move($1); } + | load_inst { $$ = std::move($1); } + | math_unary_inst { $$ = std::move($1); } + | size_inst { $$ = std::move($1); } + | subgroup_broadcast_inst { $$ = std::move($1); } + | subgroup_operation_inst { $$ = std::move($1); } + | subview_inst { $$ = std::move($1); } ; alloca_inst: - ALLOCA RETURNS memref_type { + ALLOCA optional_dictionary_attribute[dict] COLON memref_type { try { $$ = inst { std::make_unique(std::move($memref_type), @alloca_inst).release() }; + $$->attr($dict); } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; arith_inst: - ARITH ARITHMETIC identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - check_scalar_type($b, $ty, @b, @ty); + ARITHMETIC var[a] COMMA var[b] COLON data_type[ty] { try { $$ = inst { - std::make_unique($ARITHMETIC, std::move($a), std::move($b), @arith_inst) + std::make_unique($ARITHMETIC, std::move($a), std::move($b), std::move($ty), + @arith_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; arith_unary_inst: - ARITH ARITHMETIC_UNARY identifier_or_constant[a] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); + ARITHMETIC_UNARY var[a] COLON data_type[ty] { try { $$ = inst { - std::make_unique($ARITHMETIC_UNARY, std::move($a), + std::make_unique($ARITHMETIC_UNARY, std::move($a), std::move($ty), @arith_unary_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; +builtin_inst: + BUILTIN COLON data_type[ty] { + $$ = inst{std::make_unique($BUILTIN, $ty, @builtin_inst).release()}; + } +; cast_inst: - CAST identifier_or_constant[a] COLON scalar_type[from] RETURNS scalar_type[to] { - check_scalar_type($a, $from, @a, @from); + CAST var[a] COLON data_type[to] { try { $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; compare_inst: - CMP CMP_CONDITION identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - check_scalar_type($b, $ty, @b, @ty); + CMP_CONDITION var[a] COMMA var[b] COLON boolean_type { try { $$ = inst { std::make_unique($CMP_CONDITION, std::move($a), std::move($b), - @compare_inst) + std::move($boolean_type), @compare_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; -expand_inst: - EXPAND var LSQBR INTEGER_CONSTANT[mode] RETURNS expand_shape RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); +constant_inst: + CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR COLON data_type { + try { + $$ = inst { + std::make_unique(std::complex{$re, $im}, $data_type, @constant_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } + | CONSTANT FLOATING_CONSTANT COLON data_type { + try { + $$ = inst { + std::make_unique($FLOATING_CONSTANT, $data_type, @constant_inst).release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } + | CONSTANT INTEGER_CONSTANT COLON data_type { + try { + $$ = inst { + std::make_unique($INTEGER_CONSTANT, $data_type, @constant_inst).release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; } + } + | CONSTANT BOOLEAN_CONSTANT COLON data_type { try { $$ = inst { - std::make_unique(std::move($var), $mode, std::move($expand_shape), - @expand_inst) + std::make_unique($BOOLEAN_CONSTANT, $data_type, @constant_inst).release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } +; + +cooperative_matrix_apply_inst: + COOPERATIVE_MATRIX_APPLY + LPAREN LOCAL_IDENTIFIER[row] COMMA LOCAL_IDENTIFIER[col] COMMA LOCAL_IDENTIFIER[val] RPAREN + EQUALS var ARROW data_type[result_ty] { + try { + location loc = @COOPERATIVE_MATRIX_APPLY; + loc.end = @result_ty.end; + auto inode = std::make_unique($var, $result_ty, loc); + ctx.push_scope(); + auto &row = inode->row(); + ctx.val($row, row, @row); + auto &col = inode->col(); + ctx.val($col, col, @col); + auto &val = inode->val(); + ctx.val($val, val, @val); + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + }[apply_header] region { + ctx.pop_region(); + ctx.pop_scope(); + $$ = std::move($apply_header); + } +; + +cooperative_matrix_extract_inst: + COOPERATIVE_MATRIX_EXTRACT var[mat] LSQBR INTEGER_CONSTANT[index] RSQBR COLON data_type[ty] { + try { + $$ = inst { + std::make_unique(std::move($mat), $index, std::move($ty), + @cooperative_matrix_extract_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; -expand_shape: - constant_or_dynamic_or_identifier[a] TIMES constant_or_dynamic_or_identifier[b] { - $$ = std::vector{$a, $b}; +cooperative_matrix_insert_inst: + COOPERATIVE_MATRIX_INSERT var[val] COMMA var[mat] LSQBR INTEGER_CONSTANT[index] RSQBR COLON data_type[ty] { + try { + $$ = inst { + std::make_unique(std::move($val), std::move($mat), $index, + std::move($ty), + @cooperative_matrix_insert_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } } - | expand_shape TIMES constant_or_dynamic_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } ; -constant_or_dynamic_or_identifier: - var { - check_scalar_type($var, scalar_type::index, @var, @var); - $$ = $var; +cooperative_matrix_load_inst: + COOPERATIVE_MATRIX_LOAD transpose checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[result_ty] { + try { + $$ = inst { + std::make_unique( + $transpose, $checked, std::move($op), std::move($p0), std::move($p1), + std::move($result_ty), @cooperative_matrix_load_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } } - | INTEGER_CONSTANT { $$ = make_index($INTEGER_CONSTANT); $$->loc(@INTEGER_CONSTANT); } - | DYNAMIC { $$ = make_dynamic(); $$->loc(@DYNAMIC); } ; -fuse_inst: - FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); +checked: + %empty { $$ = checked_flag::none; } + | CHECKED { $$ = $CHECKED; } +; + +cooperative_matrix_mul_add_inst: + COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[to_ty] { + try { + $$ = inst { + std::make_unique(std::move($a), std::move($b), + std::move($c), std::move($to_ty), + @cooperative_matrix_mul_add_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; } + } +; + +cooperative_matrix_prefetch_inst: + COOPERATIVE_MATRIX_PREFETCH INTEGER_CONSTANT[cache_level] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR COMMA INTEGER_CONSTANT[rows] COMMA INTEGER_CONSTANT[cols] { try { $$ = inst { - std::make_unique(std::move($var), $from, $to, @fuse_inst).release() + std::make_unique($cache_level, std::move($op), + std::move($p0), std::move($p1), $rows, + $cols, @cooperative_matrix_prefetch_inst) + .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; -load_inst: - LOAD var LSQBR optional_index_list RSQBR COLON memref_or_group_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_or_group_type)) { - auto loc = @var; - loc.end = @memref_or_group_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); +cooperative_matrix_reduce_inst: + COOPERATIVE_MATRIX_REDUCE var[a] COLON data_type[ty] { + try { + $$ = inst { + std::make_unique( + $COOPERATIVE_MATRIX_REDUCE.first, $COOPERATIVE_MATRIX_REDUCE.second, std::move($a), + std::move($ty), @cooperative_matrix_reduce_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; } + } +; + +cooperative_matrix_scale_inst: + COOPERATIVE_MATRIX_SCALE var[a] COMMA var[b] COLON data_type[ty] { try { $$ = inst { - std::make_unique(std::move($var), std::move($optional_index_list), - @load_inst) + std::make_unique( + std::move($a), std::move($b), std::move($ty), @cooperative_matrix_scale_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; -optional_index_list: - %empty {} - | index_list { $$ = std::move($1); } +cooperative_matrix_store_inst: + COOPERATIVE_MATRIX_STORE checked store_flag var[val] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR { + try { + $$ = inst { + std::make_unique( + $checked, $store_flag, std::move($val), std::move($op), std::move($p0), std::move($p1), + @cooperative_matrix_store_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } ; -index_list: - index_identifier_or_const { $$.push_back($index_identifier_or_const); } - | index_list COMMA index_identifier_or_const { $$ = std::move($1); $$.push_back($index_identifier_or_const); } +expand_inst: + EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] ARROW expand_shape RSQBR COLON memref_type[ty] { + try { + auto static_shape = std::vector{}; + static_shape.reserve($expand_shape.size()); + auto dynamic_shape = std::vector{}; + dynamic_shape.reserve($expand_shape.size()); + for (auto &s : $expand_shape) { + std::visit(overloaded{ + [&](std::int64_t i) { static_shape.push_back(i); }, + [&](tinytc_value_t v) { + static_shape.push_back(dynamic); + dynamic_shape.push_back(v); + }, + }, s); + } + $$ = inst { + std::make_unique(std::move($var), $expanded_mode, std::move(static_shape), + std::move(dynamic_shape), $ty, @expand_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } catch (std::exception const& e) { + error(@expand_inst, e.what()); + } + } ; -index_identifier_or_const: +expand_shape: + integer_constant_or_identifier[a] TIMES integer_constant_or_identifier[b] { + $$ = std::vector{$a, $b}; + } + | expand_shape TIMES integer_constant_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } +; + +integer_constant_or_identifier: var { - check_scalar_type($var, scalar_type::index, @var, @var); $$ = $var; } | INTEGER_CONSTANT { - $$ = make_index($INTEGER_CONSTANT); - $$->loc(@INTEGER_CONSTANT); + $$ = $INTEGER_CONSTANT; } ; -store_inst: - STORE var[a] COMMA var[b] LSQBR optional_index_list RSQBR COLON memref_type { - if (!$b->ty() || !is_equal(*$b->ty(), *$memref_type)) { - auto loc = @b; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); +fuse_inst: + FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type[ty] { + try { + $$ = inst { + std::make_unique(std::move($var), $from, $to, $ty, @fuse_inst).release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; } + } +; + +load_inst: + LOAD var LSQBR optional_value_list RSQBR COLON data_type { try { $$ = inst { - std::make_unique(std::move($a), std::move($b), - std::move($optional_index_list), @store_inst) + std::make_unique(std::move($var), std::move($optional_value_list), + std::move($data_type), @load_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; -group_id_inst: - GROUP_ID { $$ = inst{std::make_unique().release()}; } +store_inst: + STORE store_flag var[a] COMMA var[b] LSQBR optional_value_list RSQBR { + try { + $$ = inst { + std::make_unique($store_flag, std::move($a), std::move($b), + std::move($optional_value_list), @store_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } ; -group_size_inst: - GROUP_SIZE { $$ = inst{std::make_unique().release()}; } +store_flag: + %empty { $$ = store_flag::regular; } + | ATOMIC { $$ = store_flag::atomic; } + | ATOMIC_ADD { $$ = store_flag::atomic_add; } + | ATOMIC_MAX { $$ = store_flag::atomic_max; } + | ATOMIC_MIN { $$ = store_flag::atomic_min; } ; if_inst: - IF identifier_or_constant[condition] optional_returned_values region else_region { - check_scalar_type($condition, scalar_type::i1, @condition, @condition); - $$ = inst{std::make_unique(std::move($condition), std::move($region), - std::move($else_region), - std::move($optional_returned_values)) - .release()}; - $$->loc(@if_inst); + IF var[condition] optional_returned_values { + try { + auto loc = @IF; + loc.end = @optional_returned_values.end; + auto inode = std::make_unique(std::move($condition), + std::move($optional_returned_values), loc); + ctx.push_region(&inode->then()); + $$ = std::move(inode); + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + }[header] region { + ctx.pop_region(); + ctx.push_region(&$header->otherwise()); + } else_region { + ctx.pop_region(); + $$ = inst{$header.release()}; } ; else_region: - %empty { $$ = {}; } - | ELSE region{ $$ = std::move($region); } + %empty {} + | ELSE region {} ; optional_returned_values: %empty { $$ = {}; } - | RETURNS LPAREN optional_scalar_type_list[tys] RPAREN { $$ = std::move($tys); } + | ARROW LPAREN optional_return_type_list[tys] RPAREN { $$ = std::move($tys); } ; -optional_scalar_type_list: +optional_return_type_list: %empty {} - | scalar_type_list { $$ = std::move($1); } + | return_type_list { $$ = std::move($1); } +; + +return_type_list: + data_type { $$.push_back($data_type); } + | return_type_list COMMA data_type { + $$ = std::move($1); $$.push_back($data_type); + } ; -scalar_type_list: - scalar_type { $$.push_back($scalar_type); } - | scalar_type_list COMMA scalar_type { $$ = std::move($1); $$.push_back($scalar_type); } +math_unary_inst: + MATH_UNARY var[a] COLON data_type[ty] { + try { + $$ = inst { + std::make_unique($MATH_UNARY, std::move($a), std::move($ty), + @math_unary_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } ; +parallel_inst: + PARALLEL { + try { + auto inode = std::make_unique(@PARALLEL); + ctx.push_region(&inode->body()); + $$ = inst{inode.release()}; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + }[header] region { + ctx.pop_region(); + $$ = std::move($header); + } +; size_inst: - SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); + SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON scalar_type { + try { + $$ = inst { + std::make_unique(std::move($var), $mode, $scalar_type, @size_inst).release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; + } + } +; + +subgroup_broadcast_inst: + SUBGROUP_BROADCAST var[a] COMMA var[idx] COLON scalar_type { + try { + $$ = inst { + std::make_unique(std::move($a), std::move($idx), $scalar_type, + @subgroup_broadcast_inst) + .release() + }; + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + YYERROR; } + } +; + +subgroup_operation_inst: + SUBGROUP_OPERATION var[a] COLON scalar_type { try { - $$ = inst { std::make_unique(std::move($var), $mode, @size_inst).release() }; + $$ = inst { + std::make_unique($SUBGROUP_OPERATION.first, + $SUBGROUP_OPERATION.second, std::move($a), + $scalar_type, @subgroup_operation_inst) + .release() + }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; } } ; subview_inst: - SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } + SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type[ty] { try { + auto static_offsets = std::vector{}; + auto static_sizes = std::vector{}; + auto offsets = std::vector{}; + auto sizes = std::vector{}; + static_offsets.reserve($optional_slice_list.size()); + static_sizes.reserve($optional_slice_list.size()); + offsets.reserve($optional_slice_list.size()); + sizes.reserve($optional_slice_list.size()); + for (auto &s : $optional_slice_list) { + std::visit(overloaded{ + [&](std::int64_t i) { static_offsets.push_back(i); }, + [&](tinytc_value_t v) { + static_offsets.push_back(dynamic); + offsets.push_back(v); + }, + }, + s.first); + std::visit(overloaded{ + [&](std::int64_t i) { static_sizes.push_back(i); }, + [&](tinytc_value_t v) { + static_sizes.push_back(dynamic); + sizes.push_back(v); + }, + }, + s.second); + } $$ = inst { - std::make_unique(std::move($var), std::move($optional_slice_list), - @subview_inst) + std::make_unique( + std::move($var), std::move(static_offsets), std::move(static_sizes), std::move(offsets), + std::move(sizes), std::move($ty), @subview_inst) .release() }; } catch (compilation_error const &e) { - error(e.loc(), e.what()); + report_error(ctx.cctx(), e); YYERROR; + } catch (std::exception const &e) { + error(@subview_inst, e.what()); } } ; @@ -954,25 +1350,28 @@ optional_slice_list: ; slice_list: - slice { $$.push_back($slice); } - | slice_list COMMA slice { $$ = std::move($1); $$.push_back($slice); } + slice { + $$.emplace_back(std::move($slice)); + } + | slice_list COMMA slice { + $$ = std::move($1); + $$.emplace_back(std::move($slice)); + } ; slice: - COLON { $$ = slice(make_index(0), make_dynamic()); } - | index_identifier_or_const slice_size { $$ = slice(std::move($1), std::move($2)); } + integer_constant_or_identifier slice_size { $$ = std::make_pair(std::move($1), std::move($2)); } ; slice_size: %empty { $$ = {}; } - | COLON index_identifier_or_const { $$ = $2; } - | COLON DYNAMIC { $$ = make_dynamic(); } + | COLON integer_constant_or_identifier { $$ = $2; } ; %% namespace tinytc { void parser::error(location_type const& l, std::string const& m) { - ctx.add_error(l, m); + ctx.report_error(l, m); } } diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp new file mode 100644 index 00000000..1506c6d7 --- /dev/null +++ b/src/pass/check_ir.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/check_ir.hpp" +#include "error.hpp" +#include "node/value_node.hpp" +#include "support/walk.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" + +#include +#include + +namespace tinytc { + +void check_ir_pass::check_yield(region_node const ®, inst_node const &in, + status yield_missing_status) { + auto last_inst = --reg.end(); + if (last_inst == reg.end()) { + throw compilation_error(reg.loc(), yield_missing_status); + } + auto yield = dyn_cast(last_inst.get()); + if (!yield) { + throw compilation_error(reg.loc(), yield_missing_status); + } + if (yield->num_operands() != in.num_results()) { + throw compilation_error(yield->loc(), status::ir_yield_mismatch); + } + for (std::int64_t i = 0; i < in.num_results(); ++i) { + if (yield->op(i).ty() != in.result(i).ty()) { + throw compilation_error(yield->loc(), {&yield->op(i)}, status::ir_yield_mismatch); + } + } +} + +void check_ir_pass::operator()(inst_node const &) {} +void check_ir_pass::operator()(for_inst const &in) { + if (in.num_results() > 0) { + check_yield(in.body(), in); + } +} +void check_ir_pass::operator()(if_inst const &in) { + if (in.num_results() > 0) { + check_yield(in.then(), in); + check_yield(in.otherwise(), in, status::ir_yield_in_else_branch_missing); + } +} + +void check_ir_pass::run_on_function(function_node &fn) { + walk(fn, [this](inst_node const &i, walk_stage const &stage) { + const bool child_region_is_spmd_region = + i.num_child_regions() > 0 && i.child_region(0).kind() == region_kind::spmd; + + if (stage.is_before_all_regions()) { + if (i.kind() == inst_execution_kind::collective && inside_spmd_region_) { + throw compilation_error(i.loc(), status::ir_collective_called_from_spmd); + } else if (i.kind() == inst_execution_kind::spmd && !inside_spmd_region_) { + throw compilation_error(i.loc(), status::ir_spmd_called_from_collective); + } + + if (child_region_is_spmd_region) { + inside_spmd_region_ = true; + } + } + + if (child_region_is_spmd_region && stage.is_after_all_regions()) { + inside_spmd_region_ = false; + } + + visit(*this, i); + }); +} + +} // namespace tinytc diff --git a/src/visitor/check_ir.hpp b/src/pass/check_ir.hpp similarity index 54% rename from src/visitor/check_ir.hpp rename to src/pass/check_ir.hpp index da135768..0d44fecf 100644 --- a/src/visitor/check_ir.hpp +++ b/src/pass/check_ir.hpp @@ -6,30 +6,23 @@ #include "node/function_node.hpp" #include "node/inst_node.hpp" -#include "node/program_node.hpp" #include "node/region_node.hpp" +#include "tinytc/types.hpp" namespace tinytc { -class ir_checker { +class check_ir_pass { public: - /* Stmt nodes */ void operator()(inst_node const &in); - void operator()(for_inst const &p); - void operator()(foreach_inst const &p); + void operator()(for_inst const &in); void operator()(if_inst const &in); - /* Region nodes */ - void operator()(rgn const &b); - - /* Func nodes */ - void operator()(prototype const &); - void operator()(function const &fn); - - /* Program nodes */ - void operator()(program const &p); + void run_on_function(function_node &fn); private: + void check_yield(region_node const ®, inst_node const &in, + status yield_missing_status = status::ir_must_have_yield); + bool inside_spmd_region_ = false; }; diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp new file mode 100644 index 00000000..beda7c4c --- /dev/null +++ b/src/pass/clone.cpp @@ -0,0 +1,231 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/clone.hpp" +#include "node/data_type_node.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include + +namespace tinytc { + +void inst_cloner::reset_subs() { subs_map_.clear(); } +void inst_cloner::set_subs(tinytc_value_t in_val, tinytc_value_t out_val) { + subs_map_[in_val] = out_val; +} +auto inst_cloner::subs(tinytc_value_t val) -> tinytc_value_t { + if (auto it = subs_map_.find(val); it != subs_map_.end()) { + return it->second; + } + return val; +} + +auto inst_cloner::operator()(alloca_inst &in) -> std::unique_ptr { + return std::make_unique(in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(axpby_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), subs(&in.alpha()), subs(&in.A()), subs(&in.beta()), + subs(&in.B()), in.atomic(), in.loc()); +} +auto inst_cloner::operator()(arith_inst &in) -> std::unique_ptr { + return std::make_unique(in.operation(), subs(&in.a()), subs(&in.b()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(arith_unary_inst &in) -> std::unique_ptr { + return std::make_unique(in.operation(), subs(&in.a()), in.result(0).ty(), + in.loc()); +} +auto inst_cloner::operator()(barrier_inst &in) -> std::unique_ptr { + return std::make_unique(in.fence_flags(), in.loc()); +} +auto inst_cloner::operator()(builtin_inst &in) -> std::unique_ptr { + return std::make_unique(in.builtin_type(), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cast_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.a()), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(compare_inst &in) -> std::unique_ptr { + return std::make_unique(in.cond(), subs(&in.a()), subs(&in.b()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(constant_inst &in) -> std::unique_ptr { + return std::make_unique(in.value(), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_apply_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.a()), in.result(0).ty(), + in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_extract_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.mat()), in.index(), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_insert_inst &in) -> std::unique_ptr { + return std::make_unique( + subs(&in.val()), subs(&in.mat()), in.index(), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_load_inst &in) -> std::unique_ptr { + return std::make_unique(in.t(), in.checked(), subs(&in.operand()), + subs(&in.pos0()), subs(&in.pos1()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_mul_add_inst &in) -> std::unique_ptr { + return std::make_unique( + subs(&in.a()), subs(&in.b()), subs(&in.c()), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_prefetch_inst &in) -> std::unique_ptr { + return std::make_unique(in.cache_level(), subs(&in.operand()), + subs(&in.pos0()), subs(&in.pos1()), + in.rows(), in.cols(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_reduce_inst &in) -> std::unique_ptr { + return std::make_unique(in.arith(), in.mode(), subs(&in.a()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_scale_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.a()), subs(&in.b()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(cooperative_matrix_store_inst &in) -> std::unique_ptr { + return std::make_unique(in.checked(), in.flag(), subs(&in.val()), + subs(&in.operand()), subs(&in.pos0()), + subs(&in.pos1()), in.loc()); +} +auto inst_cloner::operator()(cumsum_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.alpha()), subs(&in.A()), in.mode(), + subs(&in.beta()), subs(&in.B()), in.atomic(), in.loc()); +} +auto inst_cloner::operator()(expand_inst &in) -> std::unique_ptr { + return std::make_unique( + subs(&in.operand()), in.expanded_mode(), in.static_expand_shape(), + subs_value_range(in.expand_shape()), in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(fuse_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), in.from(), in.to(), in.result(0).ty(), + in.loc()); +} + +auto inst_cloner::operator()(lifetime_stop_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.object()), in.loc()); +} +auto inst_cloner::operator()(load_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), subs_value_range(in.index_list()), + in.result(0).ty(), in.loc()); +} +auto inst_cloner::operator()(gemm_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), in.tB(), subs(&in.alpha()), subs(&in.A()), + subs(&in.B()), subs(&in.beta()), subs(&in.C()), in.atomic(), + in.loc()); +} + +auto inst_cloner::operator()(gemv_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), subs(&in.alpha()), subs(&in.A()), subs(&in.B()), + subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); +} + +auto inst_cloner::operator()(ger_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.alpha()), subs(&in.A()), subs(&in.B()), + subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); +} +auto inst_cloner::operator()(for_inst &in) -> std::unique_ptr { + auto return_types = std::vector(in.num_results()); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + return_types[i] = in.result(0).ty(); + } + return std::make_unique(in.body().param(0).ty(), subs(&in.from()), subs(&in.to()), + in.has_step() ? subs(&in.step()) : nullptr, + subs_value_range(in.iter_init()), return_types, in.loc()); +} + +auto inst_cloner::operator()(foreach_inst &in) -> std::unique_ptr { + return std::make_unique(in.body().param(0).ty(), subs_value_range(in.from()), + subs_value_range(in.to()), in.loc()); +} + +auto inst_cloner::operator()(hadamard_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.alpha()), subs(&in.A()), subs(&in.B()), + subs(&in.beta()), subs(&in.C()), in.atomic(), in.loc()); +} + +auto inst_cloner::operator()(if_inst &in) -> std::unique_ptr { + auto return_types = std::vector(in.num_results()); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + return_types[i] = in.result(i).ty(); + } + return std::make_unique(subs(&in.condition()), return_types, in.loc()); +} + +auto inst_cloner::operator()(math_unary_inst &in) -> std::unique_ptr { + return std::make_unique(in.operation(), subs(&in.a()), in.result(0).ty(), + in.loc()); +} + +auto inst_cloner::operator()(parallel_inst &in) -> std::unique_ptr { + return std::make_unique(in.loc()); +} + +auto inst_cloner::operator()(size_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.operand()), in.mode(), in.result(0).ty(), in.loc()); +} + +auto inst_cloner::operator()(subgroup_broadcast_inst &in) -> std::unique_ptr { + return std::make_unique(subs(&in.a()), subs(&in.idx()), + in.result(0).ty(), in.loc()); +} + +auto inst_cloner::operator()(subgroup_operation_inst &in) -> std::unique_ptr { + return std::make_unique(in.arith(), in.operation(), subs(&in.a()), + in.result(0).ty(), in.loc()); +} + +auto inst_cloner::operator()(subview_inst &in) -> std::unique_ptr { + return std::make_unique( + subs(&in.operand()), in.static_offsets(), in.static_sizes(), subs_value_range(in.offsets()), + subs_value_range(in.sizes()), in.result(0).ty(), in.loc()); +} + +auto inst_cloner::operator()(store_inst &in) -> std::unique_ptr { + return std::make_unique(in.flag(), subs(&in.val()), subs(&in.operand()), + subs_value_range(in.index_list()), in.loc()); +} + +auto inst_cloner::operator()(sum_inst &in) -> std::unique_ptr { + return std::make_unique(in.tA(), subs(&in.alpha()), subs(&in.A()), subs(&in.beta()), + subs(&in.B()), in.atomic(), in.loc()); +} + +auto inst_cloner::operator()(yield_inst &in) -> std::unique_ptr { + return std::make_unique(subs_value_range(std::views::all(in.operands())), in.loc()); +} + +auto inst_cloner::clone_instruction(inst_node &in) -> std::unique_ptr { + auto cloned = visit(*this, in); + for (auto res_orig = in.result_begin(), res_cloned = cloned->result_begin(); + res_orig != in.result_end() && res_cloned != cloned->result_end(); + ++res_orig, ++res_cloned) { + set_subs(&(*res_orig), &(*res_cloned)); + } + for (auto reg_orig = in.child_regions_begin(), reg_cloned = cloned->child_regions_begin(); + reg_orig != in.child_regions_end() && reg_cloned != cloned->child_regions_end(); + ++reg_orig, ++reg_cloned) { + for (auto p_orig = reg_orig->param_begin(), p_cloned = reg_cloned->param_begin(); + p_orig != reg_orig->param_end() && p_cloned != reg_cloned->param_end(); + ++p_orig, ++p_cloned) { + set_subs(&(*p_orig), &(*p_cloned)); + } + clone_region(*reg_orig, *reg_cloned); + } + return cloned; +} + +void inst_cloner::clone_region(region_node &source, region_node &target) { + for (auto &in_orig : source.insts()) { + target.insts().push_back(clone_instruction(in_orig).release()); + } +} + +} // namespace tinytc diff --git a/src/pass/clone.hpp b/src/pass/clone.hpp new file mode 100644 index 00000000..c48df7ac --- /dev/null +++ b/src/pass/clone.hpp @@ -0,0 +1,82 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CLONE_20241118_HPP +#define CLONE_20241118_HPP + +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "tinytc/types.h" + +#include +#include +#include + +namespace tinytc { + +class inst_cloner { + public: + auto operator()(alloca_inst &in) -> std::unique_ptr; + auto operator()(axpby_inst &in) -> std::unique_ptr; + auto operator()(arith_inst &in) -> std::unique_ptr; + auto operator()(arith_unary_inst &in) -> std::unique_ptr; + auto operator()(barrier_inst &in) -> std::unique_ptr; + auto operator()(builtin_inst &in) -> std::unique_ptr; + auto operator()(cast_inst &in) -> std::unique_ptr; + auto operator()(compare_inst &in) -> std::unique_ptr; + auto operator()(constant_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_apply_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_extract_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_insert_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_load_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_mul_add_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_prefetch_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_reduce_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_scale_inst &in) -> std::unique_ptr; + auto operator()(cooperative_matrix_store_inst &in) -> std::unique_ptr; + auto operator()(cumsum_inst &in) -> std::unique_ptr; + auto operator()(expand_inst &in) -> std::unique_ptr; + auto operator()(fuse_inst &in) -> std::unique_ptr; + auto operator()(load_inst &in) -> std::unique_ptr; + auto operator()(lifetime_stop_inst &in) -> std::unique_ptr; + auto operator()(gemm_inst &in) -> std::unique_ptr; + auto operator()(gemv_inst &in) -> std::unique_ptr; + auto operator()(ger_inst &in) -> std::unique_ptr; + auto operator()(for_inst &in) -> std::unique_ptr; + auto operator()(foreach_inst &in) -> std::unique_ptr; + auto operator()(hadamard_inst &in) -> std::unique_ptr; + auto operator()(if_inst &in) -> std::unique_ptr; + auto operator()(math_unary_inst &in) -> std::unique_ptr; + auto operator()(parallel_inst &in) -> std::unique_ptr; + auto operator()(size_inst &in) -> std::unique_ptr; + auto operator()(subgroup_broadcast_inst &in) -> std::unique_ptr; + auto operator()(subgroup_operation_inst &in) -> std::unique_ptr; + auto operator()(subview_inst &in) -> std::unique_ptr; + auto operator()(store_inst &in) -> std::unique_ptr; + auto operator()(sum_inst &in) -> std::unique_ptr; + auto operator()(yield_inst &in) -> std::unique_ptr; + + void reset_subs(); + void set_subs(tinytc_value_t in_val, tinytc_value_t out_val); + auto subs(tinytc_value_t val) -> tinytc_value_t; + + auto clone_instruction(inst_node &in) -> std::unique_ptr; + void clone_region(region_node &source, region_node &target); + + private: + template auto subs_value_range(T &&range) { + auto vec = std::vector(); + vec.reserve(range.size()); + for (auto &r : range) { + vec.emplace_back(subs(&r)); + } + return vec; + } + + std::unordered_map subs_map_; +}; + +} // namespace tinytc + +#endif // CLONE_20241118_HPP diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp new file mode 100644 index 00000000..fef81a44 --- /dev/null +++ b/src/pass/constant_folding.cpp @@ -0,0 +1,344 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/constant_folding.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "scalar_type.hpp" +#include "tinytc/tinytc.hpp" +#include "util/casting.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +template class unary_op_dispatcher { + private: + scalar_type switch_ty; + F computer; + + public: + unary_op_dispatcher(scalar_type sw_ty, F &&f) + : switch_ty{sw_ty}, computer{std::forward(f)} {} + + auto operator()(bool const &) -> fold_result { + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + } + auto operator()(std::int64_t const &A) -> fold_result { + switch (switch_ty) { + case scalar_type::i8: + return computer.template operator()(A); + case scalar_type::i16: + return computer.template operator()(A); + case scalar_type::i32: + return computer.template operator()(A); + case scalar_type::i64: + return computer.template operator()(A); + case scalar_type::index: + return computer.template operator()(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + }; + } + auto operator()(double const &A) -> fold_result { + switch (switch_ty) { + case scalar_type::bf16: + return computer.template operator()(A); + case scalar_type::f16: + return computer.template operator()(A); + case scalar_type::f32: + return computer.template operator()(A); + case scalar_type::f64: + return computer.template operator()(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + auto operator()(std::complex const &A) -> fold_result { + switch (switch_ty) { + case scalar_type::c32: + return computer.template operator()>(A); + case scalar_type::c64: + return computer.template operator()>(A); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } +}; + +template class binary_op_dispatcher { + private: + scalar_type switch_ty; + F computer; + + public: + binary_op_dispatcher(scalar_type sw_ty, F &&f) + : switch_ty{sw_ty}, computer{std::forward(f)} {} + + auto operator()(std::int64_t const &A, std::int64_t const &B) -> fold_result { + switch (switch_ty) { + case scalar_type::i8: + return computer.template operator()(A, B); + case scalar_type::i16: + return computer.template operator()(A, B); + case scalar_type::i32: + return computer.template operator()(A, B); + case scalar_type::i64: + return computer.template operator()(A, B); + case scalar_type::index: + return computer.template operator()(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + }; + } + auto operator()(double const &A, double const &B) -> fold_result { + switch (switch_ty) { + case scalar_type::bf16: + return computer.template operator()(A, B); + case scalar_type::f16: + return computer.template operator()(A, B); + case scalar_type::f32: + return computer.template operator()(A, B); + case scalar_type::f64: + return computer.template operator()(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + auto operator()(std::complex const &A, std::complex const &B) -> fold_result { + switch (switch_ty) { + case scalar_type::c32: + return computer.template operator()>(A, B); + case scalar_type::c64: + return computer.template operator()>(A, B); + default: + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + break; + } + } + template auto operator()(T const &, U const &) -> fold_result { + throw compilation_error(computer.loc, status::ir_scalar_mismatch); + } +}; + +constant_folding::constant_folding(bool unsafe_fp_math) : unsafe_fp_math_(unsafe_fp_math) {} + +auto constant_folding::get_memref_type(value_node const &v) const -> const memref_data_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + +auto constant_folding::operator()(inst_node &) -> fold_result { return tinytc_value_t{}; } + +auto constant_folding::operator()(arith_inst &in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + constant_inst *b_const = dyn_cast(op_b.defining_inst()); + + if (isa(*op_a.ty())) { + if ((a_const && !std::holds_alternative(a_const->value())) || + (b_const && !std::holds_alternative(b_const->value()))) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + if (a_const != nullptr && b_const != nullptr) { + return compute_binary_op{in.operation(), op_a.ty(), in.loc()}( + std::get(a_const->value()), std::get(b_const->value())); + } else if (a_const != nullptr) { + return compute_binop_identities{unsafe_fp_math_, in.operation(), op_b, true, + in.loc()}(std::get(a_const->value())); + } else if (b_const != nullptr) { + return compute_binop_identities{unsafe_fp_math_, in.operation(), op_a, false, + in.loc()}(std::get(b_const->value())); + } + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_scalar_or_boolean); + } + at = dyn_cast(ct->ty()); + } + + if (a_const != nullptr && b_const != nullptr) { + auto computer = compute_binary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); + } else if (a_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.operation(), op_b, true, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); + } else if (b_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.operation(), op_a, false, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), b_const->value()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(arith_unary_inst &in) -> fold_result { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return tinytc_value_t{}; + } + + if (isa(*op_a.ty())) { + if (!std::holds_alternative(a_const->value())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + return compute_unary_op{in.operation(), op_a.ty(), + in.loc()}(std::get(a_const->value())); + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_scalar); + } + at = dyn_cast(ct->ty()); + } + + auto computer = compute_unary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); +} + +auto constant_folding::operator()(cast_inst &in) -> fold_result { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return tinytc_value_t{}; + } + + auto rt = dyn_cast(in.result(0).ty()); + if (rt == nullptr) { + // Cast on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(in.result(0).ty()); + if (ct == nullptr) { + throw compilation_error(in.result(0).loc(), status::ir_expected_coopmatrix_or_scalar); + } + rt = dyn_cast(ct->ty()); + } + + return std::visit( + overloaded{[&](auto A) -> fold_result { return compute_cast(rt, A, in.loc()); }}, + a_const->value()); +} + +auto constant_folding::operator()(compare_inst &in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + constant_inst *b_const = dyn_cast(op_b.defining_inst()); + if (a_const == nullptr || b_const == nullptr) { + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + auto computer = compute_compare{in.cond(), in.result(0).ty(), in.loc()}; + auto dispatcher = binary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value(), b_const->value()); +} + +auto constant_folding::operator()(cooperative_matrix_scale_inst &in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_scalar); + } + + if (a_const != nullptr) { + auto computer = + compute_binop_identities{unsafe_fp_math_, arithmetic::mul, op_b, true, in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(math_unary_inst &in) -> fold_result { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const == nullptr) { + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + return tinytc_value_t{}; + } + + auto computer = compute_math_unary_op{in.operation(), op_a.ty(), in.loc()}; + auto dispatcher = unary_op_dispatcher{at->ty(), std::move(computer)}; + return std::visit(std::move(dispatcher), a_const->value()); +} + +auto constant_folding::operator()(size_inst &in) -> fold_result { + auto mode_size = visit( + overloaded{[&](group_data_type const &g) -> std::int64_t { return g.size(); }, + [&](memref_data_type const &m) -> std::int64_t { return m.shape(in.mode()); }, + [&](auto const &) -> std::int64_t { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + }}, + *in.operand().ty()); + + if (!is_dynamic_value(mode_size)) { + return make_constant( + mode_size, scalar_data_type::get(in.operand().context(), scalar_type::index), in.loc()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(subgroup_broadcast_inst &in) -> fold_result { + auto &op_a = in.a(); + + constant_inst *a_const = dyn_cast(op_a.defining_inst()); + if (a_const != nullptr) { + return &a_const->result(0); + } + return tinytc_value_t{}; +} + +} // namespace tinytc diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp new file mode 100644 index 00000000..33ff5c75 --- /dev/null +++ b/src/pass/constant_folding.hpp @@ -0,0 +1,561 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONSTANT_FOLDING_HELPER_20241011_HPP +#define CONSTANT_FOLDING_HELPER_20241011_HPP + +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "scalar_type.hpp" +#include "support/fp_util.hpp" // IWYU pragma: keep +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +using fold_result = std::variant; + +class constant_folding { + public: + constant_folding(bool unsafe_fp_math); + + auto operator()(inst_node &) -> fold_result; + auto operator()(arith_inst &) -> fold_result; + auto operator()(arith_unary_inst &) -> fold_result; + auto operator()(cooperative_matrix_scale_inst &) -> fold_result; + auto operator()(cast_inst &) -> fold_result; + auto operator()(compare_inst &) -> fold_result; + auto operator()(math_unary_inst &) -> fold_result; + auto operator()(size_inst &in) -> fold_result; + auto operator()(subgroup_broadcast_inst &in) -> fold_result; + + private: + auto get_memref_type(value_node const &v) const -> const memref_data_type *; + + bool unsafe_fp_math_; +}; + +struct compute_unary_op { + arithmetic_unary operation; + data_type ty; + location const &loc; + + auto operator()(bool a) -> fold_result { + bool val = false; + switch (operation) { + case arithmetic_unary::not_: + val = !a; + break; + default: + throw compilation_error(loc, status::ir_boolean_unsupported); + } + return make_constant(val, ty, loc); + } + + template + requires(std::is_integral_v && !std::is_same_v) + auto operator()(T a) -> fold_result { + T val = 0; + switch (operation) { + case arithmetic_unary::abs: + val = a < 0 ? -a : a; + break; + case arithmetic_unary::neg: + val = -a; + break; + case arithmetic_unary::not_: + val = ~a; + break; + default: + throw compilation_error(loc, status::ir_int_unsupported); + } + return make_constant(val, ty, loc); + } + + template + requires(is_floating_point_or_lp_float_v) + auto operator()(T a) -> fold_result { + T val = 0; + switch (operation) { + case arithmetic_unary::abs: + val = a < T{0} ? -a : a; + break; + case arithmetic_unary::neg: + val = -a; + break; + default: + throw compilation_error(loc, status::ir_fp_unsupported); + } + return make_constant(val, ty, loc); + } + + template + requires(is_complex_v) + auto operator()(U const &A) -> fold_result { + const auto neg_conj = [&](T const &a) { + T val = {}; + switch (operation) { + case arithmetic_unary::neg: + val = -a; + break; + case arithmetic_unary::conj: + val = std::conj(a); + break; + default: + return inst{nullptr}; + } + return make_constant(val, ty, loc); + }; + const auto abs_im_re = [&](T const &a) -> inst { + typename T::value_type val = {}; + switch (operation) { + case arithmetic_unary::abs: + val = std::abs(a); + break; + case arithmetic_unary::im: + val = std::imag(a); + break; + case arithmetic_unary::re: + val = std::real(a); + break; + default: + return inst{nullptr}; + } + scalar_data_type *sty = dyn_cast(ty); + if (!sty) { + throw compilation_error(loc, status::ir_expected_scalar); + } + auto cst_ty = scalar_data_type::get(sty->context(), component_type(sty->ty())); + return make_constant(val, cst_ty, loc); + }; + + const auto a = static_cast(A); + auto result = neg_conj(a); + if (result) { + return result; + } + result = abs_im_re(a); + if (result) { + return result; + } + throw compilation_error(loc, status::ir_complex_unsupported); + } +}; + +struct compute_binary_op { + arithmetic operation; + data_type ty; + location const &loc; + + auto operator()(bool a, bool b) -> fold_result { + bool val = false; + switch (operation) { + case arithmetic::and_: + val = a && b; + break; + case arithmetic::or_: + val = a || b; + break; + case arithmetic::xor_: + val = a != b; + break; + default: + throw compilation_error(loc, status::ir_boolean_unsupported); + } + return make_constant(val, ty, loc); + } + + template + requires(std::is_integral_v && !std::is_same_v) + auto operator()(T a, T b) -> fold_result { + T val = 0; + switch (operation) { + case arithmetic::add: + val = a + b; + break; + case arithmetic::sub: + val = a - b; + break; + case arithmetic::mul: + val = a * b; + break; + case arithmetic::div: + val = a / b; + break; + case arithmetic::rem: + val = a % b; + break; + case arithmetic::shl: + val = a << b; + break; + case arithmetic::shr: + val = a >> b; + break; + case arithmetic::and_: + val = a & b; + break; + case arithmetic::or_: + val = a | b; + break; + case arithmetic::xor_: + val = a ^ b; + break; + case arithmetic::min: + val = std::min(a, b); + break; + case arithmetic::max: + val = std::max(a, b); + break; + } + return make_constant(val, ty, loc); + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A, U const &B) -> fold_result { + const auto a = static_cast(A); + const auto b = static_cast(B); + T val = {}; + switch (operation) { + case arithmetic::add: + val = a + b; + break; + case arithmetic::sub: + val = a - b; + break; + case arithmetic::mul: + val = a * b; + break; + case arithmetic::div: + val = a / b; + break; + case arithmetic::rem: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::fmod(a, b); + } + break; + case arithmetic::min: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::min(a, b); + } + break; + case arithmetic::max: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::max(a, b); + } + break; + default: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } + throw compilation_error(loc, status::ir_fp_unsupported); + break; + } + return make_constant(val, ty, loc); + } +}; + +struct compute_binop_identities { + bool unsafe_fp_math; + arithmetic operation; + tinytc_value &operand; + bool is_second_operand; + location const &loc; + + auto operator()(bool a) -> fold_result { + switch (operation) { + case arithmetic::and_: + if (!a) { + return make_constant(false, operand.ty(), loc); + } + break; + case arithmetic::or_: + case arithmetic::xor_: + if (!a) { + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } + + template + requires(std::is_integral_v && !std::is_same_v) + auto operator()(T a) -> fold_result { + switch (operation) { + case arithmetic::add: + if (a == T{0}) { // operand + 0 or 0 + operand + return &operand; + } + break; + case arithmetic::sub: + if (a == T{0} && !is_second_operand) { // operand - 0 + return &operand; + } + break; + case arithmetic::mul: + if (a == T{0}) { // operand * 0 or 0 * operand + return make_constant(T{0}, operand.ty(), loc); + } else if (a == T{1}) { // operand * 1 or 1 * operand + return &operand; + } + break; + case arithmetic::div: + if (a == T{1} && !is_second_operand) { // operand / 1 + return &operand; + } + break; + case arithmetic::rem: + if (a == T{1} && !is_second_operand) { // operand % 1 + return make_constant(T{0}, operand.ty(), loc); + } + break; + case arithmetic::shl: + case arithmetic::shr: + if (a == T{0}) { + if (is_second_operand) { // 0 << operand + return make_constant(T{0}, operand.ty(), loc); + } else { // operand << 0 + return &operand; + } + } + break; + case arithmetic::and_: + if (a == T{0}) { + return make_constant(T{0}, operand.ty(), loc); + } + break; + case arithmetic::or_: + case arithmetic::xor_: + if (a == T{0}) { + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A) -> fold_result { + const auto a = static_cast(A); + switch (operation) { + case arithmetic::add: + if (a == T{0}) { // operand + 0 or 0 + operand + return &operand; + } + break; + case arithmetic::sub: + if (a == T{0} && !is_second_operand) { // operand - 0 + return &operand; + } + break; + case arithmetic::mul: + if (unsafe_fp_math && a == T{0}) { // operand * 0 or 0 * operand + return make_constant(T{0}, operand.ty(), loc); + } else if (a == T{1}) { // operand * 1 or 1 * operand + return &operand; + } + break; + case arithmetic::div: + if (a == T{1} && !is_second_operand) { // operand / 1 + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } +}; + +struct compute_compare { + cmp_condition cond; + data_type ty; + location const &loc; + + template + requires(std::is_integral_v || is_floating_point_or_lp_float_v) + auto operator()(T a, T b) -> fold_result { + bool val = false; + switch (cond) { + case cmp_condition::eq: + val = (a == b); + break; + case cmp_condition::ne: + val = (a != b); + break; + case cmp_condition::gt: + val = (a > b); + break; + case cmp_condition::ge: + val = (a >= b); + break; + case cmp_condition::lt: + val = (a < b); + break; + case cmp_condition::le: + val = (a <= b); + break; + }; + return make_constant(val, ty, loc); + } + + template + auto operator()(std::complex const &A, std::complex const &B) -> fold_result { + const auto a = static_cast(A); + const auto b = static_cast(B); + bool val = false; + switch (cond) { + case cmp_condition::eq: + val = (a == b); + break; + case cmp_condition::ne: + val = (a != b); + break; + default: + throw compilation_error(loc, status::ir_complex_unsupported); + break; + }; + return make_constant(val, ty, loc); + } +}; + +template struct value_cast_impl { + auto operator()(U const &u) { return static_cast(u); } +}; + +template struct value_cast_impl, U> { + auto operator()(U const &u) { return std::complex{static_cast(u), static_cast(0)}; } +}; + +template struct value_cast_impl, std::complex> { + auto operator()(std::complex const &u) { return static_cast>(u); } +}; + +template struct value_cast_impl { + auto operator()(U const &u) { return u != U{}; } +}; + +template struct value_cast_impl> { + auto operator()(std::complex const &) -> bool { throw status::ir_forbidden_cast; } +}; + +template struct value_cast_impl> { + auto operator()(std::complex const &) -> T { throw status::ir_forbidden_cast; } +}; + +template auto value_cast(U const &u) { return value_cast_impl{}(u); } + +template +auto compute_cast(scalar_data_type *to_ty, T A, location const &loc) -> fold_result { + switch (to_ty->ty()) { + case scalar_type::i8: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i16: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i32: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::i64: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::index: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::bf16: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::f16: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::f32: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::f64: + return make_constant(value_cast(A), to_ty, loc); + case scalar_type::c32: + return make_constant(value_cast>(A), to_ty, loc); + case scalar_type::c64: + return make_constant(value_cast>(A), to_ty, loc); + }; + return {}; +}; + +struct compute_math_unary_op { + math_unary operation; + data_type ty; + location const &loc; + + template + requires(std::is_integral_v) + auto operator()(T) -> fold_result { + throw compilation_error(loc, status::ir_int_unsupported); + } + + template + requires(is_floating_point_or_lp_float_v) + auto operator()(T a) -> fold_result { + T val = {}; + switch (operation) { + case math_unary::cos: + case math_unary::native_cos: + val = std::cos(a); + break; + case math_unary::sin: + case math_unary::native_sin: + val = std::sin(a); + break; + case math_unary::exp: + case math_unary::native_exp: + val = std::exp(a); + break; + case math_unary::exp2: + case math_unary::native_exp2: + val = std::exp2(a); + break; + default: + throw compilation_error(loc, status::ir_fp_unsupported); + } + return make_constant(val, ty, loc); + } + + template + requires(is_complex_v) + auto operator()(U const &a) -> fold_result { + T val = {}; + switch (operation) { + case math_unary::exp: + case math_unary::native_exp: + val = std::exp(a); + break; + case math_unary::exp2: + case math_unary::native_exp2: + val = std::pow(T{std::complex{2.0, 0.0}}, a); + break; + default: + throw compilation_error(loc, status::ir_complex_unsupported); + } + return make_constant(val, ty, loc); + } +}; + +} // namespace tinytc + +#endif // CONSTANT_FOLDING_HELPER_20241011_HPP diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp new file mode 100644 index 00000000..66b97296 --- /dev/null +++ b/src/pass/constant_propagation.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/constant_propagation.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/constant_folding.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" + +#include + +namespace tinytc { + +void constant_propagation_pass::run_on_function(function_node &fn) { run_on_region(fn.body()); } + +void constant_propagation_pass::run_on_region(region_node ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + for (auto &subreg : it->child_regions()) { + run_on_region(subreg); + } + + const auto update_uses = [&it](tinytc_value_t with) { + if (it->num_results() != 1) { + throw status::internal_compiler_error; + } + auto r = it->result_begin(); + auto u = r->use_begin(); + while (r->has_uses()) { + u->set(with); + u = r->use_begin(); + } + if (r->has_uses()) { + throw status::internal_compiler_error; + } + }; + + fold_result fr = visit(constant_folding{unsafe_fp_math_}, *it); + std::visit(overloaded{[&](tinytc_value_t val) { + if (val) { + update_uses(val); + } + }, + [&](inst &new_constant) { + if (new_constant) { + if (new_constant->num_results() != 1) { + throw status::internal_compiler_error; + } + update_uses(&*new_constant->result_begin()); + // insert new constant + it = reg.insts().insert(it, new_constant.release()); + // skip over constant + ++it; + } + }}, + fr); + } +} + +void constant_propagation_pass::set_opt_flag(tinytc::optflag flag, bool enabled) { + if (flag == tinytc::optflag::unsafe_fp_math) { + unsafe_fp_math_ = enabled; + } +} + +} // namespace tinytc diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp new file mode 100644 index 00000000..9bda83f3 --- /dev/null +++ b/src/pass/constant_propagation.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONSTANT_PROPAGATION_20240807_HPP +#define CONSTANT_PROPAGATION_20240807_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +enum class optflag; + +class constant_propagation_pass { + public: + void run_on_function(::tinytc_func &fn); + void run_on_region(::tinytc_region ®); + + void set_opt_flag(tinytc::optflag flag, bool enabled); + + private: + bool unsafe_fp_math_ = false; +}; + +} // namespace tinytc + +#endif // CONSTANT_PROPAGATION_20240807_HPP diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp new file mode 100644 index 00000000..815697ab --- /dev/null +++ b/src/pass/convert_to_spirv.cpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/convert_to_spirv.hpp" +#include "spv/converter.hpp" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { + +convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw status::invalid_arguments; + } +} + +auto convert_to_spirv_pass::run_on_program(program_node const &p) -> spv_mod { + return spv::convert_prog_to_spirv(p, *info_); +} + +} // namespace tinytc diff --git a/src/pass/convert_to_spirv.hpp b/src/pass/convert_to_spirv.hpp new file mode 100644 index 00000000..6cc3de3d --- /dev/null +++ b/src/pass/convert_to_spirv.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERT_TO_SPIRV_20241029_HPP +#define CONVERT_TO_SPIRV_20241029_HPP + +#include "node/program_node.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +namespace tinytc { + +class convert_to_spirv_pass { + public: + convert_to_spirv_pass(::tinytc_core_info const *info); + + auto run_on_program(program_node const &p) -> spv_mod; + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // CONVERT_TO_SPIRV_20241029_HPP diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp new file mode 100644 index 00000000..f4e6a0e3 --- /dev/null +++ b/src/pass/dead_code_elimination.cpp @@ -0,0 +1,84 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dead_code_elimination.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" + +#include +#include + +namespace tinytc { + +class dead_code_analysis { + public: + auto operator()(inst_node &in) -> bool; + auto operator()(if_inst &in) -> bool; + auto operator()(for_inst &in) -> bool; +}; + +auto dead_code_analysis::operator()(inst_node &in) -> bool { + /* Instruction have side effects if either of the following is true + * + * - More than one child region (if, for, foreach, parallel, ...) + * - Instruction does not have results (barrier, GEMM, GER, ...) + * + */ + const bool has_side_effects = in.num_child_regions() > 0 || in.num_results() == 0; + + bool any_result_has_uses = false; + for (auto &res : in.results()) { + any_result_has_uses = any_result_has_uses || res.has_uses(); + } + + return !has_side_effects && !any_result_has_uses; +} + +auto dead_code_analysis::operator()(if_inst &in) -> bool { + constant_inst *cond_const = dyn_cast(in.condition().defining_inst()); + if (in.num_results() == 0 && cond_const) { + // If-instruction is dead if condition is constant and false + return std::holds_alternative(cond_const->value()) && + std::get(cond_const->value()) == false; + } + + return false; +} + +auto dead_code_analysis::operator()(for_inst &in) -> bool { + constant_inst *from_const = dyn_cast(in.from().defining_inst()); + constant_inst *to_const = dyn_cast(in.to().defining_inst()); + if (in.num_results() == 0 && from_const && to_const) { + // For-instruction is dead if from >= to + return std::holds_alternative(from_const->value()) && + std::holds_alternative(to_const->value()) && + std::get(from_const->value()) >= + std::get(to_const->value()); + } + return false; +} + +void dead_code_elimination_pass::run_on_function(function_node &fn) { run_on_region(fn.body()); } + +void dead_code_elimination_pass::run_on_region(region_node ®) { + auto prev_it = reg.end(); + while (prev_it != reg.begin()) { + auto it = --prev_it; + auto is_dead = visit(dead_code_analysis{}, *it); + if (is_dead) { + prev_it = reg.insts().erase(it); + } else { + for (auto &subreg : it->child_regions()) { + run_on_region(subreg); + } + } + } +} + +} // namespace tinytc diff --git a/src/pass/dead_code_elimination.hpp b/src/pass/dead_code_elimination.hpp new file mode 100644 index 00000000..e56d27f8 --- /dev/null +++ b/src/pass/dead_code_elimination.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DEAD_CODE_ELIMINATION_20241007_HPP +#define DEAD_CODE_ELIMINATION_20241007_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class dead_code_elimination_pass { + public: + void run_on_function(::tinytc_func &fn); + void run_on_region(::tinytc_region ®); +}; + +} // namespace tinytc + +#endif // DEAD_CODE_ELIMINATION_20241007_HPP diff --git a/src/pass/dump_cfg.cpp b/src/pass/dump_cfg.cpp new file mode 100644 index 00000000..710e51ec --- /dev/null +++ b/src/pass/dump_cfg.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_cfg.hpp" +#include "analysis/cfg.hpp" +#include "pass/dump_ir.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_cfg_pass::dump_cfg_pass(std::ostream &os) : os_(&os) {} + +void dump_cfg_pass::run_on_function(function_node &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + + *os_ << "digraph " << fn.name() << " {" << std::endl; + + auto cfg = get_control_flow_graph(fn.body()); + auto q = cfg.node_queue(); + for (; !q.empty(); q.pop()) { + auto &node = q.front(); + + *os_ << reinterpret_cast(node) << " [label=\""; + dump_ir.run_on_instruction(*node); + *os_ << "\"]" << std::endl; + + for (auto &neigh : cfg.successors(node)) { + *os_ << reinterpret_cast(node) << " -> " + << reinterpret_cast(neigh) << std::endl; + } + } + + *os_ << "}" << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_cfg.hpp b/src/pass/dump_cfg.hpp new file mode 100644 index 00000000..65d765d8 --- /dev/null +++ b/src/pass/dump_cfg.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_BACKWARD_CFG_20240919_HPP +#define DUMP_BACKWARD_CFG_20240919_HPP + +#include "node/function_node.hpp" + +#include + +namespace tinytc { + +class dump_cfg_pass { + public: + dump_cfg_pass(std::ostream &os); + + void run_on_function(function_node &fn); + + private: + std::ostream *os_; +}; + +} // namespace tinytc + +#endif // DUMP_BACKWARD_CFG_20240919_HPP diff --git a/src/pass/dump_def_use.cpp b/src/pass/dump_def_use.cpp new file mode 100644 index 00000000..285c194e --- /dev/null +++ b/src/pass/dump_def_use.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_def_use.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/dump_ir.hpp" +#include "support/walk.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +dump_def_use_pass::dump_def_use_pass(std::ostream &os) : os_(&os) {} + +void dump_def_use_pass::run_on_function(function_node const &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + dump_ir.init_slot_tracker(fn); + + *os_ << "Def-use in @" << fn.name() << std::endl; + walk(fn, [&](inst_node const &i) { + if (i.num_results() > 0 || i.num_child_regions() > 0) { + *os_ << "> "; + visit(dump_ir, i); + *os_ << std::endl; + auto const def_use = [&](value_node const &v) { + *os_ << " def "; + dump_ir.dump_val(v); + *os_ << std::endl; + for (auto &u : v.uses()) { + *os_ << " > "; + visit(dump_ir, *u.owner()); + *os_ << std::endl; + } + }; + for (auto &res : i.results()) { + def_use(res); + } + for (auto ® : i.child_regions()) { + for (auto &p : reg.params()) { + def_use(p); + } + } + } + }); + *os_ << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_def_use.hpp b/src/pass/dump_def_use.hpp new file mode 100644 index 00000000..ccec17cb --- /dev/null +++ b/src/pass/dump_def_use.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_DEF_USE_20241002_HPP +#define DUMP_DEF_USE_20241002_HPP + +#include "node/function_node.hpp" + +#include + +namespace tinytc { + +class dump_def_use_pass { + public: + dump_def_use_pass(std::ostream &os); + + void run_on_function(function_node const &fn); + + private: + std::ostream *os_; +}; + +} // namespace tinytc + +#endif // DUMP_DEF_USE_20241002_HPP diff --git a/src/pass/dump_gcd.cpp b/src/pass/dump_gcd.cpp new file mode 100644 index 00000000..e9b32c40 --- /dev/null +++ b/src/pass/dump_gcd.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_gcd.hpp" +#include "analysis/gcd.hpp" +#include "device_info.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/dump_ir.hpp" +#include "support/walk.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_gcd_pass::dump_gcd_pass(std::ostream &os, ::tinytc_core_info const *info) + : os_(&os), info_{info} {} + +void dump_gcd_pass::run_on_function(function_node const &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + dump_ir.init_slot_tracker(fn); + auto gcd = gcd_analysis{info_->alignment()}.run_on_function(fn); + + auto const dump_range = [&](auto begin, auto end) { + *os_ << "["; + for (auto it = begin; it != end; ++it) { + if (it != begin) { + *os_ << ","; + } + *os_ << *it; + } + *os_ << "]"; + }; + auto const dump_gcd = [&](value_node const &v) { + auto g = gcd.get_if(v); + if (g) { + *os_ << " gcd("; + dump_ir.dump_val(v); + *os_ << ") = " << *g << std::endl; + } + auto mi = gcd.get_memref_if(v); + if (mi) { + *os_ << " offset_gcd("; + dump_ir.dump_val(v); + *os_ << ") = " << mi->offset_gcd() << std::endl; + *os_ << " shape_gcd("; + dump_ir.dump_val(v); + *os_ << ") = "; + dump_range(mi->shape_gcd_begin(), mi->shape_gcd_end()); + *os_ << std::endl << " stride_gcd("; + dump_ir.dump_val(v); + *os_ << ") = "; + dump_range(mi->stride_gcd_begin(), mi->stride_gcd_end()); + *os_ << std::endl; + } + }; + + *os_ << "GCD in @" << fn.name() << std::endl; + for (auto &p : fn.params()) { + dump_gcd(p); + } + walk(fn, [&](inst_node const &i) { + if (i.num_results() > 0 || i.num_child_regions() > 0) { + *os_ << "> "; + visit(dump_ir, i); + *os_ << std::endl; + for (auto &res : i.results()) { + dump_gcd(res); + } + for (auto ® : i.child_regions()) { + for (auto &p : reg.params()) { + dump_gcd(p); + } + } + } + }); + *os_ << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_gcd.hpp b/src/pass/dump_gcd.hpp new file mode 100644 index 00000000..86abc4d6 --- /dev/null +++ b/src/pass/dump_gcd.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_GCD_20241203_HPP +#define DUMP_GCD_20241203_HPP + +#include "node/function_node.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class dump_gcd_pass { + public: + dump_gcd_pass(std::ostream &os, ::tinytc_core_info const *info); + + void run_on_function(function_node const &fn); + + private: + std::ostream *os_; + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // DUMP_GCD_20241203_HPP diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp new file mode 100644 index 00000000..70c242cf --- /dev/null +++ b/src/pass/dump_ir.cpp @@ -0,0 +1,680 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_ir.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_ir_pass::dump_ir_pass(std::ostream &os, int level_limit) : os_(&os), lvl_limit_(level_limit) {} + +/* Attribute nodes */ +void dump_ir_pass::operator()(array_attr const &a) { + *os_ << "["; + do_with_infix(a.begin(), a.end(), [&](auto const &a) { visit(*this, *a); }); + *os_ << "]"; +} +void dump_ir_pass::operator()(boolean_attr const &a) { *os_ << (a.value() ? "true" : "false"); } +void dump_ir_pass::operator()(dictionary_attr const &a) { + auto const is_keyword = [](std::string_view str) { + switch (fnv1a(str)) { + case "alignment"_fnv1a: + case "shape_gcd"_fnv1a: + case "stride_gcd"_fnv1a: + case "subgroup_size"_fnv1a: + case "unroll"_fnv1a: + case "work_group_size"_fnv1a: + return true; + default: + return false; + } + }; + auto const dump_name = [&](attr a) { + if (auto s = dyn_cast(a); s) { + if (is_keyword(s->str())) { + *os_ << s->str(); + } else { + this->operator()(*s); + } + } else { + throw status::ir_expected_string_attribute; + } + }; + *os_ << "{"; + do_with_infix( + a.begin(), a.end(), + [&](auto const &a) { + dump_name(a.name); + *os_ << "="; + visit(*this, *a.attr); + }, + ", "); + *os_ << "}"; +} +void dump_ir_pass::operator()(integer_attr const &a) { *os_ << a.value(); } +void dump_ir_pass::operator()(string_attr const &a) { *os_ << "\"" << a.str() << "\""; } + +/* Data type nodes */ +void dump_ir_pass::operator()(void_data_type const &) { *os_ << "void"; } +void dump_ir_pass::operator()(boolean_data_type const &) { *os_ << "bool"; } +void dump_ir_pass::operator()(coopmatrix_data_type const &ct) { + *os_ << "coopmatrix<"; + visit(*this, *ct.ty()); + *os_ << "x" << ct.rows() << "x" << ct.cols() << "," << to_string(ct.use()) << ">"; +} +void dump_ir_pass::operator()(group_data_type const &g) { + auto const val = [&](std::int64_t v) -> std::ostream & { + if (is_dynamic_value(v)) { + return *os_ << "?"; + } + return *os_ << v; + }; + *os_ << "group<"; + visit(*this, *g.ty()); + *os_ << "x"; + val(g.size()); + if (g.offset() != 0) { + *os_ << ", offset: "; + val(g.offset()); + } + *os_ << ">"; +} +void dump_ir_pass::operator()(memref_data_type const &d) { + auto const val = [&](std::int64_t v) -> std::ostream & { + if (is_dynamic_value(v)) { + return *os_ << "?"; + } + return *os_ << v; + }; + *os_ << "memref<" << to_string(d.element_ty()); + for (auto const &s : d.shape()) { + *os_ << "x"; + val(s); + } + if (!d.is_canonical_stride()) { + *os_ << ",strided<"; + do_with_infix(d.stride().begin(), d.stride().end(), [&](auto const &a) { val(a); }); + *os_ << ">"; + } + if (d.addrspace() != address_space::global) { + *os_ << "," << to_string(d.addrspace()); + } + *os_ << ">"; +} +void dump_ir_pass::operator()(scalar_data_type const &s) { *os_ << to_string(s.ty()); } + +/* Value nodes */ +void dump_ir_pass::dump_val(value_node const &v) { + *os_ << "%" << v.name(); + auto const slot = tracker_.get_slot(v); + if (slot >= 0) { + *os_ << slot; + } +} + +/* Inst nodes */ +void dump_ir_pass::dump_blas_a2(blas_a2_inst const &g) { + if (g.atomic()) { + *os_ << ".atomic"; + } + *os_ << ' '; + dump_val(g.alpha()); + *os_ << ", "; + dump_val(g.A()); + *os_ << ", "; + dump_val(g.beta()); + *os_ << ", "; + dump_val(g.B()); +} + +void dump_ir_pass::dump_blas_a3(blas_a3_inst const &g) { + if (g.atomic()) { + *os_ << ".atomic"; + } + *os_ << ' '; + dump_val(g.alpha()); + *os_ << ", "; + dump_val(g.A()); + *os_ << ", "; + dump_val(g.B()); + *os_ << ", "; + dump_val(g.beta()); + *os_ << ", "; + dump_val(g.C()); +} + +void dump_ir_pass::operator()(alloca_inst const &a) { + dump_val(a.result(0)); + *os_ << " = alloca : "; + visit(*this, *a.result()->ty()); +} + +void dump_ir_pass::operator()(axpby_inst const &a) { + *os_ << "axpby"; + *os_ << "." << to_string(a.tA()); + dump_blas_a2(static_cast(a)); +} + +void dump_ir_pass::operator()(arith_inst const &a) { + dump_val(a.result(0)); + *os_ << " = arith." << to_string(a.operation()) << " "; + dump_val(a.a()); + *os_ << ", "; + dump_val(a.b()); + *os_ << " : "; + visit(*this, *a.result(0).ty()); +} + +void dump_ir_pass::operator()(arith_unary_inst const &a) { + dump_val(a.result(0)); + *os_ << " = arith." << to_string(a.operation()) << " "; + dump_val(a.a()); + *os_ << " : "; + visit(*this, *a.result(0).ty()); +} + +void dump_ir_pass::operator()(barrier_inst const &b) { + *os_ << "barrier"; + if (b.has_fence(address_space::global)) { + *os_ << ".global"; + } + if (b.has_fence(address_space::local)) { + *os_ << ".local"; + } +} + +void dump_ir_pass::operator()(builtin_inst const &in) { + dump_val(in.result(0)); + *os_ << " = builtin." << to_string(in.builtin_type()) << " : "; + visit(*this, *in.result(0).ty()); +} + +void dump_ir_pass::operator()(cast_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cast "; + dump_val(c.a()); + *os_ << " : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(compare_inst const &a) { + dump_val(a.result(0)); + *os_ << " = cmp." << to_string(a.cond()) << " "; + dump_val(a.a()); + *os_ << ", "; + dump_val(a.b()); + *os_ << " : "; + visit(*this, *a.result(0).ty()); +} + +void dump_ir_pass::operator()(constant_inst const &c) { + dump_val(c.result(0)); + *os_ << " = constant "; + std::visit(overloaded{ + [&](bool b) { *os_ << (b ? "true" : "false"); }, + [&](std::int64_t i) { + if (is_dynamic_value(i)) { + *os_ << "?"; + } else { + *os_ << i; + } + }, + [&](double d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << d; + os_->flags(flags); + }, + [&](std::complex d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << "[" << d.real() << "," << d.imag() << "]"; + os_->flags(flags); + }, + }, + c.value()); + *os_ << " : "; + visit(*this, *c.result()->ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_apply_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_apply ("; + dump_val(c.row()); + *os_ << ","; + dump_val(c.col()); + *os_ << ","; + dump_val(c.val()); + *os_ << ") in "; + dump_val(c.a()); + *os_ << " -> "; + visit(*this, *c.result(0).ty()); + dump_region(c.body()); +} + +void dump_ir_pass::operator()(cooperative_matrix_extract_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_extract "; + dump_val(c.mat()); + *os_ << "[" << c.index() << "] : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_insert_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_insert "; + dump_val(c.val()); + *os_ << ", "; + dump_val(c.mat()); + *os_ << "[" << c.index() << "] : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_load_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_load"; + *os_ << "." << to_string(c.t()); + if (c.checked() != checked_flag::none) { + *os_ << "." << to_string(c.checked()); + } + *os_ << " "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "] : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_mul_add_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_mul_add "; + dump_val(c.a()); + *os_ << ", "; + dump_val(c.b()); + *os_ << ", "; + dump_val(c.c()); + *os_ << " : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_prefetch_inst const &c) { + *os_ << "cooperative_matrix_prefetch "; + *os_ << c.cache_level(); + *os_ << ", "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "], "; + *os_ << c.rows(); + *os_ << ", "; + *os_ << c.cols(); +} + +void dump_ir_pass::operator()(cooperative_matrix_reduce_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_reduce."; + *os_ << to_string(c.arith()) << "." << to_string(c.mode()) << " "; + dump_val(c.a()); + *os_ << " : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_scale_inst const &c) { + dump_val(c.result(0)); + *os_ << " = cooperative_matrix_scale "; + dump_val(c.a()); + *os_ << ", "; + dump_val(c.b()); + *os_ << " : "; + visit(*this, *c.result(0).ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_store_inst const &c) { + *os_ << "cooperative_matrix_store"; + if (c.checked() != checked_flag::none) { + *os_ << "." << to_string(c.checked()); + } + if (c.flag() != store_flag::regular) { + *os_ << '.' << to_string(c.flag()); + } + *os_ << " "; + dump_val(c.val()); + *os_ << ", "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "]"; +} + +void dump_ir_pass::operator()(cumsum_inst const &in) { + *os_ << "cumsum"; + if (in.atomic()) { + *os_ << ".atomic"; + } + *os_ << ' '; + dump_val(in.alpha()); + *os_ << ", "; + dump_val(in.A()); + *os_ << ", " << in.mode() << ", "; + dump_val(in.beta()); + *os_ << ", "; + dump_val(in.B()); +} + +void dump_ir_pass::operator()(expand_inst const &e) { + dump_val(e.result(0)); + *os_ << " = expand "; + dump_val(e.operand()); + *os_ << "[" << e.expanded_mode() << "->"; + auto const &ses = e.static_expand_shape(); + auto es = e.expand_shape(); + for (std::size_t i = 0, j = 0; i < ses.size(); ++i) { + if (i != 0) { + *os_ << " x "; + } + if (is_dynamic_value(ses[i])) { + dump_val(es[j++]); + } else { + *os_ << ses[i]; + } + } + *os_ << "] : "; + visit(*this, *e.result(0).ty()); +} + +void dump_ir_pass::operator()(fuse_inst const &f) { + dump_val(f.result(0)); + *os_ << " = fuse "; + dump_val(f.operand()); + *os_ << "[" << f.from() << "," << f.to() << "]"; + *os_ << " : "; + visit(*this, *f.result(0).ty()); +} + +void dump_ir_pass::operator()(load_inst const &e) { + dump_val(e.result(0)); + *os_ << " = load "; + dump_val(e.operand()); + *os_ << "["; + do_with_infix(e.index_list().begin(), e.index_list().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << "] : "; + visit(*this, *e.result(0).ty()); +} + +void dump_ir_pass::operator()(lifetime_stop_inst const &l) { + *os_ << "lifetime_stop "; + dump_val(l.object()); +} + +void dump_ir_pass::operator()(gemm_inst const &g) { + *os_ << "gemm"; + *os_ << "." << to_string(g.tA()); + *os_ << "." << to_string(g.tB()); + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(gemv_inst const &g) { + *os_ << "gemv"; + *os_ << "." << to_string(g.tA()); + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(ger_inst const &g) { + *os_ << "ger"; + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(for_inst const &in) { + if (in.num_results() > 0) { + do_with_infix(in.result_begin(), in.result_end(), [this](auto const &i) { dump_val(i); }); + *os_ << " = "; + } + *os_ << "for "; + dump_val(in.loop_var()); + *os_ << ":"; + visit(*this, *in.loop_var().ty()); + *os_ << "="; + dump_val(in.from()); + *os_ << ","; + dump_val(in.to()); + if (in.has_step()) { + *os_ << ","; + dump_val(in.step()); + } + if (in.num_results() > 0) { + *os_ << " init("; + for (std::int64_t i = 0; i < in.num_results(); ++i) { + if (i != 0) { + *os_ << ","; + } + dump_val(in.iter_arg(i)); + *os_ << "="; + dump_val(in.iter_init(i)); + } + *os_ << ") -> ("; + do_with_infix(in.result_begin(), in.result_end(), + [this](auto const &i) { visit(*this, *i.ty()); }); + *os_ << ")"; + } + *os_ << " "; + dump_region(in.body()); + if (in.attr()) { + *os_ << " "; + visit(*this, *in.attr()); + } +} + +void dump_ir_pass::operator()(foreach_inst const &in) { + *os_ << "foreach ("; + do_with_infix(in.loop_vars().begin(), in.loop_vars().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << "):"; + visit(*this, *in.loop_vars().begin()->ty()); + *os_ << "=("; + do_with_infix(in.from().begin(), in.from().end(), [this](auto const &i) { dump_val(i); }); + *os_ << "),("; + do_with_infix(in.to().begin(), in.to().end(), [this](auto const &i) { dump_val(i); }); + *os_ << ") "; + dump_region(in.body()); +} + +void dump_ir_pass::operator()(hadamard_inst const &g) { + *os_ << "hadamard"; + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(if_inst const &in) { + + if (in.num_results() > 0) { + do_with_infix(in.result_begin(), in.result_end(), [this](auto const &i) { dump_val(i); }); + *os_ << " = "; + } + *os_ << "if "; + dump_val(in.condition()); + *os_ << " "; + if (in.num_results() > 0) { + *os_ << "-> ("; + do_with_infix(in.result_begin(), in.result_end(), + [this](auto const &i) { visit(*this, *i.ty()); }); + *os_ << ") "; + } + dump_region(in.then()); + if (!in.is_otherwise_empty()) { + *os_ << " else "; + dump_region(in.otherwise()); + } +} + +void dump_ir_pass::operator()(math_unary_inst const &in) { + dump_val(in.result(0)); + *os_ << " = math." << to_string(in.operation()) << " "; + dump_val(in.a()); + *os_ << " : "; + visit(*this, *in.result(0).ty()); +} + +void dump_ir_pass::operator()(parallel_inst const &p) { + *os_ << "parallel "; + dump_region(p.body()); +} + +void dump_ir_pass::operator()(size_inst const &s) { + dump_val(s.result(0)); + *os_ << " = size "; + dump_val(s.operand()); + *os_ << "[" << s.mode() << "]"; + *os_ << " : "; + visit(*this, *s.result(0).ty()); +} + +void dump_ir_pass::operator()(subgroup_broadcast_inst const &in) { + dump_val(in.result(0)); + *os_ << " = subgroup_broadcast "; + dump_val(in.a()); + *os_ << ", "; + dump_val(in.idx()); + *os_ << " : "; + visit(*this, *in.result(0).ty()); +} + +void dump_ir_pass::operator()(subgroup_operation_inst const &in) { + dump_val(in.result(0)); + *os_ << " = subgroup." << to_string(in.arith()) << "." << to_string(in.operation()) << " "; + dump_val(in.a()); + *os_ << " : "; + visit(*this, *in.result(0).ty()); +} + +void dump_ir_pass::operator()(subview_inst const &s) { + dump_val(s.result(0)); + *os_ << " = subview "; + dump_val(s.operand()); + *os_ << "["; + auto dyn_offsets = s.offsets(); + auto dyn_sizes = s.sizes(); + for (std::size_t i = 0, joffset = 0, jsize = 0; i < s.static_offsets().size(); ++i) { + if (i != 0) { + *os_ << ","; + } + auto offset = s.static_offsets()[i]; + if (is_dynamic_value(offset)) { + dump_val(dyn_offsets[joffset++]); + } else { + *os_ << offset; + } + auto size = s.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + *os_ << ":"; + if (is_dynamic_value(size)) { + dump_val(dyn_sizes[jsize++]); + } else { + *os_ << size; + } + } + } + *os_ << "] : "; + visit(*this, *s.result(0).ty()); +} + +void dump_ir_pass::operator()(store_inst const &e) { + *os_ << "store"; + if (e.flag() != store_flag::regular) { + *os_ << '.' << to_string(e.flag()); + } + *os_ << ' '; + dump_val(e.val()); + *os_ << ", "; + dump_val(e.operand()); + *os_ << "["; + do_with_infix(e.index_list().begin(), e.index_list().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << "]"; +} + +void dump_ir_pass::operator()(sum_inst const &a) { + *os_ << "sum"; + *os_ << "." << to_string(a.tA()); + dump_blas_a2(static_cast(a)); +} + +void dump_ir_pass::operator()(yield_inst const &y) { + *os_ << "yield ("; + if (y.num_operands() > 0) { + do_with_infix(y.op_begin(), y.op_end(), [this](auto const &i) { dump_val(i); }, ", "); + } + *os_ << ")"; +} + +void dump_ir_pass::dump_region(region_node const ®) { + if (lvl_ < lvl_limit_) { + *os_ << "{" << std::endl; + ++lvl_; + auto ind = indent(); + for (auto const &i : reg) { + *os_ << ind; + visit(*this, i); + *os_ << std::endl; + } + --lvl_; + *os_ << indent() << "}"; + } else { + *os_ << "{...}"; + } +} + +void dump_ir_pass::run_on_function(function_node const &fn) { + init_slot_tracker(fn); + + *os_ << "func @" << fn.name() << "("; + std::string infix = ",\n "; + infix += std::string(fn.name().size(), ' '); + do_with_infix_enumerated( + fn.params().begin(), fn.params().end(), + [this, &fn](std::int32_t arg_no, auto const &a) { + dump_val(a); + *os_ << ": "; + visit(*this, *a.ty()); + if (auto pa = fn.param_attr(arg_no); pa) { + *os_ << " "; + visit(*this, *pa); + } + }, + infix); + *os_ << ")"; + if (fn.attr()) { + *os_ << " attributes"; + visit(*this, *fn.attr()); + } + *os_ << " "; + dump_region(fn.body()); + *os_ << std::endl; +} + +void dump_ir_pass::run_on_region(region_node const ®) { dump_region(reg); } +void dump_ir_pass::run_on_instruction(inst_node const &in) { visit(*this, in); } + +void dump_ir_pass::init_slot_tracker(function_node const &fn) { + tracker_ = slot_tracker{}; + tracker_.run_on_function(fn); +} + +} // namespace tinytc diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp new file mode 100644 index 00000000..6ae3c591 --- /dev/null +++ b/src/pass/dump_ir.hpp @@ -0,0 +1,123 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_IR_20230330_HPP +#define DUMP_IR_20230330_HPP + +#include "node/attr_node.hpp" // IWYU pragma: keep +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/slot_tracker.hpp" + +#include +#include +#include + +namespace tinytc { + +class dump_ir_pass { + public: + dump_ir_pass(std::ostream &os, int level_limit = std::numeric_limits::max()); + + /* Attribute nodes */ + void operator()(array_attr const &a); + void operator()(boolean_attr const &a); + void operator()(dictionary_attr const &a); + void operator()(integer_attr const &a); + void operator()(string_attr const &a); + + /* Data type nodes */ + void operator()(void_data_type const &); + void operator()(boolean_data_type const &); + void operator()(coopmatrix_data_type const &ct); + void operator()(group_data_type const &g); + void operator()(memref_data_type const &m); + void operator()(scalar_data_type const &s); + + /* Inst nodes */ + void operator()(alloca_inst const &a); + void operator()(axpby_inst const &a); + void operator()(arith_inst const &a); + void operator()(arith_unary_inst const &a); + void operator()(barrier_inst const &b); + void operator()(builtin_inst const &in); + void operator()(cast_inst const &c); + void operator()(compare_inst const &c); + void operator()(constant_inst const &c); + void operator()(cooperative_matrix_apply_inst const &c); + void operator()(cooperative_matrix_extract_inst const &c); + void operator()(cooperative_matrix_insert_inst const &c); + void operator()(cooperative_matrix_load_inst const &c); + void operator()(cooperative_matrix_mul_add_inst const &c); + void operator()(cooperative_matrix_prefetch_inst const &c); + void operator()(cooperative_matrix_reduce_inst const &c); + void operator()(cooperative_matrix_scale_inst const &c); + void operator()(cooperative_matrix_store_inst const &c); + void operator()(cumsum_inst const &a); + void operator()(expand_inst const &e); + void operator()(fuse_inst const &f); + void operator()(load_inst const &e); + void operator()(lifetime_stop_inst const &l); + void operator()(gemm_inst const &g); + void operator()(gemv_inst const &g); + void operator()(ger_inst const &g); + void operator()(for_inst const &p); + void operator()(foreach_inst const &p); + void operator()(hadamard_inst const &g); + void operator()(if_inst const &in); + void operator()(math_unary_inst const &in); + void operator()(parallel_inst const &p); + void operator()(size_inst const &s); + void operator()(subgroup_broadcast_inst const &in); + void operator()(subgroup_operation_inst const &in); + void operator()(subview_inst const &s); + void operator()(store_inst const &s); + void operator()(sum_inst const &s); + void operator()(yield_inst const &y); + + void run_on_function(function_node const &fn); + void run_on_region(region_node const ®); + void run_on_instruction(inst_node const &in); + + void dump_val(value_node const &v); + void init_slot_tracker(function_node const &fn); + + private: + void dump_region(region_node const ®); + void dump_blas_a2(blas_a2_inst const &g); + void dump_blas_a3(blas_a3_inst const &g); + + template + void do_with_infix(Iterator begin, Iterator end, Action a, std::string const &infix = ",") { + for (auto it = begin; it != end; ++it) { + if (it != begin) { + *os_ << infix; + } + a(*it); + } + } + template + void do_with_infix_enumerated(Iterator begin, Iterator end, Action a, + std::string const &infix = ",") { + for (auto it = begin; it != end; ++it) { + if (it != begin) { + *os_ << infix; + } + a(it - begin, *it); + } + } + + inline auto indent() { return std::string(2 * lvl_, ' '); } + std::ostream *os_; + int lvl_limit_; + int lvl_ = 0; + + slot_tracker tracker_; +}; + +} // namespace tinytc + +#endif // DUMP_IR_20230330_HPP diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp new file mode 100644 index 00000000..5bec28ac --- /dev/null +++ b/src/pass/insert_barrier.cpp @@ -0,0 +1,226 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/insert_barrier.hpp" +#include "analysis/aa_results.hpp" +#include "analysis/alias.hpp" +#include "analysis/cfg.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +auto intersects(std::unordered_set<::tinytc_value const *> const &a, + std::unordered_set<::tinytc_value const *> const &b, aa_results const &aa) { + for (auto &av : a) { + for (auto &bv : b) { + if (aa.alias(*av, *bv)) { + return true; + } + } + } + return false; +} + +void insert_barrier_pass::reads_writes::clear() { + for (auto &rd : reads) { + rd.clear(); + } + for (auto &wr : writes) { + wr.clear(); + } +} + +void insert_barrier_pass::reads_writes::clear(address_space as) { + const auto space = address_space_to_index(as); + reads[space].clear(); + writes[space].clear(); +} + +void insert_barrier_pass::reads_writes::merge(reads_writes const &other) { + for (std::size_t i = 0; i < reads.size(); ++i) { + reads[i].insert(other.reads[i].begin(), other.reads[i].end()); + } + for (std::size_t i = 0; i < writes.size(); ++i) { + writes[i].insert(other.writes[i].begin(), other.writes[i].end()); + } +} + +void insert_barrier_pass::reads_writes::merge(reads_writes &&other) { + for (std::size_t i = 0; i < reads.size(); ++i) { + reads[i].merge(std::move(other.reads[i])); + } + for (std::size_t i = 0; i < writes.size(); ++i) { + writes[i].merge(std::move(other.writes[i])); + } +} + +void insert_barrier_pass::reads_writes::merge(address_space as, reads_writes const &other) { + const auto space = address_space_to_index(as); + reads[space].insert(other.reads[space].begin(), other.reads[space].end()); + writes[space].insert(other.writes[space].begin(), other.writes[space].end()); +} + +void insert_barrier_pass::reads_writes::emplace_read(address_space as, ::tinytc_value const *val) { + const auto space = address_space_to_index(as); + reads[space].emplace(val); +} +void insert_barrier_pass::reads_writes::emplace_write(address_space as, ::tinytc_value const *val) { + const auto space = address_space_to_index(as); + writes[space].emplace(val); +} +auto insert_barrier_pass::reads_writes::read_cardinal(address_space as) const -> std::size_t { + const auto space = address_space_to_index(as); + return reads[space].size(); +} +auto insert_barrier_pass::reads_writes::write_cardinal(address_space as) const -> std::size_t { + const auto space = address_space_to_index(as); + return writes[space].size(); +} + +bool insert_barrier_pass::reads_writes::raw(address_space as, reads_writes const &rw, + aa_results const &aa) const { + const auto space = address_space_to_index(as); + return intersects(rw.reads[space], writes[space], aa); +} +bool insert_barrier_pass::reads_writes::war(address_space as, reads_writes const &rw, + aa_results const &aa) const { + const auto space = address_space_to_index(as); + return intersects(rw.writes[space], reads[space], aa); +} +bool insert_barrier_pass::reads_writes::waw(address_space as, reads_writes const &rw, + aa_results const &aa) const { + const auto space = address_space_to_index(as); + return intersects(rw.writes[space], writes[space], aa); +} +bool insert_barrier_pass::reads_writes::raw_war_or_waw(address_space as, reads_writes const &rw, + aa_results const &aa) const { + return raw(as, rw, aa) || war(as, rw, aa) || waw(as, rw, aa); +} + +auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) -> std::size_t { + for (std::size_t i = 0; i < address_spaces.size(); ++i) { + if (as == address_spaces[i]) { + return i; + } + } + throw internal_compiler_error{}; +} + +void insert_barrier_pass::run_on_region(region_node ®, aa_results const &aa) { + // irw = reads and writes invisible to other threads + auto irw_in = std::unordered_map{}; + auto irw_out = std::unordered_map{}; + + auto const get_rw = [](inst_node &in) -> reads_writes { + auto rw = reads_writes{}; + auto const emplace_read = [&rw](value_node const &v) { + if (auto *m = dyn_cast(v.ty()); m) { + rw.emplace_read(m->addrspace(), &v); + } + }; + auto const emplace_write = [&rw](value_node const &v) { + if (auto *m = dyn_cast(v.ty()); m) { + rw.emplace_write(m->addrspace(), &v); + } + }; + + visit(overloaded{[&](blas_a2_inst &in) { + emplace_read(in.A()); + emplace_write(in.B()); + }, + [&](blas_a3_inst &in) { + emplace_read(in.A()); + emplace_read(in.B()); + emplace_write(in.C()); + }, + [&](load_inst &in) { emplace_read(in.operand()); }, + [&](store_inst &in) { emplace_write(in.operand()); }, [](inst_node &) {}}, + in); + return rw; + }; + + auto const get_cardinal = [](reads_writes const &rw) { + return std::array{ + rw.read_cardinal(address_space::global), rw.read_cardinal(address_space::local), + rw.write_cardinal(address_space::global), rw.write_cardinal(address_space::local)}; + }; + + auto cfg = get_control_flow_graph(reg); + auto q = cfg.node_queue(); + while (!q.empty()) { + auto n = q.front(); + q.pop(); + + const bool insert_barriers = cfg.kind_max(n) < region_kind::spmd; + + auto &in = irw_in[n]; + auto &out = irw_out[n]; + + in.clear(); + for (auto &p : cfg.predecessors(n)) { + in.merge(irw_out[p]); + } + + auto out_size_before_update = get_cardinal(out); + + if (auto *barrier = dyn_cast(n); insert_barriers && barrier) { + for (auto &as : reads_writes::address_spaces) { + if (!barrier->has_fence(as)) { + out.merge(as, in); + } + } + } else { + out = get_rw(*n); + + std::int32_t fence_flags = 0; + for (auto &as : reads_writes::address_spaces) { + if (insert_barriers && in.raw_war_or_waw(as, out, aa)) { + fence_flags |= static_cast(as); + } else { + out.merge(as, in); + } + } + if (fence_flags != 0) { + tinytc_region *subreg = n->parent(); + auto new_barrier = + subreg->insts() + .insert(n->iterator(), + std::make_unique(fence_flags).release()) + .get(); + // update cfg + cfg.insert_before(n, new_barrier); + q.push(new_barrier); + } + } + + // out has changed, need to enqueue successors again + if (out_size_before_update != get_cardinal(out)) { + for (auto &s : cfg.successors(n)) { + q.push(s); + } + } + } +} + +/* Function nodes */ +void insert_barrier_pass::run_on_function(function_node &fn) { + auto aa = alias_analysis{}.run_on_function(fn); + run_on_region(fn.body(), aa); +} + +} // namespace tinytc diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp new file mode 100644 index 00000000..4fa6644e --- /dev/null +++ b/src/pass/insert_barrier.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INSERT_BARRIER_20230310_HPP +#define INSERT_BARRIER_20230310_HPP + +#include "node/function_node.hpp" +#include "node/region_node.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include + +namespace tinytc { + +class aa_results; + +class insert_barrier_pass { + public: + void run_on_function(function_node &fn); + + private: + class reads_writes { + public: + constexpr static std::array address_spaces = {address_space::global, + address_space::local}; + + void clear(); + void clear(address_space as); + void merge(reads_writes const &other); + void merge(reads_writes &&other); + void merge(address_space as, reads_writes const &other); + void emplace_read(address_space as, ::tinytc_value const *val); + void emplace_write(address_space as, ::tinytc_value const *val); + auto read_cardinal(address_space as) const -> std::size_t; + auto write_cardinal(address_space as) const -> std::size_t; + + bool raw(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool war(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool waw(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool raw_war_or_waw(address_space as, reads_writes const &rw, aa_results const &aa) const; + + private: + static auto address_space_to_index(address_space as) -> std::size_t; + + std::array, address_spaces.size()> reads, writes; + }; + + void run_on_region(region_node ®, aa_results const &aa); +}; + +} // namespace tinytc + +#endif // INSERT_BARRIER_20230310_HPP diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp new file mode 100644 index 00000000..014df6ef --- /dev/null +++ b/src/pass/insert_lifetime_stop.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/insert_lifetime_stop.hpp" +#include "analysis/aa_results.hpp" +#include "analysis/alias.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "tinytc/types.h" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" + +#include +#include + +namespace tinytc { + +auto insert_lifetime_stop_pass::run_on_region(region_node ®, aa_results const &aa) + -> std::unordered_set { + if (reg.empty()) { + return {}; + } + + auto allocas = std::vector{}; + for (auto &i : reg) { + if (auto alloca = dyn_cast(&i); alloca != nullptr) { + allocas.emplace_back(&alloca->result(0)); + } + } + + auto rgn_ops = std::unordered_set{}; + auto prev_it = reg.end(); + while (prev_it != reg.begin()) { + auto &i = *(--prev_it); + for (auto &subreg : i.child_regions()) { + rgn_ops.merge(run_on_region(subreg, aa)); + } + for (auto &v : i.operands()) { + if (isa(*v.ty())) { + rgn_ops.insert(aa.root(v)); + } + } + for (auto &v : i.results()) { + if (isa(*v.ty())) { + rgn_ops.insert(aa.root(v)); + } + } + + auto alloca_it = allocas.begin(); + while (alloca_it != allocas.end()) { + if (rgn_ops.contains(*alloca_it)) { + prev_it = reg.insts().insert_after( + prev_it, std::make_unique(*alloca_it).release()); + --prev_it; + alloca_it = allocas.erase(alloca_it); + } else { + ++alloca_it; + } + } + } + return rgn_ops; +} + +void insert_lifetime_stop_pass::run_on_function(function_node &fn) { + auto aa = alias_analysis{}.run_on_function(fn); + run_on_region(fn.body(), aa); +} + +} // namespace tinytc diff --git a/src/pass/insert_lifetime_stop.hpp b/src/pass/insert_lifetime_stop.hpp new file mode 100644 index 00000000..e038e024 --- /dev/null +++ b/src/pass/insert_lifetime_stop.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INSERT_LIFETIME_STOP_20240912_HPP +#define INSERT_LIFETIME_STOP_20240912_HPP + +#include "node/function_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" + +#include + +namespace tinytc { + +class aa_results; + +class insert_lifetime_stop_pass { + public: + void run_on_function(function_node &fn); + + private: + auto run_on_region(region_node ®, + aa_results const &aa) -> std::unordered_set<::tinytc_value const *>; +}; + +} // namespace tinytc + +#endif // INSERT_LIFETIME_STOP_20240912_HPP diff --git a/src/pass/lower_coopmatrix.cpp b/src/pass/lower_coopmatrix.cpp new file mode 100644 index 00000000..59312e1a --- /dev/null +++ b/src/pass/lower_coopmatrix.cpp @@ -0,0 +1,192 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_coopmatrix.hpp" +#include "codegen_tools.hpp" +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/clone.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +class coopmatrix_code_generator { + public: + coopmatrix_code_generator(core_config core_cfg, region_node ®) + : core_cfg_{std::move(core_cfg)}, bb_{®} {} + // Returns true if instruction was replaced + bool operator()(inst_node &in); + bool operator()(cooperative_matrix_apply_inst &in); + + void run_on_region(region_node ®); + + private: + auto needs_coopmatrix_vector_impl(inst_node &in); + + core_config core_cfg_; + region_builder bb_; +}; + +bool coopmatrix_code_generator::operator()(inst_node &) { return false; } + +bool coopmatrix_code_generator::operator()(cooperative_matrix_apply_inst &in) { + if (in.body().empty()) { + throw compilation_error(in.loc(), status::ir_must_have_yield); + } + + auto bool_ty = boolean_data_type::get(in.context()); + auto i32_ty = scalar_data_type::get(in.context(), scalar_type::i32); + + auto cloner = inst_cloner{}; + + auto ct = get_coopmatrix_type(in.a()); + auto cl = get_layout(core_cfg_, ct); + + auto p = bb_.add(make_builtin(builtin::subgroup_local_id, i32_ty, in.loc())); + auto i = p; + auto j0 = value{nullptr}; + if (cl.rows < core_cfg_.subgroup_size) { + auto cI = bb_.add(make_constant(cl.rows, i32_ty, in.loc())); + i = bb_.add(make_arith(arithmetic::rem, p, cI, i32_ty, in.loc())); + j0 = bb_.add(make_arith(arithmetic::div, p, cI, i32_ty, in.loc())); + } + const auto col_inc_factor = core_cfg_.subgroup_size / cl.rows; + + auto copy = value{&in.a()}; + for (std::int64_t v = 0; v < cl.length; ++v) { + const auto k1 = v % cl.blocks1; + const auto u = v / cl.blocks1 % cl.cols; + const auto k2 = v / (cl.blocks1 * cl.cols); + + auto row = i; + const auto block_offset = k1 * cl.rows + k2 * cl.rows * cl.blocks1; + if (block_offset) { + auto cblock_offset = bb_.add(make_constant(block_offset, i32_ty, in.loc())); + row = bb_.add(make_arith(arithmetic::add, i, cblock_offset, i32_ty, in.loc())); + } + auto j1 = bb_.add(make_constant(u * col_inc_factor, i32_ty, in.loc())); + auto col = j0 ? bb_.add(make_arith(arithmetic::add, j0, j1, i32_ty, in.loc())) : j1; + auto val = bb_.add(make_cooperative_matrix_extract(&in.a(), v, ct->ty(), in.loc())); + + cloner.set_subs(&in.row(), row); + cloner.set_subs(&in.col(), col); + cloner.set_subs(&in.val(), val); + + auto modified_val = value{}; + if ((u + 1) * col_inc_factor > cl.shape1) { + auto cshape1 = bb_.add(make_constant(cl.shape1, i32_ty, in.loc())); + auto cond = bb_.add(make_cmp(cmp_condition::lt, col, cshape1, bool_ty, in.loc())); + modified_val = bb_.ifelse( + cond, + [&](region_builder &bb) { + cloner.clone_region(in.body(), *bb.get_region()); + }, + [&](region_builder &bb) { + auto c0 = bb.add(make_constant_zero(ct->ty(), in.loc())); + bb.add(make_yield(c0)); + }, + {ct->ty()}, in.loc()) + .front(); + } else { + cloner.clone_region(in.body(), *bb_.get_region()); + + auto last_inst = --bb_.get_region()->end(); + if (last_inst != bb_.get_region()->end() && isa(*last_inst)) { + auto yi = dyn_cast(last_inst.get()); + if (yi->num_operands() != 1) { + throw compilation_error(in.loc(), status::ir_yield_mismatch); + } + modified_val = &yi->op(0); + bb_.get_region()->insts().erase(last_inst); + } else { + throw compilation_error(in.loc(), status::ir_must_have_yield); + } + } + copy = bb_.add( + make_cooperative_matrix_insert(modified_val, copy, v, in.result(0).ty(), in.loc())); + } + for (auto &r : in.results()) { + auto u = r.use_begin(); + while (r.has_uses()) { + u->set(copy); + u = r.use_begin(); + } + } + return true; +} + +void coopmatrix_code_generator::run_on_region(region_node ®) { + // Move all instructions to a temporary ilist. + // We later move the instructions back, except those that are lowered remain in old_ilist + // and are cleaned up at the end of the function. + auto old_ilist = std::move(reg.insts()); + + auto old_reg = bb_.get_region(); + bb_ = region_builder{®}; + + auto it = old_ilist.begin(); + while (it != old_ilist.end()) { + bool replaced = visit(*this, *it); + if (!replaced) { + auto instr = it.get(); + it = old_ilist.unlink(it); + reg.insts().push_back(instr); + for (auto &subreg : instr->child_regions()) { + run_on_region(subreg); + } + } else { + ++it; + } + } + it = old_ilist.end(); + while (it != old_ilist.begin()) { + --it; + for (auto &result : it->results()) { + if (result.has_uses()) { + throw compilation_error(result.loc(), status::ir_value_still_has_uses); + } + } + it = old_ilist.erase(it); + } + + bb_ = region_builder{old_reg}; +} + +lower_coopmatrix_pass::lower_coopmatrix_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_coopmatrix_pass::run_on_function(function_node &fn) { + auto const subgroup_size = fn.subgroup_size(); + core_config core_cfg = {}; + try { + core_cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + + auto gen = coopmatrix_code_generator{core_cfg, fn.body()}; + gen.run_on_region(fn.body()); +} + +} // namespace tinytc diff --git a/src/pass/lower_coopmatrix.hpp b/src/pass/lower_coopmatrix.hpp new file mode 100644 index 00000000..c1190912 --- /dev/null +++ b/src/pass/lower_coopmatrix.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_COOPMATRIX_20241206_HPP +#define LOWER_COOPMATRIX_20241206_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_coopmatrix_pass { + public: + lower_coopmatrix_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_COOPMATRIX_20241206_HPP diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp new file mode 100644 index 00000000..82f90c2d --- /dev/null +++ b/src/pass/lower_foreach.cpp @@ -0,0 +1,168 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_foreach.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "pass/clone.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +template +void make_loop0(region_builder &bb, value from, value to, value sg_id, int sgs, int num_tiles, + F &&make_body, location const &loc) { + auto ity = from->ty(); + auto ctx = compiler_context{sg_id->context(), true}; + auto bool_ty = get_boolean(ctx); + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto sg_lid_i32 = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, loc)); + auto sg_lid = bb.add(make_cast(sg_lid_i32, ity, loc)); + auto size = instant_constant_fold_add(bb, make_arith(arithmetic::sub, to, from, ity, loc)); + auto work_item_offset = bb.add(make_arith(arithmetic::add, from, sg_lid, ity, loc)); + tile_loop_by_sgs( + bb, size, sgs, num_tiles, sg_id, + [&](region_builder &bb, value block, bool is_remainder, value trip_count) { + auto loop_var0 = bb.add(make_arith(arithmetic::add, block, work_item_offset, ity, loc)); + if (is_remainder) { + auto cond = bb.add(make_cmp(cmp_condition::lt, sg_lid, trip_count, bool_ty, loc)); + bb.if_condition(cond, [&](region_builder &bb) { make_body(bb, loop_var0); }, loc); + } else { + make_body(bb, loop_var0); + } + }); +} + +class foreach_generator { + public: + foreach_generator(local_tiling tiling, core_config core_cfg) + : tiling_{std::move(tiling)}, core_cfg_{std::move(core_cfg)} {} + auto operator()(inst_node &) -> inst { return inst{}; } + auto operator()(foreach_inst &in) -> inst; + + private: + local_tiling tiling_ = {}; + core_config core_cfg_ = {}; +}; + +auto foreach_generator::operator()(foreach_inst &in) -> inst { + const int block_size0 = core_cfg_.subgroup_size; + + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto i32_ty = scalar_data_type::get(in.context(), scalar_type::i32); + + auto cloner = inst_cloner{}; + auto loop_vars = in.loop_vars().begin(); + auto from = in.from().begin(); + auto to = in.to().begin(); + auto ity = (*from).ty(); + + if (in.dim() > 1) { + auto const make_inner_loop_nest = [&](region_builder &bb, value from1, value to1) { + tinytc_region_t current_region = bb.get_region().get(); + for (std::int64_t i = in.dim() - 1; i > 1; --i) { + auto for_i = std::make_unique(ity, &from[i], &to[i], nullptr, + array_view{}, + array_view{}, in.loc()); + cloner.set_subs(&loop_vars[i], &for_i->loop_var()); + tinytc_region_t next_region = &for_i->body(); + current_region->insts().push_back(for_i.release()); + current_region = next_region; + } + region_builder{current_region}.for_loop( + ity, from1, to1, + [&](region_builder &bb, value loop_var1) { + cloner.set_subs(&loop_vars[1], loop_var1.get()); + cloner.clone_region(in.body(), *bb.get_region()); + }, + nullptr, in.loc()); + }; + + auto sg_id0 = bb.add(make_builtin(builtin::subgroup_id_x, i32_ty, in.loc())); + auto sg_id1 = bb.add(make_builtin(builtin::subgroup_id_y, i32_ty, in.loc())); + + auto size1 = bb.add(make_arith(arithmetic::sub, &to[1], &from[1], ity, in.loc())); + tile_loop_uniformly( + bb, size1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_id1, + [&](region_builder &bb, value block, value trip_count1) { + auto from1 = bb.add(make_arith(arithmetic::add, &from[1], block, ity, in.loc())); + auto to1 = bb.add(make_arith(arithmetic::add, from1, trip_count1, ity, in.loc())); + make_loop0( + bb, &from[0], &to[0], sg_id0, block_size0, tiling_.m_tiles(), + [&](region_builder &bb, value loop_var0) { + cloner.set_subs(&loop_vars[0], loop_var0.get()); + make_inner_loop_nest(bb, from1, to1); + }, + in.loc()); + }); + } else if (in.dim() == 1) { + auto sg_id = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, in.loc())); + make_loop0( + bb, &from[0], &to[0], sg_id, block_size0, tiling_.m_tiles() * tiling_.n_tiles(), + [&](region_builder &bb, value loop_var0) { + cloner.set_subs(&loop_vars[0], loop_var0.get()); + cloner.clone_region(in.body(), *bb.get_region()); + }, + in.loc()); + } + + return parallel; +} + +lower_foreach_pass::lower_foreach_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_foreach_pass::run_on_function(function_node &fn) { + auto const subgroup_size = fn.subgroup_size(); + core_config core_cfg = {}; + try { + core_cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + auto const work_group_size = fn.work_group_size(); + local_tiling tiling = {}; + tiling[0] = work_group_size[0] / subgroup_size; + tiling[1] = work_group_size[1]; + + walk(fn, [&](region_node ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + auto lowered_inst = visit(foreach_generator{tiling, core_cfg}, *it); + if (lowered_inst) { + it = reg.insts().erase(it); + it = reg.insts().insert(it, lowered_inst.release()); + } + } + }); +} + +} // namespace tinytc diff --git a/src/pass/lower_foreach.hpp b/src/pass/lower_foreach.hpp new file mode 100644 index 00000000..0486dd03 --- /dev/null +++ b/src/pass/lower_foreach.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_FOREACH_20241118_HPP +#define LOWER_FOREACH_20241118_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_foreach_pass { + public: + lower_foreach_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_FOREACH_20241118_HPP diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp new file mode 100644 index 00000000..70742d64 --- /dev/null +++ b/src/pass/lower_linalg.cpp @@ -0,0 +1,817 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_linalg.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "gemm_tools.hpp" +#include "matrix_ext_info.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "scalar_type.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomic, value alpha, + value A, value B, value beta, value C, value K, value m_block, + std::int32_t m_block_size, std::int32_t num_m_blocks, bool m_check, + value n_block, std::int32_t n_block_size, std::int32_t num_n_blocks, + bool n_check, array_view K_block_sizes, data_type a_ty, + data_type b_ty, data_type c_ty, attr for_attributes, location const &loc) { + auto ctx = m_block->context(); + auto bool_ty = boolean_data_type::get(ctx); + auto index_ty = scalar_data_type::get(ctx, scalar_type::index); + + const auto check_a = m_check ? checked_flag::rows : checked_flag::none; + const auto check_b = n_check ? checked_flag::cols : checked_flag::none; + const auto check_c = [&] { + if (m_check && n_check) { + return checked_flag::both; + } else if (m_check) { + return checked_flag::rows; + } else if (n_check) { + return checked_flag::cols; + } + return checked_flag::none; + }(); + auto c_m_block_size = bb.add(make_constant(m_block_size, index_ty, loc)); + auto c_n_block_size = bb.add(make_constant(n_block_size, index_ty, loc)); + + const auto c_acc_ty = [&c_ty, &loc]() { + auto ct = dyn_cast(c_ty); + if (ct == nullptr) { + throw compilation_error(loc, status::internal_compiler_error); + } + return scalar_data_type::get(c_ty->context(), acc_type(ct->ty())); + }(); + + auto coopmatrix_c_ty = get_coopmatrix(c_ty, m_block_size, n_block_size, matrix_use::acc, loc); + auto coopmatrix_c_acc_ty = + get_coopmatrix(c_acc_ty, m_block_size, n_block_size, matrix_use::acc, loc); + auto const compute_c_step = [&](region_builder &bb, std::int32_t k_block_size, value k, + array_view const &c_acc, + array_view const &c_acc_tys, + bool check_k = false) { + value pos_a[2] = {m_block, k}; + int amode = 0; + if (tA == transpose::T) { + std::swap(pos_a[0], pos_a[1]); + amode = 1 - amode; + } + auto coopmatrix_a_ty = get_coopmatrix(a_ty, m_block_size, k_block_size, matrix_use::a, loc); + const auto my_check_a = check_k ? add_check(check_a, checked_flag::cols) : check_a; + auto a = std::vector{}; + a.reserve(num_m_blocks); + for (std::int32_t i = 0; i < num_m_blocks; ++i) { + a.emplace_back(bb.add(make_cooperative_matrix_load(tA, my_check_a, A, pos_a[0], + pos_a[1], coopmatrix_a_ty))); + if (i + 1 < num_m_blocks) { + pos_a[amode] = bb.add( + make_arith(arithmetic::add, pos_a[amode], c_m_block_size, index_ty, loc)); + } + } + + value pos_b[2] = {k, n_block}; + int bmode = 1; + if (tB == transpose::T) { + std::swap(pos_b[0], pos_b[1]); + bmode = 1 - bmode; + } + auto coopmatrix_b_ty = get_coopmatrix(b_ty, k_block_size, n_block_size, matrix_use::b, loc); + const auto my_check_b = check_k ? add_check(check_b, checked_flag::rows) : check_b; + auto b = std::vector{}; + b.reserve(num_n_blocks); + for (std::int32_t i = 0; i < num_n_blocks; ++i) { + b.emplace_back(bb.add(make_cooperative_matrix_load(tB, my_check_b, B, pos_b[0], + pos_b[1], coopmatrix_b_ty))); + if (i + 1 < num_n_blocks) { + pos_b[bmode] = bb.add( + make_arith(arithmetic::add, pos_b[bmode], c_n_block_size, index_ty, loc)); + } + } + + auto c_next = std::vector{}; + c_next.reserve(num_m_blocks * num_n_blocks); + for (std::int32_t n = 0; n < num_n_blocks; ++n) { + for (std::int32_t m = 0; m < num_m_blocks; ++m) { + c_next.emplace_back( + bb.add(make_cooperative_matrix_mul_add(a[m], b[n], c_acc[m + n * num_m_blocks], + c_acc_tys[m + n * num_m_blocks], loc))); + } + } + return c_next; + }; + auto const compute_c = [&](region_builder &bb, std::int32_t k_block_size, value K0, value K1, + std::vector const &c_acc, + std::vector const &c_acc_tys, + bool check_k = false) -> std::vector { + auto c_step = bb.add(make_constant(k_block_size, index_ty, loc)); + auto return_values = bb.for_loop( + index_ty, K0, K1, c_step, c_acc, c_acc_tys, + [&](region_builder &bb, array_view p) { + const auto k = p[0]; + auto c_acc_iter = array_view(p.begin() + 1, p.end()); + auto c_next = compute_c_step(bb, k_block_size, k, c_acc_iter, c_acc_tys, check_k); + bb.add(make_yield(c_next, loc)); + }, + for_attributes); + return return_values; + }; + + auto c_acc = std::vector{}; + c_acc.reserve(num_m_blocks * num_n_blocks); + for (std::int32_t i = 0; i < num_m_blocks * num_n_blocks; ++i) { + c_acc.emplace_back(bb.add(make_constant_zero(coopmatrix_c_acc_ty, loc))); + } + auto c_acc_tys = std::vector(c_acc.size()); + for (auto &ty : c_acc_tys) { + ty = coopmatrix_c_acc_ty; + } + auto c_tys = std::vector(c_acc.size()); + for (auto &ty : c_tys) { + ty = coopmatrix_c_ty; + } + + auto k_block_size = K_block_sizes.back(); + + const auto const_K = get_int_constant(K); + if (const_K) { + k_block_size = choose_k_block_size(K_block_sizes, *const_K); + } + + auto c_zero = bb.add(make_constant_zero(index_ty, loc)); + auto c_k_block_size = bb.add(make_constant(k_block_size, index_ty, loc)); + auto tmp = instant_constant_fold_add( + bb, make_arith(arithmetic::div, K, c_k_block_size, index_ty, loc)); + auto K0 = instant_constant_fold_add( + bb, make_arith(arithmetic::mul, tmp, c_k_block_size, index_ty, loc)); + auto needs_remainder = + instant_constant_fold_add(bb, make_cmp(cmp_condition::lt, K0, K, bool_ty, loc)); + auto r = get_bool_constant(needs_remainder); + if (r) { + if (*r != 0) { + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc, c_acc_tys); + const auto K_block_size = K_block_sizes.front(); + c_acc = compute_c(bb, K_block_size, K0, K, c_acc, c_acc_tys, K_block_size > 1); + } else { + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc, c_acc_tys); + } + } else { + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc, c_acc_tys); + auto remainder = bb.ifelse( + needs_remainder, + [&](region_builder &bb) { + const auto K_block_size = K_block_sizes.front(); + auto c_next = + compute_c(bb, K_block_size, K0, K, c_acc, c_acc_tys, K_block_size > 1); + bb.add(make_yield(c_next, loc)); + }, + [&](region_builder &bb) { bb.add(make_yield(c_acc, loc)); }, c_acc_tys, loc); + c_acc = std::move(remainder); + } + + for (auto &a : c_acc) { + a = mixed_precision_coopmatrix_scale(bb, alpha, a, loc); + } + + const bool needs_final_cast = coopmatrix_c_ty != coopmatrix_c_acc_ty; + if (atomic) { + auto flag = get_atomic_store_flag(beta); + if (!flag) { + throw compilation_error(loc, status::ir_invalid_beta); + } + for (std::int32_t n = 0; n < num_n_blocks; ++n) { + auto pos1_offset = bb.add(make_constant(n * n_block_size, index_ty, loc)); + auto pos1 = bb.add(make_arith(arithmetic::add, n_block, pos1_offset, index_ty, loc)); + for (std::int32_t m = 0; m < num_m_blocks; ++m) { + auto pos0_offset = bb.add(make_constant(m * m_block_size, index_ty, loc)); + auto pos0 = + bb.add(make_arith(arithmetic::add, m_block, pos0_offset, index_ty, loc)); + auto alpha_ab_mn = c_acc[m + n * num_m_blocks]; + if (needs_final_cast) { + alpha_ab_mn = bb.add(make_cast(alpha_ab_mn, coopmatrix_c_ty, loc)); + } + bb.add( + make_cooperative_matrix_store(check_c, *flag, alpha_ab_mn, C, pos0, pos1, loc)); + } + } + } else { + for (std::int32_t n = 0; n < num_n_blocks; ++n) { + auto pos1_offset = bb.add(make_constant(n * n_block_size, index_ty, loc)); + auto pos1 = bb.add(make_arith(arithmetic::add, n_block, pos1_offset, index_ty, loc)); + for (std::int32_t m = 0; m < num_m_blocks; ++m) { + auto pos0_offset = bb.add(make_constant(m * m_block_size, index_ty, loc)); + auto pos0 = + bb.add(make_arith(arithmetic::add, m_block, pos0_offset, index_ty, loc)); + auto c_load = bb.add(make_cooperative_matrix_load(transpose::N, check_c, C, pos0, + pos1, coopmatrix_c_ty)); + auto &alpha_ab_mn = c_acc[m + n * num_m_blocks]; + auto alpha_ab_plus_beta_c = [&] { + if (needs_final_cast) { + auto c_load_acc = bb.add(make_cast(c_load, coopmatrix_c_acc_ty, loc)); + auto beta_c = mixed_precision_coopmatrix_scale(bb, beta, c_load_acc, loc); + auto alpha_ab_plus_beta_c = bb.add(make_arith( + arithmetic::add, alpha_ab_mn, beta_c, alpha_ab_mn->ty(), loc)); + return bb.add(make_cast(alpha_ab_plus_beta_c, coopmatrix_c_ty, loc)); + } else { + auto beta_c = mixed_precision_coopmatrix_scale(bb, beta, c_load, loc); + return bb.add(make_arith(arithmetic::add, alpha_ab_mn, beta_c, + alpha_ab_mn->ty(), loc)); + } + }(); + bb.add(make_cooperative_matrix_store(check_c, store_flag::regular, + alpha_ab_plus_beta_c, C, pos0, pos1, loc)); + } + } + } +} + +class linalg_generator { + public: + linalg_generator(local_tiling const &tiling, core_config const &core_cfg, region_node ®, + tinytc_inst_iterator_t ip) + : tiling_{tiling}, core_cfg_{core_cfg}, bb_{®, ip} {} + inline void operator()(inst_node &in) { + throw compilation_error(in.loc(), status::not_implemented); + } + void operator()(axpby_inst &in); + void operator()(cumsum_inst &in); + void operator()(gemm_inst &in); + void operator()(gemv_inst &in); + void operator()(ger_inst &in); + void operator()(hadamard_inst &in); + void operator()(sum_inst &in); + + inline auto insertion_point() -> tinytc_inst_iterator_t { return bb_.get_insertion_point(); } + + private: + auto get_memref_type(value_node const &v) const -> const memref_data_type *; + + local_tiling const &tiling_; + core_config const &core_cfg_; + region_builder bb_; +}; + +auto linalg_generator::get_memref_type(value_node const &v) const -> const memref_data_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + +void linalg_generator::operator()(axpby_inst &in) { + auto ctx = compiler_context{in.alpha().context(), true}; + auto bool_ty = get_boolean(ctx); + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + if (bt->dim() == 0) { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto sg_id = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, in.loc())); + auto sg_lid = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, in.loc())); + auto c0 = bb.add(make_constant(0, i32_ty)); + auto cond0 = bb.add(make_cmp(cmp_condition::eq, sg_id, c0, bool_ty, in.loc())); + auto cond1 = bb.add(make_cmp(cmp_condition::eq, sg_lid, c0, bool_ty, in.loc())); + auto cond = bb.add(make_arith(arithmetic::and_, cond0, cond1, cond0->ty())); + bb.if_condition(cond, [&](region_builder &bb) { + auto a = bb.add(make_load(&in.A(), {}, at->element_data_ty(), in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); + }); + + bb_.add(std::move(parallel)); + } else if (bt->dim() == 1) { + auto c0 = bb_.add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb_, make_size(&in.B(), 0, index_ty, in.loc())); + bb_.foreach_loop( + index_ty, {c0.get()}, {c_shape0.get()}, + [&](region_builder &bb, auto loop_vars) { + auto a = + bb.add(make_load(&in.A(), {loop_vars[0]}, at->element_data_ty(), in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {loop_vars[0]}, + in.loc()); + }, + in.loc()); + } else if (bt->dim() == 2) { + auto c0 = bb_.add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb_, make_size(&in.B(), 0, index_ty, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb_, make_size(&in.B(), 1, index_ty, in.loc())); + bb_.foreach_loop( + index_ty, {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, + [&](region_builder &bb, auto loop_vars) { + auto a_idx = std::array{loop_vars[0], loop_vars[1]}; + if (in.tA() == transpose::T) { + std::swap(a_idx[0], a_idx[1]); + } + auto a = bb.add(make_load(&in.A(), a_idx, at->element_data_ty(), in.loc())); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), + {loop_vars[0], loop_vars[1]}, in.loc()); + }, + in.loc()); + } +} + +void linalg_generator::operator()(cumsum_inst &in) { + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + + const auto num_tiles = tiling_.m_tiles() * tiling_.n_tiles(); + auto ctx = compiler_context{in.alpha().context(), true}; + auto bool_ty = get_boolean(ctx); + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); + const auto &loc = in.loc(); + + auto const scan_loop_1d = [&](region_builder &bb, work_group_inclusive_scan &scan, value a_sub, + value b_sub) { + auto c_sgs = bb.add(make_constant(scan.subgroup_size(), i32_ty, loc)); + auto sglid = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, loc)); + auto from_index = [&]() -> value { + if (scan.num_tiles() > 1) { + auto sgid = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, loc)); + auto from0 = bb.add(make_arith(arithmetic::mul, sgid, c_sgs, i32_ty, loc)); + auto from1 = bb.add(make_arith(arithmetic::add, from0, sglid, i32_ty, loc)); + return bb.add(make_cast(from1, index_ty, loc)); + } else { + return bb.add(make_cast(sglid, index_ty, loc)); + } + }(); + + auto c_step = bb.add(make_constant(scan.subgroup_size() * scan.num_tiles(), index_ty, loc)); + + auto c_1 = bb.add(make_constant_one(index_ty, loc)); + auto shape0 = instant_constant_fold_add(bb, make_size(a_sub, 0, index_ty, loc)); + auto tr0 = + instant_constant_fold_add(bb, make_arith(arithmetic::sub, shape0, c_1, index_ty, loc)); + auto tr1 = + instant_constant_fold_add(bb, make_arith(arithmetic::div, tr0, c_step, index_ty, loc)); + auto tr2 = + instant_constant_fold_add(bb, make_arith(arithmetic::add, tr1, c_1, index_ty, loc)); + auto trip_count = + instant_constant_fold_add(bb, make_arith(arithmetic::mul, tr2, c_step, index_ty, loc)); + + auto c_init = bb.add(make_constant_zero(bt->element_data_ty(), loc)); + auto a_scan = bb.for_loop( + index_ty, from_index, trip_count, c_step, {c_init}, {bt->element_data_ty()}, + [&](region_builder &bb, array_view args) { + auto is_in_bounds = + bb.add(make_cmp(cmp_condition::lt, args[0], shape0, bool_ty, loc)); + auto a = bb.ifelse( + is_in_bounds, + [&](region_builder &bb) { + auto a = bb.add(make_load(a_sub, {args[0]}, at->element_data_ty(), loc)); + if (at->element_data_ty() != bt->element_data_ty()) { + a = bb.add(make_cast(a, bt->element_data_ty(), loc)); + } + bb.add(make_yield({a}, loc)); + }, + [&](region_builder &bb) { bb.add(make_yield({c_init}, loc)); }, + {bt->element_data_ty()}, loc); + auto [a_scan, next_prefix] = scan.make(bb, a[0], true, loc); + a_scan = bb.add( + make_arith(arithmetic::add, args[1], a_scan, bt->element_data_ty(), loc)); + next_prefix = bb.add( + make_arith(arithmetic::add, args[1], next_prefix, bt->element_data_ty(), loc)); + bb.if_condition( + is_in_bounds, + [&](region_builder &bb) { + blas_update(bb, in.atomic(), &in.alpha(), a_scan, &in.beta(), b_sub, + {args[0]}, loc); + }, + loc); + bb.add(make_yield({next_prefix}, loc)); + }); + }; + + if (bt->dim() == 1) { + auto parallel = make_parallel(loc); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto scan = + work_group_inclusive_scan(num_tiles, core_cfg_.subgroup_size, bt->element_data_ty()); + scan.setup(bb_, loc); + + scan_loop_1d(bb, scan, &in.A(), &in.B()); + + bb_.add(std::move(parallel)); + scan.teardown(bb_); + } else if (bt->dim() >= 2 && in.mode() == 0) { + auto scan = work_group_inclusive_scan(1, core_cfg_.subgroup_size, bt->element_data_ty()); + scan.setup(bb_, loc); + + auto parallel = make_parallel(loc); + + auto c_zero = bb_.add(make_constant_zero(index_ty, loc)); + tinytc_region_t parent_region = ¶llel->child_region(0); + auto offsets = std::vector(bt->dim() - 1, nullptr); + for (std::int64_t i = bt->dim() - 1; i > 1; --i) { + auto bb = region_builder{parent_region}; + auto shape_i = bb.add(make_size(&in.B(), i, index_ty, loc)); + auto for_i = std::make_unique(index_ty, c_zero, shape_i, nullptr, + array_view{}, + array_view{}, loc); + offsets[i - 1] = &for_i->body().param(0); + parent_region = &for_i->body(); + bb.add(inst{for_i.release()}); + } + + auto bb = region_builder{parent_region}; + auto sgid = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, loc)); + auto sgid_index = bb.add(make_cast(sgid, index_ty, loc)); + + auto shape0 = bb.add(make_size(&in.B(), 0, index_ty, loc)); + auto shape1 = bb.add(make_size(&in.B(), 1, index_ty, loc)); + auto c_num_tiles = bb.add(make_constant(num_tiles, index_ty, loc)); + bb.for_loop(index_ty, sgid_index, shape1, c_num_tiles, + [&](region_builder &bb, array_view args) { + auto static_offset = std::vector(bt->dim(), dynamic); + auto static_size = std::vector(bt->dim(), 0); + static_offset[0] = 0; + static_size[0] = dynamic; + auto a_sub_ty = get_memref(at->element_data_ty(), {dynamic}, + {at->stride(0)}, at->addrspace(), loc); + auto b_sub_ty = get_memref(bt->element_data_ty(), {dynamic}, + {bt->stride(0)}, bt->addrspace(), loc); + offsets[0] = args[0]; + auto a_sub = bb.add(make_subview(&in.A(), static_offset, static_size, + offsets, {shape0}, a_sub_ty, loc)); + auto b_sub = bb.add(make_subview(&in.B(), static_offset, static_size, + offsets, {shape0}, b_sub_ty, loc)); + scan_loop_1d(bb, scan, a_sub, b_sub); + }); + + bb_.add(std::move(parallel)); + scan.teardown(bb_); + } else if (bt->dim() >= 2) { + auto c_zero = bb_.add(make_constant_zero(index_ty, loc)); + auto lb = std::vector(bt->dim() - 1, c_zero); + auto ub = std::vector{}; + ub.reserve(bt->dim() - 1); + for (std::int64_t i = 0; i < bt->dim(); ++i) { + if (i != in.mode()) { + ub.emplace_back( + instant_constant_fold_add(bb_, make_size(&in.B(), i, index_ty, loc))); + } + } + + auto J = bb_.add(make_size(&in.B(), in.mode(), index_ty, loc)); + bb_.foreach_loop( + index_ty, lb, ub, + [&](region_builder &bb, auto loop_vars) { + auto static_offset = std::vector(bt->dim(), dynamic); + auto static_size = std::vector(bt->dim(), 0); + static_offset[in.mode()] = 0; + static_size[in.mode()] = dynamic; + auto a_sub_ty = get_memref(at->element_data_ty(), {dynamic}, + {at->stride(in.mode())}, at->addrspace(), loc); + auto a_sub = bb.add(make_subview(&in.A(), static_offset, static_size, loop_vars, + {J}, a_sub_ty, loc)); + auto b_sub_ty = get_memref(bt->element_data_ty(), {dynamic}, + {bt->stride(in.mode())}, bt->addrspace(), loc); + auto b_sub = bb.add(make_subview(&in.B(), static_offset, static_size, loop_vars, + {J}, b_sub_ty, loc)); + + auto c_init = bb.add(make_constant_zero(bt->element_data_ty())); + auto acc = bb.for_loop( + index_ty, c_zero, J, {}, {c_init}, {bt->element_data_ty()}, + [&](region_builder &bb, array_view args) { + auto a = bb.add(make_load(a_sub, {args[0]}, at->element_data_ty(), loc)); + auto prefix = mixed_precision_arithmetic(bb, bt->element_ty(), + arithmetic::add, args[1], a, loc); + blas_update(bb, in.atomic(), &in.alpha(), prefix, &in.beta(), b_sub, + {args[0]}, loc); + bb.add(make_yield({prefix}, loc)); + }); + }, + loc); + } +} + +void linalg_generator::operator()(gemm_inst &in) { + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + + auto ctx = compiler_context{in.alpha().context(), true}; + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); + + auto sg_m = bb.add(make_builtin(builtin::subgroup_id_x, i32_ty, in.loc())); + auto sg_n = bb.add(make_builtin(builtin::subgroup_id_y, i32_ty, in.loc())); + + auto [max_rows, max_cols] = max_register_block_gemm( + size(at->element_ty()), size(bt->element_ty()), size(acc_type(ct->element_ty())), + core_cfg_.subgroup_size, core_cfg_.register_space, + is_complex_type(ct->element_ty()) ? 2 : 1); + + auto c_shape0 = instant_constant_fold_add(bb, make_size(&in.C(), 0, index_ty, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb, make_size(&in.C(), 1, index_ty, in.loc())); + auto K = instant_constant_fold_add( + bb, make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); + + auto const_shape0 = get_int_constant(c_shape0); + auto const_shape1 = get_int_constant(c_shape1); + + const auto [block_size0, num_blocks0, block_size1, num_blocks1, do_tile_uniformly, + K_block_sizes] = + [&]() -> std::tuple> { + if (auto ext_type = core_cfg_.matrix->get_precision(at->element_ty(), bt->element_ty(), + ct->element_ty()); + ext_type) { + const auto M_bs = ext_type->M_block_sizes(); + // @todo Think about what do if we have multiple sizes for M + const auto block_size0 = M_bs.back(); + const auto shape0 = const_shape0 ? *const_shape0 : max_rows; + const auto num_blocks0 = + choose_block_size_multiple(block_size0, max_rows, tiling_.m_tiles(), shape0); + + // @todo Think about what do for multiple N sizes + const auto N_bs = ext_type->N_block_sizes(block_size0); + const auto block_size1 = N_bs.back(); + const auto shape1 = const_shape1 ? *const_shape1 : max_cols; + const auto num_blocks1 = + choose_block_size_multiple(block_size1, max_cols, tiling_.n_tiles(), shape1); + const auto K_bs = ext_type->K_block_sizes(block_size0, block_size1); + + return std::make_tuple(block_size0, num_blocks0, block_size1, num_blocks1, false, K_bs); + } + + const auto block_size0 = core_cfg_.subgroup_size; + const auto shape0 = const_shape0 ? *const_shape0 : max_rows; + const auto num_blocks0 = + choose_block_size_multiple(block_size0, max_rows, tiling_.m_tiles(), shape0); + const auto block_size1 = max_cols; + const auto num_blocks1 = 1; + + return std::make_tuple(block_size0, num_blocks0, block_size1, num_blocks1, + const_shape1.has_value(), + std::vector(standard_K_block_sizes.begin(), + standard_K_block_sizes.end())); + }(); + + if (do_tile_uniformly) { + tile_loop_uniformly( + bb, c_shape1, block_size1 * num_blocks1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, value n_block, value trip_count) { + auto const_trip_count = get_int_constant(trip_count); + if (!const_trip_count) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + tile_loop_by_sgs(bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value m_block, bool m_check, value) { + gemm_microkernel(bb, in.tA(), in.tB(), in.atomic(), + &in.alpha(), &in.A(), &in.B(), &in.beta(), + &in.C(), K, m_block, block_size0, num_blocks0, + m_check, n_block, *const_trip_count, + num_blocks1, false, K_block_sizes, + at->element_data_ty(), bt->element_data_ty(), + ct->element_data_ty(), nullptr, in.loc()); + }); + }); + } else { + auto no_unroll = get_dictionary_attr_with_sorted( + ctx, named_attr{get_string_attr(ctx, "unroll"), get_boolean_attr(ctx, false)}); + tile_loop_by_sgs( + bb, c_shape1, block_size1 * num_blocks1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, value n_block, bool n_check, value) { + tile_loop_by_sgs( + bb, c_shape0, block_size0 * num_blocks0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, value m_block, bool m_check, value) { + gemm_microkernel(bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), &in.A(), + &in.B(), &in.beta(), &in.C(), K, m_block, block_size0, + num_blocks0, m_check, n_block, block_size1, num_blocks1, + n_check, K_block_sizes, at->element_data_ty(), + bt->element_data_ty(), ct->element_data_ty(), no_unroll, + in.loc()); + }, + no_unroll); + }, + no_unroll); + } + + bb_.add(std::move(parallel)); +} + +void linalg_generator::operator()(gemv_inst &in) { + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto c0 = bb_.add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb_, make_size(&in.C(), 0, index_ty, in.loc())); + auto ct = get_memref_type(in.C()); + bb_.foreach_loop( + index_ty, {c0.get()}, {c_shape0.get()}, + [&](region_builder &bb, auto loop_vars) { + auto c_init = bb.add(make_constant_zero(ct->element_data_ty())); + auto K = + bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); + auto c_acc = bb.for_loop( + index_ty, c0, K, {}, {c_init}, {ct->element_data_ty()}, + [&](region_builder &bb, array_view p) { + auto a_idx = std::array{loop_vars[0], p[0]}; + if (in.tA() == transpose::T) { + std::swap(a_idx[0], a_idx[1]); + } + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto a = bb.add(make_load(&in.A(), a_idx, at->element_data_ty(), in.loc())); + auto b = bb.add(make_load(&in.B(), {p[0]}, bt->element_data_ty(), in.loc())); + auto ab = mixed_precision_arithmetic(bb, ct->element_ty(), arithmetic::mul, a, + b, in.loc()); + auto ab_c = mixed_precision_arithmetic(bb, ct->element_ty(), arithmetic::add, + p[1], ab, in.loc()); + bb.add(make_yield({ab_c}, in.loc())); + }); + blas_update(bb, in.atomic(), &in.alpha(), c_acc[0], &in.beta(), &in.C(), {loop_vars[0]}, + in.loc()); + }, + in.loc()); +} + +void linalg_generator::operator()(ger_inst &in) { + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto c0 = bb_.add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb_, make_size(&in.C(), 0, index_ty, in.loc())); + auto c_shape1 = instant_constant_fold_add(bb_, make_size(&in.C(), 1, index_ty, in.loc())); + bb_.foreach_loop( + index_ty, {c0.get(), c0.get()}, {c_shape0.get(), c_shape1.get()}, + [&](region_builder &bb, auto loop_vars) { + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + auto a = bb.add(make_load(&in.A(), {loop_vars[0]}, at->element_data_ty(), in.loc())); + auto b = bb.add(make_load(&in.B(), {loop_vars[1]}, bt->element_data_ty(), in.loc())); + auto ab = + mixed_precision_arithmetic(bb, ct->element_ty(), arithmetic::mul, a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), + {loop_vars[0], loop_vars[1]}, in.loc()); + }, + in.loc()); +} + +void linalg_generator::operator()(hadamard_inst &in) { + auto index_ty = scalar_data_type::get(in.alpha().context(), scalar_type::index); + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + + auto lb = std::vector(ct->dim()); + auto ub = std::vector(ct->dim()); + + auto c0 = bb_.add(make_constant(0, index_ty, in.loc())); + for (std::int64_t i = 0; i < ct->dim(); ++i) { + lb[i] = c0; + ub[i] = instant_constant_fold_add(bb_, make_size(&in.C(), i, index_ty, in.loc())); + } + + bb_.foreach_loop( + index_ty, lb, ub, + [&](region_builder &bb, auto loop_vars) { + auto a = bb.add(make_load(&in.A(), loop_vars, at->element_data_ty(), in.loc())); + auto b = bb.add(make_load(&in.B(), loop_vars, bt->element_data_ty(), in.loc())); + auto ab = + mixed_precision_arithmetic(bb, ct->element_ty(), arithmetic::mul, a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), loop_vars, in.loc()); + }, + in.loc()); +} + +void linalg_generator::operator()(sum_inst &in) { + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + + auto ctx = compiler_context{in.alpha().context(), true}; + auto bool_ty = get_boolean(ctx); + auto i32_ty = get_scalar(ctx, scalar_type::i32); + auto index_ty = get_scalar(ctx, scalar_type::index); + + if (bt->dim() == 0) { + const auto num_tiles = tiling_.m_tiles() * tiling_.n_tiles(); + auto reducer = work_group_reduce(num_tiles, core_cfg_.subgroup_size, bt->element_data_ty()); + reducer.setup(bb_, in.loc()); + + auto parallel = make_parallel(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto c_sgs = bb.add(make_constant(core_cfg_.subgroup_size, i32_ty, in.loc())); + auto sgid = bb.add(make_builtin(builtin::subgroup_linear_id, i32_ty, in.loc())); + auto m = bb.add(make_builtin(builtin::subgroup_local_id, i32_ty, in.loc())); + auto from0 = bb.add(make_arith(arithmetic::mul, sgid, c_sgs, i32_ty, in.loc())); + auto from1 = bb.add(make_arith(arithmetic::add, from0, m, i32_ty, in.loc())); + auto from_index = bb.add(make_cast(from1, index_ty, in.loc())); + + auto c_trip_count = + instant_constant_fold_add(bb, make_size(&in.A(), 0, index_ty, in.loc())); + auto c_step = + bb.add(make_constant(core_cfg_.subgroup_size * num_tiles, index_ty, in.loc())); + auto c_init = bb.add(make_constant_zero(bt->element_data_ty(), in.loc())); + + auto acc = bb.for_loop( + index_ty, from_index, c_trip_count, c_step, {c_init}, {bt->element_data_ty()}, + [&](region_builder &bb, array_view args) { + auto a = bb.add(make_load(&in.A(), {args[0]}, at->element_data_ty(), in.loc())); + auto sum = mixed_precision_arithmetic(bb, bt->element_ty(), arithmetic::add, + args[1], a, in.loc()); + bb.add(make_yield({sum}, in.loc())); + }); + auto acc_reduced = reducer.make(bb, acc[0], in.loc()); + + auto c_zero = bb.add(make_constant_zero(i32_ty, in.loc())); + auto is_first_work_item = + bb.add(make_cmp(cmp_condition::eq, from1, c_zero, bool_ty, in.loc())); + bb.if_condition( + is_first_work_item, + [&](region_builder &bb) { + blas_update(bb, in.atomic(), &in.alpha(), acc_reduced, &in.beta(), &in.B(), {}, + in.loc()); + }, + in.loc()); + + bb_.add(std::move(parallel)); + reducer.teardown(bb_); + } else if (bt->dim() == 1) { + auto c0 = bb_.add(make_constant(0, index_ty, in.loc())); + auto c_shape0 = instant_constant_fold_add(bb_, make_size(&in.B(), 0, index_ty, in.loc())); + bb_.foreach_loop( + index_ty, array_view{c0.get()}, array_view{c_shape0.get()}, + [&](region_builder &bb, auto loop_vars) { + auto K = + bb.add(make_size(&in.A(), in.tA() == transpose::T ? 0 : 1, index_ty, in.loc())); + auto c_init = bb.add(make_constant_zero(bt->element_data_ty())); + auto acc = bb.for_loop( + index_ty, c0, K, {}, {c_init}, {bt->element_data_ty()}, + [&](region_builder &bb, array_view args) { + auto index_list = std::array{loop_vars[0], args[0]}; + if (in.tA() == transpose::T) { + std::swap(index_list[0], index_list[1]); + } + auto a = + bb.add(make_load(&in.A(), index_list, at->element_data_ty(), in.loc())); + auto sum = mixed_precision_arithmetic(bb, bt->element_ty(), arithmetic::add, + args[1], a, in.loc()); + bb.add(make_yield({sum}, in.loc())); + }); + blas_update(bb, in.atomic(), &in.alpha(), acc[0], &in.beta(), &in.B(), + {loop_vars[0]}, in.loc()); + }, + in.loc()); + } +} + +lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_linalg_pass::run_on_function(function_node &fn) { + auto [core_cfg, tiling] = get_core_config_and_tiling(fn, info_); + + walk(fn, [&](region_node ®) { + auto it = reg.begin(); + while (it != reg.end()) { + if (isa(*it) || isa(*it)) { + auto gen = linalg_generator{tiling, core_cfg, reg, it.get()}; + visit(gen, *it); + it = reg.insts().erase(gen.insertion_point()); + } else { + ++it; + } + } + }); +} + +} // namespace tinytc diff --git a/src/pass/lower_linalg.hpp b/src/pass/lower_linalg.hpp new file mode 100644 index 00000000..e16b92f5 --- /dev/null +++ b/src/pass/lower_linalg.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_LINALG_20240801_HPP +#define LOWER_LINALG_20240801_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_linalg_pass { + public: + lower_linalg_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_LINALG_20240801_HPP diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp new file mode 100644 index 00000000..e7612cc2 --- /dev/null +++ b/src/pass/slot_tracker.cpp @@ -0,0 +1,44 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/slot_tracker.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "support/walk.hpp" +#include "util/iterator.hpp" + +#include +#include +#include + +namespace tinytc { + +void slot_tracker::set_slot(value_node const &v) { + if (!v.has_name()) { + slot_map_[&v] = slot_++; + } +} + +void slot_tracker::run_on_function(function_node const &fn) { + slot_ = 0; + for (auto const &arg : fn.params()) { + set_slot(arg); + } + walk(fn, [this](inst_node const &i) { + for (auto const ® : i.child_regions()) { + for (auto const &p : reg.params()) { + set_slot(p); + } + } + for (auto const &result : i.results()) { + set_slot(result); + } + }); +} + +auto slot_tracker::get_slot(value_node const &v) -> std::int64_t { + auto it = slot_map_.find(&v); + return it != slot_map_.end() ? it->second : -1; +} + +} // namespace tinytc diff --git a/src/visitor/slot_tracker.hpp b/src/pass/slot_tracker.hpp similarity index 55% rename from src/visitor/slot_tracker.hpp rename to src/pass/slot_tracker.hpp index c92400b3..5a572537 100644 --- a/src/visitor/slot_tracker.hpp +++ b/src/pass/slot_tracker.hpp @@ -5,9 +5,6 @@ #define SLOT_TRACKER_20240418_HPP #include "node/function_node.hpp" -#include "node/inst_node.hpp" -#include "node/program_node.hpp" -#include "node/region_node.hpp" #include "node/value_node.hpp" #include @@ -17,20 +14,7 @@ namespace tinytc { class slot_tracker { public: - /* Stmt nodes */ - void operator()(inst_node const &in); - void operator()(loop_inst const &p); - void operator()(if_inst const &in); - - /* Region nodes */ - void operator()(rgn const &b); - - /* Func nodes */ - void operator()(prototype const &); - void operator()(function const &fn); - - /* Program nodes */ - void operator()(program const &p); + void run_on_function(function_node const &fn); auto get_slot(value_node const &v) -> std::int64_t; diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp new file mode 100644 index 00000000..59ff53f7 --- /dev/null +++ b/src/pass/stack.cpp @@ -0,0 +1,79 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/stack.hpp" +#include "error.hpp" +#include "node/attr_node.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "support/walk.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +void set_stack_ptr_pass::run_on_function(function_node &fn) { + struct allocation { + value_node *value; + std::int64_t start, stop; + }; + std::list allocs; + + walk(fn, [&allocs](inst_node &i) { + visit(overloaded{ + [&allocs](alloca_inst &a) { + auto t = dyn_cast(a.result()->ty()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + const auto alignment = [&]() -> std::int32_t { + if (auto aa = get_attr(a.attr(), "alignment"); aa) { + auto val = dyn_cast_or_throw(aa, [&] { + return status::ir_expected_integer_attribute; + })->value(); + return val; + } + return t->element_alignment(); + }(); + auto size = t->size_in_bytes(); + std::int64_t stack_ptr = 0; + auto it = allocs.begin(); + for (; it != allocs.end(); ++it) { + if (it->start - stack_ptr >= size) { + break; + } + stack_ptr = (1 + (it->stop - 1) / alignment) * alignment; + } + allocs.insert(it, allocation{a.result(), stack_ptr, stack_ptr + size}); + a.stack_ptr(stack_ptr); + }, + [&allocs](lifetime_stop_inst &s) { + int num = 0; + auto &v = s.object(); + for (auto it = allocs.begin(); it != allocs.end();) { + if (it->value == &v) { + it = allocs.erase(it); + ++num; + } else { + ++it; + } + } + if (num != 1) { + throw compilation_error( + s.loc(), status::internal_compiler_error, + "Incorrect lifetime_stop: value not found in list of allocations"); + } + }, + [](inst_node &) {}}, + i); + }); +} + +} // namespace tinytc diff --git a/src/pass/stack.hpp b/src/pass/stack.hpp new file mode 100644 index 00000000..ef10177f --- /dev/null +++ b/src/pass/stack.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef STACK_20230413_HPP +#define STACK_20230413_HPP + +#include "node/function_node.hpp" + +namespace tinytc { + +class set_stack_ptr_pass { + public: + void run_on_function(function_node &fn); +}; + +} // namespace tinytc + +#endif // STACK_20230413_HPP diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp new file mode 100644 index 00000000..fa2d35e5 --- /dev/null +++ b/src/pass/work_group_size.cpp @@ -0,0 +1,124 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/work_group_size.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "node/attr_node.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +auto get_shapes(function_node &fn) -> std::vector { + auto shape_set = std::unordered_set{}; + + walk(fn, [&shape_set](inst_node &i) { + visit(overloaded{[&](blas_a2_inst &in) { + auto aty = get_memref_type(in.A())->element_ty(); + auto b = get_memref_type(in.B()); + if (b->dim() == 1) { + shape_set.insert({aty, aty, b->element_ty(), {b->shape(0), 0}}); + } else if (b->dim() >= 2) { + shape_set.insert( + {aty, aty, b->element_ty(), {b->shape(0), b->shape(1)}}); + } + }, + [&](blas_a3_inst &in) { + auto aty = get_memref_type(in.A())->element_ty(); + auto bty = get_memref_type(in.B())->element_ty(); + auto c = get_memref_type(in.C()); + if (c->dim() == 1) { + shape_set.insert({aty, bty, c->element_ty(), {c->shape(0), 0}}); + } else if (c->dim() >= 2) { + shape_set.insert({aty, + bty, + c->element_ty(), + {c->shape(0), c->shape(1)}, + isa(in)}); + } + }, + [](inst_node &) {}}, + i); + }); + + return std::vector(shape_set.begin(), shape_set.end()); +} + +work_group_size_pass::work_group_size_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void work_group_size_pass::run_on_function(function_node &fn) { + auto sgs_attr = get_attr(fn.attr(), "subgroup_size"); + auto wgs_attr = get_attr(fn.attr(), "work_group_size"); + + if (sgs_attr && wgs_attr) { + return; + } + + const auto shapes = get_shapes(fn); + + auto ctx = compiler_context{fn.ty()->context(), true}; + const auto subgroup_size = [&] { + if (!sgs_attr) { + auto sgs = suggest_subgroup_size(shapes, *info_); + sgs_attr = get_integer_attr(ctx, sgs); + return sgs; + } else { + return fn.subgroup_size(); + } + }(); + + core_config cfg = {}; + try { + cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + + const auto work_group_size = [&] { + if (!wgs_attr) { + auto tiling = suggest_local_tiling(shapes, cfg); + auto wgs = std::array{tiling[0] * subgroup_size, tiling[1]}; + wgs_attr = + get_array_attr(ctx, {get_integer_attr(ctx, wgs[0]), get_integer_attr(ctx, wgs[1])}); + return wgs; + } else { + return fn.work_group_size(); + } + }(); + + if (work_group_size[0] % subgroup_size != 0) { + throw compilation_error(fn.loc(), status::unsupported_work_group_size, + "First work-group size mode must be divisible by subgroup size"); + } + if (work_group_size[0] * work_group_size[1] > cfg.max_work_group_size) { + throw compilation_error(fn.loc(), status::unsupported_work_group_size); + } + + fn.attr(get_dictionary_attr_with_sorted( + ctx, {named_attr{get_string_attr(ctx, "subgroup_size"), sgs_attr}, + named_attr{get_string_attr(ctx, "work_group_size"), wgs_attr}})); +} + +} // namespace tinytc diff --git a/src/pass/work_group_size.hpp b/src/pass/work_group_size.hpp new file mode 100644 index 00000000..571aebd5 --- /dev/null +++ b/src/pass/work_group_size.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef WORK_GROUP_SIZE_20240311_HPP +#define WORK_GROUP_SIZE_20240311_HPP + +#include "node/function_node.hpp" +#include "tinytc/types.h" + +namespace tinytc { + +class work_group_size_pass { + public: + work_group_size_pass(tinytc_core_info const *info); + + void run_on_function(function_node &fn); + + private: + tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // WORK_GROUP_SIZE_20240311_HPP diff --git a/src/passes.cpp b/src/passes.cpp deleted file mode 100644 index ef27bad0..00000000 --- a/src/passes.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "passes.hpp" -#include "device_info.hpp" -#include "kernel_metadata.hpp" -#include "node/data_type_node.hpp" -#include "node/function_node.hpp" -#include "node/program_node.hpp" -#include "visitor/check_ir.hpp" -#include "visitor/dump_ir.hpp" -#include "visitor/equal.hpp" -#include "visitor/insert_barrier.hpp" -#include "visitor/lifetime_analysis.hpp" -#include "visitor/metadata.hpp" -#include "visitor/opencl_ast.hpp" -#include "visitor/stack.hpp" -#include "visitor/work_group_size.hpp" - -#include - -using clir::visit; - -namespace tinytc { - -void check_ir(tinytc_prog const &p) { return visit(ir_checker{}, p); } - -void dump_ir(std::ostream &os, tinytc_func const &f) { visit(ir_dumper{os}, f); } -void dump_ir(std::ostream &os, tinytc_prog const &p) { visit(ir_dumper{os}, p); } - -clir::prog generate_opencl_ast(tinytc_prog const &p, ::tinytc_core_info const &info) { - return visit(opencl_ast{&info}, p); -} - -auto get_metadata(tinytc_prog const &p) -> std::unordered_map { - auto v = metadata{}; - visit(v, p); - return v.get_result(); -} - -void insert_barriers(tinytc_func &f) { visit(insert_barrier{}, f); } -void insert_barriers(tinytc_prog &p) { visit(insert_barrier{}, p); } - -void insert_lifetime_stop_inst(tinytc_func &f) { visit(lifetime_inserter{}, f); } -void insert_lifetime_stop_inst(tinytc_prog &p) { visit(lifetime_inserter{}, p); } - -bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b) { return visit(equal{}, a, b); } - -void set_stack_ptrs(tinytc_func &f) { visit(stack_ptr{}, f); } -void set_stack_ptrs(tinytc_prog &p) { visit(stack_ptr{}, p); } - -void set_work_group_size(tinytc_func &f, ::tinytc_core_info const &info) { - visit(work_group_size{&info}, f); -} -void set_work_group_size(tinytc_prog &p, ::tinytc_core_info const &info) { - visit(work_group_size{&info}, p); -} - -} // namespace tinytc - diff --git a/src/passes.def b/src/passes.def new file mode 100644 index 00000000..b1317da8 --- /dev/null +++ b/src/passes.def @@ -0,0 +1,17 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +FUNCTION_PASS("check-ir", check_ir_pass{}) +FUNCTION_PASS("constant-propagation", constant_propagation_pass{}, tinytc::optflag::unsafe_fp_math) +FUNCTION_PASS("dead-code-elimination", dead_code_elimination_pass{}) +FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) +FUNCTION_PASS("dump-def-use", dump_def_use_pass{std::cout}) +FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) +FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) +FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) +FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) +FUNCTION_PASS_WITH_INFO("dump-gcd", [](tinytc_core_info const* info) { return dump_gcd_pass(std::cout, info); }) +FUNCTION_PASS_WITH_INFO("lower-coopmatrix", [](tinytc_core_info const* info) { return lower_coopmatrix_pass{info}; }) +FUNCTION_PASS_WITH_INFO("lower-foreach", [](tinytc_core_info const* info) { return lower_foreach_pass{info}; }) +FUNCTION_PASS_WITH_INFO("lower-linalg", [](tinytc_core_info const* info) { return lower_linalg_pass{info}; }) +FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass{info}; }) diff --git a/src/passes.hpp b/src/passes.hpp index 19515e18..8a8223f1 100644 --- a/src/passes.hpp +++ b/src/passes.hpp @@ -4,44 +4,22 @@ #ifndef PASSES_20240314_HPP #define PASSES_20240314_HPP -#include "kernel_metadata.hpp" +#include "node/program_node.hpp" #include "tinytc/types.h" -#include -#include -#include -#include - namespace tinytc { -//! Check whether some IR rules are respected -void check_ir(tinytc_prog const &p); -//! Dump IR to ostream -void dump_ir(std::ostream &os, tinytc_func const &f); -//! Dump IR to ostream -void dump_ir(std::ostream &os, tinytc_prog const &p); -//! Generate OpenCL AST -clir::prog generate_opencl_ast(tinytc_prog const &p, tinytc_core_info const &info); -//! Get kernel metadata -auto get_metadata(tinytc_prog const &p) -> std::unordered_map; -//! Insert barriers where necessary -void insert_barriers(tinytc_func &f); -//! Insert barriers where necessary -void insert_barriers(tinytc_prog &p); -//! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_func &f); -//! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_prog &p); -//! Check whether data types a and b are equal -bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b); -//! Manage temporary memory requested by alloca -void set_stack_ptrs(tinytc_func &f); -//! Manage temporary memory requested by alloca -void set_stack_ptrs(tinytc_prog &p); -//! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_func &f, tinytc_core_info const &info); -//! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_prog &p, tinytc_core_info const &info); +template void run_function_pass(FunctionPass &&pass, tinytc_prog &p) { + for (auto &fun : p) { + pass.run_on_function(fun); + } +} + +template void run_function_pass(FunctionPass &&pass, tinytc_prog const &p) { + for (auto const &fun : p) { + pass.run_on_function(fun); + } +} } // namespace tinytc diff --git a/src/precision_helper.cpp b/src/precision_helper.cpp deleted file mode 100644 index e5498a03..00000000 --- a/src/precision_helper.cpp +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "precision_helper.hpp" -#include "scalar_type.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include -#include -#include - -#include - -using clir::address_space; -using clir::as_char; -using clir::as_double; -using clir::as_float; -using clir::as_int; -using clir::as_long; -using clir::as_short; -using clir::as_uchar; -using clir::as_uint; -using clir::as_ulong; -using clir::builtin_type; -using clir::cast; -using clir::expr; -using clir::get_sub_group_local_id; -using clir::intel_sub_group_block_read_ui; -using clir::intel_sub_group_block_read_ul; -using clir::intel_sub_group_block_read_us; -using clir::intel_sub_group_block_write_ui; -using clir::intel_sub_group_block_write_ul; -using clir::intel_sub_group_block_write_us; -using clir::pointer_to; - -namespace tinytc { - -precision_helper::precision_helper(scalar_type ty) : ty_(ty) {} -builtin_type precision_helper::base_type() const { return to_clir_builtin_ty(ty_); } -builtin_type precision_helper::block_rw_base_type() const { - auto bt = base_type(); - switch (bt) { - case builtin_type::short_t: - return builtin_type::ushort_t; - case builtin_type::int_t: - case builtin_type::float_t: - return builtin_type::uint_t; - case builtin_type::long_t: - case builtin_type::double_t: - return builtin_type::ulong_t; - default: - break; - } - return bt; -} -expr precision_helper::as_type(builtin_type ty, expr e) const { - switch (ty) { - case builtin_type::char_t: - return as_char(std::move(e)); - case builtin_type::uchar_t: - return as_uchar(std::move(e)); - case builtin_type::short_t: - return as_short(std::move(e)); - case builtin_type::ushort_t: - return as_ushort(std::move(e)); - case builtin_type::int_t: - return as_int(std::move(e)); - case builtin_type::uint_t: - return as_uint(std::move(e)); - case builtin_type::long_t: - return as_long(std::move(e)); - case builtin_type::ulong_t: - return as_ulong(std::move(e)); - case builtin_type::float_t: - return as_float(std::move(e)); - case builtin_type::double_t: - return as_double(std::move(e)); - default: - break; - } - return e; -} -short precision_helper::bits() const { return size(ty_) * 8; } -clir::data_type precision_helper::type(address_space as) const { - return clir::data_type(base_type(), as); -} -clir::data_type precision_helper::type(short size, address_space as) const { - return clir::data_type(base_type(), size, as); -} -// TODO: Think of something for integer constants -expr precision_helper::constant(double value) const { return expr(value, bits()); } -expr precision_helper::zero() const { return constant(0.0); } - -expr precision_helper::sub_group_block_read(expr address, address_space as) const { - auto const make_read = [](builtin_type bt, expr address) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_read_us(std::move(address)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_read_ui(std::move(address)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_read_ul(std::move(address)); - default: - break; - } - return address[get_sub_group_local_id()]; - }; - auto bt = block_rw_base_type(); - address = cast(pointer_to(clir::data_type(bt, as)), std::move(address)); - auto inst = make_read(bt, std::move(address)); - if (bt != base_type()) { - return as_type(base_type(), std::move(inst)); - } - return inst; -} -expr precision_helper::sub_group_block_write(expr address, expr data, address_space as) const { - auto const make_write = [](builtin_type bt, expr address, expr data) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_write_us(std::move(address), std::move(data)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_write_ui(std::move(address), std::move(data)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_write_ul(std::move(address), std::move(data)); - default: - break; - } - return address[get_sub_group_local_id()] = std::move(data); - }; - auto bt = block_rw_base_type(); - address = cast(pointer_to(clir::data_type(bt, as)), std::move(address)); - if (bt != base_type()) { - data = as_type(bt, std::move(data)); - } - return make_write(bt, std::move(address), std::move(data)); -} - -} // namespace tinytc diff --git a/src/precision_helper.hpp b/src/precision_helper.hpp deleted file mode 100644 index 32445697..00000000 --- a/src/precision_helper.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef PRECISION_HELPER_20230214_HPP -#define PRECISION_HELPER_20230214_HPP - -#include "tinytc/types.hpp" - -#include "clir/builtin_type.hpp" -#include "clir/data_type.hpp" -#include "clir/expr.hpp" - -namespace tinytc { - -class precision_helper { - public: - precision_helper(scalar_type ty); - clir::builtin_type base_type() const; - clir::builtin_type block_rw_base_type() const; - clir::expr as_type(clir::builtin_type ty, clir::expr e) const; - short bits() const; - clir::data_type type(clir::address_space as = clir::address_space::generic_t) const; - clir::data_type type(short size, clir::address_space as = clir::address_space::generic_t) const; - clir::expr constant(double value) const; - clir::expr zero() const; - clir::expr sub_group_block_read(clir::expr address, - clir::address_space as = clir::address_space::generic_t) const; - clir::expr sub_group_block_write(clir::expr address, clir::expr data, - clir::address_space as = clir::address_space::generic_t) const; - - private: - scalar_type ty_; -}; - -} // namespace tinytc - -#endif // PRECISION_HELPER_20230214_HPP diff --git a/src/prog.cpp b/src/prog.cpp index 7ee2c714..accbf565 100644 --- a/src/prog.cpp +++ b/src/prog.cpp @@ -4,13 +4,13 @@ #include "error.hpp" #include "location.hpp" #include "node/program_node.hpp" +#include "pass/dump_ir.hpp" #include "passes.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include #include #include #include @@ -18,27 +18,29 @@ #include #include #include -#include using namespace tinytc; extern "C" { -tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size, - tinytc_func_t *fun_list, const tinytc_location_t *loc) { - if (prg == nullptr || (fun_list_size > 0 && fun_list == nullptr)) { +tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { + if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto fun_vec = std::vector(); - fun_vec.reserve(fun_list_size); - for (uint32_t i = 0; i < fun_list_size; ++i) { - fun_vec.emplace_back(func(fun_list[i], true)); - } - *prg = std::make_unique(std::move(fun_vec), get_optional(loc)).release(); + *prg = std::make_unique(compiler_context{ctx, true}, get_optional(loc)) + .release(); }); } +tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun) { + if (prg == nullptr || fun == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { prg->push_back(tinytc::func{fun}); }); +} + tinytc_status_t tinytc_prog_release(tinytc_prog_t obj) { if (obj == nullptr) { return tinytc_status_invalid_arguments; @@ -62,7 +64,15 @@ tinytc_status_t tinytc_prog_dump(const_tinytc_prog_t prg) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { dump_ir(std::cerr, *prg); }); + return exception_to_status_code([&] { run_function_pass(dump_ir_pass{std::cerr}, *prg); }); +} + +tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, + tinytc_compiler_context_t *ctx) { + if (prg == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = prg->share_context().release(); }); } tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *filename) { @@ -74,7 +84,7 @@ tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *f if (!stream.good()) { throw status::file_io_error; } - dump_ir(stream, *prg); + run_function_pass(dump_ir_pass{stream}, *prg); }); } @@ -85,7 +95,7 @@ tinytc_status_t tinytc_prog_print_to_string(const_tinytc_prog_t prg, char **str) return exception_to_status_code([&] { auto const text = [&] { auto oss = std::ostringstream{}; - dump_ir(oss, *prg); + run_function_pass(dump_ir_pass{oss}, *prg); return std::move(oss).str(); }(); auto const length = text.size() + 1; // Need to include terminating null character diff --git a/src/recipe.cpp b/src/recipe.cpp index 75c9f18b..f4d730ee 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -7,6 +7,7 @@ #include "tinytc/types.hpp" #include +#include #include namespace tinytc { @@ -28,12 +29,18 @@ auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_va return is_argument_zero(arg_size, arg_value); case scalar_type::i64: return is_argument_zero(arg_size, arg_value); + case scalar_type::bf16: + return is_argument_zero(arg_size, arg_value); + case scalar_type::f16: + return is_argument_zero(arg_size, arg_value); case scalar_type::f32: return is_argument_zero(arg_size, arg_value); case scalar_type::f64: return is_argument_zero(arg_size, arg_value); - case scalar_type::i1: - break; + case scalar_type::c32: + return is_argument_zero>(arg_size, arg_value); + case scalar_type::c64: + return is_argument_zero>(arg_size, arg_value); }; throw status::invalid_arguments; } @@ -50,12 +57,12 @@ tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recipe, tinytc_prog [&] { *prg = tinytc::prog(recipe->get_program()).release(); }); } -tinytc_status_t tinytc_recipe_get_source(const_tinytc_recipe_t recipe, tinytc_source_t *src) { - if (recipe == nullptr || src == nullptr) { +tinytc_status_t tinytc_recipe_get_binary(const_tinytc_recipe_t recipe, tinytc_binary_t *bin) { + if (recipe == nullptr || bin == nullptr) { return tinytc_status_invalid_arguments; } return tinytc::exception_to_status_code( - [&] { *src = tinytc::source(recipe->get_source()).release(); }); + [&] { *bin = tinytc::binary(recipe->get_binary()).release(); }); } tinytc_status_t tinytc_recipe_release(tinytc_recipe_t obj) { diff --git a/src/recipe.hpp b/src/recipe.hpp index 879245a0..5b4e5e10 100644 --- a/src/recipe.hpp +++ b/src/recipe.hpp @@ -7,31 +7,31 @@ #include "reference_counted.hpp" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include "tinytc/types.hpp" #include #include #include namespace tinytc { +enum class scalar_type; auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_value) -> bool; } // namespace tinytc struct tinytc_recipe : tinytc::reference_counted { public: - inline tinytc_recipe(tinytc::prog prg, tinytc::source src) - : prg_(std::move(prg)), src_(std::move(src)) {} + inline tinytc_recipe(tinytc::prog prg, tinytc::binary bin) + : prg_(std::move(prg)), bin_(std::move(bin)) {} virtual ~tinytc_recipe() = default; inline auto get_program() const -> tinytc::prog const & { return prg_; } - inline auto get_source() const -> tinytc::source const & { return src_; } + inline auto get_binary() const -> tinytc::binary const & { return bin_; } virtual auto num_kernels() const -> int = 0; virtual auto kernel_name(int kernel_num) const -> char const * = 0; private: tinytc::prog prg_; - tinytc::source src_; + tinytc::binary bin_; }; struct tinytc_recipe_handler : tinytc::reference_counted { diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 7ace3283..b5825f09 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -3,21 +3,19 @@ #include "small_gemm_batched.hpp" #include "error.hpp" -#include "parser.hpp" #include "recipe.hpp" -#include "reference_counted.hpp" #include "tinytc/tinytc.h" #include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" +#include "util/casting.hpp" +#include #include #include #include #include #include -#include namespace tinytc { @@ -32,8 +30,8 @@ auto small_gemm_batched_kernel_name(small_gemm_batched_kernel k) -> char const * } throw status::invalid_arguments; } -small_gemm_batched_recipe::small_gemm_batched_recipe(prog prg, source src, scalar_type ty) - : ::tinytc_recipe(std::move(prg), std::move(src)), ty_(ty) {} +small_gemm_batched_recipe::small_gemm_batched_recipe(prog prg, binary bin, scalar_type ty) + : ::tinytc_recipe(std::move(prg), std::move(bin)), ty_(ty) {} auto small_gemm_batched_recipe::num_kernels() const -> int { return static_cast(small_gemm_batched_kernel::num_kernels); } @@ -46,12 +44,11 @@ auto small_gemm_batched_recipe::kernel_name(int kernel_num) const -> char const using namespace tinytc; extern "C" { -tinytc_status_t -tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_core_info_t info, - tinytc_scalar_type_t ty, tinytc_transpose_t tA, - tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, - int64_t ldA, int64_t strideA, int64_t ldB, int64_t strideB, - int64_t ldC, int64_t strideC, tinytc_source_context_t ctx) { +tinytc_status_t tinytc_recipe_small_gemm_batched_create( + tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, + tinytc_transpose_t tA, tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, + int64_t strideA, int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC, + tinytc_compiler_context_t ctx) { if (recipe == nullptr || info == nullptr || M == TINYTC_DYNAMIC || N == TINYTC_DYNAMIC || K == TINYTC_DYNAMIC || ldA == TINYTC_DYNAMIC || strideA == TINYTC_DYNAMIC || ldB == TINYTC_DYNAMIC || strideB == TINYTC_DYNAMIC || ldC == TINYTC_DYNAMIC || @@ -59,11 +56,10 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co return tinytc_status_invalid_arguments; } + auto ctx_ = ctx ? compiler_context{ctx, true} : make_compiler_context(); std::int32_t source_id = 0; - if (ctx) { - TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "small gemm batched recipe", "", &source_id)); - } + TINYTC_CHECK_STATUS(tinytc_compiler_context_add_source( + ctx_.get(), "recipe/small_gemm_batched.cpp", "", &source_id)); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -74,61 +70,76 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co ++l.end.column; return l; }; - - auto const selA = [&](std::int64_t N1, std::int64_t N2) { - return tA == tinytc_transpose_T ? N2 : N1; - }; - auto const selB = [&](std::int64_t N1, std::int64_t N2) { - return tB == tinytc_transpose_T ? N2 : N1; + auto const make_static_sizes = [](transpose t, std::int64_t A, std::int64_t B) { + auto s = std::array{A, B, 0}; + if (t == transpose::T) { + std::swap(s[0], s[1]); + } + return s; }; + return exception_to_status_code( [&] { - auto const ty_ = enum_cast(ty); + auto const ty_ = get_scalar(ctx_, enum_cast(ty)); + auto const index_ty = get_scalar(ctx_, scalar_type::index); auto const tA_ = enum_cast(tA); auto const tB_ = enum_cast(tB); - auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = fb.argument(make_memref(ty_, {selA(M, K), selA(K, M), dynamic}, - {1, ldA, strideA}, my_loc()), - "A", my_loc()); - auto B = fb.argument(make_memref(ty_, {selB(K, N), selB(N, K), dynamic}, - {1, ldB, strideB}, my_loc()), - "B", my_loc()); - auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, my_loc()), - "C", my_loc()); - - auto beta = is_beta_nonzero ? std::move(beta_arg) : make_imm(0.0, ty_, my_loc()); - fb.body( - [&](region_builder &bb) { - auto gid = bb.add(make_group_id(my_loc())); - auto offsets = std::vector{make_index(0, my_loc()), - make_index(0, my_loc()), gid}; - auto size = std::vector{make_dynamic(my_loc()), - make_dynamic(my_loc()), value{}}; - auto a = bb.add(make_subview(A, offsets, size, my_loc())); - auto b = bb.add(make_subview(B, offsets, size, my_loc())); - auto c = bb.add(make_subview(C, offsets, size, my_loc())); - bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta, - std::move(c), my_loc())); - }, - my_loc()); + auto const kernel = [&](char const *name, bool is_beta_nonzero) { + auto const static_offsets = std::array{0, 0, dynamic}; + auto const A_static_sizes = make_static_sizes(tA_, M, K); + auto const B_static_sizes = make_static_sizes(tB_, K, N); + auto const C_static_sizes = make_static_sizes(transpose::N, M, N); + + auto A_ty = get_memref(ty_, {A_static_sizes[0], A_static_sizes[1], dynamic}, + {1, ldA, strideA}, address_space::global, my_loc()); + auto B_ty = get_memref(ty_, {B_static_sizes[0], B_static_sizes[1], dynamic}, + {1, ldB, strideB}, address_space::global, my_loc()); + auto C_ty = get_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, + address_space::global, my_loc()); + auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, get_void(ctx_), my_loc()); + auto fn_body = f.get_body(); + auto params = std::array{}; + fn_body.get_parameters(params); + params[0].set_name("alpha"); + params[1].set_name("A"); + params[2].set_name("B"); + params[3].set_name("beta"); + params[4].set_name("C"); + + auto bb = region_builder{fn_body}; + + auto gid = bb.add(make_builtin(builtin::group_id_x, index_ty, my_loc())); + auto at = get_memref(ty_, array_view(A_static_sizes.data(), 2), {1, ldA}, + address_space::global, my_loc()); + auto bt = get_memref(ty_, array_view(B_static_sizes.data(), 2), {1, ldB}, + address_space::global, my_loc()); + auto ct = get_memref(ty_, array_view(C_static_sizes.data(), 2), {1, ldC}, + address_space::global, my_loc()); + auto a = bb.add(make_subview(params[1], static_offsets, A_static_sizes, + array_view{gid}, {}, at, my_loc())); + auto b = bb.add(make_subview(params[2], static_offsets, B_static_sizes, + array_view{gid}, {}, bt, my_loc())); + auto c = bb.add(make_subview(params[4], static_offsets, C_static_sizes, + array_view{gid}, {}, ct, my_loc())); + auto beta = is_beta_nonzero ? params[3] : bb.add(make_constant_zero(ty_, my_loc())); + bb.add(make_gemm(tA_, tB_, false, params[0], std::move(a), std::move(b), beta, + std::move(c), my_loc())); + + return f; }; - auto pb = program_builder{}; - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - auto p = pb.get_product(my_loc()); - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); - *recipe = std::make_unique(std::move(p), source(src), ty_) + auto p = make_prog(ctx_, my_loc()); + p.add_function( + kernel(small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), true)); + p.add_function(kernel( + small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), false)); + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, p.get(), info)); + *recipe = std::make_unique(std::move(p), binary(bin), + enum_cast(ty)) .release(); }, - ctx); + ctx_.get()); } tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( diff --git a/src/recipe/small_gemm_batched.hpp b/src/recipe/small_gemm_batched.hpp index 0ffdc2da..5bd467cf 100644 --- a/src/recipe/small_gemm_batched.hpp +++ b/src/recipe/small_gemm_batched.hpp @@ -5,17 +5,19 @@ #define SMALL_GEMM_BATCHED_20240419_HPP #include "../recipe.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" namespace tinytc { +class binary; +class prog; +enum class scalar_type; + enum class small_gemm_batched_kernel : int { gemm = 0, gemm_beta0 = 1, num_kernels = 2 }; auto small_gemm_batched_kernel_name(small_gemm_batched_kernel k) -> char const *; struct small_gemm_batched_recipe : ::tinytc_recipe { public: - small_gemm_batched_recipe(prog prg, source src, scalar_type ty); + small_gemm_batched_recipe(prog prg, binary bin, scalar_type ty); auto num_kernels() const -> int override; auto kernel_name(int kernel_num) const -> char const * override; diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 90b2505f..cfd80120 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -4,13 +4,13 @@ #include "tall_and_skinny.hpp" #include "device_info.hpp" #include "error.hpp" -#include "parser.hpp" #include "recipe.hpp" -#include "reference_counted.hpp" #include "tiling.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" -#include "util.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" #include #include @@ -19,7 +19,6 @@ #include #include #include -#include namespace tinytc { @@ -34,10 +33,10 @@ auto tall_and_skinny_kernel_name(tall_and_skinny_kernel k) -> char const * { } throw status::invalid_arguments; } -tall_and_skinny_recipe::tall_and_skinny_recipe(prog prg, source src, scalar_type ty, std::int64_t M, +tall_and_skinny_recipe::tall_and_skinny_recipe(prog prg, binary bin, scalar_type ty, std::int64_t M, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size) - : ::tinytc_recipe(std::move(prg), std::move(src)), ty_(ty), M_dyn_(is_dynamic_value(M)), + : ::tinytc_recipe(std::move(prg), std::move(bin)), ty_(ty), M_dyn_(is_dynamic_value(M)), ldA_dyn_(is_dynamic_value(ldA)), ldB_dyn_(is_dynamic_value(ldB)), ldC_dyn_(is_dynamic_value(ldC)), M_block_size_(M_block_size) {} auto tall_and_skinny_recipe::num_kernels() const -> int { @@ -56,25 +55,24 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create(tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t N, int64_t K, int32_t M_block_size, - tinytc_source_context_t ctx) { - return tinytc_recipe_tall_and_skinny_create_specialized(recipe, info, ty, TINYTC_DYNAMIC, N, K, - TINYTC_DYNAMIC, TINYTC_DYNAMIC, - TINYTC_DYNAMIC, M_block_size, ctx); + tinytc_compiler_context_t ctx) { + return tinytc_recipe_tall_and_skinny_create_specialized( + recipe, info, ty, TINYTC_DYNAMIC, N, K, TINYTC_DYNAMIC, TINYTC_DYNAMIC, TINYTC_DYNAMIC, 0, + 0, 0, M_block_size, ctx); } tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t M, - int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t M_block_size, - tinytc_source_context_t ctx) { + int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t alignA, int32_t alignB, + int32_t alignC, int32_t M_block_size, tinytc_compiler_context_t ctx) { if (recipe == nullptr || info == nullptr || N == TINYTC_DYNAMIC || K == TINYTC_DYNAMIC) { return tinytc_status_invalid_arguments; } + auto ctx_ = ctx ? compiler_context{ctx, true} : make_compiler_context(); std::int32_t source_id = 0; - if (ctx) { - TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "tall and skinny recipe", "", &source_id)); - } + TINYTC_CHECK_STATUS(tinytc_compiler_context_add_source(ctx_.get(), "recipe/tall_and_skinny.cpp", + "", &source_id)); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -86,16 +84,22 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( return l; }; - if (M_block_size == 0u) { + if (M_block_size == 0) { TINYTC_CHECK_STATUS(tinytc_recipe_tall_and_skinny_suggest_block_size(info, &M_block_size)); } return exception_to_status_code( [&] { - auto const ty_ = enum_cast(ty); + auto const ty_ = get_scalar(ctx_, enum_cast(ty)); + auto const bool_ty = get_boolean(ctx_); + auto const index_ty = get_scalar(ctx_, scalar_type::index); - auto const shapes = std::vector{blas_shape{ty_, {M_block_size, N}}}; - auto [sgs, tiling] = suggest_subgroup_size_and_tiling(shapes, *info); + auto const bshape = blas_shape{enum_cast(ty), + enum_cast(ty), + enum_cast(ty), + {M_block_size, N}, + true}; + auto [sgs, tiling] = suggest_subgroup_size_and_tiling(array_view(bshape), *info); // We want to avoid working on too many columns in parallel as there is a high // chance to trash the cache due to the large pitch @@ -103,70 +107,111 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tiling[1] /= 2; } - auto const body = [&](region_builder &bb, value &alpha, value &A, value &B, value &beta, - value &C) { - auto const gemm = [&](region_builder &bb, std::vector const &offsets, - value const &block_size) { + auto const body = [&](region_builder &bb, value alpha, value A, value B, + bool is_beta_nonzero, value beta_arg, value C) { + auto c_M_block_size = bb.add(make_constant(M_block_size, index_ty, my_loc())); + auto gid = bb.add(make_builtin(builtin::group_id_x, index_ty, my_loc())); + auto m = bb.add( + make_arith(arithmetic::mul, gid, c_M_block_size, gid.get_type(), my_loc())); + auto beta = is_beta_nonzero ? beta_arg : bb.add(make_constant_zero(ty_, my_loc())); + + auto const static_offsets = std::array{dynamic, 0}; + auto const offsets = array_view(m); + + auto const static_gemm = [&](region_builder &bb) { + auto const A_static_sizes = std::array{M_block_size, K}; + auto const C_static_sizes = std::array{M_block_size, N}; + auto at = + get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); + auto ct = + get_memref(ty_, C_static_sizes, {1, ldC}, address_space::global, my_loc()); auto a = bb.add( - make_subview(A, offsets, {block_size, make_index(K, my_loc())}, my_loc())); + make_subview(A, static_offsets, A_static_sizes, offsets, {}, at, my_loc())); auto c = bb.add( - make_subview(C, offsets, {block_size, make_index(N, my_loc())}, my_loc())); + make_subview(C, static_offsets, C_static_sizes, offsets, {}, ct, my_loc())); + bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, + my_loc())); + }; + auto const dynamic_gemm = [&](region_builder &bb, value dyn_block_size) { + auto const A_static_sizes = std::array{dynamic, K}; + auto const C_static_sizes = std::array{dynamic, N}; + auto const sizes = array_view(dyn_block_size); + auto at = + get_memref(ty_, A_static_sizes, {1, ldA}, address_space::global, my_loc()); + auto ct = + get_memref(ty_, C_static_sizes, {1, ldC}, address_space::global, my_loc()); + auto a = bb.add(make_subview(A, static_offsets, A_static_sizes, offsets, sizes, + at, my_loc())); + auto c = bb.add(make_subview(C, static_offsets, C_static_sizes, offsets, sizes, + ct, my_loc())); bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, my_loc())); }; - - auto const block_size_imm = make_index(M_block_size, my_loc()); - auto gid = bb.add(make_group_id(my_loc())); - auto m = bb.add( - make_arith(arithmetic::mul, gid, make_index(M_block_size, my_loc()), my_loc())); - auto const offsets = std::vector{m, make_index(0, my_loc())}; if (!is_dynamic_value(M) && M % M_block_size == 0) { - gemm(bb, offsets, block_size_imm); + static_gemm(bb); } else { - auto M_val = - is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) : make_index(M); - auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, my_loc())); - auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, - make_index(M_block_size, my_loc()), my_loc())); - auto const dynamic_imm = make_dynamic(my_loc()); + + auto M_val = bb.add(make_size(C, 0, index_ty, my_loc())); + auto M_val_sub_m = + bb.add(make_arith(arithmetic::sub, M_val, m, m.get_type(), my_loc())); + auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, c_M_block_size, + bool_ty, my_loc())); bb.ifelse( - cond, [&](region_builder &bb) { gemm(bb, offsets, dynamic_imm); }, - [&](region_builder &bb) { gemm(bb, offsets, block_size_imm); }, {}, - my_loc()); + cond, [&](region_builder &bb) { dynamic_gemm(bb, M_val_sub_m); }, + [&](region_builder &bb) { static_gemm(bb); }, {}, my_loc()); } }; - auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = fb.argument(make_memref(ty_, {M, K}, {1, ldA}, my_loc()), "A", my_loc()); - auto B = fb.argument(make_memref(ty_, {K, N}, {1, ldB}, my_loc()), "B", my_loc()); - auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N}, {1, ldC}, my_loc()), "C", my_loc()); - fb.subgroup_size(sgs); + auto const kernel = [&](char const *name, bool is_beta_nonzero) { + auto A_ty = get_memref(ty_, {M, K}, {1, ldA}, address_space::global, my_loc()); + auto B_ty = get_memref(ty_, {K, N}, {1, ldB}, address_space::global, my_loc()); + auto C_ty = get_memref(ty_, {M, N}, {1, ldC}, address_space::global, my_loc()); + auto f = make_func(name, {ty_, A_ty, B_ty, ty_, C_ty}, get_void(ctx_), my_loc()); + + auto alignments = std::array, 3u>{ + {{1, alignA}, {2, alignB}, {4, alignC}}}; + auto align_attr = named_attr{get_string_attr(ctx_, "align"), nullptr}; + for (auto &[param_no, alignment] : alignments) { + if (alignment > 0) { + align_attr.attr = get_integer_attr(ctx_, alignment); + f.set_parameter_attr(param_no, + get_dictionary_attr_with_sorted(ctx_, align_attr)); + } + } + + auto fn_body = f.get_body(); + auto params = std::array{}; + fn_body.get_parameters(params); + params[0].set_name("alpha"); + params[1].set_name("A"); + params[2].set_name("B"); + params[3].set_name("beta"); + params[4].set_name("C"); auto const wgs = tiling.work_group_size(sgs); - fb.work_group_size(wgs[0], wgs[1]); + auto const wgs_attr = + named_attr{get_string_attr(ctx_, "work_group_size"), + get_array_attr(ctx_, {get_integer_attr(ctx_, wgs[0]), + get_integer_attr(ctx_, wgs[1])})}; + f.set_attr(get_dictionary_attr_with_sorted(ctx_, wgs_attr)); - auto beta = is_beta_nonzero ? beta_arg : make_imm(0.0, ty_, my_loc()); - fb.body([&](region_builder &bb) { body(bb, alpha, A, B, beta, C); }, my_loc()); + auto bb = region_builder{fn_body}; + body(bb, params[0], params[1], params[2], is_beta_nonzero, params[3], params[4]); + return f; }; - auto pb = program_builder{}; - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - - auto p = pb.get_product(my_loc()); - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); - *recipe = std::make_unique(std::move(p), source(src), ty_, M, - ldA, ldB, ldC, M_block_size) + auto p = make_prog(ctx_, my_loc()); + p.add_function(kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), true)); + p.add_function( + kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), false)); + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, p.get(), info)); + *recipe = std::make_unique(std::move(p), binary(bin), + enum_cast(ty), M, ldA, + ldB, ldC, M_block_size) .release(); }, - ctx); + ctx_.get()); } tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size(const_tinytc_core_info_t info, @@ -174,9 +219,12 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size(const_tinytc_co if (info == nullptr || M_block_size == nullptr) { return tinytc_status_invalid_arguments; } - - return tinytc::exception_to_status_code( - [&] { *M_block_size = std::min(128, info->minmax_work_group_size()); }); + return tinytc::exception_to_status_code([&] { + if (info->minmax_work_group_size() <= 0) { + throw tinytc::status::invalid_core_info; + } + *M_block_size = std::min(128, info->minmax_work_group_size()); + }); } tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( diff --git a/src/recipe/tall_and_skinny.hpp b/src/recipe/tall_and_skinny.hpp index dd2aaca4..32712884 100644 --- a/src/recipe/tall_and_skinny.hpp +++ b/src/recipe/tall_and_skinny.hpp @@ -5,19 +5,21 @@ #define TALL_AND_SKINNY_20240422_HPP #include "../recipe.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" #include namespace tinytc { +class binary; +class prog; +enum class scalar_type; + enum class tall_and_skinny_kernel : int { gemm = 0, gemm_beta0 = 1, num_kernels = 2 }; auto tall_and_skinny_kernel_name(tall_and_skinny_kernel k) -> char const *; struct tall_and_skinny_recipe : ::tinytc_recipe { public: - tall_and_skinny_recipe(prog prg, source src, scalar_type ty, std::int64_t M, std::int64_t ldA, + tall_and_skinny_recipe(prog prg, binary bin, scalar_type ty, std::int64_t M, std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size); auto num_kernels() const -> int override; auto kernel_name(int kernel_num) const -> char const * override; diff --git a/src/region.cpp b/src/region.cpp index ea9a507d..29cb4f6e 100644 --- a/src/region.cpp +++ b/src/region.cpp @@ -2,53 +2,104 @@ // SPDX-License-Identifier: BSD-3-Clause #include "error.hpp" -#include "location.hpp" +#include "node/inst_node.hpp" // IWYU pragma: keep #include "node/region_node.hpp" #include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include "util/ilist.hpp" +#include #include -#include -#include +#include +#include #include using namespace tinytc; extern "C" { -tinytc_status_t tinytc_region_create(tinytc_region_t *reg, uint32_t instruction_list_size, - tinytc_inst_t *instruction_list, - const tinytc_location_t *loc) { - if (reg == nullptr || (instruction_list_size > 0 && instruction_list == nullptr)) { +tinytc_status_t tinytc_region_append(tinytc_region_t reg, tinytc_inst_t instr) { + if (reg == nullptr || instr == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { - auto inst_vec = std::vector(); - inst_vec.reserve(instruction_list_size); - for (uint32_t i = 0; i < instruction_list_size; ++i) { - inst_vec.emplace_back(inst(instruction_list[i], true)); - } - *reg = std::make_unique(std::move(inst_vec), get_optional(loc)).release(); - }); + return exception_to_status_code([&] { reg->insts().push_back(instr); }); +} + +tinytc_status_t tinytc_region_begin(tinytc_region_t reg, tinytc_inst_iterator_t *iterator) { + if (reg == nullptr || iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *iterator = reg->insts().begin().get(); }); +} + +tinytc_status_t tinytc_region_end(tinytc_region_t reg, tinytc_inst_iterator_t *iterator) { + if (reg == nullptr || iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *iterator = reg->insts().end().get(); }); } -tinytc_status_t tinytc_region_release(tinytc_region_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_region_erase(tinytc_region_t reg, tinytc_inst_iterator_t *iterator) { + if (reg == nullptr || iterator == nullptr) { return tinytc_status_invalid_arguments; } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; + return exception_to_status_code([&] { *iterator = reg->insts().erase(*iterator).get(); }); +} + +tinytc_status_t tinytc_region_insert(tinytc_region_t reg, tinytc_inst_iterator_t *iterator, + tinytc_inst_t instr) { + if (reg == nullptr || iterator == nullptr || instr == nullptr) { + return tinytc_status_invalid_arguments; } - return tinytc_status_success; + return exception_to_status_code( + [&] { *iterator = reg->insts().insert(*iterator, instr).get(); }); } -tinytc_status_t tinytc_region_retain(tinytc_region_t obj) { - if (obj == nullptr) { +tinytc_status_t tinytc_next_inst(tinytc_inst_iterator_t *iterator) { + if (iterator == nullptr) { return tinytc_status_invalid_arguments; } - obj->inc_ref(); - return tinytc_status_success; + return exception_to_status_code( + [&] { *iterator = static_cast((*iterator)->next()); }); +} + +tinytc_status_t tinytc_prev_inst(tinytc_inst_iterator_t *iterator) { + if (iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *iterator = static_cast((*iterator)->prev()); }); +} + +tinytc_status_t tinytc_region_get_parameter(tinytc_region_t reg, uint32_t param_no, + tinytc_value_t *result) { + if (reg == nullptr || result == nullptr || param_no >= reg->num_params()) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *result = ®->param(param_no); }); +} + +tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, uint32_t *result_list_size, + tinytc_value_t *result_list) { + + if (reg == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const num_results = reg->num_params(); + if (num_results > std::numeric_limits::max()) { + throw std::out_of_range("too many results"); + } + auto num = static_cast(num_results); + if (*result_list_size > 0) { + auto results = reg->param_begin(); + num = std::min(num, *result_list_size); + for (uint32_t i = 0; i < num; ++i) { + result_list[i] = &results[i]; + } + } + *result_list_size = num; + }); } } diff --git a/src/required_extensions.cpp b/src/required_extensions.cpp deleted file mode 100644 index ea4d3edc..00000000 --- a/src/required_extensions.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "required_extensions.hpp" - -#include -#include - -#include - -namespace tinytc { - -auto ext_list(std::vector const &ext) -> std::vector { - auto result = std::vector{}; - result.reserve(ext.size() + 1); - for (auto const &e : ext) { - result.emplace_back(clir::to_string(e)); - } - result.emplace_back("cl_khr_fp64"); - return result; -} - -auto required_extensions(clir::func f) -> std::vector { - return ext_list(clir::get_required_extensions(std::move(f))); -} -auto required_extensions(clir::prog p) -> std::vector { - return ext_list(clir::get_required_extensions(std::move(p))); -} - -} // namespace tinytc diff --git a/src/required_extensions.hpp b/src/required_extensions.hpp deleted file mode 100644 index e4c5a23a..00000000 --- a/src/required_extensions.hpp +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef REQUIRED_EXTENSIONS_20240416_HPP -#define REQUIRED_EXTENSIONS_20240416_HPP - -#include - -#include -#include - -namespace tinytc { - -auto required_extensions(clir::func f) -> std::vector; -auto required_extensions(clir::prog p) -> std::vector; - -} // namespace tinytc - -#endif // REQUIRED_EXTENSIONS_20240416_HPP diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp index af4ef4d1..c5de0337 100644 --- a/src/scalar_type.cpp +++ b/src/scalar_type.cpp @@ -2,7 +2,9 @@ // SPDX-License-Identifier: BSD-3-Clause #include "scalar_type.hpp" +#include "error.hpp" #include "tinytc/tinytc.h" +#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -12,6 +14,8 @@ namespace tinytc { bool is_floating_type(scalar_type ty) { switch (ty) { + case scalar_type::bf16: + case scalar_type::f16: case scalar_type::f32: case scalar_type::f64: return true; @@ -21,60 +25,107 @@ bool is_floating_type(scalar_type ty) { return false; } -clir::builtin_type to_clir_builtin_ty(scalar_type ty) { +bool is_complex_type(scalar_type ty) { + switch (ty) { + case scalar_type::c32: + case scalar_type::c64: + return true; + default: + break; + } + return false; +} + +bool is_integer_type(scalar_type ty) { switch (ty) { - case scalar_type::i1: - return clir::builtin_type::bool_t; case scalar_type::i8: - return clir::builtin_type::char_t; case scalar_type::i16: - return clir::builtin_type::short_t; case scalar_type::i32: - return clir::builtin_type::int_t; case scalar_type::i64: - return clir::builtin_type::long_t; case scalar_type::index: - return clir::builtin_type::long_t; - case scalar_type::f32: - return clir::builtin_type::float_t; - case scalar_type::f64: - return clir::builtin_type::double_t; + return true; + default: + break; } - return clir::builtin_type::void_t; + return false; } -clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return clir::data_type(to_clir_builtin_ty(ty), as, q); +auto acc_type(scalar_type ty) -> scalar_type { + switch (ty) { + case scalar_type::i8: + return scalar_type::i32; + case scalar_type::bf16: + case scalar_type::f16: + return scalar_type::f32; + default: + return ty; + } } -clir::builtin_type to_clir_atomic_builtin_ty(scalar_type ty) { +auto component_count(scalar_type ty) -> vector_size { switch (ty) { - case scalar_type::i32: - return clir::builtin_type::atomic_int_t; - case scalar_type::i64: - return clir::builtin_type::atomic_long_t; - case scalar_type::index: - return clir::builtin_type::atomic_long_t; - case scalar_type::f32: - return clir::builtin_type::atomic_float_t; - case scalar_type::f64: - return clir::builtin_type::atomic_double_t; + case scalar_type::c32: + case scalar_type::c64: + return vector_size::v2; + default: + break; + } + return vector_size::v1; +} +auto component_type(scalar_type ty) -> scalar_type { + switch (ty) { + case scalar_type::c32: + return scalar_type::f32; + case scalar_type::c64: + return scalar_type::f64; default: break; } - return clir::builtin_type::void_t; + return ty; +} + +auto promotable(scalar_type a_ty, scalar_type b_ty) -> bool { + if (a_ty == b_ty) { + return true; + } + const auto a_cc = static_cast(component_count(a_ty)); + const auto b_cc = static_cast(component_count(b_ty)); + const auto a_ct = component_type(a_ty); + const auto b_ct = component_type(b_ty); + return (is_integer_type(a_ct) || !is_integer_type(b_ct)) && + (size(a_ct) < size(b_ct) || a_ct == b_ct) && a_cc <= b_cc; +} + +auto promote(scalar_type a_ty, scalar_type b_ty) -> std::optional { + if (promotable(a_ty, b_ty)) { + return b_ty; + } else if (promotable(b_ty, a_ty)) { + return a_ty; + } + return std::nullopt; +} + +auto promote_or_throw(scalar_type a_ty, scalar_type b_ty, location const &loc) -> scalar_type { + auto res = promote(a_ty, b_ty); + if (res) { + return *res; + } + throw compilation_error(loc, status::ir_forbidden_promotion); +} + +auto alignment(scalar_type ty, vector_size count) -> std::int32_t { + const std::int32_t scale = count == vector_size::v3 ? 4 : static_cast(count); + return scale * size(ty); } -clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return clir::data_type(to_clir_atomic_builtin_ty(ty), as, q); +auto is_cast_allowed(scalar_type from_ty, scalar_type to_ty) -> bool { + return !is_complex_type(from_ty) || is_complex_type(to_ty); } } // namespace tinytc char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { switch (ty) { - case tinytc_scalar_type_i1: - return "i1"; case tinytc_scalar_type_i8: return "i8"; case tinytc_scalar_type_i16: @@ -85,19 +136,28 @@ char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { return "i64"; case tinytc_scalar_type_index: return "index"; + case tinytc_scalar_type_bf16: + return "bf16"; + case tinytc_scalar_type_f16: + return "f16"; case tinytc_scalar_type_f32: return "f32"; case tinytc_scalar_type_f64: return "f64"; + case tinytc_scalar_type_c32: + return "c32"; + case tinytc_scalar_type_c64: + return "c64"; } return "unknown"; } size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty) { switch (ty) { - case tinytc_scalar_type_i1: case tinytc_scalar_type_i8: return 1; case tinytc_scalar_type_i16: + case tinytc_scalar_type_bf16: + case tinytc_scalar_type_f16: return 2; case tinytc_scalar_type_i32: case tinytc_scalar_type_f32: @@ -105,7 +165,10 @@ size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty) { case tinytc_scalar_type_i64: case tinytc_scalar_type_index: case tinytc_scalar_type_f64: + case tinytc_scalar_type_c32: return 8; + case tinytc_scalar_type_c64: + return 16; } return 0; } diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp index ff0428c1..5dd03ee0 100644 --- a/src/scalar_type.hpp +++ b/src/scalar_type.hpp @@ -6,19 +6,26 @@ #include "tinytc/types.hpp" -#include -#include +#include +#include namespace tinytc { +using host_index_type = std::int64_t; + +enum class vector_size { v1 = 1, v2 = 2, v3 = 3, v4 = 4, v8 = 8, v16 = 16 }; + bool is_floating_type(scalar_type ty); -clir::builtin_type to_clir_builtin_ty(scalar_type ty); -clir::builtin_type to_clir_atomic_builtin_ty(scalar_type ty); -clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); -clir::data_type to_clir_atomic_ty(scalar_type ty, - clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); +bool is_complex_type(scalar_type ty); +bool is_integer_type(scalar_type ty); +auto acc_type(scalar_type ty) -> scalar_type; +auto component_count(scalar_type ty) -> vector_size; +auto component_type(scalar_type ty) -> scalar_type; +auto promotable(scalar_type a_ty, scalar_type b_ty) -> bool; +auto promote(scalar_type a_ty, scalar_type b_ty) -> std::optional; +auto promote_or_throw(scalar_type a_ty, scalar_type b_ty, location const &loc) -> scalar_type; +auto is_cast_allowed(scalar_type from_ty, scalar_type to_ty) -> bool; +auto alignment(scalar_type ty, vector_size count = vector_size::v1) -> std::int32_t; } // namespace tinytc diff --git a/src/slice.hpp b/src/slice.hpp deleted file mode 100644 index d545242e..00000000 --- a/src/slice.hpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SLICE_20240412_HPP -#define SLICE_20240412_HPP - -#include "tinytc/tinytc.hpp" - -#include - -namespace tinytc { - -//! Slice storing offset:size -class slice : public std::pair { - public: - //! ctor - inline slice(value offset = {}, value size = {}) - : std::pair{std::move(offset), std::move(size)} {} -}; - -} // namespace tinytc - -#endif // SLICE_20240412_HPP diff --git a/src/source.cpp b/src/source.cpp deleted file mode 100644 index 18038a0d..00000000 --- a/src/source.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "source.hpp" -#include "error.hpp" -#include "tinytc/tinytc.h" - -#include - -using namespace tinytc; - -extern "C" { -tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, - char const **code) { - if (src == nullptr || length == nullptr || code == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *length = src->size(); - *code = src->code(); - }); -} - -tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, tinytc_location_t *loc) { - if (src == nullptr || loc == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *loc = src->code_loc(); }); -} - -tinytc_status_t tinytc_source_get_core_features(const_tinytc_source_t src, - tinytc_core_feature_flags_t *core_features) { - if (src == nullptr || core_features == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *core_features = src->core_features(); }); -} - -tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t src, uint32_t *extensions_size, - char const *const **extensions) { - if (src == nullptr || extensions_size == nullptr || extensions == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *extensions_size = src->required_extensions().size(); - *extensions = src->required_extensions().data(); - }); -} - -tinytc_status_t tinytc_source_release(tinytc_source_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_retain(tinytc_source_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/source.hpp b/src/source.hpp deleted file mode 100644 index d07017e7..00000000 --- a/src/source.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SOURCE_20240412_HPP -#define SOURCE_20240412_HPP - -#include "reference_counted.hpp" -#include "tinytc/types.h" - -#include -#include -#include -#include - -struct tinytc_source : tinytc::reference_counted { - public: - inline tinytc_source(std::string code, tinytc_location const &code_loc, - std::vector required_extensions, - tinytc_core_feature_flags_t core_features) - : code_(std::move(code)), code_loc_(code_loc), - required_extensions_(std::move(required_extensions)), core_features_(core_features) {} - - inline auto code() const -> char const * { return code_.c_str(); } - inline auto code_loc() const -> tinytc_location const & { return code_loc_; } - inline auto size() const -> std::size_t { return code_.size(); } - inline auto const &required_extensions() const { return required_extensions_; } - inline auto core_features() const noexcept -> tinytc_core_feature_flags_t { - return core_features_; - } - - private: - std::string code_; - tinytc_location code_loc_; - std::vector required_extensions_; - tinytc_core_feature_flags_t core_features_; -}; - -#endif // SOURCE_20240412_HPP diff --git a/src/spv/block2d_diy.cpp b/src/spv/block2d_diy.cpp new file mode 100644 index 00000000..17ad6598 --- /dev/null +++ b/src/spv/block2d_diy.cpp @@ -0,0 +1,316 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/block2d_diy.hpp" +#include "spv/xe_constants.hpp" +#include "support/temp_counter.hpp" +#include "tinytc/types.hpp" +#include "util/math.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto block_config::block_size_in_bytes() const -> std::int32_t { + return element_size * array_length * rows * cols; +} +auto block_config::block_size_in_num_grf() const -> std::int32_t { + return block_size_in_bytes() / xe::grf_size; +} +auto block_config::byte_offset(std::int32_t row, std::int32_t col, std::int32_t array_idx, + std::int32_t col_block, std::int32_t row_block) const + -> std::int32_t { + std::int32_t offset = 0; + if (transpose) { + offset = col_block; + offset = row_block + offset * row_blocks; + } else { + offset = row_block; + offset = col_block + offset * col_blocks; + } + offset = array_idx + offset * array_length; + offset = col + offset * cols; + offset = row + offset * rows; + return offset * element_size; +} +auto block_config::origin(std::int32_t row, std::int32_t col, std::int32_t array_idx, + std::int32_t col_block, std::int32_t row_block) const + -> std::array { + const auto offset = byte_offset(row, col, array_idx, col_block, row_block); + return region_origin(element_size, offset); +} +auto block_config::total_rows() const -> std::int32_t { return array_length * rows * row_blocks; } + +auto lsc_data_size(std::int32_t element_size) -> std::int32_t { + if (!is_positive_power_of_two(element_size) || element_size > 8) { + throw status::internal_compiler_error; + } + return ilog2(element_size); +} +auto region_origin(std::int32_t element_size, std::int32_t byte_offset) + -> std::array { + return {byte_offset / xe::grf_size, byte_offset % xe::grf_size / element_size}; +} +auto visa_type(scalar_type sty) -> char const * { + switch (sty) { + case scalar_type::i8: + return "b"; + case scalar_type::i16: + return "w"; + case scalar_type::i32: + return "d"; + case scalar_type::i64: + case scalar_type::index: + return "q"; + case scalar_type::f16: + return "hf"; + case scalar_type::bf16: + return "bf"; + case scalar_type::f32: + return "f"; + case scalar_type::f64: + return "df"; + default: + throw status::internal_compiler_error; + } +} + +/** + * This routine generates transpose code for 8x8 matrices of d32 type. Multiple 8x8 matrices may be + * packed side-by-side. E.g. for num_8x8_blocks=2 the GRF layout, say starting at reg r93, + * is expected to be + * + * r93: a_11 ... a_18 b_11 ... b_18 + * ...: ... ... ... ... + * r100: a_81 ... a_88 b_81 ... b_88 + * + * The transposition is done in-place and we should get + * + * r93: a_11 ... a_81 b_11 ... b_81 + * ...: ... ... ... ... + * r100: a_18 ... a_88 b_18 ... b_88 + * + * This routine can also be called for 16x8 half types or 32x8 i8 types. + * Then, the routine generates transpose + VNNI transform. + * + */ +auto make_d32_transpose8x8(std::ostream &oasm, char const *matrix, std::size_t offset, + temp_counter &make_tmp, std::int32_t num_8x8_blocks = 1) { + constexpr std::int32_t element_size = 4; + const std::int32_t stride = 8 * element_size * num_8x8_blocks; + const std::int32_t num_elements = 8 * stride / element_size; + + auto dst_d = make_tmp("dst_d"); + auto dst_q = make_tmp("dst_q"); + oasm << ".decl " << dst_d << " v_type=G type=d num_elts=" << num_elements + << " align=wordx32 alias=<" << matrix << "," << offset << ">\n"; + oasm << ".decl " << dst_q << " v_type=G type=q num_elts=" << num_elements / 2 + << " align=wordx32 alias=<" << matrix << "," << offset << ">\n"; + + // 2x2 transpose + const std::int32_t exec_size = 4 * num_8x8_blocks; + for (std::int32_t r = 0; r < 4; ++r) { + auto ttmp = make_tmp("ttmp_d"); + oasm << ".decl " << ttmp << " v_type=G type=d num_elts=" << exec_size << " align=wordx32\n"; + auto [R1, C1] = region_origin(element_size, 2 * r * stride + element_size); + auto [R2, C2] = region_origin(element_size, 2 * r * stride + stride); + oasm << "mov (M1," << exec_size << ") " << ttmp << "(0,0)<1> " << dst_d << "(" << R1 << "," + << C1 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_d << "(" << R1 << "," << C1 << ")<2> " + << dst_d << "(" << R2 << "," << C2 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_d << "(" << R2 << "," << C2 << ")<2> " + << ttmp << "(0,0)<1;1,0>\n"; + } + // 4x4 transpose + for (std::int32_t r = 0; r < 4; r += 2) { + auto ttmp = make_tmp("ttmp_q"); + oasm << ".decl " << ttmp << " v_type=G type=q num_elts=" << exec_size << " align=wordx32\n"; + auto [R1, C1] = region_origin(2 * element_size, 2 * r * stride + 2 * element_size); + auto [R2, C2] = region_origin(2 * element_size, 2 * (r + 1) * stride); + oasm << "mov (M1," << exec_size << ") " << ttmp << "(0,0)<1> " << dst_q << "(" << R1 << "," + << C1 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_q << "(" << R1 << "," << C1 << ")<2> " + << dst_q << "(" << R2 << "," << C2 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_q << "(" << R2 << "," << C2 << ")<2> " + << ttmp << "(0,0)<1;1,0>\n"; + } + // 8x8 transpose + for (std::int32_t r = 0; r < 4; ++r) { + auto ttmp = make_tmp("ttmp_d"); + auto [R1, C1] = region_origin(element_size, r * stride + 4 * element_size); + auto [R2, C2] = region_origin(element_size, (r + 4) * stride); + oasm << ".decl " << ttmp << " v_type=G type=d num_elts=" << exec_size << " align=wordx32\n"; + oasm << "mov (M1," << exec_size << ") " << ttmp << "(0,0)<1> " << dst_d << "(" << R1 << "," + << C1 << ")<8;4,1>\n"; + for (std::int32_t b = 0; b < 8 * num_8x8_blocks; b += 8) { + oasm << "mov (M1,4) " << dst_d << "(" << R1 << "," << C1 + b << ")<1> " << dst_d << "(" + << R2 << "," << C2 + b << ")<1;1,0>\n"; + oasm << "mov (M1,4) " << dst_d << "(" << R2 << "," << C2 + b << ")<1> " << ttmp << "(0," + << b / 2 << ")<1;1,0>\n"; + } + } +} + +struct block2d_native_helper { + block2d_native_helper(std::ostream &oasm, block_config const &cfg, temp_counter &make_tmp, + std::int32_t first_address_operand); + void header(); + template void walk(F &&io); + + inline auto temp(std::int32_t m, std::int32_t n) -> std::string const & { + return temps[n + m * cfg.col_blocks]; + } + + std::ostream &oasm; + block_config const &cfg; + std::vector temps; + std::string tempq; + std::int32_t first_; +}; + +block2d_native_helper::block2d_native_helper(std::ostream &oasm, block_config const &cfg, + temp_counter &make_tmp, + std::int32_t first_address_operand) + : oasm(oasm), cfg(cfg), temps(cfg.row_blocks * cfg.col_blocks), tempq{make_tmp("tempq")}, + first_{first_address_operand} { + std::generate(temps.begin(), temps.end(), [&]() { return make_tmp("temp"); }); +} + +void block2d_native_helper::header() { + const std::uint32_t block_size = + ((cfg.array_length - 1) << 16) | ((cfg.cols - 1) << 8) | (cfg.rows - 1); + const auto &tmp0 = temp(0, 0); + oasm << ".decl " << tmp0 << " v_type=G type=ud num_elts=8 align=wordx32\n" + << ".decl " << tempq << " v_type=G type=uq num_elts=4 align=wordx32 alias=<" << tmp0 + << ",0>\n" + << "mov (M1,1) " << tempq << "(0,0)<1> $" << first_ << "(0,0)<0;1,0>\n" + << "add (M1,1) " << tmp0 << "(0,2)<1> $" << first_ + 1 << "(0,0)<0;1,0> -1:d\n" + << "add (M1,1) " << tmp0 << "(0,3)<1> $" << first_ + 2 << "(0,0)<0;1,0> -1:d\n" + << "add (M1,1) " << tmp0 << "(0,4)<1> $" << first_ + 3 << "(0,0)<0;1,0> -1:d\n"; + if (cfg.pos0_shr) { + oasm << "shr (M1,1) " << tmp0 << "(0,5)<1> $" << first_ + 4 << "(0,0)<0;1,0> " + << cfg.pos0_shr << ":d\n"; + } else { + oasm << "mov (M1,1) " << tmp0 << "(0,5)<1> $" << first_ + 4 << "(0,0)<0;1,0>\n"; + } + oasm << "mov (M1,1) " << tmp0 << "(0,6)<1> $" << first_ + 5 << "(0,0)<0;1,0>\n" + << "mov (M1,1) " << tmp0 << "(0,7)<1> 0x" << std::hex << block_size << ":ud\n"; + for (std::int32_t m = 0; m < cfg.row_blocks; ++m) { + for (std::int32_t n = 0; n < cfg.col_blocks; ++n) { + const auto &tmp = temp(m, n); + if (m > 0 || n > 0) { + oasm << ".decl " << tmp << " v_type=G type=ud num_elts=8 align=wordx32\n"; + oasm << "mov (M1,8) " << tmp << "(0,0)<1> " << tmp0 << "(0,0)<1;1,0>\n"; + } + if (m > 0) { + oasm << "add (M1,1) " << tmp << "(0,5)<1> " << tmp << "(0,5)<0;1,0> 0x" + << m * cfg.rows * cfg.array_length << ":ud\n"; + } + if (n > 0) { + oasm << "add (M1,1) " << tmp << "(0,6)<1> " << tmp << "(0,6)<0;1,0> 0x" + << n * cfg.cols << ":ud\n"; + } + } + } +} + +template void block2d_native_helper::walk(F &&io) { + for (std::int32_t m = 0; m < cfg.row_blocks; ++m) { + for (std::int32_t n = 0; n < cfg.col_blocks; ++n) { + io(m, n); + } + } +} + +auto load_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string { + const std::uint32_t num_dst = std::min(31, cfg.block_size_in_num_grf()); + const std::uint32_t desc = [&] { + const std::uint32_t data_size = lsc_data_size(cfg.element_size); + std::uint32_t d = 3; + if (cfg.vnni) { + d |= 1 << 7; + } + if (cfg.transpose && !cfg.vnni) { + d |= 1 << 15; + } + d |= data_size << 9; + d |= num_dst << 20; + d |= 1 << 25; + return d; + }(); + + auto oasm = std::ostringstream{}; + auto h = block2d_native_helper(oasm, cfg, make_tmp, 1); + + oasm << "{\n"; + h.header(); + h.walk([&](std::int32_t m, std::int32_t n) { + const auto offset = cfg.byte_offset(0, 0, 0, n, m); + oasm << std::dec << "raw_sends.15.1.0." << num_dst << " (M1, 1) 0x0:ud 0x" << std::hex + << desc << ":ud " << h.temp(m, n) << ".0 %null.0 $0." << std::dec << offset << "\n"; + + if (cfg.vnni && cfg.transpose) { + for (std::int32_t array_idx = 0; array_idx < cfg.array_length; ++array_idx) { + make_d32_transpose8x8(oasm, "$0", cfg.byte_offset(0, 0, array_idx, n, m), make_tmp); + } + } + }); + oasm << "}\n"; + + return std::move(oasm).str(); +} +auto prefetch_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string { + const std::uint32_t desc = [&] { + const std::uint32_t data_size = lsc_data_size(cfg.element_size); + const std::uint32_t cache_control = cfg.cache_level == 1 ? 2 : 4; + std::uint32_t d = 3; + d |= data_size << 9; + d |= cache_control << 17; + d |= 1 << 25; + return d; + }(); + + auto oasm = std::ostringstream{}; + auto h = block2d_native_helper(oasm, cfg, make_tmp, 0); + + oasm << "{\n"; + h.header(); + h.walk([&](std::int32_t m, std::int32_t n) { + oasm << std::dec << "raw_sends.15.1.0.0 (M1, 1) 0x0:ud 0x" << std::hex << desc << ":ud " + << h.temp(m, n) << ".0 %null.0 %null.0\n"; + }); + oasm << "}\n"; + + return std::move(oasm).str(); +} +auto store_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string { + const std::uint32_t num_src1 = std::min(31, cfg.block_size_in_num_grf()); + const std::uint32_t desc = [&] { + const std::uint32_t data_size = lsc_data_size(cfg.element_size); + std::uint32_t d = 7; + d |= data_size << 9; + d |= 1 << 25; + return d; + }(); + + auto oasm = std::ostringstream{}; + auto h = block2d_native_helper(oasm, cfg, make_tmp, 1); + + oasm << "{\n"; + h.header(); + h.walk([&](std::int32_t m, std::int32_t n) { + const auto offset = cfg.byte_offset(0, 0, 0, n, m); + oasm << "raw_sends.15.1." << num_src1 << ".0 (M1, 1) 0x0:ud 0x" << std::hex << desc + << ":ud " << h.temp(m, n) << ".0 $0." << std::dec << offset << " %null.0\n"; + }); + oasm << "}\n"; + + return std::move(oasm).str(); +} + +} // namespace tinytc::spv diff --git a/src/spv/block2d_diy.hpp b/src/spv/block2d_diy.hpp new file mode 100644 index 00000000..86ac4c2d --- /dev/null +++ b/src/spv/block2d_diy.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef BLOCK2D_DIY_20250219_HPP +#define BLOCK2D_DIY_20250219_HPP + +#include +#include +#include + +namespace tinytc { +class temp_counter; +enum class scalar_type; +} // namespace tinytc + +namespace tinytc::spv { + +struct block_config { + scalar_type sty; + std::int32_t element_size; + std::int32_t array_length; + std::int32_t rows; + std::int32_t cols; + std::int32_t row_blocks; + std::int32_t col_blocks; + bool transpose; + bool vnni; + std::int32_t pos0_shr; // Number of bits to shift pos0 to the right (= divide by 2^pos0_shr) + std::int32_t cache_level; + + auto block_size_in_bytes() const -> std::int32_t; + auto block_size_in_num_grf() const -> std::int32_t; + auto byte_offset(std::int32_t row, std::int32_t col, std::int32_t array_idx, + std::int32_t col_block, std::int32_t row_block) const -> std::int32_t; + auto origin(std::int32_t row, std::int32_t col, std::int32_t array_idx, std::int32_t col_block, + std::int32_t row_block) const -> std::array; + auto total_rows() const -> std::int32_t; +}; + +auto lsc_data_size(std::int32_t element_size) -> std::int32_t; +auto region_origin(std::int32_t element_size, std::int32_t byte_offset) + -> std::array; +auto visa_type(scalar_type sty) -> char const *; + +auto load_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string; +auto prefetch_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string; +auto store_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string; + +} // namespace tinytc::spv + +#endif // BLOCK2D_DIY_20250219_HPP diff --git a/src/spv/capex_util.cpp b/src/spv/capex_util.cpp new file mode 100644 index 00000000..0cd1fb98 --- /dev/null +++ b/src/spv/capex_util.cpp @@ -0,0 +1,686 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "capex_util.hpp" +#include "enums.hpp" + +namespace tinytc::spv { + +auto capabilities(ExecutionModel e) -> array_view { + switch (e) { + case ExecutionModel::Vertex: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::TessellationControl: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionModel::TessellationEvaluation: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionModel::Geometry: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionModel::Fragment: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::GLCompute: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::Kernel: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionModel::TaskNV: { + constexpr static Capability values[] = {Capability::MeshShadingNV}; + return {values, 1}; + } + case ExecutionModel::MeshNV: { + constexpr static Capability values[] = {Capability::MeshShadingNV}; + return {values, 1}; + } + case ExecutionModel::RayGenerationKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::IntersectionKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::AnyHitKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::ClosestHitKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::MissKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::CallableKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::TaskEXT: { + constexpr static Capability values[] = {Capability::MeshShadingEXT}; + return {values, 1}; + } + case ExecutionModel::MeshEXT: { + constexpr static Capability values[] = {Capability::MeshShadingEXT}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(AddressingModel e) -> array_view { + switch (e) { + case AddressingModel::Physical32: { + constexpr static Capability values[] = {Capability::Addresses}; + return {values, 1}; + } + case AddressingModel::Physical64: { + constexpr static Capability values[] = {Capability::Addresses}; + return {values, 1}; + } + case AddressingModel::PhysicalStorageBuffer64: { + constexpr static Capability values[] = {Capability::PhysicalStorageBufferAddresses}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(MemoryModel e) -> array_view { + switch (e) { + case MemoryModel::Simple: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case MemoryModel::GLSL450: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case MemoryModel::OpenCL: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case MemoryModel::Vulkan: { + constexpr static Capability values[] = {Capability::VulkanMemoryModel}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(ExecutionMode e) -> array_view { + switch (e) { + case ExecutionMode::Invocations: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::SpacingEqual: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::SpacingFractionalEven: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::SpacingFractionalOdd: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::VertexOrderCw: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::VertexOrderCcw: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::PixelCenterInteger: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::OriginUpperLeft: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::OriginLowerLeft: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::EarlyFragmentTests: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::PointMode: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::Xfb: { + constexpr static Capability values[] = {Capability::TransformFeedback}; + return {values, 1}; + } + case ExecutionMode::DepthReplacing: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthGreater: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthLess: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthUnchanged: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::LocalSizeHint: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::InputPoints: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::InputLines: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::InputLinesAdjacency: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::Triangles: { + constexpr static Capability values[] = {Capability::Geometry, Capability::Tessellation}; + return {values, 2}; + } + case ExecutionMode::InputTrianglesAdjacency: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::Quads: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::Isolines: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::OutputVertices: { + constexpr static Capability values[] = {Capability::Geometry, Capability::Tessellation, + Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 4}; + } + case ExecutionMode::OutputPoints: { + constexpr static Capability values[] = {Capability::Geometry, Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 3}; + } + case ExecutionMode::OutputLineStrip: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::OutputTriangleStrip: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::VecTypeHint: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::ContractionOff: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::Initializer: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::Finalizer: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::SubgroupSize: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::SubgroupsPerWorkgroup: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::SubgroupsPerWorkgroupId: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::LocalSizeHintId: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::NonCoherentColorAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageColorReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::NonCoherentDepthAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageDepthReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::NonCoherentStencilAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageStencilReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::SubgroupUniformControlFlowKHR: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::PostDepthCoverage: { + constexpr static Capability values[] = {Capability::SampleMaskPostDepthCoverage}; + return {values, 1}; + } + case ExecutionMode::DenormPreserve: { + constexpr static Capability values[] = {Capability::DenormPreserve}; + return {values, 1}; + } + case ExecutionMode::DenormFlushToZero: { + constexpr static Capability values[] = {Capability::DenormFlushToZero}; + return {values, 1}; + } + case ExecutionMode::SignedZeroInfNanPreserve: { + constexpr static Capability values[] = {Capability::SignedZeroInfNanPreserve}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTE: { + constexpr static Capability values[] = {Capability::RoundingModeRTE}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTZ: { + constexpr static Capability values[] = {Capability::RoundingModeRTZ}; + return {values, 1}; + } + case ExecutionMode::NonCoherentTileAttachmentReadQCOM: { + constexpr static Capability values[] = {Capability::TileShadingQCOM}; + return {values, 1}; + } + case ExecutionMode::TileShadingRateQCOM: { + constexpr static Capability values[] = {Capability::TileShadingQCOM}; + return {values, 1}; + } + case ExecutionMode::EarlyAndLateFragmentTestsAMD: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::StencilRefReplacingEXT: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::CoalescingAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::IsApiEntryAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::MaxNodeRecursionAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::StaticNumWorkgroupsAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::ShaderIndexAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::MaxNumWorkgroupsAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefGreaterFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefLessFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefGreaterBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefLessBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::QuadDerivativesKHR: { + constexpr static Capability values[] = {Capability::QuadControlKHR}; + return {values, 1}; + } + case ExecutionMode::RequireFullQuadsKHR: { + constexpr static Capability values[] = {Capability::QuadControlKHR}; + return {values, 1}; + } + case ExecutionMode::SharesInputWithAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::OutputLinesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::OutputPrimitivesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupQuadsKHR: { + constexpr static Capability values[] = {Capability::ComputeDerivativeGroupQuadsKHR}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupLinearKHR: { + constexpr static Capability values[] = {Capability::ComputeDerivativeGroupLinearKHR}; + return {values, 2}; + } + case ExecutionMode::OutputTrianglesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::PixelInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderPixelInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::PixelInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderPixelInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderSampleInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderSampleInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderShadingRateInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderShadingRateInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SharedLocalMemorySizeINTEL: { + constexpr static Capability values[] = {Capability::VectorComputeINTEL}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTPINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTNINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::FloatingPointModeALTINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::FloatingPointModeIEEEINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::MaxWorkgroupSizeINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::MaxWorkDimINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::NoGlobalOffsetINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::NumSIMDWorkitemsINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::SchedulerTargetFmaxMhzINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximallyReconvergesKHR: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::FPFastMathDefault: { + constexpr static Capability values[] = {Capability::FloatControls2}; + return {values, 1}; + } + case ExecutionMode::StreamingInterfaceINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::RegisterMapInterfaceINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesv2INTEL}; + return {values, 1}; + } + case ExecutionMode::NamedBarrierCountINTEL: { + constexpr static Capability values[] = {Capability::VectorComputeINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximumRegistersINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximumRegistersIdINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + case ExecutionMode::NamedMaximumRegistersINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + default: + return {}; + } +} +auto extensions(ExecutionModel e) -> array_view { + switch (e) { + default: + return {}; + } +} +auto extensions(AddressingModel e) -> array_view { + switch (e) { + case AddressingModel::PhysicalStorageBuffer64: { + constexpr static char const *values[] = {"SPV_EXT_physical_storage_buffer", + "SPV_KHR_physical_storage_buffer"}; + return {values, 2}; + } + default: + return {}; + } +} +auto extensions(MemoryModel e) -> array_view { + switch (e) { + case MemoryModel::Vulkan: { + constexpr static char const *values[] = {"SPV_KHR_vulkan_memory_model"}; + return {values, 1}; + } + default: + return {}; + } +} +auto extensions(ExecutionMode e) -> array_view { + switch (e) { + case ExecutionMode::SubgroupUniformControlFlowKHR: { + constexpr static char const *values[] = {"SPV_KHR_subgroup_uniform_control_flow"}; + return {values, 1}; + } + case ExecutionMode::PostDepthCoverage: { + constexpr static char const *values[] = {"SPV_KHR_post_depth_coverage"}; + return {values, 1}; + } + case ExecutionMode::DenormPreserve: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::DenormFlushToZero: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::SignedZeroInfNanPreserve: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTE: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTZ: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::EarlyAndLateFragmentTestsAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests"}; + return {values, 1}; + } + case ExecutionMode::StencilRefReplacingEXT: { + constexpr static char const *values[] = {"SPV_EXT_shader_stencil_export"}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefGreaterFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefLessFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefUnchangedBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefGreaterBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefLessBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::OutputLinesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::OutputPrimitivesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupQuadsKHR: { + constexpr static char const *values[] = {"SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives"}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupLinearKHR: { + constexpr static char const *values[] = {"SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives"}; + return {values, 2}; + } + case ExecutionMode::OutputTrianglesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::PixelInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::PixelInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::MaxWorkgroupSizeINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::MaxWorkDimINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::NoGlobalOffsetINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::NumSIMDWorkitemsINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::MaximallyReconvergesKHR: { + constexpr static char const *values[] = {"SPV_KHR_maximal_reconvergence"}; + return {values, 1}; + } + default: + return {}; + } +} + +} // namespace tinytc::spv diff --git a/src/spv/capex_util.hpp b/src/spv/capex_util.hpp new file mode 100644 index 00000000..a827fad3 --- /dev/null +++ b/src/spv/capex_util.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_CAPEX_UTIL_20250605_HPP +#define GENERATED_CAPEX_UTIL_20250605_HPP + +#include "tinytc/tinytc.hpp" + +namespace tinytc::spv { + +enum class Capability; +enum class AddressingModel; +enum class ExecutionMode; +enum class ExecutionModel; +enum class MemoryModel; +auto capabilities(ExecutionModel op) -> array_view; +auto capabilities(AddressingModel op) -> array_view; +auto capabilities(MemoryModel op) -> array_view; +auto capabilities(ExecutionMode op) -> array_view; +auto extensions(ExecutionModel op) -> array_view; +auto extensions(AddressingModel op) -> array_view; +auto extensions(MemoryModel op) -> array_view; +auto extensions(ExecutionMode op) -> array_view; + +} // namespace tinytc::spv + +#endif // GENERATED_CAPEX_UTIL_20250605_HPP diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp new file mode 100644 index 00000000..44596eb1 --- /dev/null +++ b/src/spv/converter.cpp @@ -0,0 +1,1164 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/converter.hpp" +#include "analysis/gcd.hpp" +#include "analysis/stack.hpp" +#include "codegen_tools.hpp" +#include "converter_aux.hpp" +#include "error.hpp" +#include "matrix_ext_info.hpp" +#include "node/attr_node.hpp" +#include "node/data_type_node.hpp" +#include "node/function_node.hpp" +#include "node/inst_node.hpp" +#include "node/program_node.hpp" +#include "node/region_node.hpp" +#include "node/value_node.hpp" +#include "scalar_type.hpp" +#include "spv/coopmatrix_impl_block.hpp" +#include "spv/coopmatrix_impl_dpas.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "spv/pass/capex.hpp" +#include "spv/uniquifier.hpp" +#include "spv/visit.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto convert_prog_to_spirv(tinytc_prog const &p, tinytc_core_info const &info) + -> ::tinytc::spv_mod { + auto m = ::tinytc::spv_mod{ + std::make_unique(p.share_context(), info.core_features()).release()}; + + auto conv = inst_converter{*m, info}; + + m->add_to(section::memory_model, AddressingModel::Physical64, + MemoryModel::OpenCL); + + for (auto const &fn : p) { + conv.run_on_function(fn); + } + + // Add missing capabilites and extensions + auto cx = capex{conv.unique()}; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : m->insts(enum_cast

(s))) { + visit(cx, i); + } + } + + for (int i = 0; i < TINYTC_NUMBER_OF_SPIRV_FEATURES; ++i) { + const auto feature = enum_cast(i); + if (cx.requires_feature(feature) && !info.have_spirv_feature(feature)) { + throw compilation_error(p.loc(), status::spirv_required_feature_unavailable, + to_string(feature)); + } + } + + return m; +} + +inst_converter::inst_converter(tinytc_spv_mod &m, tinytc_core_info const &info) + : mod_(&m), info_(&info), unique_(m) {} + +auto inst_converter::get_dope_vector(tinytc_value const &v) -> dope_vector * { + if (auto it = dope_vec_.find(&v); it != dope_vec_.end()) { + return &it->second; + } + return nullptr; +} + +auto inst_converter::declare(value_node const &v, spv_inst *in) { vals_[&v] = in; } +auto inst_converter::val(value_node const &v) -> spv_inst * { + if (auto it = vals_.find(&v); it != vals_.end()) { + return it->second; + } + throw compilation_error(v.loc(), status::spirv_undefined_value); +} + +auto inst_converter::spv_ty(const_tinytc_data_type_t ty) -> spv_inst * { + return tinytc::visit( + overloaded{ + [&](void_data_type const &) -> spv_inst * { return unique_.void_ty(); }, + [&](boolean_data_type const &) -> spv_inst * { return unique_.bool_ty(); }, + [&](group_data_type const &g) -> spv_inst * { + return unique_.pointer_ty(StorageClass::CrossWorkgroup, spv_ty(g.ty()), + alignment(scalar_type::i64)); + }, + [&](memref_data_type const &mr) -> spv_inst * { return unique_.pointer_ty(&mr); }, + [&](scalar_data_type const &ty) -> spv_inst * { return unique_.scalar_ty(ty.ty()); }, + [&](coopmatrix_data_type const &ty) -> spv_inst * { return matrix_impl().spv_ty(&ty); }, + [](auto const &) -> spv_inst * { + // @todo + throw status::not_implemented; + }}, + *ty); +} + +auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { + if (dope_vec_.contains(&v)) { + throw compilation_error(v.loc(), status::internal_compiler_error); + } + + auto spv_index_ty = unique_.scalar_ty(scalar_type::index); + return ::tinytc::visit( + overloaded{[&](memref_data_type const &mr) -> dope_vector * { + return &(dope_vec_[&v] = dope_vector{spv_index_ty, mr.shape(), mr.stride()}); + }, + [&](group_data_type const &g) -> dope_vector * { + if (auto mt = dyn_cast(g.ty()); mt) { + auto pointer_ty = + unique_.pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, + alignment(scalar_type::i64)); + return &(dope_vec_[&v] = dope_vector{ + pointer_ty, mt->shape(), mt->stride(), spv_index_ty, + g.size(), spv_index_ty, g.offset()}); + } else { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + }, + [](auto const &) -> dope_vector * { return nullptr; }}, + *v.ty()); +} + +auto inst_converter::matrix_impl() -> coopmatrix_impl & { + if (matrix_impl_) { + return *matrix_impl_; + } + throw status::internal_compiler_error; +} + +void inst_converter::operator()(inst_node const &in) { + // @todo + throw compilation_error(in.loc(), status::not_implemented); +} + +void inst_converter::operator()(alloca_inst const &in) { + if (in.stack_ptr() < 0) { + throw compilation_error(in.loc(), status::internal_compiler_error, + "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); + } + if (!stack_) { + throw compilation_error(in.loc(), status::internal_compiler_error, + "Stack required but not allocated"); + } + + auto mt = get_memref_type(in.result(0)); + if (in.stack_ptr() % mt->element_alignment() != 0) { + throw compilation_error(in.loc(), status::ir_insufficient_alignment); + } + + auto stack_element_ty = unique_.scalar_ty(scalar_type::i8); + auto stack_ptr_ty = + unique_.pointer_ty(StorageClass::Workgroup, stack_element_ty, alignment(scalar_type::i8)); + auto stack_ptr = mod_->add( + stack_ptr_ty, stack_, std::vector{unique_.constant(in.stack_ptr())}); + + auto memref_ptr_ty = unique_.pointer_ty(mt); + declare(in.result(0), mod_->add(memref_ptr_ty, stack_ptr)); + + // alloca only accepts fixed-size memrefs => dope vector is constant + auto rdv = make_dope_vector(in.result(0)); + for (std::int64_t i = 0; i < mt->dim(); ++i) { + rdv->shape(i, unique_.constant(mt->shape(i))); + } + for (std::int64_t i = 0; i < mt->dim(); ++i) { + rdv->stride(i, unique_.constant(mt->stride(i))); + } +} + +void inst_converter::operator()(arith_inst const &in) { + auto const make_boolean = [&](arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::and_: + return mod_->add(ty, a, b); + case arithmetic::or_: + return mod_->add(ty, a, b); + case arithmetic::xor_: + return mod_->add(ty, a, b); + default: + break; + } + throw compilation_error(in.loc(), status::ir_boolean_unsupported); + }; + + if (isa(*in.result(0).ty())) { + auto ty = unique_.bool_ty(); + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), make_boolean(in.operation(), ty, av, bv)); + } else if (auto st = dyn_cast(in.result(0).ty()); st) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), make_binary_op(unique_, st->ty(), in.operation(), av, bv, in.loc())); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(0), matrix_impl().arith(in, av, bv)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(arith_unary_inst const &in) { + auto const make_boolean = [&](arithmetic_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::not_: + return mod_->add(ty, a); + default: + break; + } + throw compilation_error(in.loc(), status::ir_boolean_unsupported); + }; + if (isa(*in.a().ty())) { + auto ty = unique_.bool_ty(); + auto av = val(in.a()); + declare(in.result(0), make_boolean(in.operation(), ty, av)); + } else if (auto st = dyn_cast(in.a().ty()); st) { + auto av = val(in.a()); + declare(in.result(0), make_unary_op(unique_, st->ty(), in.operation(), av, in.loc())); + } else if (auto ct = dyn_cast(in.a().ty()); ct) { + auto av = val(in.a()); + declare(in.result(0), matrix_impl().arith_unary(in, av)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(barrier_inst const &in) { + std::int32_t fence = 0; + if (in.has_fence(address_space::global)) { + fence = fence | static_cast(MemorySemantics::CrossWorkgroupMemory) | + static_cast(MemorySemantics::SequentiallyConsistent); + } + if (in.has_fence(address_space::local)) { + fence = fence | static_cast(MemorySemantics::WorkgroupMemory) | + static_cast(MemorySemantics::SequentiallyConsistent); + } + auto scope = unique_.constant(static_cast(Scope::Workgroup)); + auto memory_semantics = unique_.constant(fence); + mod_->add(scope, scope, memory_semantics); +} + +void inst_converter::operator()(builtin_inst const &in) { + switch (in.builtin_type()) { + case builtin::group_id_x: + case builtin::group_id_y: + case builtin::group_id_z: { + auto gid = unique_.load_builtin(BuiltIn::WorkgroupId); + auto index_ty = unique_.scalar_ty(scalar_type::index); + const std::int32_t mode = static_cast(in.builtin_type()) - + static_cast(builtin::group_id_x); + declare(in.result(0), + mod_->add(index_ty, gid, std::vector{mode})); + break; + } + case builtin::num_groups_x: + case builtin::num_groups_y: + case builtin::num_groups_z: { + auto ng = unique_.load_builtin(BuiltIn::NumWorkgroups); + auto index_ty = unique_.scalar_ty(scalar_type::index); + const std::int32_t mode = static_cast(in.builtin_type()) - + static_cast(builtin::num_groups_x); + declare(in.result(0), + mod_->add(index_ty, ng, std::vector{mode})); + break; + } + case builtin::num_subgroups_x: + declare(in.result(0), unique_.constant(tiling_.m_tiles())); + break; + case builtin::num_subgroups_y: + declare(in.result(0), unique_.constant(tiling_.n_tiles())); + break; + case builtin::subgroup_size: + declare(in.result(0), unique_.load_builtin(BuiltIn::SubgroupSize)); + break; + case builtin::subgroup_id_x: { + auto i32_ty = unique_.scalar_ty(scalar_type::i32); + auto m_tiles = unique_.constant(tiling_.m_tiles()); + auto sgid = unique_.load_builtin(BuiltIn::SubgroupId); + declare(in.result(0), mod_->add(i32_ty, sgid, m_tiles)); + break; + } + case builtin::subgroup_id_y: { + auto i32_ty = unique_.scalar_ty(scalar_type::i32); + auto m_tiles = unique_.constant(tiling_.m_tiles()); + auto sgid = unique_.load_builtin(BuiltIn::SubgroupId); + declare(in.result(0), mod_->add(i32_ty, sgid, m_tiles)); + break; + } + case builtin::subgroup_linear_id: + declare(in.result(0), unique_.load_builtin(BuiltIn::SubgroupId)); + break; + case builtin::subgroup_local_id: + declare(in.result(0), unique_.load_builtin(BuiltIn::SubgroupLocalInvocationId)); + break; + } +} + +void inst_converter::operator()(cast_inst const &in) { + if (auto st = dyn_cast(in.result(0).ty()); st) { + auto av = val(in.a()); + auto a_ty = get_scalar_type(in.a()); + declare(in.result(0), make_cast(unique_, st->ty(), a_ty, av, in.loc())); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + declare(in.result(0), matrix_impl().cast(in, val(in.a()))); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(compare_inst const &in) { + auto const compare_int = [&](cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (cond) { + case cmp_condition::eq: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ne: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::gt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ge: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::lt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::le: + return mod_->add(spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const compare_float = [&](cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (cond) { + case cmp_condition::eq: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ne: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::gt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::ge: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::lt: + return mod_->add(spv_to_ty, a, b); + case cmp_condition::le: + return mod_->add(spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + auto const compare_complex = [&](cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (cond) { + case cmp_condition::eq: { + auto components_equal = mod_->add(unique_.bool2_ty(), a, b); + return mod_->add(spv_to_ty, components_equal); + } + case cmp_condition::ne: { + auto components_not_equal = mod_->add(unique_.bool2_ty(), a, b); + return mod_->add(spv_to_ty, components_not_equal); + } + default: + throw compilation_error(in.loc(), status::ir_complex_unsupported); + } + }; + auto const make = [&](scalar_type a_ty, cmp_condition cond, spv_inst *spv_to_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return compare_int(cond, spv_to_ty, a, b); + case scalar_type::bf16: { + auto float_ty = unique_.scalar_ty(scalar_type::f32); + auto af = mod_->add(float_ty, a); + auto bf = mod_->add(float_ty, b); + auto af_op_bf = compare_float(cond, float_ty, af, bf); + return mod_->add(spv_to_ty, af_op_bf); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return compare_float(cond, spv_to_ty, a, b); + case scalar_type::c32: + case scalar_type::c64: + return compare_complex(cond, spv_to_ty, a, b); + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }; + + auto spv_to_ty = spv_ty(in.result(0).ty()); + auto av = val(in.a()); + auto bv = val(in.b()); + auto a_ty = get_scalar_type(in.a()); + declare(in.result(0), make(a_ty, in.cond(), spv_to_ty, av, bv)); +} + +void inst_converter::operator()(constant_inst const &in) { + if (isa(*in.result(0).ty())) { + if (!std::holds_alternative(in.value())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + declare(in.result(0), unique_.bool_constant(std::get(in.value()))); + } else if (auto st = dyn_cast(in.result(0).ty()); st) { + auto cst = make_constant(unique_, st->ty(), in.value()); + if (cst == nullptr) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + declare(in.result(0), cst); + } else if (auto ct = dyn_cast(in.result(0).ty()); ct) { + declare(in.result(0), matrix_impl().constant(in)); + } else { + throw compilation_error(in.loc(), status::ir_expected_coopmatrix_or_scalar); + } +} + +void inst_converter::operator()(cooperative_matrix_extract_inst const &in) { + declare(in.result(0), matrix_impl().extract(in, val(in.mat()))); +} +void inst_converter::operator()(cooperative_matrix_insert_inst const &in) { + declare(in.result(0), matrix_impl().insert(in, val(in.val()), val(in.mat()))); +} + +void inst_converter::operator()(cooperative_matrix_load_inst const &in) { + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + declare(in.result(0), + matrix_impl().load(in, *odv, val(in.operand()), val(in.pos0()), val(in.pos1()))); +} + +void inst_converter::operator()(cooperative_matrix_mul_add_inst const &in) { + declare(in.result(0), matrix_impl().mul_add(in, val(in.a()), val(in.b()), val(in.c()))); +} +void inst_converter::operator()(cooperative_matrix_prefetch_inst const &in) { + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + matrix_impl().prefetch(in, *odv, val(in.operand()), val(in.pos0()), val(in.pos1())); +} +void inst_converter::operator()(cooperative_matrix_reduce_inst const &in) { + declare(in.result(0), matrix_impl().reduce(in, val(in.a()))); +} +void inst_converter::operator()(cooperative_matrix_scale_inst const &in) { + declare(in.result(0), matrix_impl().scale(in, val(in.a()), val(in.b()))); +} +void inst_converter::operator()(cooperative_matrix_store_inst const &in) { + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + matrix_impl().store(in, *odv, val(in.val()), val(in.operand()), val(in.pos0()), val(in.pos1())); +} + +void inst_converter::operator()(expand_inst const &in) { + auto spv_index_ty = unique_.scalar_ty(scalar_type::index); + + auto shape = std::vector{}; + auto stride = std::vector{}; + auto const make_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + auto static_shape = in.static_expand_shape(); + auto dyn_shape = in.expand_shape(); + + shape.reserve(mt->dim() + static_shape.size() - 1); + stride.reserve(mt->dim() + static_shape.size() - 1); + + for (std::int64_t i = 0; i < in.expanded_mode(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + + auto get_shape = [&, j = std::size_t{0}](std::int64_t s) mutable { + if (is_dynamic_value(s)) { + return val(dyn_shape[j++]); + } + return unique_.constant(s); + }; + stride.push_back(dv->stride(in.expanded_mode())); + shape.push_back(get_shape(static_shape[0])); + for (std::size_t j = 1; j < static_shape.size(); ++j) { + stride.push_back(mod_->add(spv_index_ty, stride.back(), shape.back())); + shape.push_back(get_shape(static_shape[j])); + } + + for (std::int64_t i = in.expanded_mode() + 1; i < mt->dim(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + }; + make_shape_stride(); + declare(in.result(0), val(in.operand())); + + auto rdv = make_dope_vector(in.result(0)); + + if (shape.size() != static_cast(rdv->dim()) || + stride.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride[i]); + } +} + +void inst_converter::operator()(for_inst const &in) { + auto header_label_op = std::make_unique(); + auto body_label_op = std::make_unique(); + auto continue_label_op = std::make_unique(); + auto merge_label_op = std::make_unique(); + auto header_label = header_label_op.get(); + auto body_label = body_label_op.get(); + auto continue_label = continue_label_op.get(); + auto merge_label = merge_label_op.get(); + + auto entry_label = get_last_label(*mod_); + if (!entry_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + + mod_->add(header_label); + + // Header block + auto spv_bool_ty = unique_.bool_ty(); + auto spv_loop_var_ty = spv_ty(in.loop_var().ty()); + mod_->insts(section::function).push_back(header_label_op.release()); + // nullptr needs to be replaced by the loop var update once it is defined + auto loop_var_phi = mod_->add( + spv_loop_var_ty, std::vector{PairIdRefIdRef{val(in.from()), entry_label}, + PairIdRefIdRef{nullptr, continue_label}}); + declare(in.loop_var(), loop_var_phi); + auto const &make_iter_arg_phi = [&]() -> std::vector { + auto phis = std::vector{}; + phis.reserve(in.num_results()); + for (std::int64_t i = 0; i < in.num_results(); ++i) { + auto ty = spv_ty(in.iter_arg(i).ty()); + phis.emplace_back(mod_->add( + ty, std::vector{PairIdRefIdRef{val(in.iter_init(i)), entry_label}, + PairIdRefIdRef{nullptr, continue_label}})); + declare(in.iter_arg(i), phis.back()); + } + return phis; + }; + auto iter_arg_phis = make_iter_arg_phi(); + + auto condition = mod_->add(spv_bool_ty, loop_var_phi, val(in.to())); + auto loop_control = [&]() -> std::pair> { + auto unroll = get_attr(in.attr(), "unroll"); + if (unroll) { + auto ba = dyn_cast(unroll); + if (ba) { + return {ba->value() ? LoopControl::Unroll : LoopControl::DontUnroll, std::nullopt}; + } + auto ia = dyn_cast(unroll); + if (ia) { + return {LoopControl::PartialCount, ia->value()}; + } + throw status::ir_expected_boolean_attribute; + } + return {LoopControl::None, std::nullopt}; + }(); + mod_->add(merge_label, continue_label, loop_control.first, loop_control.second); + mod_->add(condition, body_label, merge_label, + std::vector{}); + + // Body block + mod_->insts(section::function).push_back(body_label_op.release()); + + auto yielded_for = run_on_region_with_yield(in.body(), in.num_results()); + // Update phis with yielded values + for (std::int64_t i = 0; i < in.num_results(); ++i) { + iter_arg_phis[i]->op0().back().first = yielded_for[i]; + } + + mod_->add(continue_label); + + // Continue block + mod_->insts(section::function).push_back(continue_label_op.release()); + auto step = [&]() -> spv_inst * { + if (in.has_step()) { + return val(in.step()); + } + return make_constant(unique_, get_scalar_type(in.loop_var()), std::int64_t{1}); + }(); + auto loop_var_update = mod_->add(spv_loop_var_ty, loop_var_phi, step); + loop_var_phi->op0().back().first = loop_var_update; + mod_->add(header_label); + + // Merge block + mod_->insts(section::function).push_back(merge_label_op.release()); + + auto const &set_results = [&] { + for (std::int64_t i = 0; i < in.num_results(); ++i) { + declare(in.result(i), val(in.iter_arg(i))); + } + }; + set_results(); +} + +void inst_converter::operator()(fuse_inst const &in) { + auto spv_index_ty = unique_.scalar_ty(scalar_type::index); + + auto shape = std::vector{}; + auto stride = std::vector{}; + auto const make_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + shape.reserve(mt->dim()); + stride.reserve(mt->dim()); + std::int64_t i = 0; + for (; i < in.from(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + spv_inst *prod = dv->shape(i++); + for (; i <= in.to(); ++i) { + prod = mod_->add(spv_index_ty, prod, dv->shape(i)); + } + shape.push_back(prod); + stride.push_back(dv->stride(in.from())); + for (i = in.to() + 1; i < mt->dim(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + }; + make_shape_stride(); + declare(in.result(0), val(in.operand())); + + auto rdv = make_dope_vector(in.result(0)); + + if (shape.size() != static_cast(rdv->dim()) || + stride.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride[i]); + } +} + +void inst_converter::operator()(if_inst const &in) { + auto then_label = std::make_unique(); + auto otherwise_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto conditionv = val(in.condition()); + mod_->add(merge_label.get(), SelectionControl::None); + mod_->add(conditionv, then_label.get(), otherwise_label.get(), + std::vector{}); + mod_->insts(section::function).push_back(then_label.release()); + auto yielded_then = run_on_region_with_yield(in.then(), in.num_results()); + mod_->add(merge_label.get()); + auto then_last_label = get_last_label(*mod_); + if (!then_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + mod_->insts(section::function).push_back(otherwise_label.release()); + auto yielded_otherwise = run_on_region_with_yield(in.otherwise(), in.num_results()); + mod_->add(merge_label.get()); + auto otherwise_last_label = get_last_label(*mod_); + if (!otherwise_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + + mod_->insts(section::function).push_back(merge_label.release()); + + std::int64_t val_no = 0; + for (std::int64_t i = 0; i < in.num_results(); ++i) { + auto ty = spv_ty(in.result(i).ty()); + auto phi_inst = mod_->add( + ty, std::vector{ + PairIdRefIdRef{yielded_then[val_no], then_last_label}, + PairIdRefIdRef{yielded_otherwise[val_no], otherwise_last_label}}); + ++val_no; + declare(in.result(i), phi_inst); + } +} + +void inst_converter::operator()(lifetime_stop_inst const &) {} + +void inst_converter::operator()(load_inst const &in) { + auto spv_index_ty = unique_.scalar_ty(scalar_type::index); + auto spv_pointer_index_ty = + unique_.pointer_ty(StorageClass::CrossWorkgroup, spv_index_ty, alignment(scalar_type::i64)); + auto spv_pointer_ty = spv_ty(in.operand().ty()); + auto spv_result_ty = spv_ty(in.result(0).ty()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + if (auto group_ty = dyn_cast(in.operand().ty()); group_ty) { + auto offset = mod_->add(spv_index_ty, dv->offset(), val(in.index_list()[0])); + auto pointer = mod_->add(spv_pointer_ty, val(in.operand()), + offset, std::vector{}); + declare(in.result(0), mod_->add(spv_result_ty, pointer)); + auto rdv = make_dope_vector(in.result(0)); + + auto const make_dope_par = [&](std::int64_t static_s, spv_inst *s) -> spv_inst * { + if (is_dynamic_value(static_s)) { + auto pointer = mod_->add(spv_pointer_index_ty, s, offset, + std::vector{}); + return mod_->add(spv_index_ty, pointer); + } + return s; + }; + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, make_dope_par(dv->static_shape(i), dv->shape(i))); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, make_dope_par(dv->static_stride(i), dv->stride(i))); + } + } else if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { + const auto pointer = [&](spv_inst *additional_offset0 = nullptr) -> spv_inst * { + if (memref_ty->dim() == 0) { + return val(in.operand()); + } + + auto idx0 = val(in.index_list()[0]); + spv_inst *offset = memref_ty->stride(0) != 1 + ? mod_->add(spv_index_ty, idx0, dv->stride(0)) + : idx0; + for (std::int64_t i = 1; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + if (additional_offset0) { + offset = mod_->add(spv_index_ty, offset, additional_offset0); + } + return mod_->add(spv_pointer_ty, val(in.operand()), offset, + std::vector{}); + }; + declare(in.result(0), mod_->add(spv_result_ty, pointer())); + } else { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + } +} + +void inst_converter::operator()(math_unary_inst const &in) { + auto const make_float = [&](math_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + auto const make_ext_inst = [&](OpenCLEntrypoint ep) { + return mod_->add(ty, unique_.opencl_ext(), static_cast(ep), + std::vector{a}); + }; + switch (op) { + case math_unary::cos: + return make_ext_inst(OpenCLEntrypoint::cos); + case math_unary::sin: + return make_ext_inst(OpenCLEntrypoint::sin); + case math_unary::exp: + return make_ext_inst(OpenCLEntrypoint::exp); + case math_unary::exp2: + return make_ext_inst(OpenCLEntrypoint::exp2); + case math_unary::native_cos: + return make_ext_inst(OpenCLEntrypoint::native_cos); + case math_unary::native_sin: + return make_ext_inst(OpenCLEntrypoint::native_sin); + case math_unary::native_exp: + return make_ext_inst(OpenCLEntrypoint::native_exp); + case math_unary::native_exp2: + return make_ext_inst(OpenCLEntrypoint::native_exp2); + default: + throw compilation_error(in.loc(), status::internal_compiler_error); + } + }; + auto const make_complex = [&](math_unary op, scalar_type sty, spv_inst *ty, spv_inst *a, + LiteralContextDependentNumber log2) -> spv_inst * { + auto spv_float_ty = unique_.scalar_ty(component_type(sty)); + auto const make_complex_exp = [&](auto exp_ep, auto cos_ep, auto sin_ep, + spv_inst *im_scale = nullptr) { + auto a0 = + mod_->add(spv_float_ty, a, std::vector{0}); + spv_inst *a1 = + mod_->add(spv_float_ty, a, std::vector{1}); + if (im_scale) { + a1 = mod_->add(spv_float_ty, a1, im_scale); + } + auto e = + mod_->add(spv_float_ty, unique_.opencl_ext(), + static_cast(exp_ep), std::vector{a0}); + auto c = + mod_->add(spv_float_ty, unique_.opencl_ext(), + static_cast(cos_ep), std::vector{a1}); + auto s = + mod_->add(spv_float_ty, unique_.opencl_ext(), + static_cast(sin_ep), std::vector{a1}); + auto r = mod_->add(spv_float_ty, e, c); + auto i = mod_->add(spv_float_ty, e, s); + auto dummy = mod_->add(ty); + auto result = + mod_->add(ty, r, dummy, std::vector{0}); + return mod_->add(ty, i, result, std::vector{1}); + }; + switch (op) { + case math_unary::exp: + return make_complex_exp(OpenCLEntrypoint::exp, OpenCLEntrypoint::cos, + OpenCLEntrypoint::sin); + case math_unary::exp2: + return make_complex_exp(OpenCLEntrypoint::exp2, OpenCLEntrypoint::cos, + OpenCLEntrypoint::sin, unique_.constant(log2)); + case math_unary::native_exp: + return make_complex_exp(OpenCLEntrypoint::native_exp, OpenCLEntrypoint::native_cos, + OpenCLEntrypoint::native_sin); + case math_unary::native_exp2: + return make_complex_exp(OpenCLEntrypoint::native_exp2, OpenCLEntrypoint::native_cos, + OpenCLEntrypoint::native_sin, unique_.constant(log2)); + default: + throw compilation_error(in.loc(), status::internal_compiler_error); + } + }; + auto const make = [&](scalar_type sty, math_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (sty) { + case scalar_type::bf16: { + auto float_ty = unique_.scalar_ty(scalar_type::f32); + auto af = mod_->add(float_ty, a); + auto op_af = make_float(op, float_ty, af); + return mod_->add(ty, op_af); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return make_float(op, ty, a); + case scalar_type::c32: + return make_complex(op, sty, ty, a, std::log(2.0f)); + case scalar_type::c64: + return make_complex(op, sty, ty, a, std::log(2.0)); + default: + throw compilation_error(in.loc(), status::internal_compiler_error); + } + }; + + auto sty = get_scalar_type(in.a()); + auto ty = spv_ty(in.result(0).ty()); + auto av = val(in.a()); + declare(in.result(0), make(sty, in.operation(), ty, av)); +} + +void inst_converter::operator()(parallel_inst const &in) { run_on_region(in.body()); } + +void inst_converter::operator()(size_inst const &in) { + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + const auto shape = ::tinytc::visit( + overloaded{[&](group_data_type const &) -> spv_inst * { return dv->size(); }, + [&](memref_data_type const &) -> spv_inst * { return dv->shape(in.mode()); }, + [&](auto const &) -> spv_inst * { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + }}, + *in.operand().ty()); + declare(in.result(0), shape); +} + +void inst_converter::operator()(subgroup_broadcast_inst const &in) { + auto broadcast_scope = unique_.constant(static_cast(Scope::Subgroup)); + auto ty = spv_ty(in.result(0).ty()); + auto av = val(in.a()); + auto idxv = val(in.idx()); + declare(in.result(0), mod_->add(ty, broadcast_scope, av, idxv)); +} + +void inst_converter::operator()(subgroup_operation_inst const &in) { + auto sty = get_scalar_type(in.a()); + declare(in.result(0), + make_subgroup_op(unique_, sty, in.arith(), in.operation(), val(in.a()), in.loc())); +} + +void inst_converter::operator()(store_inst const &in) { + auto spv_index_ty = unique_.scalar_ty(scalar_type::index); + auto spv_pointer_ty = spv_ty(in.operand().ty()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { + const auto pointer = [&]() -> spv_inst * { + if (memref_ty->dim() == 0) { + return val(in.operand()); + } + + auto idx0 = val(in.index_list()[0]); + auto offset = memref_ty->stride(0) != 1 + ? mod_->add(spv_index_ty, idx0, dv->stride(0)) + : idx0; + for (std::int64_t i = 1; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + + return mod_->add(spv_pointer_ty, val(in.operand()), offset, + std::vector{}); + }; + + make_store(unique_, in.flag(), memref_ty->element_ty(), memref_ty->addrspace(), pointer(), + val(in.val()), in.loc()); + } else { + throw compilation_error(in.loc(), status::ir_expected_memref); + } +} + +void inst_converter::operator()(subview_inst const &in) { + auto spv_index_ty = unique_.scalar_ty(scalar_type::index); + auto spv_result_ty = spv_ty(in.result(0).ty()); + + auto shape_out = std::vector{}; + auto stride_out = std::vector{}; + auto const make_offset_and_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + shape_out.reserve(mt->dim()); + stride_out.reserve(mt->dim()); + auto dyn_offsets = in.offsets(); + auto dyn_sizes = in.sizes(); + auto offset_acc = unique_.null_constant(spv_index_ty); + for (std::int64_t i = 0, joffset = 0, jsize = 0; i < mt->dim(); ++i) { + const std::int64_t offset = in.static_offsets()[i]; + + auto const offset_inst = [&]() -> spv_inst * { + if (is_dynamic_value(offset)) { + return val(dyn_offsets[joffset++]); + } + return unique_.constant(offset); + }; + auto tmp = mod_->add(spv_index_ty, offset_inst(), dv->stride(i)); + offset_acc = mod_->add(spv_index_ty, offset_acc, tmp); + + const std::int64_t size = in.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + auto const size_inst = [&]() -> spv_inst * { + if (is_dynamic_value(size)) { + return val(dyn_sizes[jsize++]); + } + return unique_.constant(size); + }; + shape_out.emplace_back(size_inst()); + stride_out.emplace_back(dv->stride(i)); + } + } + return offset_acc; + }; + + auto offset = make_offset_and_shape_stride(); + declare(in.result(0), mod_->add(spv_result_ty, val(in.operand()), + offset, std::vector{})); + + auto rdv = make_dope_vector(in.result(0)); + + if (shape_out.size() != static_cast(rdv->dim()) || + stride_out.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape_out[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride_out[i]); + } +} + +void inst_converter::operator()(yield_inst const &in) { + if (yielded_vals_.empty()) { + throw compilation_error(in.loc(), status::ir_unexpected_yield); + } + + auto &top = yielded_vals_.top(); + if (static_cast(top.size()) != in.num_operands()) { + throw compilation_error(in.loc(), status::ir_yield_mismatch); + } + + std::int64_t i = 0; + for (auto &op : in.operands()) { + top[i++] = val(op); + } +} + +void inst_converter::run_on_region(region_node const ®) { + for (auto const &i : reg) { + visit(*this, i); + } +} + +auto inst_converter::run_on_region_with_yield(region_node const ®, std::int64_t num_results) + -> std::vector { + yielded_vals_.push(std::vector(num_results, nullptr)); + run_on_region(reg); + auto yielded_vals = std::move(yielded_vals_.top()); + if (static_cast(yielded_vals.size()) != num_results || + std::any_of(yielded_vals.begin(), yielded_vals.end(), + [](spv_inst *in) { return in == nullptr; })) { + throw compilation_error(reg.loc(), status::ir_yield_mismatch); + } + yielded_vals_.pop(); + return yielded_vals; +} + +void inst_converter::run_on_function(function_node const &fn) { + try { + core_cfg_ = info_->get_core_config(fn.subgroup_size()); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + + auto vars_used_by_function = std::vector{}; + + // Stack + auto const make_stack = [&] { + const auto high_water_mark = stack_high_water_mark{}.run_on_function(fn); + if (high_water_mark > 0) { + auto stack_element_ty = unique_.scalar_ty(scalar_type::i8); + auto stack_array_ty = unique_.array_ty(stack_element_ty, high_water_mark); + auto stack_ptr_ty = unique_.pointer_ty(StorageClass::Workgroup, stack_array_ty, 0); + stack_ = mod_->add_to(section::type_const_var, stack_ptr_ty, + StorageClass::Workgroup); + const std::int32_t alignment = info_->alignment(); + mod_->add_to(section::decoration, stack_, Decoration::Alignment, + DecorationAttr{alignment}); + vars_used_by_function.emplace_back(stack_); + } else { + stack_ = nullptr; + } + }; + make_stack(); + + // Function type + auto fun_ty = unique_.function_ty(unique_.void_ty(), [&] { + auto params = std::vector{}; + params.reserve(fn.num_params()); + for (auto const &p : fn.params()) { + params.emplace_back(spv_ty(p.ty())); + auto dv = make_dope_vector(p); + if (dv) { + for (std::int64_t i = 0; i < dv->num_dynamic(); ++i) { + params.emplace_back(dv->ty()); + } + if (is_dynamic_value(dv->static_size())) { + params.emplace_back(dv->size_ty()); + } + if (is_dynamic_value(dv->static_offset())) { + params.emplace_back(dv->offset_ty()); + } + } + } + return params; + }()); + + // Function + auto const subgroup_size = fn.subgroup_size(); + auto const work_group_size = fn.work_group_size(); + tiling_[0] = work_group_size[0] / subgroup_size; + tiling_[1] = work_group_size[1]; + + matrix_impl_ = [&]() -> std::unique_ptr { + const auto gcd = gcd_analysis{info_->alignment()}.run_on_function(fn); + if (info_->matrix().have_dpas()) { + return std::make_unique(unique_, core_cfg_, std::move(gcd)); + } else if (info_->have_spirv_feature(spirv_feature::subgroup_buffer_block_io)) { + return std::make_unique(unique_, core_cfg_, std::move(gcd)); + } + return std::make_unique(unique_, core_cfg_, std::move(gcd)); + }(); + + auto void_ty = unique_.void_ty(); + auto fun = mod_->add(void_ty, FunctionControl::None, fun_ty); + for (auto const &p : fn.params()) { + declare(p, mod_->add(spv_ty(p.ty()))); + auto dv = get_dope_vector(p); + if (dv) { + auto const make_dope_par = [&](spv_inst *ty, std::int64_t s) { + return is_dynamic_value(s) ? mod_->add(ty) + : unique_.constant(s); + }; + for (std::int64_t i = 0; i < dv->dim(); ++i) { + dv->shape(i, make_dope_par(dv->ty(), dv->static_shape(i))); + } + for (std::int64_t i = 0; i < dv->dim(); ++i) { + dv->stride(i, make_dope_par(dv->ty(), dv->static_stride(i))); + } + if (dv->size_ty()) { + dv->size(make_dope_par(dv->size_ty(), dv->static_size())); + } + if (dv->offset_ty()) { + dv->offset(make_dope_par(dv->offset_ty(), dv->static_offset())); + } + } + } + + mod_->add(); + auto func_begin = --mod_->insts(section::function).end(); + run_on_region(fn.body()); + + auto func_end = mod_->insts(section::function).end(); + for (auto it = func_begin; it != func_end; ++it) { + if (auto ld = dyn_cast(it.get()); ld && isa(*ld->op0())) { + if (std::find(vars_used_by_function.begin(), vars_used_by_function.end(), ld->op0()) == + vars_used_by_function.end()) { + vars_used_by_function.push_back(ld->op0()); + } + } + } + + mod_->add(); + mod_->add(); + + tiling_ = {}; + + // Entry point + mod_->add_to(section::entry_point, ExecutionModel::Kernel, fun, + std::string{fn.name()}, std::move(vars_used_by_function)); + + // Execution mode + mod_->add_to( + section::execution_mode, fun, ExecutionMode::LocalSize, + ExecutionModeAttr{std::array{work_group_size[0], work_group_size[1], 1}}); + mod_->add_to(section::execution_mode, fun, ExecutionMode::SubgroupSize, + ExecutionModeAttr{subgroup_size}); +} + +} // namespace tinytc::spv + diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp new file mode 100644 index 00000000..e147c1d9 --- /dev/null +++ b/src/spv/converter.hpp @@ -0,0 +1,95 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERTER_20241111_HPP +#define CONVERTER_20241111_HPP + +#include "device_info.hpp" +#include "node/inst_node.hpp" +#include "node/region_node.hpp" +#include "spv/coopmatrix_impl.hpp" +#include "spv/defs.hpp" +#include "spv/dope_vector.hpp" +#include "spv/uniquifier.hpp" +#include "tiling.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto convert_prog_to_spirv(tinytc_prog const &p, tinytc_core_info const &info) -> ::tinytc::spv_mod; + +class inst_converter { + public: + inst_converter(tinytc_spv_mod &m, tinytc_core_info const &info); + + // Instruction nodes + void operator()(inst_node const &in); + void operator()(alloca_inst const &in); + void operator()(arith_inst const &in); + void operator()(arith_unary_inst const &in); + void operator()(barrier_inst const &in); + void operator()(builtin_inst const &in); + void operator()(cast_inst const &in); + void operator()(compare_inst const &in); + void operator()(constant_inst const &in); + void operator()(cooperative_matrix_extract_inst const &in); + void operator()(cooperative_matrix_insert_inst const &in); + void operator()(cooperative_matrix_load_inst const &in); + void operator()(cooperative_matrix_mul_add_inst const &in); + void operator()(cooperative_matrix_prefetch_inst const &in); + void operator()(cooperative_matrix_reduce_inst const &in); + void operator()(cooperative_matrix_scale_inst const &in); + void operator()(cooperative_matrix_store_inst const &in); + void operator()(expand_inst const &in); + void operator()(for_inst const &in); + void operator()(fuse_inst const &in); + void operator()(if_inst const &in); + void operator()(lifetime_stop_inst const &in); + void operator()(load_inst const &in); + void operator()(math_unary_inst const &in); + void operator()(parallel_inst const &in); + void operator()(size_inst const &in); + void operator()(subgroup_broadcast_inst const &in); + void operator()(subgroup_operation_inst const &in); + void operator()(store_inst const &in); + void operator()(subview_inst const &in); + void operator()(yield_inst const &in); + + void run_on_region(tinytc_region const ®); + auto run_on_region_with_yield(region_node const ®, std::int64_t num_results) + -> std::vector; + void run_on_function(tinytc_func const &fn); + + inline auto unique() -> uniquifier & { return unique_; } + + private: + auto get_dope_vector(tinytc_value const &v) -> dope_vector *; + auto declare(tinytc_value const &v, spv_inst *in); + auto val(tinytc_value const &v) -> spv_inst *; + auto spv_ty(const_tinytc_data_type_t ty) -> spv_inst *; + auto make_dope_vector(tinytc_value const &v) -> dope_vector *; + auto matrix_impl() -> coopmatrix_impl &; + auto make_matrix_impl() -> std::unique_ptr; + + tinytc_spv_mod_t mod_; + tinytc_core_info const *info_; + uniquifier unique_; + std::unique_ptr matrix_impl_ = nullptr; + std::unordered_map dope_vec_; + std::unordered_map vals_; + std::stack> yielded_vals_; + spv_inst *stack_ = nullptr; + core_config core_cfg_ = {}; + local_tiling tiling_ = {}; +}; + +} // namespace tinytc::spv + +#endif // CONVERTER_20241111_HPP diff --git a/src/spv/converter_aux.cpp b/src/spv/converter_aux.cpp new file mode 100644 index 00000000..6f1f8ed0 --- /dev/null +++ b/src/spv/converter_aux.cpp @@ -0,0 +1,675 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/converter_aux.hpp" +#include "error.hpp" +#include "scalar_type.hpp" +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto convert_group_operation(group_operation op) -> GroupOperation { + switch (op) { + case group_operation::exclusive_scan: + return GroupOperation::ExclusiveScan; + case group_operation::inclusive_scan: + return GroupOperation::InclusiveScan; + case group_operation::reduce: + return GroupOperation::Reduce; + } + throw status::internal_compiler_error; +} + +auto get_last_label(tinytc_spv_mod &mod) -> spv_inst * { + auto &insts = mod.insts(section::function); + auto it = insts.end(); + while (it != insts.begin()) { + auto in = (--it).get(); + if (isa(*in)) { + return in; + } + } + return nullptr; +} + +auto make_binary_op(uniquifier &unique, scalar_type sty, arithmetic op, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_int = [&](arithmetic op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod.add(ty, a, b); + case arithmetic::sub: + return mod.add(ty, a, b); + case arithmetic::mul: + return mod.add(ty, a, b); + case arithmetic::div: + return mod.add(ty, a, b); + case arithmetic::rem: + return mod.add(ty, a, b); + case arithmetic::shl: + return mod.add(ty, a, b); + case arithmetic::shr: + return mod.add(ty, a, b); + case arithmetic::and_: + return mod.add(ty, a, b); + case arithmetic::or_: + return mod.add(ty, a, b); + case arithmetic::xor_: + return mod.add(ty, a, b); + case arithmetic::min: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::s_min), + std::vector{a, b}); + case arithmetic::max: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::s_max), + std::vector{a, b}); + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_float = [&](arithmetic op, spv_inst *ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod.add(ty, a, b); + case arithmetic::sub: + return mod.add(ty, a, b); + case arithmetic::mul: + return mod.add(ty, a, b); + case arithmetic::div: + return mod.add(ty, a, b); + case arithmetic::rem: + return mod.add(ty, a, b); + case arithmetic::min: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fmin), + std::vector{a, b}); + case arithmetic::max: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fmax), + std::vector{a, b}); + default: + break; + } + throw compilation_error(loc, status::ir_fp_unsupported); + }; + auto const make_complex = [&](arithmetic op, spv_inst *ty, spv_inst *float_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case arithmetic::add: + return mod.add(ty, a, b); + case arithmetic::sub: + return mod.add(ty, a, b); + case arithmetic::mul: { + return make_complex_mul(unique, ty, a, b); + } + case arithmetic::div: { + auto a_times_conj_b = make_complex_mul(unique, ty, a, b, true); + + auto b_squared = mod.add(ty, b, b); + auto b_squared_0 = + mod.add(float_ty, b_squared, std::vector{0}); + auto b_squared_1 = + mod.add(float_ty, b_squared, std::vector{1}); + spv_inst *b_abs = mod.add(float_ty, b_squared_0, b_squared_1); + auto dummy = mod.add(ty); + b_abs = mod.add(ty, b_abs, dummy, std::vector{0}); + b_abs = mod.add(ty, b_abs, dummy, std::vector{0, 0}); + return mod.add(ty, a_times_conj_b, b_abs); + } + default: + break; + } + throw compilation_error(loc, status::ir_complex_unsupported); + }; + auto ty = unique.scalar_ty(sty); + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a, b); + case scalar_type::bf16: { + auto float_ty = unique.scalar_ty(scalar_type::f32); + auto af = mod.add(float_ty, a); + auto bf = mod.add(float_ty, b); + auto af_op_bf = make_float(op, float_ty, af, bf); + return mod.add(ty, af_op_bf); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return make_float(op, ty, a, b); + case scalar_type::c32: + case scalar_type::c64: + return make_complex(op, ty, unique.scalar_ty(component_type(sty)), a, b); + } + throw compilation_error(loc, status::internal_compiler_error); +} + +auto make_binary_op_mixed_precision(uniquifier &unique, scalar_type result_ty, arithmetic op, + scalar_type a_ty, spv_inst *a, scalar_type b_ty, spv_inst *b, + location const &loc) -> spv_inst * { + if (!promotable(a_ty, result_ty) || !promotable(b_ty, result_ty)) { + throw compilation_error(loc, status::ir_forbidden_promotion); + } + if (a_ty != result_ty) { + a = make_cast(unique, result_ty, a_ty, a, loc); + } + if (b_ty != result_ty) { + b = make_cast(unique, result_ty, b_ty, b, loc); + } + return make_binary_op(unique, result_ty, op, a, b, loc); +} + +auto make_cast(uniquifier &unique, scalar_type to_ty, scalar_type a_ty, spv_inst *a, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto float_ty = unique.scalar_ty(scalar_type::f32); + auto const cast_from_int = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod.add(spv_to_ty, a); + case scalar_type::bf16: { + auto af = mod.add(float_ty, a); + return mod.add(spv_to_ty, af); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return mod.add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto spv_float_ty = unique.scalar_ty(component_type(to_ty)); + auto re = mod.add(spv_float_ty, a); + return mod.add(spv_to_ty, re, unique.null_constant(spv_to_ty), + std::vector{0}); + } + } + throw compilation_error(loc, status::ir_forbidden_cast); + }; + auto const cast_from_float = [&](scalar_type to_ty, spv_inst *spv_to_ty, scalar_type a_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod.add(spv_to_ty, a); + case scalar_type::bf16: + return mod.add(spv_to_ty, a); + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return mod.add(spv_to_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + auto re = a; + if (component_type(to_ty) != a_ty) { + auto spv_float_ty = unique.scalar_ty(component_type(to_ty)); + re = mod.add(spv_float_ty, a); + } + // If the line below is change, adjust make_complex_mul as well + return mod.add(spv_to_ty, re, unique.null_constant(spv_to_ty), + std::vector{0}); + } + } + throw compilation_error(loc, status::ir_forbidden_cast); + }; + auto const cast_from_complex = [&](scalar_type to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + switch (to_ty) { + case scalar_type::c32: + case scalar_type::c64: + return mod.add(spv_to_ty, a); + default: + throw compilation_error(loc, status::ir_forbidden_cast); + } + }; + auto spv_to_ty = unique.scalar_ty(to_ty); + if (a_ty == to_ty) { + return mod.add(spv_to_ty, a); + } + switch (a_ty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return cast_from_int(to_ty, spv_to_ty, a); + case scalar_type::bf16: { + auto af = mod.add(float_ty, a); + return cast_from_float(to_ty, spv_to_ty, scalar_type::f32, af); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return cast_from_float(to_ty, spv_to_ty, a_ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return cast_from_complex(to_ty, spv_to_ty, a); + } + } + throw compilation_error(loc, status::internal_compiler_error); +} + +auto make_complex_mul(uniquifier &unique, spv_inst *ty, spv_inst *a, spv_inst *b, bool conj_b) + -> spv_inst * { + auto &mod = unique.mod(); + const auto is_imag_zero = [&](spv_inst *a) -> bool { + // We capture the case here if "a" stems from a non-complex -> complex cast + if (auto ci = dyn_cast(a); ci) { + return ci->type() == ty && ci->op1() == unique.null_constant(ty) && + ci->op2().size() == 1 && ci->op2()[0] == 0; + } + return false; + }; + + if (is_imag_zero(a)) { + a = mod.add(ty, a, a, std::vector{0, 0}); + return mod.add(ty, a, b); + } else if (is_imag_zero(b)) { + b = mod.add(ty, b, b, std::vector{0, 0}); + return mod.add(ty, a, b); + } + + auto neg_a = mod.add(ty, a); + auto a_times_i = + conj_b ? mod.add(ty, a, neg_a, std::vector{1, 2}) + : mod.add(ty, neg_a, a, std::vector{1, 2}); + auto b_1 = mod.add(ty, b, b, std::vector{1, 1}); + auto b_1_a_times_i = mod.add(ty, b_1, a_times_i); + auto b_0 = mod.add(ty, b, b, std::vector{0, 0}); + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fma), + std::vector{a, b_0, b_1_a_times_i}); +} + +auto make_constant(uniquifier &unique, scalar_type sty, constant_inst::value_type const &val) + -> spv_inst * { + auto const add_constant_complex = [&](auto cst) -> spv_inst * { + auto c_re = unique.constant(cst.real()); + auto c_im = unique.constant(cst.imag()); + auto ty = unique.scalar_ty(sty); + return unique.mod().add_to(section::type_const_var, ty, + std::vector{c_re, c_im}); + }; + const auto visitor = overloaded{ + [&](bool) -> spv_inst * { return nullptr; }, + [&](std::int64_t i) -> spv_inst * { + switch (sty) { + case scalar_type::i8: + return unique.constant(static_cast(i)); + case scalar_type::i16: + return unique.constant(static_cast(i)); + case scalar_type::i32: + return unique.constant(static_cast(i)); + case scalar_type::i64: + case scalar_type::index: + return unique.constant(i); + default: + return nullptr; + } + }, + [&](double d) -> spv_inst * { + switch (sty) { + case scalar_type::bf16: + return unique.constant(bfloat16{static_cast(d)}.bits()); + case scalar_type::f16: + return unique.constant(half{static_cast(d)}); + case scalar_type::f32: + return unique.constant(static_cast(d)); + case scalar_type::f64: + return unique.constant(d); + default: + return nullptr; + } + }, + [&](std::complex d) -> spv_inst * { + switch (sty) { + case scalar_type::c32: + return add_constant_complex(static_cast>(d)); + case scalar_type::c64: + return add_constant_complex(d); + default: + return nullptr; + } + }, + }; + auto cst = std::visit(visitor, val); + if (!cst) { + throw status::internal_compiler_error; + } + return cst; +} + +void make_conditional_execution(uniquifier &unique, spv_inst *condition, + std::function then) { + auto then_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto &mod = unique.mod(); + mod.add(merge_label.get(), SelectionControl::None); + mod.add(condition, then_label.get(), merge_label.get(), + std::vector{}); + mod.insts(section::function).push_back(then_label.release()); + then(mod); + mod.add(merge_label.get()); + mod.insts(section::function).push_back(merge_label.release()); +} + +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + spv_inst *otherwise, location const &loc) -> spv_inst * { + auto then_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto &mod = unique.mod(); + auto init_last_label = get_last_label(mod); + if (!init_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod.add(merge_label.get(), SelectionControl::None); + mod.add(condition, then_label.get(), merge_label.get(), + std::vector{}); + mod.insts(section::function).push_back(then_label.release()); + spv_inst *yielded_then = then(mod); + mod.add(merge_label.get()); + auto then_last_label = get_last_label(mod); + if (!then_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod.insts(section::function).push_back(merge_label.release()); + return mod.add(return_ty, + std::vector{PairIdRefIdRef{yielded_then, then_last_label}, + PairIdRefIdRef{otherwise, init_last_label}}); +} + +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + std::function otherwise, + location const &loc) -> spv_inst * { + auto then_label = std::make_unique(); + auto otherwise_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto &mod = unique.mod(); + mod.add(merge_label.get(), SelectionControl::None); + mod.add(condition, then_label.get(), otherwise_label.get(), + std::vector{}); + mod.insts(section::function).push_back(then_label.release()); + spv_inst *yielded_then = then(mod); + mod.add(merge_label.get()); + auto then_last_label = get_last_label(mod); + if (!then_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + mod.insts(section::function).push_back(otherwise_label.release()); + spv_inst *yielded_otherwise = otherwise(mod); + mod.add(merge_label.get()); + auto otherwise_last_label = get_last_label(mod); + if (!otherwise_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod.insts(section::function).push_back(merge_label.release()); + return mod.add(return_ty, std::vector{ + PairIdRefIdRef{yielded_then, then_last_label}, + PairIdRefIdRef{yielded_otherwise, otherwise_last_label}}); +} + +void make_store(uniquifier &unique, store_flag flag, scalar_type sty, address_space as, + spv_inst *pointer, spv_inst *value, location const &loc) { + auto &mod = unique.mod(); + auto const split_re_im = [&]() -> std::array, 2u> { + auto component_sty = component_type(sty); + auto float_ty = unique.scalar_ty(component_sty); + const auto storage_cls = address_space_to_storage_class(as); + auto pointer_ty = unique.pointer_ty(storage_cls, float_ty, alignment(component_sty)); + auto c0 = unique.constant(std::int32_t{0}); + auto c1 = unique.constant(std::int32_t{1}); + auto re_ptr = mod.add(pointer_ty, pointer, std::vector{c0}); + auto im_ptr = mod.add(pointer_ty, pointer, std::vector{c1}); + auto re_val = mod.add(float_ty, value, std::vector{0}); + auto im_val = mod.add(float_ty, value, std::vector{1}); + return {{{re_ptr, re_val}, {im_ptr, im_val}}}; + }; + auto const make_atomic_something = [&]() { + auto result_ty = unique.scalar_ty(sty); + auto scope = unique.constant(static_cast(Scope::Workgroup)); + auto semantics = unique.constant(static_cast(MemorySemantics::Relaxed)); + switch (sty) { + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + mod.add(result_ty, pointer, scope, semantics, value); + break; + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + mod.add(result_ty, pointer, scope, semantics, value); + break; + case scalar_type::c32: + case scalar_type::c64: { + auto re_im = split_re_im(); + auto component_sty = component_type(sty); + auto float_ty = unique.scalar_ty(component_sty); + mod.add(float_ty, re_im[0][0], scope, semantics, re_im[0][1]); + mod.add(float_ty, re_im[1][0], scope, semantics, re_im[1][1]); + break; + } + default: + throw compilation_error(loc, status::spirv_unsupported_atomic_data_type); + } + }; + switch (flag) { + case store_flag::regular: + mod.add(pointer, value); + break; + case store_flag::atomic: { + auto scope = unique.constant(static_cast(Scope::Workgroup)); + auto semantics = unique.constant(static_cast(MemorySemantics::Relaxed)); + switch (sty) { + case scalar_type::c32: + case scalar_type::c64: { + auto re_im = split_re_im(); + mod.add(re_im[0][0], scope, semantics, re_im[0][1]); + mod.add(re_im[1][0], scope, semantics, re_im[1][1]); + break; + } + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + mod.add(pointer, scope, semantics, value); + break; + default: + throw compilation_error(loc, status::spirv_unsupported_atomic_data_type); + } + break; + } + case store_flag::atomic_add: + make_atomic_something.template operator()(); + break; + case store_flag::atomic_max: + make_atomic_something.template operator()(); + break; + case store_flag::atomic_min: + make_atomic_something.template operator()(); + break; + } +} + +auto make_unary_op(uniquifier &unique, scalar_type sty, arithmetic_unary op, spv_inst *a, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_int = [&](arithmetic_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::abs: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::s_abs), + std::vector{a}); + case arithmetic_unary::neg: + return mod.add(ty, a); + case arithmetic_unary::not_: + return mod.add(ty, a); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_float = [&](arithmetic_unary op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case arithmetic_unary::abs: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fabs), + std::vector{a}); + case arithmetic_unary::neg: + return mod.add(ty, a); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_complex = [&](arithmetic_unary op, scalar_type sty, spv_inst *ty, + spv_inst *a) -> spv_inst * { + auto float_ty = unique.scalar_ty(component_type(sty)); + switch (op) { + case arithmetic_unary::abs: { + auto a2 = mod.add(ty, a, a); + auto a2_0 = mod.add(float_ty, a2, std::vector{0}); + auto a2_1 = mod.add(float_ty, a2, std::vector{1}); + auto a2_0p1 = mod.add(float_ty, a2_0, a2_1); + return mod.add(float_ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::sqrt), + std::vector{a2_0p1}); + } + case arithmetic_unary::neg: + return mod.add(ty, a); + case arithmetic_unary::conj: { + auto a_im = mod.add(float_ty, a, std::vector{1}); + auto neg_a_im = mod.add(float_ty, a_im); + return mod.add(ty, neg_a_im, a, std::vector{1}); + } + case arithmetic_unary::im: + return mod.add(float_ty, a, std::vector{1}); + case arithmetic_unary::re: + return mod.add(float_ty, a, std::vector{0}); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + + spv_inst *ty = unique.scalar_ty(sty); + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return make_int(op, ty, a); + case scalar_type::bf16: { + auto float_ty = unique.scalar_ty(scalar_type::f32); + auto af = mod.add(float_ty, a); + auto op_af = make_float(op, float_ty, af); + return mod.add(ty, op_af); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return make_float(op, ty, a); + case scalar_type::c32: + case scalar_type::c64: { + return make_complex(op, sty, ty, a); + } + } + throw compilation_error(loc, status::internal_compiler_error); +} + +auto make_subgroup_op(uniquifier &unique, scalar_type sty, group_arithmetic arith, + group_operation op, spv_inst *a, location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_impl = [&](scalar_type sty, GroupOperation group_op, spv_inst *ty, + spv_inst *operand) -> spv_inst * { + auto scope = unique.constant(static_cast(Scope::Subgroup)); + switch (sty) { + case scalar_type::i8: + case scalar_type::i16: + case scalar_type::i32: + case scalar_type::i64: + case scalar_type::index: + return mod.add(ty, scope, group_op, operand); + case scalar_type::bf16: { + auto float_ty = unique.scalar_ty(scalar_type::f32); + auto operandf = mod.add(float_ty, operand); + auto resultf = mod.add(float_ty, scope, group_op, operandf); + return mod.add(ty, resultf); + } + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return mod.add(ty, scope, group_op, operand); + case scalar_type::c32: + case scalar_type::c64: + return mod.add(ty, scope, group_op, operand); + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto ty = unique.scalar_ty(sty); + struct add_ops { + using i = OpGroupIAdd; + using f = OpGroupFAdd; + }; + struct max_ops { + using i = OpGroupSMax; + using f = OpGroupFMax; + }; + struct min_ops { + using i = OpGroupSMin; + using f = OpGroupFMin; + }; + auto spv_op = convert_group_operation(op); + switch (arith) { + case group_arithmetic::add: + return make_impl.template operator()(sty, spv_op, ty, a); + case group_arithmetic::max: + return make_impl.template operator()(sty, spv_op, ty, a); + case group_arithmetic::min: + return make_impl.template operator()(sty, spv_op, ty, a); + } + throw compilation_error(loc, status::internal_compiler_error); +} + +} // namespace tinytc::spv + diff --git a/src/spv/converter_aux.hpp b/src/spv/converter_aux.hpp new file mode 100644 index 00000000..afa02748 --- /dev/null +++ b/src/spv/converter_aux.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERTER_AUX_20250416_HPP +#define CONVERTER_AUX_20250416_HPP + +#include "node/inst_node.hpp" +#include "spv/enums.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include + +namespace tinytc::spv { + +class spv_inst; +class uniquifier; + +auto convert_group_operation(group_operation op) -> GroupOperation; +auto get_last_label(tinytc_spv_mod &mod) -> spv_inst *; +auto make_binary_op(uniquifier &unique, scalar_type sty, arithmetic op, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst *; +auto make_binary_op_mixed_precision(uniquifier &unique, scalar_type result_ty, arithmetic op, + scalar_type a_ty, spv_inst *a, scalar_type b_ty, spv_inst *b, + location const &loc) -> spv_inst *; +auto make_cast(uniquifier &unique, scalar_type to_ty, scalar_type a_ty, spv_inst *a, + location const &loc) -> spv_inst *; +auto make_complex_mul(uniquifier &unique, spv_inst *ty, spv_inst *a, spv_inst *b, + bool conj_b = false) -> spv_inst *; +auto make_constant(uniquifier &unique, scalar_type sty, constant_inst::value_type const &val) + -> spv_inst *; +void make_conditional_execution(uniquifier &unique, spv_inst *condition, + std::function then); +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + spv_inst *otherwise, location const &loc) -> spv_inst *; +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + std::function otherwise, + location const &loc) -> spv_inst *; +void make_store(uniquifier &unique, store_flag flag, scalar_type sty, address_space as, + spv_inst *pointer, spv_inst *value, location const &loc); +auto make_unary_op(uniquifier &unique, scalar_type sty, arithmetic_unary op, spv_inst *a, + location const &loc) -> spv_inst *; +auto make_subgroup_op(uniquifier &unique, scalar_type sty, group_arithmetic arith, + group_operation op, spv_inst *a, location const &loc) -> spv_inst *; + +} // namespace tinytc::spv + +#endif // CONVERTER_AUX_20250416_HPP diff --git a/src/spv/coopmatrix_impl.cpp b/src/spv/coopmatrix_impl.cpp new file mode 100644 index 00000000..52f158de --- /dev/null +++ b/src/spv/coopmatrix_impl.cpp @@ -0,0 +1,668 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/coopmatrix_impl.hpp" +#include "codegen_tools.hpp" +#include "converter_aux.hpp" +#include "error.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "scalar_type.hpp" +#include "spv/dope_vector.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/matrix_walker.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +coopmatrix_impl::coopmatrix_impl(uniquifier &unique, core_config const &cfg, gcd_analysis_result g) + : unique_{&unique}, cfg_{cfg}, gcd_{std::move(g)} {} + +auto coopmatrix_impl::extract(cooperative_matrix_extract_inst const &in, spv_inst *mat) + -> spv_inst * { + auto matt = get_coopmatrix_type(in.mat()); + auto matl = get_layout(cfg(), matt); + const auto idx = in.index(); + if (idx < 0 || idx >= matl.length) { + throw compilation_error(in.loc(), status::ir_out_of_bounds); + } + return extract(matl, mat, idx); +} +auto coopmatrix_impl::insert(cooperative_matrix_insert_inst const &in, spv_inst *val, spv_inst *mat) + -> spv_inst * { + auto matt = get_coopmatrix_type(in.mat()); + auto matl = get_layout(cfg(), matt); + const auto idx = in.index(); + if (idx < 0 || idx >= matl.length) { + throw compilation_error(in.loc(), status::ir_out_of_bounds); + } + return insert(matl, val, mat, idx); +} + +auto coopmatrix_impl::load(cooperative_matrix_load_inst const &in, dope_vector const &odv, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) -> spv_inst * { + auto ot = get_memref_type(in.operand()); + auto rt = get_coopmatrix_type(in.result(0)); + auto pointer_ty = unique_->pointer_ty(ot); + + auto layout = get_layout(cfg(), rt); + auto matrix_ty = spv_ty(layout); + auto interface_ty = spv_interface_ty(layout); + + auto shape = std::array{odv.shape(0), odv.shape(1)}; + auto stride = std::array{odv.stride(0), odv.stride(1)}; + if (in.t() == transpose::T) { + std::swap(pos0, pos1); + std::swap(shape[0], shape[1]); + std::swap(stride[0], stride[1]); + } + + auto walker = matrix_walker(*unique_, cfg().subgroup_size, layout, pos0, pos1, shape[0], + shape[1], stride[0], stride[1], in.checked()); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(matrix_ty); + + const auto ld = [&](tinytc_spv_mod &mod) -> spv_inst * { + auto pointer = mod.add(pointer_ty, operand, walker.offset(), + std::vector{}); + return mod.add(interface_ty, pointer); + }; + const auto ld_chk = [&](tinytc_spv_mod &) { + return make_conditional_execution(*unique_, interface_ty, walker.col_ok(), ld, + unique_->null_constant(interface_ty), in.loc()); + }; + auto const ld_block = [&](tinytc_spv_mod &mod) { + spv_inst *block_result = result; + for (std::int64_t u = 0; u < layout.length / layout.blocks; ++u) { + spv_inst *val = walker.needs_mask() || walker.cols_checked() ? ld_chk(mod) : ld(mod); + block_result = insert(layout, val, block_result, walker.component_no()); + + if (u < layout.cols - 1) { + walker.advance_column(); + } + } + return block_result; + }; + auto const ld_block_chk = [&](tinytc_spv_mod &) { + auto const ld_block_zero = [&](tinytc_spv_mod &) { + spv_inst *block_result = result; + for (std::int64_t u = 0; u < layout.length / layout.blocks; ++u) { + block_result = insert(layout, unique_->null_constant(interface_ty), block_result, + walker.component_no(u)); + } + return block_result; + }; + return make_conditional_execution(*unique_, matrix_ty, walker.row_ok(), ld_block, + ld_block_zero, in.loc()); + }; + + for (std::int64_t w = 0; w < layout.blocks; ++w) { + result = walker.rows_checked() ? ld_block_chk(mod) : ld_block(mod); + + if (w < layout.blocks - 1) { + walker.advance_block(); + } + } + return result; +} + +void coopmatrix_impl::store(cooperative_matrix_store_inst const &in, dope_vector const &odv, + spv_inst *val, spv_inst *operand, spv_inst *pos0, spv_inst *pos1) { + auto ot = get_memref_type(in.operand()); + auto vt = get_coopmatrix_type(in.val()); + auto pointer_ty = unique_->pointer_ty(ot); + + auto layout = get_layout(cfg(), vt); + + auto walker = matrix_walker(*unique_, cfg().subgroup_size, layout, pos0, pos1, odv.shape(0), + odv.shape(1), odv.stride(0), odv.stride(1), in.checked()); + + auto &mod = unique_->mod(); + const auto st = [&](tinytc_spv_mod &mod) { + auto pointer = mod.add(pointer_ty, operand, walker.offset(), + std::vector{}); + auto val_ij = extract(layout, val, walker.component_no()); + + make_store(*unique_, in.flag(), ot->element_ty(), ot->addrspace(), pointer, val_ij, + in.loc()); + }; + auto const st_block = [&](tinytc_spv_mod &mod) { + for (std::int64_t u = 0; u < layout.length / layout.blocks; ++u) { + if (walker.needs_mask() || walker.cols_checked()) { + make_conditional_execution(*unique_, walker.col_ok(), st); + } else { + st(mod); + } + + if (u < layout.cols - 1) { + walker.advance_column(); + } + } + }; + + for (std::int64_t w = 0; w < layout.blocks; ++w) { + if (walker.rows_checked()) { + make_conditional_execution(*unique_, walker.row_ok(), st_block); + + } else { + st_block(mod); + } + + if (w < layout.blocks - 1) { + walker.advance_block(); + } + } +} + +auto coopmatrix_impl::mul_add(cooperative_matrix_mul_add_inst const &in, spv_inst *a, spv_inst *b, + spv_inst *c) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + auto bt = get_coopmatrix_type(in.b()); + auto ct = get_coopmatrix_type(in.c()); + auto rt = get_coopmatrix_type(in.result(0)); + + if (at->rows() % cfg().subgroup_size != 0) { + throw compilation_error(in.loc(), {&in.a()}, status::ir_unsupported_coopmatrix_shape); + } + if (ct->rows() % cfg().subgroup_size != 0) { + throw compilation_error(in.loc(), {&in.c()}, status::ir_unsupported_coopmatrix_shape); + } + if (rt->rows() % cfg().subgroup_size != 0) { + throw compilation_error(in.loc(), {}, status::ir_unsupported_coopmatrix_shape); + } + + auto al = get_layout(cfg(), at); + auto bl = get_layout(cfg(), bt); + auto cl = get_layout(cfg(), ct); + auto rl = get_layout(cfg(), rt); + + const auto a_ty = at->component_ty(); + const auto b_ty = bt->component_ty(); + const auto b_component_ty = component_type(b_ty); + const auto c_ty = ct->component_ty(); + const auto r_ty = rt->component_ty(); + const auto spv_b_ty = unique_->scalar_ty(b_ty); + const auto spv_b_component_ty = unique_->scalar_ty(b_component_ty); + const auto spv_c_ty = unique_->scalar_ty(c_ty); + const bool a_and_b_complex = is_complex_type(a_ty) && is_complex_type(b_ty); + + auto &mod = unique_->mod(); + auto result_ty = spv_ty(rl); + spv_inst *result = mod.add(result_ty); + spv_inst *imaginary_unit = + a_and_b_complex ? make_constant(*unique_, c_ty, std::complex{0.0, 1.0}) : nullptr; + + constexpr std::int64_t nbb = 4; + auto broadcast_scope = unique_->constant(static_cast(Scope::Subgroup)); + for (std::int64_t m_block = 0; m_block < rl.blocks; ++m_block) { + for (std::int64_t nb = 0; nb < rl.cols; nb += nbb) { + auto c_block = std::array{}; + auto c_im_block = std::array{}; + std::fill(std::begin(c_block), std::end(c_block), nullptr); + std::fill(std::begin(c_im_block), std::end(c_im_block), nullptr); + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + c_block[n - nb] = extract(cl, c, cl.component_no(n, m_block)); + if (a_and_b_complex) { + c_im_block[n - nb] = unique_->null_constant(spv_c_ty); + } + } + } + + for (std::int64_t k = 0; k < bl.rows * bl.blocks; ++k) { + auto a_mk = extract(al, a, al.component_no(k, m_block)); + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + /** For matrix B we have L(i,k_1,j,k_2) = i + k_1*I + j*I*K_1 + + * k_2*I*K_1*J. + * + * The n loop variable is equal to j and the k loop variable fuses + * iteration over indices i,k_1,k_2, such that k = i + k_1*I + + * k_2*I*K_1: The k loop variable fuses iteration over indices + * i,k_1,k_2, We recover i+k_1*I = k%(IK_1) k_2 = k/(IK_1) + * + * We have + * p + vS = L + * Therefore, + * p = L%S + * v = L/S + */ + const auto IK_1 = bl.rows * bl.blocks1; + const auto L = k % IK_1 + n * IK_1 + (k / IK_1) * IK_1 * bl.cols; + const auto p = + unique_->constant(static_cast(L % cfg().subgroup_size)); + const auto v = static_cast(L / cfg().subgroup_size); + + spv_inst *b_kn = extract(bl, b, v); + b_kn = mod.add(spv_b_ty, broadcast_scope, b_kn, p); + + auto &c_mn = c_block[n - nb]; + if (a_and_b_complex) { + auto &c_im_mn = c_im_block[n - nb]; + auto b_kn_re = mod.add( + spv_b_component_ty, b_kn, std::vector{0}); + auto b_kn_im = mod.add( + spv_b_component_ty, b_kn, std::vector{1}); + + auto ab_mn = make_binary_op_mixed_precision( + *unique_, c_ty, arithmetic::mul, a_ty, a_mk, b_component_ty, + b_kn_re, in.loc()); + c_mn = make_binary_op(*unique_, c_ty, arithmetic::add, ab_mn, c_mn, + in.loc()); + auto ab_im_mn = make_binary_op_mixed_precision( + *unique_, c_ty, arithmetic::mul, a_ty, a_mk, b_component_ty, + b_kn_im, in.loc()); + c_im_mn = make_binary_op(*unique_, c_ty, arithmetic::add, ab_im_mn, + c_im_mn, in.loc()); + } else { + auto ab_mn = make_binary_op_mixed_precision( + *unique_, c_ty, arithmetic::mul, a_ty, a_mk, b_ty, b_kn, in.loc()); + c_mn = make_binary_op(*unique_, c_ty, arithmetic::add, ab_mn, c_mn, + in.loc()); + } + } + } + } + if (a_and_b_complex) { + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + auto &c_mn = c_block[n - nb]; + auto &c_im_mn = c_im_block[n - nb]; + auto c_im_mn_times_i = make_binary_op(*unique_, c_ty, arithmetic::mul, + c_im_mn, imaginary_unit, in.loc()); + c_mn = make_binary_op(*unique_, c_ty, arithmetic::add, c_mn, + c_im_mn_times_i, in.loc()); + } + } + } + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + auto &c_mn = c_block[n - nb]; + if (c_ty != r_ty) { + c_mn = make_cast(*unique_, r_ty, c_ty, c_mn, in.loc()); + } + result = insert(cl, c_mn, result, n + m_block * rl.cols); + } + } + } + } + return result; +} + +void coopmatrix_impl::prefetch(cooperative_matrix_prefetch_inst const &, dope_vector const &, + spv_inst *, spv_inst *, spv_inst *) {} + +auto coopmatrix_impl::reduce(cooperative_matrix_reduce_inst const &in, spv_inst *a) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + const auto sgs = cfg().subgroup_size; + + if (at->rows() % sgs != 0) { + throw compilation_error(in.loc(), {&in.a()}, status::ir_unsupported_coopmatrix_shape); + } + + auto rt = get_coopmatrix_type(in.result(0)); + auto rl = get_layout(cfg(), rt); + auto al = get_layout(cfg(), at); + auto matrix_ty = spv_ty(rl); + const auto sty = rt->component_ty(); + auto ty = unique_->scalar_ty(sty); + auto bool_ty = unique_->bool_ty(); + auto i32_ty = unique_->scalar_ty(scalar_type::i32); + + auto const binary_arith = [&](group_arithmetic a) { + switch (a) { + case group_arithmetic::add: + return arithmetic::add; + case group_arithmetic::max: + return arithmetic::max; + case group_arithmetic::min: + return arithmetic::min; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }(in.arith()); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(matrix_ty); + if (in.mode() == reduce_mode::column) { + auto p = unique_->load_builtin(BuiltIn::SubgroupLocalInvocationId); + auto scope = unique_->constant(static_cast(Scope::Subgroup)); + auto c0 = unique_->constant(std::int32_t{0}); + auto cnull = unique_->null_constant(ty); + + for (std::int32_t j0 = 0; j0 < al.shape1; j0 += sgs) { + auto x = std::vector(sgs, nullptr); + for (std::int32_t b = 0; b < al.blocks; ++b) { + for (std::int32_t i = 0; i < sgs; ++i) { + if (j0 + i < al.shape1) { + auto a_i = extract(al, a, al.component_no(j0 + i, b)); + x[i] = + x[i] ? make_binary_op(*unique_, sty, binary_arith, x[i], a_i, in.loc()) + : a_i; + } else if (x[i] == nullptr) { + x[i] = cnull; + } + } + } + for (std::int32_t v = 1; v < sgs; v *= 2) { + auto cv = unique_->constant(v); + spv_inst *cond = mod.add(i32_ty, p, cv); + cond = mod.add(bool_ty, cond, c0); + for (int i = 0; i < sgs / v; i += 2) { + auto xip1_up = mod.add(ty, scope, x[i + 1], cv); + auto xi_down = mod.add(ty, scope, x[i], cv); + auto s1 = mod.add(ty, cond, x[i], xip1_up); + auto s2 = mod.add(ty, cond, xi_down, x[i + 1]); + x[i / 2] = make_binary_op(*unique_, sty, binary_arith, s1, s2, in.loc()); + } + } + result = insert(rl, x[0], result, j0 / sgs); + } + } else { + for (std::int32_t b = 0; b < al.blocks; ++b) { + spv_inst *sum = nullptr; + for (std::int32_t j = 0; j < al.length / al.blocks; ++j) { + auto a_i = extract(al, a, al.component_no(j, b)); + sum = sum ? make_binary_op(*unique_, sty, binary_arith, sum, a_i, in.loc()) : a_i; + } + result = insert(rl, sum, result, rl.component_no(0, b)); + } + } + + return result; +} + +auto coopmatrix_impl::scale(cooperative_matrix_scale_inst const &in, spv_inst *a, spv_inst *b) + -> spv_inst * { + auto rt = get_coopmatrix_type(in.result(0)); + auto rl = get_layout(cfg(), rt); + auto bl = get_layout(cfg(), get_coopmatrix_type(in.b())); + auto sty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto b_v = extract(bl, b, v); + auto r_v = make_binary_op(*unique_, sty, arithmetic::mul, a, b_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::arith(arith_inst const &in, spv_inst *a, spv_inst *b) -> spv_inst * { + auto rt = get_coopmatrix_type(in.result(0)); + auto rl = get_layout(cfg(), rt); + auto al = get_layout(cfg(), get_coopmatrix_type(in.a())); + auto bl = get_layout(cfg(), get_coopmatrix_type(in.b())); + auto sty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto a_v = extract(al, a, v); + auto b_v = extract(bl, b, v); + auto r_v = make_binary_op(*unique_, sty, in.operation(), a_v, b_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::arith_unary(arith_unary_inst const &in, spv_inst *a) -> spv_inst * { + auto al = get_layout(cfg(), get_coopmatrix_type(in.a())); + auto rt = get_coopmatrix_type(in.result(0)); + auto rl = get_layout(cfg(), rt); + auto sty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto a_v = extract(al, a, v); + auto r_v = make_unary_op(*unique_, sty, in.operation(), a_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::cast(cast_inst const &in, spv_inst *a) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + auto al = get_layout(cfg(), at); + auto a_ty = at->component_ty(); + auto rt = get_coopmatrix_type(in.result(0)); + auto rl = get_layout(cfg(), rt); + auto r_ty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + + const auto P = + rt->use() == matrix_use::b && at->use() == matrix_use::acc + ? std::function([&](LiteralInteger v) -> LiteralInteger { + /** + * Using that M >= S we have for matrix_b + * L_b(i,k_1,j,k_2) = i + k_1*S + j*S*K_1 + k_2*S*K_1*J. + * + * We have + * p_b + v_bS = L_b + * + * Recovering i,k_1,j,k_2 from L_b we have + * i = L_b%S = p_b + * k_1 = L_b/S%K_1 = v_b%K_1 + * j = L_b/(SK_1)%J = v_b/K_1%J + * k_2 = L_b/(SK_1J) = v_b/(K_1J) + * + * Let k=k_1 + k_2K_1, and L_1, L_2 be the block sizes of matrix + * acc. We have L_{acc} = i + (k%L_1)*S + j*S*L_1 + + * (k/L_1)*S*L_1*J. + * + * Recovering p_{acc}, v_{acc} from + * p_{acc} + v_{acc}S = L_{acc} + * we have + * p_{acc} = L_{acc}%S = p_b + * v_{acc} = L_{acc}/S = k%L_1 + j*L_1 + (k/L_1)*L_1*J + * + * If M < S, then we have K_1=K_2=L_1=L_2=1, and there is no layout + * transformation. The code below just returns v - the identity - + * if M < S. + */ + auto const k_1 = v % rl.blocks1; + auto const j = v / rl.blocks1 % rl.cols; + auto const k_2 = v / (rl.blocks1 * rl.cols); + auto const k = k_1 + k_2 * rl.blocks1; + return k % al.blocks1 + j * al.blocks1 + (k / al.blocks1) * al.blocks1 * al.cols; + }) + : std::function([](LiteralInteger v) -> LiteralInteger { return v; }); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto a_v = extract(al, a, P(v)); + auto r_v = make_cast(*unique_, r_ty, a_ty, a_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::constant(constant_inst const &in) -> spv_inst * { + auto rt = get_coopmatrix_type(in.result(0)); + auto rl = get_layout(cfg(), rt); + auto sty = rt->component_ty(); + auto spv_result_ty = spv_ty(rl); + + if (in.is_zero()) { + return unique_->null_constant(spv_result_ty); + } + if (rl.length == 1) { + return make_constant(*unique_, sty, in.value()); + } + + auto const init_vector = [&]() { + if (is_complex_type(sty)) { + const auto c = std::get>(in.value()); + auto cty = component_type(sty); + auto re = make_constant(*unique_, cty, c.real()); + auto im = make_constant(*unique_, cty, c.imag()); + auto vec = std::vector(2 * rl.length); + for (std::int64_t v = 0; v < rl.length; ++v) { + vec[2 * v] = re; + vec[2 * v + 1] = im; + } + return vec; + } else if (rl.ops_per_chan > 1) { + auto cst = std::visit( + overloaded{[&](std::int64_t i) -> spv_inst * { + switch (rl.sty) { + case scalar_type::i8: { + auto v8 = + std::bit_cast(static_cast(i)); + return unique_->constant( + std::int32_t{v8 | (v8 << 8) | (v8 << 16) | (v8 << 24)}); + } + case scalar_type::i16: { + auto v16 = + std::bit_cast(static_cast(i)); + return unique_->constant(std::int32_t{v16 | (v16 << 16)}); + } + default: + return nullptr; + } + }, + [&](double d) -> spv_inst * { + const float f = static_cast(d); + switch (rl.sty) { + case scalar_type::bf16: { + std::uint16_t v16 = bfloat16{f}.bits(); + return unique_->constant(std::int32_t{v16 | (v16 << 16)}); + } + case scalar_type::f16: { + std::uint16_t v16 = half{f}.bits(); + return unique_->constant(std::int32_t{v16 | (v16 << 16)}); + } + default: + return nullptr; + } + }, + [&](auto const &) -> spv_inst * { return nullptr; }}, + in.value()); + if (!cst) { + throw status::internal_compiler_error; + } + return std::vector(rl.length, cst); + } + return std::vector(rl.length, make_constant(*unique_, sty, in.value())); + }; + return unique_->mod().add_to(section::type_const_var, spv_result_ty, + init_vector()); +} + +auto coopmatrix_impl::spv_interface_ty(coopmatrix_layout const &layout) -> spv_inst * { + return unique_->scalar_ty(layout.sty); +} + +auto coopmatrix_impl::spv_storage_ty(coopmatrix_layout const &layout) -> spv_inst * { + if (layout.ops_per_chan > 1) { + if (layout.ops_per_chan * size(layout.sty) != 4) { + throw status::internal_compiler_error; + } + return unique_->scalar_ty(scalar_type::i32); + } + return unique_->scalar_ty(component_type(layout.sty)); +} + +auto coopmatrix_impl::spv_ty(coopmatrix_layout const &layout) -> spv_inst * { + if (layout.length == 1) { + return spv_interface_ty(layout); + } + const auto length = + static_cast(component_count(layout.sty)) * layout.length / layout.ops_per_chan; + auto storage_ty = spv_storage_ty(layout); + return length > 1 ? unique_->vec_ty(storage_ty, length) : storage_ty; +} + +auto coopmatrix_impl::spv_ty(coopmatrix_data_type const *ct) -> spv_inst * { + return spv_ty(get_layout(cfg(), ct)); +} + +auto coopmatrix_impl::extract(coopmatrix_layout const &layout, spv_inst *mat, LiteralInteger v) + -> spv_inst * { + if (layout.length == 1) { + return mat; + } + const auto ty = spv_interface_ty(layout); + auto &mod = unique_->mod(); + if (is_complex_type(layout.sty)) { + const auto storage_ty = spv_storage_ty(layout); + auto re = mod.add(storage_ty, mat, std::vector{2 * v}); + auto im = mod.add(storage_ty, mat, std::vector{2 * v + 1}); + return mod.add(ty, std::vector{re, im}); + } else if (layout.ops_per_chan > 1) { + if (layout.blocks1 != 1) { + throw status::internal_compiler_error; + } + const auto storage_ty = spv_storage_ty(layout); + const auto chan_ty = unique_->vec_ty(ty, layout.ops_per_chan); + spv_inst *val = + layout.length > layout.ops_per_chan + ? mod.add(storage_ty, mat, std::vector{v / layout.ops_per_chan}) + : mat; + val = mod.add(chan_ty, val); + return mod.add(ty, val, std::vector{v % layout.ops_per_chan}); + } + return mod.add(ty, mat, std::vector{v}); +} +auto coopmatrix_impl::insert(coopmatrix_layout const &layout, spv_inst *val, spv_inst *mat, + LiteralInteger v) -> spv_inst * { + if (layout.length == 1) { + return val; + } + auto matrix_ty = spv_ty(layout); + auto &mod = unique_->mod(); + if (is_complex_type(layout.sty)) { + const auto storage_ty = spv_storage_ty(layout); + auto re = mod.add(storage_ty, val, std::vector{0}); + auto im = mod.add(storage_ty, val, std::vector{1}); + auto tmp = mod.add(matrix_ty, re, mat, std::vector{2 * v}); + return mod.add(matrix_ty, im, tmp, std::vector{2 * v + 1}); + } else if (layout.ops_per_chan > 1) { + if (layout.blocks1 != 1) { + throw status::internal_compiler_error; + } + const auto storage_ty = spv_storage_ty(layout); + const auto channels_ty = unique_->vec_ty(spv_interface_ty(layout), layout.ops_per_chan); + const auto entry_no = std::vector{v / layout.ops_per_chan}; + spv_inst *channels = layout.length > layout.ops_per_chan + ? mod.add(storage_ty, mat, entry_no) + : mat; + channels = mod.add(channels_ty, channels); + channels = mod.add(channels_ty, val, channels, + std::vector{v % layout.ops_per_chan}); + channels = mod.add(storage_ty, channels); + return layout.length > layout.ops_per_chan + ? mod.add(matrix_ty, channels, mat, entry_no) + : channels; + } + return mod.add(matrix_ty, val, mat, std::vector{v}); +} + +} // namespace tinytc::spv diff --git a/src/spv/coopmatrix_impl.hpp b/src/spv/coopmatrix_impl.hpp new file mode 100644 index 00000000..3c1b8103 --- /dev/null +++ b/src/spv/coopmatrix_impl.hpp @@ -0,0 +1,86 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_IMPL_20250415_HPP +#define COOPMATRIX_IMPL_20250415_HPP + +#include "analysis/gcd.hpp" +#include "coopmatrix_layout.hpp" // IWYU pragma: keep +#include "device_info.hpp" +#include "spv/defs.hpp" + +#include + +namespace tinytc { +class arith_inst; +class arith_unary_inst; +class cast_inst; +class constant_inst; +class cooperative_matrix_extract_inst; +class cooperative_matrix_insert_inst; +class cooperative_matrix_load_inst; +class cooperative_matrix_mul_add_inst; +class cooperative_matrix_prefetch_inst; +class cooperative_matrix_reduce_inst; +class cooperative_matrix_scale_inst; +class cooperative_matrix_store_inst; +class coopmatrix_data_type; +} // namespace tinytc + +namespace tinytc::spv { + +class dope_vector; +class uniquifier; + +class coopmatrix_impl { + public: + coopmatrix_impl(uniquifier &unique, core_config const &cfg, gcd_analysis_result g); + virtual ~coopmatrix_impl() = default; + + inline auto gcd() const -> gcd_analysis_result const & { return gcd_; } + inline void gcd(gcd_analysis_result g) { gcd_ = std::move(g); } + + auto cfg() const -> core_config const & { return cfg_; } + inline void cfg(core_config const &cfg) { cfg_ = cfg; } + + virtual auto extract(cooperative_matrix_extract_inst const &in, spv_inst *mat) -> spv_inst *; + virtual auto insert(cooperative_matrix_insert_inst const &in, spv_inst *val, spv_inst *mat) + -> spv_inst *; + virtual auto load(cooperative_matrix_load_inst const &in, dope_vector const &odv, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) -> spv_inst *; + virtual auto mul_add(cooperative_matrix_mul_add_inst const &in, spv_inst *a, spv_inst *b, + spv_inst *c) -> spv_inst *; + virtual void prefetch(cooperative_matrix_prefetch_inst const &in, dope_vector const &odv, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1); + virtual auto reduce(cooperative_matrix_reduce_inst const &in, spv_inst *a) -> spv_inst *; + virtual auto scale(cooperative_matrix_scale_inst const &in, spv_inst *a, spv_inst *b) + -> spv_inst *; + virtual void store(cooperative_matrix_store_inst const &in, dope_vector const &odv, + spv_inst *val, spv_inst *operand, spv_inst *pos0, spv_inst *pos1); + + virtual auto arith(arith_inst const &in, spv_inst *a, spv_inst *b) -> spv_inst *; + virtual auto arith_unary(arith_unary_inst const &in, spv_inst *a) -> spv_inst *; + virtual auto cast(cast_inst const &in, spv_inst *a) -> spv_inst *; + virtual auto constant(constant_inst const &in) -> spv_inst *; + + virtual auto spv_ty(coopmatrix_data_type const *ct) -> spv_inst *; + + protected: + auto spv_interface_ty(coopmatrix_layout const &layout) -> spv_inst *; + auto spv_storage_ty(coopmatrix_layout const &layout) -> spv_inst *; + auto spv_ty(coopmatrix_layout const &layout) -> spv_inst *; + auto extract(coopmatrix_layout const &layout, spv_inst *mat, LiteralInteger v) -> spv_inst *; + auto insert(coopmatrix_layout const &layout, spv_inst *val, spv_inst *mat, LiteralInteger v) + -> spv_inst *; + + inline auto unique() -> uniquifier & { return *unique_; } + + private: + uniquifier *unique_; + core_config cfg_; + gcd_analysis_result gcd_; +}; + +} // namespace tinytc::spv + +#endif // COOPMATRIX_IMPL_20250415_HPP diff --git a/src/spv/coopmatrix_impl_block.cpp b/src/spv/coopmatrix_impl_block.cpp new file mode 100644 index 00000000..9a8f1878 --- /dev/null +++ b/src/spv/coopmatrix_impl_block.cpp @@ -0,0 +1,287 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/coopmatrix_impl_block.hpp" +#include "analysis/gcd.hpp" +#include "codegen_tools.hpp" +#include "converter_aux.hpp" +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "node/value_node.hpp" +#include "spv/defs.hpp" +#include "spv/dope_vector.hpp" +#include "spv/instructions.hpp" +#include "spv/matrix_walker.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/math.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto max_block_io_vec_size(scalar_type sty) -> std::int64_t { + return sty == scalar_type::i8 || sty == scalar_type::i16 ? 16 : 8; +} + +auto coopmatrix_impl_block::load(cooperative_matrix_load_inst const &in, dope_vector const &odv, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) -> spv_inst * { + const auto ot = get_memref_type(in.operand()); + const auto rt = get_coopmatrix_type(in.result(0)); + const auto layout = get_layout(cfg(), rt); + const auto sty = layout.sty; + + const std::int32_t required_alignment = ot->addrspace() == address_space::global ? 4 : 16; + + const bool layout_ok = layout.rows >= cfg().subgroup_size; + const bool transpose_ok = in.t() == transpose::N; + const bool alignment_ok = is_aligned(required_alignment, in.operand(), in.pos0()); + const bool checked_ok = + in.checked() == checked_flag::none || in.checked() == checked_flag::cols; + const bool sty_ok = sty != scalar_type::c64; // We do not have 16 byte/lane block loads + if (!layout_ok || !transpose_ok || !alignment_ok || !checked_ok || !sty_ok) { + return coopmatrix_impl::load(in, odv, operand, pos0, pos1); + } + + auto walker = matrix_walker(unique(), cfg().subgroup_size, layout, pos0, pos1, odv.shape(0), + odv.shape(1), odv.stride(0), odv.stride(1), in.checked(), 0); + + const auto io_sty = get_io_sty(sty); + const std::int32_t blocks_per_load = + is_positive_power_of_two(layout.blocks) + ? std::min(layout.blocks, max_block_io_vec_size(io_sty)) + : 1; + const std::int32_t cols_per_load = [&]() -> std::int32_t { + const auto ot = get_memref_type(in.operand()); + const bool is_contiguous = ot->dim() == 2 && ot->shape(0) == rt->shape(0) && + ot->stride(0) == 1 && ot->stride(1) == ot->shape(0); + std::int32_t cols_per_load = 1; + if (is_contiguous && !walker.cols_checked()) { + const std::int32_t max_cols_per_load = max_block_io_vec_size(io_sty) / blocks_per_load; + const std::int32_t num_cols = layout.length / layout.blocks; + while (2 * cols_per_load <= max_cols_per_load && num_cols % (2 * cols_per_load) == 0) { + cols_per_load *= 2; + } + } + return cols_per_load; + }(); + + const auto matrix_ty = spv_ty(layout); + const auto interface_ty = spv_interface_ty(layout); + auto io_ty = unique().scalar_ty(io_sty); + const auto io_vec_size = blocks_per_load * cols_per_load; + spv_inst *io_vec_ty = io_vec_size > 1 ? unique().vec_ty(io_ty, io_vec_size) : io_ty; + const auto pointer_ty = [&] { + auto ot = get_memref_type(in.operand()); + const auto storage_cls = address_space_to_storage_class(ot->addrspace()); + const auto align = ot->element_alignment(); + return unique().pointer_ty(storage_cls, io_ty, align); + }(); + + auto &mod = unique().mod(); + operand = mod.add(pointer_ty, operand); + spv_inst *result = mod.add(matrix_ty); + + const auto ld = [&](tinytc_spv_mod &mod) -> spv_inst * { + spv_inst *offset = walker.offset(); + auto pointer = mod.add(pointer_ty, operand, offset, + std::vector{}); + return mod.add(io_vec_ty, pointer); + }; + const auto ld_chk = [&](tinytc_spv_mod &) { + return make_conditional_execution(unique(), interface_ty, walker.col_ok(), ld, + unique().null_constant(io_vec_ty), in.loc()); + }; + auto const ld_block = [&](tinytc_spv_mod &mod) { + spv_inst *block_result = result; + for (std::int64_t u = 0; u < layout.length / layout.blocks; u += cols_per_load) { + spv_inst *val = walker.needs_mask() || walker.cols_checked() ? ld_chk(mod) : ld(mod); + if (io_vec_size > 1) { + for (std::int32_t c = 0; c < cols_per_load; ++c) { + for (std::int32_t b = 0; b < blocks_per_load; ++b) { + spv_inst *v = mod.add( + io_ty, val, std::vector{b + c * blocks_per_load}); + v = mod.add(interface_ty, v); + const auto comp_no = + layout.component_no(walker.col_no() + c, walker.block_no() + b); + block_result = insert(layout, v, block_result, comp_no); + } + } + } else { + val = mod.add(interface_ty, val); + block_result = insert(layout, val, block_result, walker.component_no()); + } + + if (u < layout.cols - 1) { + for (std::int32_t c = 0; c < cols_per_load; ++c) { + walker.advance_column(); + } + } + } + return block_result; + }; + + for (std::int64_t w = 0; w < layout.blocks; w += blocks_per_load) { + result = ld_block(mod); + + if (w < layout.blocks - blocks_per_load) { + for (std::int32_t b = 0; b < blocks_per_load; ++b) { + walker.advance_block(); + } + } + } + return result; +} + +void coopmatrix_impl_block::store(cooperative_matrix_store_inst const &in, dope_vector const &odv, + spv_inst *val, spv_inst *operand, spv_inst *pos0, + spv_inst *pos1) { + constexpr std::int32_t required_alignment = 16; + + auto vt = get_coopmatrix_type(in.val()); + auto layout = get_layout(cfg(), vt); + auto sty = vt->component_ty(); + + const bool layout_ok = layout.rows >= cfg().subgroup_size; + const bool flag_ok = in.flag() == store_flag::regular; + const bool alignment_ok = is_aligned(required_alignment, in.operand(), in.pos0()); + const bool checked_ok = + in.checked() == checked_flag::none || in.checked() == checked_flag::cols; + const bool sty_ok = sty != scalar_type::c64; // We do not have 16 byte/lane block writes + if (!layout_ok || !flag_ok || !alignment_ok || !checked_ok || !sty_ok) { + coopmatrix_impl::store(in, odv, val, operand, pos0, pos1); + return; + } + + auto walker = matrix_walker(unique(), cfg().subgroup_size, layout, pos0, pos1, odv.shape(0), + odv.shape(1), odv.stride(0), odv.stride(1), in.checked(), 0); + + const auto io_sty = get_io_sty(sty); + const std::int32_t blocks_per_store = + is_positive_power_of_two(layout.blocks) + ? std::min(layout.blocks, max_block_io_vec_size(io_sty)) + : 1; + const std::int32_t cols_per_store = [&]() -> std::int32_t { + const auto ot = get_memref_type(in.operand()); + const bool is_contiguous = ot->dim() == 2 && ot->shape(0) == vt->shape(0) && + ot->stride(0) == 1 && ot->stride(1) == ot->shape(0); + std::int32_t cols_per_store = 1; + if (is_contiguous && !walker.cols_checked()) { + const std::int32_t max_cols_per_store = + max_block_io_vec_size(io_sty) / blocks_per_store; + const std::int32_t num_cols = layout.length / layout.blocks; + while (2 * cols_per_store <= max_cols_per_store && + num_cols % (2 * cols_per_store) == 0) { + cols_per_store *= 2; + } + } + return cols_per_store; + }(); + + auto io_ty = unique().scalar_ty(io_sty); + auto const io_vec_size = blocks_per_store * cols_per_store; + spv_inst *io_vec_ty = io_vec_size > 1 ? unique().vec_ty(io_ty, io_vec_size) : io_ty; + const auto pointer_ty = [&] { + auto ot = get_memref_type(in.operand()); + const auto storage_cls = address_space_to_storage_class(ot->addrspace()); + const auto align = std::max(16, ot->element_alignment()); + return unique().pointer_ty(storage_cls, io_ty, align); + }(); + + auto &mod = unique().mod(); + operand = mod.add(pointer_ty, operand); + + for (std::int64_t w = 0; w < layout.blocks; w += blocks_per_store) { + auto const st_block = [&](tinytc_spv_mod &mod) { + for (std::int64_t u = 0; u < layout.length / layout.blocks; u += cols_per_store) { + const auto st = [&](tinytc_spv_mod &mod) { + spv_inst *offset = walker.offset(); + auto pointer = mod.add(pointer_ty, operand, offset, + std::vector{}); + spv_inst *val_ij = nullptr; + if (io_vec_size > 1) { + val_ij = mod.add(io_vec_ty); + for (std::int32_t c = 0; c < cols_per_store; ++c) { + for (std::int32_t b = 0; b < blocks_per_store; ++b) { + const auto comp_no = + layout.component_no(walker.col_no() + c, walker.block_no() + b); + spv_inst *v = extract(layout, val, comp_no); + v = mod.add(io_ty, v); + val_ij = mod.add( + io_vec_ty, v, val_ij, + std::vector{b + c * blocks_per_store}); + } + } + } else { + val_ij = extract(layout, val, walker.component_no()); + val_ij = mod.add(io_ty, val_ij); + } + mod.add(pointer, val_ij); + }; + if (walker.needs_mask() || walker.cols_checked()) { + make_conditional_execution(unique(), walker.col_ok(), st); + } else { + st(mod); + } + + if (u < layout.cols - 1) { + for (std::int32_t c = 0; c < cols_per_store; ++c) { + walker.advance_column(); + } + } + } + }; + st_block(mod); + + if (w < layout.blocks - blocks_per_store) { + for (std::int32_t b = 0; b < blocks_per_store; ++b) { + walker.advance_block(); + } + } + } +} + +auto coopmatrix_impl_block::get_io_sty(scalar_type sty) -> scalar_type { + switch (sty) { + case scalar_type::bf16: + case scalar_type::f16: + return scalar_type::i16; + case scalar_type::f32: + return scalar_type::i32; + case scalar_type::f64: + case scalar_type::c32: + return scalar_type::i64; + default: + break; + } + return sty; +} + +auto coopmatrix_impl_block::is_aligned(std::int32_t alignment, value_node const &operand, + value_node const &pos0) -> bool { + auto const mt = get_memref_type(operand); + const auto sty_size = size(mt->element_ty()); + if (sty_size >= static_cast(alignment)) { + return true; + } + if (auto mi = gcd().get_memref_if(operand); mi) { + const bool base_ok = (mi->offset_gcd() * sty_size) % alignment == 0; + const bool pos0_ok = (gcd().get(pos0) * sty_size) % alignment == 0; + const bool stride_ok = + mt->stride(0) == 1 && (mi->stride_gcd()[1] * sty_size) % alignment == 0; + + return base_ok && pos0_ok && stride_ok; + } + return false; +} + +} // namespace tinytc::spv diff --git a/src/spv/coopmatrix_impl_block.hpp b/src/spv/coopmatrix_impl_block.hpp new file mode 100644 index 00000000..3a11fcd4 --- /dev/null +++ b/src/spv/coopmatrix_impl_block.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_IMPL_BLOCK_20250428_HPP +#define COOPMATRIX_IMPL_BLOCK_20250428_HPP + +#include "spv/coopmatrix_impl.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include + +namespace tinytc::spv { + +class coopmatrix_impl_block : public coopmatrix_impl { + public: + using coopmatrix_impl::coopmatrix_impl; + + auto load(cooperative_matrix_load_inst const &in, dope_vector const &odv, spv_inst *operand, + spv_inst *pos0, spv_inst *pos1) -> spv_inst * override; + void store(cooperative_matrix_store_inst const &in, dope_vector const &odv, spv_inst *val, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) override; + + private: + auto get_io_sty(scalar_type sty) -> scalar_type; + auto is_aligned(std::int32_t alignment, tinytc_value const &operand, tinytc_value const &pos0) + -> bool; +}; + +} // namespace tinytc::spv + +#endif // COOPMATRIX_IMPL_BLOCK_20250428_HPP diff --git a/src/spv/coopmatrix_impl_dpas.cpp b/src/spv/coopmatrix_impl_dpas.cpp new file mode 100644 index 00000000..8d69722b --- /dev/null +++ b/src/spv/coopmatrix_impl_dpas.cpp @@ -0,0 +1,605 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/coopmatrix_impl_dpas.hpp" +#include "analysis/gcd.hpp" +#include "codegen_tools.hpp" +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "matrix_ext_info.hpp" +#include "node/data_type_node.hpp" +#include "node/inst_node.hpp" +#include "spv/block2d_diy.hpp" +#include "spv/coopmatrix_impl.hpp" +#include "spv/defs.hpp" +#include "spv/dope_vector.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/lut.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "spv/xe_constants.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/math.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto precision(scalar_type sty) -> char const * { + switch (sty) { + case scalar_type::f16: + return "hf"; + case scalar_type::bf16: + return "bf"; + case scalar_type::i8: + return "s8"; + default: + throw status::internal_compiler_error; + } +} + +auto coopmatrix_impl_dpas::max_rows_in_block(matrix_use use, std::int32_t element_size) const + -> std::int32_t { + if (use == matrix_use::b) { + std::int32_t ops_per_chan = xe::channel_size / element_size; + return ops_per_chan * xe::sdepth; + } + return xe::exec_size; +} + +auto coopmatrix_impl_dpas::check_2d_block_io(tinytc_value const &operand, tinytc_value const &pos0) + -> bool { + if (auto mi = gcd().get_memref_if(operand); mi) { + auto const mt = get_memref_type(operand); + const auto sty_size = size(mt->element_ty()); + auto const &block_io = cfg().matrix->block_io(); + + const bool sfid_ok = mt->addrspace() == address_space::global; + const bool base_address_alignment_ok = + (mi->offset_gcd() * sty_size) % block_io.base_address_alignment == 0; + const bool pos0_alignment_ok = (gcd().get(pos0) * sty_size) % block_io.pos0_alignment == 0; + const bool stride_ok = + mt->stride(0) == 1 && (mi->stride_gcd()[1] * sty_size) % block_io.stride_alignment == 0; + const bool width_ok = (mi->shape_gcd()[0] * sty_size) % block_io.width_alignment == 0; + + return sfid_ok && base_address_alignment_ok && pos0_alignment_ok && stride_ok && width_ok; + } + return false; +} + +auto coopmatrix_impl_dpas::load_config(scalar_type sty, std::int32_t rows, std::int32_t cols, + matrix_use use, transpose trans, int32_t cache_level) + -> block_config { + auto cfg = block_config{}; + cfg.sty = sty; + cfg.element_size = size(sty); + cfg.array_length = 1; + cfg.rows = rows; + cfg.cols = cols; + cfg.row_blocks = 1; + cfg.col_blocks = 1; + cfg.transpose = trans == transpose::T; + cfg.vnni = use == matrix_use::a; + cfg.pos0_shr = 0; + cfg.cache_level = cache_level; + + auto const adjust_rows = [&cfg](std::int32_t max_rows, std::int32_t max_array_length) { + if (cfg.rows > max_rows) { + const std::int32_t num_blocks = cfg.rows / max_rows; + if (num_blocks > max_array_length) { + cfg.array_length = max_array_length; + cfg.row_blocks = num_blocks / max_array_length; + } else { + cfg.array_length = num_blocks; + } + cfg.rows = max_rows; + } + }; + auto const adjust_cols = [&cfg](std::int32_t max_cols_in_block) { + if (cfg.cols > max_cols_in_block) { + cfg.col_blocks = cfg.cols / max_cols_in_block; + cfg.cols = max_cols_in_block; + } + }; + auto const max_array_length = [&cfg](std::int32_t max_rows) -> std::int32_t { + return 64 / (max_rows * cfg.element_size); + }; + + // transpose + vnni message is the same as transpose message on d32 + if (cfg.transpose && cfg.vnni) { + std::swap(cfg.rows, cfg.cols); + + const auto ops_per_chan = 4 / cfg.element_size; + cfg.rows /= ops_per_chan; + cfg.sty = scalar_type::i32; + cfg.element_size = 4; + cfg.pos0_shr = ilog2(ops_per_chan); + cfg.vnni = false; + + adjust_cols(xe::exec_size); + + const auto max_rows = xe::exec_size / 2; + adjust_rows(max_rows, 1); + } else if (cfg.transpose) { + std::swap(cfg.rows, cfg.cols); + // Enable VNNI as transpose loads for B matrix are missing, so we use VNNI + mov-based 8x8 + // transpose + cfg.vnni = true; + + const std::int32_t max_cols = max_rows_in_block(use, cfg.element_size); + const std::int32_t max_rows = 8; + adjust_cols(max_cols); + adjust_rows(max_rows, max_array_length(max_rows)); + } else { + const std::int32_t max_cols = 32; + const std::int32_t max_rows = max_rows_in_block(use, cfg.element_size); + + adjust_cols(max_cols); + adjust_rows(max_rows, max_array_length(max_rows)); + } + + return cfg; +} + +auto coopmatrix_impl_dpas::load_fun(coopmatrix_data_type const *result_ty, spv_inst *spv_operand_ty, + transpose trans) -> spv_inst * { + const auto key = load_key{result_ty, spv_operand_ty, trans}; + return lookup(load_funs_, key, [&](load_key const &key) { + const auto [result_ty, spv_operand_ty, trans] = key; + + const auto cfg = load_config(result_ty->component_ty(), result_ty->rows(), + result_ty->cols(), result_ty->use(), trans); + auto code = load_block2d_native(cfg, tmp_); + + auto spv_i32_ty = unique().scalar_ty(scalar_type::i32); + auto spv_result_ty = spv_ty(result_ty); + auto fun_ty = unique().function_ty( + spv_result_ty, array_view{spv_operand_ty, spv_i32_ty, spv_i32_ty, + spv_i32_ty, spv_i32_ty, spv_i32_ty}); + return unique().mod().add_to(section::type_const_var, spv_result_ty, fun_ty, + unique().asm_target(), code, + "=rw,rw.u,rw.u,rw.u,rw.u,rw.u,rw.u"); + }); +} + +auto coopmatrix_impl_dpas::prefetch_fun(std::int32_t cache_level, scalar_type sty, + spv_inst *spv_operand_ty, std::int32_t rows, + std::int32_t cols) -> spv_inst * { + const auto key = prefetch_key{cache_level, sty, spv_operand_ty, rows, cols}; + return lookup(prefetch_funs_, key, [&](prefetch_key const &key) -> spv_inst * { + const auto [cache_level, sty, spv_operand_ty, rows, cols] = key; + + const auto cfg = load_config(sty, rows, cols, matrix_use::acc, transpose::N, cache_level); + auto code = prefetch_block2d_native(cfg, tmp_); + + auto spv_i32_ty = unique().scalar_ty(scalar_type::i32); + auto spv_void_ty = unique().void_ty(); + auto fun_ty = unique().function_ty( + spv_void_ty, array_view{spv_operand_ty, spv_i32_ty, spv_i32_ty, spv_i32_ty, + spv_i32_ty, spv_i32_ty}); + return unique().mod().add_to(section::type_const_var, spv_void_ty, fun_ty, + unique().asm_target(), code, + "rw.u,rw.u,rw.u,rw.u,rw.u,rw.u"); + }); +} + +auto coopmatrix_impl_dpas::store_config(coopmatrix_data_type const *ct) -> block_config { + constexpr std::int32_t max_cols_in_block = 8; + + auto cfg = block_config{}; + cfg.sty = ct->component_ty(); + cfg.element_size = size(ct->component_ty()); + cfg.array_length = 1; + cfg.rows = ct->rows(); + cfg.cols = ct->cols(); + cfg.row_blocks = 1; + cfg.col_blocks = 1; + cfg.transpose = false; + cfg.vnni = false; + cfg.pos0_shr = 0; + cfg.cache_level = -1; + + if (cfg.cols > max_cols_in_block) { + cfg.col_blocks = cfg.cols / max_cols_in_block; + cfg.cols = max_cols_in_block; + } + + const auto max_rows = max_rows_in_block(ct->use(), cfg.element_size); + if (cfg.rows > max_rows) { + cfg.row_blocks = cfg.rows / max_rows; + cfg.rows = max_rows; + } + + return cfg; +} + +auto coopmatrix_impl_dpas::store_fun(coopmatrix_data_type const *val_ty, spv_inst *spv_operand_ty) + -> spv_inst * { + const auto key = store_key{val_ty, spv_operand_ty}; + return lookup(store_funs_, key, [&](store_key const &key) { + const auto [val_ty, spv_operand_ty] = key; + + const auto cfg = store_config(val_ty); + auto code = store_block2d_native(cfg, tmp_); + + auto spv_void_ty = unique().void_ty(); + auto spv_val_ty = spv_ty(val_ty); + auto spv_i32_ty = unique().scalar_ty(scalar_type::i32); + auto fun_ty = unique().function_ty( + spv_void_ty, array_view{spv_val_ty, spv_operand_ty, spv_i32_ty, spv_i32_ty, + spv_i32_ty, spv_i32_ty, spv_i32_ty}); + auto &mod = unique().mod(); + auto asmop = + mod.add_to(section::type_const_var, spv_void_ty, fun_ty, + unique().asm_target(), code, "rw,rw.u,rw.u,rw.u,rw.u,rw.u,rw.u"); + mod.add_to(section::decoration, asmop, Decoration::SideEffectsINTEL); + return asmop; + }); +} + +auto coopmatrix_impl_dpas::mul_add_fun(coopmatrix_data_type const *at, + coopmatrix_data_type const *bt, + coopmatrix_data_type const *ct, + coopmatrix_data_type const *rt, bool is_c_zero) + -> spv_inst * { + const auto key = mul_add_key{{at, bt, ct, rt}, is_c_zero}; + return lookup(mul_add_funs_, key, [&](mul_add_key const &key) { + const auto [at, bt, ct, rt] = key.op_ty; + + auto oasm = std::ostringstream{}; + + const std::int32_t ops_per_chan = xe::channel_size / size(at->component_ty()); + const std::int32_t K = ops_per_chan * xe::sdepth; + + oasm << "{\n"; + std::string result_placeholder = "$0"; + std::string temp = result_placeholder; + if (rt->component_ty() != ct->component_ty() && at->cols() / K > 1) { + temp = tmp_("temp"); + oasm << ".decl " << temp << " v_type=G type=" << visa_type(ct->component_ty()) + << " num_elts=" << ct->rows() * ct->cols() << " align=wordx32\n"; + } + const auto mat_A = tmp_("matrix_A"); + const auto mat_B = tmp_("matrix_B"); + oasm << ".decl " << mat_A + << " v_type=G type=d num_elts=" << at->rows() * at->cols() / ops_per_chan + << " align=wordx32 alias=<$1,0>\n"; + oasm << ".decl " << mat_B + << " v_type=G type=d num_elts=" << bt->rows() * bt->cols() / ops_per_chan + << " align=wordx32 alias=<$2,0>\n"; + + /** The GRF layout must follow the layout described in the following. + * + * Let CM, CN, CK be the size of the coopmatrices, where + * CM = ct->rows() = at->rows(), + * CN = ct->cols() = bt->cols(), + * CK = at->cols() = bt->rows(), + * and let M, N, K be the size expected by DPAS, where + * M = xe::exec_size, + * N = xe::rcount + * K = ops_per_chan * xe::sdepth + * Let BM:=CM/M, BN:=CN/N, BK:=CK/K be the number of blocks in the respective mode. + * + * The blocks are laid out in the GRF as following + * + * A[m,k,bk,bm] = m + k * M + bk * M * K + bm * M * K * BK + * B[k,n,bn,bk] = k + n * K + bn * K * N + bk * K * N * BN + * C[m,n,bn,bm] = m + n * M + bn * M * N + bm * M * N * BN + * + * where m \in [M], n \in [N], k \in [K], bm \in [BM], bn \in [BN], bk \in [BK]. + * + * The mapping of m,n,k,bm,bn,bk to memory address is given by + * + * MA[m,k,bk,bm] = m' + bm' * M + (k' + bk' * K) * A_stride1 + * MB[k,n,bn,bk] = k'' + bk'' * K + (n'' + bn'' * N) * B_stride1 + * MC[m,n,bn,bm] = m + bm * M + (n + bn * N) * C_stride1 + * + * where + * + * (m',k') = { (m%ops_per_chan + k*ops_per_chan, floor(m/ops_per_chan)) if A + * transposed { (floor(m/ops_per_chan) + k*(M/ops_per_chan), m%ops_per_chan) else (bm',bk') + * = { (bk,bm) if A transposed { (bm,bk) else + * + * and + * + * (k'',n'') = { (n,k) if B transposed + * { (k,n) else + * (bk'',bn'') = { (bn,bk) if B transposed + * { (bk,bn) else + * + */ + const auto precision_src1 = precision(at->component_ty()); + const auto precision_src2 = precision(bt->component_ty()); + for (std::int32_t k = 0; k < at->cols(); k += K) { + char const *src0 = k > 0 ? temp.c_str() : (!key.is_c_zero ? "$3" : "%null"); + char const *dst = k + K >= at->cols() ? result_placeholder.c_str() : temp.c_str(); + const auto rsize = + k + K >= at->cols() ? size(rt->component_ty()) : size(ct->component_ty()); + for (std::int32_t m = 0; m < ct->rows(); m += xe::exec_size) { + for (std::int32_t n = 0; n < ct->cols(); n += xe::rcount) { + const auto aoffset = + (k * xe::exec_size + m * at->cols()) * size(at->component_ty()); + const auto brow = + (k * bt->cols() + n * K) * size(bt->component_ty()) / xe::grf_size; + const auto coffset = + !key.is_c_zero || k > 0 + ? (m * ct->cols() + n * xe::exec_size) * size(ct->component_ty()) + : 0; + const auto roffset = (m * rt->cols() + n * xe::exec_size) * rsize; + oasm << "dpas." << precision_src1 << "." << precision_src2 << "." << xe::sdepth + << "." << xe::rcount << " (M1," << xe::exec_size << ") " << dst << "." + << roffset << " " << src0 << "." << coffset << " " << mat_A << "." + << aoffset << " " << mat_B << "(" << brow << ",0)\n"; + } + } + } + oasm << "}\n"; + + auto spv_a_ty = spv_ty(at); + auto spv_b_ty = spv_ty(bt); + auto spv_c_ty = spv_ty(ct); + auto spv_result_ty = spv_ty(rt); + auto fun_ty = unique().function_ty(spv_result_ty, + array_view{spv_a_ty, spv_b_ty, spv_c_ty}); + + return unique().mod().add_to(section::type_const_var, spv_result_ty, fun_ty, + unique().asm_target(), std::move(oasm).str(), + "=rw,rw,rw,rw"); + }); +} + +auto coopmatrix_impl_dpas::reduce_fun(std::int32_t sgs, group_arithmetic arith, + coopmatrix_data_type const *at, + coopmatrix_data_type const *rt) -> spv_inst * { + const auto key = std::make_tuple(sgs, arith, at, rt); + return lookup(reduce_funs_, key, [&](reduce_key const &key) { + auto [sgs, arith, at, rt] = key; + auto rl = get_layout(cfg(), rt); + auto al = get_layout(cfg(), at); + auto matrix_ty = spv_ty(rl); + const auto sty = rt->component_ty(); + const auto sty_size = size(sty); + + auto oasm = std::ostringstream{}; + + oasm << "{\n"; + auto aview = tmp_("aview"); + oasm << ".decl " << aview << " v_type=G type=" << visa_type(at->component_ty()) + << " num_elts=" << al.length * sgs << " align=wordx32 alias=<$1,0>\n"; + auto rview = tmp_("rview"); + oasm << ".decl " << rview << " v_type=G type=" << visa_type(rt->component_ty()) + << " num_elts=" << rl.length * sgs << " align=wordx32 alias=<$0,0>\n"; + auto predicate = tmp_("predicate"); + oasm << ".decl " << predicate << " v_type=P num_elts=" << sgs << "\n"; + + auto const reduce = [&]() -> char const * { + switch (arith) { + case group_arithmetic::add: + return "add"; + case group_arithmetic::max: + return "max"; + case group_arithmetic::min: + return "min"; + default: + break; + } + throw status::internal_compiler_error; + }; + + for (std::int32_t offset = 0; offset < al.shape1; offset += sgs) { + const std::int32_t remainder = + std::min(sgs, static_cast(al.shape1 - offset)); + std::string src = aview; + if (al.blocks > 1) { + auto tmp = tmp_("tmp"); + oasm << ".decl " << tmp << " v_type=G type=" << visa_type(at->component_ty()) + << " num_elts=" << sgs * sgs << " align=wordx32\n"; + for (std::int32_t j0 = offset; j0 < offset + remainder; ++j0) { + const auto t1 = region_origin(sty_size, sgs * (j0 - offset) * sty_size); + const auto a1 = + region_origin(sty_size, sgs * al.component_no(j0, 0) * sty_size); + const auto a2 = + region_origin(sty_size, sgs * al.component_no(j0, 1) * sty_size); + oasm << reduce() << " (M1," << sgs << ") " << tmp << "(" << t1[0] << "," + << t1[1] << ")<1> " << aview << "(" << a1[0] << "," << a1[1] << ")<1;1,0> " + << aview << "(" << a2[0] << "," << a2[1] << ")<1;1,0>\n"; + for (std::int32_t b = 2; b < al.blocks; ++b) { + const auto a2 = + region_origin(sty_size, sgs * al.component_no(j0, b) * sty_size); + oasm << reduce() << " (M1," << sgs << ") " << tmp << "(" << t1[0] << "," + << t1[1] << ")<1> " << tmp << "(" << t1[0] << "," << t1[1] + << ")<1;1,0> " << aview << "(" << a2[0] << "," << a2[1] + << ")<1;1,0>\n"; + } + } + src = tmp; + } + + for (std::int32_t v = 1; v < sgs; v *= 2) { + std::uint32_t pval = 0; + for (std::uint32_t j = 0; j < 32; j += 2 * v) { + pval |= (((1 << v) - 1) << j); + } + oasm << "setp (M1," << sgs << ") " << predicate << " " << pval << ":ud\n"; + + std::string dst = rview; + auto dst_offset = offset; + if (2 * v < sgs) { + auto tmp = tmp_("tmp"); + oasm << ".decl " << tmp << " v_type=G type=" << visa_type(at->component_ty()) + << " num_elts=" << sgs * sgs / (2 * v) << " align=wordx32\n"; + dst = tmp; + dst_offset = 0; + } + + for (int i = 0; i < sgs / v && i < remainder; i += 2) { + auto tmp1 = tmp_("tmp"); + oasm << ".decl " << tmp1 << " v_type=G type=" << visa_type(at->component_ty()) + << " num_elts=" << sgs << " align=wordx32\n"; + auto tmp2 = tmp_("tmp"); + oasm << ".decl " << tmp2 << " v_type=G type=" << visa_type(at->component_ty()) + << " num_elts=" << sgs << " align=wordx32\n"; + + const auto t0 = region_origin(sty_size, (dst_offset + sgs * i / 2) * sty_size); + const auto t1 = region_origin(sty_size, sgs * i * sty_size); + const auto t1down = region_origin(sty_size, (sgs * i + v) * sty_size); + const auto t2 = region_origin(sty_size, sgs * (i + 1) * sty_size); + const auto t2up = region_origin(sty_size, (sgs * (i + 1) - v) * sty_size); + oasm << "(!" << predicate << ") sel (M1," << sgs << ") " << tmp1 << "(0,0)<1> " + << src << "(" << t2up[0] << "," << t2up[1] << ")<1;1,0> " << src << "(" + << t1[0] << "," << t1[1] << ")<1;1,0>\n"; + oasm << "(" << predicate << ") sel (M1," << sgs << ") " << tmp2 << "(0,0)<1> " + << src << "(" << t1down[0] << "," << t1down[1] << ")<1;1,0> " << src << "(" + << t2[0] << "," << t2[1] << ")<1;1,0>\n"; + oasm << reduce() << " (M1," << sgs << ") " << dst << "(" << t0[0] << "," + << t0[1] << ")<1> " << tmp1 << "(0,0)<1;1,0> " << tmp2 << "(0,0)<1;1,0>\n"; + } + src = dst; + } + } + oasm << "}\n"; + + auto fun_ty = unique().function_ty(matrix_ty, array_view{spv_ty(al)}); + return unique().mod().add_to(section::type_const_var, matrix_ty, fun_ty, + unique().asm_target(), std::move(oasm).str(), + "=rw,rw"); + }); +} + +auto coopmatrix_impl_dpas::load(cooperative_matrix_load_inst const &in, dope_vector const &odv, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) -> spv_inst * { + auto rt = get_coopmatrix_type(in.result(0)); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const auto type_ok = cfg().matrix->have_type(rt); + const auto block_io_ok = check_2d_block_io(in.operand(), in.pos0()); + const bool transpose_ok = in.t() == transpose::N || rt->use() == matrix_use::a; + + if (!sgs_ok || !type_ok || !block_io_ok || !transpose_ok) { + return coopmatrix_impl_block::load(in, odv, pointer, pos0, pos1); + } + + auto ot = get_memref_type(in.operand()); + auto ct = get_coopmatrix_type(in.result(0)); + auto fun = load_fun(ct, unique().pointer_ty(ot), in.t()); + + auto &mod = unique().mod(); + auto spv_i32_ty = unique().scalar_ty(scalar_type::i32); + auto csize = unique().constant(static_cast(size(ot->element_ty()))); + auto shape0_i32 = mod.add(spv_i32_ty, odv.shape(0)); + auto width_in_bytes = mod.add(spv_i32_ty, shape0_i32, csize); + auto height = mod.add(spv_i32_ty, odv.shape(1)); + auto stride1_i32 = mod.add(spv_i32_ty, odv.stride(1)); + auto stride_in_bytes = mod.add(spv_i32_ty, stride1_i32, csize); + auto pos0_i32 = mod.add(spv_i32_ty, pos0); + auto pos1_i32 = mod.add(spv_i32_ty, pos1); + + return mod.add(spv_ty(ct), fun, + array_view{pointer, width_in_bytes, height, + stride_in_bytes, pos0_i32, pos1_i32}); +} + +auto coopmatrix_impl_dpas::mul_add(cooperative_matrix_mul_add_inst const &in, spv_inst *a, + spv_inst *b, spv_inst *c) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + auto bt = get_coopmatrix_type(in.b()); + auto ct = get_coopmatrix_type(in.c()); + auto rt = get_coopmatrix_type(in.result(0)); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const bool have_gemm = + cfg().matrix->have_gemm(at->component_ty(), bt->component_ty(), ct->component_ty(), + rt->component_ty(), rt->rows(), rt->cols(), at->cols()); + if (!sgs_ok || !have_gemm) { + return coopmatrix_impl_block::mul_add(in, a, b, c); + } + + auto fun = mul_add_fun(at, bt, ct, rt, in.is_c_zero()); + return unique().mod().add(spv_ty(rt), fun, array_view{a, b, c}); +} + +void coopmatrix_impl_dpas::prefetch(cooperative_matrix_prefetch_inst const &in, + dope_vector const &odv, spv_inst *pointer, spv_inst *pos0, + spv_inst *pos1) { + auto ot = get_memref_type(in.operand()); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const auto type_ok = size(ot->element_ty()) <= 4; + const auto block_io_ok = check_2d_block_io(in.operand(), in.pos0()); + + if (!sgs_ok || !type_ok || !block_io_ok) { + coopmatrix_impl_block::prefetch(in, odv, pointer, pos0, pos1); + } else { + auto fun = prefetch_fun(in.cache_level(), ot->element_ty(), unique().pointer_ty(ot), + in.rows(), in.cols()); + + if (fun) { + auto &mod = unique().mod(); + auto spv_void_ty = unique().void_ty(); + auto spv_i32_ty = unique().scalar_ty(scalar_type::i32); + auto csize = unique().constant(static_cast(size(ot->element_ty()))); + auto shape0_i32 = mod.add(spv_i32_ty, odv.shape(0)); + auto width_in_bytes = mod.add(spv_i32_ty, shape0_i32, csize); + auto height = mod.add(spv_i32_ty, odv.shape(1)); + auto stride1_i32 = mod.add(spv_i32_ty, odv.stride(1)); + auto stride_in_bytes = mod.add(spv_i32_ty, stride1_i32, csize); + auto pos0_i32 = mod.add(spv_i32_ty, pos0); + auto pos1_i32 = mod.add(spv_i32_ty, pos1); + + mod.add(spv_void_ty, fun, + array_view{pointer, width_in_bytes, height, + stride_in_bytes, pos0_i32, pos1_i32}); + } + } +} + +void coopmatrix_impl_dpas::store(cooperative_matrix_store_inst const &in, dope_vector const &odv, + spv_inst *val, spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) { + auto ct = get_coopmatrix_type(in.val()); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const auto type_ok = cfg().matrix->have_type(ct); + const auto block_io_ok = check_2d_block_io(in.operand(), in.pos0()); + + if (!sgs_ok || !type_ok || !block_io_ok) { + coopmatrix_impl_block::store(in, odv, val, pointer, pos0, pos1); + } else { + auto ot = get_memref_type(in.operand()); + auto fun = store_fun(ct, unique().pointer_ty(ot)); + + auto &mod = unique().mod(); + auto spv_void_ty = unique().void_ty(); + auto spv_i32_ty = unique().scalar_ty(scalar_type::i32); + auto csize = unique().constant(static_cast(size(ot->element_ty()))); + auto shape0_i32 = mod.add(spv_i32_ty, odv.shape(0)); + auto width_in_bytes = mod.add(spv_i32_ty, shape0_i32, csize); + auto height = mod.add(spv_i32_ty, odv.shape(1)); + auto stride1_i32 = mod.add(spv_i32_ty, odv.stride(1)); + auto stride_in_bytes = mod.add(spv_i32_ty, stride1_i32, csize); + auto pos0_i32 = mod.add(spv_i32_ty, pos0); + auto pos1_i32 = mod.add(spv_i32_ty, pos1); + + mod.add(spv_void_ty, fun, + array_view{val, pointer, width_in_bytes, height, + stride_in_bytes, pos0_i32, pos1_i32}); + } +} + +auto coopmatrix_impl_dpas::reduce(cooperative_matrix_reduce_inst const &in, spv_inst *a) + -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + const auto sgs = cfg().subgroup_size; + + if (in.mode() != reduce_mode::column || at->rows() % sgs != 0 || at->use() == matrix_use::a) { + return coopmatrix_impl::reduce(in, a); + } + + auto rt = get_coopmatrix_type(in.result(0)); + auto fun = reduce_fun(sgs, in.arith(), at, rt); + return unique().mod().add(spv_ty(rt), fun, array_view{a}); +} + +} // namespace tinytc::spv + diff --git a/src/spv/coopmatrix_impl_dpas.hpp b/src/spv/coopmatrix_impl_dpas.hpp new file mode 100644 index 00000000..8628fbda --- /dev/null +++ b/src/spv/coopmatrix_impl_dpas.hpp @@ -0,0 +1,96 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_IMPL_DPAS_20250428_HPP +#define COOPMATRIX_IMPL_DPAS_20250428_HPP + +#include "spv/block2d_diy.hpp" +#include "spv/coopmatrix_impl_block.hpp" +#include "support/temp_counter.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { +class coopmatrix_data_type; +} // namespace tinytc + +namespace tinytc::spv { + +class spv_inst; + +class coopmatrix_impl_dpas : public coopmatrix_impl_block { + public: + using coopmatrix_impl_block::coopmatrix_impl_block; + + auto load(cooperative_matrix_load_inst const &in, dope_vector const &odv, spv_inst *pointer, + spv_inst *pos0, spv_inst *pos1) -> spv_inst * override; + auto mul_add(cooperative_matrix_mul_add_inst const &in, spv_inst *a, spv_inst *b, spv_inst *c) + -> spv_inst * override; + void prefetch(cooperative_matrix_prefetch_inst const &in, dope_vector const &odv, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) override; + void store(cooperative_matrix_store_inst const &in, dope_vector const &odv, spv_inst *val, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) override; + auto reduce(cooperative_matrix_reduce_inst const &in, spv_inst *a) -> spv_inst * override; + + private: + struct mul_add_key { + std::array op_ty; + bool is_c_zero; + + auto operator==(mul_add_key const &other) const { + return op_ty == other.op_ty && is_c_zero == other.is_c_zero; + } + }; + struct mul_add_hash { + inline auto operator()(mul_add_key const &key) const -> std::size_t { + return fnv1a_combine(key.op_ty[0], key.op_ty[1], key.op_ty[2], key.op_ty[3], + key.is_c_zero); + } + }; + template struct tuple_hash { + inline auto operator()(Tuple const &key) const -> std::size_t { + return std::apply([](auto const &...args) { return fnv1a_combine(args...); }, key); + } + }; + + auto max_rows_in_block(matrix_use use, std::int32_t element_size) const -> std::int32_t; + auto check_2d_block_io(tinytc_value const &operand, tinytc_value const &pos0) -> bool; + auto load_config(scalar_type sty, std::int32_t rows, std::int32_t cols, matrix_use use, + transpose trans, std::int32_t cache_level = -1) -> block_config; + auto load_fun(coopmatrix_data_type const *result_ty, spv_inst *spv_operand_ty, transpose trans) + -> spv_inst *; + auto prefetch_fun(std::int32_t cache_level, scalar_type sty, spv_inst *spv_operand_ty, + std::int32_t rows, std::int32_t cols) -> spv_inst *; + auto store_config(coopmatrix_data_type const *ct) -> block_config; + auto store_fun(coopmatrix_data_type const *val_ty, spv_inst *spv_operand_ty) -> spv_inst *; + auto mul_add_fun(coopmatrix_data_type const *at, coopmatrix_data_type const *bt, + coopmatrix_data_type const *ct, coopmatrix_data_type const *rt, bool is_c_zero) + -> spv_inst *; + auto reduce_fun(std::int32_t sgs, group_arithmetic arith, coopmatrix_data_type const *at, + coopmatrix_data_type const *rt) -> spv_inst *; + + using load_key = std::tuple; + using prefetch_key = + std::tuple; + using store_key = std::tuple; + using reduce_key = std::tuple; + std::unordered_map> load_funs_; + std::unordered_map> prefetch_funs_; + std::unordered_map> store_funs_; + std::unordered_map mul_add_funs_; + std::unordered_map> reduce_funs_; + temp_counter tmp_; +}; + +} // namespace tinytc::spv + +#endif // COOPMATRIX_IMPL_DPAS_20250428_HPP diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp new file mode 100644 index 00000000..18a96a9b --- /dev/null +++ b/src/spv/defs.hpp @@ -0,0 +1,426 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_DEFS_20250605_HPP +#define GENERATED_DEFS_20250605_HPP + +#include "enums.hpp" +#include "tinytc/tinytc.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst : public ilist_node { + public: + inline spv_inst(Op opcode, bool has_result_id) + : opcode_{opcode}, id_{has_result_id ? 0 : std::numeric_limits::max()} {} + virtual ~spv_inst() = default; + + spv_inst(spv_inst const &other) = delete; + spv_inst(spv_inst &&other) = delete; + spv_inst &operator=(spv_inst const &other) = delete; + spv_inst &operator=(spv_inst &&other) = delete; + + inline auto opcode() const -> Op { return opcode_; } + // SPIR-V requires 0 < id < Bound, therefore, we can reserve 0 for encoding "produces result; id + // not yet assigned" and uint32_max for encoding "does not produce result" + inline auto has_result_id() const -> bool { + return id_ != std::numeric_limits::max(); + } + inline auto id() const -> std::uint32_t { return id_; } + inline void id(std::uint32_t id) { id_ = id; } + + private: + Op opcode_; + std::uint32_t id_; +}; + +using DecorationAttr = std::variant>; +using ExecutionModeAttr = std::variant>; +using LiteralContextDependentNumber = + std::variant; +using LiteralString = std::string; +using LiteralInteger = std::int32_t; +using LiteralExtInstInteger = std::int32_t; +using IdResultType = spv_inst *; +using IdRef = spv_inst *; +using IdScope = spv_inst *; +using IdMemorySemantics = spv_inst *; +using LoopControlAttr = std::int32_t; +using MemoryAccessAttr = std::int32_t; +using PairIdRefIdRef = std::pair; +using PairLiteralIntegerIdRef = + std::pair, spv_inst *>; +using PairIdRefLiteralInteger = std::pair; + +class OpNop; // IWYU pragma: export +class OpUndef; // IWYU pragma: export +class OpSourceContinued; // IWYU pragma: export +class OpSource; // IWYU pragma: export +class OpSourceExtension; // IWYU pragma: export +class OpName; // IWYU pragma: export +class OpMemberName; // IWYU pragma: export +class OpString; // IWYU pragma: export +class OpLine; // IWYU pragma: export +class OpExtension; // IWYU pragma: export +class OpExtInstImport; // IWYU pragma: export +class OpExtInst; // IWYU pragma: export +class OpMemoryModel; // IWYU pragma: export +class OpEntryPoint; // IWYU pragma: export +class OpExecutionMode; // IWYU pragma: export +class OpCapability; // IWYU pragma: export +class OpTypeVoid; // IWYU pragma: export +class OpTypeBool; // IWYU pragma: export +class OpTypeInt; // IWYU pragma: export +class OpTypeFloat; // IWYU pragma: export +class OpTypeVector; // IWYU pragma: export +class OpTypeMatrix; // IWYU pragma: export +class OpTypeImage; // IWYU pragma: export +class OpTypeSampler; // IWYU pragma: export +class OpTypeSampledImage; // IWYU pragma: export +class OpTypeArray; // IWYU pragma: export +class OpTypeRuntimeArray; // IWYU pragma: export +class OpTypeStruct; // IWYU pragma: export +class OpTypeOpaque; // IWYU pragma: export +class OpTypePointer; // IWYU pragma: export +class OpTypeFunction; // IWYU pragma: export +class OpTypeEvent; // IWYU pragma: export +class OpTypeDeviceEvent; // IWYU pragma: export +class OpTypeReserveId; // IWYU pragma: export +class OpTypeQueue; // IWYU pragma: export +class OpTypePipe; // IWYU pragma: export +class OpTypeForwardPointer; // IWYU pragma: export +class OpConstantTrue; // IWYU pragma: export +class OpConstantFalse; // IWYU pragma: export +class OpConstant; // IWYU pragma: export +class OpConstantComposite; // IWYU pragma: export +class OpConstantSampler; // IWYU pragma: export +class OpConstantNull; // IWYU pragma: export +class OpFunction; // IWYU pragma: export +class OpFunctionParameter; // IWYU pragma: export +class OpFunctionEnd; // IWYU pragma: export +class OpFunctionCall; // IWYU pragma: export +class OpVariable; // IWYU pragma: export +class OpImageTexelPointer; // IWYU pragma: export +class OpLoad; // IWYU pragma: export +class OpStore; // IWYU pragma: export +class OpCopyMemory; // IWYU pragma: export +class OpCopyMemorySized; // IWYU pragma: export +class OpAccessChain; // IWYU pragma: export +class OpInBoundsAccessChain; // IWYU pragma: export +class OpPtrAccessChain; // IWYU pragma: export +class OpArrayLength; // IWYU pragma: export +class OpGenericPtrMemSemantics; // IWYU pragma: export +class OpInBoundsPtrAccessChain; // IWYU pragma: export +class OpDecorate; // IWYU pragma: export +class OpMemberDecorate; // IWYU pragma: export +class OpDecorationGroup; // IWYU pragma: export +class OpGroupDecorate; // IWYU pragma: export +class OpGroupMemberDecorate; // IWYU pragma: export +class OpVectorExtractDynamic; // IWYU pragma: export +class OpVectorInsertDynamic; // IWYU pragma: export +class OpVectorShuffle; // IWYU pragma: export +class OpCompositeConstruct; // IWYU pragma: export +class OpCompositeExtract; // IWYU pragma: export +class OpCompositeInsert; // IWYU pragma: export +class OpCopyObject; // IWYU pragma: export +class OpTranspose; // IWYU pragma: export +class OpSampledImage; // IWYU pragma: export +class OpImageSampleImplicitLod; // IWYU pragma: export +class OpImageSampleExplicitLod; // IWYU pragma: export +class OpImageSampleDrefImplicitLod; // IWYU pragma: export +class OpImageSampleDrefExplicitLod; // IWYU pragma: export +class OpImageSampleProjImplicitLod; // IWYU pragma: export +class OpImageSampleProjExplicitLod; // IWYU pragma: export +class OpImageSampleProjDrefImplicitLod; // IWYU pragma: export +class OpImageSampleProjDrefExplicitLod; // IWYU pragma: export +class OpImageFetch; // IWYU pragma: export +class OpImageGather; // IWYU pragma: export +class OpImageDrefGather; // IWYU pragma: export +class OpImageRead; // IWYU pragma: export +class OpImageWrite; // IWYU pragma: export +class OpImage; // IWYU pragma: export +class OpImageQueryFormat; // IWYU pragma: export +class OpImageQueryOrder; // IWYU pragma: export +class OpImageQuerySizeLod; // IWYU pragma: export +class OpImageQuerySize; // IWYU pragma: export +class OpImageQueryLod; // IWYU pragma: export +class OpImageQueryLevels; // IWYU pragma: export +class OpImageQuerySamples; // IWYU pragma: export +class OpConvertFToU; // IWYU pragma: export +class OpConvertFToS; // IWYU pragma: export +class OpConvertSToF; // IWYU pragma: export +class OpConvertUToF; // IWYU pragma: export +class OpUConvert; // IWYU pragma: export +class OpSConvert; // IWYU pragma: export +class OpFConvert; // IWYU pragma: export +class OpQuantizeToF16; // IWYU pragma: export +class OpConvertPtrToU; // IWYU pragma: export +class OpSatConvertSToU; // IWYU pragma: export +class OpSatConvertUToS; // IWYU pragma: export +class OpConvertUToPtr; // IWYU pragma: export +class OpPtrCastToGeneric; // IWYU pragma: export +class OpGenericCastToPtr; // IWYU pragma: export +class OpGenericCastToPtrExplicit; // IWYU pragma: export +class OpBitcast; // IWYU pragma: export +class OpSNegate; // IWYU pragma: export +class OpFNegate; // IWYU pragma: export +class OpIAdd; // IWYU pragma: export +class OpFAdd; // IWYU pragma: export +class OpISub; // IWYU pragma: export +class OpFSub; // IWYU pragma: export +class OpIMul; // IWYU pragma: export +class OpFMul; // IWYU pragma: export +class OpUDiv; // IWYU pragma: export +class OpSDiv; // IWYU pragma: export +class OpFDiv; // IWYU pragma: export +class OpUMod; // IWYU pragma: export +class OpSRem; // IWYU pragma: export +class OpSMod; // IWYU pragma: export +class OpFRem; // IWYU pragma: export +class OpFMod; // IWYU pragma: export +class OpVectorTimesScalar; // IWYU pragma: export +class OpMatrixTimesScalar; // IWYU pragma: export +class OpVectorTimesMatrix; // IWYU pragma: export +class OpMatrixTimesVector; // IWYU pragma: export +class OpMatrixTimesMatrix; // IWYU pragma: export +class OpOuterProduct; // IWYU pragma: export +class OpDot; // IWYU pragma: export +class OpIAddCarry; // IWYU pragma: export +class OpISubBorrow; // IWYU pragma: export +class OpUMulExtended; // IWYU pragma: export +class OpSMulExtended; // IWYU pragma: export +class OpAny; // IWYU pragma: export +class OpAll; // IWYU pragma: export +class OpIsNan; // IWYU pragma: export +class OpIsInf; // IWYU pragma: export +class OpIsFinite; // IWYU pragma: export +class OpIsNormal; // IWYU pragma: export +class OpSignBitSet; // IWYU pragma: export +class OpLessOrGreater; // IWYU pragma: export +class OpOrdered; // IWYU pragma: export +class OpUnordered; // IWYU pragma: export +class OpLogicalEqual; // IWYU pragma: export +class OpLogicalNotEqual; // IWYU pragma: export +class OpLogicalOr; // IWYU pragma: export +class OpLogicalAnd; // IWYU pragma: export +class OpLogicalNot; // IWYU pragma: export +class OpSelect; // IWYU pragma: export +class OpIEqual; // IWYU pragma: export +class OpINotEqual; // IWYU pragma: export +class OpUGreaterThan; // IWYU pragma: export +class OpSGreaterThan; // IWYU pragma: export +class OpUGreaterThanEqual; // IWYU pragma: export +class OpSGreaterThanEqual; // IWYU pragma: export +class OpULessThan; // IWYU pragma: export +class OpSLessThan; // IWYU pragma: export +class OpULessThanEqual; // IWYU pragma: export +class OpSLessThanEqual; // IWYU pragma: export +class OpFOrdEqual; // IWYU pragma: export +class OpFUnordEqual; // IWYU pragma: export +class OpFOrdNotEqual; // IWYU pragma: export +class OpFUnordNotEqual; // IWYU pragma: export +class OpFOrdLessThan; // IWYU pragma: export +class OpFUnordLessThan; // IWYU pragma: export +class OpFOrdGreaterThan; // IWYU pragma: export +class OpFUnordGreaterThan; // IWYU pragma: export +class OpFOrdLessThanEqual; // IWYU pragma: export +class OpFUnordLessThanEqual; // IWYU pragma: export +class OpFOrdGreaterThanEqual; // IWYU pragma: export +class OpFUnordGreaterThanEqual; // IWYU pragma: export +class OpShiftRightLogical; // IWYU pragma: export +class OpShiftRightArithmetic; // IWYU pragma: export +class OpShiftLeftLogical; // IWYU pragma: export +class OpBitwiseOr; // IWYU pragma: export +class OpBitwiseXor; // IWYU pragma: export +class OpBitwiseAnd; // IWYU pragma: export +class OpNot; // IWYU pragma: export +class OpBitFieldInsert; // IWYU pragma: export +class OpBitFieldSExtract; // IWYU pragma: export +class OpBitFieldUExtract; // IWYU pragma: export +class OpBitReverse; // IWYU pragma: export +class OpBitCount; // IWYU pragma: export +class OpDPdx; // IWYU pragma: export +class OpDPdy; // IWYU pragma: export +class OpFwidth; // IWYU pragma: export +class OpDPdxFine; // IWYU pragma: export +class OpDPdyFine; // IWYU pragma: export +class OpFwidthFine; // IWYU pragma: export +class OpDPdxCoarse; // IWYU pragma: export +class OpDPdyCoarse; // IWYU pragma: export +class OpFwidthCoarse; // IWYU pragma: export +class OpEmitVertex; // IWYU pragma: export +class OpEndPrimitive; // IWYU pragma: export +class OpEmitStreamVertex; // IWYU pragma: export +class OpEndStreamPrimitive; // IWYU pragma: export +class OpControlBarrier; // IWYU pragma: export +class OpMemoryBarrier; // IWYU pragma: export +class OpAtomicLoad; // IWYU pragma: export +class OpAtomicStore; // IWYU pragma: export +class OpAtomicExchange; // IWYU pragma: export +class OpAtomicCompareExchange; // IWYU pragma: export +class OpAtomicCompareExchangeWeak; // IWYU pragma: export +class OpAtomicIIncrement; // IWYU pragma: export +class OpAtomicIDecrement; // IWYU pragma: export +class OpAtomicIAdd; // IWYU pragma: export +class OpAtomicISub; // IWYU pragma: export +class OpAtomicSMin; // IWYU pragma: export +class OpAtomicUMin; // IWYU pragma: export +class OpAtomicSMax; // IWYU pragma: export +class OpAtomicUMax; // IWYU pragma: export +class OpAtomicAnd; // IWYU pragma: export +class OpAtomicOr; // IWYU pragma: export +class OpAtomicXor; // IWYU pragma: export +class OpPhi; // IWYU pragma: export +class OpLoopMerge; // IWYU pragma: export +class OpSelectionMerge; // IWYU pragma: export +class OpLabel; // IWYU pragma: export +class OpBranch; // IWYU pragma: export +class OpBranchConditional; // IWYU pragma: export +class OpSwitch; // IWYU pragma: export +class OpKill; // IWYU pragma: export +class OpReturn; // IWYU pragma: export +class OpReturnValue; // IWYU pragma: export +class OpUnreachable; // IWYU pragma: export +class OpLifetimeStart; // IWYU pragma: export +class OpLifetimeStop; // IWYU pragma: export +class OpGroupAsyncCopy; // IWYU pragma: export +class OpGroupWaitEvents; // IWYU pragma: export +class OpGroupAll; // IWYU pragma: export +class OpGroupAny; // IWYU pragma: export +class OpGroupBroadcast; // IWYU pragma: export +class OpGroupIAdd; // IWYU pragma: export +class OpGroupFAdd; // IWYU pragma: export +class OpGroupFMin; // IWYU pragma: export +class OpGroupUMin; // IWYU pragma: export +class OpGroupSMin; // IWYU pragma: export +class OpGroupFMax; // IWYU pragma: export +class OpGroupUMax; // IWYU pragma: export +class OpGroupSMax; // IWYU pragma: export +class OpReadPipe; // IWYU pragma: export +class OpWritePipe; // IWYU pragma: export +class OpReservedReadPipe; // IWYU pragma: export +class OpReservedWritePipe; // IWYU pragma: export +class OpReserveReadPipePackets; // IWYU pragma: export +class OpReserveWritePipePackets; // IWYU pragma: export +class OpCommitReadPipe; // IWYU pragma: export +class OpCommitWritePipe; // IWYU pragma: export +class OpIsValidReserveId; // IWYU pragma: export +class OpGetNumPipePackets; // IWYU pragma: export +class OpGetMaxPipePackets; // IWYU pragma: export +class OpGroupReserveReadPipePackets; // IWYU pragma: export +class OpGroupReserveWritePipePackets; // IWYU pragma: export +class OpGroupCommitReadPipe; // IWYU pragma: export +class OpGroupCommitWritePipe; // IWYU pragma: export +class OpEnqueueMarker; // IWYU pragma: export +class OpEnqueueKernel; // IWYU pragma: export +class OpGetKernelNDrangeSubGroupCount; // IWYU pragma: export +class OpGetKernelNDrangeMaxSubGroupSize; // IWYU pragma: export +class OpGetKernelWorkGroupSize; // IWYU pragma: export +class OpGetKernelPreferredWorkGroupSizeMultiple; // IWYU pragma: export +class OpRetainEvent; // IWYU pragma: export +class OpReleaseEvent; // IWYU pragma: export +class OpCreateUserEvent; // IWYU pragma: export +class OpIsValidEvent; // IWYU pragma: export +class OpSetUserEventStatus; // IWYU pragma: export +class OpCaptureEventProfilingInfo; // IWYU pragma: export +class OpGetDefaultQueue; // IWYU pragma: export +class OpBuildNDRange; // IWYU pragma: export +class OpImageSparseSampleImplicitLod; // IWYU pragma: export +class OpImageSparseSampleExplicitLod; // IWYU pragma: export +class OpImageSparseSampleDrefImplicitLod; // IWYU pragma: export +class OpImageSparseSampleDrefExplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjImplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjExplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjDrefImplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjDrefExplicitLod; // IWYU pragma: export +class OpImageSparseFetch; // IWYU pragma: export +class OpImageSparseGather; // IWYU pragma: export +class OpImageSparseDrefGather; // IWYU pragma: export +class OpImageSparseTexelsResident; // IWYU pragma: export +class OpNoLine; // IWYU pragma: export +class OpAtomicFlagTestAndSet; // IWYU pragma: export +class OpAtomicFlagClear; // IWYU pragma: export +class OpImageSparseRead; // IWYU pragma: export +class OpSizeOf; // IWYU pragma: export +class OpTypePipeStorage; // IWYU pragma: export +class OpConstantPipeStorage; // IWYU pragma: export +class OpCreatePipeFromPipeStorage; // IWYU pragma: export +class OpGetKernelLocalSizeForSubgroupCount; // IWYU pragma: export +class OpGetKernelMaxNumSubgroups; // IWYU pragma: export +class OpTypeNamedBarrier; // IWYU pragma: export +class OpNamedBarrierInitialize; // IWYU pragma: export +class OpMemoryNamedBarrier; // IWYU pragma: export +class OpModuleProcessed; // IWYU pragma: export +class OpExecutionModeId; // IWYU pragma: export +class OpDecorateId; // IWYU pragma: export +class OpGroupNonUniformElect; // IWYU pragma: export +class OpGroupNonUniformAll; // IWYU pragma: export +class OpGroupNonUniformAny; // IWYU pragma: export +class OpGroupNonUniformAllEqual; // IWYU pragma: export +class OpGroupNonUniformBroadcast; // IWYU pragma: export +class OpGroupNonUniformBroadcastFirst; // IWYU pragma: export +class OpGroupNonUniformBallot; // IWYU pragma: export +class OpGroupNonUniformInverseBallot; // IWYU pragma: export +class OpGroupNonUniformBallotBitExtract; // IWYU pragma: export +class OpGroupNonUniformBallotBitCount; // IWYU pragma: export +class OpGroupNonUniformBallotFindLSB; // IWYU pragma: export +class OpGroupNonUniformBallotFindMSB; // IWYU pragma: export +class OpGroupNonUniformShuffle; // IWYU pragma: export +class OpGroupNonUniformShuffleXor; // IWYU pragma: export +class OpGroupNonUniformShuffleUp; // IWYU pragma: export +class OpGroupNonUniformShuffleDown; // IWYU pragma: export +class OpGroupNonUniformIAdd; // IWYU pragma: export +class OpGroupNonUniformFAdd; // IWYU pragma: export +class OpGroupNonUniformIMul; // IWYU pragma: export +class OpGroupNonUniformFMul; // IWYU pragma: export +class OpGroupNonUniformSMin; // IWYU pragma: export +class OpGroupNonUniformUMin; // IWYU pragma: export +class OpGroupNonUniformFMin; // IWYU pragma: export +class OpGroupNonUniformSMax; // IWYU pragma: export +class OpGroupNonUniformUMax; // IWYU pragma: export +class OpGroupNonUniformFMax; // IWYU pragma: export +class OpGroupNonUniformBitwiseAnd; // IWYU pragma: export +class OpGroupNonUniformBitwiseOr; // IWYU pragma: export +class OpGroupNonUniformBitwiseXor; // IWYU pragma: export +class OpGroupNonUniformLogicalAnd; // IWYU pragma: export +class OpGroupNonUniformLogicalOr; // IWYU pragma: export +class OpGroupNonUniformLogicalXor; // IWYU pragma: export +class OpGroupNonUniformQuadBroadcast; // IWYU pragma: export +class OpGroupNonUniformQuadSwap; // IWYU pragma: export +class OpCopyLogical; // IWYU pragma: export +class OpPtrEqual; // IWYU pragma: export +class OpPtrNotEqual; // IWYU pragma: export +class OpPtrDiff; // IWYU pragma: export +class OpTypeCooperativeMatrixKHR; // IWYU pragma: export +class OpCooperativeMatrixLoadKHR; // IWYU pragma: export +class OpCooperativeMatrixStoreKHR; // IWYU pragma: export +class OpCooperativeMatrixMulAddKHR; // IWYU pragma: export +class OpCooperativeMatrixLengthKHR; // IWYU pragma: export +class OpSubgroupBlockReadINTEL; // IWYU pragma: export +class OpSubgroupBlockWriteINTEL; // IWYU pragma: export +class OpAsmTargetINTEL; // IWYU pragma: export +class OpAsmINTEL; // IWYU pragma: export +class OpAsmCallINTEL; // IWYU pragma: export +class OpAtomicFMinEXT; // IWYU pragma: export +class OpAtomicFMaxEXT; // IWYU pragma: export +class OpAtomicFAddEXT; // IWYU pragma: export +class OpConvertFToBF16INTEL; // IWYU pragma: export +class OpConvertBF16ToFINTEL; // IWYU pragma: export +class OpControlBarrierArriveINTEL; // IWYU pragma: export +class OpControlBarrierWaitINTEL; // IWYU pragma: export +class OpCooperativeMatrixLoadCheckedINTEL; // IWYU pragma: export +class OpCooperativeMatrixStoreCheckedINTEL; // IWYU pragma: export + +} // namespace tinytc::spv + +#endif // GENERATED_DEFS_20250605_HPP diff --git a/src/spv/dope_vector.cpp b/src/spv/dope_vector.cpp new file mode 100644 index 00000000..d7ffc46c --- /dev/null +++ b/src/spv/dope_vector.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/dope_vector.hpp" +#include "spv/defs.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc::spv { + +dope_vector::dope_vector(spv_inst *ty, std::vector static_shape, + std::vector static_stride, spv_inst *size_ty, + std::int64_t static_size, spv_inst *offset_ty, std::int64_t static_offset) + : ty_(ty), static_shape_(std::move(static_shape)), static_stride_(std::move(static_stride)), + shape_(dim(), nullptr), stride_(dim(), nullptr), size_ty_(size_ty), offset_ty_(offset_ty), + static_size_(static_size), static_offset_(static_offset) { + if (static_shape_.size() != static_stride_.size()) { + throw status::internal_compiler_error; + } +} + +auto dope_vector::num_dynamic() const -> std::int64_t { + auto const sum_dynamic = [](std::vector const &vec) { + std::int64_t num_dynamic = 0; + for (auto &v : vec) { + if (is_dynamic_value(v)) { + ++num_dynamic; + } + } + return num_dynamic; + }; + return sum_dynamic(static_shape_) + sum_dynamic(static_stride_); +} + +} // namespace tinytc::spv + diff --git a/src/spv/dope_vector.hpp b/src/spv/dope_vector.hpp new file mode 100644 index 00000000..64bbec3f --- /dev/null +++ b/src/spv/dope_vector.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DOPE_VECTOR_20241213_HPP +#define DOPE_VECTOR_20241213_HPP + +#include +#include + +namespace tinytc::spv { + +class spv_inst; + +class dope_vector { + public: + dope_vector() = default; + dope_vector(spv_inst *ty, std::vector static_shape, + std::vector static_stride, spv_inst *size_ty = nullptr, + std::int64_t static_size = 0, spv_inst *offset_ty = nullptr, + std::int64_t static_offset = 0); + + inline auto dim() const -> std::int64_t { return static_shape_.size(); } + inline auto ty() const -> spv_inst * { return ty_; } + inline auto static_shape(std::int64_t i) const -> std::int64_t { return static_shape_[i]; } + inline auto static_stride(std::int64_t i) const -> std::int64_t { return static_stride_[i]; } + inline auto shape(std::int64_t i) const -> spv_inst * { return shape_[i]; } + inline auto stride(std::int64_t i) const -> spv_inst * { return stride_[i]; } + inline void shape(std::int64_t i, spv_inst *s) { shape_[i] = s; } + inline void stride(std::int64_t i, spv_inst *s) { stride_[i] = s; } + + inline auto size_ty() const -> spv_inst * { return size_ty_; } + inline auto static_size() const -> std::int64_t { return static_size_; } + inline auto size() -> spv_inst * { return size_; } + inline void size(spv_inst *size) { size_ = size; } + + inline auto offset_ty() const -> spv_inst * { return offset_ty_; } + inline auto static_offset() const -> std::int64_t { return static_offset_; } + inline auto offset() -> spv_inst * { return offset_; } + inline void offset(spv_inst *offset) { offset_ = offset; } + + auto num_dynamic() const -> std::int64_t; + + private: + spv_inst *ty_ = nullptr; + std::vector static_shape_, static_stride_; + std::vector shape_, stride_; + spv_inst *size_ty_ = nullptr, *offset_ty_ = nullptr; + std::int64_t static_size_, static_offset_; + spv_inst *size_ = nullptr; + spv_inst *offset_ = nullptr; +}; + +} // namespace tinytc::spv + +#endif // DOPE_VECTOR_20241213_HPP diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp new file mode 100644 index 00000000..874c177a --- /dev/null +++ b/src/spv/enums.hpp @@ -0,0 +1,1546 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_ENUMS_20250605_HPP +#define GENERATED_ENUMS_20250605_HPP + +#include + +namespace tinytc::spv { + +constexpr std::int32_t magic_number = 0x07230203; + +enum class Op { + Nop = 0, + Undef = 1, + SourceContinued = 2, + Source = 3, + SourceExtension = 4, + Name = 5, + MemberName = 6, + String = 7, + Line = 8, + Extension = 10, + ExtInstImport = 11, + ExtInst = 12, + MemoryModel = 14, + EntryPoint = 15, + ExecutionMode = 16, + Capability = 17, + TypeVoid = 19, + TypeBool = 20, + TypeInt = 21, + TypeFloat = 22, + TypeVector = 23, + TypeMatrix = 24, + TypeImage = 25, + TypeSampler = 26, + TypeSampledImage = 27, + TypeArray = 28, + TypeRuntimeArray = 29, + TypeStruct = 30, + TypeOpaque = 31, + TypePointer = 32, + TypeFunction = 33, + TypeEvent = 34, + TypeDeviceEvent = 35, + TypeReserveId = 36, + TypeQueue = 37, + TypePipe = 38, + TypeForwardPointer = 39, + ConstantTrue = 41, + ConstantFalse = 42, + Constant = 43, + ConstantComposite = 44, + ConstantSampler = 45, + ConstantNull = 46, + Function = 54, + FunctionParameter = 55, + FunctionEnd = 56, + FunctionCall = 57, + Variable = 59, + ImageTexelPointer = 60, + Load = 61, + Store = 62, + CopyMemory = 63, + CopyMemorySized = 64, + AccessChain = 65, + InBoundsAccessChain = 66, + PtrAccessChain = 67, + ArrayLength = 68, + GenericPtrMemSemantics = 69, + InBoundsPtrAccessChain = 70, + Decorate = 71, + MemberDecorate = 72, + DecorationGroup = 73, + GroupDecorate = 74, + GroupMemberDecorate = 75, + VectorExtractDynamic = 77, + VectorInsertDynamic = 78, + VectorShuffle = 79, + CompositeConstruct = 80, + CompositeExtract = 81, + CompositeInsert = 82, + CopyObject = 83, + Transpose = 84, + SampledImage = 86, + ImageSampleImplicitLod = 87, + ImageSampleExplicitLod = 88, + ImageSampleDrefImplicitLod = 89, + ImageSampleDrefExplicitLod = 90, + ImageSampleProjImplicitLod = 91, + ImageSampleProjExplicitLod = 92, + ImageSampleProjDrefImplicitLod = 93, + ImageSampleProjDrefExplicitLod = 94, + ImageFetch = 95, + ImageGather = 96, + ImageDrefGather = 97, + ImageRead = 98, + ImageWrite = 99, + Image = 100, + ImageQueryFormat = 101, + ImageQueryOrder = 102, + ImageQuerySizeLod = 103, + ImageQuerySize = 104, + ImageQueryLod = 105, + ImageQueryLevels = 106, + ImageQuerySamples = 107, + ConvertFToU = 109, + ConvertFToS = 110, + ConvertSToF = 111, + ConvertUToF = 112, + UConvert = 113, + SConvert = 114, + FConvert = 115, + QuantizeToF16 = 116, + ConvertPtrToU = 117, + SatConvertSToU = 118, + SatConvertUToS = 119, + ConvertUToPtr = 120, + PtrCastToGeneric = 121, + GenericCastToPtr = 122, + GenericCastToPtrExplicit = 123, + Bitcast = 124, + SNegate = 126, + FNegate = 127, + IAdd = 128, + FAdd = 129, + ISub = 130, + FSub = 131, + IMul = 132, + FMul = 133, + UDiv = 134, + SDiv = 135, + FDiv = 136, + UMod = 137, + SRem = 138, + SMod = 139, + FRem = 140, + FMod = 141, + VectorTimesScalar = 142, + MatrixTimesScalar = 143, + VectorTimesMatrix = 144, + MatrixTimesVector = 145, + MatrixTimesMatrix = 146, + OuterProduct = 147, + Dot = 148, + IAddCarry = 149, + ISubBorrow = 150, + UMulExtended = 151, + SMulExtended = 152, + Any = 154, + All = 155, + IsNan = 156, + IsInf = 157, + IsFinite = 158, + IsNormal = 159, + SignBitSet = 160, + LessOrGreater = 161, + Ordered = 162, + Unordered = 163, + LogicalEqual = 164, + LogicalNotEqual = 165, + LogicalOr = 166, + LogicalAnd = 167, + LogicalNot = 168, + Select = 169, + IEqual = 170, + INotEqual = 171, + UGreaterThan = 172, + SGreaterThan = 173, + UGreaterThanEqual = 174, + SGreaterThanEqual = 175, + ULessThan = 176, + SLessThan = 177, + ULessThanEqual = 178, + SLessThanEqual = 179, + FOrdEqual = 180, + FUnordEqual = 181, + FOrdNotEqual = 182, + FUnordNotEqual = 183, + FOrdLessThan = 184, + FUnordLessThan = 185, + FOrdGreaterThan = 186, + FUnordGreaterThan = 187, + FOrdLessThanEqual = 188, + FUnordLessThanEqual = 189, + FOrdGreaterThanEqual = 190, + FUnordGreaterThanEqual = 191, + ShiftRightLogical = 194, + ShiftRightArithmetic = 195, + ShiftLeftLogical = 196, + BitwiseOr = 197, + BitwiseXor = 198, + BitwiseAnd = 199, + Not = 200, + BitFieldInsert = 201, + BitFieldSExtract = 202, + BitFieldUExtract = 203, + BitReverse = 204, + BitCount = 205, + DPdx = 207, + DPdy = 208, + Fwidth = 209, + DPdxFine = 210, + DPdyFine = 211, + FwidthFine = 212, + DPdxCoarse = 213, + DPdyCoarse = 214, + FwidthCoarse = 215, + EmitVertex = 218, + EndPrimitive = 219, + EmitStreamVertex = 220, + EndStreamPrimitive = 221, + ControlBarrier = 224, + MemoryBarrier = 225, + AtomicLoad = 227, + AtomicStore = 228, + AtomicExchange = 229, + AtomicCompareExchange = 230, + AtomicCompareExchangeWeak = 231, + AtomicIIncrement = 232, + AtomicIDecrement = 233, + AtomicIAdd = 234, + AtomicISub = 235, + AtomicSMin = 236, + AtomicUMin = 237, + AtomicSMax = 238, + AtomicUMax = 239, + AtomicAnd = 240, + AtomicOr = 241, + AtomicXor = 242, + Phi = 245, + LoopMerge = 246, + SelectionMerge = 247, + Label = 248, + Branch = 249, + BranchConditional = 250, + Switch = 251, + Kill = 252, + Return = 253, + ReturnValue = 254, + Unreachable = 255, + LifetimeStart = 256, + LifetimeStop = 257, + GroupAsyncCopy = 259, + GroupWaitEvents = 260, + GroupAll = 261, + GroupAny = 262, + GroupBroadcast = 263, + GroupIAdd = 264, + GroupFAdd = 265, + GroupFMin = 266, + GroupUMin = 267, + GroupSMin = 268, + GroupFMax = 269, + GroupUMax = 270, + GroupSMax = 271, + ReadPipe = 274, + WritePipe = 275, + ReservedReadPipe = 276, + ReservedWritePipe = 277, + ReserveReadPipePackets = 278, + ReserveWritePipePackets = 279, + CommitReadPipe = 280, + CommitWritePipe = 281, + IsValidReserveId = 282, + GetNumPipePackets = 283, + GetMaxPipePackets = 284, + GroupReserveReadPipePackets = 285, + GroupReserveWritePipePackets = 286, + GroupCommitReadPipe = 287, + GroupCommitWritePipe = 288, + EnqueueMarker = 291, + EnqueueKernel = 292, + GetKernelNDrangeSubGroupCount = 293, + GetKernelNDrangeMaxSubGroupSize = 294, + GetKernelWorkGroupSize = 295, + GetKernelPreferredWorkGroupSizeMultiple = 296, + RetainEvent = 297, + ReleaseEvent = 298, + CreateUserEvent = 299, + IsValidEvent = 300, + SetUserEventStatus = 301, + CaptureEventProfilingInfo = 302, + GetDefaultQueue = 303, + BuildNDRange = 304, + ImageSparseSampleImplicitLod = 305, + ImageSparseSampleExplicitLod = 306, + ImageSparseSampleDrefImplicitLod = 307, + ImageSparseSampleDrefExplicitLod = 308, + ImageSparseSampleProjImplicitLod = 309, + ImageSparseSampleProjExplicitLod = 310, + ImageSparseSampleProjDrefImplicitLod = 311, + ImageSparseSampleProjDrefExplicitLod = 312, + ImageSparseFetch = 313, + ImageSparseGather = 314, + ImageSparseDrefGather = 315, + ImageSparseTexelsResident = 316, + NoLine = 317, + AtomicFlagTestAndSet = 318, + AtomicFlagClear = 319, + ImageSparseRead = 320, + SizeOf = 321, + TypePipeStorage = 322, + ConstantPipeStorage = 323, + CreatePipeFromPipeStorage = 324, + GetKernelLocalSizeForSubgroupCount = 325, + GetKernelMaxNumSubgroups = 326, + TypeNamedBarrier = 327, + NamedBarrierInitialize = 328, + MemoryNamedBarrier = 329, + ModuleProcessed = 330, + ExecutionModeId = 331, + DecorateId = 332, + GroupNonUniformElect = 333, + GroupNonUniformAll = 334, + GroupNonUniformAny = 335, + GroupNonUniformAllEqual = 336, + GroupNonUniformBroadcast = 337, + GroupNonUniformBroadcastFirst = 338, + GroupNonUniformBallot = 339, + GroupNonUniformInverseBallot = 340, + GroupNonUniformBallotBitExtract = 341, + GroupNonUniformBallotBitCount = 342, + GroupNonUniformBallotFindLSB = 343, + GroupNonUniformBallotFindMSB = 344, + GroupNonUniformShuffle = 345, + GroupNonUniformShuffleXor = 346, + GroupNonUniformShuffleUp = 347, + GroupNonUniformShuffleDown = 348, + GroupNonUniformIAdd = 349, + GroupNonUniformFAdd = 350, + GroupNonUniformIMul = 351, + GroupNonUniformFMul = 352, + GroupNonUniformSMin = 353, + GroupNonUniformUMin = 354, + GroupNonUniformFMin = 355, + GroupNonUniformSMax = 356, + GroupNonUniformUMax = 357, + GroupNonUniformFMax = 358, + GroupNonUniformBitwiseAnd = 359, + GroupNonUniformBitwiseOr = 360, + GroupNonUniformBitwiseXor = 361, + GroupNonUniformLogicalAnd = 362, + GroupNonUniformLogicalOr = 363, + GroupNonUniformLogicalXor = 364, + GroupNonUniformQuadBroadcast = 365, + GroupNonUniformQuadSwap = 366, + CopyLogical = 400, + PtrEqual = 401, + PtrNotEqual = 402, + PtrDiff = 403, + TypeCooperativeMatrixKHR = 4456, + CooperativeMatrixLoadKHR = 4457, + CooperativeMatrixStoreKHR = 4458, + CooperativeMatrixMulAddKHR = 4459, + CooperativeMatrixLengthKHR = 4460, + SubgroupBlockReadINTEL = 5575, + SubgroupBlockWriteINTEL = 5576, + AsmTargetINTEL = 5609, + AsmINTEL = 5610, + AsmCallINTEL = 5611, + AtomicFMinEXT = 5614, + AtomicFMaxEXT = 5615, + AtomicFAddEXT = 6035, + ConvertFToBF16INTEL = 6116, + ConvertBF16ToFINTEL = 6117, + ControlBarrierArriveINTEL = 6142, + ControlBarrierWaitINTEL = 6143, + CooperativeMatrixLoadCheckedINTEL = 6193, + CooperativeMatrixStoreCheckedINTEL = 6194, +}; +enum class ImageOperands { + None = 0x0000, + Bias = 0x0001, + Lod = 0x0002, + Grad = 0x0004, + ConstOffset = 0x0008, + Offset = 0x0010, + ConstOffsets = 0x0020, + Sample = 0x0040, + MinLod = 0x0080, + MakeTexelAvailable = 0x0100, + MakeTexelVisible = 0x0200, + NonPrivateTexel = 0x0400, + VolatileTexel = 0x0800, + SignExtend = 0x1000, + ZeroExtend = 0x2000, + Nontemporal = 0x4000, + Offsets = 0x10000, +}; +enum class FPFastMathMode { + None = 0x0000, + NotNaN = 0x0001, + NotInf = 0x0002, + NSZ = 0x0004, + AllowRecip = 0x0008, + Fast = 0x0010, + AllowContract = 0x10000, + AllowReassoc = 0x20000, + AllowTransform = 0x40000, +}; +enum class SelectionControl { + None = 0x0000, + Flatten = 0x0001, + DontFlatten = 0x0002, +}; +enum class LoopControl { + None = 0x0000, + Unroll = 0x0001, + DontUnroll = 0x0002, + DependencyInfinite = 0x0004, + DependencyLength = 0x0008, + MinIterations = 0x0010, + MaxIterations = 0x0020, + IterationMultiple = 0x0040, + PeelCount = 0x0080, + PartialCount = 0x0100, + InitiationIntervalINTEL = 0x10000, + MaxConcurrencyINTEL = 0x20000, + DependencyArrayINTEL = 0x40000, + PipelineEnableINTEL = 0x80000, + LoopCoalesceINTEL = 0x100000, + MaxInterleavingINTEL = 0x200000, + SpeculatedIterationsINTEL = 0x400000, + NoFusionINTEL = 0x800000, + LoopCountINTEL = 0x1000000, + MaxReinvocationDelayINTEL = 0x2000000, +}; +enum class FunctionControl { + None = 0x0000, + Inline = 0x0001, + DontInline = 0x0002, + Pure = 0x0004, + Const = 0x0008, + OptNoneEXT = 0x10000, +}; +enum class MemorySemantics { + Relaxed = 0x0000, + Acquire = 0x0002, + Release = 0x0004, + AcquireRelease = 0x0008, + SequentiallyConsistent = 0x0010, + UniformMemory = 0x0040, + SubgroupMemory = 0x0080, + WorkgroupMemory = 0x0100, + CrossWorkgroupMemory = 0x0200, + AtomicCounterMemory = 0x0400, + ImageMemory = 0x0800, + OutputMemory = 0x1000, + MakeAvailable = 0x2000, + MakeVisible = 0x4000, + Volatile = 0x8000, +}; +enum class MemoryAccess { + None = 0x0000, + Volatile = 0x0001, + Aligned = 0x0002, + Nontemporal = 0x0004, + MakePointerAvailable = 0x0008, + MakePointerVisible = 0x0010, + NonPrivatePointer = 0x0020, + AliasScopeINTELMask = 0x10000, + NoAliasINTELMask = 0x20000, +}; +enum class KernelProfilingInfo { + None = 0x0000, + CmdExecTime = 0x0001, +}; +enum class RayFlags { + NoneKHR = 0x0000, + OpaqueKHR = 0x0001, + NoOpaqueKHR = 0x0002, + TerminateOnFirstHitKHR = 0x0004, + SkipClosestHitShaderKHR = 0x0008, + CullBackFacingTrianglesKHR = 0x0010, + CullFrontFacingTrianglesKHR = 0x0020, + CullOpaqueKHR = 0x0040, + CullNoOpaqueKHR = 0x0080, + SkipTrianglesKHR = 0x0100, + SkipAABBsKHR = 0x0200, + ForceOpacityMicromap2StateEXT = 0x0400, +}; +enum class FragmentShadingRate { + Vertical2Pixels = 0x0001, + Vertical4Pixels = 0x0002, + Horizontal2Pixels = 0x0004, + Horizontal4Pixels = 0x0008, +}; +enum class RawAccessChainOperands { + None = 0x0000, + RobustnessPerComponentNV = 0x0001, + RobustnessPerElementNV = 0x0002, +}; +enum class SourceLanguage { + Unknown = 0, + ESSL = 1, + GLSL = 2, + OpenCL_C = 3, + OpenCL_CPP = 4, + HLSL = 5, + CPP_for_OpenCL = 6, + SYCL = 7, + HERO_C = 8, + NZSL = 9, + WGSL = 10, + Slang = 11, + Zig = 12, + Rust = 13, +}; +enum class ExecutionModel { + Vertex = 0, + TessellationControl = 1, + TessellationEvaluation = 2, + Geometry = 3, + Fragment = 4, + GLCompute = 5, + Kernel = 6, + TaskNV = 5267, + MeshNV = 5268, + RayGenerationKHR = 5313, + IntersectionKHR = 5314, + AnyHitKHR = 5315, + ClosestHitKHR = 5316, + MissKHR = 5317, + CallableKHR = 5318, + TaskEXT = 5364, + MeshEXT = 5365, +}; +enum class AddressingModel { + Logical = 0, + Physical32 = 1, + Physical64 = 2, + PhysicalStorageBuffer64 = 5348, +}; +enum class MemoryModel { + Simple = 0, + GLSL450 = 1, + OpenCL = 2, + Vulkan = 3, +}; +enum class ExecutionMode { + Invocations = 0, + SpacingEqual = 1, + SpacingFractionalEven = 2, + SpacingFractionalOdd = 3, + VertexOrderCw = 4, + VertexOrderCcw = 5, + PixelCenterInteger = 6, + OriginUpperLeft = 7, + OriginLowerLeft = 8, + EarlyFragmentTests = 9, + PointMode = 10, + Xfb = 11, + DepthReplacing = 12, + DepthGreater = 14, + DepthLess = 15, + DepthUnchanged = 16, + LocalSize = 17, + LocalSizeHint = 18, + InputPoints = 19, + InputLines = 20, + InputLinesAdjacency = 21, + Triangles = 22, + InputTrianglesAdjacency = 23, + Quads = 24, + Isolines = 25, + OutputVertices = 26, + OutputPoints = 27, + OutputLineStrip = 28, + OutputTriangleStrip = 29, + VecTypeHint = 30, + ContractionOff = 31, + Initializer = 33, + Finalizer = 34, + SubgroupSize = 35, + SubgroupsPerWorkgroup = 36, + SubgroupsPerWorkgroupId = 37, + LocalSizeId = 38, + LocalSizeHintId = 39, + NonCoherentColorAttachmentReadEXT = 4169, + NonCoherentDepthAttachmentReadEXT = 4170, + NonCoherentStencilAttachmentReadEXT = 4171, + SubgroupUniformControlFlowKHR = 4421, + PostDepthCoverage = 4446, + DenormPreserve = 4459, + DenormFlushToZero = 4460, + SignedZeroInfNanPreserve = 4461, + RoundingModeRTE = 4462, + RoundingModeRTZ = 4463, + NonCoherentTileAttachmentReadQCOM = 4489, + TileShadingRateQCOM = 4490, + EarlyAndLateFragmentTestsAMD = 5017, + StencilRefReplacingEXT = 5027, + CoalescingAMDX = 5069, + IsApiEntryAMDX = 5070, + MaxNodeRecursionAMDX = 5071, + StaticNumWorkgroupsAMDX = 5072, + ShaderIndexAMDX = 5073, + MaxNumWorkgroupsAMDX = 5077, + StencilRefUnchangedFrontAMD = 5079, + StencilRefGreaterFrontAMD = 5080, + StencilRefLessFrontAMD = 5081, + StencilRefUnchangedBackAMD = 5082, + StencilRefGreaterBackAMD = 5083, + StencilRefLessBackAMD = 5084, + QuadDerivativesKHR = 5088, + RequireFullQuadsKHR = 5089, + SharesInputWithAMDX = 5102, + OutputLinesEXT = 5269, + OutputPrimitivesEXT = 5270, + DerivativeGroupQuadsKHR = 5289, + DerivativeGroupLinearKHR = 5290, + OutputTrianglesEXT = 5298, + PixelInterlockOrderedEXT = 5366, + PixelInterlockUnorderedEXT = 5367, + SampleInterlockOrderedEXT = 5368, + SampleInterlockUnorderedEXT = 5369, + ShadingRateInterlockOrderedEXT = 5370, + ShadingRateInterlockUnorderedEXT = 5371, + SharedLocalMemorySizeINTEL = 5618, + RoundingModeRTPINTEL = 5620, + RoundingModeRTNINTEL = 5621, + FloatingPointModeALTINTEL = 5622, + FloatingPointModeIEEEINTEL = 5623, + MaxWorkgroupSizeINTEL = 5893, + MaxWorkDimINTEL = 5894, + NoGlobalOffsetINTEL = 5895, + NumSIMDWorkitemsINTEL = 5896, + SchedulerTargetFmaxMhzINTEL = 5903, + MaximallyReconvergesKHR = 6023, + FPFastMathDefault = 6028, + StreamingInterfaceINTEL = 6154, + RegisterMapInterfaceINTEL = 6160, + NamedBarrierCountINTEL = 6417, + MaximumRegistersINTEL = 6461, + MaximumRegistersIdINTEL = 6462, + NamedMaximumRegistersINTEL = 6463, +}; +enum class StorageClass { + UniformConstant = 0, + Input = 1, + Uniform = 2, + Output = 3, + Workgroup = 4, + CrossWorkgroup = 5, + Private = 6, + Function = 7, + Generic = 8, + PushConstant = 9, + AtomicCounter = 10, + Image = 11, + StorageBuffer = 12, + TileImageEXT = 4172, + TileAttachmentQCOM = 4491, + NodePayloadAMDX = 5068, + CallableDataKHR = 5328, + IncomingCallableDataKHR = 5329, + RayPayloadKHR = 5338, + HitAttributeKHR = 5339, + IncomingRayPayloadKHR = 5342, + ShaderRecordBufferKHR = 5343, + PhysicalStorageBuffer = 5349, + HitObjectAttributeNV = 5385, + TaskPayloadWorkgroupEXT = 5402, + CodeSectionINTEL = 5605, + DeviceOnlyINTEL = 5936, + HostOnlyINTEL = 5937, +}; +enum class Dim { + Dim1D = 0, + Dim2D = 1, + Dim3D = 2, + Cube = 3, + Rect = 4, + Buffer = 5, + SubpassData = 6, + TileImageDataEXT = 4173, +}; +enum class SamplerAddressingMode { + None = 0, + ClampToEdge = 1, + Clamp = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +enum class SamplerFilterMode { + Nearest = 0, + Linear = 1, +}; +enum class ImageFormat { + Unknown = 0, + Rgba32f = 1, + Rgba16f = 2, + R32f = 3, + Rgba8 = 4, + Rgba8Snorm = 5, + Rg32f = 6, + Rg16f = 7, + R11fG11fB10f = 8, + R16f = 9, + Rgba16 = 10, + Rgb10A2 = 11, + Rg16 = 12, + Rg8 = 13, + R16 = 14, + R8 = 15, + Rgba16Snorm = 16, + Rg16Snorm = 17, + Rg8Snorm = 18, + R16Snorm = 19, + R8Snorm = 20, + Rgba32i = 21, + Rgba16i = 22, + Rgba8i = 23, + R32i = 24, + Rg32i = 25, + Rg16i = 26, + Rg8i = 27, + R16i = 28, + R8i = 29, + Rgba32ui = 30, + Rgba16ui = 31, + Rgba8ui = 32, + R32ui = 33, + Rgb10a2ui = 34, + Rg32ui = 35, + Rg16ui = 36, + Rg8ui = 37, + R16ui = 38, + R8ui = 39, + R64ui = 40, + R64i = 41, +}; +enum class ImageChannelOrder { + R = 0, + A = 1, + RG = 2, + RA = 3, + RGB = 4, + RGBA = 5, + BGRA = 6, + ARGB = 7, + Intensity = 8, + Luminance = 9, + Rx = 10, + RGx = 11, + RGBx = 12, + Depth = 13, + DepthStencil = 14, + sRGB = 15, + sRGBx = 16, + sRGBA = 17, + sBGRA = 18, + ABGR = 19, +}; +enum class ImageChannelDataType { + SnormInt8 = 0, + SnormInt16 = 1, + UnormInt8 = 2, + UnormInt16 = 3, + UnormShort565 = 4, + UnormShort555 = 5, + UnormInt101010 = 6, + SignedInt8 = 7, + SignedInt16 = 8, + SignedInt32 = 9, + UnsignedInt8 = 10, + UnsignedInt16 = 11, + UnsignedInt32 = 12, + HalfFloat = 13, + Float = 14, + UnormInt24 = 15, + UnormInt101010_2 = 16, + UnormInt10X6EXT = 17, + UnsignedIntRaw10EXT = 19, + UnsignedIntRaw12EXT = 20, + UnormInt2_101010EXT = 21, + UnsignedInt10X6EXT = 22, + UnsignedInt12X4EXT = 23, + UnsignedInt14X2EXT = 24, + UnormInt12X4EXT = 25, + UnormInt14X2EXT = 26, +}; +enum class FPRoundingMode { + RTE = 0, + RTZ = 1, + RTP = 2, + RTN = 3, +}; +enum class FPDenormMode { + Preserve = 0, + FlushToZero = 1, +}; +enum class QuantizationModes { + TRN = 0, + TRN_ZERO = 1, + RND = 2, + RND_ZERO = 3, + RND_INF = 4, + RND_MIN_INF = 5, + RND_CONV = 6, + RND_CONV_ODD = 7, +}; +enum class FPOperationMode { + IEEE = 0, + ALT = 1, +}; +enum class OverflowModes { + WRAP = 0, + SAT = 1, + SAT_ZERO = 2, + SAT_SYM = 3, +}; +enum class LinkageType { + Export = 0, + Import = 1, + LinkOnceODR = 2, +}; +enum class AccessQualifier { + ReadOnly = 0, + WriteOnly = 1, + ReadWrite = 2, +}; +enum class HostAccessQualifier { + NoneINTEL = 0, + ReadINTEL = 1, + WriteINTEL = 2, + ReadWriteINTEL = 3, +}; +enum class FunctionParameterAttribute { + Zext = 0, + Sext = 1, + ByVal = 2, + Sret = 3, + NoAlias = 4, + NoCapture = 5, + NoWrite = 6, + NoReadWrite = 7, + RuntimeAlignedINTEL = 5940, +}; +enum class Decoration { + RelaxedPrecision = 0, + SpecId = 1, + Block = 2, + BufferBlock = 3, + RowMajor = 4, + ColMajor = 5, + ArrayStride = 6, + MatrixStride = 7, + GLSLShared = 8, + GLSLPacked = 9, + CPacked = 10, + BuiltIn = 11, + NoPerspective = 13, + Flat = 14, + Patch = 15, + Centroid = 16, + Sample = 17, + Invariant = 18, + Restrict = 19, + Aliased = 20, + Volatile = 21, + Constant = 22, + Coherent = 23, + NonWritable = 24, + NonReadable = 25, + Uniform = 26, + UniformId = 27, + SaturatedConversion = 28, + Stream = 29, + Location = 30, + Component = 31, + Index = 32, + Binding = 33, + DescriptorSet = 34, + Offset = 35, + XfbBuffer = 36, + XfbStride = 37, + FuncParamAttr = 38, + FPRoundingMode = 39, + FPFastMathMode = 40, + LinkageAttributes = 41, + NoContraction = 42, + InputAttachmentIndex = 43, + Alignment = 44, + MaxByteOffset = 45, + AlignmentId = 46, + MaxByteOffsetId = 47, + SaturatedToLargestFloat8NormalConversionEXT = 4216, + NoSignedWrap = 4469, + NoUnsignedWrap = 4470, + WeightTextureQCOM = 4487, + BlockMatchTextureQCOM = 4488, + BlockMatchSamplerQCOM = 4499, + ExplicitInterpAMD = 4999, + NodeSharesPayloadLimitsWithAMDX = 5019, + NodeMaxPayloadsAMDX = 5020, + TrackFinishWritingAMDX = 5078, + PayloadNodeNameAMDX = 5091, + PayloadNodeBaseIndexAMDX = 5098, + PayloadNodeSparseArrayAMDX = 5099, + PayloadNodeArraySizeAMDX = 5100, + PayloadDispatchIndirectAMDX = 5105, + OverrideCoverageNV = 5248, + PassthroughNV = 5250, + ViewportRelativeNV = 5252, + SecondaryViewportRelativeNV = 5256, + PerPrimitiveEXT = 5271, + PerViewNV = 5272, + PerTaskNV = 5273, + PerVertexKHR = 5285, + NonUniform = 5300, + RestrictPointer = 5355, + AliasedPointer = 5356, + HitObjectShaderRecordBufferNV = 5386, + BindlessSamplerNV = 5398, + BindlessImageNV = 5399, + BoundSamplerNV = 5400, + BoundImageNV = 5401, + SIMTCallINTEL = 5599, + ReferencedIndirectlyINTEL = 5602, + ClobberINTEL = 5607, + SideEffectsINTEL = 5608, + VectorComputeVariableINTEL = 5624, + FuncParamIOKindINTEL = 5625, + VectorComputeFunctionINTEL = 5626, + StackCallINTEL = 5627, + GlobalVariableOffsetINTEL = 5628, + CounterBuffer = 5634, + UserSemantic = 5635, + UserTypeGOOGLE = 5636, + FunctionRoundingModeINTEL = 5822, + FunctionDenormModeINTEL = 5823, + RegisterINTEL = 5825, + MemoryINTEL = 5826, + NumbanksINTEL = 5827, + BankwidthINTEL = 5828, + MaxPrivateCopiesINTEL = 5829, + SinglepumpINTEL = 5830, + DoublepumpINTEL = 5831, + MaxReplicatesINTEL = 5832, + SimpleDualPortINTEL = 5833, + MergeINTEL = 5834, + BankBitsINTEL = 5835, + ForcePow2DepthINTEL = 5836, + StridesizeINTEL = 5883, + WordsizeINTEL = 5884, + TrueDualPortINTEL = 5885, + BurstCoalesceINTEL = 5899, + CacheSizeINTEL = 5900, + DontStaticallyCoalesceINTEL = 5901, + PrefetchINTEL = 5902, + StallEnableINTEL = 5905, + FuseLoopsInFunctionINTEL = 5907, + MathOpDSPModeINTEL = 5909, + AliasScopeINTEL = 5914, + NoAliasINTEL = 5915, + InitiationIntervalINTEL = 5917, + MaxConcurrencyINTEL = 5918, + PipelineEnableINTEL = 5919, + BufferLocationINTEL = 5921, + IOPipeStorageINTEL = 5944, + FunctionFloatingPointModeINTEL = 6080, + SingleElementVectorINTEL = 6085, + VectorComputeCallableFunctionINTEL = 6087, + MediaBlockIOINTEL = 6140, + StallFreeINTEL = 6151, + FPMaxErrorDecorationINTEL = 6170, + LatencyControlLabelINTEL = 6172, + LatencyControlConstraintINTEL = 6173, + ConduitKernelArgumentINTEL = 6175, + RegisterMapKernelArgumentINTEL = 6176, + MMHostInterfaceAddressWidthINTEL = 6177, + MMHostInterfaceDataWidthINTEL = 6178, + MMHostInterfaceLatencyINTEL = 6179, + MMHostInterfaceReadWriteModeINTEL = 6180, + MMHostInterfaceMaxBurstINTEL = 6181, + MMHostInterfaceWaitRequestINTEL = 6182, + StableKernelArgumentINTEL = 6183, + HostAccessINTEL = 6188, + InitModeINTEL = 6190, + ImplementInRegisterMapINTEL = 6191, + CacheControlLoadINTEL = 6442, + CacheControlStoreINTEL = 6443, +}; +enum class BuiltIn { + Position = 0, + PointSize = 1, + ClipDistance = 3, + CullDistance = 4, + VertexId = 5, + InstanceId = 6, + PrimitiveId = 7, + InvocationId = 8, + Layer = 9, + ViewportIndex = 10, + TessLevelOuter = 11, + TessLevelInner = 12, + TessCoord = 13, + PatchVertices = 14, + FragCoord = 15, + PointCoord = 16, + FrontFacing = 17, + SampleId = 18, + SamplePosition = 19, + SampleMask = 20, + FragDepth = 22, + HelperInvocation = 23, + NumWorkgroups = 24, + WorkgroupSize = 25, + WorkgroupId = 26, + LocalInvocationId = 27, + GlobalInvocationId = 28, + LocalInvocationIndex = 29, + WorkDim = 30, + GlobalSize = 31, + EnqueuedWorkgroupSize = 32, + GlobalOffset = 33, + GlobalLinearId = 34, + SubgroupSize = 36, + SubgroupMaxSize = 37, + NumSubgroups = 38, + NumEnqueuedSubgroups = 39, + SubgroupId = 40, + SubgroupLocalInvocationId = 41, + VertexIndex = 42, + InstanceIndex = 43, + CoreIDARM = 4160, + CoreCountARM = 4161, + CoreMaxIDARM = 4162, + WarpIDARM = 4163, + WarpMaxIDARM = 4164, + SubgroupEqMask = 4416, + SubgroupGeMask = 4417, + SubgroupGtMask = 4418, + SubgroupLeMask = 4419, + SubgroupLtMask = 4420, + BaseVertex = 4424, + BaseInstance = 4425, + DrawIndex = 4426, + PrimitiveShadingRateKHR = 4432, + DeviceIndex = 4438, + ViewIndex = 4440, + ShadingRateKHR = 4444, + TileOffsetQCOM = 4492, + TileDimensionQCOM = 4493, + TileApronSizeQCOM = 4494, + BaryCoordNoPerspAMD = 4992, + BaryCoordNoPerspCentroidAMD = 4993, + BaryCoordNoPerspSampleAMD = 4994, + BaryCoordSmoothAMD = 4995, + BaryCoordSmoothCentroidAMD = 4996, + BaryCoordSmoothSampleAMD = 4997, + BaryCoordPullModelAMD = 4998, + FragStencilRefEXT = 5014, + RemainingRecursionLevelsAMDX = 5021, + ShaderIndexAMDX = 5073, + ViewportMaskNV = 5253, + SecondaryPositionNV = 5257, + SecondaryViewportMaskNV = 5258, + PositionPerViewNV = 5261, + ViewportMaskPerViewNV = 5262, + FullyCoveredEXT = 5264, + TaskCountNV = 5274, + PrimitiveCountNV = 5275, + PrimitiveIndicesNV = 5276, + ClipDistancePerViewNV = 5277, + CullDistancePerViewNV = 5278, + LayerPerViewNV = 5279, + MeshViewCountNV = 5280, + MeshViewIndicesNV = 5281, + BaryCoordKHR = 5286, + BaryCoordNoPerspKHR = 5287, + FragSizeEXT = 5292, + FragInvocationCountEXT = 5293, + PrimitivePointIndicesEXT = 5294, + PrimitiveLineIndicesEXT = 5295, + PrimitiveTriangleIndicesEXT = 5296, + CullPrimitiveEXT = 5299, + LaunchIdKHR = 5319, + LaunchSizeKHR = 5320, + WorldRayOriginKHR = 5321, + WorldRayDirectionKHR = 5322, + ObjectRayOriginKHR = 5323, + ObjectRayDirectionKHR = 5324, + RayTminKHR = 5325, + RayTmaxKHR = 5326, + InstanceCustomIndexKHR = 5327, + ObjectToWorldKHR = 5330, + WorldToObjectKHR = 5331, + HitTNV = 5332, + HitKindKHR = 5333, + CurrentRayTimeNV = 5334, + HitTriangleVertexPositionsKHR = 5335, + HitMicroTriangleVertexPositionsNV = 5337, + HitMicroTriangleVertexBarycentricsNV = 5344, + IncomingRayFlagsKHR = 5351, + RayGeometryIndexKHR = 5352, + HitIsSphereNV = 5359, + HitIsLSSNV = 5360, + HitSpherePositionNV = 5361, + WarpsPerSMNV = 5374, + SMCountNV = 5375, + WarpIDNV = 5376, + SMIDNV = 5377, + HitLSSPositionsNV = 5396, + HitKindFrontFacingMicroTriangleNV = 5405, + HitKindBackFacingMicroTriangleNV = 5406, + HitSphereRadiusNV = 5420, + HitLSSRadiiNV = 5421, + ClusterIDNV = 5436, + CullMaskKHR = 6021, +}; +enum class Scope { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamily = 5, + ShaderCallKHR = 6, +}; +enum class GroupOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, + ClusteredReduce = 3, + PartitionedReduceNV = 6, + PartitionedInclusiveScanNV = 7, + PartitionedExclusiveScanNV = 8, +}; +enum class KernelEnqueueFlags { + NoWait = 0, + WaitKernel = 1, + WaitWorkGroup = 2, +}; +enum class Capability { + Matrix = 0, + Shader = 1, + Geometry = 2, + Tessellation = 3, + Addresses = 4, + Linkage = 5, + Kernel = 6, + Vector16 = 7, + Float16Buffer = 8, + Float16 = 9, + Float64 = 10, + Int64 = 11, + Int64Atomics = 12, + ImageBasic = 13, + ImageReadWrite = 14, + ImageMipmap = 15, + Pipes = 17, + Groups = 18, + DeviceEnqueue = 19, + LiteralSampler = 20, + AtomicStorage = 21, + Int16 = 22, + TessellationPointSize = 23, + GeometryPointSize = 24, + ImageGatherExtended = 25, + StorageImageMultisample = 27, + UniformBufferArrayDynamicIndexing = 28, + SampledImageArrayDynamicIndexing = 29, + StorageBufferArrayDynamicIndexing = 30, + StorageImageArrayDynamicIndexing = 31, + ClipDistance = 32, + CullDistance = 33, + ImageCubeArray = 34, + SampleRateShading = 35, + ImageRect = 36, + SampledRect = 37, + GenericPointer = 38, + Int8 = 39, + InputAttachment = 40, + SparseResidency = 41, + MinLod = 42, + Sampled1D = 43, + Image1D = 44, + SampledCubeArray = 45, + SampledBuffer = 46, + ImageBuffer = 47, + ImageMSArray = 48, + StorageImageExtendedFormats = 49, + ImageQuery = 50, + DerivativeControl = 51, + InterpolationFunction = 52, + TransformFeedback = 53, + GeometryStreams = 54, + StorageImageReadWithoutFormat = 55, + StorageImageWriteWithoutFormat = 56, + MultiViewport = 57, + SubgroupDispatch = 58, + NamedBarrier = 59, + PipeStorage = 60, + GroupNonUniform = 61, + GroupNonUniformVote = 62, + GroupNonUniformArithmetic = 63, + GroupNonUniformBallot = 64, + GroupNonUniformShuffle = 65, + GroupNonUniformShuffleRelative = 66, + GroupNonUniformClustered = 67, + GroupNonUniformQuad = 68, + ShaderLayer = 69, + ShaderViewportIndex = 70, + UniformDecoration = 71, + CoreBuiltinsARM = 4165, + TileImageColorReadAccessEXT = 4166, + TileImageDepthReadAccessEXT = 4167, + TileImageStencilReadAccessEXT = 4168, + TensorsARM = 4174, + StorageTensorArrayDynamicIndexingARM = 4175, + StorageTensorArrayNonUniformIndexingARM = 4176, + CooperativeMatrixLayoutsARM = 4201, + Float8EXT = 4212, + Float8CooperativeMatrixEXT = 4213, + FragmentShadingRateKHR = 4422, + SubgroupBallotKHR = 4423, + DrawParameters = 4427, + WorkgroupMemoryExplicitLayoutKHR = 4428, + WorkgroupMemoryExplicitLayout8BitAccessKHR = 4429, + WorkgroupMemoryExplicitLayout16BitAccessKHR = 4430, + SubgroupVoteKHR = 4431, + StorageBuffer16BitAccess = 4433, + UniformAndStorageBuffer16BitAccess = 4434, + StoragePushConstant16 = 4435, + StorageInputOutput16 = 4436, + DeviceGroup = 4437, + MultiView = 4439, + VariablePointersStorageBuffer = 4441, + VariablePointers = 4442, + AtomicStorageOps = 4445, + SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, + DenormPreserve = 4464, + DenormFlushToZero = 4465, + SignedZeroInfNanPreserve = 4466, + RoundingModeRTE = 4467, + RoundingModeRTZ = 4468, + RayQueryProvisionalKHR = 4471, + RayQueryKHR = 4472, + UntypedPointersKHR = 4473, + RayTraversalPrimitiveCullingKHR = 4478, + RayTracingKHR = 4479, + TextureSampleWeightedQCOM = 4484, + TextureBoxFilterQCOM = 4485, + TextureBlockMatchQCOM = 4486, + TileShadingQCOM = 4495, + TextureBlockMatch2QCOM = 4498, + Float16ImageAMD = 5008, + ImageGatherBiasLodAMD = 5009, + FragmentMaskAMD = 5010, + StencilExportEXT = 5013, + ImageReadWriteLodAMD = 5015, + Int64ImageEXT = 5016, + ShaderClockKHR = 5055, + ShaderEnqueueAMDX = 5067, + QuadControlKHR = 5087, + Int4TypeINTEL = 5112, + Int4CooperativeMatrixINTEL = 5114, + BFloat16TypeKHR = 5116, + BFloat16DotProductKHR = 5117, + BFloat16CooperativeMatrixKHR = 5118, + SampleMaskOverrideCoverageNV = 5249, + GeometryShaderPassthroughNV = 5251, + ShaderViewportIndexLayerEXT = 5254, + ShaderViewportMaskNV = 5255, + ShaderStereoViewNV = 5259, + PerViewAttributesNV = 5260, + FragmentFullyCoveredEXT = 5265, + MeshShadingNV = 5266, + ImageFootprintNV = 5282, + MeshShadingEXT = 5283, + FragmentBarycentricKHR = 5284, + ComputeDerivativeGroupQuadsKHR = 5288, + FragmentDensityEXT = 5291, + GroupNonUniformPartitionedNV = 5297, + ShaderNonUniform = 5301, + RuntimeDescriptorArray = 5302, + InputAttachmentArrayDynamicIndexing = 5303, + UniformTexelBufferArrayDynamicIndexing = 5304, + StorageTexelBufferArrayDynamicIndexing = 5305, + UniformBufferArrayNonUniformIndexing = 5306, + SampledImageArrayNonUniformIndexing = 5307, + StorageBufferArrayNonUniformIndexing = 5308, + StorageImageArrayNonUniformIndexing = 5309, + InputAttachmentArrayNonUniformIndexing = 5310, + UniformTexelBufferArrayNonUniformIndexing = 5311, + StorageTexelBufferArrayNonUniformIndexing = 5312, + RayTracingPositionFetchKHR = 5336, + RayTracingNV = 5340, + RayTracingMotionBlurNV = 5341, + VulkanMemoryModel = 5345, + VulkanMemoryModelDeviceScope = 5346, + PhysicalStorageBufferAddresses = 5347, + ComputeDerivativeGroupLinearKHR = 5350, + RayTracingProvisionalKHR = 5353, + CooperativeMatrixNV = 5357, + FragmentShaderSampleInterlockEXT = 5363, + FragmentShaderShadingRateInterlockEXT = 5372, + ShaderSMBuiltinsNV = 5373, + FragmentShaderPixelInterlockEXT = 5378, + DemoteToHelperInvocation = 5379, + DisplacementMicromapNV = 5380, + RayTracingOpacityMicromapEXT = 5381, + ShaderInvocationReorderNV = 5383, + BindlessTextureNV = 5390, + RayQueryPositionFetchKHR = 5391, + CooperativeVectorNV = 5394, + AtomicFloat16VectorNV = 5404, + RayTracingDisplacementMicromapNV = 5409, + RawAccessChainsNV = 5414, + RayTracingSpheresGeometryNV = 5418, + RayTracingLinearSweptSpheresGeometryNV = 5419, + CooperativeMatrixReductionsNV = 5430, + CooperativeMatrixConversionsNV = 5431, + CooperativeMatrixPerElementOperationsNV = 5432, + CooperativeMatrixTensorAddressingNV = 5433, + CooperativeMatrixBlockLoadsNV = 5434, + CooperativeVectorTrainingNV = 5435, + RayTracingClusterAccelerationStructureNV = 5437, + TensorAddressingNV = 5439, + SubgroupShuffleINTEL = 5568, + SubgroupBufferBlockIOINTEL = 5569, + SubgroupImageBlockIOINTEL = 5570, + SubgroupImageMediaBlockIOINTEL = 5579, + RoundToInfinityINTEL = 5582, + FloatingPointModeINTEL = 5583, + IntegerFunctions2INTEL = 5584, + FunctionPointersINTEL = 5603, + IndirectReferencesINTEL = 5604, + AsmINTEL = 5606, + AtomicFloat32MinMaxEXT = 5612, + AtomicFloat64MinMaxEXT = 5613, + AtomicFloat16MinMaxEXT = 5616, + VectorComputeINTEL = 5617, + VectorAnyINTEL = 5619, + ExpectAssumeKHR = 5629, + SubgroupAvcMotionEstimationINTEL = 5696, + SubgroupAvcMotionEstimationIntraINTEL = 5697, + SubgroupAvcMotionEstimationChromaINTEL = 5698, + VariableLengthArrayINTEL = 5817, + FunctionFloatControlINTEL = 5821, + FPGAMemoryAttributesINTEL = 5824, + FPFastMathModeINTEL = 5837, + ArbitraryPrecisionIntegersINTEL = 5844, + ArbitraryPrecisionFloatingPointINTEL = 5845, + UnstructuredLoopControlsINTEL = 5886, + FPGALoopControlsINTEL = 5888, + KernelAttributesINTEL = 5892, + FPGAKernelAttributesINTEL = 5897, + FPGAMemoryAccessesINTEL = 5898, + FPGAClusterAttributesINTEL = 5904, + LoopFuseINTEL = 5906, + FPGADSPControlINTEL = 5908, + MemoryAccessAliasingINTEL = 5910, + FPGAInvocationPipeliningAttributesINTEL = 5916, + FPGABufferLocationINTEL = 5920, + ArbitraryPrecisionFixedPointINTEL = 5922, + USMStorageClassesINTEL = 5935, + RuntimeAlignedAttributeINTEL = 5939, + IOPipesINTEL = 5943, + BlockingPipesINTEL = 5945, + FPGARegINTEL = 5948, + DotProductInputAll = 6016, + DotProductInput4x8Bit = 6017, + DotProductInput4x8BitPacked = 6018, + DotProduct = 6019, + RayCullMaskKHR = 6020, + CooperativeMatrixKHR = 6022, + ReplicatedCompositesEXT = 6024, + BitInstructions = 6025, + GroupNonUniformRotateKHR = 6026, + FloatControls2 = 6029, + AtomicFloat32AddEXT = 6033, + AtomicFloat64AddEXT = 6034, + LongCompositesINTEL = 6089, + OptNoneEXT = 6094, + AtomicFloat16AddEXT = 6095, + DebugInfoModuleINTEL = 6114, + BFloat16ConversionINTEL = 6115, + SplitBarrierINTEL = 6141, + ArithmeticFenceEXT = 6144, + FPGAClusterAttributesV2INTEL = 6150, + FPGAKernelAttributesv2INTEL = 6161, + TaskSequenceINTEL = 6162, + FPMaxErrorINTEL = 6169, + FPGALatencyControlINTEL = 6171, + FPGAArgumentInterfacesINTEL = 6174, + GlobalVariableHostAccessINTEL = 6187, + GlobalVariableFPGADecorationsINTEL = 6189, + SubgroupBufferPrefetchINTEL = 6220, + Subgroup2DBlockIOINTEL = 6228, + Subgroup2DBlockTransformINTEL = 6229, + Subgroup2DBlockTransposeINTEL = 6230, + SubgroupMatrixMultiplyAccumulateINTEL = 6236, + TernaryBitwiseFunctionINTEL = 6241, + GroupUniformArithmeticKHR = 6400, + TensorFloat32RoundingINTEL = 6425, + MaskedGatherScatterINTEL = 6427, + CacheControlsINTEL = 6441, + RegisterLimitsINTEL = 6460, + BindlessImagesINTEL = 6528, + PackedCooperativeMatrixINTEL = 6434, + CooperativeMatrixInvocationInstructionsINTEL = 6435, + CooperativeMatrixTF32ComponentTypeINTEL = 6436, + CooperativeMatrixBFloat16ComponentTypeINTEL = 6437, + CooperativeMatrixCheckedInstructionsINTEL = 6192, + CooperativeMatrixPrefetchINTEL = 6449, +}; +enum class RayQueryIntersection { + RayQueryCandidateIntersectionKHR = 0, + RayQueryCommittedIntersectionKHR = 1, +}; +enum class RayQueryCommittedIntersectionType { + RayQueryCommittedIntersectionNoneKHR = 0, + RayQueryCommittedIntersectionTriangleKHR = 1, + RayQueryCommittedIntersectionGeneratedKHR = 2, +}; +enum class RayQueryCandidateIntersectionType { + RayQueryCandidateIntersectionTriangleKHR = 0, + RayQueryCandidateIntersectionAABBKHR = 1, +}; +enum class PackedVectorFormat { + PackedVectorFormat4x8Bit = 0, +}; +enum class CooperativeMatrixOperands { + NoneKHR = 0x0000, + MatrixASignedComponentsKHR = 0x0001, + MatrixBSignedComponentsKHR = 0x0002, + MatrixCSignedComponentsKHR = 0x0004, + MatrixResultSignedComponentsKHR = 0x0008, + SaturatingAccumulationKHR = 0x0010, +}; +enum class CooperativeMatrixLayout { + RowMajorKHR = 0, + ColumnMajorKHR = 1, + RowBlockedInterleavedARM = 4202, + ColumnBlockedInterleavedARM = 4203, +}; +enum class CooperativeMatrixUse { + MatrixAKHR = 0, + MatrixBKHR = 1, + MatrixAccumulatorKHR = 2, +}; +enum class CooperativeMatrixReduce { + Row = 0x0001, + Column = 0x0002, + CooperativeMatrixReduce2x2 = 0x0004, +}; +enum class TensorClampMode { + Undefined = 0, + Constant = 1, + ClampToEdge = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +enum class TensorAddressingOperands { + None = 0x0000, + TensorView = 0x0001, + DecodeFunc = 0x0002, +}; +enum class InitializationModeQualifier { + InitOnDeviceReprogramINTEL = 0, + InitOnDeviceResetINTEL = 1, +}; +enum class LoadCacheControl { + UncachedINTEL = 0, + CachedINTEL = 1, + StreamingINTEL = 2, + InvalidateAfterReadINTEL = 3, + ConstCachedINTEL = 4, +}; +enum class StoreCacheControl { + UncachedINTEL = 0, + WriteThroughINTEL = 1, + WriteBackINTEL = 2, + StreamingINTEL = 3, +}; +enum class NamedMaximumNumberOfRegisters { + AutoINTEL = 0, +}; +enum class MatrixMultiplyAccumulateOperands { + None = 0x0, + MatrixASignedComponentsINTEL = 0x1, + MatrixBSignedComponentsINTEL = 0x2, + MatrixCBFloat16INTEL = 0x4, + MatrixResultBFloat16INTEL = 0x8, + MatrixAPackedInt8INTEL = 0x10, + MatrixBPackedInt8INTEL = 0x20, + MatrixAPackedInt4INTEL = 0x40, + MatrixBPackedInt4INTEL = 0x80, + MatrixATF32INTEL = 0x100, + MatrixBTF32INTEL = 0x200, + MatrixAPackedFloat16INTEL = 0x400, + MatrixBPackedFloat16INTEL = 0x800, + MatrixAPackedBFloat16INTEL = 0x1000, + MatrixBPackedBFloat16INTEL = 0x2000, +}; +enum class FPEncoding { + BFloat16KHR = 0, + Float8E4M3EXT = 4214, + Float8E5M2EXT = 4215, +}; +enum class CooperativeVectorMatrixLayout { + RowMajorNV = 0, + ColumnMajorNV = 1, + InferencingOptimalNV = 2, + TrainingOptimalNV = 3, +}; +enum class ComponentType { + Float16NV = 0, + Float32NV = 1, + Float64NV = 2, + SignedInt8NV = 3, + SignedInt16NV = 4, + SignedInt32NV = 5, + SignedInt64NV = 6, + UnsignedInt8NV = 7, + UnsignedInt16NV = 8, + UnsignedInt32NV = 9, + UnsignedInt64NV = 10, + SignedInt8PackedNV = 1000491000, + UnsignedInt8PackedNV = 1000491001, + FloatE4M3NV = 1000491002, + FloatE5M2NV = 1000491003, +}; +enum class TensorOperands { + NoneARM = 0x0000, + NontemporalARM = 0x0001, + OutOfBoundsValueARM = 0x0002, + MakeElementAvailableARM = 0x0004, + MakeElementVisibleARM = 0x0008, + NonPrivateElementARM = 0x0010, +}; + +} // namespace tinytc::spv + +#endif // GENERATED_ENUMS_20250605_HPP diff --git a/src/spv/inst_assembler.cpp b/src/spv/inst_assembler.cpp new file mode 100644 index 00000000..ac09f3a5 --- /dev/null +++ b/src/spv/inst_assembler.cpp @@ -0,0 +1,68 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/inst_assembler.hpp" + +#include +#include +#include + +namespace tinytc::spv { + +enum class LinkageType; + +inst_assembler::inst_assembler(word_stream &stream) : stream_{&stream} {} + +void inst_assembler::operator()(DecorationAttr const &da) { + std::visit(overloaded{[&](auto const &a) { this->operator()(a); }, + [&](std::pair const &a) { + *stream_ << a.first; + this->operator()(a.second); + }}, + da); +} +void inst_assembler::operator()(ExecutionModeAttr const &ea) { + std::visit(overloaded{[&](std::int32_t const &a) { *stream_ << a; }, + [&](std::array const &a) { + for (auto const &s : a) { + *stream_ << s; + } + }}, + ea); +} +void inst_assembler::operator()(LiteralContextDependentNumber const &l) { + std::visit(overloaded{[&](auto const &l) { *stream_ << l; }}, l); +} +void inst_assembler::operator()(LiteralInteger const &l) { *stream_ << l; } +void inst_assembler::operator()(LiteralString const &l) { *stream_ << l; } + +void inst_assembler::operator()(PairIdRefIdRef const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void inst_assembler::operator()(PairIdRefLiteralInteger const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void inst_assembler::operator()(PairLiteralIntegerIdRef const &p) { + std::visit(overloaded{[&](auto const &l) { *stream_ << l; }}, p.first); + this->operator()(p.second); +} + +void inst_assembler::pre_visit(spv_inst const &) { + *stream_ << std::int32_t{0}; + last_opcode_pos_ = stream_->tell(); +} + +void inst_assembler::visit_result(spv_inst const &in) { *stream_ << in.id(); } + +void inst_assembler::post_visit(spv_inst const &in) { + const std::int32_t word_count = stream_->tell() - last_opcode_pos_ + 1; + const auto ophead = (word_count << 16) | static_cast(in.opcode()); + stream_->update(last_opcode_pos_, ophead); +} + +void inst_assembler::operator()(spv_inst *const &in) { *stream_ << in->id(); } + +} // namespace tinytc::spv + diff --git a/src/spv/inst_assembler.hpp b/src/spv/inst_assembler.hpp new file mode 100644 index 00000000..f86149a5 --- /dev/null +++ b/src/spv/inst_assembler.hpp @@ -0,0 +1,96 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INST_ASSEMBLER_20241111_HPP +#define INST_ASSEMBLER_20241111_HPP + +#include "spv/defs.hpp" +#include "spv/visit.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +enum class BuiltIn; +enum class LinkageType; + +template class word_stream { + public: + word_stream(std::vector &vec) : vec_{&vec} {} + + template auto operator<<(T const &t) -> word_stream & { + const std::size_t insert_pos = vec_->size() / sizeof(WordT); + vec_->resize(vec_->size() + word_count(t) * sizeof(WordT)); + update(insert_pos, t); + return *this; + } + + template static auto word_count(T const &) -> std::size_t { + return 1 + (sizeof(T) - 1) / sizeof(WordT); // ceil(sizeof(T)/sizeof(WordT)) + } + + static auto word_count(std::string const &s) -> std::size_t { + return 1 + s.size() / sizeof(WordT); // ceil((s.size()+1)/sizeof(WordT)) + } + + template auto update(std::size_t word, T const &t) -> word_stream & { + const std::size_t addr = word * sizeof(WordT); + std::memcpy(vec_->data() + addr, &t, sizeof(T)); + return *this; + } + + auto update(std::size_t word, std::string const &s) -> word_stream & { + const std::size_t addr = word * sizeof(WordT); + std::memcpy(vec_->data() + addr, s.c_str(), s.size() + 1); + return *this; + } + + //! Returns last word position + auto tell() const -> std::size_t { + const auto size = vec_->size(); + return size > 0 ? size / sizeof(WordT) - 1 : 0; + } + + private: + std::vector *vec_; +}; + +class inst_assembler : public default_visitor { + public: + using default_visitor::operator(); + + inst_assembler(word_stream &stream); + + void pre_visit(spv_inst const &in); + void visit_result(spv_inst const &in); + void post_visit(spv_inst const &in); + + template + requires std::is_enum_v + void operator()(T const &t) { + *stream_ << static_cast(t); + } + void operator()(DecorationAttr const &da); + void operator()(ExecutionModeAttr const &ea); + void operator()(LiteralContextDependentNumber const &l); + void operator()(LiteralInteger const &l); + void operator()(LiteralString const &l); + + void operator()(PairIdRefIdRef const &p); + void operator()(PairIdRefLiteralInteger const &p); + void operator()(PairLiteralIntegerIdRef const &p); + + void operator()(spv_inst *const &in); + + private: + word_stream *stream_; + std::size_t last_opcode_pos_ = 0; +}; + +} // namespace tinytc::spv + +#endif // INST_ASSEMBLER_20241111_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp new file mode 100644 index 00000000..09c20b7e --- /dev/null +++ b/src/spv/instructions.hpp @@ -0,0 +1,7008 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_INSTRUCTIONS_20250605_HPP +#define GENERATED_INSTRUCTIONS_20250605_HPP + +#include "defs.hpp" +#include "enums.hpp" +#include "error.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +class OpNop : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Nop; } + OpNop() : spv_inst{Op::Nop, false} {} + + private: +}; +class OpUndef : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Undef; } + OpUndef(IdResultType type) : spv_inst{Op::Undef, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpSourceContinued : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceContinued; } + OpSourceContinued(LiteralString op0) + : spv_inst{Op::SourceContinued, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpSource : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Source; } + OpSource(SourceLanguage op0, LiteralInteger op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt) + : spv_inst{Op::Source, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> SourceLanguage & { return op0_; } + inline auto op0() const -> SourceLanguage const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + SourceLanguage op0_; + LiteralInteger op1_; + std::optional op2_; + std::optional op3_; +}; +class OpSourceExtension : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceExtension; } + OpSourceExtension(LiteralString op0) + : spv_inst{Op::SourceExtension, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpName : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Name; } + OpName(IdRef op0, LiteralString op1) + : spv_inst{Op::Name, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralString & { return op1_; } + inline auto op1() const -> LiteralString const & { return op1_; } + + private: + IdRef op0_; + LiteralString op1_; +}; +class OpMemberName : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemberName; } + OpMemberName(IdRef op0, LiteralInteger op1, LiteralString op2) + : spv_inst{Op::MemberName, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } + inline auto op2() const -> LiteralString const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + LiteralString op2_; +}; +class OpString : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::String; } + OpString(LiteralString op0) : spv_inst{Op::String, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpLine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Line; } + OpLine(IdRef op0, LiteralInteger op1, LiteralInteger op2) + : spv_inst{Op::Line, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + LiteralInteger op2_; +}; +class OpExtension : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Extension; } + OpExtension(LiteralString op0) : spv_inst{Op::Extension, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExtInstImport : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInstImport; } + OpExtInstImport(LiteralString op0) : spv_inst{Op::ExtInstImport, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExtInst : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInst; } + OpExtInst(IdResultType type, IdRef op0, LiteralExtInstInteger op1, std::vector op2) + : spv_inst{Op::ExtInst, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralExtInstInteger & { return op1_; } + inline auto op1() const -> LiteralExtInstInteger const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + LiteralExtInstInteger op1_; + std::vector op2_; +}; +class OpMemoryModel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryModel; } + OpMemoryModel(AddressingModel op0, MemoryModel op1) + : spv_inst{Op::MemoryModel, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> AddressingModel & { return op0_; } + inline auto op0() const -> AddressingModel const & { return op0_; } + inline auto op1() -> MemoryModel & { return op1_; } + inline auto op1() const -> MemoryModel const & { return op1_; } + + private: + AddressingModel op0_; + MemoryModel op1_; +}; +class OpEntryPoint : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EntryPoint; } + OpEntryPoint(ExecutionModel op0, IdRef op1, LiteralString op2, std::vector op3) + : spv_inst{Op::EntryPoint, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> ExecutionModel & { return op0_; } + inline auto op0() const -> ExecutionModel const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } + inline auto op2() const -> LiteralString const & { return op2_; } + inline auto op3() -> std::vector & { return op3_; } + inline auto op3() const -> std::vector const & { return op3_; } + + private: + ExecutionModel op0_; + IdRef op1_; + LiteralString op2_; + std::vector op3_; +}; +class OpExecutionMode : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionMode; } + OpExecutionMode(IdRef op0, ExecutionMode op1, ExecutionModeAttr op2) + : spv_inst{Op::ExecutionMode, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> ExecutionMode & { return op1_; } + inline auto op1() const -> ExecutionMode const & { return op1_; } + inline auto op2() -> ExecutionModeAttr & { return op2_; } + inline auto op2() const -> ExecutionModeAttr const & { return op2_; } + + private: + IdRef op0_; + ExecutionMode op1_; + ExecutionModeAttr op2_; +}; +class OpCapability : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Capability; } + OpCapability(Capability op0) : spv_inst{Op::Capability, false}, op0_(std::move(op0)) {} + inline auto op0() -> Capability & { return op0_; } + inline auto op0() const -> Capability const & { return op0_; } + + private: + Capability op0_; +}; +class OpTypeVoid : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVoid; } + OpTypeVoid() : spv_inst{Op::TypeVoid, true} {} + + private: +}; +class OpTypeBool : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeBool; } + OpTypeBool() : spv_inst{Op::TypeBool, true} {} + + private: +}; +class OpTypeInt : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeInt; } + OpTypeInt(LiteralInteger op0, LiteralInteger op1) + : spv_inst{Op::TypeInt, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> LiteralInteger & { return op0_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + LiteralInteger op0_; + LiteralInteger op1_; +}; +class OpTypeFloat : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFloat; } + OpTypeFloat(LiteralInteger op0, std::optional op1 = std::nullopt) + : spv_inst{Op::TypeFloat, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> LiteralInteger & { return op0_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + LiteralInteger op0_; + std::optional op1_; +}; +class OpTypeVector : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVector; } + OpTypeVector(IdRef op0, LiteralInteger op1) + : spv_inst{Op::TypeVector, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpTypeMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpTypeMatrix(IdRef op0, LiteralInteger op1) + : spv_inst{Op::TypeMatrix, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpTypeImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeImage; } + OpTypeImage(IdRef op0, Dim op1, LiteralInteger op2, LiteralInteger op3, LiteralInteger op4, + LiteralInteger op5, ImageFormat op6, + std::optional op7 = std::nullopt) + : spv_inst{Op::TypeImage, true}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)), + op6_(std::move(op6)), op7_(std::move(op7)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Dim & { return op1_; } + inline auto op1() const -> Dim const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + inline auto op3() -> LiteralInteger & { return op3_; } + inline auto op3() const -> LiteralInteger const & { return op3_; } + inline auto op4() -> LiteralInteger & { return op4_; } + inline auto op4() const -> LiteralInteger const & { return op4_; } + inline auto op5() -> LiteralInteger & { return op5_; } + inline auto op5() const -> LiteralInteger const & { return op5_; } + inline auto op6() -> ImageFormat & { return op6_; } + inline auto op6() const -> ImageFormat const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } + inline auto op7() const -> std::optional const & { return op7_; } + + private: + IdRef op0_; + Dim op1_; + LiteralInteger op2_; + LiteralInteger op3_; + LiteralInteger op4_; + LiteralInteger op5_; + ImageFormat op6_; + std::optional op7_; +}; +class OpTypeSampler : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampler; } + OpTypeSampler() : spv_inst{Op::TypeSampler, true} {} + + private: +}; +class OpTypeSampledImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampledImage; } + OpTypeSampledImage(IdRef op0) : spv_inst{Op::TypeSampledImage, true}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpTypeArray : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeArray; } + OpTypeArray(IdRef op0, IdRef op1) + : spv_inst{Op::TypeArray, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpTypeRuntimeArray : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeRuntimeArray; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpTypeRuntimeArray(IdRef op0) : spv_inst{Op::TypeRuntimeArray, true}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpTypeStruct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeStruct; } + OpTypeStruct(std::vector op0) : spv_inst{Op::TypeStruct, true}, op0_(std::move(op0)) {} + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + std::vector op0_; +}; +class OpTypeOpaque : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeOpaque; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpTypeOpaque(LiteralString op0) : spv_inst{Op::TypeOpaque, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpTypePointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePointer; } + OpTypePointer(StorageClass op0, IdRef op1) + : spv_inst{Op::TypePointer, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> StorageClass & { return op0_; } + inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + StorageClass op0_; + IdRef op1_; +}; +class OpTypeFunction : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFunction; } + OpTypeFunction(IdRef op0, std::vector op1) + : spv_inst{Op::TypeFunction, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpTypeEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeEvent; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpTypeEvent() : spv_inst{Op::TypeEvent, true} {} + + private: +}; +class OpTypeDeviceEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeDeviceEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpTypeDeviceEvent() : spv_inst{Op::TypeDeviceEvent, true} {} + + private: +}; +class OpTypeReserveId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeReserveId; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpTypeReserveId() : spv_inst{Op::TypeReserveId, true} {} + + private: +}; +class OpTypeQueue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeQueue; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpTypeQueue() : spv_inst{Op::TypeQueue, true} {} + + private: +}; +class OpTypePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpTypePipe(AccessQualifier op0) : spv_inst{Op::TypePipe, true}, op0_(std::move(op0)) {} + inline auto op0() -> AccessQualifier & { return op0_; } + inline auto op0() const -> AccessQualifier const & { return op0_; } + + private: + AccessQualifier op0_; +}; +class OpTypeForwardPointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeForwardPointer; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpTypeForwardPointer(IdRef op0, StorageClass op1) + : spv_inst{Op::TypeForwardPointer, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> StorageClass & { return op1_; } + inline auto op1() const -> StorageClass const & { return op1_; } + + private: + IdRef op0_; + StorageClass op1_; +}; +class OpConstantTrue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantTrue; } + OpConstantTrue(IdResultType type) : spv_inst{Op::ConstantTrue, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpConstantFalse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantFalse; } + OpConstantFalse(IdResultType type) + : spv_inst{Op::ConstantFalse, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpConstant : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Constant; } + OpConstant(IdResultType type, LiteralContextDependentNumber op0) + : spv_inst{Op::Constant, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> LiteralContextDependentNumber & { return op0_; } + inline auto op0() const -> LiteralContextDependentNumber const & { return op0_; } + + private: + IdResultType type_; + LiteralContextDependentNumber op0_; +}; +class OpConstantComposite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantComposite; } + OpConstantComposite(IdResultType type, std::vector op0) + : spv_inst{Op::ConstantComposite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpConstantSampler : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantSampler; } + constexpr static std::array required_capabilities = {Capability::LiteralSampler}; + OpConstantSampler(IdResultType type, SamplerAddressingMode op0, LiteralInteger op1, + SamplerFilterMode op2) + : spv_inst{Op::ConstantSampler, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> SamplerAddressingMode & { return op0_; } + inline auto op0() const -> SamplerAddressingMode const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> SamplerFilterMode & { return op2_; } + inline auto op2() const -> SamplerFilterMode const & { return op2_; } + + private: + IdResultType type_; + SamplerAddressingMode op0_; + LiteralInteger op1_; + SamplerFilterMode op2_; +}; +class OpConstantNull : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantNull; } + OpConstantNull(IdResultType type) : spv_inst{Op::ConstantNull, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpFunction : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Function; } + OpFunction(IdResultType type, FunctionControl op0, IdRef op1) + : spv_inst{Op::Function, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> FunctionControl & { return op0_; } + inline auto op0() const -> FunctionControl const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + FunctionControl op0_; + IdRef op1_; +}; +class OpFunctionParameter : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionParameter; } + OpFunctionParameter(IdResultType type) + : spv_inst{Op::FunctionParameter, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpFunctionEnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionEnd; } + OpFunctionEnd() : spv_inst{Op::FunctionEnd, false} {} + + private: +}; +class OpFunctionCall : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionCall; } + OpFunctionCall(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::FunctionCall, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpVariable : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Variable; } + OpVariable(IdResultType type, StorageClass op0, std::optional op1 = std::nullopt) + : spv_inst{Op::Variable, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> StorageClass & { return op0_; } + inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + IdResultType type_; + StorageClass op0_; + std::optional op1_; +}; +class OpImageTexelPointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageTexelPointer; } + OpImageTexelPointer(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::ImageTexelPointer, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpLoad : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Load; } + OpLoad(IdResultType type, IdRef op0, std::optional op1 = std::nullopt, + std::optional op2 = std::nullopt) + : spv_inst{Op::Load, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } + inline auto op1() const -> std::optional const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + std::optional op1_; + std::optional op2_; +}; +class OpStore : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Store; } + OpStore(IdRef op0, IdRef op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt) + : spv_inst{Op::Store, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; +}; +class OpCopyMemory : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemory; } + OpCopyMemory(IdRef op0, IdRef op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt) + : spv_inst{Op::CopyMemory, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + + private: + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; + std::optional op4_; +}; +class OpCopyMemorySized : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemorySized; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::UntypedPointersKHR}; + OpCopyMemorySized(IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt, + std::optional op5 = std::nullopt) + : spv_inst{Op::CopyMemorySized, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() -> std::optional & { return op5_; } + inline auto op5() const -> std::optional const & { return op5_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; + std::optional op4_; + std::optional op5_; +}; +class OpAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AccessChain; } + OpAccessChain(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::AccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpInBoundsAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::InBoundsAccessChain; } + OpInBoundsAccessChain(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::InBoundsAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpPtrAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrAccessChain; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::VariablePointers, + Capability::VariablePointersStorageBuffer, Capability::PhysicalStorageBufferAddresses}; + OpPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::PtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpArrayLength : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ArrayLength; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpArrayLength(IdResultType type, IdRef op0, LiteralInteger op1) + : spv_inst{Op::ArrayLength, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + LiteralInteger op1_; +}; +class OpGenericPtrMemSemantics : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GenericPtrMemSemantics; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericPtrMemSemantics(IdResultType type, IdRef op0) + : spv_inst{Op::GenericPtrMemSemantics, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpInBoundsPtrAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::InBoundsPtrAccessChain; + } + constexpr static std::array required_capabilities = {Capability::Addresses}; + OpInBoundsPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::InBoundsPtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Decorate; } + OpDecorate(IdRef op0, Decoration op1, std::optional op2 = std::nullopt) + : spv_inst{Op::Decorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Decoration & { return op1_; } + inline auto op1() const -> Decoration const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdRef op0_; + Decoration op1_; + std::optional op2_; +}; +class OpMemberDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemberDecorate; } + OpMemberDecorate(IdRef op0, LiteralInteger op1, Decoration op2) + : spv_inst{Op::MemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> Decoration & { return op2_; } + inline auto op2() const -> Decoration const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + Decoration op2_; +}; +class OpDecorationGroup : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorationGroup; } + OpDecorationGroup() : spv_inst{Op::DecorationGroup, true} {} + + private: +}; +class OpGroupDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupDecorate; } + OpGroupDecorate(IdRef op0, std::vector op1) + : spv_inst{Op::GroupDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpGroupMemberDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupMemberDecorate; } + OpGroupMemberDecorate(IdRef op0, std::vector op1) + : spv_inst{Op::GroupMemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpVectorExtractDynamic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorExtractDynamic; } + OpVectorExtractDynamic(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorExtractDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorInsertDynamic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorInsertDynamic; } + OpVectorInsertDynamic(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::VectorInsertDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpVectorShuffle : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorShuffle; } + OpVectorShuffle(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::VectorShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpCompositeConstruct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeConstruct; } + OpCompositeConstruct(IdResultType type, std::vector op0) + : spv_inst{Op::CompositeConstruct, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpCompositeExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeExtract; } + OpCompositeExtract(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::CompositeExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpCompositeInsert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeInsert; } + OpCompositeInsert(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::CompositeInsert, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpCopyObject : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyObject; } + OpCopyObject(IdResultType type, IdRef op0) + : spv_inst{Op::CopyObject, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpTranspose : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Transpose; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpTranspose(IdResultType type, IdRef op0) + : spv_inst{Op::Transpose, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSampledImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SampledImage; } + OpSampledImage(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SampledImage, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageSampleImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSampleExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleExplicitLod; + } + OpImageSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSampleExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSampleDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleDrefImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSampleDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSampleDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleDrefExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSampleDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSampleProjImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSampleProjImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSampleProjExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSampleProjExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSampleProjDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjDrefImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSampleProjDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSampleProjDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjDrefExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSampleProjDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageFetch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageFetch; } + OpImageFetch(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageFetch, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageGather; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageDrefGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageDrefGather; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageRead : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageRead; } + OpImageRead(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageRead, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageWrite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageWrite; } + OpImageWrite(IdRef op0, IdRef op1, IdRef op2, std::optional op3 = std::nullopt) + : spv_inst{Op::ImageWrite, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Image; } + OpImage(IdResultType type, IdRef op0) + : spv_inst{Op::Image, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryFormat : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryFormat; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpImageQueryFormat(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryFormat, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryOrder : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryOrder; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpImageQueryOrder(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryOrder, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQuerySizeLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySizeLod; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySizeLod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ImageQuerySizeLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageQuerySize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySize; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySize(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQuerySize, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryLod; } + constexpr static std::array required_capabilities = {Capability::ImageQuery}; + OpImageQueryLod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ImageQueryLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageQueryLevels : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryLevels; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQueryLevels(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryLevels, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQuerySamples : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySamples; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySamples(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQuerySamples, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertFToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToU; } + OpConvertFToU(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertFToS : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToS; } + OpConvertFToS(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertSToF : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertSToF; } + OpConvertSToF(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertSToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertUToF : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToF; } + OpConvertUToF(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertUToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpUConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UConvert; } + OpUConvert(IdResultType type, IdRef op0) + : spv_inst{Op::UConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SConvert; } + OpSConvert(IdResultType type, IdRef op0) + : spv_inst{Op::SConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FConvert; } + OpFConvert(IdResultType type, IdRef op0) + : spv_inst{Op::FConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpQuantizeToF16 : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::QuantizeToF16; } + OpQuantizeToF16(IdResultType type, IdRef op0) + : spv_inst{Op::QuantizeToF16, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertPtrToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertPtrToU; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpConvertPtrToU(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertPtrToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSatConvertSToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SatConvertSToU; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSatConvertSToU(IdResultType type, IdRef op0) + : spv_inst{Op::SatConvertSToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSatConvertUToS : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SatConvertUToS; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSatConvertUToS(IdResultType type, IdRef op0) + : spv_inst{Op::SatConvertUToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertUToPtr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToPtr; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpConvertUToPtr(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertUToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpPtrCastToGeneric : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrCastToGeneric; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpPtrCastToGeneric(IdResultType type, IdRef op0) + : spv_inst{Op::PtrCastToGeneric, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGenericCastToPtr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GenericCastToPtr; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericCastToPtr(IdResultType type, IdRef op0) + : spv_inst{Op::GenericCastToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGenericCastToPtrExplicit : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GenericCastToPtrExplicit; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericCastToPtrExplicit(IdResultType type, IdRef op0, StorageClass op1) + : spv_inst{Op::GenericCastToPtrExplicit, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> StorageClass & { return op1_; } + inline auto op1() const -> StorageClass const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + StorageClass op1_; +}; +class OpBitcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Bitcast; } + OpBitcast(IdResultType type, IdRef op0) + : spv_inst{Op::Bitcast, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSNegate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SNegate; } + OpSNegate(IdResultType type, IdRef op0) + : spv_inst{Op::SNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFNegate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FNegate; } + OpFNegate(IdResultType type, IdRef op0) + : spv_inst{Op::FNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IAdd; } + OpIAdd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FAdd; } + OpFAdd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpISub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ISub; } + OpISub(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ISub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFSub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FSub; } + OpFSub(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FSub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpIMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IMul; } + OpIMul(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FMul; } + OpFMul(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UDiv; } + OpUDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SDiv; } + OpSDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FDiv; } + OpFDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UMod; } + OpUMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSRem : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SRem; } + OpSRem(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SRem, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SMod; } + OpSMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFRem : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FRem; } + OpFRem(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FRem, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FMod; } + OpFMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorTimesScalar : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorTimesScalar; } + OpVectorTimesScalar(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesScalar : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesScalar; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesScalar(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorTimesMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorTimesMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpVectorTimesMatrix(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesVector : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesVector; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesVector(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesVector, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesMatrix(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpOuterProduct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::OuterProduct; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpOuterProduct(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::OuterProduct, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpDot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Dot; } + OpDot(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Dot, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpIAddCarry : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IAddCarry; } + OpIAddCarry(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IAddCarry, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpISubBorrow : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ISubBorrow; } + OpISubBorrow(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ISubBorrow, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUMulExtended : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UMulExtended; } + OpUMulExtended(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSMulExtended : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SMulExtended; } + OpSMulExtended(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Any; } + OpAny(IdResultType type, IdRef op0) + : spv_inst{Op::Any, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::All; } + OpAll(IdResultType type, IdRef op0) + : spv_inst{Op::All, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsNan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNan; } + OpIsNan(IdResultType type, IdRef op0) + : spv_inst{Op::IsNan, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsInf : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsInf; } + OpIsInf(IdResultType type, IdRef op0) + : spv_inst{Op::IsInf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsFinite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsFinite; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpIsFinite(IdResultType type, IdRef op0) + : spv_inst{Op::IsFinite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsNormal : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNormal; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpIsNormal(IdResultType type, IdRef op0) + : spv_inst{Op::IsNormal, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSignBitSet : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SignBitSet; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSignBitSet(IdResultType type, IdRef op0) + : spv_inst{Op::SignBitSet, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpLessOrGreater : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LessOrGreater; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLessOrGreater(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LessOrGreater, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpOrdered : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Ordered; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpOrdered(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Ordered, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUnordered : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Unordered; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpUnordered(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Unordered, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalEqual; } + OpLogicalEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNotEqual; } + OpLogicalNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalOr; } + OpLogicalOr(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalAnd; } + OpLogicalAnd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalNot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNot; } + OpLogicalNot(IdResultType type, IdRef op0) + : spv_inst{Op::LogicalNot, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSelect : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Select; } + OpSelect(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::Select, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpIEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IEqual; } + OpIEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpINotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::INotEqual; } + OpINotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::INotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UGreaterThan; } + OpUGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SGreaterThan; } + OpSGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UGreaterThanEqual; } + OpUGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SGreaterThanEqual; } + OpSGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpULessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ULessThan; } + OpULessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ULessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SLessThan; } + OpSLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpULessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ULessThanEqual; } + OpULessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ULessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SLessThanEqual; } + OpSLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdEqual; } + OpFOrdEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordEqual; } + OpFUnordEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdNotEqual; } + OpFOrdNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordNotEqual; } + OpFUnordNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdLessThan; } + OpFOrdLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordLessThan; } + OpFUnordLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdGreaterThan; } + OpFOrdGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordGreaterThan; } + OpFUnordGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdLessThanEqual; } + OpFOrdLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordLessThanEqual; } + OpFUnordLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdGreaterThanEqual; } + OpFOrdGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::FUnordGreaterThanEqual; + } + OpFUnordGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftRightLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftRightLogical; } + OpShiftRightLogical(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftRightLogical, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftRightArithmetic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftRightArithmetic; } + OpShiftRightArithmetic(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftRightArithmetic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftLeftLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftLeftLogical; } + OpShiftLeftLogical(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftLeftLogical, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseOr; } + OpBitwiseOr(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseXor; } + OpBitwiseXor(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseXor, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseAnd; } + OpBitwiseAnd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpNot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Not; } + OpNot(IdResultType type, IdRef op0) + : spv_inst{Op::Not, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpBitFieldInsert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldInsert; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldInsert(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::BitFieldInsert, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpBitFieldSExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldSExtract; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldSExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BitFieldSExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpBitFieldUExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldUExtract; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldUExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BitFieldUExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpBitReverse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitReverse; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitReverse(IdResultType type, IdRef op0) + : spv_inst{Op::BitReverse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpBitCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitCount; } + OpBitCount(IdResultType type, IdRef op0) + : spv_inst{Op::BitCount, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdx : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdx; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpDPdx(IdResultType type, IdRef op0) + : spv_inst{Op::DPdx, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdy : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdy; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpDPdy(IdResultType type, IdRef op0) + : spv_inst{Op::DPdy, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidth : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Fwidth; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpFwidth(IdResultType type, IdRef op0) + : spv_inst{Op::Fwidth, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdxFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdxFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdxFine(IdResultType type, IdRef op0) + : spv_inst{Op::DPdxFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdyFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdyFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdyFine(IdResultType type, IdRef op0) + : spv_inst{Op::DPdyFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidthFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FwidthFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpFwidthFine(IdResultType type, IdRef op0) + : spv_inst{Op::FwidthFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdxCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdxCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdxCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::DPdxCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdyCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdyCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdyCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::DPdyCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidthCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FwidthCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpFwidthCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::FwidthCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpEmitVertex : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EmitVertex; } + constexpr static std::array required_capabilities = {Capability::Geometry}; + OpEmitVertex() : spv_inst{Op::EmitVertex, false} {} + + private: +}; +class OpEndPrimitive : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EndPrimitive; } + constexpr static std::array required_capabilities = {Capability::Geometry}; + OpEndPrimitive() : spv_inst{Op::EndPrimitive, false} {} + + private: +}; +class OpEmitStreamVertex : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EmitStreamVertex; } + constexpr static std::array required_capabilities = { + Capability::GeometryStreams}; + OpEmitStreamVertex(IdRef op0) : spv_inst{Op::EmitStreamVertex, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpEndStreamPrimitive : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EndStreamPrimitive; } + constexpr static std::array required_capabilities = { + Capability::GeometryStreams}; + OpEndStreamPrimitive(IdRef op0) + : spv_inst{Op::EndStreamPrimitive, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpControlBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ControlBarrier; } + OpControlBarrier(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpMemoryBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryBarrier; } + OpMemoryBarrier(IdScope op0, IdMemorySemantics op1) + : spv_inst{Op::MemoryBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdMemorySemantics & { return op1_; } + inline auto op1() const -> IdMemorySemantics const & { return op1_; } + + private: + IdScope op0_; + IdMemorySemantics op1_; +}; +class OpAtomicLoad : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicLoad; } + OpAtomicLoad(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicLoad, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicStore : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicStore; } + OpAtomicStore(IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicStore, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicExchange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicExchange; } + OpAtomicExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicExchange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicCompareExchange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::AtomicCompareExchange; + } + OpAtomicCompareExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, + IdMemorySemantics op3, IdRef op4, IdRef op5) + : spv_inst{Op::AtomicCompareExchange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdMemorySemantics & { return op3_; } + inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdMemorySemantics op3_; + IdRef op4_; + IdRef op5_; +}; +class OpAtomicCompareExchangeWeak : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::AtomicCompareExchangeWeak; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicCompareExchangeWeak(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, + IdMemorySemantics op3, IdRef op4, IdRef op5) + : spv_inst{Op::AtomicCompareExchangeWeak, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdMemorySemantics & { return op3_; } + inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdMemorySemantics op3_; + IdRef op4_; + IdRef op5_; +}; +class OpAtomicIIncrement : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIIncrement; } + OpAtomicIIncrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicIIncrement, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicIDecrement : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIDecrement; } + OpAtomicIDecrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicIDecrement, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIAdd; } + OpAtomicIAdd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicISub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicISub; } + OpAtomicISub(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicISub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicSMin; } + OpAtomicSMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicUMin; } + OpAtomicUMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicSMax; } + OpAtomicSMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicUMax; } + OpAtomicUMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicAnd; } + OpAtomicAnd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicOr; } + OpAtomicOr(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicXor; } + OpAtomicXor(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicXor, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpPhi : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Phi; } + OpPhi(IdResultType type, std::vector op0) + : spv_inst{Op::Phi, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpLoopMerge : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LoopMerge; } + OpLoopMerge(IdRef op0, IdRef op1, LoopControl op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::LoopMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LoopControl & { return op2_; } + inline auto op2() const -> LoopControl const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + LoopControl op2_; + std::optional op3_; +}; +class OpSelectionMerge : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SelectionMerge; } + OpSelectionMerge(IdRef op0, SelectionControl op1) + : spv_inst{Op::SelectionMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> SelectionControl & { return op1_; } + inline auto op1() const -> SelectionControl const & { return op1_; } + + private: + IdRef op0_; + SelectionControl op1_; +}; +class OpLabel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Label; } + OpLabel() : spv_inst{Op::Label, true} {} + + private: +}; +class OpBranch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Branch; } + OpBranch(IdRef op0) : spv_inst{Op::Branch, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpBranchConditional : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BranchConditional; } + OpBranchConditional(IdRef op0, IdRef op1, IdRef op2, std::vector op3) + : spv_inst{Op::BranchConditional, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::vector & { return op3_; } + inline auto op3() const -> std::vector const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::vector op3_; +}; +class OpSwitch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Switch; } + OpSwitch(IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::Switch, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpKill : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Kill; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpKill() : spv_inst{Op::Kill, false} {} + + private: +}; +class OpReturn : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Return; } + OpReturn() : spv_inst{Op::Return, false} {} + + private: +}; +class OpReturnValue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReturnValue; } + OpReturnValue(IdRef op0) : spv_inst{Op::ReturnValue, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpUnreachable : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Unreachable; } + OpUnreachable() : spv_inst{Op::Unreachable, false} {} + + private: +}; +class OpLifetimeStart : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LifetimeStart; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLifetimeStart(IdRef op0, LiteralInteger op1) + : spv_inst{Op::LifetimeStart, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpLifetimeStop : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LifetimeStop; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLifetimeStop(IdRef op0, LiteralInteger op1) + : spv_inst{Op::LifetimeStop, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpGroupAsyncCopy : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAsyncCopy; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGroupAsyncCopy(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::GroupAsyncCopy, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpGroupWaitEvents : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupWaitEvents; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGroupWaitEvents(IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupWaitEvents, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAll; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupAll(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupAll, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAny; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupAny(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupAny, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupBroadcast; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupBroadcast, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupIAdd; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFAdd; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupUMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupSMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupUMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupSMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::WritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::WritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpReservedReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReservedReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReservedReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::ReservedReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpReservedWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReservedWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReservedWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::ReservedWritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpReserveReadPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ReserveReadPipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReserveReadPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReserveReadPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpReserveWritePipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ReserveWritePipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReserveWritePipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReserveWritePipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpCommitReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CommitReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpCommitReadPipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::CommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpCommitWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CommitWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpCommitWritePipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::CommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpIsValidReserveId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsValidReserveId; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpIsValidReserveId(IdResultType type, IdRef op0) + : spv_inst{Op::IsValidReserveId, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGetNumPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetNumPipePackets; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGetNumPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::GetNumPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGetMaxPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetMaxPipePackets; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGetMaxPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::GetMaxPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupReserveReadPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupReserveReadPipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupReserveReadPipePackets(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GroupReserveReadPipePackets, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupReserveWritePipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupReserveWritePipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupReserveWritePipePackets(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GroupReserveWritePipePackets, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupCommitReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupCommitReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupCommitReadPipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::GroupCommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupCommitWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupCommitWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupCommitWritePipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::GroupCommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpEnqueueMarker : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EnqueueMarker; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpEnqueueMarker(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::EnqueueMarker, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpEnqueueKernel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EnqueueKernel; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpEnqueueKernel(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5, IdRef op6, IdRef op7, IdRef op8, IdRef op9, std::vector op10) + : spv_inst{Op::EnqueueKernel, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), op8_(std::move(op8)), + op9_(std::move(op9)), op10_(std::move(op10)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> IdRef & { return op6_; } + inline auto op6() const -> IdRef const & { return op6_; } + inline auto op7() -> IdRef & { return op7_; } + inline auto op7() const -> IdRef const & { return op7_; } + inline auto op8() -> IdRef & { return op8_; } + inline auto op8() const -> IdRef const & { return op8_; } + inline auto op9() -> IdRef & { return op9_; } + inline auto op9() const -> IdRef const & { return op9_; } + inline auto op10() -> std::vector & { return op10_; } + inline auto op10() const -> std::vector const & { return op10_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + IdRef op6_; + IdRef op7_; + IdRef op8_; + IdRef op9_; + std::vector op10_; +}; +class OpGetKernelNDrangeSubGroupCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelNDrangeSubGroupCount; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelNDrangeSubGroupCount(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GetKernelNDrangeSubGroupCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelNDrangeMaxSubGroupSize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelNDrangeMaxSubGroupSize; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelNDrangeMaxSubGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GetKernelNDrangeMaxSubGroupSize, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelWorkGroupSize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelWorkGroupSize; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelWorkGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::GetKernelWorkGroupSize, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpGetKernelPreferredWorkGroupSizeMultiple : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelPreferredWorkGroupSizeMultiple; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelPreferredWorkGroupSizeMultiple(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3) + : spv_inst{Op::GetKernelPreferredWorkGroupSizeMultiple, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpRetainEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::RetainEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpRetainEvent(IdRef op0) : spv_inst{Op::RetainEvent, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpReleaseEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReleaseEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpReleaseEvent(IdRef op0) : spv_inst{Op::ReleaseEvent, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpCreateUserEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CreateUserEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpCreateUserEvent(IdResultType type) + : spv_inst{Op::CreateUserEvent, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpIsValidEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsValidEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpIsValidEvent(IdResultType type, IdRef op0) + : spv_inst{Op::IsValidEvent, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSetUserEventStatus : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SetUserEventStatus; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpSetUserEventStatus(IdRef op0, IdRef op1) + : spv_inst{Op::SetUserEventStatus, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpCaptureEventProfilingInfo : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CaptureEventProfilingInfo; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpCaptureEventProfilingInfo(IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::CaptureEventProfilingInfo, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGetDefaultQueue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetDefaultQueue; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetDefaultQueue(IdResultType type) + : spv_inst{Op::GetDefaultQueue, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpBuildNDRange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BuildNDRange; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpBuildNDRange(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BuildNDRange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpImageSparseSampleImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseSampleImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseSampleExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSparseSampleExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSparseSampleDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleDrefImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseSampleDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseSampleDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleDrefExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSparseSampleDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSparseSampleProjImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseSampleProjImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseSampleProjExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSparseSampleProjExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSparseSampleProjDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjDrefImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseSampleProjDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseSampleProjDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjDrefExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSparseSampleProjDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSparseFetch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseFetch; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseFetch(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseFetch, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseGather; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseDrefGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseDrefGather; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseTexelsResident : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseTexelsResident; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseTexelsResident(IdResultType type, IdRef op0) + : spv_inst{Op::ImageSparseTexelsResident, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpNoLine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::NoLine; } + OpNoLine() : spv_inst{Op::NoLine, false} {} + + private: +}; +class OpAtomicFlagTestAndSet : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFlagTestAndSet; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicFlagTestAndSet(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicFlagTestAndSet, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicFlagClear : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFlagClear; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicFlagClear(IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicFlagClear, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpImageSparseRead : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseRead; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseRead(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseRead, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpSizeOf : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SizeOf; } + constexpr static std::array required_capabilities = {Capability::Addresses}; + OpSizeOf(IdResultType type, IdRef op0) + : spv_inst{Op::SizeOf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpTypePipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipeStorage; } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpTypePipeStorage() : spv_inst{Op::TypePipeStorage, true} {} + + private: +}; +class OpConstantPipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantPipeStorage; } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpConstantPipeStorage(IdResultType type, LiteralInteger op0, LiteralInteger op1, + LiteralInteger op2) + : spv_inst{Op::ConstantPipeStorage, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> LiteralInteger & { return op0_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + + private: + IdResultType type_; + LiteralInteger op0_; + LiteralInteger op1_; + LiteralInteger op2_; +}; +class OpCreatePipeFromPipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CreatePipeFromPipeStorage; + } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpCreatePipeFromPipeStorage(IdResultType type, IdRef op0) + : spv_inst{Op::CreatePipeFromPipeStorage, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGetKernelLocalSizeForSubgroupCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelLocalSizeForSubgroupCount; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupDispatch}; + OpGetKernelLocalSizeForSubgroupCount(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3, IdRef op4) + : spv_inst{Op::GetKernelLocalSizeForSubgroupCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelMaxNumSubgroups : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelMaxNumSubgroups; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupDispatch}; + OpGetKernelMaxNumSubgroups(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::GetKernelMaxNumSubgroups, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpTypeNamedBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeNamedBarrier; } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpTypeNamedBarrier() : spv_inst{Op::TypeNamedBarrier, true} {} + + private: +}; +class OpNamedBarrierInitialize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::NamedBarrierInitialize; + } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpNamedBarrierInitialize(IdResultType type, IdRef op0) + : spv_inst{Op::NamedBarrierInitialize, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpMemoryNamedBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryNamedBarrier; } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpMemoryNamedBarrier(IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::MemoryNamedBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpModuleProcessed : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ModuleProcessed; } + OpModuleProcessed(LiteralString op0) + : spv_inst{Op::ModuleProcessed, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExecutionModeId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionModeId; } + OpExecutionModeId(IdRef op0, ExecutionMode op1) + : spv_inst{Op::ExecutionModeId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> ExecutionMode & { return op1_; } + inline auto op1() const -> ExecutionMode const & { return op1_; } + + private: + IdRef op0_; + ExecutionMode op1_; +}; +class OpDecorateId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorateId; } + constexpr static std::array required_extensions = { + "SPV_GOOGLE_hlsl_functionality1"}; + OpDecorateId(IdRef op0, Decoration op1) + : spv_inst{Op::DecorateId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Decoration & { return op1_; } + inline auto op1() const -> Decoration const & { return op1_; } + + private: + IdRef op0_; + Decoration op1_; +}; +class OpGroupNonUniformElect : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformElect; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniform}; + OpGroupNonUniformElect(IdResultType type, IdScope op0) + : spv_inst{Op::GroupNonUniformElect, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + + private: + IdResultType type_; + IdScope op0_; +}; +class OpGroupNonUniformAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformAll; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAll(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAll, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformAny; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAny(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAny, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformAllEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformAllEqual; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAllEqual(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAllEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBroadcast; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBroadcast, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformBroadcastFirst : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBroadcastFirst; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBroadcastFirst(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBroadcastFirst, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallot; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallot(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallot, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformInverseBallot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformInverseBallot; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformInverseBallot(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformInverseBallot, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallotBitExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotBitExtract; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotBitExtract(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBallotBitExtract, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformBallotBitCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotBitCount; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotBitCount(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBallotBitCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupNonUniformBallotFindLSB : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotFindLSB; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotFindLSB(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallotFindLSB, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallotFindMSB : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotFindMSB; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotFindMSB(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallotFindMSB, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformShuffle : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffle; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffle}; + OpGroupNonUniformShuffle(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffle}; + OpGroupNonUniformShuffleXor(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleUp : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleUp; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffleRelative}; + OpGroupNonUniformShuffleUp(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleUp, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleDown : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleDown; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffleRelative}; + OpGroupNonUniformShuffleDown(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleDown, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformIAdd; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFAdd; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformIMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformIMul; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformIMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformIMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMul; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformSMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformUMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformSMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformUMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseAnd; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformBitwiseAnd, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseOr; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformBitwiseOr, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformBitwiseXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalAnd; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformLogicalAnd, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalOr; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformLogicalOr, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformLogicalXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformQuadBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformQuadBroadcast; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformQuad}; + OpGroupNonUniformQuadBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformQuadBroadcast, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformQuadSwap : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformQuadSwap; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformQuad}; + OpGroupNonUniformQuadSwap(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformQuadSwap, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpCopyLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyLogical; } + OpCopyLogical(IdResultType type, IdRef op0) + : spv_inst{Op::CopyLogical, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpPtrEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrEqual; } + OpPtrEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpPtrNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrNotEqual; } + OpPtrNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpPtrDiff : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrDiff; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::VariablePointers, + Capability::VariablePointersStorageBuffer}; + OpPtrDiff(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrDiff, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpTypeCooperativeMatrixKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::TypeCooperativeMatrixKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpTypeCooperativeMatrixKHR(IdRef op0, IdScope op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::TypeCooperativeMatrixKHR, true}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdRef op0_; + IdScope op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpCooperativeMatrixLoadKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLoadKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixLoadKHR(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt) + : spv_inst{Op::CooperativeMatrixLoadKHR, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; + std::optional op4_; +}; +class OpCooperativeMatrixStoreKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixStoreKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixStoreKHR(IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt, + std::optional op5 = std::nullopt) + : spv_inst{Op::CooperativeMatrixStoreKHR, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() -> std::optional & { return op5_; } + inline auto op5() const -> std::optional const & { return op5_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; + std::optional op4_; + std::optional op5_; +}; +class OpCooperativeMatrixMulAddKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixMulAddKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixMulAddKHR(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::CooperativeMatrixMulAddKHR, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpCooperativeMatrixLengthKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLengthKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixLengthKHR(IdResultType type, IdRef op0) + : spv_inst{Op::CooperativeMatrixLengthKHR, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSubgroupBlockReadINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::SubgroupBlockReadINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupBufferBlockIOINTEL}; + OpSubgroupBlockReadINTEL(IdResultType type, IdRef op0) + : spv_inst{Op::SubgroupBlockReadINTEL, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSubgroupBlockWriteINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::SubgroupBlockWriteINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupBufferBlockIOINTEL}; + OpSubgroupBlockWriteINTEL(IdRef op0, IdRef op1) + : spv_inst{Op::SubgroupBlockWriteINTEL, false}, op0_(std::move(op0)), op1_(std::move(op1)) { + } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpAsmTargetINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AsmTargetINTEL; } + constexpr static std::array required_capabilities = {Capability::AsmINTEL}; + OpAsmTargetINTEL(LiteralString op0) + : spv_inst{Op::AsmTargetINTEL, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpAsmINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AsmINTEL; } + constexpr static std::array required_capabilities = {Capability::AsmINTEL}; + OpAsmINTEL(IdResultType type, IdRef op0, IdRef op1, LiteralString op2, LiteralString op3) + : spv_inst{Op::AsmINTEL, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } + inline auto op2() const -> LiteralString const & { return op2_; } + inline auto op3() -> LiteralString & { return op3_; } + inline auto op3() const -> LiteralString const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + LiteralString op2_; + LiteralString op3_; +}; +class OpAsmCallINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AsmCallINTEL; } + constexpr static std::array required_capabilities = {Capability::AsmINTEL}; + OpAsmCallINTEL(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::AsmCallINTEL, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpAtomicFMinEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMinEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16MinMaxEXT, Capability::AtomicFloat32MinMaxEXT, + Capability::AtomicFloat64MinMaxEXT, Capability::AtomicFloat16VectorNV}; + OpAtomicFMinEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFMinEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicFMaxEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMaxEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16MinMaxEXT, Capability::AtomicFloat32MinMaxEXT, + Capability::AtomicFloat64MinMaxEXT, Capability::AtomicFloat16VectorNV}; + OpAtomicFMaxEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFMaxEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicFAddEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFAddEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16AddEXT, Capability::AtomicFloat32AddEXT, + Capability::AtomicFloat64AddEXT, Capability::AtomicFloat16VectorNV}; + constexpr static std::array required_extensions = { + "SPV_EXT_shader_atomic_float_add"}; + OpAtomicFAddEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFAddEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpConvertFToBF16INTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToBF16INTEL; } + constexpr static std::array required_capabilities = { + Capability::BFloat16ConversionINTEL}; + OpConvertFToBF16INTEL(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToBF16INTEL, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertBF16ToFINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertBF16ToFINTEL; } + constexpr static std::array required_capabilities = { + Capability::BFloat16ConversionINTEL}; + OpConvertBF16ToFINTEL(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertBF16ToFINTEL, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpControlBarrierArriveINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ControlBarrierArriveINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SplitBarrierINTEL}; + OpControlBarrierArriveINTEL(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrierArriveINTEL, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpControlBarrierWaitINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ControlBarrierWaitINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SplitBarrierINTEL}; + OpControlBarrierWaitINTEL(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrierWaitINTEL, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpCooperativeMatrixLoadCheckedINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLoadCheckedINTEL; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixCheckedInstructionsINTEL}; + OpCooperativeMatrixLoadCheckedINTEL(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3, IdRef op4, IdRef op5, + std::optional op6 = std::nullopt, + std::optional op7 = std::nullopt, + std::optional op8 = std::nullopt) + : spv_inst{Op::CooperativeMatrixLoadCheckedINTEL, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)), op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), + op8_(std::move(op8)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> std::optional & { return op6_; } + inline auto op6() const -> std::optional const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } + inline auto op7() const -> std::optional const & { return op7_; } + inline auto op8() -> std::optional & { return op8_; } + inline auto op8() const -> std::optional const & { return op8_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + std::optional op6_; + std::optional op7_; + std::optional op8_; +}; +class OpCooperativeMatrixStoreCheckedINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixStoreCheckedINTEL; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixCheckedInstructionsINTEL}; + OpCooperativeMatrixStoreCheckedINTEL(IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5, IdRef op6, + std::optional op7 = std::nullopt, + std::optional op8 = std::nullopt, + std::optional op9 = std::nullopt) + : spv_inst{Op::CooperativeMatrixStoreCheckedINTEL, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), op8_(std::move(op8)), + op9_(std::move(op9)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> IdRef & { return op6_; } + inline auto op6() const -> IdRef const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } + inline auto op7() const -> std::optional const & { return op7_; } + inline auto op8() -> std::optional & { return op8_; } + inline auto op8() const -> std::optional const & { return op8_; } + inline auto op9() -> std::optional & { return op9_; } + inline auto op9() const -> std::optional const & { return op9_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + IdRef op6_; + std::optional op7_; + std::optional op8_; + std::optional op9_; +}; + +} // namespace tinytc::spv + +#endif // GENERATED_INSTRUCTIONS_20250605_HPP diff --git a/src/spv/lut.hpp b/src/spv/lut.hpp new file mode 100644 index 00000000..4af4b86b --- /dev/null +++ b/src/spv/lut.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LUT_20241216_HPP +#define LUT_20241216_HPP + +namespace tinytc::spv { + +class spv_inst; + +template +auto lookup(Map &map, Key &&key, Maker &&maker) { + auto it = map.find(key); + if (it == map.end()) { + map[key] = maker(key); + return map[key]; + } + return it->second; +} +template auto lookup(spv_inst *&var, Maker &&maker) -> spv_inst * { + if (!var) { + var = maker(); + } + return var; +} + +} // namespace tinytc::spv + +#endif // LUT_20241216_HPP diff --git a/src/spv/matrix_walker.cpp b/src/spv/matrix_walker.cpp new file mode 100644 index 00000000..b9f6de62 --- /dev/null +++ b/src/spv/matrix_walker.cpp @@ -0,0 +1,98 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/matrix_walker.hpp" +#include "coopmatrix_layout.hpp" +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/types.hpp" + +namespace tinytc::spv { + +matrix_walker::matrix_walker(uniquifier &unique, std::int32_t sgs, coopmatrix_layout const &layout, + spv_inst *pos0, spv_inst *pos1, spv_inst *shape0, spv_inst *shape1, + spv_inst *stride0, spv_inst *stride1, checked_flag chk, + std::int32_t constant_p) + : unique_{unique}, layout_{layout}, chk_{chk} { + index_ty_ = unique.scalar_ty(scalar_type::index); + + auto &mod = unique_.mod(); + auto crows = unique.constant(layout_.rows); + row_inc_ = mod.add(index_ty_, crows, stride0); + col_inc_factor_ = sgs / layout.rows; + col_inc_ = mod.add(index_ty_, unique.constant(col_inc_factor_), stride1); + + spv_inst *p = constant_p >= 0 ? unique.constant(constant_p) + : unique.load_builtin(BuiltIn::SubgroupLocalInvocationId); + p = mod.add(index_ty_, p); + + row_ = layout.rows < sgs ? mod.add(index_ty_, p, crows) : p; + row_ = mod.add(index_ty_, row_, pos0); + row_ = mod.add(index_ty_, row_, stride0); + + auto c0 = unique.null_constant(index_ty_); + col0_ = layout.rows < sgs ? mod.add(index_ty_, p, crows) : c0; + col0_ = mod.add(index_ty_, col0_, pos1); + col0_ = mod.add(index_ty_, col0_, stride1); + col_ = col0_; + + if (rows_checked()) { + row_max_ = mod.add(index_ty_, shape0, stride0); + } + if (may_need_mask() || cols_checked()) { + col_max_ = mod.add(index_ty_, shape1, stride1); + } +} + +void matrix_walker::advance_block() { + col_ = col0_; + col_no_ = 0; + row_ = unique_.mod().add(index_ty_, row_, row_inc_); + ++block_no_; +} +void matrix_walker::advance_column() { + col_ = unique_.mod().add(index_ty_, col_, col_inc_); + ++col_no_; +} + +auto matrix_walker::component_no(std::int32_t col_no) const -> std::int32_t { + return layout_.component_no(col_no, block_no_); +} +auto matrix_walker::component_no() const -> std::int32_t { return component_no(col_no_); } +auto matrix_walker::offset() const -> spv_inst * { + return unique_.mod().add(index_ty_, row_, col_); +}; +auto matrix_walker::rows_checked() const -> bool { + return chk_ == checked_flag::both || chk_ == checked_flag::rows; +} +auto matrix_walker::cols_checked() const -> bool { + return chk_ == checked_flag::both || chk_ == checked_flag::cols; +} +auto matrix_walker::needs_mask() const -> bool { + return (col_no_ + 1) * col_inc_factor_ > layout_.shape1; +} +auto matrix_walker::may_need_mask() const -> bool { + return layout_.cols * col_inc_factor_ > layout_.shape1; +} + +auto matrix_walker::col_ok() const -> spv_inst * { + auto c0 = unique_.null_constant(index_ty_); + auto bool_ty = unique_.bool_ty(); + auto &mod = unique_.mod(); + auto check1 = mod.add(bool_ty, c0, col_); + auto check2 = mod.add(bool_ty, col_, col_max_); + return mod.add(bool_ty, check1, check2); +} +auto matrix_walker::row_ok() const -> spv_inst * { + auto c0 = unique_.null_constant(index_ty_); + auto bool_ty = unique_.bool_ty(); + auto &mod = unique_.mod(); + auto check1 = mod.add(bool_ty, c0, row_); + auto check2 = mod.add(bool_ty, row_, row_max_); + return mod.add(bool_ty, check1, check2); +} + +} // namespace tinytc::spv diff --git a/src/spv/matrix_walker.hpp b/src/spv/matrix_walker.hpp new file mode 100644 index 00000000..d4b36579 --- /dev/null +++ b/src/spv/matrix_walker.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MATRIX_WALKER_20250428_HPP +#define MATRIX_WALKER_20250428_HPP + +#include + +namespace tinytc { +enum class checked_flag; +struct coopmatrix_layout; +} // namespace tinytc + +namespace tinytc::spv { +class spv_inst; +class uniquifier; + +class matrix_walker { + public: + matrix_walker(uniquifier &unique, std::int32_t sgs, coopmatrix_layout const &layout, + spv_inst *pos0, spv_inst *pos1, spv_inst *shape0, spv_inst *shape1, + spv_inst *stride0, spv_inst *stride1, checked_flag chk, + std::int32_t constant_p = -1); + + void advance_block(); + void advance_column(); + + auto component_no(std::int32_t col_no) const -> std::int32_t; + auto component_no() const -> std::int32_t; + auto offset() const -> spv_inst *; + auto rows_checked() const -> bool; + auto cols_checked() const -> bool; + auto needs_mask() const -> bool; + auto may_need_mask() const -> bool; + auto col_ok() const -> spv_inst *; + auto row_ok() const -> spv_inst *; + inline auto block_no() const { return block_no_; } + inline auto col_no() const { return col_no_; } + + private: + uniquifier &unique_; + coopmatrix_layout const &layout_; + checked_flag chk_; + spv_inst *index_ty_; + spv_inst *row_inc_; + std::int64_t col_inc_factor_; + spv_inst *col_inc_; + spv_inst *row_; + spv_inst *col0_; + spv_inst *col_; + spv_inst *row_max_ = nullptr; + spv_inst *col_max_ = nullptr; + std::int32_t block_no_ = 0; + std::int32_t col_no_ = 0; +}; + +} // namespace tinytc::spv + +#endif // MATRIX_WALKER_20250428_HPP diff --git a/src/spv/module.cpp b/src/spv/module.cpp new file mode 100644 index 00000000..59c53c42 --- /dev/null +++ b/src/spv/module.cpp @@ -0,0 +1,105 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "tinytc/types.h" + +#include "error.hpp" +#include "spv/defs.hpp" +#include "spv/module.hpp" +#include "spv/pass/dump_asm.hpp" +#include "tinytc/tinytc.h" +#include "tinytc/types.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { +void ilist_callbacks::node_added(spv::spv_inst *) {} +void ilist_callbacks::node_moved(spv::spv_inst *) {} +void ilist_callbacks::node_removed(spv::spv_inst *node) { delete node; } +} // namespace tinytc + +using namespace tinytc; + +tinytc_spv_mod::tinytc_spv_mod(compiler_context ctx, tinytc_core_feature_flags_t core_features, + std::int32_t major_version, std::int32_t minor_version) + : ctx_{std::move(ctx)}, core_features_(core_features), major_version_{major_version}, + minor_version_{minor_version} {} +tinytc_spv_mod::~tinytc_spv_mod() {} + +auto tinytc_spv_mod::bound() const -> std::uint32_t { + std::uint32_t bnd = 0; + for (auto const &sec : insts_) { + for (auto const &i : sec) { + if (i.has_result_id()) { + bnd = std::max(bnd, i.id()); + } + } + } + return bnd + 1; +} + +extern "C" { + +tinytc_status_t tinytc_spv_mod_dump(const_tinytc_spv_mod_t mod) { + if (mod == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { spv::dump_asm_pass{std::cerr}.run_on_module(*mod); }); +} + +tinytc_status_t tinytc_spv_mod_print_to_file(const_tinytc_spv_mod_t mod, char const *filename) { + if (mod == nullptr || filename == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto stream = std::ofstream(filename); + if (!stream.good()) { + throw status::file_io_error; + } + spv::dump_asm_pass{stream}.run_on_module(*mod); + }); +} + +tinytc_status_t tinytc_spv_mod_print_to_string(const_tinytc_spv_mod_t mod, char **str) { + if (mod == nullptr || str == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const text = [&] { + auto oss = std::ostringstream{}; + spv::dump_asm_pass{oss}.run_on_module(*mod); + return std::move(oss).str(); + }(); + auto const length = text.size() + 1; // Need to include terminating null character + *str = (char *)malloc(length * sizeof(char)); + if (!str) { + throw status::bad_alloc; + } + std::strncpy(*str, text.c_str(), length); + }); +} +tinytc_status_t tinytc_spv_mod_release(tinytc_spv_mod_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + auto ref_count = obj->dec_ref(); + if (ref_count == 0) { + delete obj; + } + return tinytc_status_success; +} + +tinytc_status_t tinytc_spv_mod_retain(tinytc_spv_mod_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + obj->inc_ref(); + return tinytc_status_success; +} +} diff --git a/src/spv/module.hpp b/src/spv/module.hpp new file mode 100644 index 00000000..536ee294 --- /dev/null +++ b/src/spv/module.hpp @@ -0,0 +1,94 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MODULE_20241029_HPP +#define MODULE_20241029_HPP + +#include "reference_counted.hpp" +#include "spv/defs.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "util/ilist.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +template <> struct ilist_callbacks { + void node_added(spv::spv_inst *node); + void node_moved(spv::spv_inst *node); + void node_removed(spv::spv_inst *node); +}; + +namespace spv { + +enum class section { + capability = 0, + extension = 1, + ext_inst = 2, + memory_model = 3, + entry_point = 4, + execution_mode = 5, + decoration = 6, + type_const_var = 7, + function = 8 +}; +inline constexpr std::int32_t num_module_sections = 9; + +} // namespace spv +} // namespace tinytc + +struct tinytc_spv_mod final : tinytc::reference_counted { + public: + using iterator = tinytc::ilist::iterator; + using const_iterator = tinytc::ilist::const_iterator; + + tinytc_spv_mod(tinytc::compiler_context ctx, tinytc_core_feature_flags_t core_features, + std::int32_t major_version = 1, std::int32_t minor_version = 6); + ~tinytc_spv_mod(); + + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::compiler_context { return ctx_; } + inline auto core_features() const -> tinytc_core_feature_flags_t { return core_features_; } + + auto bound() const -> std::uint32_t; + + inline auto insts(tinytc::spv::section s) -> tinytc::ilist & { + return insts_[static_cast(s)]; + } + inline auto insts(tinytc::spv::section s) const + -> tinytc::ilist const & { + return insts_[static_cast(s)]; + } + inline auto empty(tinytc::spv::section s) const -> bool { + return insts_[static_cast(s)].empty(); + } + + inline auto major_version() const -> std::int32_t { return major_version_; } + inline auto minor_version() const -> std::int32_t { return minor_version_; } + + template + auto add_to(tinytc::spv::section s, Args &&...args) -> T * { + auto ptr = std::make_unique(std::forward(args)...).release(); + insts(s).push_back(ptr); + return ptr; + } + template auto add(Args &&...args) -> T * { + return add_to(tinytc::spv::section::function, std::forward(args)...); + } + + private: + tinytc::compiler_context ctx_; + tinytc_core_feature_flags_t core_features_; + std::array, tinytc::spv::num_module_sections> insts_; + std::int32_t major_version_, minor_version_; +}; + +namespace tinytc::spv { +// using mod = ::tinytc_spv_mod; +} // namespace tinytc::spv + +#endif // MODULE_20241029_HPP diff --git a/src/spv/names.cpp b/src/spv/names.cpp new file mode 100644 index 00000000..4f82ceeb --- /dev/null +++ b/src/spv/names.cpp @@ -0,0 +1,3125 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "names.hpp" +#include "enums.hpp" + +namespace tinytc::spv { + +auto to_string(Op op) -> char const * { + switch (op) { + case Op::Nop: + return "Nop"; + case Op::Undef: + return "Undef"; + case Op::SourceContinued: + return "SourceContinued"; + case Op::Source: + return "Source"; + case Op::SourceExtension: + return "SourceExtension"; + case Op::Name: + return "Name"; + case Op::MemberName: + return "MemberName"; + case Op::String: + return "String"; + case Op::Line: + return "Line"; + case Op::Extension: + return "Extension"; + case Op::ExtInstImport: + return "ExtInstImport"; + case Op::ExtInst: + return "ExtInst"; + case Op::MemoryModel: + return "MemoryModel"; + case Op::EntryPoint: + return "EntryPoint"; + case Op::ExecutionMode: + return "ExecutionMode"; + case Op::Capability: + return "Capability"; + case Op::TypeVoid: + return "TypeVoid"; + case Op::TypeBool: + return "TypeBool"; + case Op::TypeInt: + return "TypeInt"; + case Op::TypeFloat: + return "TypeFloat"; + case Op::TypeVector: + return "TypeVector"; + case Op::TypeMatrix: + return "TypeMatrix"; + case Op::TypeImage: + return "TypeImage"; + case Op::TypeSampler: + return "TypeSampler"; + case Op::TypeSampledImage: + return "TypeSampledImage"; + case Op::TypeArray: + return "TypeArray"; + case Op::TypeRuntimeArray: + return "TypeRuntimeArray"; + case Op::TypeStruct: + return "TypeStruct"; + case Op::TypeOpaque: + return "TypeOpaque"; + case Op::TypePointer: + return "TypePointer"; + case Op::TypeFunction: + return "TypeFunction"; + case Op::TypeEvent: + return "TypeEvent"; + case Op::TypeDeviceEvent: + return "TypeDeviceEvent"; + case Op::TypeReserveId: + return "TypeReserveId"; + case Op::TypeQueue: + return "TypeQueue"; + case Op::TypePipe: + return "TypePipe"; + case Op::TypeForwardPointer: + return "TypeForwardPointer"; + case Op::ConstantTrue: + return "ConstantTrue"; + case Op::ConstantFalse: + return "ConstantFalse"; + case Op::Constant: + return "Constant"; + case Op::ConstantComposite: + return "ConstantComposite"; + case Op::ConstantSampler: + return "ConstantSampler"; + case Op::ConstantNull: + return "ConstantNull"; + case Op::Function: + return "Function"; + case Op::FunctionParameter: + return "FunctionParameter"; + case Op::FunctionEnd: + return "FunctionEnd"; + case Op::FunctionCall: + return "FunctionCall"; + case Op::Variable: + return "Variable"; + case Op::ImageTexelPointer: + return "ImageTexelPointer"; + case Op::Load: + return "Load"; + case Op::Store: + return "Store"; + case Op::CopyMemory: + return "CopyMemory"; + case Op::CopyMemorySized: + return "CopyMemorySized"; + case Op::AccessChain: + return "AccessChain"; + case Op::InBoundsAccessChain: + return "InBoundsAccessChain"; + case Op::PtrAccessChain: + return "PtrAccessChain"; + case Op::ArrayLength: + return "ArrayLength"; + case Op::GenericPtrMemSemantics: + return "GenericPtrMemSemantics"; + case Op::InBoundsPtrAccessChain: + return "InBoundsPtrAccessChain"; + case Op::Decorate: + return "Decorate"; + case Op::MemberDecorate: + return "MemberDecorate"; + case Op::DecorationGroup: + return "DecorationGroup"; + case Op::GroupDecorate: + return "GroupDecorate"; + case Op::GroupMemberDecorate: + return "GroupMemberDecorate"; + case Op::VectorExtractDynamic: + return "VectorExtractDynamic"; + case Op::VectorInsertDynamic: + return "VectorInsertDynamic"; + case Op::VectorShuffle: + return "VectorShuffle"; + case Op::CompositeConstruct: + return "CompositeConstruct"; + case Op::CompositeExtract: + return "CompositeExtract"; + case Op::CompositeInsert: + return "CompositeInsert"; + case Op::CopyObject: + return "CopyObject"; + case Op::Transpose: + return "Transpose"; + case Op::SampledImage: + return "SampledImage"; + case Op::ImageSampleImplicitLod: + return "ImageSampleImplicitLod"; + case Op::ImageSampleExplicitLod: + return "ImageSampleExplicitLod"; + case Op::ImageSampleDrefImplicitLod: + return "ImageSampleDrefImplicitLod"; + case Op::ImageSampleDrefExplicitLod: + return "ImageSampleDrefExplicitLod"; + case Op::ImageSampleProjImplicitLod: + return "ImageSampleProjImplicitLod"; + case Op::ImageSampleProjExplicitLod: + return "ImageSampleProjExplicitLod"; + case Op::ImageSampleProjDrefImplicitLod: + return "ImageSampleProjDrefImplicitLod"; + case Op::ImageSampleProjDrefExplicitLod: + return "ImageSampleProjDrefExplicitLod"; + case Op::ImageFetch: + return "ImageFetch"; + case Op::ImageGather: + return "ImageGather"; + case Op::ImageDrefGather: + return "ImageDrefGather"; + case Op::ImageRead: + return "ImageRead"; + case Op::ImageWrite: + return "ImageWrite"; + case Op::Image: + return "Image"; + case Op::ImageQueryFormat: + return "ImageQueryFormat"; + case Op::ImageQueryOrder: + return "ImageQueryOrder"; + case Op::ImageQuerySizeLod: + return "ImageQuerySizeLod"; + case Op::ImageQuerySize: + return "ImageQuerySize"; + case Op::ImageQueryLod: + return "ImageQueryLod"; + case Op::ImageQueryLevels: + return "ImageQueryLevels"; + case Op::ImageQuerySamples: + return "ImageQuerySamples"; + case Op::ConvertFToU: + return "ConvertFToU"; + case Op::ConvertFToS: + return "ConvertFToS"; + case Op::ConvertSToF: + return "ConvertSToF"; + case Op::ConvertUToF: + return "ConvertUToF"; + case Op::UConvert: + return "UConvert"; + case Op::SConvert: + return "SConvert"; + case Op::FConvert: + return "FConvert"; + case Op::QuantizeToF16: + return "QuantizeToF16"; + case Op::ConvertPtrToU: + return "ConvertPtrToU"; + case Op::SatConvertSToU: + return "SatConvertSToU"; + case Op::SatConvertUToS: + return "SatConvertUToS"; + case Op::ConvertUToPtr: + return "ConvertUToPtr"; + case Op::PtrCastToGeneric: + return "PtrCastToGeneric"; + case Op::GenericCastToPtr: + return "GenericCastToPtr"; + case Op::GenericCastToPtrExplicit: + return "GenericCastToPtrExplicit"; + case Op::Bitcast: + return "Bitcast"; + case Op::SNegate: + return "SNegate"; + case Op::FNegate: + return "FNegate"; + case Op::IAdd: + return "IAdd"; + case Op::FAdd: + return "FAdd"; + case Op::ISub: + return "ISub"; + case Op::FSub: + return "FSub"; + case Op::IMul: + return "IMul"; + case Op::FMul: + return "FMul"; + case Op::UDiv: + return "UDiv"; + case Op::SDiv: + return "SDiv"; + case Op::FDiv: + return "FDiv"; + case Op::UMod: + return "UMod"; + case Op::SRem: + return "SRem"; + case Op::SMod: + return "SMod"; + case Op::FRem: + return "FRem"; + case Op::FMod: + return "FMod"; + case Op::VectorTimesScalar: + return "VectorTimesScalar"; + case Op::MatrixTimesScalar: + return "MatrixTimesScalar"; + case Op::VectorTimesMatrix: + return "VectorTimesMatrix"; + case Op::MatrixTimesVector: + return "MatrixTimesVector"; + case Op::MatrixTimesMatrix: + return "MatrixTimesMatrix"; + case Op::OuterProduct: + return "OuterProduct"; + case Op::Dot: + return "Dot"; + case Op::IAddCarry: + return "IAddCarry"; + case Op::ISubBorrow: + return "ISubBorrow"; + case Op::UMulExtended: + return "UMulExtended"; + case Op::SMulExtended: + return "SMulExtended"; + case Op::Any: + return "Any"; + case Op::All: + return "All"; + case Op::IsNan: + return "IsNan"; + case Op::IsInf: + return "IsInf"; + case Op::IsFinite: + return "IsFinite"; + case Op::IsNormal: + return "IsNormal"; + case Op::SignBitSet: + return "SignBitSet"; + case Op::LessOrGreater: + return "LessOrGreater"; + case Op::Ordered: + return "Ordered"; + case Op::Unordered: + return "Unordered"; + case Op::LogicalEqual: + return "LogicalEqual"; + case Op::LogicalNotEqual: + return "LogicalNotEqual"; + case Op::LogicalOr: + return "LogicalOr"; + case Op::LogicalAnd: + return "LogicalAnd"; + case Op::LogicalNot: + return "LogicalNot"; + case Op::Select: + return "Select"; + case Op::IEqual: + return "IEqual"; + case Op::INotEqual: + return "INotEqual"; + case Op::UGreaterThan: + return "UGreaterThan"; + case Op::SGreaterThan: + return "SGreaterThan"; + case Op::UGreaterThanEqual: + return "UGreaterThanEqual"; + case Op::SGreaterThanEqual: + return "SGreaterThanEqual"; + case Op::ULessThan: + return "ULessThan"; + case Op::SLessThan: + return "SLessThan"; + case Op::ULessThanEqual: + return "ULessThanEqual"; + case Op::SLessThanEqual: + return "SLessThanEqual"; + case Op::FOrdEqual: + return "FOrdEqual"; + case Op::FUnordEqual: + return "FUnordEqual"; + case Op::FOrdNotEqual: + return "FOrdNotEqual"; + case Op::FUnordNotEqual: + return "FUnordNotEqual"; + case Op::FOrdLessThan: + return "FOrdLessThan"; + case Op::FUnordLessThan: + return "FUnordLessThan"; + case Op::FOrdGreaterThan: + return "FOrdGreaterThan"; + case Op::FUnordGreaterThan: + return "FUnordGreaterThan"; + case Op::FOrdLessThanEqual: + return "FOrdLessThanEqual"; + case Op::FUnordLessThanEqual: + return "FUnordLessThanEqual"; + case Op::FOrdGreaterThanEqual: + return "FOrdGreaterThanEqual"; + case Op::FUnordGreaterThanEqual: + return "FUnordGreaterThanEqual"; + case Op::ShiftRightLogical: + return "ShiftRightLogical"; + case Op::ShiftRightArithmetic: + return "ShiftRightArithmetic"; + case Op::ShiftLeftLogical: + return "ShiftLeftLogical"; + case Op::BitwiseOr: + return "BitwiseOr"; + case Op::BitwiseXor: + return "BitwiseXor"; + case Op::BitwiseAnd: + return "BitwiseAnd"; + case Op::Not: + return "Not"; + case Op::BitFieldInsert: + return "BitFieldInsert"; + case Op::BitFieldSExtract: + return "BitFieldSExtract"; + case Op::BitFieldUExtract: + return "BitFieldUExtract"; + case Op::BitReverse: + return "BitReverse"; + case Op::BitCount: + return "BitCount"; + case Op::DPdx: + return "DPdx"; + case Op::DPdy: + return "DPdy"; + case Op::Fwidth: + return "Fwidth"; + case Op::DPdxFine: + return "DPdxFine"; + case Op::DPdyFine: + return "DPdyFine"; + case Op::FwidthFine: + return "FwidthFine"; + case Op::DPdxCoarse: + return "DPdxCoarse"; + case Op::DPdyCoarse: + return "DPdyCoarse"; + case Op::FwidthCoarse: + return "FwidthCoarse"; + case Op::EmitVertex: + return "EmitVertex"; + case Op::EndPrimitive: + return "EndPrimitive"; + case Op::EmitStreamVertex: + return "EmitStreamVertex"; + case Op::EndStreamPrimitive: + return "EndStreamPrimitive"; + case Op::ControlBarrier: + return "ControlBarrier"; + case Op::MemoryBarrier: + return "MemoryBarrier"; + case Op::AtomicLoad: + return "AtomicLoad"; + case Op::AtomicStore: + return "AtomicStore"; + case Op::AtomicExchange: + return "AtomicExchange"; + case Op::AtomicCompareExchange: + return "AtomicCompareExchange"; + case Op::AtomicCompareExchangeWeak: + return "AtomicCompareExchangeWeak"; + case Op::AtomicIIncrement: + return "AtomicIIncrement"; + case Op::AtomicIDecrement: + return "AtomicIDecrement"; + case Op::AtomicIAdd: + return "AtomicIAdd"; + case Op::AtomicISub: + return "AtomicISub"; + case Op::AtomicSMin: + return "AtomicSMin"; + case Op::AtomicUMin: + return "AtomicUMin"; + case Op::AtomicSMax: + return "AtomicSMax"; + case Op::AtomicUMax: + return "AtomicUMax"; + case Op::AtomicAnd: + return "AtomicAnd"; + case Op::AtomicOr: + return "AtomicOr"; + case Op::AtomicXor: + return "AtomicXor"; + case Op::Phi: + return "Phi"; + case Op::LoopMerge: + return "LoopMerge"; + case Op::SelectionMerge: + return "SelectionMerge"; + case Op::Label: + return "Label"; + case Op::Branch: + return "Branch"; + case Op::BranchConditional: + return "BranchConditional"; + case Op::Switch: + return "Switch"; + case Op::Kill: + return "Kill"; + case Op::Return: + return "Return"; + case Op::ReturnValue: + return "ReturnValue"; + case Op::Unreachable: + return "Unreachable"; + case Op::LifetimeStart: + return "LifetimeStart"; + case Op::LifetimeStop: + return "LifetimeStop"; + case Op::GroupAsyncCopy: + return "GroupAsyncCopy"; + case Op::GroupWaitEvents: + return "GroupWaitEvents"; + case Op::GroupAll: + return "GroupAll"; + case Op::GroupAny: + return "GroupAny"; + case Op::GroupBroadcast: + return "GroupBroadcast"; + case Op::GroupIAdd: + return "GroupIAdd"; + case Op::GroupFAdd: + return "GroupFAdd"; + case Op::GroupFMin: + return "GroupFMin"; + case Op::GroupUMin: + return "GroupUMin"; + case Op::GroupSMin: + return "GroupSMin"; + case Op::GroupFMax: + return "GroupFMax"; + case Op::GroupUMax: + return "GroupUMax"; + case Op::GroupSMax: + return "GroupSMax"; + case Op::ReadPipe: + return "ReadPipe"; + case Op::WritePipe: + return "WritePipe"; + case Op::ReservedReadPipe: + return "ReservedReadPipe"; + case Op::ReservedWritePipe: + return "ReservedWritePipe"; + case Op::ReserveReadPipePackets: + return "ReserveReadPipePackets"; + case Op::ReserveWritePipePackets: + return "ReserveWritePipePackets"; + case Op::CommitReadPipe: + return "CommitReadPipe"; + case Op::CommitWritePipe: + return "CommitWritePipe"; + case Op::IsValidReserveId: + return "IsValidReserveId"; + case Op::GetNumPipePackets: + return "GetNumPipePackets"; + case Op::GetMaxPipePackets: + return "GetMaxPipePackets"; + case Op::GroupReserveReadPipePackets: + return "GroupReserveReadPipePackets"; + case Op::GroupReserveWritePipePackets: + return "GroupReserveWritePipePackets"; + case Op::GroupCommitReadPipe: + return "GroupCommitReadPipe"; + case Op::GroupCommitWritePipe: + return "GroupCommitWritePipe"; + case Op::EnqueueMarker: + return "EnqueueMarker"; + case Op::EnqueueKernel: + return "EnqueueKernel"; + case Op::GetKernelNDrangeSubGroupCount: + return "GetKernelNDrangeSubGroupCount"; + case Op::GetKernelNDrangeMaxSubGroupSize: + return "GetKernelNDrangeMaxSubGroupSize"; + case Op::GetKernelWorkGroupSize: + return "GetKernelWorkGroupSize"; + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return "GetKernelPreferredWorkGroupSizeMultiple"; + case Op::RetainEvent: + return "RetainEvent"; + case Op::ReleaseEvent: + return "ReleaseEvent"; + case Op::CreateUserEvent: + return "CreateUserEvent"; + case Op::IsValidEvent: + return "IsValidEvent"; + case Op::SetUserEventStatus: + return "SetUserEventStatus"; + case Op::CaptureEventProfilingInfo: + return "CaptureEventProfilingInfo"; + case Op::GetDefaultQueue: + return "GetDefaultQueue"; + case Op::BuildNDRange: + return "BuildNDRange"; + case Op::ImageSparseSampleImplicitLod: + return "ImageSparseSampleImplicitLod"; + case Op::ImageSparseSampleExplicitLod: + return "ImageSparseSampleExplicitLod"; + case Op::ImageSparseSampleDrefImplicitLod: + return "ImageSparseSampleDrefImplicitLod"; + case Op::ImageSparseSampleDrefExplicitLod: + return "ImageSparseSampleDrefExplicitLod"; + case Op::ImageSparseSampleProjImplicitLod: + return "ImageSparseSampleProjImplicitLod"; + case Op::ImageSparseSampleProjExplicitLod: + return "ImageSparseSampleProjExplicitLod"; + case Op::ImageSparseSampleProjDrefImplicitLod: + return "ImageSparseSampleProjDrefImplicitLod"; + case Op::ImageSparseSampleProjDrefExplicitLod: + return "ImageSparseSampleProjDrefExplicitLod"; + case Op::ImageSparseFetch: + return "ImageSparseFetch"; + case Op::ImageSparseGather: + return "ImageSparseGather"; + case Op::ImageSparseDrefGather: + return "ImageSparseDrefGather"; + case Op::ImageSparseTexelsResident: + return "ImageSparseTexelsResident"; + case Op::NoLine: + return "NoLine"; + case Op::AtomicFlagTestAndSet: + return "AtomicFlagTestAndSet"; + case Op::AtomicFlagClear: + return "AtomicFlagClear"; + case Op::ImageSparseRead: + return "ImageSparseRead"; + case Op::SizeOf: + return "SizeOf"; + case Op::TypePipeStorage: + return "TypePipeStorage"; + case Op::ConstantPipeStorage: + return "ConstantPipeStorage"; + case Op::CreatePipeFromPipeStorage: + return "CreatePipeFromPipeStorage"; + case Op::GetKernelLocalSizeForSubgroupCount: + return "GetKernelLocalSizeForSubgroupCount"; + case Op::GetKernelMaxNumSubgroups: + return "GetKernelMaxNumSubgroups"; + case Op::TypeNamedBarrier: + return "TypeNamedBarrier"; + case Op::NamedBarrierInitialize: + return "NamedBarrierInitialize"; + case Op::MemoryNamedBarrier: + return "MemoryNamedBarrier"; + case Op::ModuleProcessed: + return "ModuleProcessed"; + case Op::ExecutionModeId: + return "ExecutionModeId"; + case Op::DecorateId: + return "DecorateId"; + case Op::GroupNonUniformElect: + return "GroupNonUniformElect"; + case Op::GroupNonUniformAll: + return "GroupNonUniformAll"; + case Op::GroupNonUniformAny: + return "GroupNonUniformAny"; + case Op::GroupNonUniformAllEqual: + return "GroupNonUniformAllEqual"; + case Op::GroupNonUniformBroadcast: + return "GroupNonUniformBroadcast"; + case Op::GroupNonUniformBroadcastFirst: + return "GroupNonUniformBroadcastFirst"; + case Op::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Op::GroupNonUniformInverseBallot: + return "GroupNonUniformInverseBallot"; + case Op::GroupNonUniformBallotBitExtract: + return "GroupNonUniformBallotBitExtract"; + case Op::GroupNonUniformBallotBitCount: + return "GroupNonUniformBallotBitCount"; + case Op::GroupNonUniformBallotFindLSB: + return "GroupNonUniformBallotFindLSB"; + case Op::GroupNonUniformBallotFindMSB: + return "GroupNonUniformBallotFindMSB"; + case Op::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Op::GroupNonUniformShuffleXor: + return "GroupNonUniformShuffleXor"; + case Op::GroupNonUniformShuffleUp: + return "GroupNonUniformShuffleUp"; + case Op::GroupNonUniformShuffleDown: + return "GroupNonUniformShuffleDown"; + case Op::GroupNonUniformIAdd: + return "GroupNonUniformIAdd"; + case Op::GroupNonUniformFAdd: + return "GroupNonUniformFAdd"; + case Op::GroupNonUniformIMul: + return "GroupNonUniformIMul"; + case Op::GroupNonUniformFMul: + return "GroupNonUniformFMul"; + case Op::GroupNonUniformSMin: + return "GroupNonUniformSMin"; + case Op::GroupNonUniformUMin: + return "GroupNonUniformUMin"; + case Op::GroupNonUniformFMin: + return "GroupNonUniformFMin"; + case Op::GroupNonUniformSMax: + return "GroupNonUniformSMax"; + case Op::GroupNonUniformUMax: + return "GroupNonUniformUMax"; + case Op::GroupNonUniformFMax: + return "GroupNonUniformFMax"; + case Op::GroupNonUniformBitwiseAnd: + return "GroupNonUniformBitwiseAnd"; + case Op::GroupNonUniformBitwiseOr: + return "GroupNonUniformBitwiseOr"; + case Op::GroupNonUniformBitwiseXor: + return "GroupNonUniformBitwiseXor"; + case Op::GroupNonUniformLogicalAnd: + return "GroupNonUniformLogicalAnd"; + case Op::GroupNonUniformLogicalOr: + return "GroupNonUniformLogicalOr"; + case Op::GroupNonUniformLogicalXor: + return "GroupNonUniformLogicalXor"; + case Op::GroupNonUniformQuadBroadcast: + return "GroupNonUniformQuadBroadcast"; + case Op::GroupNonUniformQuadSwap: + return "GroupNonUniformQuadSwap"; + case Op::CopyLogical: + return "CopyLogical"; + case Op::PtrEqual: + return "PtrEqual"; + case Op::PtrNotEqual: + return "PtrNotEqual"; + case Op::PtrDiff: + return "PtrDiff"; + case Op::TypeCooperativeMatrixKHR: + return "TypeCooperativeMatrixKHR"; + case Op::CooperativeMatrixLoadKHR: + return "CooperativeMatrixLoadKHR"; + case Op::CooperativeMatrixStoreKHR: + return "CooperativeMatrixStoreKHR"; + case Op::CooperativeMatrixMulAddKHR: + return "CooperativeMatrixMulAddKHR"; + case Op::CooperativeMatrixLengthKHR: + return "CooperativeMatrixLengthKHR"; + case Op::SubgroupBlockReadINTEL: + return "SubgroupBlockReadINTEL"; + case Op::SubgroupBlockWriteINTEL: + return "SubgroupBlockWriteINTEL"; + case Op::AsmTargetINTEL: + return "AsmTargetINTEL"; + case Op::AsmINTEL: + return "AsmINTEL"; + case Op::AsmCallINTEL: + return "AsmCallINTEL"; + case Op::AtomicFMinEXT: + return "AtomicFMinEXT"; + case Op::AtomicFMaxEXT: + return "AtomicFMaxEXT"; + case Op::AtomicFAddEXT: + return "AtomicFAddEXT"; + case Op::ConvertFToBF16INTEL: + return "ConvertFToBF16INTEL"; + case Op::ConvertBF16ToFINTEL: + return "ConvertBF16ToFINTEL"; + case Op::ControlBarrierArriveINTEL: + return "ControlBarrierArriveINTEL"; + case Op::ControlBarrierWaitINTEL: + return "ControlBarrierWaitINTEL"; + case Op::CooperativeMatrixLoadCheckedINTEL: + return "CooperativeMatrixLoadCheckedINTEL"; + case Op::CooperativeMatrixStoreCheckedINTEL: + return "CooperativeMatrixStoreCheckedINTEL"; + } + return "unknown"; +} +auto to_string(ImageOperands e) -> char const * { + switch (e) { + case ImageOperands::None: + return "None"; + case ImageOperands::Bias: + return "Bias"; + case ImageOperands::Lod: + return "Lod"; + case ImageOperands::Grad: + return "Grad"; + case ImageOperands::ConstOffset: + return "ConstOffset"; + case ImageOperands::Offset: + return "Offset"; + case ImageOperands::ConstOffsets: + return "ConstOffsets"; + case ImageOperands::Sample: + return "Sample"; + case ImageOperands::MinLod: + return "MinLod"; + case ImageOperands::MakeTexelAvailable: + return "MakeTexelAvailable"; + case ImageOperands::MakeTexelVisible: + return "MakeTexelVisible"; + case ImageOperands::NonPrivateTexel: + return "NonPrivateTexel"; + case ImageOperands::VolatileTexel: + return "VolatileTexel"; + case ImageOperands::SignExtend: + return "SignExtend"; + case ImageOperands::ZeroExtend: + return "ZeroExtend"; + case ImageOperands::Nontemporal: + return "Nontemporal"; + case ImageOperands::Offsets: + return "Offsets"; + } + return "unknown"; +} +auto to_string(FPFastMathMode e) -> char const * { + switch (e) { + case FPFastMathMode::None: + return "None"; + case FPFastMathMode::NotNaN: + return "NotNaN"; + case FPFastMathMode::NotInf: + return "NotInf"; + case FPFastMathMode::NSZ: + return "NSZ"; + case FPFastMathMode::AllowRecip: + return "AllowRecip"; + case FPFastMathMode::Fast: + return "Fast"; + case FPFastMathMode::AllowContract: + return "AllowContract"; + case FPFastMathMode::AllowReassoc: + return "AllowReassoc"; + case FPFastMathMode::AllowTransform: + return "AllowTransform"; + } + return "unknown"; +} +auto to_string(SelectionControl e) -> char const * { + switch (e) { + case SelectionControl::None: + return "None"; + case SelectionControl::Flatten: + return "Flatten"; + case SelectionControl::DontFlatten: + return "DontFlatten"; + } + return "unknown"; +} +auto to_string(LoopControl e) -> char const * { + switch (e) { + case LoopControl::None: + return "None"; + case LoopControl::Unroll: + return "Unroll"; + case LoopControl::DontUnroll: + return "DontUnroll"; + case LoopControl::DependencyInfinite: + return "DependencyInfinite"; + case LoopControl::DependencyLength: + return "DependencyLength"; + case LoopControl::MinIterations: + return "MinIterations"; + case LoopControl::MaxIterations: + return "MaxIterations"; + case LoopControl::IterationMultiple: + return "IterationMultiple"; + case LoopControl::PeelCount: + return "PeelCount"; + case LoopControl::PartialCount: + return "PartialCount"; + case LoopControl::InitiationIntervalINTEL: + return "InitiationIntervalINTEL"; + case LoopControl::MaxConcurrencyINTEL: + return "MaxConcurrencyINTEL"; + case LoopControl::DependencyArrayINTEL: + return "DependencyArrayINTEL"; + case LoopControl::PipelineEnableINTEL: + return "PipelineEnableINTEL"; + case LoopControl::LoopCoalesceINTEL: + return "LoopCoalesceINTEL"; + case LoopControl::MaxInterleavingINTEL: + return "MaxInterleavingINTEL"; + case LoopControl::SpeculatedIterationsINTEL: + return "SpeculatedIterationsINTEL"; + case LoopControl::NoFusionINTEL: + return "NoFusionINTEL"; + case LoopControl::LoopCountINTEL: + return "LoopCountINTEL"; + case LoopControl::MaxReinvocationDelayINTEL: + return "MaxReinvocationDelayINTEL"; + } + return "unknown"; +} +auto to_string(FunctionControl e) -> char const * { + switch (e) { + case FunctionControl::None: + return "None"; + case FunctionControl::Inline: + return "Inline"; + case FunctionControl::DontInline: + return "DontInline"; + case FunctionControl::Pure: + return "Pure"; + case FunctionControl::Const: + return "Const"; + case FunctionControl::OptNoneEXT: + return "OptNoneEXT"; + } + return "unknown"; +} +auto to_string(MemorySemantics e) -> char const * { + switch (e) { + case MemorySemantics::Relaxed: + return "Relaxed"; + case MemorySemantics::Acquire: + return "Acquire"; + case MemorySemantics::Release: + return "Release"; + case MemorySemantics::AcquireRelease: + return "AcquireRelease"; + case MemorySemantics::SequentiallyConsistent: + return "SequentiallyConsistent"; + case MemorySemantics::UniformMemory: + return "UniformMemory"; + case MemorySemantics::SubgroupMemory: + return "SubgroupMemory"; + case MemorySemantics::WorkgroupMemory: + return "WorkgroupMemory"; + case MemorySemantics::CrossWorkgroupMemory: + return "CrossWorkgroupMemory"; + case MemorySemantics::AtomicCounterMemory: + return "AtomicCounterMemory"; + case MemorySemantics::ImageMemory: + return "ImageMemory"; + case MemorySemantics::OutputMemory: + return "OutputMemory"; + case MemorySemantics::MakeAvailable: + return "MakeAvailable"; + case MemorySemantics::MakeVisible: + return "MakeVisible"; + case MemorySemantics::Volatile: + return "Volatile"; + } + return "unknown"; +} +auto to_string(MemoryAccess e) -> char const * { + switch (e) { + case MemoryAccess::None: + return "None"; + case MemoryAccess::Volatile: + return "Volatile"; + case MemoryAccess::Aligned: + return "Aligned"; + case MemoryAccess::Nontemporal: + return "Nontemporal"; + case MemoryAccess::MakePointerAvailable: + return "MakePointerAvailable"; + case MemoryAccess::MakePointerVisible: + return "MakePointerVisible"; + case MemoryAccess::NonPrivatePointer: + return "NonPrivatePointer"; + case MemoryAccess::AliasScopeINTELMask: + return "AliasScopeINTELMask"; + case MemoryAccess::NoAliasINTELMask: + return "NoAliasINTELMask"; + } + return "unknown"; +} +auto to_string(KernelProfilingInfo e) -> char const * { + switch (e) { + case KernelProfilingInfo::None: + return "None"; + case KernelProfilingInfo::CmdExecTime: + return "CmdExecTime"; + } + return "unknown"; +} +auto to_string(RayFlags e) -> char const * { + switch (e) { + case RayFlags::NoneKHR: + return "NoneKHR"; + case RayFlags::OpaqueKHR: + return "OpaqueKHR"; + case RayFlags::NoOpaqueKHR: + return "NoOpaqueKHR"; + case RayFlags::TerminateOnFirstHitKHR: + return "TerminateOnFirstHitKHR"; + case RayFlags::SkipClosestHitShaderKHR: + return "SkipClosestHitShaderKHR"; + case RayFlags::CullBackFacingTrianglesKHR: + return "CullBackFacingTrianglesKHR"; + case RayFlags::CullFrontFacingTrianglesKHR: + return "CullFrontFacingTrianglesKHR"; + case RayFlags::CullOpaqueKHR: + return "CullOpaqueKHR"; + case RayFlags::CullNoOpaqueKHR: + return "CullNoOpaqueKHR"; + case RayFlags::SkipTrianglesKHR: + return "SkipTrianglesKHR"; + case RayFlags::SkipAABBsKHR: + return "SkipAABBsKHR"; + case RayFlags::ForceOpacityMicromap2StateEXT: + return "ForceOpacityMicromap2StateEXT"; + } + return "unknown"; +} +auto to_string(FragmentShadingRate e) -> char const * { + switch (e) { + case FragmentShadingRate::Vertical2Pixels: + return "Vertical2Pixels"; + case FragmentShadingRate::Vertical4Pixels: + return "Vertical4Pixels"; + case FragmentShadingRate::Horizontal2Pixels: + return "Horizontal2Pixels"; + case FragmentShadingRate::Horizontal4Pixels: + return "Horizontal4Pixels"; + } + return "unknown"; +} +auto to_string(RawAccessChainOperands e) -> char const * { + switch (e) { + case RawAccessChainOperands::None: + return "None"; + case RawAccessChainOperands::RobustnessPerComponentNV: + return "RobustnessPerComponentNV"; + case RawAccessChainOperands::RobustnessPerElementNV: + return "RobustnessPerElementNV"; + } + return "unknown"; +} +auto to_string(SourceLanguage e) -> char const * { + switch (e) { + case SourceLanguage::Unknown: + return "Unknown"; + case SourceLanguage::ESSL: + return "ESSL"; + case SourceLanguage::GLSL: + return "GLSL"; + case SourceLanguage::OpenCL_C: + return "OpenCL_C"; + case SourceLanguage::OpenCL_CPP: + return "OpenCL_CPP"; + case SourceLanguage::HLSL: + return "HLSL"; + case SourceLanguage::CPP_for_OpenCL: + return "CPP_for_OpenCL"; + case SourceLanguage::SYCL: + return "SYCL"; + case SourceLanguage::HERO_C: + return "HERO_C"; + case SourceLanguage::NZSL: + return "NZSL"; + case SourceLanguage::WGSL: + return "WGSL"; + case SourceLanguage::Slang: + return "Slang"; + case SourceLanguage::Zig: + return "Zig"; + case SourceLanguage::Rust: + return "Rust"; + } + return "unknown"; +} +auto to_string(ExecutionModel e) -> char const * { + switch (e) { + case ExecutionModel::Vertex: + return "Vertex"; + case ExecutionModel::TessellationControl: + return "TessellationControl"; + case ExecutionModel::TessellationEvaluation: + return "TessellationEvaluation"; + case ExecutionModel::Geometry: + return "Geometry"; + case ExecutionModel::Fragment: + return "Fragment"; + case ExecutionModel::GLCompute: + return "GLCompute"; + case ExecutionModel::Kernel: + return "Kernel"; + case ExecutionModel::TaskNV: + return "TaskNV"; + case ExecutionModel::MeshNV: + return "MeshNV"; + case ExecutionModel::RayGenerationKHR: + return "RayGenerationKHR"; + case ExecutionModel::IntersectionKHR: + return "IntersectionKHR"; + case ExecutionModel::AnyHitKHR: + return "AnyHitKHR"; + case ExecutionModel::ClosestHitKHR: + return "ClosestHitKHR"; + case ExecutionModel::MissKHR: + return "MissKHR"; + case ExecutionModel::CallableKHR: + return "CallableKHR"; + case ExecutionModel::TaskEXT: + return "TaskEXT"; + case ExecutionModel::MeshEXT: + return "MeshEXT"; + } + return "unknown"; +} +auto to_string(AddressingModel e) -> char const * { + switch (e) { + case AddressingModel::Logical: + return "Logical"; + case AddressingModel::Physical32: + return "Physical32"; + case AddressingModel::Physical64: + return "Physical64"; + case AddressingModel::PhysicalStorageBuffer64: + return "PhysicalStorageBuffer64"; + } + return "unknown"; +} +auto to_string(MemoryModel e) -> char const * { + switch (e) { + case MemoryModel::Simple: + return "Simple"; + case MemoryModel::GLSL450: + return "GLSL450"; + case MemoryModel::OpenCL: + return "OpenCL"; + case MemoryModel::Vulkan: + return "Vulkan"; + } + return "unknown"; +} +auto to_string(ExecutionMode e) -> char const * { + switch (e) { + case ExecutionMode::Invocations: + return "Invocations"; + case ExecutionMode::SpacingEqual: + return "SpacingEqual"; + case ExecutionMode::SpacingFractionalEven: + return "SpacingFractionalEven"; + case ExecutionMode::SpacingFractionalOdd: + return "SpacingFractionalOdd"; + case ExecutionMode::VertexOrderCw: + return "VertexOrderCw"; + case ExecutionMode::VertexOrderCcw: + return "VertexOrderCcw"; + case ExecutionMode::PixelCenterInteger: + return "PixelCenterInteger"; + case ExecutionMode::OriginUpperLeft: + return "OriginUpperLeft"; + case ExecutionMode::OriginLowerLeft: + return "OriginLowerLeft"; + case ExecutionMode::EarlyFragmentTests: + return "EarlyFragmentTests"; + case ExecutionMode::PointMode: + return "PointMode"; + case ExecutionMode::Xfb: + return "Xfb"; + case ExecutionMode::DepthReplacing: + return "DepthReplacing"; + case ExecutionMode::DepthGreater: + return "DepthGreater"; + case ExecutionMode::DepthLess: + return "DepthLess"; + case ExecutionMode::DepthUnchanged: + return "DepthUnchanged"; + case ExecutionMode::LocalSize: + return "LocalSize"; + case ExecutionMode::LocalSizeHint: + return "LocalSizeHint"; + case ExecutionMode::InputPoints: + return "InputPoints"; + case ExecutionMode::InputLines: + return "InputLines"; + case ExecutionMode::InputLinesAdjacency: + return "InputLinesAdjacency"; + case ExecutionMode::Triangles: + return "Triangles"; + case ExecutionMode::InputTrianglesAdjacency: + return "InputTrianglesAdjacency"; + case ExecutionMode::Quads: + return "Quads"; + case ExecutionMode::Isolines: + return "Isolines"; + case ExecutionMode::OutputVertices: + return "OutputVertices"; + case ExecutionMode::OutputPoints: + return "OutputPoints"; + case ExecutionMode::OutputLineStrip: + return "OutputLineStrip"; + case ExecutionMode::OutputTriangleStrip: + return "OutputTriangleStrip"; + case ExecutionMode::VecTypeHint: + return "VecTypeHint"; + case ExecutionMode::ContractionOff: + return "ContractionOff"; + case ExecutionMode::Initializer: + return "Initializer"; + case ExecutionMode::Finalizer: + return "Finalizer"; + case ExecutionMode::SubgroupSize: + return "SubgroupSize"; + case ExecutionMode::SubgroupsPerWorkgroup: + return "SubgroupsPerWorkgroup"; + case ExecutionMode::SubgroupsPerWorkgroupId: + return "SubgroupsPerWorkgroupId"; + case ExecutionMode::LocalSizeId: + return "LocalSizeId"; + case ExecutionMode::LocalSizeHintId: + return "LocalSizeHintId"; + case ExecutionMode::NonCoherentColorAttachmentReadEXT: + return "NonCoherentColorAttachmentReadEXT"; + case ExecutionMode::NonCoherentDepthAttachmentReadEXT: + return "NonCoherentDepthAttachmentReadEXT"; + case ExecutionMode::NonCoherentStencilAttachmentReadEXT: + return "NonCoherentStencilAttachmentReadEXT"; + case ExecutionMode::SubgroupUniformControlFlowKHR: + return "SubgroupUniformControlFlowKHR"; + case ExecutionMode::PostDepthCoverage: + return "PostDepthCoverage"; + case ExecutionMode::DenormPreserve: + return "DenormPreserve"; + case ExecutionMode::DenormFlushToZero: + return "DenormFlushToZero"; + case ExecutionMode::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case ExecutionMode::RoundingModeRTE: + return "RoundingModeRTE"; + case ExecutionMode::RoundingModeRTZ: + return "RoundingModeRTZ"; + case ExecutionMode::NonCoherentTileAttachmentReadQCOM: + return "NonCoherentTileAttachmentReadQCOM"; + case ExecutionMode::TileShadingRateQCOM: + return "TileShadingRateQCOM"; + case ExecutionMode::EarlyAndLateFragmentTestsAMD: + return "EarlyAndLateFragmentTestsAMD"; + case ExecutionMode::StencilRefReplacingEXT: + return "StencilRefReplacingEXT"; + case ExecutionMode::CoalescingAMDX: + return "CoalescingAMDX"; + case ExecutionMode::IsApiEntryAMDX: + return "IsApiEntryAMDX"; + case ExecutionMode::MaxNodeRecursionAMDX: + return "MaxNodeRecursionAMDX"; + case ExecutionMode::StaticNumWorkgroupsAMDX: + return "StaticNumWorkgroupsAMDX"; + case ExecutionMode::ShaderIndexAMDX: + return "ShaderIndexAMDX"; + case ExecutionMode::MaxNumWorkgroupsAMDX: + return "MaxNumWorkgroupsAMDX"; + case ExecutionMode::StencilRefUnchangedFrontAMD: + return "StencilRefUnchangedFrontAMD"; + case ExecutionMode::StencilRefGreaterFrontAMD: + return "StencilRefGreaterFrontAMD"; + case ExecutionMode::StencilRefLessFrontAMD: + return "StencilRefLessFrontAMD"; + case ExecutionMode::StencilRefUnchangedBackAMD: + return "StencilRefUnchangedBackAMD"; + case ExecutionMode::StencilRefGreaterBackAMD: + return "StencilRefGreaterBackAMD"; + case ExecutionMode::StencilRefLessBackAMD: + return "StencilRefLessBackAMD"; + case ExecutionMode::QuadDerivativesKHR: + return "QuadDerivativesKHR"; + case ExecutionMode::RequireFullQuadsKHR: + return "RequireFullQuadsKHR"; + case ExecutionMode::SharesInputWithAMDX: + return "SharesInputWithAMDX"; + case ExecutionMode::OutputLinesEXT: + return "OutputLinesEXT"; + case ExecutionMode::OutputPrimitivesEXT: + return "OutputPrimitivesEXT"; + case ExecutionMode::DerivativeGroupQuadsKHR: + return "DerivativeGroupQuadsKHR"; + case ExecutionMode::DerivativeGroupLinearKHR: + return "DerivativeGroupLinearKHR"; + case ExecutionMode::OutputTrianglesEXT: + return "OutputTrianglesEXT"; + case ExecutionMode::PixelInterlockOrderedEXT: + return "PixelInterlockOrderedEXT"; + case ExecutionMode::PixelInterlockUnorderedEXT: + return "PixelInterlockUnorderedEXT"; + case ExecutionMode::SampleInterlockOrderedEXT: + return "SampleInterlockOrderedEXT"; + case ExecutionMode::SampleInterlockUnorderedEXT: + return "SampleInterlockUnorderedEXT"; + case ExecutionMode::ShadingRateInterlockOrderedEXT: + return "ShadingRateInterlockOrderedEXT"; + case ExecutionMode::ShadingRateInterlockUnorderedEXT: + return "ShadingRateInterlockUnorderedEXT"; + case ExecutionMode::SharedLocalMemorySizeINTEL: + return "SharedLocalMemorySizeINTEL"; + case ExecutionMode::RoundingModeRTPINTEL: + return "RoundingModeRTPINTEL"; + case ExecutionMode::RoundingModeRTNINTEL: + return "RoundingModeRTNINTEL"; + case ExecutionMode::FloatingPointModeALTINTEL: + return "FloatingPointModeALTINTEL"; + case ExecutionMode::FloatingPointModeIEEEINTEL: + return "FloatingPointModeIEEEINTEL"; + case ExecutionMode::MaxWorkgroupSizeINTEL: + return "MaxWorkgroupSizeINTEL"; + case ExecutionMode::MaxWorkDimINTEL: + return "MaxWorkDimINTEL"; + case ExecutionMode::NoGlobalOffsetINTEL: + return "NoGlobalOffsetINTEL"; + case ExecutionMode::NumSIMDWorkitemsINTEL: + return "NumSIMDWorkitemsINTEL"; + case ExecutionMode::SchedulerTargetFmaxMhzINTEL: + return "SchedulerTargetFmaxMhzINTEL"; + case ExecutionMode::MaximallyReconvergesKHR: + return "MaximallyReconvergesKHR"; + case ExecutionMode::FPFastMathDefault: + return "FPFastMathDefault"; + case ExecutionMode::StreamingInterfaceINTEL: + return "StreamingInterfaceINTEL"; + case ExecutionMode::RegisterMapInterfaceINTEL: + return "RegisterMapInterfaceINTEL"; + case ExecutionMode::NamedBarrierCountINTEL: + return "NamedBarrierCountINTEL"; + case ExecutionMode::MaximumRegistersINTEL: + return "MaximumRegistersINTEL"; + case ExecutionMode::MaximumRegistersIdINTEL: + return "MaximumRegistersIdINTEL"; + case ExecutionMode::NamedMaximumRegistersINTEL: + return "NamedMaximumRegistersINTEL"; + } + return "unknown"; +} +auto to_string(StorageClass e) -> char const * { + switch (e) { + case StorageClass::UniformConstant: + return "UniformConstant"; + case StorageClass::Input: + return "Input"; + case StorageClass::Uniform: + return "Uniform"; + case StorageClass::Output: + return "Output"; + case StorageClass::Workgroup: + return "Workgroup"; + case StorageClass::CrossWorkgroup: + return "CrossWorkgroup"; + case StorageClass::Private: + return "Private"; + case StorageClass::Function: + return "Function"; + case StorageClass::Generic: + return "Generic"; + case StorageClass::PushConstant: + return "PushConstant"; + case StorageClass::AtomicCounter: + return "AtomicCounter"; + case StorageClass::Image: + return "Image"; + case StorageClass::StorageBuffer: + return "StorageBuffer"; + case StorageClass::TileImageEXT: + return "TileImageEXT"; + case StorageClass::TileAttachmentQCOM: + return "TileAttachmentQCOM"; + case StorageClass::NodePayloadAMDX: + return "NodePayloadAMDX"; + case StorageClass::CallableDataKHR: + return "CallableDataKHR"; + case StorageClass::IncomingCallableDataKHR: + return "IncomingCallableDataKHR"; + case StorageClass::RayPayloadKHR: + return "RayPayloadKHR"; + case StorageClass::HitAttributeKHR: + return "HitAttributeKHR"; + case StorageClass::IncomingRayPayloadKHR: + return "IncomingRayPayloadKHR"; + case StorageClass::ShaderRecordBufferKHR: + return "ShaderRecordBufferKHR"; + case StorageClass::PhysicalStorageBuffer: + return "PhysicalStorageBuffer"; + case StorageClass::HitObjectAttributeNV: + return "HitObjectAttributeNV"; + case StorageClass::TaskPayloadWorkgroupEXT: + return "TaskPayloadWorkgroupEXT"; + case StorageClass::CodeSectionINTEL: + return "CodeSectionINTEL"; + case StorageClass::DeviceOnlyINTEL: + return "DeviceOnlyINTEL"; + case StorageClass::HostOnlyINTEL: + return "HostOnlyINTEL"; + } + return "unknown"; +} +auto to_string(Dim e) -> char const * { + switch (e) { + case Dim::Dim1D: + return "Dim1D"; + case Dim::Dim2D: + return "Dim2D"; + case Dim::Dim3D: + return "Dim3D"; + case Dim::Cube: + return "Cube"; + case Dim::Rect: + return "Rect"; + case Dim::Buffer: + return "Buffer"; + case Dim::SubpassData: + return "SubpassData"; + case Dim::TileImageDataEXT: + return "TileImageDataEXT"; + } + return "unknown"; +} +auto to_string(SamplerAddressingMode e) -> char const * { + switch (e) { + case SamplerAddressingMode::None: + return "None"; + case SamplerAddressingMode::ClampToEdge: + return "ClampToEdge"; + case SamplerAddressingMode::Clamp: + return "Clamp"; + case SamplerAddressingMode::Repeat: + return "Repeat"; + case SamplerAddressingMode::RepeatMirrored: + return "RepeatMirrored"; + } + return "unknown"; +} +auto to_string(SamplerFilterMode e) -> char const * { + switch (e) { + case SamplerFilterMode::Nearest: + return "Nearest"; + case SamplerFilterMode::Linear: + return "Linear"; + } + return "unknown"; +} +auto to_string(ImageFormat e) -> char const * { + switch (e) { + case ImageFormat::Unknown: + return "Unknown"; + case ImageFormat::Rgba32f: + return "Rgba32f"; + case ImageFormat::Rgba16f: + return "Rgba16f"; + case ImageFormat::R32f: + return "R32f"; + case ImageFormat::Rgba8: + return "Rgba8"; + case ImageFormat::Rgba8Snorm: + return "Rgba8Snorm"; + case ImageFormat::Rg32f: + return "Rg32f"; + case ImageFormat::Rg16f: + return "Rg16f"; + case ImageFormat::R11fG11fB10f: + return "R11fG11fB10f"; + case ImageFormat::R16f: + return "R16f"; + case ImageFormat::Rgba16: + return "Rgba16"; + case ImageFormat::Rgb10A2: + return "Rgb10A2"; + case ImageFormat::Rg16: + return "Rg16"; + case ImageFormat::Rg8: + return "Rg8"; + case ImageFormat::R16: + return "R16"; + case ImageFormat::R8: + return "R8"; + case ImageFormat::Rgba16Snorm: + return "Rgba16Snorm"; + case ImageFormat::Rg16Snorm: + return "Rg16Snorm"; + case ImageFormat::Rg8Snorm: + return "Rg8Snorm"; + case ImageFormat::R16Snorm: + return "R16Snorm"; + case ImageFormat::R8Snorm: + return "R8Snorm"; + case ImageFormat::Rgba32i: + return "Rgba32i"; + case ImageFormat::Rgba16i: + return "Rgba16i"; + case ImageFormat::Rgba8i: + return "Rgba8i"; + case ImageFormat::R32i: + return "R32i"; + case ImageFormat::Rg32i: + return "Rg32i"; + case ImageFormat::Rg16i: + return "Rg16i"; + case ImageFormat::Rg8i: + return "Rg8i"; + case ImageFormat::R16i: + return "R16i"; + case ImageFormat::R8i: + return "R8i"; + case ImageFormat::Rgba32ui: + return "Rgba32ui"; + case ImageFormat::Rgba16ui: + return "Rgba16ui"; + case ImageFormat::Rgba8ui: + return "Rgba8ui"; + case ImageFormat::R32ui: + return "R32ui"; + case ImageFormat::Rgb10a2ui: + return "Rgb10a2ui"; + case ImageFormat::Rg32ui: + return "Rg32ui"; + case ImageFormat::Rg16ui: + return "Rg16ui"; + case ImageFormat::Rg8ui: + return "Rg8ui"; + case ImageFormat::R16ui: + return "R16ui"; + case ImageFormat::R8ui: + return "R8ui"; + case ImageFormat::R64ui: + return "R64ui"; + case ImageFormat::R64i: + return "R64i"; + } + return "unknown"; +} +auto to_string(ImageChannelOrder e) -> char const * { + switch (e) { + case ImageChannelOrder::R: + return "R"; + case ImageChannelOrder::A: + return "A"; + case ImageChannelOrder::RG: + return "RG"; + case ImageChannelOrder::RA: + return "RA"; + case ImageChannelOrder::RGB: + return "RGB"; + case ImageChannelOrder::RGBA: + return "RGBA"; + case ImageChannelOrder::BGRA: + return "BGRA"; + case ImageChannelOrder::ARGB: + return "ARGB"; + case ImageChannelOrder::Intensity: + return "Intensity"; + case ImageChannelOrder::Luminance: + return "Luminance"; + case ImageChannelOrder::Rx: + return "Rx"; + case ImageChannelOrder::RGx: + return "RGx"; + case ImageChannelOrder::RGBx: + return "RGBx"; + case ImageChannelOrder::Depth: + return "Depth"; + case ImageChannelOrder::DepthStencil: + return "DepthStencil"; + case ImageChannelOrder::sRGB: + return "sRGB"; + case ImageChannelOrder::sRGBx: + return "sRGBx"; + case ImageChannelOrder::sRGBA: + return "sRGBA"; + case ImageChannelOrder::sBGRA: + return "sBGRA"; + case ImageChannelOrder::ABGR: + return "ABGR"; + } + return "unknown"; +} +auto to_string(ImageChannelDataType e) -> char const * { + switch (e) { + case ImageChannelDataType::SnormInt8: + return "SnormInt8"; + case ImageChannelDataType::SnormInt16: + return "SnormInt16"; + case ImageChannelDataType::UnormInt8: + return "UnormInt8"; + case ImageChannelDataType::UnormInt16: + return "UnormInt16"; + case ImageChannelDataType::UnormShort565: + return "UnormShort565"; + case ImageChannelDataType::UnormShort555: + return "UnormShort555"; + case ImageChannelDataType::UnormInt101010: + return "UnormInt101010"; + case ImageChannelDataType::SignedInt8: + return "SignedInt8"; + case ImageChannelDataType::SignedInt16: + return "SignedInt16"; + case ImageChannelDataType::SignedInt32: + return "SignedInt32"; + case ImageChannelDataType::UnsignedInt8: + return "UnsignedInt8"; + case ImageChannelDataType::UnsignedInt16: + return "UnsignedInt16"; + case ImageChannelDataType::UnsignedInt32: + return "UnsignedInt32"; + case ImageChannelDataType::HalfFloat: + return "HalfFloat"; + case ImageChannelDataType::Float: + return "Float"; + case ImageChannelDataType::UnormInt24: + return "UnormInt24"; + case ImageChannelDataType::UnormInt101010_2: + return "UnormInt101010_2"; + case ImageChannelDataType::UnormInt10X6EXT: + return "UnormInt10X6EXT"; + case ImageChannelDataType::UnsignedIntRaw10EXT: + return "UnsignedIntRaw10EXT"; + case ImageChannelDataType::UnsignedIntRaw12EXT: + return "UnsignedIntRaw12EXT"; + case ImageChannelDataType::UnormInt2_101010EXT: + return "UnormInt2_101010EXT"; + case ImageChannelDataType::UnsignedInt10X6EXT: + return "UnsignedInt10X6EXT"; + case ImageChannelDataType::UnsignedInt12X4EXT: + return "UnsignedInt12X4EXT"; + case ImageChannelDataType::UnsignedInt14X2EXT: + return "UnsignedInt14X2EXT"; + case ImageChannelDataType::UnormInt12X4EXT: + return "UnormInt12X4EXT"; + case ImageChannelDataType::UnormInt14X2EXT: + return "UnormInt14X2EXT"; + } + return "unknown"; +} +auto to_string(FPRoundingMode e) -> char const * { + switch (e) { + case FPRoundingMode::RTE: + return "RTE"; + case FPRoundingMode::RTZ: + return "RTZ"; + case FPRoundingMode::RTP: + return "RTP"; + case FPRoundingMode::RTN: + return "RTN"; + } + return "unknown"; +} +auto to_string(FPDenormMode e) -> char const * { + switch (e) { + case FPDenormMode::Preserve: + return "Preserve"; + case FPDenormMode::FlushToZero: + return "FlushToZero"; + } + return "unknown"; +} +auto to_string(QuantizationModes e) -> char const * { + switch (e) { + case QuantizationModes::TRN: + return "TRN"; + case QuantizationModes::TRN_ZERO: + return "TRN_ZERO"; + case QuantizationModes::RND: + return "RND"; + case QuantizationModes::RND_ZERO: + return "RND_ZERO"; + case QuantizationModes::RND_INF: + return "RND_INF"; + case QuantizationModes::RND_MIN_INF: + return "RND_MIN_INF"; + case QuantizationModes::RND_CONV: + return "RND_CONV"; + case QuantizationModes::RND_CONV_ODD: + return "RND_CONV_ODD"; + } + return "unknown"; +} +auto to_string(FPOperationMode e) -> char const * { + switch (e) { + case FPOperationMode::IEEE: + return "IEEE"; + case FPOperationMode::ALT: + return "ALT"; + } + return "unknown"; +} +auto to_string(OverflowModes e) -> char const * { + switch (e) { + case OverflowModes::WRAP: + return "WRAP"; + case OverflowModes::SAT: + return "SAT"; + case OverflowModes::SAT_ZERO: + return "SAT_ZERO"; + case OverflowModes::SAT_SYM: + return "SAT_SYM"; + } + return "unknown"; +} +auto to_string(LinkageType e) -> char const * { + switch (e) { + case LinkageType::Export: + return "Export"; + case LinkageType::Import: + return "Import"; + case LinkageType::LinkOnceODR: + return "LinkOnceODR"; + } + return "unknown"; +} +auto to_string(AccessQualifier e) -> char const * { + switch (e) { + case AccessQualifier::ReadOnly: + return "ReadOnly"; + case AccessQualifier::WriteOnly: + return "WriteOnly"; + case AccessQualifier::ReadWrite: + return "ReadWrite"; + } + return "unknown"; +} +auto to_string(HostAccessQualifier e) -> char const * { + switch (e) { + case HostAccessQualifier::NoneINTEL: + return "NoneINTEL"; + case HostAccessQualifier::ReadINTEL: + return "ReadINTEL"; + case HostAccessQualifier::WriteINTEL: + return "WriteINTEL"; + case HostAccessQualifier::ReadWriteINTEL: + return "ReadWriteINTEL"; + } + return "unknown"; +} +auto to_string(FunctionParameterAttribute e) -> char const * { + switch (e) { + case FunctionParameterAttribute::Zext: + return "Zext"; + case FunctionParameterAttribute::Sext: + return "Sext"; + case FunctionParameterAttribute::ByVal: + return "ByVal"; + case FunctionParameterAttribute::Sret: + return "Sret"; + case FunctionParameterAttribute::NoAlias: + return "NoAlias"; + case FunctionParameterAttribute::NoCapture: + return "NoCapture"; + case FunctionParameterAttribute::NoWrite: + return "NoWrite"; + case FunctionParameterAttribute::NoReadWrite: + return "NoReadWrite"; + case FunctionParameterAttribute::RuntimeAlignedINTEL: + return "RuntimeAlignedINTEL"; + } + return "unknown"; +} +auto to_string(Decoration e) -> char const * { + switch (e) { + case Decoration::RelaxedPrecision: + return "RelaxedPrecision"; + case Decoration::SpecId: + return "SpecId"; + case Decoration::Block: + return "Block"; + case Decoration::BufferBlock: + return "BufferBlock"; + case Decoration::RowMajor: + return "RowMajor"; + case Decoration::ColMajor: + return "ColMajor"; + case Decoration::ArrayStride: + return "ArrayStride"; + case Decoration::MatrixStride: + return "MatrixStride"; + case Decoration::GLSLShared: + return "GLSLShared"; + case Decoration::GLSLPacked: + return "GLSLPacked"; + case Decoration::CPacked: + return "CPacked"; + case Decoration::BuiltIn: + return "BuiltIn"; + case Decoration::NoPerspective: + return "NoPerspective"; + case Decoration::Flat: + return "Flat"; + case Decoration::Patch: + return "Patch"; + case Decoration::Centroid: + return "Centroid"; + case Decoration::Sample: + return "Sample"; + case Decoration::Invariant: + return "Invariant"; + case Decoration::Restrict: + return "Restrict"; + case Decoration::Aliased: + return "Aliased"; + case Decoration::Volatile: + return "Volatile"; + case Decoration::Constant: + return "Constant"; + case Decoration::Coherent: + return "Coherent"; + case Decoration::NonWritable: + return "NonWritable"; + case Decoration::NonReadable: + return "NonReadable"; + case Decoration::Uniform: + return "Uniform"; + case Decoration::UniformId: + return "UniformId"; + case Decoration::SaturatedConversion: + return "SaturatedConversion"; + case Decoration::Stream: + return "Stream"; + case Decoration::Location: + return "Location"; + case Decoration::Component: + return "Component"; + case Decoration::Index: + return "Index"; + case Decoration::Binding: + return "Binding"; + case Decoration::DescriptorSet: + return "DescriptorSet"; + case Decoration::Offset: + return "Offset"; + case Decoration::XfbBuffer: + return "XfbBuffer"; + case Decoration::XfbStride: + return "XfbStride"; + case Decoration::FuncParamAttr: + return "FuncParamAttr"; + case Decoration::FPRoundingMode: + return "FPRoundingMode"; + case Decoration::FPFastMathMode: + return "FPFastMathMode"; + case Decoration::LinkageAttributes: + return "LinkageAttributes"; + case Decoration::NoContraction: + return "NoContraction"; + case Decoration::InputAttachmentIndex: + return "InputAttachmentIndex"; + case Decoration::Alignment: + return "Alignment"; + case Decoration::MaxByteOffset: + return "MaxByteOffset"; + case Decoration::AlignmentId: + return "AlignmentId"; + case Decoration::MaxByteOffsetId: + return "MaxByteOffsetId"; + case Decoration::SaturatedToLargestFloat8NormalConversionEXT: + return "SaturatedToLargestFloat8NormalConversionEXT"; + case Decoration::NoSignedWrap: + return "NoSignedWrap"; + case Decoration::NoUnsignedWrap: + return "NoUnsignedWrap"; + case Decoration::WeightTextureQCOM: + return "WeightTextureQCOM"; + case Decoration::BlockMatchTextureQCOM: + return "BlockMatchTextureQCOM"; + case Decoration::BlockMatchSamplerQCOM: + return "BlockMatchSamplerQCOM"; + case Decoration::ExplicitInterpAMD: + return "ExplicitInterpAMD"; + case Decoration::NodeSharesPayloadLimitsWithAMDX: + return "NodeSharesPayloadLimitsWithAMDX"; + case Decoration::NodeMaxPayloadsAMDX: + return "NodeMaxPayloadsAMDX"; + case Decoration::TrackFinishWritingAMDX: + return "TrackFinishWritingAMDX"; + case Decoration::PayloadNodeNameAMDX: + return "PayloadNodeNameAMDX"; + case Decoration::PayloadNodeBaseIndexAMDX: + return "PayloadNodeBaseIndexAMDX"; + case Decoration::PayloadNodeSparseArrayAMDX: + return "PayloadNodeSparseArrayAMDX"; + case Decoration::PayloadNodeArraySizeAMDX: + return "PayloadNodeArraySizeAMDX"; + case Decoration::PayloadDispatchIndirectAMDX: + return "PayloadDispatchIndirectAMDX"; + case Decoration::OverrideCoverageNV: + return "OverrideCoverageNV"; + case Decoration::PassthroughNV: + return "PassthroughNV"; + case Decoration::ViewportRelativeNV: + return "ViewportRelativeNV"; + case Decoration::SecondaryViewportRelativeNV: + return "SecondaryViewportRelativeNV"; + case Decoration::PerPrimitiveEXT: + return "PerPrimitiveEXT"; + case Decoration::PerViewNV: + return "PerViewNV"; + case Decoration::PerTaskNV: + return "PerTaskNV"; + case Decoration::PerVertexKHR: + return "PerVertexKHR"; + case Decoration::NonUniform: + return "NonUniform"; + case Decoration::RestrictPointer: + return "RestrictPointer"; + case Decoration::AliasedPointer: + return "AliasedPointer"; + case Decoration::HitObjectShaderRecordBufferNV: + return "HitObjectShaderRecordBufferNV"; + case Decoration::BindlessSamplerNV: + return "BindlessSamplerNV"; + case Decoration::BindlessImageNV: + return "BindlessImageNV"; + case Decoration::BoundSamplerNV: + return "BoundSamplerNV"; + case Decoration::BoundImageNV: + return "BoundImageNV"; + case Decoration::SIMTCallINTEL: + return "SIMTCallINTEL"; + case Decoration::ReferencedIndirectlyINTEL: + return "ReferencedIndirectlyINTEL"; + case Decoration::ClobberINTEL: + return "ClobberINTEL"; + case Decoration::SideEffectsINTEL: + return "SideEffectsINTEL"; + case Decoration::VectorComputeVariableINTEL: + return "VectorComputeVariableINTEL"; + case Decoration::FuncParamIOKindINTEL: + return "FuncParamIOKindINTEL"; + case Decoration::VectorComputeFunctionINTEL: + return "VectorComputeFunctionINTEL"; + case Decoration::StackCallINTEL: + return "StackCallINTEL"; + case Decoration::GlobalVariableOffsetINTEL: + return "GlobalVariableOffsetINTEL"; + case Decoration::CounterBuffer: + return "CounterBuffer"; + case Decoration::UserSemantic: + return "UserSemantic"; + case Decoration::UserTypeGOOGLE: + return "UserTypeGOOGLE"; + case Decoration::FunctionRoundingModeINTEL: + return "FunctionRoundingModeINTEL"; + case Decoration::FunctionDenormModeINTEL: + return "FunctionDenormModeINTEL"; + case Decoration::RegisterINTEL: + return "RegisterINTEL"; + case Decoration::MemoryINTEL: + return "MemoryINTEL"; + case Decoration::NumbanksINTEL: + return "NumbanksINTEL"; + case Decoration::BankwidthINTEL: + return "BankwidthINTEL"; + case Decoration::MaxPrivateCopiesINTEL: + return "MaxPrivateCopiesINTEL"; + case Decoration::SinglepumpINTEL: + return "SinglepumpINTEL"; + case Decoration::DoublepumpINTEL: + return "DoublepumpINTEL"; + case Decoration::MaxReplicatesINTEL: + return "MaxReplicatesINTEL"; + case Decoration::SimpleDualPortINTEL: + return "SimpleDualPortINTEL"; + case Decoration::MergeINTEL: + return "MergeINTEL"; + case Decoration::BankBitsINTEL: + return "BankBitsINTEL"; + case Decoration::ForcePow2DepthINTEL: + return "ForcePow2DepthINTEL"; + case Decoration::StridesizeINTEL: + return "StridesizeINTEL"; + case Decoration::WordsizeINTEL: + return "WordsizeINTEL"; + case Decoration::TrueDualPortINTEL: + return "TrueDualPortINTEL"; + case Decoration::BurstCoalesceINTEL: + return "BurstCoalesceINTEL"; + case Decoration::CacheSizeINTEL: + return "CacheSizeINTEL"; + case Decoration::DontStaticallyCoalesceINTEL: + return "DontStaticallyCoalesceINTEL"; + case Decoration::PrefetchINTEL: + return "PrefetchINTEL"; + case Decoration::StallEnableINTEL: + return "StallEnableINTEL"; + case Decoration::FuseLoopsInFunctionINTEL: + return "FuseLoopsInFunctionINTEL"; + case Decoration::MathOpDSPModeINTEL: + return "MathOpDSPModeINTEL"; + case Decoration::AliasScopeINTEL: + return "AliasScopeINTEL"; + case Decoration::NoAliasINTEL: + return "NoAliasINTEL"; + case Decoration::InitiationIntervalINTEL: + return "InitiationIntervalINTEL"; + case Decoration::MaxConcurrencyINTEL: + return "MaxConcurrencyINTEL"; + case Decoration::PipelineEnableINTEL: + return "PipelineEnableINTEL"; + case Decoration::BufferLocationINTEL: + return "BufferLocationINTEL"; + case Decoration::IOPipeStorageINTEL: + return "IOPipeStorageINTEL"; + case Decoration::FunctionFloatingPointModeINTEL: + return "FunctionFloatingPointModeINTEL"; + case Decoration::SingleElementVectorINTEL: + return "SingleElementVectorINTEL"; + case Decoration::VectorComputeCallableFunctionINTEL: + return "VectorComputeCallableFunctionINTEL"; + case Decoration::MediaBlockIOINTEL: + return "MediaBlockIOINTEL"; + case Decoration::StallFreeINTEL: + return "StallFreeINTEL"; + case Decoration::FPMaxErrorDecorationINTEL: + return "FPMaxErrorDecorationINTEL"; + case Decoration::LatencyControlLabelINTEL: + return "LatencyControlLabelINTEL"; + case Decoration::LatencyControlConstraintINTEL: + return "LatencyControlConstraintINTEL"; + case Decoration::ConduitKernelArgumentINTEL: + return "ConduitKernelArgumentINTEL"; + case Decoration::RegisterMapKernelArgumentINTEL: + return "RegisterMapKernelArgumentINTEL"; + case Decoration::MMHostInterfaceAddressWidthINTEL: + return "MMHostInterfaceAddressWidthINTEL"; + case Decoration::MMHostInterfaceDataWidthINTEL: + return "MMHostInterfaceDataWidthINTEL"; + case Decoration::MMHostInterfaceLatencyINTEL: + return "MMHostInterfaceLatencyINTEL"; + case Decoration::MMHostInterfaceReadWriteModeINTEL: + return "MMHostInterfaceReadWriteModeINTEL"; + case Decoration::MMHostInterfaceMaxBurstINTEL: + return "MMHostInterfaceMaxBurstINTEL"; + case Decoration::MMHostInterfaceWaitRequestINTEL: + return "MMHostInterfaceWaitRequestINTEL"; + case Decoration::StableKernelArgumentINTEL: + return "StableKernelArgumentINTEL"; + case Decoration::HostAccessINTEL: + return "HostAccessINTEL"; + case Decoration::InitModeINTEL: + return "InitModeINTEL"; + case Decoration::ImplementInRegisterMapINTEL: + return "ImplementInRegisterMapINTEL"; + case Decoration::CacheControlLoadINTEL: + return "CacheControlLoadINTEL"; + case Decoration::CacheControlStoreINTEL: + return "CacheControlStoreINTEL"; + } + return "unknown"; +} +auto to_string(BuiltIn e) -> char const * { + switch (e) { + case BuiltIn::Position: + return "Position"; + case BuiltIn::PointSize: + return "PointSize"; + case BuiltIn::ClipDistance: + return "ClipDistance"; + case BuiltIn::CullDistance: + return "CullDistance"; + case BuiltIn::VertexId: + return "VertexId"; + case BuiltIn::InstanceId: + return "InstanceId"; + case BuiltIn::PrimitiveId: + return "PrimitiveId"; + case BuiltIn::InvocationId: + return "InvocationId"; + case BuiltIn::Layer: + return "Layer"; + case BuiltIn::ViewportIndex: + return "ViewportIndex"; + case BuiltIn::TessLevelOuter: + return "TessLevelOuter"; + case BuiltIn::TessLevelInner: + return "TessLevelInner"; + case BuiltIn::TessCoord: + return "TessCoord"; + case BuiltIn::PatchVertices: + return "PatchVertices"; + case BuiltIn::FragCoord: + return "FragCoord"; + case BuiltIn::PointCoord: + return "PointCoord"; + case BuiltIn::FrontFacing: + return "FrontFacing"; + case BuiltIn::SampleId: + return "SampleId"; + case BuiltIn::SamplePosition: + return "SamplePosition"; + case BuiltIn::SampleMask: + return "SampleMask"; + case BuiltIn::FragDepth: + return "FragDepth"; + case BuiltIn::HelperInvocation: + return "HelperInvocation"; + case BuiltIn::NumWorkgroups: + return "NumWorkgroups"; + case BuiltIn::WorkgroupSize: + return "WorkgroupSize"; + case BuiltIn::WorkgroupId: + return "WorkgroupId"; + case BuiltIn::LocalInvocationId: + return "LocalInvocationId"; + case BuiltIn::GlobalInvocationId: + return "GlobalInvocationId"; + case BuiltIn::LocalInvocationIndex: + return "LocalInvocationIndex"; + case BuiltIn::WorkDim: + return "WorkDim"; + case BuiltIn::GlobalSize: + return "GlobalSize"; + case BuiltIn::EnqueuedWorkgroupSize: + return "EnqueuedWorkgroupSize"; + case BuiltIn::GlobalOffset: + return "GlobalOffset"; + case BuiltIn::GlobalLinearId: + return "GlobalLinearId"; + case BuiltIn::SubgroupSize: + return "SubgroupSize"; + case BuiltIn::SubgroupMaxSize: + return "SubgroupMaxSize"; + case BuiltIn::NumSubgroups: + return "NumSubgroups"; + case BuiltIn::NumEnqueuedSubgroups: + return "NumEnqueuedSubgroups"; + case BuiltIn::SubgroupId: + return "SubgroupId"; + case BuiltIn::SubgroupLocalInvocationId: + return "SubgroupLocalInvocationId"; + case BuiltIn::VertexIndex: + return "VertexIndex"; + case BuiltIn::InstanceIndex: + return "InstanceIndex"; + case BuiltIn::CoreIDARM: + return "CoreIDARM"; + case BuiltIn::CoreCountARM: + return "CoreCountARM"; + case BuiltIn::CoreMaxIDARM: + return "CoreMaxIDARM"; + case BuiltIn::WarpIDARM: + return "WarpIDARM"; + case BuiltIn::WarpMaxIDARM: + return "WarpMaxIDARM"; + case BuiltIn::SubgroupEqMask: + return "SubgroupEqMask"; + case BuiltIn::SubgroupGeMask: + return "SubgroupGeMask"; + case BuiltIn::SubgroupGtMask: + return "SubgroupGtMask"; + case BuiltIn::SubgroupLeMask: + return "SubgroupLeMask"; + case BuiltIn::SubgroupLtMask: + return "SubgroupLtMask"; + case BuiltIn::BaseVertex: + return "BaseVertex"; + case BuiltIn::BaseInstance: + return "BaseInstance"; + case BuiltIn::DrawIndex: + return "DrawIndex"; + case BuiltIn::PrimitiveShadingRateKHR: + return "PrimitiveShadingRateKHR"; + case BuiltIn::DeviceIndex: + return "DeviceIndex"; + case BuiltIn::ViewIndex: + return "ViewIndex"; + case BuiltIn::ShadingRateKHR: + return "ShadingRateKHR"; + case BuiltIn::TileOffsetQCOM: + return "TileOffsetQCOM"; + case BuiltIn::TileDimensionQCOM: + return "TileDimensionQCOM"; + case BuiltIn::TileApronSizeQCOM: + return "TileApronSizeQCOM"; + case BuiltIn::BaryCoordNoPerspAMD: + return "BaryCoordNoPerspAMD"; + case BuiltIn::BaryCoordNoPerspCentroidAMD: + return "BaryCoordNoPerspCentroidAMD"; + case BuiltIn::BaryCoordNoPerspSampleAMD: + return "BaryCoordNoPerspSampleAMD"; + case BuiltIn::BaryCoordSmoothAMD: + return "BaryCoordSmoothAMD"; + case BuiltIn::BaryCoordSmoothCentroidAMD: + return "BaryCoordSmoothCentroidAMD"; + case BuiltIn::BaryCoordSmoothSampleAMD: + return "BaryCoordSmoothSampleAMD"; + case BuiltIn::BaryCoordPullModelAMD: + return "BaryCoordPullModelAMD"; + case BuiltIn::FragStencilRefEXT: + return "FragStencilRefEXT"; + case BuiltIn::RemainingRecursionLevelsAMDX: + return "RemainingRecursionLevelsAMDX"; + case BuiltIn::ShaderIndexAMDX: + return "ShaderIndexAMDX"; + case BuiltIn::ViewportMaskNV: + return "ViewportMaskNV"; + case BuiltIn::SecondaryPositionNV: + return "SecondaryPositionNV"; + case BuiltIn::SecondaryViewportMaskNV: + return "SecondaryViewportMaskNV"; + case BuiltIn::PositionPerViewNV: + return "PositionPerViewNV"; + case BuiltIn::ViewportMaskPerViewNV: + return "ViewportMaskPerViewNV"; + case BuiltIn::FullyCoveredEXT: + return "FullyCoveredEXT"; + case BuiltIn::TaskCountNV: + return "TaskCountNV"; + case BuiltIn::PrimitiveCountNV: + return "PrimitiveCountNV"; + case BuiltIn::PrimitiveIndicesNV: + return "PrimitiveIndicesNV"; + case BuiltIn::ClipDistancePerViewNV: + return "ClipDistancePerViewNV"; + case BuiltIn::CullDistancePerViewNV: + return "CullDistancePerViewNV"; + case BuiltIn::LayerPerViewNV: + return "LayerPerViewNV"; + case BuiltIn::MeshViewCountNV: + return "MeshViewCountNV"; + case BuiltIn::MeshViewIndicesNV: + return "MeshViewIndicesNV"; + case BuiltIn::BaryCoordKHR: + return "BaryCoordKHR"; + case BuiltIn::BaryCoordNoPerspKHR: + return "BaryCoordNoPerspKHR"; + case BuiltIn::FragSizeEXT: + return "FragSizeEXT"; + case BuiltIn::FragInvocationCountEXT: + return "FragInvocationCountEXT"; + case BuiltIn::PrimitivePointIndicesEXT: + return "PrimitivePointIndicesEXT"; + case BuiltIn::PrimitiveLineIndicesEXT: + return "PrimitiveLineIndicesEXT"; + case BuiltIn::PrimitiveTriangleIndicesEXT: + return "PrimitiveTriangleIndicesEXT"; + case BuiltIn::CullPrimitiveEXT: + return "CullPrimitiveEXT"; + case BuiltIn::LaunchIdKHR: + return "LaunchIdKHR"; + case BuiltIn::LaunchSizeKHR: + return "LaunchSizeKHR"; + case BuiltIn::WorldRayOriginKHR: + return "WorldRayOriginKHR"; + case BuiltIn::WorldRayDirectionKHR: + return "WorldRayDirectionKHR"; + case BuiltIn::ObjectRayOriginKHR: + return "ObjectRayOriginKHR"; + case BuiltIn::ObjectRayDirectionKHR: + return "ObjectRayDirectionKHR"; + case BuiltIn::RayTminKHR: + return "RayTminKHR"; + case BuiltIn::RayTmaxKHR: + return "RayTmaxKHR"; + case BuiltIn::InstanceCustomIndexKHR: + return "InstanceCustomIndexKHR"; + case BuiltIn::ObjectToWorldKHR: + return "ObjectToWorldKHR"; + case BuiltIn::WorldToObjectKHR: + return "WorldToObjectKHR"; + case BuiltIn::HitTNV: + return "HitTNV"; + case BuiltIn::HitKindKHR: + return "HitKindKHR"; + case BuiltIn::CurrentRayTimeNV: + return "CurrentRayTimeNV"; + case BuiltIn::HitTriangleVertexPositionsKHR: + return "HitTriangleVertexPositionsKHR"; + case BuiltIn::HitMicroTriangleVertexPositionsNV: + return "HitMicroTriangleVertexPositionsNV"; + case BuiltIn::HitMicroTriangleVertexBarycentricsNV: + return "HitMicroTriangleVertexBarycentricsNV"; + case BuiltIn::IncomingRayFlagsKHR: + return "IncomingRayFlagsKHR"; + case BuiltIn::RayGeometryIndexKHR: + return "RayGeometryIndexKHR"; + case BuiltIn::HitIsSphereNV: + return "HitIsSphereNV"; + case BuiltIn::HitIsLSSNV: + return "HitIsLSSNV"; + case BuiltIn::HitSpherePositionNV: + return "HitSpherePositionNV"; + case BuiltIn::WarpsPerSMNV: + return "WarpsPerSMNV"; + case BuiltIn::SMCountNV: + return "SMCountNV"; + case BuiltIn::WarpIDNV: + return "WarpIDNV"; + case BuiltIn::SMIDNV: + return "SMIDNV"; + case BuiltIn::HitLSSPositionsNV: + return "HitLSSPositionsNV"; + case BuiltIn::HitKindFrontFacingMicroTriangleNV: + return "HitKindFrontFacingMicroTriangleNV"; + case BuiltIn::HitKindBackFacingMicroTriangleNV: + return "HitKindBackFacingMicroTriangleNV"; + case BuiltIn::HitSphereRadiusNV: + return "HitSphereRadiusNV"; + case BuiltIn::HitLSSRadiiNV: + return "HitLSSRadiiNV"; + case BuiltIn::ClusterIDNV: + return "ClusterIDNV"; + case BuiltIn::CullMaskKHR: + return "CullMaskKHR"; + } + return "unknown"; +} +auto to_string(Scope e) -> char const * { + switch (e) { + case Scope::CrossDevice: + return "CrossDevice"; + case Scope::Device: + return "Device"; + case Scope::Workgroup: + return "Workgroup"; + case Scope::Subgroup: + return "Subgroup"; + case Scope::Invocation: + return "Invocation"; + case Scope::QueueFamily: + return "QueueFamily"; + case Scope::ShaderCallKHR: + return "ShaderCallKHR"; + } + return "unknown"; +} +auto to_string(GroupOperation e) -> char const * { + switch (e) { + case GroupOperation::Reduce: + return "Reduce"; + case GroupOperation::InclusiveScan: + return "InclusiveScan"; + case GroupOperation::ExclusiveScan: + return "ExclusiveScan"; + case GroupOperation::ClusteredReduce: + return "ClusteredReduce"; + case GroupOperation::PartitionedReduceNV: + return "PartitionedReduceNV"; + case GroupOperation::PartitionedInclusiveScanNV: + return "PartitionedInclusiveScanNV"; + case GroupOperation::PartitionedExclusiveScanNV: + return "PartitionedExclusiveScanNV"; + } + return "unknown"; +} +auto to_string(KernelEnqueueFlags e) -> char const * { + switch (e) { + case KernelEnqueueFlags::NoWait: + return "NoWait"; + case KernelEnqueueFlags::WaitKernel: + return "WaitKernel"; + case KernelEnqueueFlags::WaitWorkGroup: + return "WaitWorkGroup"; + } + return "unknown"; +} +auto to_string(Capability e) -> char const * { + switch (e) { + case Capability::Matrix: + return "Matrix"; + case Capability::Shader: + return "Shader"; + case Capability::Geometry: + return "Geometry"; + case Capability::Tessellation: + return "Tessellation"; + case Capability::Addresses: + return "Addresses"; + case Capability::Linkage: + return "Linkage"; + case Capability::Kernel: + return "Kernel"; + case Capability::Vector16: + return "Vector16"; + case Capability::Float16Buffer: + return "Float16Buffer"; + case Capability::Float16: + return "Float16"; + case Capability::Float64: + return "Float64"; + case Capability::Int64: + return "Int64"; + case Capability::Int64Atomics: + return "Int64Atomics"; + case Capability::ImageBasic: + return "ImageBasic"; + case Capability::ImageReadWrite: + return "ImageReadWrite"; + case Capability::ImageMipmap: + return "ImageMipmap"; + case Capability::Pipes: + return "Pipes"; + case Capability::Groups: + return "Groups"; + case Capability::DeviceEnqueue: + return "DeviceEnqueue"; + case Capability::LiteralSampler: + return "LiteralSampler"; + case Capability::AtomicStorage: + return "AtomicStorage"; + case Capability::Int16: + return "Int16"; + case Capability::TessellationPointSize: + return "TessellationPointSize"; + case Capability::GeometryPointSize: + return "GeometryPointSize"; + case Capability::ImageGatherExtended: + return "ImageGatherExtended"; + case Capability::StorageImageMultisample: + return "StorageImageMultisample"; + case Capability::UniformBufferArrayDynamicIndexing: + return "UniformBufferArrayDynamicIndexing"; + case Capability::SampledImageArrayDynamicIndexing: + return "SampledImageArrayDynamicIndexing"; + case Capability::StorageBufferArrayDynamicIndexing: + return "StorageBufferArrayDynamicIndexing"; + case Capability::StorageImageArrayDynamicIndexing: + return "StorageImageArrayDynamicIndexing"; + case Capability::ClipDistance: + return "ClipDistance"; + case Capability::CullDistance: + return "CullDistance"; + case Capability::ImageCubeArray: + return "ImageCubeArray"; + case Capability::SampleRateShading: + return "SampleRateShading"; + case Capability::ImageRect: + return "ImageRect"; + case Capability::SampledRect: + return "SampledRect"; + case Capability::GenericPointer: + return "GenericPointer"; + case Capability::Int8: + return "Int8"; + case Capability::InputAttachment: + return "InputAttachment"; + case Capability::SparseResidency: + return "SparseResidency"; + case Capability::MinLod: + return "MinLod"; + case Capability::Sampled1D: + return "Sampled1D"; + case Capability::Image1D: + return "Image1D"; + case Capability::SampledCubeArray: + return "SampledCubeArray"; + case Capability::SampledBuffer: + return "SampledBuffer"; + case Capability::ImageBuffer: + return "ImageBuffer"; + case Capability::ImageMSArray: + return "ImageMSArray"; + case Capability::StorageImageExtendedFormats: + return "StorageImageExtendedFormats"; + case Capability::ImageQuery: + return "ImageQuery"; + case Capability::DerivativeControl: + return "DerivativeControl"; + case Capability::InterpolationFunction: + return "InterpolationFunction"; + case Capability::TransformFeedback: + return "TransformFeedback"; + case Capability::GeometryStreams: + return "GeometryStreams"; + case Capability::StorageImageReadWithoutFormat: + return "StorageImageReadWithoutFormat"; + case Capability::StorageImageWriteWithoutFormat: + return "StorageImageWriteWithoutFormat"; + case Capability::MultiViewport: + return "MultiViewport"; + case Capability::SubgroupDispatch: + return "SubgroupDispatch"; + case Capability::NamedBarrier: + return "NamedBarrier"; + case Capability::PipeStorage: + return "PipeStorage"; + case Capability::GroupNonUniform: + return "GroupNonUniform"; + case Capability::GroupNonUniformVote: + return "GroupNonUniformVote"; + case Capability::GroupNonUniformArithmetic: + return "GroupNonUniformArithmetic"; + case Capability::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Capability::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Capability::GroupNonUniformShuffleRelative: + return "GroupNonUniformShuffleRelative"; + case Capability::GroupNonUniformClustered: + return "GroupNonUniformClustered"; + case Capability::GroupNonUniformQuad: + return "GroupNonUniformQuad"; + case Capability::ShaderLayer: + return "ShaderLayer"; + case Capability::ShaderViewportIndex: + return "ShaderViewportIndex"; + case Capability::UniformDecoration: + return "UniformDecoration"; + case Capability::CoreBuiltinsARM: + return "CoreBuiltinsARM"; + case Capability::TileImageColorReadAccessEXT: + return "TileImageColorReadAccessEXT"; + case Capability::TileImageDepthReadAccessEXT: + return "TileImageDepthReadAccessEXT"; + case Capability::TileImageStencilReadAccessEXT: + return "TileImageStencilReadAccessEXT"; + case Capability::TensorsARM: + return "TensorsARM"; + case Capability::StorageTensorArrayDynamicIndexingARM: + return "StorageTensorArrayDynamicIndexingARM"; + case Capability::StorageTensorArrayNonUniformIndexingARM: + return "StorageTensorArrayNonUniformIndexingARM"; + case Capability::CooperativeMatrixLayoutsARM: + return "CooperativeMatrixLayoutsARM"; + case Capability::Float8EXT: + return "Float8EXT"; + case Capability::Float8CooperativeMatrixEXT: + return "Float8CooperativeMatrixEXT"; + case Capability::FragmentShadingRateKHR: + return "FragmentShadingRateKHR"; + case Capability::SubgroupBallotKHR: + return "SubgroupBallotKHR"; + case Capability::DrawParameters: + return "DrawParameters"; + case Capability::WorkgroupMemoryExplicitLayoutKHR: + return "WorkgroupMemoryExplicitLayoutKHR"; + case Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR: + return "WorkgroupMemoryExplicitLayout8BitAccessKHR"; + case Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR: + return "WorkgroupMemoryExplicitLayout16BitAccessKHR"; + case Capability::SubgroupVoteKHR: + return "SubgroupVoteKHR"; + case Capability::StorageBuffer16BitAccess: + return "StorageBuffer16BitAccess"; + case Capability::UniformAndStorageBuffer16BitAccess: + return "UniformAndStorageBuffer16BitAccess"; + case Capability::StoragePushConstant16: + return "StoragePushConstant16"; + case Capability::StorageInputOutput16: + return "StorageInputOutput16"; + case Capability::DeviceGroup: + return "DeviceGroup"; + case Capability::MultiView: + return "MultiView"; + case Capability::VariablePointersStorageBuffer: + return "VariablePointersStorageBuffer"; + case Capability::VariablePointers: + return "VariablePointers"; + case Capability::AtomicStorageOps: + return "AtomicStorageOps"; + case Capability::SampleMaskPostDepthCoverage: + return "SampleMaskPostDepthCoverage"; + case Capability::StorageBuffer8BitAccess: + return "StorageBuffer8BitAccess"; + case Capability::UniformAndStorageBuffer8BitAccess: + return "UniformAndStorageBuffer8BitAccess"; + case Capability::StoragePushConstant8: + return "StoragePushConstant8"; + case Capability::DenormPreserve: + return "DenormPreserve"; + case Capability::DenormFlushToZero: + return "DenormFlushToZero"; + case Capability::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case Capability::RoundingModeRTE: + return "RoundingModeRTE"; + case Capability::RoundingModeRTZ: + return "RoundingModeRTZ"; + case Capability::RayQueryProvisionalKHR: + return "RayQueryProvisionalKHR"; + case Capability::RayQueryKHR: + return "RayQueryKHR"; + case Capability::UntypedPointersKHR: + return "UntypedPointersKHR"; + case Capability::RayTraversalPrimitiveCullingKHR: + return "RayTraversalPrimitiveCullingKHR"; + case Capability::RayTracingKHR: + return "RayTracingKHR"; + case Capability::TextureSampleWeightedQCOM: + return "TextureSampleWeightedQCOM"; + case Capability::TextureBoxFilterQCOM: + return "TextureBoxFilterQCOM"; + case Capability::TextureBlockMatchQCOM: + return "TextureBlockMatchQCOM"; + case Capability::TileShadingQCOM: + return "TileShadingQCOM"; + case Capability::TextureBlockMatch2QCOM: + return "TextureBlockMatch2QCOM"; + case Capability::Float16ImageAMD: + return "Float16ImageAMD"; + case Capability::ImageGatherBiasLodAMD: + return "ImageGatherBiasLodAMD"; + case Capability::FragmentMaskAMD: + return "FragmentMaskAMD"; + case Capability::StencilExportEXT: + return "StencilExportEXT"; + case Capability::ImageReadWriteLodAMD: + return "ImageReadWriteLodAMD"; + case Capability::Int64ImageEXT: + return "Int64ImageEXT"; + case Capability::ShaderClockKHR: + return "ShaderClockKHR"; + case Capability::ShaderEnqueueAMDX: + return "ShaderEnqueueAMDX"; + case Capability::QuadControlKHR: + return "QuadControlKHR"; + case Capability::Int4TypeINTEL: + return "Int4TypeINTEL"; + case Capability::Int4CooperativeMatrixINTEL: + return "Int4CooperativeMatrixINTEL"; + case Capability::BFloat16TypeKHR: + return "BFloat16TypeKHR"; + case Capability::BFloat16DotProductKHR: + return "BFloat16DotProductKHR"; + case Capability::BFloat16CooperativeMatrixKHR: + return "BFloat16CooperativeMatrixKHR"; + case Capability::SampleMaskOverrideCoverageNV: + return "SampleMaskOverrideCoverageNV"; + case Capability::GeometryShaderPassthroughNV: + return "GeometryShaderPassthroughNV"; + case Capability::ShaderViewportIndexLayerEXT: + return "ShaderViewportIndexLayerEXT"; + case Capability::ShaderViewportMaskNV: + return "ShaderViewportMaskNV"; + case Capability::ShaderStereoViewNV: + return "ShaderStereoViewNV"; + case Capability::PerViewAttributesNV: + return "PerViewAttributesNV"; + case Capability::FragmentFullyCoveredEXT: + return "FragmentFullyCoveredEXT"; + case Capability::MeshShadingNV: + return "MeshShadingNV"; + case Capability::ImageFootprintNV: + return "ImageFootprintNV"; + case Capability::MeshShadingEXT: + return "MeshShadingEXT"; + case Capability::FragmentBarycentricKHR: + return "FragmentBarycentricKHR"; + case Capability::ComputeDerivativeGroupQuadsKHR: + return "ComputeDerivativeGroupQuadsKHR"; + case Capability::FragmentDensityEXT: + return "FragmentDensityEXT"; + case Capability::GroupNonUniformPartitionedNV: + return "GroupNonUniformPartitionedNV"; + case Capability::ShaderNonUniform: + return "ShaderNonUniform"; + case Capability::RuntimeDescriptorArray: + return "RuntimeDescriptorArray"; + case Capability::InputAttachmentArrayDynamicIndexing: + return "InputAttachmentArrayDynamicIndexing"; + case Capability::UniformTexelBufferArrayDynamicIndexing: + return "UniformTexelBufferArrayDynamicIndexing"; + case Capability::StorageTexelBufferArrayDynamicIndexing: + return "StorageTexelBufferArrayDynamicIndexing"; + case Capability::UniformBufferArrayNonUniformIndexing: + return "UniformBufferArrayNonUniformIndexing"; + case Capability::SampledImageArrayNonUniformIndexing: + return "SampledImageArrayNonUniformIndexing"; + case Capability::StorageBufferArrayNonUniformIndexing: + return "StorageBufferArrayNonUniformIndexing"; + case Capability::StorageImageArrayNonUniformIndexing: + return "StorageImageArrayNonUniformIndexing"; + case Capability::InputAttachmentArrayNonUniformIndexing: + return "InputAttachmentArrayNonUniformIndexing"; + case Capability::UniformTexelBufferArrayNonUniformIndexing: + return "UniformTexelBufferArrayNonUniformIndexing"; + case Capability::StorageTexelBufferArrayNonUniformIndexing: + return "StorageTexelBufferArrayNonUniformIndexing"; + case Capability::RayTracingPositionFetchKHR: + return "RayTracingPositionFetchKHR"; + case Capability::RayTracingNV: + return "RayTracingNV"; + case Capability::RayTracingMotionBlurNV: + return "RayTracingMotionBlurNV"; + case Capability::VulkanMemoryModel: + return "VulkanMemoryModel"; + case Capability::VulkanMemoryModelDeviceScope: + return "VulkanMemoryModelDeviceScope"; + case Capability::PhysicalStorageBufferAddresses: + return "PhysicalStorageBufferAddresses"; + case Capability::ComputeDerivativeGroupLinearKHR: + return "ComputeDerivativeGroupLinearKHR"; + case Capability::RayTracingProvisionalKHR: + return "RayTracingProvisionalKHR"; + case Capability::CooperativeMatrixNV: + return "CooperativeMatrixNV"; + case Capability::FragmentShaderSampleInterlockEXT: + return "FragmentShaderSampleInterlockEXT"; + case Capability::FragmentShaderShadingRateInterlockEXT: + return "FragmentShaderShadingRateInterlockEXT"; + case Capability::ShaderSMBuiltinsNV: + return "ShaderSMBuiltinsNV"; + case Capability::FragmentShaderPixelInterlockEXT: + return "FragmentShaderPixelInterlockEXT"; + case Capability::DemoteToHelperInvocation: + return "DemoteToHelperInvocation"; + case Capability::DisplacementMicromapNV: + return "DisplacementMicromapNV"; + case Capability::RayTracingOpacityMicromapEXT: + return "RayTracingOpacityMicromapEXT"; + case Capability::ShaderInvocationReorderNV: + return "ShaderInvocationReorderNV"; + case Capability::BindlessTextureNV: + return "BindlessTextureNV"; + case Capability::RayQueryPositionFetchKHR: + return "RayQueryPositionFetchKHR"; + case Capability::CooperativeVectorNV: + return "CooperativeVectorNV"; + case Capability::AtomicFloat16VectorNV: + return "AtomicFloat16VectorNV"; + case Capability::RayTracingDisplacementMicromapNV: + return "RayTracingDisplacementMicromapNV"; + case Capability::RawAccessChainsNV: + return "RawAccessChainsNV"; + case Capability::RayTracingSpheresGeometryNV: + return "RayTracingSpheresGeometryNV"; + case Capability::RayTracingLinearSweptSpheresGeometryNV: + return "RayTracingLinearSweptSpheresGeometryNV"; + case Capability::CooperativeMatrixReductionsNV: + return "CooperativeMatrixReductionsNV"; + case Capability::CooperativeMatrixConversionsNV: + return "CooperativeMatrixConversionsNV"; + case Capability::CooperativeMatrixPerElementOperationsNV: + return "CooperativeMatrixPerElementOperationsNV"; + case Capability::CooperativeMatrixTensorAddressingNV: + return "CooperativeMatrixTensorAddressingNV"; + case Capability::CooperativeMatrixBlockLoadsNV: + return "CooperativeMatrixBlockLoadsNV"; + case Capability::CooperativeVectorTrainingNV: + return "CooperativeVectorTrainingNV"; + case Capability::RayTracingClusterAccelerationStructureNV: + return "RayTracingClusterAccelerationStructureNV"; + case Capability::TensorAddressingNV: + return "TensorAddressingNV"; + case Capability::SubgroupShuffleINTEL: + return "SubgroupShuffleINTEL"; + case Capability::SubgroupBufferBlockIOINTEL: + return "SubgroupBufferBlockIOINTEL"; + case Capability::SubgroupImageBlockIOINTEL: + return "SubgroupImageBlockIOINTEL"; + case Capability::SubgroupImageMediaBlockIOINTEL: + return "SubgroupImageMediaBlockIOINTEL"; + case Capability::RoundToInfinityINTEL: + return "RoundToInfinityINTEL"; + case Capability::FloatingPointModeINTEL: + return "FloatingPointModeINTEL"; + case Capability::IntegerFunctions2INTEL: + return "IntegerFunctions2INTEL"; + case Capability::FunctionPointersINTEL: + return "FunctionPointersINTEL"; + case Capability::IndirectReferencesINTEL: + return "IndirectReferencesINTEL"; + case Capability::AsmINTEL: + return "AsmINTEL"; + case Capability::AtomicFloat32MinMaxEXT: + return "AtomicFloat32MinMaxEXT"; + case Capability::AtomicFloat64MinMaxEXT: + return "AtomicFloat64MinMaxEXT"; + case Capability::AtomicFloat16MinMaxEXT: + return "AtomicFloat16MinMaxEXT"; + case Capability::VectorComputeINTEL: + return "VectorComputeINTEL"; + case Capability::VectorAnyINTEL: + return "VectorAnyINTEL"; + case Capability::ExpectAssumeKHR: + return "ExpectAssumeKHR"; + case Capability::SubgroupAvcMotionEstimationINTEL: + return "SubgroupAvcMotionEstimationINTEL"; + case Capability::SubgroupAvcMotionEstimationIntraINTEL: + return "SubgroupAvcMotionEstimationIntraINTEL"; + case Capability::SubgroupAvcMotionEstimationChromaINTEL: + return "SubgroupAvcMotionEstimationChromaINTEL"; + case Capability::VariableLengthArrayINTEL: + return "VariableLengthArrayINTEL"; + case Capability::FunctionFloatControlINTEL: + return "FunctionFloatControlINTEL"; + case Capability::FPGAMemoryAttributesINTEL: + return "FPGAMemoryAttributesINTEL"; + case Capability::FPFastMathModeINTEL: + return "FPFastMathModeINTEL"; + case Capability::ArbitraryPrecisionIntegersINTEL: + return "ArbitraryPrecisionIntegersINTEL"; + case Capability::ArbitraryPrecisionFloatingPointINTEL: + return "ArbitraryPrecisionFloatingPointINTEL"; + case Capability::UnstructuredLoopControlsINTEL: + return "UnstructuredLoopControlsINTEL"; + case Capability::FPGALoopControlsINTEL: + return "FPGALoopControlsINTEL"; + case Capability::KernelAttributesINTEL: + return "KernelAttributesINTEL"; + case Capability::FPGAKernelAttributesINTEL: + return "FPGAKernelAttributesINTEL"; + case Capability::FPGAMemoryAccessesINTEL: + return "FPGAMemoryAccessesINTEL"; + case Capability::FPGAClusterAttributesINTEL: + return "FPGAClusterAttributesINTEL"; + case Capability::LoopFuseINTEL: + return "LoopFuseINTEL"; + case Capability::FPGADSPControlINTEL: + return "FPGADSPControlINTEL"; + case Capability::MemoryAccessAliasingINTEL: + return "MemoryAccessAliasingINTEL"; + case Capability::FPGAInvocationPipeliningAttributesINTEL: + return "FPGAInvocationPipeliningAttributesINTEL"; + case Capability::FPGABufferLocationINTEL: + return "FPGABufferLocationINTEL"; + case Capability::ArbitraryPrecisionFixedPointINTEL: + return "ArbitraryPrecisionFixedPointINTEL"; + case Capability::USMStorageClassesINTEL: + return "USMStorageClassesINTEL"; + case Capability::RuntimeAlignedAttributeINTEL: + return "RuntimeAlignedAttributeINTEL"; + case Capability::IOPipesINTEL: + return "IOPipesINTEL"; + case Capability::BlockingPipesINTEL: + return "BlockingPipesINTEL"; + case Capability::FPGARegINTEL: + return "FPGARegINTEL"; + case Capability::DotProductInputAll: + return "DotProductInputAll"; + case Capability::DotProductInput4x8Bit: + return "DotProductInput4x8Bit"; + case Capability::DotProductInput4x8BitPacked: + return "DotProductInput4x8BitPacked"; + case Capability::DotProduct: + return "DotProduct"; + case Capability::RayCullMaskKHR: + return "RayCullMaskKHR"; + case Capability::CooperativeMatrixKHR: + return "CooperativeMatrixKHR"; + case Capability::ReplicatedCompositesEXT: + return "ReplicatedCompositesEXT"; + case Capability::BitInstructions: + return "BitInstructions"; + case Capability::GroupNonUniformRotateKHR: + return "GroupNonUniformRotateKHR"; + case Capability::FloatControls2: + return "FloatControls2"; + case Capability::AtomicFloat32AddEXT: + return "AtomicFloat32AddEXT"; + case Capability::AtomicFloat64AddEXT: + return "AtomicFloat64AddEXT"; + case Capability::LongCompositesINTEL: + return "LongCompositesINTEL"; + case Capability::OptNoneEXT: + return "OptNoneEXT"; + case Capability::AtomicFloat16AddEXT: + return "AtomicFloat16AddEXT"; + case Capability::DebugInfoModuleINTEL: + return "DebugInfoModuleINTEL"; + case Capability::BFloat16ConversionINTEL: + return "BFloat16ConversionINTEL"; + case Capability::SplitBarrierINTEL: + return "SplitBarrierINTEL"; + case Capability::ArithmeticFenceEXT: + return "ArithmeticFenceEXT"; + case Capability::FPGAClusterAttributesV2INTEL: + return "FPGAClusterAttributesV2INTEL"; + case Capability::FPGAKernelAttributesv2INTEL: + return "FPGAKernelAttributesv2INTEL"; + case Capability::TaskSequenceINTEL: + return "TaskSequenceINTEL"; + case Capability::FPMaxErrorINTEL: + return "FPMaxErrorINTEL"; + case Capability::FPGALatencyControlINTEL: + return "FPGALatencyControlINTEL"; + case Capability::FPGAArgumentInterfacesINTEL: + return "FPGAArgumentInterfacesINTEL"; + case Capability::GlobalVariableHostAccessINTEL: + return "GlobalVariableHostAccessINTEL"; + case Capability::GlobalVariableFPGADecorationsINTEL: + return "GlobalVariableFPGADecorationsINTEL"; + case Capability::SubgroupBufferPrefetchINTEL: + return "SubgroupBufferPrefetchINTEL"; + case Capability::Subgroup2DBlockIOINTEL: + return "Subgroup2DBlockIOINTEL"; + case Capability::Subgroup2DBlockTransformINTEL: + return "Subgroup2DBlockTransformINTEL"; + case Capability::Subgroup2DBlockTransposeINTEL: + return "Subgroup2DBlockTransposeINTEL"; + case Capability::SubgroupMatrixMultiplyAccumulateINTEL: + return "SubgroupMatrixMultiplyAccumulateINTEL"; + case Capability::TernaryBitwiseFunctionINTEL: + return "TernaryBitwiseFunctionINTEL"; + case Capability::GroupUniformArithmeticKHR: + return "GroupUniformArithmeticKHR"; + case Capability::TensorFloat32RoundingINTEL: + return "TensorFloat32RoundingINTEL"; + case Capability::MaskedGatherScatterINTEL: + return "MaskedGatherScatterINTEL"; + case Capability::CacheControlsINTEL: + return "CacheControlsINTEL"; + case Capability::RegisterLimitsINTEL: + return "RegisterLimitsINTEL"; + case Capability::BindlessImagesINTEL: + return "BindlessImagesINTEL"; + case Capability::PackedCooperativeMatrixINTEL: + return "PackedCooperativeMatrixINTEL"; + case Capability::CooperativeMatrixInvocationInstructionsINTEL: + return "CooperativeMatrixInvocationInstructionsINTEL"; + case Capability::CooperativeMatrixTF32ComponentTypeINTEL: + return "CooperativeMatrixTF32ComponentTypeINTEL"; + case Capability::CooperativeMatrixBFloat16ComponentTypeINTEL: + return "CooperativeMatrixBFloat16ComponentTypeINTEL"; + case Capability::CooperativeMatrixCheckedInstructionsINTEL: + return "CooperativeMatrixCheckedInstructionsINTEL"; + case Capability::CooperativeMatrixPrefetchINTEL: + return "CooperativeMatrixPrefetchINTEL"; + } + return "unknown"; +} +auto to_string(RayQueryIntersection e) -> char const * { + switch (e) { + case RayQueryIntersection::RayQueryCandidateIntersectionKHR: + return "RayQueryCandidateIntersectionKHR"; + case RayQueryIntersection::RayQueryCommittedIntersectionKHR: + return "RayQueryCommittedIntersectionKHR"; + } + return "unknown"; +} +auto to_string(RayQueryCommittedIntersectionType e) -> char const * { + switch (e) { + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR: + return "RayQueryCommittedIntersectionNoneKHR"; + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR: + return "RayQueryCommittedIntersectionTriangleKHR"; + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionGeneratedKHR: + return "RayQueryCommittedIntersectionGeneratedKHR"; + } + return "unknown"; +} +auto to_string(RayQueryCandidateIntersectionType e) -> char const * { + switch (e) { + case RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR: + return "RayQueryCandidateIntersectionTriangleKHR"; + case RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR: + return "RayQueryCandidateIntersectionAABBKHR"; + } + return "unknown"; +} +auto to_string(PackedVectorFormat e) -> char const * { + switch (e) { + case PackedVectorFormat::PackedVectorFormat4x8Bit: + return "PackedVectorFormat4x8Bit"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixOperands e) -> char const * { + switch (e) { + case CooperativeMatrixOperands::NoneKHR: + return "NoneKHR"; + case CooperativeMatrixOperands::MatrixASignedComponentsKHR: + return "MatrixASignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixBSignedComponentsKHR: + return "MatrixBSignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixCSignedComponentsKHR: + return "MatrixCSignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixResultSignedComponentsKHR: + return "MatrixResultSignedComponentsKHR"; + case CooperativeMatrixOperands::SaturatingAccumulationKHR: + return "SaturatingAccumulationKHR"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixLayout e) -> char const * { + switch (e) { + case CooperativeMatrixLayout::RowMajorKHR: + return "RowMajorKHR"; + case CooperativeMatrixLayout::ColumnMajorKHR: + return "ColumnMajorKHR"; + case CooperativeMatrixLayout::RowBlockedInterleavedARM: + return "RowBlockedInterleavedARM"; + case CooperativeMatrixLayout::ColumnBlockedInterleavedARM: + return "ColumnBlockedInterleavedARM"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixUse e) -> char const * { + switch (e) { + case CooperativeMatrixUse::MatrixAKHR: + return "MatrixAKHR"; + case CooperativeMatrixUse::MatrixBKHR: + return "MatrixBKHR"; + case CooperativeMatrixUse::MatrixAccumulatorKHR: + return "MatrixAccumulatorKHR"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixReduce e) -> char const * { + switch (e) { + case CooperativeMatrixReduce::Row: + return "Row"; + case CooperativeMatrixReduce::Column: + return "Column"; + case CooperativeMatrixReduce::CooperativeMatrixReduce2x2: + return "CooperativeMatrixReduce2x2"; + } + return "unknown"; +} +auto to_string(TensorClampMode e) -> char const * { + switch (e) { + case TensorClampMode::Undefined: + return "Undefined"; + case TensorClampMode::Constant: + return "Constant"; + case TensorClampMode::ClampToEdge: + return "ClampToEdge"; + case TensorClampMode::Repeat: + return "Repeat"; + case TensorClampMode::RepeatMirrored: + return "RepeatMirrored"; + } + return "unknown"; +} +auto to_string(TensorAddressingOperands e) -> char const * { + switch (e) { + case TensorAddressingOperands::None: + return "None"; + case TensorAddressingOperands::TensorView: + return "TensorView"; + case TensorAddressingOperands::DecodeFunc: + return "DecodeFunc"; + } + return "unknown"; +} +auto to_string(InitializationModeQualifier e) -> char const * { + switch (e) { + case InitializationModeQualifier::InitOnDeviceReprogramINTEL: + return "InitOnDeviceReprogramINTEL"; + case InitializationModeQualifier::InitOnDeviceResetINTEL: + return "InitOnDeviceResetINTEL"; + } + return "unknown"; +} +auto to_string(LoadCacheControl e) -> char const * { + switch (e) { + case LoadCacheControl::UncachedINTEL: + return "UncachedINTEL"; + case LoadCacheControl::CachedINTEL: + return "CachedINTEL"; + case LoadCacheControl::StreamingINTEL: + return "StreamingINTEL"; + case LoadCacheControl::InvalidateAfterReadINTEL: + return "InvalidateAfterReadINTEL"; + case LoadCacheControl::ConstCachedINTEL: + return "ConstCachedINTEL"; + } + return "unknown"; +} +auto to_string(StoreCacheControl e) -> char const * { + switch (e) { + case StoreCacheControl::UncachedINTEL: + return "UncachedINTEL"; + case StoreCacheControl::WriteThroughINTEL: + return "WriteThroughINTEL"; + case StoreCacheControl::WriteBackINTEL: + return "WriteBackINTEL"; + case StoreCacheControl::StreamingINTEL: + return "StreamingINTEL"; + } + return "unknown"; +} +auto to_string(NamedMaximumNumberOfRegisters e) -> char const * { + switch (e) { + case NamedMaximumNumberOfRegisters::AutoINTEL: + return "AutoINTEL"; + } + return "unknown"; +} +auto to_string(MatrixMultiplyAccumulateOperands e) -> char const * { + switch (e) { + case MatrixMultiplyAccumulateOperands::None: + return "None"; + case MatrixMultiplyAccumulateOperands::MatrixASignedComponentsINTEL: + return "MatrixASignedComponentsINTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBSignedComponentsINTEL: + return "MatrixBSignedComponentsINTEL"; + case MatrixMultiplyAccumulateOperands::MatrixCBFloat16INTEL: + return "MatrixCBFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixResultBFloat16INTEL: + return "MatrixResultBFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedInt8INTEL: + return "MatrixAPackedInt8INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedInt8INTEL: + return "MatrixBPackedInt8INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedInt4INTEL: + return "MatrixAPackedInt4INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedInt4INTEL: + return "MatrixBPackedInt4INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixATF32INTEL: + return "MatrixATF32INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBTF32INTEL: + return "MatrixBTF32INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedFloat16INTEL: + return "MatrixAPackedFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedFloat16INTEL: + return "MatrixBPackedFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedBFloat16INTEL: + return "MatrixAPackedBFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedBFloat16INTEL: + return "MatrixBPackedBFloat16INTEL"; + } + return "unknown"; +} +auto to_string(FPEncoding e) -> char const * { + switch (e) { + case FPEncoding::BFloat16KHR: + return "BFloat16KHR"; + case FPEncoding::Float8E4M3EXT: + return "Float8E4M3EXT"; + case FPEncoding::Float8E5M2EXT: + return "Float8E5M2EXT"; + } + return "unknown"; +} +auto to_string(CooperativeVectorMatrixLayout e) -> char const * { + switch (e) { + case CooperativeVectorMatrixLayout::RowMajorNV: + return "RowMajorNV"; + case CooperativeVectorMatrixLayout::ColumnMajorNV: + return "ColumnMajorNV"; + case CooperativeVectorMatrixLayout::InferencingOptimalNV: + return "InferencingOptimalNV"; + case CooperativeVectorMatrixLayout::TrainingOptimalNV: + return "TrainingOptimalNV"; + } + return "unknown"; +} +auto to_string(ComponentType e) -> char const * { + switch (e) { + case ComponentType::Float16NV: + return "Float16NV"; + case ComponentType::Float32NV: + return "Float32NV"; + case ComponentType::Float64NV: + return "Float64NV"; + case ComponentType::SignedInt8NV: + return "SignedInt8NV"; + case ComponentType::SignedInt16NV: + return "SignedInt16NV"; + case ComponentType::SignedInt32NV: + return "SignedInt32NV"; + case ComponentType::SignedInt64NV: + return "SignedInt64NV"; + case ComponentType::UnsignedInt8NV: + return "UnsignedInt8NV"; + case ComponentType::UnsignedInt16NV: + return "UnsignedInt16NV"; + case ComponentType::UnsignedInt32NV: + return "UnsignedInt32NV"; + case ComponentType::UnsignedInt64NV: + return "UnsignedInt64NV"; + case ComponentType::SignedInt8PackedNV: + return "SignedInt8PackedNV"; + case ComponentType::UnsignedInt8PackedNV: + return "UnsignedInt8PackedNV"; + case ComponentType::FloatE4M3NV: + return "FloatE4M3NV"; + case ComponentType::FloatE5M2NV: + return "FloatE5M2NV"; + } + return "unknown"; +} +auto to_string(TensorOperands e) -> char const * { + switch (e) { + case TensorOperands::NoneARM: + return "NoneARM"; + case TensorOperands::NontemporalARM: + return "NontemporalARM"; + case TensorOperands::OutOfBoundsValueARM: + return "OutOfBoundsValueARM"; + case TensorOperands::MakeElementAvailableARM: + return "MakeElementAvailableARM"; + case TensorOperands::MakeElementVisibleARM: + return "MakeElementVisibleARM"; + case TensorOperands::NonPrivateElementARM: + return "NonPrivateElementARM"; + } + return "unknown"; +} + +} // namespace tinytc::spv diff --git a/src/spv/names.hpp b/src/spv/names.hpp new file mode 100644 index 00000000..31679c42 --- /dev/null +++ b/src/spv/names.hpp @@ -0,0 +1,131 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_NAMES_20250605_HPP +#define GENERATED_NAMES_20250605_HPP + +namespace tinytc::spv { + +enum class Op; +auto to_string(Op op) -> char const *; +enum class ImageOperands; +auto to_string(ImageOperands e) -> char const *; +enum class FPFastMathMode; +auto to_string(FPFastMathMode e) -> char const *; +enum class SelectionControl; +auto to_string(SelectionControl e) -> char const *; +enum class LoopControl; +auto to_string(LoopControl e) -> char const *; +enum class FunctionControl; +auto to_string(FunctionControl e) -> char const *; +enum class MemorySemantics; +auto to_string(MemorySemantics e) -> char const *; +enum class MemoryAccess; +auto to_string(MemoryAccess e) -> char const *; +enum class KernelProfilingInfo; +auto to_string(KernelProfilingInfo e) -> char const *; +enum class RayFlags; +auto to_string(RayFlags e) -> char const *; +enum class FragmentShadingRate; +auto to_string(FragmentShadingRate e) -> char const *; +enum class RawAccessChainOperands; +auto to_string(RawAccessChainOperands e) -> char const *; +enum class SourceLanguage; +auto to_string(SourceLanguage e) -> char const *; +enum class ExecutionModel; +auto to_string(ExecutionModel e) -> char const *; +enum class AddressingModel; +auto to_string(AddressingModel e) -> char const *; +enum class MemoryModel; +auto to_string(MemoryModel e) -> char const *; +enum class ExecutionMode; +auto to_string(ExecutionMode e) -> char const *; +enum class StorageClass; +auto to_string(StorageClass e) -> char const *; +enum class Dim; +auto to_string(Dim e) -> char const *; +enum class SamplerAddressingMode; +auto to_string(SamplerAddressingMode e) -> char const *; +enum class SamplerFilterMode; +auto to_string(SamplerFilterMode e) -> char const *; +enum class ImageFormat; +auto to_string(ImageFormat e) -> char const *; +enum class ImageChannelOrder; +auto to_string(ImageChannelOrder e) -> char const *; +enum class ImageChannelDataType; +auto to_string(ImageChannelDataType e) -> char const *; +enum class FPRoundingMode; +auto to_string(FPRoundingMode e) -> char const *; +enum class FPDenormMode; +auto to_string(FPDenormMode e) -> char const *; +enum class QuantizationModes; +auto to_string(QuantizationModes e) -> char const *; +enum class FPOperationMode; +auto to_string(FPOperationMode e) -> char const *; +enum class OverflowModes; +auto to_string(OverflowModes e) -> char const *; +enum class LinkageType; +auto to_string(LinkageType e) -> char const *; +enum class AccessQualifier; +auto to_string(AccessQualifier e) -> char const *; +enum class HostAccessQualifier; +auto to_string(HostAccessQualifier e) -> char const *; +enum class FunctionParameterAttribute; +auto to_string(FunctionParameterAttribute e) -> char const *; +enum class Decoration; +auto to_string(Decoration e) -> char const *; +enum class BuiltIn; +auto to_string(BuiltIn e) -> char const *; +enum class Scope; +auto to_string(Scope e) -> char const *; +enum class GroupOperation; +auto to_string(GroupOperation e) -> char const *; +enum class KernelEnqueueFlags; +auto to_string(KernelEnqueueFlags e) -> char const *; +enum class Capability; +auto to_string(Capability e) -> char const *; +enum class RayQueryIntersection; +auto to_string(RayQueryIntersection e) -> char const *; +enum class RayQueryCommittedIntersectionType; +auto to_string(RayQueryCommittedIntersectionType e) -> char const *; +enum class RayQueryCandidateIntersectionType; +auto to_string(RayQueryCandidateIntersectionType e) -> char const *; +enum class PackedVectorFormat; +auto to_string(PackedVectorFormat e) -> char const *; +enum class CooperativeMatrixOperands; +auto to_string(CooperativeMatrixOperands e) -> char const *; +enum class CooperativeMatrixLayout; +auto to_string(CooperativeMatrixLayout e) -> char const *; +enum class CooperativeMatrixUse; +auto to_string(CooperativeMatrixUse e) -> char const *; +enum class CooperativeMatrixReduce; +auto to_string(CooperativeMatrixReduce e) -> char const *; +enum class TensorClampMode; +auto to_string(TensorClampMode e) -> char const *; +enum class TensorAddressingOperands; +auto to_string(TensorAddressingOperands e) -> char const *; +enum class InitializationModeQualifier; +auto to_string(InitializationModeQualifier e) -> char const *; +enum class LoadCacheControl; +auto to_string(LoadCacheControl e) -> char const *; +enum class StoreCacheControl; +auto to_string(StoreCacheControl e) -> char const *; +enum class NamedMaximumNumberOfRegisters; +auto to_string(NamedMaximumNumberOfRegisters e) -> char const *; +enum class MatrixMultiplyAccumulateOperands; +auto to_string(MatrixMultiplyAccumulateOperands e) -> char const *; +enum class FPEncoding; +auto to_string(FPEncoding e) -> char const *; +enum class CooperativeVectorMatrixLayout; +auto to_string(CooperativeVectorMatrixLayout e) -> char const *; +enum class ComponentType; +auto to_string(ComponentType e) -> char const *; +enum class TensorOperands; +auto to_string(TensorOperands e) -> char const *; + +} // namespace tinytc::spv + +#endif // GENERATED_NAMES_20250605_HPP diff --git a/src/spv/opencl.std.cpp b/src/spv/opencl.std.cpp new file mode 100644 index 00000000..1284448b --- /dev/null +++ b/src/spv/opencl.std.cpp @@ -0,0 +1,341 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "opencl.std.hpp" + +namespace tinytc::spv { + +auto to_string(OpenCLEntrypoint ep) -> char const * { + switch (ep) { + case OpenCLEntrypoint::acos: + return "acos"; + case OpenCLEntrypoint::acosh: + return "acosh"; + case OpenCLEntrypoint::acospi: + return "acospi"; + case OpenCLEntrypoint::asin: + return "asin"; + case OpenCLEntrypoint::asinh: + return "asinh"; + case OpenCLEntrypoint::asinpi: + return "asinpi"; + case OpenCLEntrypoint::atan: + return "atan"; + case OpenCLEntrypoint::atan2: + return "atan2"; + case OpenCLEntrypoint::atanh: + return "atanh"; + case OpenCLEntrypoint::atanpi: + return "atanpi"; + case OpenCLEntrypoint::atan2pi: + return "atan2pi"; + case OpenCLEntrypoint::cbrt: + return "cbrt"; + case OpenCLEntrypoint::ceil: + return "ceil"; + case OpenCLEntrypoint::copysign: + return "copysign"; + case OpenCLEntrypoint::cos: + return "cos"; + case OpenCLEntrypoint::cosh: + return "cosh"; + case OpenCLEntrypoint::cospi: + return "cospi"; + case OpenCLEntrypoint::erfc: + return "erfc"; + case OpenCLEntrypoint::erf: + return "erf"; + case OpenCLEntrypoint::exp: + return "exp"; + case OpenCLEntrypoint::exp2: + return "exp2"; + case OpenCLEntrypoint::exp10: + return "exp10"; + case OpenCLEntrypoint::expm1: + return "expm1"; + case OpenCLEntrypoint::fabs: + return "fabs"; + case OpenCLEntrypoint::fdim: + return "fdim"; + case OpenCLEntrypoint::floor: + return "floor"; + case OpenCLEntrypoint::fma: + return "fma"; + case OpenCLEntrypoint::fmax: + return "fmax"; + case OpenCLEntrypoint::fmin: + return "fmin"; + case OpenCLEntrypoint::fmod: + return "fmod"; + case OpenCLEntrypoint::fract: + return "fract"; + case OpenCLEntrypoint::frexp: + return "frexp"; + case OpenCLEntrypoint::hypot: + return "hypot"; + case OpenCLEntrypoint::ilogb: + return "ilogb"; + case OpenCLEntrypoint::ldexp: + return "ldexp"; + case OpenCLEntrypoint::lgamma: + return "lgamma"; + case OpenCLEntrypoint::lgamma_r: + return "lgamma_r"; + case OpenCLEntrypoint::log: + return "log"; + case OpenCLEntrypoint::log2: + return "log2"; + case OpenCLEntrypoint::log10: + return "log10"; + case OpenCLEntrypoint::log1p: + return "log1p"; + case OpenCLEntrypoint::logb: + return "logb"; + case OpenCLEntrypoint::mad: + return "mad"; + case OpenCLEntrypoint::maxmag: + return "maxmag"; + case OpenCLEntrypoint::minmag: + return "minmag"; + case OpenCLEntrypoint::modf: + return "modf"; + case OpenCLEntrypoint::nan: + return "nan"; + case OpenCLEntrypoint::nextafter: + return "nextafter"; + case OpenCLEntrypoint::pow: + return "pow"; + case OpenCLEntrypoint::pown: + return "pown"; + case OpenCLEntrypoint::powr: + return "powr"; + case OpenCLEntrypoint::remainder: + return "remainder"; + case OpenCLEntrypoint::remquo: + return "remquo"; + case OpenCLEntrypoint::rint: + return "rint"; + case OpenCLEntrypoint::rootn: + return "rootn"; + case OpenCLEntrypoint::round: + return "round"; + case OpenCLEntrypoint::rsqrt: + return "rsqrt"; + case OpenCLEntrypoint::sin: + return "sin"; + case OpenCLEntrypoint::sincos: + return "sincos"; + case OpenCLEntrypoint::sinh: + return "sinh"; + case OpenCLEntrypoint::sinpi: + return "sinpi"; + case OpenCLEntrypoint::sqrt: + return "sqrt"; + case OpenCLEntrypoint::tan: + return "tan"; + case OpenCLEntrypoint::tanh: + return "tanh"; + case OpenCLEntrypoint::tanpi: + return "tanpi"; + case OpenCLEntrypoint::tgamma: + return "tgamma"; + case OpenCLEntrypoint::trunc: + return "trunc"; + case OpenCLEntrypoint::half_cos: + return "half_cos"; + case OpenCLEntrypoint::half_divide: + return "half_divide"; + case OpenCLEntrypoint::half_exp: + return "half_exp"; + case OpenCLEntrypoint::half_exp2: + return "half_exp2"; + case OpenCLEntrypoint::half_exp10: + return "half_exp10"; + case OpenCLEntrypoint::half_log: + return "half_log"; + case OpenCLEntrypoint::half_log2: + return "half_log2"; + case OpenCLEntrypoint::half_log10: + return "half_log10"; + case OpenCLEntrypoint::half_powr: + return "half_powr"; + case OpenCLEntrypoint::half_recip: + return "half_recip"; + case OpenCLEntrypoint::half_rsqrt: + return "half_rsqrt"; + case OpenCLEntrypoint::half_sin: + return "half_sin"; + case OpenCLEntrypoint::half_sqrt: + return "half_sqrt"; + case OpenCLEntrypoint::half_tan: + return "half_tan"; + case OpenCLEntrypoint::native_cos: + return "native_cos"; + case OpenCLEntrypoint::native_divide: + return "native_divide"; + case OpenCLEntrypoint::native_exp: + return "native_exp"; + case OpenCLEntrypoint::native_exp2: + return "native_exp2"; + case OpenCLEntrypoint::native_exp10: + return "native_exp10"; + case OpenCLEntrypoint::native_log: + return "native_log"; + case OpenCLEntrypoint::native_log2: + return "native_log2"; + case OpenCLEntrypoint::native_log10: + return "native_log10"; + case OpenCLEntrypoint::native_powr: + return "native_powr"; + case OpenCLEntrypoint::native_recip: + return "native_recip"; + case OpenCLEntrypoint::native_rsqrt: + return "native_rsqrt"; + case OpenCLEntrypoint::native_sin: + return "native_sin"; + case OpenCLEntrypoint::native_sqrt: + return "native_sqrt"; + case OpenCLEntrypoint::native_tan: + return "native_tan"; + case OpenCLEntrypoint::s_abs: + return "s_abs"; + case OpenCLEntrypoint::s_abs_diff: + return "s_abs_diff"; + case OpenCLEntrypoint::s_add_sat: + return "s_add_sat"; + case OpenCLEntrypoint::u_add_sat: + return "u_add_sat"; + case OpenCLEntrypoint::s_hadd: + return "s_hadd"; + case OpenCLEntrypoint::u_hadd: + return "u_hadd"; + case OpenCLEntrypoint::s_rhadd: + return "s_rhadd"; + case OpenCLEntrypoint::u_rhadd: + return "u_rhadd"; + case OpenCLEntrypoint::s_clamp: + return "s_clamp"; + case OpenCLEntrypoint::u_clamp: + return "u_clamp"; + case OpenCLEntrypoint::clz: + return "clz"; + case OpenCLEntrypoint::ctz: + return "ctz"; + case OpenCLEntrypoint::s_mad_hi: + return "s_mad_hi"; + case OpenCLEntrypoint::u_mad_sat: + return "u_mad_sat"; + case OpenCLEntrypoint::s_mad_sat: + return "s_mad_sat"; + case OpenCLEntrypoint::s_max: + return "s_max"; + case OpenCLEntrypoint::u_max: + return "u_max"; + case OpenCLEntrypoint::s_min: + return "s_min"; + case OpenCLEntrypoint::u_min: + return "u_min"; + case OpenCLEntrypoint::s_mul_hi: + return "s_mul_hi"; + case OpenCLEntrypoint::rotate: + return "rotate"; + case OpenCLEntrypoint::s_sub_sat: + return "s_sub_sat"; + case OpenCLEntrypoint::u_sub_sat: + return "u_sub_sat"; + case OpenCLEntrypoint::u_upsample: + return "u_upsample"; + case OpenCLEntrypoint::s_upsample: + return "s_upsample"; + case OpenCLEntrypoint::popcount: + return "popcount"; + case OpenCLEntrypoint::s_mad24: + return "s_mad24"; + case OpenCLEntrypoint::u_mad24: + return "u_mad24"; + case OpenCLEntrypoint::s_mul24: + return "s_mul24"; + case OpenCLEntrypoint::u_mul24: + return "u_mul24"; + case OpenCLEntrypoint::u_abs: + return "u_abs"; + case OpenCLEntrypoint::u_abs_diff: + return "u_abs_diff"; + case OpenCLEntrypoint::u_mul_hi: + return "u_mul_hi"; + case OpenCLEntrypoint::u_mad_hi: + return "u_mad_hi"; + case OpenCLEntrypoint::fclamp: + return "fclamp"; + case OpenCLEntrypoint::degrees: + return "degrees"; + case OpenCLEntrypoint::fmax_common: + return "fmax_common"; + case OpenCLEntrypoint::fmin_common: + return "fmin_common"; + case OpenCLEntrypoint::mix: + return "mix"; + case OpenCLEntrypoint::radians: + return "radians"; + case OpenCLEntrypoint::step: + return "step"; + case OpenCLEntrypoint::smoothstep: + return "smoothstep"; + case OpenCLEntrypoint::sign: + return "sign"; + case OpenCLEntrypoint::cross: + return "cross"; + case OpenCLEntrypoint::distance: + return "distance"; + case OpenCLEntrypoint::length: + return "length"; + case OpenCLEntrypoint::normalize: + return "normalize"; + case OpenCLEntrypoint::fast_distance: + return "fast_distance"; + case OpenCLEntrypoint::fast_length: + return "fast_length"; + case OpenCLEntrypoint::fast_normalize: + return "fast_normalize"; + case OpenCLEntrypoint::bitselect: + return "bitselect"; + case OpenCLEntrypoint::select: + return "select"; + case OpenCLEntrypoint::vloadn: + return "vloadn"; + case OpenCLEntrypoint::vstoren: + return "vstoren"; + case OpenCLEntrypoint::vload_half: + return "vload_half"; + case OpenCLEntrypoint::vload_halfn: + return "vload_halfn"; + case OpenCLEntrypoint::vstore_half: + return "vstore_half"; + case OpenCLEntrypoint::vstore_half_r: + return "vstore_half_r"; + case OpenCLEntrypoint::vstore_halfn: + return "vstore_halfn"; + case OpenCLEntrypoint::vstore_halfn_r: + return "vstore_halfn_r"; + case OpenCLEntrypoint::vloada_halfn: + return "vloada_halfn"; + case OpenCLEntrypoint::vstorea_halfn: + return "vstorea_halfn"; + case OpenCLEntrypoint::vstorea_halfn_r: + return "vstorea_halfn_r"; + case OpenCLEntrypoint::shuffle: + return "shuffle"; + case OpenCLEntrypoint::shuffle2: + return "shuffle2"; + case OpenCLEntrypoint::printf: + return "printf"; + case OpenCLEntrypoint::prefetch: + return "prefetch"; + } + return "unknown"; +} + +} // namespace tinytc::spv diff --git a/src/spv/opencl.std.hpp b/src/spv/opencl.std.hpp new file mode 100644 index 00000000..b662101f --- /dev/null +++ b/src/spv/opencl.std.hpp @@ -0,0 +1,183 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_OPENCL_STD_2024115_HPP +#define GENERATED_OPENCL_STD_2024115_HPP + +namespace tinytc::spv { + +constexpr char const *OpenCLExt = "OpenCL.std"; + +enum class OpenCLEntrypoint { + acos = 0, + acosh = 1, + acospi = 2, + asin = 3, + asinh = 4, + asinpi = 5, + atan = 6, + atan2 = 7, + atanh = 8, + atanpi = 9, + atan2pi = 10, + cbrt = 11, + ceil = 12, + copysign = 13, + cos = 14, + cosh = 15, + cospi = 16, + erfc = 17, + erf = 18, + exp = 19, + exp2 = 20, + exp10 = 21, + expm1 = 22, + fabs = 23, + fdim = 24, + floor = 25, + fma = 26, + fmax = 27, + fmin = 28, + fmod = 29, + fract = 30, + frexp = 31, + hypot = 32, + ilogb = 33, + ldexp = 34, + lgamma = 35, + lgamma_r = 36, + log = 37, + log2 = 38, + log10 = 39, + log1p = 40, + logb = 41, + mad = 42, + maxmag = 43, + minmag = 44, + modf = 45, + nan = 46, + nextafter = 47, + pow = 48, + pown = 49, + powr = 50, + remainder = 51, + remquo = 52, + rint = 53, + rootn = 54, + round = 55, + rsqrt = 56, + sin = 57, + sincos = 58, + sinh = 59, + sinpi = 60, + sqrt = 61, + tan = 62, + tanh = 63, + tanpi = 64, + tgamma = 65, + trunc = 66, + half_cos = 67, + half_divide = 68, + half_exp = 69, + half_exp2 = 70, + half_exp10 = 71, + half_log = 72, + half_log2 = 73, + half_log10 = 74, + half_powr = 75, + half_recip = 76, + half_rsqrt = 77, + half_sin = 78, + half_sqrt = 79, + half_tan = 80, + native_cos = 81, + native_divide = 82, + native_exp = 83, + native_exp2 = 84, + native_exp10 = 85, + native_log = 86, + native_log2 = 87, + native_log10 = 88, + native_powr = 89, + native_recip = 90, + native_rsqrt = 91, + native_sin = 92, + native_sqrt = 93, + native_tan = 94, + s_abs = 141, + s_abs_diff = 142, + s_add_sat = 143, + u_add_sat = 144, + s_hadd = 145, + u_hadd = 146, + s_rhadd = 147, + u_rhadd = 148, + s_clamp = 149, + u_clamp = 150, + clz = 151, + ctz = 152, + s_mad_hi = 153, + u_mad_sat = 154, + s_mad_sat = 155, + s_max = 156, + u_max = 157, + s_min = 158, + u_min = 159, + s_mul_hi = 160, + rotate = 161, + s_sub_sat = 162, + u_sub_sat = 163, + u_upsample = 164, + s_upsample = 165, + popcount = 166, + s_mad24 = 167, + u_mad24 = 168, + s_mul24 = 169, + u_mul24 = 170, + u_abs = 201, + u_abs_diff = 202, + u_mul_hi = 203, + u_mad_hi = 204, + fclamp = 95, + degrees = 96, + fmax_common = 97, + fmin_common = 98, + mix = 99, + radians = 100, + step = 101, + smoothstep = 102, + sign = 103, + cross = 104, + distance = 105, + length = 106, + normalize = 107, + fast_distance = 108, + fast_length = 109, + fast_normalize = 110, + bitselect = 186, + select = 187, + vloadn = 171, + vstoren = 172, + vload_half = 173, + vload_halfn = 174, + vstore_half = 175, + vstore_half_r = 176, + vstore_halfn = 177, + vstore_halfn_r = 178, + vloada_halfn = 179, + vstorea_halfn = 180, + vstorea_halfn_r = 181, + shuffle = 182, + shuffle2 = 183, + printf = 184, + prefetch = 185, +}; + +auto to_string(OpenCLEntrypoint op) -> char const *; + +} // namespace tinytc::spv + +#endif // GENERATED_OPENCL_STD_2024115_HPP diff --git a/src/spv/pass/assemble.cpp b/src/spv/pass/assemble.cpp new file mode 100644 index 00000000..4a485513 --- /dev/null +++ b/src/spv/pass/assemble.cpp @@ -0,0 +1,50 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/assemble.hpp" +#include "spv/enums.hpp" +#include "spv/inst_assembler.hpp" +#include "spv/module.hpp" +#include "spv/visit.hpp" +#include "tinytc/tinytc.h" +#include "tinytc/types.h" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc::spv { + +auto assembler::run_on_module(tinytc_spv_mod const &mod) -> binary { + auto data = std::vector{}; + auto stream = word_stream{data}; + + const std::int32_t bound = mod.bound(); + // Guess instruction stream by using 5 words per instruction that produces a result + // Not really important, but could be improved + data.reserve(5 * sizeof(std::int32_t) * bound); + + // Make header + const std::int32_t version = (mod.major_version() << 16) | (mod.minor_version() << 8); + const std::int32_t generator_number = 0; + stream << magic_number << version << generator_number << bound << std::int32_t{0}; + + // Assemble instructions + auto ia = inst_assembler{stream}; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : mod.insts(enum_cast
(s))) { + visit(ia, i); + } + } + + // Create binary + tinytc_binary_t bin; + CHECK_STATUS(tinytc_binary_create(&bin, mod.context(), tinytc_bundle_format_spirv, data.size(), + data.data(), mod.core_features())); + return binary{bin}; +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/assemble.hpp b/src/spv/pass/assemble.hpp new file mode 100644 index 00000000..e4f6172b --- /dev/null +++ b/src/spv/pass/assemble.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ASSEMBLE_20241111_HPP +#define ASSEMBLE_20241111_HPP + +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" + +namespace tinytc::spv { + +class assembler { + public: + auto run_on_module(tinytc_spv_mod const &mod) -> binary; +}; + +} // namespace tinytc::spv + +#endif // ASSEMBLE_20241111_HPP diff --git a/src/spv/pass/assign_ids.cpp b/src/spv/pass/assign_ids.cpp new file mode 100644 index 00000000..6295faae --- /dev/null +++ b/src/spv/pass/assign_ids.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/assign_ids.hpp" +#include "spv/defs.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc::spv { + +void id_assigner::declare(spv_inst *in) { + if (!slot_map_.contains(in)) { + const auto slot = slot_++; + slot_map_[in] = slot; + in->id(slot); + } +} + +void id_assigner::visit_result(spv_inst &in) { declare(&in); } + +void id_assigner::operator()(spv_inst *&in) { + if (!slot_map_.contains(in)) { + if (isa(*in) || isa(*in) || isa(*in) || + isa(*in) || isa(*in)) { + declare(in); + } else { + throw status::spirv_forbidden_forward_declaration; + } + } +} + +void id_assigner::operator()(OpPhi &in) { + pre_visit(in); + this->operator()(in.type()); + this->visit_result(in); + for (auto &op : in.op0()) { + // Forward references are allowed in phi instructions + declare(op.first); + this->operator()(op); + } + post_visit(in); +} + +void id_assigner::run_on_module(tinytc_spv_mod &m) { + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto &i : m.insts(enum_cast
(s))) { + visit(*this, i); + } + } +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/assign_ids.hpp b/src/spv/pass/assign_ids.hpp new file mode 100644 index 00000000..a6e16c8c --- /dev/null +++ b/src/spv/pass/assign_ids.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ASSIGN_IDS_20241111_HPP +#define ASSIGN_IDS_20241111_HPP + +#include "spv/defs.hpp" +#include "spv/visit.hpp" +#include "tinytc/types.h" + +#include +#include + +namespace tinytc::spv { + +class id_assigner : public default_visitor { + public: + using default_visitor::operator(); + + void visit_result(spv_inst &in); + + // Do nothing by default + template void operator()(T &) {} + + void operator()(spv_inst *&in); + void operator()(OpPhi &in); + + void run_on_module(tinytc_spv_mod &m); + + private: + void declare(spv_inst *in); + + std::uint32_t slot_ = 1; + std::unordered_map slot_map_; +}; + +} // namespace tinytc::spv + +#endif // ASSIGN_IDS_20241111_HPP diff --git a/src/spv/pass/capex.cpp b/src/spv/pass/capex.cpp new file mode 100644 index 00000000..aa1f0431 --- /dev/null +++ b/src/spv/pass/capex.cpp @@ -0,0 +1,259 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/capex.hpp" +#include "spv/capex_util.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "spv/visit.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc::spv { + +template +concept inst_with_return_type = requires(T &t) { + { t.type() } -> std::same_as; +}; + +capex::capex(uniquifier &unique) : unique_{&unique} {} + +void capex::operator()(spv_inst const &) {} +void capex::operator()(OpAtomicStore const &in) { + auto ty = visit(overloaded{[](inst_with_return_type auto &a) -> spv_inst * { return a.type(); }, + [](auto &) -> spv_inst * { return nullptr; }}, + *in.op3()); + if (!ty) { + throw status::internal_compiler_error; + } + auto ity = dyn_cast(ty); + if (ity && ity->op0() == 64) { + unique_->capability(Capability::Int64Atomics); + required_features_[tinytc_spirv_feature_int64_atomics] = true; + } +} +auto capex::float_atomic_class(spv_inst *raw_ty, spv_inst *op0) + -> std::pair { + auto ty = dyn_cast(raw_ty); + if (!ty) { + throw status::internal_compiler_error; + } + + auto pointer_ty = visit(overloaded{[](inst_with_return_type auto &a) -> OpTypePointer * { + return dyn_cast(a.type()); + }, + [](auto &) -> OpTypePointer * { return nullptr; }}, + *op0); + if (!pointer_ty) { + throw status::internal_compiler_error; + } + return {ty->op0(), pointer_ty->op0()}; +} +void capex::operator()(OpAtomicFAddEXT const &in) { + const auto [bits, storage_cls] = float_atomic_class(in.type(), in.op0()); + switch (bits) { + case 16: + unique_->capability(Capability::AtomicFloat16AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float16_add"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float16_add_local + : tinytc_spirv_feature_atomic_float16_add_global] = true; + break; + case 32: + unique_->capability(Capability::AtomicFloat32AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float_add"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float32_add_local + : tinytc_spirv_feature_atomic_float32_add_global] = true; + break; + case 64: + unique_->capability(Capability::AtomicFloat64AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float_add"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float64_add_local + : tinytc_spirv_feature_atomic_float64_add_global] = true; + break; + default: + break; + } +} +void capex::check_float_min_max_atomic(spv_inst *ty, spv_inst *op0) { + const auto [bits, storage_cls] = float_atomic_class(ty, op0); + switch (bits) { + case 16: + unique_->capability(Capability::AtomicFloat16MinMaxEXT); + unique_->extension("SPV_EXT_shader_atomic_float16_min_max"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float16_min_max_local + : tinytc_spirv_feature_atomic_float16_min_max_global] = true; + break; + case 32: + unique_->capability(Capability::AtomicFloat32MinMaxEXT); + unique_->extension("SPV_EXT_shader_atomic_float_min_max"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float32_min_max_local + : tinytc_spirv_feature_atomic_float32_min_max_global] = true; + break; + case 64: + unique_->capability(Capability::AtomicFloat64MinMaxEXT); + unique_->extension("SPV_EXT_shader_atomic_float_min_max"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float64_min_max_local + : tinytc_spirv_feature_atomic_float64_min_max_global] = true; + break; + default: + break; + } +} +void capex::operator()(OpAtomicFMaxEXT const &in) { + check_float_min_max_atomic(in.type(), in.op0()); +} +void capex::operator()(OpAtomicFMinEXT const &in) { + check_float_min_max_atomic(in.type(), in.op0()); +} +void capex::check_int_atomic(spv_inst *raw_ty) { + auto ty = dyn_cast(raw_ty); + if (!ty) { + throw status::internal_compiler_error; + } + if (ty && ty->op0() == 64) { + unique_->capability(Capability::Int64Atomics); + required_features_[tinytc_spirv_feature_int64_atomics] = true; + } +} +void capex::operator()(OpAtomicIAdd const &in) { check_int_atomic(in.type()); } +void capex::operator()(OpAtomicSMax const &in) { check_int_atomic(in.type()); } +void capex::operator()(OpAtomicSMin const &in) { check_int_atomic(in.type()); } +void capex::operator()(OpAsmTargetINTEL const &) { + unique_->capability(Capability::AsmINTEL); + unique_->extension("SPV_INTEL_inline_assembly"); +} +void capex::operator()(OpAsmINTEL const &) { + unique_->capability(Capability::AsmINTEL); + unique_->extension("SPV_INTEL_inline_assembly"); +} +void capex::operator()(OpAsmCallINTEL const &) { + unique_->capability(Capability::AsmINTEL); + unique_->extension("SPV_INTEL_inline_assembly"); +} +void capex::operator()(OpConvertBF16ToFINTEL const &) { + unique_->capability(Capability::BFloat16ConversionINTEL); + unique_->extension("SPV_INTEL_bfloat16_conversion"); + required_features_[tinytc_spirv_feature_bfloat16_conversion] = true; +} +void capex::operator()(OpConvertFToBF16INTEL const &) { + unique_->capability(Capability::BFloat16ConversionINTEL); + unique_->extension("SPV_INTEL_bfloat16_conversion"); + required_features_[tinytc_spirv_feature_bfloat16_conversion] = true; +} +void capex::operator()(OpCooperativeMatrixLoadKHR const &) { + unique_->capability(Capability::CooperativeMatrixKHR); + unique_->extension("SPV_KHR_cooperative_matrix"); +} +void capex::operator()(OpCooperativeMatrixMulAddKHR const &) { + unique_->capability(Capability::CooperativeMatrixKHR); + unique_->extension("SPV_KHR_cooperative_matrix"); +} +void capex::operator()(OpCooperativeMatrixStoreKHR const &) { + unique_->capability(Capability::CooperativeMatrixKHR); + unique_->extension("SPV_KHR_cooperative_matrix"); +} +void capex::operator()(OpEntryPoint const &in) { + for (auto const &cap : capabilities(in.op0())) { + unique_->capability(cap); + } +} +void capex::operator()(OpExecutionMode const &in) { + for (auto const &cap : capabilities(in.op1())) { + unique_->capability(cap); + if (cap == Capability::SubgroupDispatch) { + required_features_[tinytc_spirv_feature_subgroup_dispatch] = true; + } + } +} +void capex::operator()(OpGroupBroadcast const &) { + unique_->capability(Capability::Groups); + required_features_[tinytc_spirv_feature_groups] = true; +} +void capex::operator()(OpGroupFAdd const &) { + unique_->capability(Capability::Groups); + required_features_[tinytc_spirv_feature_groups] = true; +} +void capex::operator()(OpGroupIAdd const &) { + unique_->capability(Capability::Groups); + required_features_[tinytc_spirv_feature_groups] = true; +} +void capex::operator()(OpInBoundsPtrAccessChain const &) { + unique_->capability(Capability::Addresses); +} +void capex::operator()(OpMemoryModel const &in) { + for (auto const &cap : capabilities(in.op0())) { + unique_->capability(cap); + } + for (auto const &cap : capabilities(in.op1())) { + unique_->capability(cap); + } +} +void capex::operator()(OpSubgroupBlockReadINTEL const &) { + unique_->capability(Capability::SubgroupBufferBlockIOINTEL); + unique_->extension("SPV_INTEL_subgroups"); + required_features_[tinytc_spirv_feature_subgroup_buffer_block_io] = true; +} +void capex::operator()(OpSubgroupBlockWriteINTEL const &) { + unique_->capability(Capability::SubgroupBufferBlockIOINTEL); + unique_->extension("SPV_INTEL_subgroups"); + required_features_[tinytc_spirv_feature_subgroup_buffer_block_io] = true; +} +void capex::operator()(OpTypeFloat const &in) { + switch (in.op0()) { + case 16: + unique_->capability(Capability::Float16); + required_features_[tinytc_spirv_feature_float16] = true; + break; + case 64: + unique_->capability(Capability::Float64); + required_features_[tinytc_spirv_feature_float64] = true; + break; + default: + break; + } +} +void capex::operator()(OpTypeInt const &in) { + switch (in.op0()) { + case 8: + unique_->capability(Capability::Int8); + break; + case 16: + unique_->capability(Capability::Int16); + break; + case 64: + unique_->capability(Capability::Int64); + break; + default: + break; + } +} +void capex::operator()(OpTypeVector const &in) { + if (in.op1() > 4) { + unique_->capability(Capability::Vector16); + } +} + +void capex::run_on_module(tinytc_spv_mod const &mod) { + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : mod.insts(enum_cast
(s))) { + visit(*this, i); + } + } +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/capex.hpp b/src/spv/pass/capex.hpp new file mode 100644 index 00000000..6a2723f8 --- /dev/null +++ b/src/spv/pass/capex.hpp @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CAPEX_20241113_HPP +#define CAPEX_20241113_HPP + +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "tinytc/types.h" + +#include +#include + +namespace tinytc { +enum class spirv_feature; +} + +namespace tinytc::spv { + +class uniquifier; + +class capex { + public: + capex(uniquifier &unique); + + void operator()(spv_inst const &in); + void operator()(OpAtomicStore const &in); + void operator()(OpAtomicFAddEXT const &in); + void operator()(OpAtomicFMaxEXT const &in); + void operator()(OpAtomicFMinEXT const &in); + void operator()(OpAtomicIAdd const &in); + void operator()(OpAtomicSMax const &in); + void operator()(OpAtomicSMin const &in); + void operator()(OpAsmTargetINTEL const &in); + void operator()(OpAsmINTEL const &in); + void operator()(OpAsmCallINTEL const &in); + void operator()(OpConvertBF16ToFINTEL const &in); + void operator()(OpConvertFToBF16INTEL const &in); + void operator()(OpCooperativeMatrixLoadKHR const &in); + void operator()(OpCooperativeMatrixMulAddKHR const &in); + void operator()(OpCooperativeMatrixStoreKHR const &in); + void operator()(OpEntryPoint const &in); + void operator()(OpExecutionMode const &in); + void operator()(OpGroupBroadcast const &in); + void operator()(OpGroupFAdd const &in); + void operator()(OpGroupIAdd const &in); + void operator()(OpInBoundsPtrAccessChain const &in); + void operator()(OpMemoryModel const &in); + void operator()(OpSubgroupBlockReadINTEL const &in); + void operator()(OpSubgroupBlockWriteINTEL const &in); + void operator()(OpTypeFloat const &in); + void operator()(OpTypeInt const &in); + void operator()(OpTypeVector const &in); + + void run_on_module(tinytc_spv_mod const &mod); + + inline auto requires_feature(spirv_feature f) const -> bool { + return required_features_[static_cast(f)]; + } + + private: + auto float_atomic_class(spv_inst *ty, spv_inst *op0) -> std::pair; + void check_float_min_max_atomic(spv_inst *ty, spv_inst *op0); + void check_int_atomic(spv_inst *ty); + + uniquifier *unique_; + std::array required_features_ = {}; +}; + +} // namespace tinytc::spv + +#endif // CAPEX_20241113_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp new file mode 100644 index 00000000..693b0739 --- /dev/null +++ b/src/spv/pass/dump_asm.cpp @@ -0,0 +1,150 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/dump_asm.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "tinytc/tinytc.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +enum class LinkageType; + +dump_asm_pass::dump_asm_pass(std::ostream &os) : os_(&os) {} + +void dump_asm_pass::pre_visit(spv_inst const &in) { + auto const num_digits = [](std::int64_t number) { + std::int64_t d = 1; + while (number /= 10) { + ++d; + } + return d; + }; + *os_ << std::endl; + if (in.has_result_id()) { + const auto id = in.id(); + + for (int i = 0; i < rhs_indent - 4 - num_digits(id); ++i) { + *os_ << ' '; + } + *os_ << "%" << id << " = "; + } else { + for (int i = 0; i < rhs_indent; ++i) { + *os_ << ' '; + } + } + *os_ << "Op" << to_string(in.opcode()); +} + +void dump_asm_pass::operator()(DecorationAttr const &da) { + std::visit(overloaded{[&](auto const &a) { this->operator()(a); }, + [&](std::pair const &a) { + *os_ << " \"" << a.first << '"'; + this->operator()(a.second); + }}, + da); +} +void dump_asm_pass::operator()(ExecutionModeAttr const &ea) { + std::visit( + overloaded{[&](std::int32_t const &a) { *os_ << " " << static_cast(a); }, + [&](std::array const &a) { + for (auto const &s : a) { + *os_ << " " << static_cast(s); + } + }}, + ea); +} +void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { + std::visit(overloaded{[&](std::int8_t const &l) { + *os_ << " " + << static_cast(static_cast(l)); + }, + [&](std::signed_integral auto const &l) { + using unsigned_t = std::make_unsigned_t>; + *os_ << " " << static_cast(l); + }, + [&](std::floating_point auto const &l) { + auto flags = os_->flags(); + *os_ << " " << std::hexfloat << l; + os_->flags(flags); + }, + [&](half const &l) { + auto flags = os_->flags(); + *os_ << " " << std::hexfloat << l; + os_->flags(flags); + }}, + l); +} +void dump_asm_pass::operator()(LiteralInteger const &l) { + *os_ << " " << static_cast>(l); +} +void dump_asm_pass::operator()(LiteralString const &l) { *os_ << " \"" << l << '"'; } + +void dump_asm_pass::operator()(PairIdRefIdRef const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void dump_asm_pass::operator()(PairIdRefLiteralInteger const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void dump_asm_pass::operator()(PairLiteralIntegerIdRef const &p) { + std::visit(overloaded{[&](auto const &l) { + using unsigned_t = std::make_unsigned_t>; + *os_ << " " << static_cast(l); + }}, + p.first); + this->operator()(p.second); +} + +void dump_asm_pass::operator()(spv_inst *const &in) { *os_ << " %" << in->id(); } +void dump_asm_pass::operator()(OpExtInst const &in) { + pre_visit(in); + this->operator()(in.type()); + visit_result(in); + this->operator()(in.op0()); + + if (auto extimport = dyn_cast(in.op0()); + extimport && extimport->op0() == OpenCLExt) { + this->operator()(static_cast(in.op1())); + } else { + this->operator()(in.op1()); + } + + for (auto const &op : in.op2()) { + this->operator()(op); + } + post_visit(in); +} + +void dump_asm_pass::run_on_module(tinytc_spv_mod const &m) { + auto const visit_section = [&](section s) { + for (auto const &i : m.insts(s)) { + visit(*this, i); + } + }; + *os_ << "; SPIR-V" << std::endl + << "; Version " << m.major_version() << '.' << m.minor_version() << std::endl + << "; Generator: Tiny Tensor Compiler" << std::endl + << "; Bound: " << m.bound() << std::endl + << "; Schema: 0"; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + visit_section(enum_cast
(s)); + } + *os_ << std::endl; +} + +} // namespace tinytc::spv diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp new file mode 100644 index 00000000..a1bbad29 --- /dev/null +++ b/src/spv/pass/dump_asm.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_ASM_20241029_HPP +#define DUMP_ASM_20241029_HPP + +#include "spv/defs.hpp" +#include "spv/names.hpp" +#include "spv/visit.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc::spv { + +class dump_asm_pass : public default_visitor { + public: + using default_visitor::operator(); + constexpr static int rhs_indent = 15; + + dump_asm_pass(std::ostream &os); + + void pre_visit(spv_inst const &in); + + template + requires requires(T const &e) { to_string(e); } + void operator()(T const &e) { + *os_ << " " << to_string(e); + } + void operator()(DecorationAttr const &da); + void operator()(ExecutionModeAttr const &ea); + void operator()(LiteralContextDependentNumber const &l); + void operator()(LiteralInteger const &l); + void operator()(LiteralString const &l); + + void operator()(PairIdRefIdRef const &p); + void operator()(PairIdRefLiteralInteger const &p); + void operator()(PairLiteralIntegerIdRef const &p); + + void operator()(spv_inst *const &in); + void operator()(OpExtInst const &in); + + void run_on_module(tinytc_spv_mod const &m); + + private: + std::ostream *os_; +}; + +} // namespace tinytc::spv + +#endif // DUMP_ASM_20241029_HPP diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp new file mode 100644 index 00000000..48e8a8da --- /dev/null +++ b/src/spv/uniquifier.cpp @@ -0,0 +1,281 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/uniquifier.hpp" +#include "node/data_type_node.hpp" +#include "scalar_type.hpp" +#include "spv/instructions.hpp" +#include "spv/lut.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "support/fnv1a_array_view.hpp" +#include "tinytc/types.hpp" +#include "util/visit.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto address_space_to_storage_class(address_space as) -> StorageClass { + return as == address_space::local ? StorageClass::Workgroup : StorageClass::CrossWorkgroup; +} + +uniquifier::uniquifier(tinytc_spv_mod &m) : mod_(&m) { + std::fill(std::begin(scalar_tys_), std::end(scalar_tys_), nullptr); +} + +auto uniquifier::asm_target() -> spv_inst * { + if (!asm_target_) { + asm_target_ = + mod_->add_to(section::type_const_var, "spirv64-unknown-unknown"); + } + return asm_target_; +} + +auto uniquifier::bool_constant(bool b) -> spv_inst * { + if (b) { + return lookup(bool_true_, [&] { + return mod_->add_to(section::type_const_var, bool_ty()); + }); + } + return lookup(bool_false_, [&] { + return mod_->add_to(section::type_const_var, bool_ty()); + }); +} + +auto uniquifier::builtin_alignment(BuiltIn b) -> std::int32_t { + switch (b) { + case BuiltIn::WorkDim: + case BuiltIn::SubgroupSize: + case BuiltIn::SubgroupMaxSize: + case BuiltIn::NumSubgroups: + case BuiltIn::NumEnqueuedSubgroups: + case BuiltIn::SubgroupId: + case BuiltIn::SubgroupLocalInvocationId: + return alignment(scalar_type::i32); + case BuiltIn::GlobalLinearId: + case BuiltIn::LocalInvocationIndex: + return alignment(scalar_type::index); + case BuiltIn::GlobalSize: + case BuiltIn::GlobalInvocationId: + case BuiltIn::WorkgroupSize: + case BuiltIn::EnqueuedWorkgroupSize: + case BuiltIn::LocalInvocationId: + case BuiltIn::NumWorkgroups: + case BuiltIn::WorkgroupId: + case BuiltIn::GlobalOffset: + return alignment(scalar_type::index, vector_size::v3); + break; + default: + throw status::internal_compiler_error; + } +} + +auto uniquifier::builtin_pointee_ty(BuiltIn b) -> spv_inst * { + switch (b) { + case BuiltIn::WorkDim: + case BuiltIn::SubgroupSize: + case BuiltIn::SubgroupMaxSize: + case BuiltIn::NumSubgroups: + case BuiltIn::NumEnqueuedSubgroups: + case BuiltIn::SubgroupId: + case BuiltIn::SubgroupLocalInvocationId: + return scalar_ty(scalar_type::i32); + case BuiltIn::GlobalLinearId: + case BuiltIn::LocalInvocationIndex: + return scalar_ty(scalar_type::index); + case BuiltIn::GlobalSize: + case BuiltIn::GlobalInvocationId: + case BuiltIn::WorkgroupSize: + case BuiltIn::EnqueuedWorkgroupSize: + case BuiltIn::LocalInvocationId: + case BuiltIn::NumWorkgroups: + case BuiltIn::WorkgroupId: + case BuiltIn::GlobalOffset: + return index3_ty(); + break; + default: + throw status::internal_compiler_error; + } +} + +auto uniquifier::builtin_var(BuiltIn b) -> spv_inst * { + return lookup(builtin_, b, [&](BuiltIn b) { + auto ty = pointer_ty(StorageClass::Input, builtin_pointee_ty(b), builtin_alignment(b)); + auto var = mod_->add_to(section::type_const_var, ty, StorageClass::Input, + std::nullopt); + mod_->add_to(section::decoration, var, Decoration::Constant); + mod_->add_to(section::decoration, var, Decoration::BuiltIn, b); + return var; + }); +} + +void uniquifier::capability(Capability cap) { + if (!capabilities_.contains(cap)) { + mod_->add_to(section::capability, cap); + capabilities_.insert(cap); + } +} + +auto uniquifier::constant(LiteralContextDependentNumber cst) -> spv_inst * { + return lookup(cst_map_, cst, [&](LiteralContextDependentNumber const &cst) { + scalar_type sty = std::visit( + overloaded{[](auto const &c) { return to_scalar_type_v>; }}, + cst); + return mod_->add_to(section::type_const_var, scalar_ty(sty), cst); + }); +} + +void uniquifier::extension(char const *ext_name) { + if (!extensions_.contains(ext_name)) { + mod_->add_to(section::extension, ext_name); + extensions_.insert(ext_name); + } +} + +auto uniquifier::null_constant(spv_inst *spv_ty) -> spv_inst * { + return lookup(null_cst_, spv_ty, [&](spv_inst *spv_ty) { + return mod_->add_to(section::type_const_var, spv_ty); + }); +} + +auto uniquifier::opencl_ext() -> spv_inst * { + return lookup(opencl_ext_, + [&] { return mod_->add_to(section::ext_inst, OpenCLExt); }); +} + +auto uniquifier::array_ty(spv_inst *element_ty, std::int32_t length) -> spv_inst * { + auto key = std::make_pair(element_ty, length); + return lookup(array_tys_, key, [&](std::pair const &key) { + return mod_->add_to(section::type_const_var, key.first, constant(key.second)); + }); +} + +auto uniquifier::bool_ty() -> spv_inst * { + if (!bool_ty_) { + bool_ty_ = mod_->add_to(section::type_const_var); + } + return bool_ty_; +} + +auto uniquifier::bool2_ty() -> spv_inst * { return vec_ty(bool_ty(), vector_size::v2); } + +auto uniquifier::function_ty(spv_inst *return_ty, array_view params) -> spv_inst * { + const auto map_key = fnv1a_step(fnv1a_step(fnv1a0(), return_ty), params); + auto range = function_tys_.equal_range(map_key); + for (auto it = range.first; it != range.second; ++it) { + if (return_ty == it->second->op0() && + std::equal(params.begin(), params.end(), it->second->op1().begin(), + it->second->op1().end())) { + return it->second; + } + } + return function_tys_ + .emplace(map_key, mod_->add_to(section::type_const_var, return_ty, + std::move(params))) + ->second; +} + +auto uniquifier::index3_ty() -> spv_inst * { + return vec_ty(scalar_ty(scalar_type::index), vector_size::v3); +} + +auto uniquifier::pointer_ty(StorageClass cls, spv_inst *pointee_ty, std::int32_t alignment) + -> spv_inst * { + auto key = std::make_tuple(cls, pointee_ty, alignment); + return lookup( + pointer_tys_, key, [&](std::tuple const &key) { + auto pointer_ty = mod_->add_to(section::type_const_var, std::get<0>(key), + std::get<1>(key)); + if (std::get<2>(key) > 0) { + mod_->add_to(section::decoration, pointer_ty, Decoration::Alignment, + DecorationAttr{std::get<2>(key)}); + } + return pointer_ty; + }); +} + +auto uniquifier::pointer_ty(memref_data_type const *mt) -> spv_inst * { + const auto storage_cls = address_space_to_storage_class(mt->addrspace()); + auto ty = scalar_ty(mt->element_ty()); + const auto align = mt->element_alignment(); + return pointer_ty(storage_cls, ty, align); +} + +auto uniquifier::scalar_ty(scalar_type sty) -> spv_inst * { + auto const make_ty = [this](scalar_type sty) -> spv_inst * { + switch (sty) { + case scalar_type::i8: + return mod_->add_to(section::type_const_var, 8, 0); + case scalar_type::i16: + return mod_->add_to(section::type_const_var, 16, 0); + case scalar_type::i32: + return mod_->add_to(section::type_const_var, 32, 0); + case scalar_type::i64: + return mod_->add_to(section::type_const_var, 64, 0); + case scalar_type::index: { + const auto sz = size(scalar_type::index); + if (sz == 8) { + return scalar_ty(scalar_type::i64); + } + return scalar_ty(scalar_type::i32); + } + case scalar_type::bf16: + return scalar_ty(scalar_type::i16); + case scalar_type::f16: + case scalar_type::f32: + case scalar_type::f64: + return mod_->add_to(section::type_const_var, size(sty) * 8); + case scalar_type::c32: { + auto f32_ty = scalar_ty(scalar_type::f32); + return vec_ty(f32_ty, vector_size::v2); + } + case scalar_type::c64: { + auto f64_ty = scalar_ty(scalar_type::f64); + return vec_ty(f64_ty, vector_size::v2); + } + } + throw status::internal_compiler_error; + }; + + const auto index = static_cast(sty); + if (index < 0 || index >= scalar_tys_.size()) { + throw status::internal_compiler_error; + } + if (!scalar_tys_[index]) { + scalar_tys_[index] = make_ty(sty); + } + return scalar_tys_[index]; +} + +auto uniquifier::vec_ty(spv_inst *component_ty, std::int32_t length) -> spv_inst * { + auto key = std::make_pair(component_ty, length); + return lookup(vec_tys_, key, [&](std::pair const &key) { + return mod_->add_to(section::type_const_var, key.first, key.second); + }); +} +auto uniquifier::vec_ty(spv_inst *component_ty, vector_size length) -> spv_inst * { + return vec_ty(component_ty, static_cast(length)); +} + +auto uniquifier::void_ty() -> spv_inst * { + if (!void_ty_) { + void_ty_ = mod_->add_to(section::type_const_var); + } + return void_ty_; +} + +auto uniquifier::load_builtin(BuiltIn b) -> spv_inst * { + auto builtin = builtin_var(b); + return mod_->add(builtin_pointee_ty(b), builtin, MemoryAccess::Aligned, + builtin_alignment(b)); +} + +} // namespace tinytc::spv + diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp new file mode 100644 index 00000000..f7262ce9 --- /dev/null +++ b/src/spv/uniquifier.hpp @@ -0,0 +1,103 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef UNIQUIFIER_20241107_HPP +#define UNIQUIFIER_20241107_HPP + +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { +enum class address_space; +enum class scalar_type; +enum class vector_size; +class memref_data_type; +} // namespace tinytc + +namespace tinytc::spv { + +auto address_space_to_storage_class(address_space as) -> StorageClass; + +class uniquifier { + public: + uniquifier(tinytc_spv_mod &m); + + inline auto mod() -> tinytc_spv_mod & { return *mod_; } + + auto asm_target() -> spv_inst *; + auto bool_constant(bool b) -> spv_inst *; + auto builtin_alignment(BuiltIn b) -> std::int32_t; + auto builtin_pointee_ty(BuiltIn b) -> spv_inst *; + auto builtin_var(BuiltIn b) -> spv_inst *; + void capability(Capability cap); + auto constant(LiteralContextDependentNumber cst) -> spv_inst *; + void extension(char const *ext_name); + auto null_constant(spv_inst *spv_ty) -> spv_inst *; + auto opencl_ext() -> spv_inst *; + + // types + auto array_ty(spv_inst *element_ty, std::int32_t length) -> spv_inst *; + auto bool_ty() -> spv_inst *; + auto bool2_ty() -> spv_inst *; + auto function_ty(spv_inst *return_ty, array_view params) -> spv_inst *; + auto index3_ty() -> spv_inst *; + auto pointer_ty(StorageClass cls, spv_inst *pointee_ty, std::int32_t alignment) -> spv_inst *; + auto pointer_ty(memref_data_type const *mt) -> spv_inst *; + auto scalar_ty(scalar_type sty) -> spv_inst *; + auto vec_ty(spv_inst *component_ty, std::int32_t length) -> spv_inst *; + auto vec_ty(spv_inst *component_ty, vector_size length) -> spv_inst *; + auto void_ty() -> spv_inst *; + + // util + auto load_builtin(BuiltIn b) -> spv_inst *; + + private: + struct array_key_hash { + inline auto operator()(std::pair const &key) const + -> std::size_t { + return fnv1a_combine(key.first, key.second); + } + }; + struct pointer_key_hash { + inline auto operator()(std::tuple const &key) const + -> std::size_t { + return fnv1a_combine(std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + }; + + tinytc_spv_mod_t mod_; + spv_inst *asm_target_ = nullptr; + spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; + spv_inst *opencl_ext_ = nullptr; + std::unordered_map builtin_; + std::unordered_set capabilities_; + std::unordered_map cst_map_; + std::unordered_set extensions_; + std::unordered_map null_cst_; + + std::unordered_map, spv_inst *, array_key_hash> array_tys_; + spv_inst *bool_ty_ = nullptr; + std::unordered_multimap function_tys_; + std::unordered_map, spv_inst *, + pointer_key_hash> + pointer_tys_; + std::array scalar_tys_; + std::unordered_map, spv_inst *, array_key_hash> vec_tys_; + spv_inst *void_ty_ = nullptr; +}; + +} // namespace tinytc::spv + +#endif // UNIQUIFIER_20241107_HPP diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp new file mode 100644 index 00000000..6a33720f --- /dev/null +++ b/src/spv/visit.hpp @@ -0,0 +1,4582 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_VISIT_20250605_HPP +#define GENERATED_VISIT_20250605_HPP + +#include "defs.hpp" +#include "enums.hpp" +#include "instructions.hpp" + +namespace tinytc::spv { + +template struct overloaded : Ts... { + using Ts::operator()...; +}; +template overloaded(Ts...) -> overloaded; +template auto visit(Visitor &&visitor, spv_inst &inst) { + switch (inst.opcode()) { + case Op::Nop: + return visitor(static_cast(inst)); + case Op::Undef: + return visitor(static_cast(inst)); + case Op::SourceContinued: + return visitor(static_cast(inst)); + case Op::Source: + return visitor(static_cast(inst)); + case Op::SourceExtension: + return visitor(static_cast(inst)); + case Op::Name: + return visitor(static_cast(inst)); + case Op::MemberName: + return visitor(static_cast(inst)); + case Op::String: + return visitor(static_cast(inst)); + case Op::Line: + return visitor(static_cast(inst)); + case Op::Extension: + return visitor(static_cast(inst)); + case Op::ExtInstImport: + return visitor(static_cast(inst)); + case Op::ExtInst: + return visitor(static_cast(inst)); + case Op::MemoryModel: + return visitor(static_cast(inst)); + case Op::EntryPoint: + return visitor(static_cast(inst)); + case Op::ExecutionMode: + return visitor(static_cast(inst)); + case Op::Capability: + return visitor(static_cast(inst)); + case Op::TypeVoid: + return visitor(static_cast(inst)); + case Op::TypeBool: + return visitor(static_cast(inst)); + case Op::TypeInt: + return visitor(static_cast(inst)); + case Op::TypeFloat: + return visitor(static_cast(inst)); + case Op::TypeVector: + return visitor(static_cast(inst)); + case Op::TypeMatrix: + return visitor(static_cast(inst)); + case Op::TypeImage: + return visitor(static_cast(inst)); + case Op::TypeSampler: + return visitor(static_cast(inst)); + case Op::TypeSampledImage: + return visitor(static_cast(inst)); + case Op::TypeArray: + return visitor(static_cast(inst)); + case Op::TypeRuntimeArray: + return visitor(static_cast(inst)); + case Op::TypeStruct: + return visitor(static_cast(inst)); + case Op::TypeOpaque: + return visitor(static_cast(inst)); + case Op::TypePointer: + return visitor(static_cast(inst)); + case Op::TypeFunction: + return visitor(static_cast(inst)); + case Op::TypeEvent: + return visitor(static_cast(inst)); + case Op::TypeDeviceEvent: + return visitor(static_cast(inst)); + case Op::TypeReserveId: + return visitor(static_cast(inst)); + case Op::TypeQueue: + return visitor(static_cast(inst)); + case Op::TypePipe: + return visitor(static_cast(inst)); + case Op::TypeForwardPointer: + return visitor(static_cast(inst)); + case Op::ConstantTrue: + return visitor(static_cast(inst)); + case Op::ConstantFalse: + return visitor(static_cast(inst)); + case Op::Constant: + return visitor(static_cast(inst)); + case Op::ConstantComposite: + return visitor(static_cast(inst)); + case Op::ConstantSampler: + return visitor(static_cast(inst)); + case Op::ConstantNull: + return visitor(static_cast(inst)); + case Op::Function: + return visitor(static_cast(inst)); + case Op::FunctionParameter: + return visitor(static_cast(inst)); + case Op::FunctionEnd: + return visitor(static_cast(inst)); + case Op::FunctionCall: + return visitor(static_cast(inst)); + case Op::Variable: + return visitor(static_cast(inst)); + case Op::ImageTexelPointer: + return visitor(static_cast(inst)); + case Op::Load: + return visitor(static_cast(inst)); + case Op::Store: + return visitor(static_cast(inst)); + case Op::CopyMemory: + return visitor(static_cast(inst)); + case Op::CopyMemorySized: + return visitor(static_cast(inst)); + case Op::AccessChain: + return visitor(static_cast(inst)); + case Op::InBoundsAccessChain: + return visitor(static_cast(inst)); + case Op::PtrAccessChain: + return visitor(static_cast(inst)); + case Op::ArrayLength: + return visitor(static_cast(inst)); + case Op::GenericPtrMemSemantics: + return visitor(static_cast(inst)); + case Op::InBoundsPtrAccessChain: + return visitor(static_cast(inst)); + case Op::Decorate: + return visitor(static_cast(inst)); + case Op::MemberDecorate: + return visitor(static_cast(inst)); + case Op::DecorationGroup: + return visitor(static_cast(inst)); + case Op::GroupDecorate: + return visitor(static_cast(inst)); + case Op::GroupMemberDecorate: + return visitor(static_cast(inst)); + case Op::VectorExtractDynamic: + return visitor(static_cast(inst)); + case Op::VectorInsertDynamic: + return visitor(static_cast(inst)); + case Op::VectorShuffle: + return visitor(static_cast(inst)); + case Op::CompositeConstruct: + return visitor(static_cast(inst)); + case Op::CompositeExtract: + return visitor(static_cast(inst)); + case Op::CompositeInsert: + return visitor(static_cast(inst)); + case Op::CopyObject: + return visitor(static_cast(inst)); + case Op::Transpose: + return visitor(static_cast(inst)); + case Op::SampledImage: + return visitor(static_cast(inst)); + case Op::ImageSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageFetch: + return visitor(static_cast(inst)); + case Op::ImageGather: + return visitor(static_cast(inst)); + case Op::ImageDrefGather: + return visitor(static_cast(inst)); + case Op::ImageRead: + return visitor(static_cast(inst)); + case Op::ImageWrite: + return visitor(static_cast(inst)); + case Op::Image: + return visitor(static_cast(inst)); + case Op::ImageQueryFormat: + return visitor(static_cast(inst)); + case Op::ImageQueryOrder: + return visitor(static_cast(inst)); + case Op::ImageQuerySizeLod: + return visitor(static_cast(inst)); + case Op::ImageQuerySize: + return visitor(static_cast(inst)); + case Op::ImageQueryLod: + return visitor(static_cast(inst)); + case Op::ImageQueryLevels: + return visitor(static_cast(inst)); + case Op::ImageQuerySamples: + return visitor(static_cast(inst)); + case Op::ConvertFToU: + return visitor(static_cast(inst)); + case Op::ConvertFToS: + return visitor(static_cast(inst)); + case Op::ConvertSToF: + return visitor(static_cast(inst)); + case Op::ConvertUToF: + return visitor(static_cast(inst)); + case Op::UConvert: + return visitor(static_cast(inst)); + case Op::SConvert: + return visitor(static_cast(inst)); + case Op::FConvert: + return visitor(static_cast(inst)); + case Op::QuantizeToF16: + return visitor(static_cast(inst)); + case Op::ConvertPtrToU: + return visitor(static_cast(inst)); + case Op::SatConvertSToU: + return visitor(static_cast(inst)); + case Op::SatConvertUToS: + return visitor(static_cast(inst)); + case Op::ConvertUToPtr: + return visitor(static_cast(inst)); + case Op::PtrCastToGeneric: + return visitor(static_cast(inst)); + case Op::GenericCastToPtr: + return visitor(static_cast(inst)); + case Op::GenericCastToPtrExplicit: + return visitor(static_cast(inst)); + case Op::Bitcast: + return visitor(static_cast(inst)); + case Op::SNegate: + return visitor(static_cast(inst)); + case Op::FNegate: + return visitor(static_cast(inst)); + case Op::IAdd: + return visitor(static_cast(inst)); + case Op::FAdd: + return visitor(static_cast(inst)); + case Op::ISub: + return visitor(static_cast(inst)); + case Op::FSub: + return visitor(static_cast(inst)); + case Op::IMul: + return visitor(static_cast(inst)); + case Op::FMul: + return visitor(static_cast(inst)); + case Op::UDiv: + return visitor(static_cast(inst)); + case Op::SDiv: + return visitor(static_cast(inst)); + case Op::FDiv: + return visitor(static_cast(inst)); + case Op::UMod: + return visitor(static_cast(inst)); + case Op::SRem: + return visitor(static_cast(inst)); + case Op::SMod: + return visitor(static_cast(inst)); + case Op::FRem: + return visitor(static_cast(inst)); + case Op::FMod: + return visitor(static_cast(inst)); + case Op::VectorTimesScalar: + return visitor(static_cast(inst)); + case Op::MatrixTimesScalar: + return visitor(static_cast(inst)); + case Op::VectorTimesMatrix: + return visitor(static_cast(inst)); + case Op::MatrixTimesVector: + return visitor(static_cast(inst)); + case Op::MatrixTimesMatrix: + return visitor(static_cast(inst)); + case Op::OuterProduct: + return visitor(static_cast(inst)); + case Op::Dot: + return visitor(static_cast(inst)); + case Op::IAddCarry: + return visitor(static_cast(inst)); + case Op::ISubBorrow: + return visitor(static_cast(inst)); + case Op::UMulExtended: + return visitor(static_cast(inst)); + case Op::SMulExtended: + return visitor(static_cast(inst)); + case Op::Any: + return visitor(static_cast(inst)); + case Op::All: + return visitor(static_cast(inst)); + case Op::IsNan: + return visitor(static_cast(inst)); + case Op::IsInf: + return visitor(static_cast(inst)); + case Op::IsFinite: + return visitor(static_cast(inst)); + case Op::IsNormal: + return visitor(static_cast(inst)); + case Op::SignBitSet: + return visitor(static_cast(inst)); + case Op::LessOrGreater: + return visitor(static_cast(inst)); + case Op::Ordered: + return visitor(static_cast(inst)); + case Op::Unordered: + return visitor(static_cast(inst)); + case Op::LogicalEqual: + return visitor(static_cast(inst)); + case Op::LogicalNotEqual: + return visitor(static_cast(inst)); + case Op::LogicalOr: + return visitor(static_cast(inst)); + case Op::LogicalAnd: + return visitor(static_cast(inst)); + case Op::LogicalNot: + return visitor(static_cast(inst)); + case Op::Select: + return visitor(static_cast(inst)); + case Op::IEqual: + return visitor(static_cast(inst)); + case Op::INotEqual: + return visitor(static_cast(inst)); + case Op::UGreaterThan: + return visitor(static_cast(inst)); + case Op::SGreaterThan: + return visitor(static_cast(inst)); + case Op::UGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::SGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ULessThan: + return visitor(static_cast(inst)); + case Op::SLessThan: + return visitor(static_cast(inst)); + case Op::ULessThanEqual: + return visitor(static_cast(inst)); + case Op::SLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdEqual: + return visitor(static_cast(inst)); + case Op::FUnordEqual: + return visitor(static_cast(inst)); + case Op::FOrdNotEqual: + return visitor(static_cast(inst)); + case Op::FUnordNotEqual: + return visitor(static_cast(inst)); + case Op::FOrdLessThan: + return visitor(static_cast(inst)); + case Op::FUnordLessThan: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThan: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThan: + return visitor(static_cast(inst)); + case Op::FOrdLessThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ShiftRightLogical: + return visitor(static_cast(inst)); + case Op::ShiftRightArithmetic: + return visitor(static_cast(inst)); + case Op::ShiftLeftLogical: + return visitor(static_cast(inst)); + case Op::BitwiseOr: + return visitor(static_cast(inst)); + case Op::BitwiseXor: + return visitor(static_cast(inst)); + case Op::BitwiseAnd: + return visitor(static_cast(inst)); + case Op::Not: + return visitor(static_cast(inst)); + case Op::BitFieldInsert: + return visitor(static_cast(inst)); + case Op::BitFieldSExtract: + return visitor(static_cast(inst)); + case Op::BitFieldUExtract: + return visitor(static_cast(inst)); + case Op::BitReverse: + return visitor(static_cast(inst)); + case Op::BitCount: + return visitor(static_cast(inst)); + case Op::DPdx: + return visitor(static_cast(inst)); + case Op::DPdy: + return visitor(static_cast(inst)); + case Op::Fwidth: + return visitor(static_cast(inst)); + case Op::DPdxFine: + return visitor(static_cast(inst)); + case Op::DPdyFine: + return visitor(static_cast(inst)); + case Op::FwidthFine: + return visitor(static_cast(inst)); + case Op::DPdxCoarse: + return visitor(static_cast(inst)); + case Op::DPdyCoarse: + return visitor(static_cast(inst)); + case Op::FwidthCoarse: + return visitor(static_cast(inst)); + case Op::EmitVertex: + return visitor(static_cast(inst)); + case Op::EndPrimitive: + return visitor(static_cast(inst)); + case Op::EmitStreamVertex: + return visitor(static_cast(inst)); + case Op::EndStreamPrimitive: + return visitor(static_cast(inst)); + case Op::ControlBarrier: + return visitor(static_cast(inst)); + case Op::MemoryBarrier: + return visitor(static_cast(inst)); + case Op::AtomicLoad: + return visitor(static_cast(inst)); + case Op::AtomicStore: + return visitor(static_cast(inst)); + case Op::AtomicExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchangeWeak: + return visitor(static_cast(inst)); + case Op::AtomicIIncrement: + return visitor(static_cast(inst)); + case Op::AtomicIDecrement: + return visitor(static_cast(inst)); + case Op::AtomicIAdd: + return visitor(static_cast(inst)); + case Op::AtomicISub: + return visitor(static_cast(inst)); + case Op::AtomicSMin: + return visitor(static_cast(inst)); + case Op::AtomicUMin: + return visitor(static_cast(inst)); + case Op::AtomicSMax: + return visitor(static_cast(inst)); + case Op::AtomicUMax: + return visitor(static_cast(inst)); + case Op::AtomicAnd: + return visitor(static_cast(inst)); + case Op::AtomicOr: + return visitor(static_cast(inst)); + case Op::AtomicXor: + return visitor(static_cast(inst)); + case Op::Phi: + return visitor(static_cast(inst)); + case Op::LoopMerge: + return visitor(static_cast(inst)); + case Op::SelectionMerge: + return visitor(static_cast(inst)); + case Op::Label: + return visitor(static_cast(inst)); + case Op::Branch: + return visitor(static_cast(inst)); + case Op::BranchConditional: + return visitor(static_cast(inst)); + case Op::Switch: + return visitor(static_cast(inst)); + case Op::Kill: + return visitor(static_cast(inst)); + case Op::Return: + return visitor(static_cast(inst)); + case Op::ReturnValue: + return visitor(static_cast(inst)); + case Op::Unreachable: + return visitor(static_cast(inst)); + case Op::LifetimeStart: + return visitor(static_cast(inst)); + case Op::LifetimeStop: + return visitor(static_cast(inst)); + case Op::GroupAsyncCopy: + return visitor(static_cast(inst)); + case Op::GroupWaitEvents: + return visitor(static_cast(inst)); + case Op::GroupAll: + return visitor(static_cast(inst)); + case Op::GroupAny: + return visitor(static_cast(inst)); + case Op::GroupBroadcast: + return visitor(static_cast(inst)); + case Op::GroupIAdd: + return visitor(static_cast(inst)); + case Op::GroupFAdd: + return visitor(static_cast(inst)); + case Op::GroupFMin: + return visitor(static_cast(inst)); + case Op::GroupUMin: + return visitor(static_cast(inst)); + case Op::GroupSMin: + return visitor(static_cast(inst)); + case Op::GroupFMax: + return visitor(static_cast(inst)); + case Op::GroupUMax: + return visitor(static_cast(inst)); + case Op::GroupSMax: + return visitor(static_cast(inst)); + case Op::ReadPipe: + return visitor(static_cast(inst)); + case Op::WritePipe: + return visitor(static_cast(inst)); + case Op::ReservedReadPipe: + return visitor(static_cast(inst)); + case Op::ReservedWritePipe: + return visitor(static_cast(inst)); + case Op::ReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::ReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::CommitReadPipe: + return visitor(static_cast(inst)); + case Op::CommitWritePipe: + return visitor(static_cast(inst)); + case Op::IsValidReserveId: + return visitor(static_cast(inst)); + case Op::GetNumPipePackets: + return visitor(static_cast(inst)); + case Op::GetMaxPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::GroupCommitReadPipe: + return visitor(static_cast(inst)); + case Op::GroupCommitWritePipe: + return visitor(static_cast(inst)); + case Op::EnqueueMarker: + return visitor(static_cast(inst)); + case Op::EnqueueKernel: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeSubGroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeMaxSubGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelWorkGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return visitor(static_cast(inst)); + case Op::RetainEvent: + return visitor(static_cast(inst)); + case Op::ReleaseEvent: + return visitor(static_cast(inst)); + case Op::CreateUserEvent: + return visitor(static_cast(inst)); + case Op::IsValidEvent: + return visitor(static_cast(inst)); + case Op::SetUserEventStatus: + return visitor(static_cast(inst)); + case Op::CaptureEventProfilingInfo: + return visitor(static_cast(inst)); + case Op::GetDefaultQueue: + return visitor(static_cast(inst)); + case Op::BuildNDRange: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseFetch: + return visitor(static_cast(inst)); + case Op::ImageSparseGather: + return visitor(static_cast(inst)); + case Op::ImageSparseDrefGather: + return visitor(static_cast(inst)); + case Op::ImageSparseTexelsResident: + return visitor(static_cast(inst)); + case Op::NoLine: + return visitor(static_cast(inst)); + case Op::AtomicFlagTestAndSet: + return visitor(static_cast(inst)); + case Op::AtomicFlagClear: + return visitor(static_cast(inst)); + case Op::ImageSparseRead: + return visitor(static_cast(inst)); + case Op::SizeOf: + return visitor(static_cast(inst)); + case Op::TypePipeStorage: + return visitor(static_cast(inst)); + case Op::ConstantPipeStorage: + return visitor(static_cast(inst)); + case Op::CreatePipeFromPipeStorage: + return visitor(static_cast(inst)); + case Op::GetKernelLocalSizeForSubgroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelMaxNumSubgroups: + return visitor(static_cast(inst)); + case Op::TypeNamedBarrier: + return visitor(static_cast(inst)); + case Op::NamedBarrierInitialize: + return visitor(static_cast(inst)); + case Op::MemoryNamedBarrier: + return visitor(static_cast(inst)); + case Op::ModuleProcessed: + return visitor(static_cast(inst)); + case Op::ExecutionModeId: + return visitor(static_cast(inst)); + case Op::DecorateId: + return visitor(static_cast(inst)); + case Op::GroupNonUniformElect: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAll: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAny: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAllEqual: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcastFirst: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformInverseBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitExtract: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitCount: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindLSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindMSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffle: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleUp: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleDown: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadSwap: + return visitor(static_cast(inst)); + case Op::CopyLogical: + return visitor(static_cast(inst)); + case Op::PtrEqual: + return visitor(static_cast(inst)); + case Op::PtrNotEqual: + return visitor(static_cast(inst)); + case Op::PtrDiff: + return visitor(static_cast(inst)); + case Op::TypeCooperativeMatrixKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixMulAddKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLengthKHR: + return visitor(static_cast(inst)); + case Op::SubgroupBlockReadINTEL: + return visitor(static_cast(inst)); + case Op::SubgroupBlockWriteINTEL: + return visitor(static_cast(inst)); + case Op::AsmTargetINTEL: + return visitor(static_cast(inst)); + case Op::AsmINTEL: + return visitor(static_cast(inst)); + case Op::AsmCallINTEL: + return visitor(static_cast(inst)); + case Op::AtomicFMinEXT: + return visitor(static_cast(inst)); + case Op::AtomicFMaxEXT: + return visitor(static_cast(inst)); + case Op::AtomicFAddEXT: + return visitor(static_cast(inst)); + case Op::ConvertFToBF16INTEL: + return visitor(static_cast(inst)); + case Op::ConvertBF16ToFINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierArriveINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierWaitINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadCheckedINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreCheckedINTEL: + return visitor(static_cast(inst)); + } + throw internal_compiler_error(); +} + +template auto visit(Visitor &&visitor, spv_inst const &inst) { + switch (inst.opcode()) { + case Op::Nop: + return visitor(static_cast(inst)); + case Op::Undef: + return visitor(static_cast(inst)); + case Op::SourceContinued: + return visitor(static_cast(inst)); + case Op::Source: + return visitor(static_cast(inst)); + case Op::SourceExtension: + return visitor(static_cast(inst)); + case Op::Name: + return visitor(static_cast(inst)); + case Op::MemberName: + return visitor(static_cast(inst)); + case Op::String: + return visitor(static_cast(inst)); + case Op::Line: + return visitor(static_cast(inst)); + case Op::Extension: + return visitor(static_cast(inst)); + case Op::ExtInstImport: + return visitor(static_cast(inst)); + case Op::ExtInst: + return visitor(static_cast(inst)); + case Op::MemoryModel: + return visitor(static_cast(inst)); + case Op::EntryPoint: + return visitor(static_cast(inst)); + case Op::ExecutionMode: + return visitor(static_cast(inst)); + case Op::Capability: + return visitor(static_cast(inst)); + case Op::TypeVoid: + return visitor(static_cast(inst)); + case Op::TypeBool: + return visitor(static_cast(inst)); + case Op::TypeInt: + return visitor(static_cast(inst)); + case Op::TypeFloat: + return visitor(static_cast(inst)); + case Op::TypeVector: + return visitor(static_cast(inst)); + case Op::TypeMatrix: + return visitor(static_cast(inst)); + case Op::TypeImage: + return visitor(static_cast(inst)); + case Op::TypeSampler: + return visitor(static_cast(inst)); + case Op::TypeSampledImage: + return visitor(static_cast(inst)); + case Op::TypeArray: + return visitor(static_cast(inst)); + case Op::TypeRuntimeArray: + return visitor(static_cast(inst)); + case Op::TypeStruct: + return visitor(static_cast(inst)); + case Op::TypeOpaque: + return visitor(static_cast(inst)); + case Op::TypePointer: + return visitor(static_cast(inst)); + case Op::TypeFunction: + return visitor(static_cast(inst)); + case Op::TypeEvent: + return visitor(static_cast(inst)); + case Op::TypeDeviceEvent: + return visitor(static_cast(inst)); + case Op::TypeReserveId: + return visitor(static_cast(inst)); + case Op::TypeQueue: + return visitor(static_cast(inst)); + case Op::TypePipe: + return visitor(static_cast(inst)); + case Op::TypeForwardPointer: + return visitor(static_cast(inst)); + case Op::ConstantTrue: + return visitor(static_cast(inst)); + case Op::ConstantFalse: + return visitor(static_cast(inst)); + case Op::Constant: + return visitor(static_cast(inst)); + case Op::ConstantComposite: + return visitor(static_cast(inst)); + case Op::ConstantSampler: + return visitor(static_cast(inst)); + case Op::ConstantNull: + return visitor(static_cast(inst)); + case Op::Function: + return visitor(static_cast(inst)); + case Op::FunctionParameter: + return visitor(static_cast(inst)); + case Op::FunctionEnd: + return visitor(static_cast(inst)); + case Op::FunctionCall: + return visitor(static_cast(inst)); + case Op::Variable: + return visitor(static_cast(inst)); + case Op::ImageTexelPointer: + return visitor(static_cast(inst)); + case Op::Load: + return visitor(static_cast(inst)); + case Op::Store: + return visitor(static_cast(inst)); + case Op::CopyMemory: + return visitor(static_cast(inst)); + case Op::CopyMemorySized: + return visitor(static_cast(inst)); + case Op::AccessChain: + return visitor(static_cast(inst)); + case Op::InBoundsAccessChain: + return visitor(static_cast(inst)); + case Op::PtrAccessChain: + return visitor(static_cast(inst)); + case Op::ArrayLength: + return visitor(static_cast(inst)); + case Op::GenericPtrMemSemantics: + return visitor(static_cast(inst)); + case Op::InBoundsPtrAccessChain: + return visitor(static_cast(inst)); + case Op::Decorate: + return visitor(static_cast(inst)); + case Op::MemberDecorate: + return visitor(static_cast(inst)); + case Op::DecorationGroup: + return visitor(static_cast(inst)); + case Op::GroupDecorate: + return visitor(static_cast(inst)); + case Op::GroupMemberDecorate: + return visitor(static_cast(inst)); + case Op::VectorExtractDynamic: + return visitor(static_cast(inst)); + case Op::VectorInsertDynamic: + return visitor(static_cast(inst)); + case Op::VectorShuffle: + return visitor(static_cast(inst)); + case Op::CompositeConstruct: + return visitor(static_cast(inst)); + case Op::CompositeExtract: + return visitor(static_cast(inst)); + case Op::CompositeInsert: + return visitor(static_cast(inst)); + case Op::CopyObject: + return visitor(static_cast(inst)); + case Op::Transpose: + return visitor(static_cast(inst)); + case Op::SampledImage: + return visitor(static_cast(inst)); + case Op::ImageSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageFetch: + return visitor(static_cast(inst)); + case Op::ImageGather: + return visitor(static_cast(inst)); + case Op::ImageDrefGather: + return visitor(static_cast(inst)); + case Op::ImageRead: + return visitor(static_cast(inst)); + case Op::ImageWrite: + return visitor(static_cast(inst)); + case Op::Image: + return visitor(static_cast(inst)); + case Op::ImageQueryFormat: + return visitor(static_cast(inst)); + case Op::ImageQueryOrder: + return visitor(static_cast(inst)); + case Op::ImageQuerySizeLod: + return visitor(static_cast(inst)); + case Op::ImageQuerySize: + return visitor(static_cast(inst)); + case Op::ImageQueryLod: + return visitor(static_cast(inst)); + case Op::ImageQueryLevels: + return visitor(static_cast(inst)); + case Op::ImageQuerySamples: + return visitor(static_cast(inst)); + case Op::ConvertFToU: + return visitor(static_cast(inst)); + case Op::ConvertFToS: + return visitor(static_cast(inst)); + case Op::ConvertSToF: + return visitor(static_cast(inst)); + case Op::ConvertUToF: + return visitor(static_cast(inst)); + case Op::UConvert: + return visitor(static_cast(inst)); + case Op::SConvert: + return visitor(static_cast(inst)); + case Op::FConvert: + return visitor(static_cast(inst)); + case Op::QuantizeToF16: + return visitor(static_cast(inst)); + case Op::ConvertPtrToU: + return visitor(static_cast(inst)); + case Op::SatConvertSToU: + return visitor(static_cast(inst)); + case Op::SatConvertUToS: + return visitor(static_cast(inst)); + case Op::ConvertUToPtr: + return visitor(static_cast(inst)); + case Op::PtrCastToGeneric: + return visitor(static_cast(inst)); + case Op::GenericCastToPtr: + return visitor(static_cast(inst)); + case Op::GenericCastToPtrExplicit: + return visitor(static_cast(inst)); + case Op::Bitcast: + return visitor(static_cast(inst)); + case Op::SNegate: + return visitor(static_cast(inst)); + case Op::FNegate: + return visitor(static_cast(inst)); + case Op::IAdd: + return visitor(static_cast(inst)); + case Op::FAdd: + return visitor(static_cast(inst)); + case Op::ISub: + return visitor(static_cast(inst)); + case Op::FSub: + return visitor(static_cast(inst)); + case Op::IMul: + return visitor(static_cast(inst)); + case Op::FMul: + return visitor(static_cast(inst)); + case Op::UDiv: + return visitor(static_cast(inst)); + case Op::SDiv: + return visitor(static_cast(inst)); + case Op::FDiv: + return visitor(static_cast(inst)); + case Op::UMod: + return visitor(static_cast(inst)); + case Op::SRem: + return visitor(static_cast(inst)); + case Op::SMod: + return visitor(static_cast(inst)); + case Op::FRem: + return visitor(static_cast(inst)); + case Op::FMod: + return visitor(static_cast(inst)); + case Op::VectorTimesScalar: + return visitor(static_cast(inst)); + case Op::MatrixTimesScalar: + return visitor(static_cast(inst)); + case Op::VectorTimesMatrix: + return visitor(static_cast(inst)); + case Op::MatrixTimesVector: + return visitor(static_cast(inst)); + case Op::MatrixTimesMatrix: + return visitor(static_cast(inst)); + case Op::OuterProduct: + return visitor(static_cast(inst)); + case Op::Dot: + return visitor(static_cast(inst)); + case Op::IAddCarry: + return visitor(static_cast(inst)); + case Op::ISubBorrow: + return visitor(static_cast(inst)); + case Op::UMulExtended: + return visitor(static_cast(inst)); + case Op::SMulExtended: + return visitor(static_cast(inst)); + case Op::Any: + return visitor(static_cast(inst)); + case Op::All: + return visitor(static_cast(inst)); + case Op::IsNan: + return visitor(static_cast(inst)); + case Op::IsInf: + return visitor(static_cast(inst)); + case Op::IsFinite: + return visitor(static_cast(inst)); + case Op::IsNormal: + return visitor(static_cast(inst)); + case Op::SignBitSet: + return visitor(static_cast(inst)); + case Op::LessOrGreater: + return visitor(static_cast(inst)); + case Op::Ordered: + return visitor(static_cast(inst)); + case Op::Unordered: + return visitor(static_cast(inst)); + case Op::LogicalEqual: + return visitor(static_cast(inst)); + case Op::LogicalNotEqual: + return visitor(static_cast(inst)); + case Op::LogicalOr: + return visitor(static_cast(inst)); + case Op::LogicalAnd: + return visitor(static_cast(inst)); + case Op::LogicalNot: + return visitor(static_cast(inst)); + case Op::Select: + return visitor(static_cast(inst)); + case Op::IEqual: + return visitor(static_cast(inst)); + case Op::INotEqual: + return visitor(static_cast(inst)); + case Op::UGreaterThan: + return visitor(static_cast(inst)); + case Op::SGreaterThan: + return visitor(static_cast(inst)); + case Op::UGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::SGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ULessThan: + return visitor(static_cast(inst)); + case Op::SLessThan: + return visitor(static_cast(inst)); + case Op::ULessThanEqual: + return visitor(static_cast(inst)); + case Op::SLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdEqual: + return visitor(static_cast(inst)); + case Op::FUnordEqual: + return visitor(static_cast(inst)); + case Op::FOrdNotEqual: + return visitor(static_cast(inst)); + case Op::FUnordNotEqual: + return visitor(static_cast(inst)); + case Op::FOrdLessThan: + return visitor(static_cast(inst)); + case Op::FUnordLessThan: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThan: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThan: + return visitor(static_cast(inst)); + case Op::FOrdLessThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ShiftRightLogical: + return visitor(static_cast(inst)); + case Op::ShiftRightArithmetic: + return visitor(static_cast(inst)); + case Op::ShiftLeftLogical: + return visitor(static_cast(inst)); + case Op::BitwiseOr: + return visitor(static_cast(inst)); + case Op::BitwiseXor: + return visitor(static_cast(inst)); + case Op::BitwiseAnd: + return visitor(static_cast(inst)); + case Op::Not: + return visitor(static_cast(inst)); + case Op::BitFieldInsert: + return visitor(static_cast(inst)); + case Op::BitFieldSExtract: + return visitor(static_cast(inst)); + case Op::BitFieldUExtract: + return visitor(static_cast(inst)); + case Op::BitReverse: + return visitor(static_cast(inst)); + case Op::BitCount: + return visitor(static_cast(inst)); + case Op::DPdx: + return visitor(static_cast(inst)); + case Op::DPdy: + return visitor(static_cast(inst)); + case Op::Fwidth: + return visitor(static_cast(inst)); + case Op::DPdxFine: + return visitor(static_cast(inst)); + case Op::DPdyFine: + return visitor(static_cast(inst)); + case Op::FwidthFine: + return visitor(static_cast(inst)); + case Op::DPdxCoarse: + return visitor(static_cast(inst)); + case Op::DPdyCoarse: + return visitor(static_cast(inst)); + case Op::FwidthCoarse: + return visitor(static_cast(inst)); + case Op::EmitVertex: + return visitor(static_cast(inst)); + case Op::EndPrimitive: + return visitor(static_cast(inst)); + case Op::EmitStreamVertex: + return visitor(static_cast(inst)); + case Op::EndStreamPrimitive: + return visitor(static_cast(inst)); + case Op::ControlBarrier: + return visitor(static_cast(inst)); + case Op::MemoryBarrier: + return visitor(static_cast(inst)); + case Op::AtomicLoad: + return visitor(static_cast(inst)); + case Op::AtomicStore: + return visitor(static_cast(inst)); + case Op::AtomicExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchangeWeak: + return visitor(static_cast(inst)); + case Op::AtomicIIncrement: + return visitor(static_cast(inst)); + case Op::AtomicIDecrement: + return visitor(static_cast(inst)); + case Op::AtomicIAdd: + return visitor(static_cast(inst)); + case Op::AtomicISub: + return visitor(static_cast(inst)); + case Op::AtomicSMin: + return visitor(static_cast(inst)); + case Op::AtomicUMin: + return visitor(static_cast(inst)); + case Op::AtomicSMax: + return visitor(static_cast(inst)); + case Op::AtomicUMax: + return visitor(static_cast(inst)); + case Op::AtomicAnd: + return visitor(static_cast(inst)); + case Op::AtomicOr: + return visitor(static_cast(inst)); + case Op::AtomicXor: + return visitor(static_cast(inst)); + case Op::Phi: + return visitor(static_cast(inst)); + case Op::LoopMerge: + return visitor(static_cast(inst)); + case Op::SelectionMerge: + return visitor(static_cast(inst)); + case Op::Label: + return visitor(static_cast(inst)); + case Op::Branch: + return visitor(static_cast(inst)); + case Op::BranchConditional: + return visitor(static_cast(inst)); + case Op::Switch: + return visitor(static_cast(inst)); + case Op::Kill: + return visitor(static_cast(inst)); + case Op::Return: + return visitor(static_cast(inst)); + case Op::ReturnValue: + return visitor(static_cast(inst)); + case Op::Unreachable: + return visitor(static_cast(inst)); + case Op::LifetimeStart: + return visitor(static_cast(inst)); + case Op::LifetimeStop: + return visitor(static_cast(inst)); + case Op::GroupAsyncCopy: + return visitor(static_cast(inst)); + case Op::GroupWaitEvents: + return visitor(static_cast(inst)); + case Op::GroupAll: + return visitor(static_cast(inst)); + case Op::GroupAny: + return visitor(static_cast(inst)); + case Op::GroupBroadcast: + return visitor(static_cast(inst)); + case Op::GroupIAdd: + return visitor(static_cast(inst)); + case Op::GroupFAdd: + return visitor(static_cast(inst)); + case Op::GroupFMin: + return visitor(static_cast(inst)); + case Op::GroupUMin: + return visitor(static_cast(inst)); + case Op::GroupSMin: + return visitor(static_cast(inst)); + case Op::GroupFMax: + return visitor(static_cast(inst)); + case Op::GroupUMax: + return visitor(static_cast(inst)); + case Op::GroupSMax: + return visitor(static_cast(inst)); + case Op::ReadPipe: + return visitor(static_cast(inst)); + case Op::WritePipe: + return visitor(static_cast(inst)); + case Op::ReservedReadPipe: + return visitor(static_cast(inst)); + case Op::ReservedWritePipe: + return visitor(static_cast(inst)); + case Op::ReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::ReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::CommitReadPipe: + return visitor(static_cast(inst)); + case Op::CommitWritePipe: + return visitor(static_cast(inst)); + case Op::IsValidReserveId: + return visitor(static_cast(inst)); + case Op::GetNumPipePackets: + return visitor(static_cast(inst)); + case Op::GetMaxPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::GroupCommitReadPipe: + return visitor(static_cast(inst)); + case Op::GroupCommitWritePipe: + return visitor(static_cast(inst)); + case Op::EnqueueMarker: + return visitor(static_cast(inst)); + case Op::EnqueueKernel: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeSubGroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeMaxSubGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelWorkGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return visitor(static_cast(inst)); + case Op::RetainEvent: + return visitor(static_cast(inst)); + case Op::ReleaseEvent: + return visitor(static_cast(inst)); + case Op::CreateUserEvent: + return visitor(static_cast(inst)); + case Op::IsValidEvent: + return visitor(static_cast(inst)); + case Op::SetUserEventStatus: + return visitor(static_cast(inst)); + case Op::CaptureEventProfilingInfo: + return visitor(static_cast(inst)); + case Op::GetDefaultQueue: + return visitor(static_cast(inst)); + case Op::BuildNDRange: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseFetch: + return visitor(static_cast(inst)); + case Op::ImageSparseGather: + return visitor(static_cast(inst)); + case Op::ImageSparseDrefGather: + return visitor(static_cast(inst)); + case Op::ImageSparseTexelsResident: + return visitor(static_cast(inst)); + case Op::NoLine: + return visitor(static_cast(inst)); + case Op::AtomicFlagTestAndSet: + return visitor(static_cast(inst)); + case Op::AtomicFlagClear: + return visitor(static_cast(inst)); + case Op::ImageSparseRead: + return visitor(static_cast(inst)); + case Op::SizeOf: + return visitor(static_cast(inst)); + case Op::TypePipeStorage: + return visitor(static_cast(inst)); + case Op::ConstantPipeStorage: + return visitor(static_cast(inst)); + case Op::CreatePipeFromPipeStorage: + return visitor(static_cast(inst)); + case Op::GetKernelLocalSizeForSubgroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelMaxNumSubgroups: + return visitor(static_cast(inst)); + case Op::TypeNamedBarrier: + return visitor(static_cast(inst)); + case Op::NamedBarrierInitialize: + return visitor(static_cast(inst)); + case Op::MemoryNamedBarrier: + return visitor(static_cast(inst)); + case Op::ModuleProcessed: + return visitor(static_cast(inst)); + case Op::ExecutionModeId: + return visitor(static_cast(inst)); + case Op::DecorateId: + return visitor(static_cast(inst)); + case Op::GroupNonUniformElect: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAll: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAny: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAllEqual: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcastFirst: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformInverseBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitExtract: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitCount: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindLSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindMSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffle: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleUp: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleDown: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadSwap: + return visitor(static_cast(inst)); + case Op::CopyLogical: + return visitor(static_cast(inst)); + case Op::PtrEqual: + return visitor(static_cast(inst)); + case Op::PtrNotEqual: + return visitor(static_cast(inst)); + case Op::PtrDiff: + return visitor(static_cast(inst)); + case Op::TypeCooperativeMatrixKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixMulAddKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLengthKHR: + return visitor(static_cast(inst)); + case Op::SubgroupBlockReadINTEL: + return visitor(static_cast(inst)); + case Op::SubgroupBlockWriteINTEL: + return visitor(static_cast(inst)); + case Op::AsmTargetINTEL: + return visitor(static_cast(inst)); + case Op::AsmINTEL: + return visitor(static_cast(inst)); + case Op::AsmCallINTEL: + return visitor(static_cast(inst)); + case Op::AtomicFMinEXT: + return visitor(static_cast(inst)); + case Op::AtomicFMaxEXT: + return visitor(static_cast(inst)); + case Op::AtomicFAddEXT: + return visitor(static_cast(inst)); + case Op::ConvertFToBF16INTEL: + return visitor(static_cast(inst)); + case Op::ConvertBF16ToFINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierArriveINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierWaitINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadCheckedINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreCheckedINTEL: + return visitor(static_cast(inst)); + } + throw internal_compiler_error(); +} + +template class default_visitor { + public: + template using const_t = std::conditional_t, T>; + auto pre_visit(const_t &) {} + auto visit_result(const_t &) {} + auto post_visit(const_t &) {} + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + for (auto &op : in.op3()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + if (in.op5()) { + static_cast(this)->operator()(*in.op5()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + for (auto &op : in.op3()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + static_cast(this)->operator()(in.op7()); + static_cast(this)->operator()(in.op8()); + static_cast(this)->operator()(in.op9()); + for (auto &op : in.op10()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + if (in.op5()) { + static_cast(this)->operator()(*in.op5()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + if (in.op6()) { + static_cast(this)->operator()(*in.op6()); + } + + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + + if (in.op8()) { + static_cast(this)->operator()(*in.op8()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + + if (in.op8()) { + static_cast(this)->operator()(*in.op8()); + } + + if (in.op9()) { + static_cast(this)->operator()(*in.op9()); + } + + static_cast(this)->post_visit(in); + } +}; + +} // namespace tinytc::spv + +#endif // GENERATED_VISIT_20250605_HPP diff --git a/src/spv/xe_constants.hpp b/src/spv/xe_constants.hpp new file mode 100644 index 00000000..6da62431 --- /dev/null +++ b/src/spv/xe_constants.hpp @@ -0,0 +1,16 @@ +#ifndef XE_CONSTANTS_20250219_HPP +#define XE_CONSTANTS_20250219_HPP + +#include + +namespace tinytc::spv::xe { +constexpr static std::int32_t grf_size = 64; +constexpr static std::int32_t exec_size = 16; +constexpr static std::int32_t channel_size = 4; +constexpr static std::int32_t sdepth = 8; +constexpr static std::int32_t rcount = 8; +constexpr static std::int32_t load_batch_size = 4; +constexpr static std::int32_t store_batch_size = 1; +} // namespace tinytc::spv::xe + +#endif // XE_CONSTANTS_20250219_HPP diff --git a/src/support/fnv1a_array_view.hpp b/src/support/fnv1a_array_view.hpp new file mode 100644 index 00000000..9b254d1a --- /dev/null +++ b/src/support/fnv1a_array_view.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FNV1A_ARRAY_VIEW_20241010_HPP +#define FNV1A_ARRAY_VIEW_20241010_HPP + +#include "tinytc/tinytc.hpp" +#include "util/fnv1a.hpp" + +namespace tinytc { + +template +constexpr auto fnv1a_step(std::uint64_t hash, array_view const &data) -> std::uint64_t { + for (auto const &i : data) { + hash = fnv1a_step(hash, i); + } + return hash; +} + +} // namespace tinytc + +#endif // FNV1A_ARRAY_VIEW_20241010_HPP diff --git a/src/support/fp_util.hpp b/src/support/fp_util.hpp new file mode 100644 index 00000000..44492dab --- /dev/null +++ b/src/support/fp_util.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FP_UTIL_20241126_HPP +#define FP_UTIL_20241126_HPP + +#include "tinytc/tinytc.hpp" + +#include +#include + +namespace tinytc { + +template class U> +struct is_instance_of : public std::false_type {}; +template