diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c2fe9f5..58463426 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,55 @@ # Changelog +## [0.4.0] - 2025-XX-XX + +### Major changes +* Support for complex numbers (c32 and c64) with optimized complex GEMM +* Support for half precision and bfloat16 numbers +* (Experimental) Support for accelerating low precision GEMMs with tensor cores (XMX) +* Compiler directly generates SPIR-V binaries (instead of OpenCL-C); JIT compilation speed-up of about 2x +* Removed dependency on clir and ocloc libraries + +### Tiny Tensor Language +* Clarified execution model +* Added align attribute +* Added unroll attribute +* Added boolean, f16 (half precision), bf16 (bfloat16), c32 (single complex), and c64 (double complex) type +* Added address space to memref type +* Added cooperative matrix type +* Added new instructions: barrier, builtin, constant, cooperative_matrix_load/mul_add/scale/store, parallel, + subgroup_broadcast, work_group +* Extended foreach for iteration spaces with dim > 1 +* For-loops can yield values +* Removed immediate operands in favour of constant instruction +* Instructions annotate the return type instead of operand types +* Added cumulative sum instruction +* Change 1d group id to 3d group id; introduce 2d subgroup id + +### Core API +* Introduced compiler_context class for compilations flags and control of error reporting +* Removed functions to create kernel bundle from source text +* Removed source_context in favour of compiler_context +* Added functions to compile tinytc prog to SPIR-V +* Added API to run function passes +* Added float <-> half and float <-> bfloat16 conversion functions +* Added array_view class in C++-API + +### Builder API +* Changed ownership model for data type, inst, region, and func +* Normalized function names in builder API +* Instruction creation function always include return type + +### Compiler +* Overhauled infrastructure for writing advanced compiler passes +* Introduced data flow analysis to properly insert barriers automatically when for-loops are present +* Added constant folding, constant propagation, and dead code elimination passes +* Added memref alignment analysis +* Added analysis to check applicability of tensor core acceleration +* Added SPIR-V support; Tiny Tensor Language to SPIR-V conversion and SPIR-V binary generation +* Introduced GEMM to cooperative matrix conversion (instead of direct GEMM to OpenCL-C conversion) +* Added fast cooperative matrix mul add implementation for complex numbers +* Added attributes + ## [0.3.1] - 2024-05-22 * Bugfix: Add alias analysis for stack; needed to correctly insert barriers * Bugfix: Disable block writes as alignment analysis is missing diff --git a/CMakeLists.txt b/CMakeLists.txt index ca7f1f1e..7ea9a109 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.23) -project(tiny_tensor_compiler VERSION 0.3.1 LANGUAGES C CXX) +project(tiny-tensor-compiler VERSION 0.3.1 LANGUAGES C CXX) include(CMakeDependentOption) option(BUILD_DOCUMENTATION "Build documentation" OFF) @@ -14,6 +14,8 @@ cmake_dependent_option(BUILD_LEVEL_ZERO cmake_dependent_option(BUILD_OPENCL "Build support for OpenCL run-time; required when SYCL is enabled" ON "NOT BUILD_SYCL" ON) +option(BUILD_DOUBLE_PRECISION_TESTS "Build double precision unit tests" ON) +option(BUILD_MOCHI "Build mochi; can only be disabled if building from release tarball" ON) include(CTest) @@ -40,6 +42,13 @@ if(BUILD_DOCUMENTATION) add_subdirectory(docs) endif() +# cpack + +include(CPackSetup) +cpack_setup() + +# tests + enable_testing() if(BUILD_TESTING) add_subdirectory(test) diff --git a/cmake/CPackGeneratedFiles.cmake.in b/cmake/CPackGeneratedFiles.cmake.in new file mode 100644 index 00000000..afd3eaef --- /dev/null +++ b/cmake/CPackGeneratedFiles.cmake.in @@ -0,0 +1,15 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +if(CPACK_SOURCE_INSTALLED_DIRECTORIES) + file(GLOB FILE_LISTS "@GENERATED_FILE_LISTS_DIR@/*.list") + foreach(_file_list ${FILE_LISTS}) + file(READ ${_file_list} _files) + foreach(_src ${_files}) + file(RELATIVE_PATH _src_rel @PROJECT_BINARY_DIR@ ${_src}) + get_filename_component(_src_dir ${_src_rel} DIRECTORY) + message(STATUS "Adding generated file ${_src_rel} to ${_src_dir}") + file(INSTALL ${_src} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/${_src_dir}") + endforeach() + endforeach() +endif() diff --git a/cmake/CPackSetup.cmake b/cmake/CPackSetup.cmake new file mode 100644 index 00000000..f5a1a51e --- /dev/null +++ b/cmake/CPackSetup.cmake @@ -0,0 +1,46 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +include(GitVersion) + +function(cpack_setup) + git_version() + set(CPACK_SOURCE_GENERATOR "TXZ") + set(CPACK_SOURCE_PACKAGE_FILE_NAME + "${CMAKE_PROJECT_NAME}-${GIT_MAJOR_VERSION}.${GIT_MINOR_VERSION}.${GIT_PATCH_VERSION}" + ) + if(NOT "${GIT_COMMITS_SINCE_RELEASE}" STREQUAL "0") + set(CPACK_SOURCE_PACKAGE_FILE_NAME + "${CPACK_SOURCE_PACKAGE_FILE_NAME}-${GIT_COMMITS_SINCE_RELEASE}-${GIT_COMMIT}" + ) + endif() + set(CPACK_SOURCE_IGNORE_FILES + "/\\\\.git/" + "/\\\\.gitignore$" + "\\\\.swp$" + "/build/" + "/build_debug/" + "/build_iwyu/" + "/coverity_scan/" + "/__pycache__/" + ) + + set(GENERATED_FILE_LISTS_DIR "${PROJECT_BINARY_DIR}/generated_file_lists") + + configure_file("${PROJECT_SOURCE_DIR}/cmake/CPackGeneratedFiles.cmake.in" "CPackGeneratedFiles.cmake" @ONLY) + set(CPACK_INSTALL_SCRIPT "${CMAKE_CURRENT_BINARY_DIR}/CPackGeneratedFiles.cmake") + include(CPack) + add_custom_target(dist + COMMAND "${CMAKE_COMMAND}" --build "${PROJECT_BINARY_DIR}" --target package_source + DEPENDS tinytc-objects + VERBATIM + USES_TERMINAL + ) + if(BUILD_OPENCL) + add_dependencies(dist tinytc_cl-objects) + endif() + if(BUILD_MOCHI) + add_dependencies(dist mochi) + endif() +endfunction() + diff --git a/cmake/CommonOptions.cmake b/cmake/CommonOptions.cmake index 1d5ba3a1..2eabf540 100644 --- a/cmake/CommonOptions.cmake +++ b/cmake/CommonOptions.cmake @@ -7,6 +7,7 @@ include(CheckCompilerFlag) include(CheckLinkerFlag) +include(GitVersion) function(add_flag_if_available lang target flag) string(MAKE_C_IDENTIFIER ${flag} flag_c) @@ -54,6 +55,11 @@ function(set_common_options lang target) SOVERSION ${tiny_tensor_compiler_VERSION_MAJOR}) set_property(TARGET ${target} PROPERTY POSITION_INDEPENDENT_CODE ON) target_compile_definitions(${target} PRIVATE -D_FORTIFY_SOURCE=2) + + git_version() + set_target_properties(${target} PROPERTIES + VERSION ${GIT_MAJOR_VERSION}.${GIT_MINOR_VERSION}.${GIT_PATCH_VERSION} + SOVERSION ${GIT_MAJOR_VERSION}) endfunction() function(set_c_common_options target) diff --git a/cmake/FindClangFormat.cmake b/cmake/FindClangFormat.cmake new file mode 100644 index 00000000..936fe2d1 --- /dev/null +++ b/cmake/FindClangFormat.cmake @@ -0,0 +1,56 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +# Try to find clang-format + +# The following definitions are added on success +# +# ClangFormat_FOUND - ClangFormat was found +# ClangFormat_EXECUTABLE - ClangFormat executable +# ClangFormat_VERSION - ClangFormat version +# +# The followings hints may be passed in the environment: +# +# CLANGFORMAT_ROOT +# + +if(ClangFormat_EXECUTABLE) + set(ClangFormat_FOUND TRUE) +else() + find_program(ClangFormat_EXECUTABLE NAMES clang-format + HINTS + ENV CLANGFORMAT_ROOT + ENV PATH + ) + + execute_process(COMMAND ${ClangFormat_EXECUTABLE} --version + OUTPUT_VARIABLE ClangFormat_version_output + ERROR_VARIABLE ClangFormat_version_error + RESULT_VARIABLE ClangFormat_version_result + OUTPUT_STRIP_TRAILING_WHITESPACE) + + macro(ClangFormat_version_failed) + set(ClangFormat_message "Command \"${ClangFormat_EXECUTABLE} --version\" failed:\n${ClangFormat_version_output}\n${ClangFormat_version_error}") + if(ClangFormat_FIND_REQUIRED) + message(SEND_ERROR ${ClangFormat_message}) + else() + message(${ClangFormat_message}) + endif() + endmacro() + + if(NOT ${ClangFormat_version_result} EQUAL 0) + ClangFormat_version_failed() + else() + if(${ClangFormat_version_output} MATCHES "[a-zA-Z-_ ]*([0-9\.]+)") + set(ClangFormat_VERSION ${CMAKE_MATCH_1}) + else() + ClangFormat_version_failed() + endif() + endif() + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(ClangFormat DEFAULT_MSG + ClangFormat_EXECUTABLE ClangFormat_VERSION) + + mark_as_advanced(ClangFormat_EXECUTABLE ClangFormat_VERSION) +endif() diff --git a/cmake/FindSPIRVTools.cmake b/cmake/FindSPIRVTools.cmake new file mode 100644 index 00000000..b1f4de66 --- /dev/null +++ b/cmake/FindSPIRVTools.cmake @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +# Try to find SPIR-V Tools + +# The following definitions are added on success +# +# SPIRVTools_FOUND - SPIR-V Tools was found +# SPIRVTools_SPIRV_VAL - spirv-val executable +# SPIRVTools_VERSION - SPIR-V Tools version +# +# The followings hints may be passed in the environment: +# +# RE2C_ROOT +# + +if(SPIRVTools_SPIRV_VAL) + set(SPIRVTools_FOUND TRUE) +else() + find_program(SPIRVTools_SPIRV_VAL NAMES spirv-val + HINTS + ENV SPIRVTools_ROOT + ENV PATH + ) + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(SPIRVTools DEFAULT_MSG SPIRVTools_SPIRV_VAL) + + execute_process(COMMAND ${SPIRVTools_SPIRV_VAL} --version + OUTPUT_VARIABLE SPIRVTools_version_output + ERROR_VARIABLE SPIRVTools_version_error + RESULT_VARIABLE SPIRVTools_version_result + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT ${SPIRVTools_version_result} EQUAL 0) + set(SPIRVTools_message "Command \"{SPIRVTools_SPIRV_VAL} --version\" failed:\n${SPIRVTools_version_output}\n${SPIRVTools_version_error}") + if(SPIRVTools_FIND_REQUIRED) + message(SEND_ERROR ${SPIRVTools_message}) + else() + message(${SPIRVTools_message}) + endif() + else() + string(REGEX REPLACE "SPIRV-Tools ([v0-9\.]+) .*" "\\1" SPIRVTools_VERSION "${SPIRVTools_version_output}") + endif() + + mark_as_advanced(SPIRVTools_SPIRV_VAL) +endif() diff --git a/cmake/Findre2c.cmake b/cmake/Findre2c.cmake index 6ca51db7..996171c1 100644 --- a/cmake/Findre2c.cmake +++ b/cmake/Findre2c.cmake @@ -14,6 +14,8 @@ # RE2C_ROOT # +include(GeneratedPath) + if(re2c_EXECUTABLE) set(re2c_FOUND TRUE) else() @@ -23,9 +25,6 @@ else() ENV PATH ) - include(FindPackageHandleStandardArgs) - find_package_handle_standard_args(re2c DEFAULT_MSG re2c_EXECUTABLE) - execute_process(COMMAND ${re2c_EXECUTABLE} --version OUTPUT_VARIABLE re2c_version_output ERROR_VARIABLE re2c_version_error @@ -33,7 +32,7 @@ else() OUTPUT_STRIP_TRAILING_WHITESPACE) if(NOT ${re2c_version_result} EQUAL 0) - set(re2c_message "Command \"{re2c_EXECUTABLE} --version\" failed:\n${re2c_version_output}\n${re2c_version_error}") + set(re2c_message "Command \"${re2c_EXECUTABLE} --version\" failed:\n${re2c_version_output}\n${re2c_version_error}") if(re2c_FIND_REQUIRED) message(SEND_ERROR ${re2c_message}) else() @@ -43,31 +42,24 @@ else() string(REGEX REPLACE "re2c ([0-9\.]+)" "\\1" re2c_VERSION "${re2c_version_output}") endif() - mark_as_advanced(re2c_EXECUTABLE) + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(re2c DEFAULT_MSG re2c_EXECUTABLE re2c_VERSION) + + mark_as_advanced(re2c_EXECUTABLE re2c_VERSION) endif() if(re2c_EXECUTABLE) function(add_re2c_to_target) cmake_parse_arguments(PARSE_ARGV 0 ARG "" "TARGET;FLAGS" "SOURCES") foreach(SOURCE IN LISTS ARG_SOURCES) - set(INPUT_FILE ${SOURCE}) - if(NOT IS_ABSOLUTE "${INPUT_FILE}") - set(INPUT_FILE ${CMAKE_CURRENT_SOURCE_DIR}/${INPUT_FILE}) - endif() - file(RELATIVE_PATH INPUT_REL_PATH ${PROJECT_SOURCE_DIR} ${INPUT_FILE}) - get_filename_component(INPUT_REL_PATH ${INPUT_REL_PATH} DIRECTORY) - get_filename_component(INPUT_NAME ${INPUT_FILE} NAME) - string(REGEX REPLACE "[.]re$" ".cpp" OUTPUT_NAME ${INPUT_NAME}) - set(OUTPUT_PATH ${PROJECT_BINARY_DIR}/${INPUT_REL_PATH}) - file(MAKE_DIRECTORY ${OUTPUT_PATH}) - set(OUTPUT_FILE ${OUTPUT_PATH}/${OUTPUT_NAME}) + get_path_of_generated_file(SOURCE ${SOURCE} EXT "cpp") add_custom_command( - OUTPUT ${OUTPUT_FILE} + OUTPUT ${OUTPUT_PATH} DEPENDS ${SOURCE} - COMMAND ${re2c_EXECUTABLE} ${ARG_FLAGS} -o ${OUTPUT_FILE} ${INPUT_FILE} - COMMENT "Generating lexer ${OUTPUT_FILE}" + COMMAND ${re2c_EXECUTABLE} ${ARG_FLAGS} -o ${OUTPUT_PATH} ${INPUT_PATH} + COMMENT "Generating lexer ${OUTPUT_REL_PATH}" ) - target_sources(${ARG_TARGET} PRIVATE ${OUTPUT_FILE}) + target_sources(${ARG_TARGET} PRIVATE ${OUTPUT_PATH}) endforeach() endfunction() endif() diff --git a/cmake/GeneratedFiles.cmake b/cmake/GeneratedFiles.cmake new file mode 100644 index 00000000..d73d298b --- /dev/null +++ b/cmake/GeneratedFiles.cmake @@ -0,0 +1,114 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +include(CommonOptions) +include(GeneratedPath) + +function(get_generated_files VAR target) + get_target_property(_sources ${target} SOURCES) + set(_generated_files "") + foreach(_src ${_sources}) + get_source_file_property(_generated "${_src}" GENERATED) + if(${_generated} GREATER 0) + list(APPEND _generated_files ${_src}) + endif() + endforeach() + set(${VAR} ${_generated_files} PARENT_SCOPE) +endfunction() + +function(write_generated_files target) + get_generated_files(GENERATED_FILES ${target}) + file(WRITE "${PROJECT_BINARY_DIR}/generated_file_lists/${target}.list" "${GENERATED_FILES}") +endfunction() + +function(add_re2c_or_pregenerated_to_target) + cmake_parse_arguments(PARSE_ARGV 0 ARG "" "TARGET;FLAGS" "SOURCES") + + foreach(SOURCE IN LISTS ARG_SOURCES) + get_path_of_generated_file(SOURCE ${SOURCE} EXT "cpp") + + if(EXISTS ${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH}) + message(STATUS "Pre-generated ${OUTPUT_REL_PATH} available -- skipping re2c dependency") + target_sources(${ARG_TARGET} PRIVATE ${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH}) + else() + find_package(re2c REQUIRED) + add_re2c_to_target(TARGET ${ARG_TARGET} SOURCES ${SOURCE} FLAGS "${ARG_FLAGS}") + endif() + endforeach() +endfunction() + +function(add_bison_or_pregenerated_to_target) + cmake_parse_arguments(PARSE_ARGV 0 ARG "HAVE_LOCATION" "TARGET" "SOURCES") + + foreach(SOURCE IN LISTS ARG_SOURCES) + get_path_of_generated_file(SOURCE ${SOURCE} EXT "cpp") + + string(REGEX REPLACE "[.]cpp$" ".hpp" header_rel_path ${OUTPUT_REL_PATH}) + if(ARG_HAVE_LOCATION) + get_filename_component(rel_path ${OUTPUT_REL_PATH} DIRECTORY) + set(location_hh "${rel_path}/location.hh") + endif() + + if(EXISTS ${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH} AND + EXISTS ${PROJECT_SOURCE_DIR}/${header_rel_path} AND + (NOT location_hh OR EXISTS ${PROJECT_SOURCE_DIR}/${location_hh})) + message(STATUS "Pre-generated ${OUTPUT_REL_PATH},${header_rel_path} available -- skipping bison dependency") + target_sources(${ARG_TARGET} PRIVATE ${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH}) + set(BISON_parser_OUTPUTS "${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH}") + if(location_hh) + set(location_hh "${PROJECT_SOURCE_DIR}/${location_hh}") + endif() + else() + find_package(BISON 3.8.2 REQUIRED) + BISON_TARGET(parser ${INPUT_PATH} ${OUTPUT_PATH} + DEFINES_FILE ${PROJECT_BINARY_DIR}/${header_rel_path}) + if(location_hh) + set(location_hh "${PROJECT_BINARY_DIR}/${location_hh}") + endif() + endif() + target_sources(${ARG_TARGET} PRIVATE ${BISON_parser_OUTPUTS}) + if(location_hh) + set_property(SOURCE "${location_hh}" PROPERTY GENERATED 1) + target_sources(${ARG_TARGET} PRIVATE ${location_hh}) + endif() + add_flag_if_available_to_source_files(CXX "${BISON_parser_OUTPUTS}" "-Wno-unused-but-set-variable") + endforeach() +endfunction() + +function(add_mochi_or_pregenerated_to_target) + cmake_parse_arguments(PARSE_ARGV 0 ARG "" "FLAGS;TARGET" "DEPENDS;SEARCH_PATHS;SOURCES") + + find_package(ClangFormat) + + set(search_paths "") + foreach(search_path IN LISTS ARG_SEARCH_PATHS) + list(APPEND search_paths "-I\"${search_path}\"") + endforeach() + + foreach(SOURCE IN LISTS ARG_SOURCES) + get_path_of_generated_file(SOURCE ${SOURCE}) + + if(EXISTS ${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH}) + message(STATUS "Pre-generated ${OUTPUT_REL_PATH} available -- skipping mochi dependency") + target_sources(${ARG_TARGET} PRIVATE ${PROJECT_SOURCE_DIR}/${OUTPUT_REL_PATH}) + else() + if(ClangFormat_FOUND) + add_custom_command( + OUTPUT ${OUTPUT_PATH} + DEPENDS mochi ${SOURCE} ${ARG_DEPENDS} + COMMAND mochi ${ARG_FLAGS} -o ${OUTPUT_PATH} ${INPUT_PATH} ${search_paths} + COMMAND ${ClangFormat_EXECUTABLE} -i ${OUTPUT_PATH} + COMMENT "Generating code ${OUTPUT_REL_PATH}" + ) + else() + add_custom_command( + OUTPUT ${OUTPUT_PATH} + DEPENDS mochi ${SOURCE} ${ARG_DEPENDS} + COMMAND mochi ${ARG_FLAGS} -o ${OUTPUT_PATH} ${INPUT_PATH} ${search_paths} + COMMENT "Generating code ${OUTPUT_REL_PATH}" + ) + endif() + target_sources(${ARG_TARGET} PRIVATE ${OUTPUT_PATH}) + endif() + endforeach() +endfunction() diff --git a/cmake/GeneratedPath.cmake b/cmake/GeneratedPath.cmake new file mode 100644 index 00000000..fe66726b --- /dev/null +++ b/cmake/GeneratedPath.cmake @@ -0,0 +1,24 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +function(get_path_of_generated_file) + cmake_parse_arguments(PARSE_ARGV 0 ARG "" "SOURCE;EXT" "") + + set(INPUT_FILE ${ARG_SOURCE}) + if(NOT IS_ABSOLUTE "${INPUT_FILE}") + set(INPUT_FILE ${CMAKE_CURRENT_SOURCE_DIR}/${INPUT_FILE}) + endif() + file(RELATIVE_PATH INPUT_REL_PATH ${PROJECT_SOURCE_DIR} ${INPUT_FILE}) + get_filename_component(INPUT_REL_PATH ${INPUT_REL_PATH} DIRECTORY) + get_filename_component(INPUT_NAME_WLE ${INPUT_FILE} NAME_WLE) + set(OUTPUT_DIR ${PROJECT_BINARY_DIR}/${INPUT_REL_PATH}) + file(MAKE_DIRECTORY ${OUTPUT_DIR}) + if(ARG_EXT) + set(OUTPUT_NAME "${INPUT_NAME_WLE}.${ARG_EXT}") + else() + set(OUTPUT_NAME ${INPUT_NAME_WLE}) + endif() + set(INPUT_PATH "${INPUT_FILE}" PARENT_SCOPE) + set(OUTPUT_REL_PATH "${INPUT_REL_PATH}/${OUTPUT_NAME}" PARENT_SCOPE) + set(OUTPUT_PATH "${OUTPUT_DIR}/${OUTPUT_NAME}" PARENT_SCOPE) +endfunction() diff --git a/cmake/GitVersion.cmake b/cmake/GitVersion.cmake index 0ab27061..bc39bf84 100644 --- a/cmake/GitVersion.cmake +++ b/cmake/GitVersion.cmake @@ -13,7 +13,7 @@ function(git_version) if(GIT_FOUND) execute_process( - COMMAND ${GIT_EXECUTABLE} describe --tags --long + COMMAND ${GIT_EXECUTABLE} describe --tags --long --match "v[0-9]*" WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" RESULT_VARIABLE status_code OUTPUT_VARIABLE output diff --git a/cmake/tinytc-config.cmake b/cmake/tinytc-config.cmake index a7a1126b..72ca66f0 100644 --- a/cmake/tinytc-config.cmake +++ b/cmake/tinytc-config.cmake @@ -10,10 +10,6 @@ function(load_dependencies type) if ("${type}" STREQUAL "shared") set(is_shared ON) endif () - if (NOT DEFINED clir_SHARED_LIBS) - set(clir_SHARED_LIBS ${is_shared}) - endif () - find_dependency(clir REQUIRED) endfunction() @SHARED_STATIC_TEMPLATE@ diff --git a/docs/CMakeLists.txt b/docs/CMakeLists.txt index 388abf9d..0aa8b9bf 100644 --- a/docs/CMakeLists.txt +++ b/docs/CMakeLists.txt @@ -51,6 +51,7 @@ set(DOC_FILES api/ze/capi.rst api/ze/cxxapi.rst api/ze/index.rst + dev/coopmatrix_layout.rst manual/build.rst manual/builder.rst manual/calling_convention.rst diff --git a/docs/Doxyfile.in b/docs/Doxyfile.in index faae40d4..bbedf22e 100644 --- a/docs/Doxyfile.in +++ b/docs/Doxyfile.in @@ -1370,15 +1370,6 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - # If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML # documentation will contain a main index with vertical navigation menus that # are dynamically created via JavaScript. If disabled, the navigation index will @@ -1703,17 +1694,6 @@ HTML_FORMULA_FORMAT = png FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANSPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - # The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands # to create new LaTeX commands to be used in formulas as building blocks. See # the section "Including formulas" for details. @@ -2050,14 +2030,6 @@ LATEX_HIDE_INDICES = NO LATEX_BIB_STYLE = plain -# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated -# page will contain the date and time when the page was generated. Setting this -# to NO can help when comparing the output of multiple runs. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_TIMESTAMP = NO - # The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) # path from which the emoji images will be read. If a relative path is entered, # it will be relative to the LATEX_OUTPUT directory. If left blank the @@ -2429,23 +2401,6 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# When you want a differently looking font in the dot files that doxygen -# generates you can specify the font name using DOT_FONTNAME. You need to make -# sure dot is able to find the font, which can be done by putting it in a -# standard location or by setting the DOTFONTPATH environment variable or by -# setting DOT_FONTPATH to the directory containing the font. -# The default value is: Helvetica. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_FONTNAME = Helvetica - -# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of -# dot graphs. -# Minimum value: 4, maximum value: 24, default value: 10. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_FONTSIZE = 10 - # By default doxygen will tell dot to use the default font as specified with # DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set # the path where dot can find it using this tag. @@ -2692,18 +2647,6 @@ DOT_GRAPH_MAX_NODES = 50 MAX_DOT_GRAPH_DEPTH = 0 -# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent -# background. This is disabled by default, because dot on Windows does not seem -# to support this out of the box. -# -# Warning: Depending on the platform used, enabling this option may lead to -# badly anti-aliased labels on the edges of a graph (i.e. they become hard to -# read). -# The default value is: NO. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_TRANSPARENT = NO - # Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output # files in one run (i.e. multiple -o and -T options on the command line). This # makes dot run faster, but since only newer versions of dot (>1.8.10) support diff --git a/docs/api/api_gen.py b/docs/api/api_gen.py index 01a7057a..939e38c5 100755 --- a/docs/api/api_gen.py +++ b/docs/api/api_gen.py @@ -4,6 +4,7 @@ from argparse import ArgumentParser from yaml import load, dump, Loader, Dumper +import re parser = ArgumentParser() parser.add_argument('input_yaml') @@ -35,6 +36,13 @@ def escape_ref(symbol): symbol = symbol.replace('>', '\\>') return symbol.replace('*', '\\*') +def get_label_and_title(rst_title): + m = re.match('([^<]+) <([^>]+)>', rst_title) + if m: + return m.groups() + return (rst_title, rst_title) + + api = dict() with open(args.input_yaml, 'r') as y: api = load(y, Loader) @@ -42,17 +50,21 @@ def escape_ref(symbol): with open(args.output_rst, 'w') as f: f.write('.. Copyright (C) 2024 Intel Corporation\n') f.write(' SPDX-License-Identifier: BSD-3-Clause\n\n') + for rst_title, categories in api.items(): - f.write('=' * len(rst_title) + '\n') - write_underline(f, rst_title, '=') + title_text, title_label = get_label_and_title(rst_title) + f.write(f'.. _{title_label}:\n\n') + f.write('=' * len(title_text) + '\n') + write_underline(f, title_text, '=') for category_name, category in categories.items(): write_underline(f, category_name, '=') for symbol_type, symbol_list in category.items(): f.write(f'* {title(symbol_type).title()}s\n\n') for symbol in symbol_list: - f.write(f' * :ref:`{escape_ref(strip_symbol_name(symbol))}`\n\n') + f.write(f' * :ref:`{escape_ref(symbol)}`\n\n') for symbol_type, symbol_list in category.items(): write_underline(f, f'{category_name} {title(symbol_type).title()}s', '-') for symbol in symbol_list: + f.write(f'.. _{escape_ref(symbol)}:\n\n') write_underline(f, strip_symbol_name(symbol), '.') f.write(f'.. doxygen{symbol_type}:: {symbol}\n\n') diff --git a/docs/api/builder_capi.rst b/docs/api/builder_capi.rst index 6af8efdd..b6e45985 100644 --- a/docs/api/builder_capi.rst +++ b/docs/api/builder_capi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Builder C-API: + ============= Builder C-API ============= @@ -10,13 +12,17 @@ Common * Enumerations - * :ref:`tinytc_arithmetic_t` + * :ref:`tinytc_address_space_t` + + * :ref:`tinytc_checked_flag_t` - * :ref:`tinytc_arithmetic_unary_t` + * :ref:`tinytc_comp3_t` - * :ref:`tinytc_cmp_condition_t` + * :ref:`tinytc_matrix_use_t` - * :ref:`tinytc_scalar_type_t` + * :ref:`tinytc_reduce_mode_t` + + * :ref:`tinytc_store_flag_t` * :ref:`tinytc_transpose_t` @@ -26,74 +32,110 @@ Common * Functions - * :ref:`tinytc_arithmetic_to_string` + * :ref:`tinytc_address_space_to_string` + + * :ref:`tinytc_checked_flag_to_string` - * :ref:`tinytc_arithmetic_unary_to_string` + * :ref:`tinytc_comp3_to_string` - * :ref:`tinytc_cmp_condition_to_string` + * :ref:`tinytc_matrix_use_to_string` - * :ref:`tinytc_scalar_type_size` + * :ref:`tinytc_reduce_mode_to_string` - * :ref:`tinytc_scalar_type_to_string` + * :ref:`tinytc_store_flag_to_string` * :ref:`tinytc_transpose_to_string` * Structures - * :ref:`tinytc_position` + * :ref:`tinytc_named_attr` * :ref:`tinytc_location` + * :ref:`tinytc_position` + * Typedefs - * :ref:`tinytc_data_type_t` + * :ref:`tinytc_address_spaces_t` + + * :ref:`tinytc_attr_t` * :ref:`tinytc_func_t` + * :ref:`tinytc_named_attr_t` + * :ref:`tinytc_location_t` * :ref:`tinytc_position_t` - * :ref:`tinytc_prog_t` - * :ref:`tinytc_inst_t` + * :ref:`tinytc_inst_iterator_t` + * :ref:`tinytc_region_t` + * :ref:`tinytc_type_t` + * :ref:`tinytc_value_t` - * :ref:`const_tinytc_data_type_t` + * :ref:`const_tinytc_attr_t` * :ref:`const_tinytc_func_t` * :ref:`const_tinytc_inst_t` - * :ref:`const_tinytc_prog_t` - * :ref:`const_tinytc_region_t` + * :ref:`const_tinytc_type_t` + + * :ref:`const_tinytc_value_t` + Common Enumerations ------------------- -tinytc_arithmetic_t -................... +.. _tinytc_address_space_t: -.. doxygenenum:: tinytc_arithmetic_t +tinytc_address_space_t +...................... -tinytc_arithmetic_unary_t -......................... +.. doxygenenum:: tinytc_address_space_t -.. doxygenenum:: tinytc_arithmetic_unary_t +.. _tinytc_checked_flag_t: -tinytc_cmp_condition_t -...................... +tinytc_checked_flag_t +..................... + +.. doxygenenum:: tinytc_checked_flag_t + +.. _tinytc_comp3_t: + +tinytc_comp3_t +.............. -.. doxygenenum:: tinytc_cmp_condition_t +.. doxygenenum:: tinytc_comp3_t -tinytc_scalar_type_t +.. _tinytc_matrix_use_t: + +tinytc_matrix_use_t +................... + +.. doxygenenum:: tinytc_matrix_use_t + +.. _tinytc_reduce_mode_t: + +tinytc_reduce_mode_t .................... -.. doxygenenum:: tinytc_scalar_type_t +.. doxygenenum:: tinytc_reduce_mode_t + +.. _tinytc_store_flag_t: + +tinytc_store_flag_t +................... + +.. doxygenenum:: tinytc_store_flag_t + +.. _tinytc_transpose_t: tinytc_transpose_t .................. @@ -103,6 +145,8 @@ tinytc_transpose_t Common Definitions ------------------ +.. _TINYTC_DYNAMIC: + TINYTC_DYNAMIC .............. @@ -111,30 +155,49 @@ TINYTC_DYNAMIC Common Functions ---------------- -tinytc_arithmetic_to_string -........................... +.. _tinytc_address_space_to_string: -.. doxygenfunction:: tinytc_arithmetic_to_string +tinytc_address_space_to_string +.............................. -tinytc_arithmetic_unary_to_string -................................. +.. doxygenfunction:: tinytc_address_space_to_string -.. doxygenfunction:: tinytc_arithmetic_unary_to_string +.. _tinytc_checked_flag_to_string: -tinytc_cmp_condition_to_string -.............................. +tinytc_checked_flag_to_string +............................. -.. doxygenfunction:: tinytc_cmp_condition_to_string +.. doxygenfunction:: tinytc_checked_flag_to_string -tinytc_scalar_type_size -....................... +.. _tinytc_comp3_to_string: + +tinytc_comp3_to_string +...................... + +.. doxygenfunction:: tinytc_comp3_to_string + +.. _tinytc_matrix_use_to_string: + +tinytc_matrix_use_to_string +........................... + +.. doxygenfunction:: tinytc_matrix_use_to_string -.. doxygenfunction:: tinytc_scalar_type_size +.. _tinytc_reduce_mode_to_string: -tinytc_scalar_type_to_string +tinytc_reduce_mode_to_string ............................ -.. doxygenfunction:: tinytc_scalar_type_to_string +.. doxygenfunction:: tinytc_reduce_mode_to_string + +.. _tinytc_store_flag_to_string: + +tinytc_store_flag_to_string +........................... + +.. doxygenfunction:: tinytc_store_flag_to_string + +.. _tinytc_transpose_to_string: tinytc_transpose_to_string .......................... @@ -144,507 +207,1426 @@ tinytc_transpose_to_string Common Structures ----------------- -tinytc_position -............... +.. _tinytc_named_attr: -.. doxygenstruct:: tinytc_position +tinytc_named_attr +................. + +.. doxygenstruct:: tinytc_named_attr + +.. _tinytc_location: tinytc_location ............... .. doxygenstruct:: tinytc_location +.. _tinytc_position: + +tinytc_position +............... + +.. doxygenstruct:: tinytc_position + Common Typedefs --------------- -tinytc_data_type_t -.................. +.. _tinytc_address_spaces_t: + +tinytc_address_spaces_t +....................... + +.. doxygentypedef:: tinytc_address_spaces_t + +.. _tinytc_attr_t: + +tinytc_attr_t +............. -.. doxygentypedef:: tinytc_data_type_t +.. doxygentypedef:: tinytc_attr_t + +.. _tinytc_func_t: tinytc_func_t ............. .. doxygentypedef:: tinytc_func_t +.. _tinytc_named_attr_t: + +tinytc_named_attr_t +................... + +.. doxygentypedef:: tinytc_named_attr_t + +.. _tinytc_location_t: + tinytc_location_t ................. .. doxygentypedef:: tinytc_location_t +.. _tinytc_position_t: + tinytc_position_t ................. .. doxygentypedef:: tinytc_position_t -tinytc_prog_t -............. - -.. doxygentypedef:: tinytc_prog_t +.. _tinytc_inst_t: tinytc_inst_t ............. .. doxygentypedef:: tinytc_inst_t +.. _tinytc_inst_iterator_t: + +tinytc_inst_iterator_t +...................... + +.. doxygentypedef:: tinytc_inst_iterator_t + +.. _tinytc_region_t: + tinytc_region_t ............... .. doxygentypedef:: tinytc_region_t +.. _tinytc_type_t: + +tinytc_type_t +............. + +.. doxygentypedef:: tinytc_type_t + +.. _tinytc_value_t: + tinytc_value_t .............. .. doxygentypedef:: tinytc_value_t -const_tinytc_data_type_t -........................ +.. _const_tinytc_attr_t: -.. doxygentypedef:: const_tinytc_data_type_t +const_tinytc_attr_t +................... + +.. doxygentypedef:: const_tinytc_attr_t + +.. _const_tinytc_func_t: const_tinytc_func_t ................... .. doxygentypedef:: const_tinytc_func_t +.. _const_tinytc_inst_t: + const_tinytc_inst_t ................... .. doxygentypedef:: const_tinytc_inst_t -const_tinytc_prog_t -................... - -.. doxygentypedef:: const_tinytc_prog_t +.. _const_tinytc_region_t: const_tinytc_region_t ..................... .. doxygentypedef:: const_tinytc_region_t -Data Type -========= - -* Functions - - * :ref:`tinytc_group_type_create` +.. _const_tinytc_type_t: - * :ref:`tinytc_memref_type_create` +const_tinytc_type_t +................... - * :ref:`tinytc_scalar_type_create` +.. doxygentypedef:: const_tinytc_type_t - * :ref:`tinytc_data_type_release` +.. _const_tinytc_value_t: - * :ref:`tinytc_data_type_retain` +const_tinytc_value_t +.................... -Data Type Functions -------------------- +.. doxygentypedef:: const_tinytc_value_t -tinytc_group_type_create -........................ +Attribute +========= -.. doxygenfunction:: tinytc_group_type_create +* Functions -tinytc_memref_type_create -......................... + * :ref:`tinytc_array_attr_get` -.. doxygenfunction:: tinytc_memref_type_create + * :ref:`tinytc_boolean_attr_get` -tinytc_scalar_type_create -......................... + * :ref:`tinytc_dictionary_attr_get` -.. doxygenfunction:: tinytc_scalar_type_create + * :ref:`tinytc_dictionary_attr_get_with_sorted` -tinytc_data_type_release -........................ + * :ref:`tinytc_dictionary_attr_sort` -.. doxygenfunction:: tinytc_data_type_release + * :ref:`tinytc_integer_attr_get` -tinytc_data_type_retain -....................... + * :ref:`tinytc_string_attr_get` -.. doxygenfunction:: tinytc_data_type_retain +Attribute Functions +------------------- -Function -======== +.. _tinytc_array_attr_get: -* Functions +tinytc_array_attr_get +..................... - * :ref:`tinytc_function_create` +.. doxygenfunction:: tinytc_array_attr_get - * :ref:`tinytc_function_prototype_create` +.. _tinytc_boolean_attr_get: - * :ref:`tinytc_function_set_subgroup_size` +tinytc_boolean_attr_get +....................... - * :ref:`tinytc_function_set_work_group_size` +.. doxygenfunction:: tinytc_boolean_attr_get - * :ref:`tinytc_func_release` +.. _tinytc_dictionary_attr_get: - * :ref:`tinytc_func_retain` +tinytc_dictionary_attr_get +.......................... -Function Functions ------------------- +.. doxygenfunction:: tinytc_dictionary_attr_get -tinytc_function_create -...................... +.. _tinytc_dictionary_attr_get_with_sorted: -.. doxygenfunction:: tinytc_function_create +tinytc_dictionary_attr_get_with_sorted +...................................... -tinytc_function_prototype_create -................................ +.. doxygenfunction:: tinytc_dictionary_attr_get_with_sorted -.. doxygenfunction:: tinytc_function_prototype_create +.. _tinytc_dictionary_attr_sort: -tinytc_function_set_subgroup_size -................................. +tinytc_dictionary_attr_sort +........................... -.. doxygenfunction:: tinytc_function_set_subgroup_size +.. doxygenfunction:: tinytc_dictionary_attr_sort -tinytc_function_set_work_group_size -................................... +.. _tinytc_integer_attr_get: -.. doxygenfunction:: tinytc_function_set_work_group_size +tinytc_integer_attr_get +....................... -tinytc_func_release -................... +.. doxygenfunction:: tinytc_integer_attr_get -.. doxygenfunction:: tinytc_func_release +.. _tinytc_string_attr_get: -tinytc_func_retain -.................. +tinytc_string_attr_get +...................... -.. doxygenfunction:: tinytc_func_retain +.. doxygenfunction:: tinytc_string_attr_get -Instruction -=========== +Data Type +========= * Functions - * :ref:`tinytc_alloca_inst_create` - - * :ref:`tinytc_axpby_inst_create` - - * :ref:`tinytc_arith_inst_create` + * :ref:`tinytc_boolean_type_get` - * :ref:`tinytc_arith_unary_inst_create` + * :ref:`tinytc_i8_type_get` - * :ref:`tinytc_cast_inst_create` + * :ref:`tinytc_i16_type_get` - * :ref:`tinytc_cmp_inst_create` + * :ref:`tinytc_i32_type_get` - * :ref:`tinytc_expand_inst_create` + * :ref:`tinytc_i64_type_get` - * :ref:`tinytc_for_inst_create` + * :ref:`tinytc_index_type_get` - * :ref:`tinytc_foreach_inst_create` + * :ref:`tinytc_f16_type_get` - * :ref:`tinytc_fuse_inst_create` + * :ref:`tinytc_bf16_type_get` - * :ref:`tinytc_gemm_inst_create` + * :ref:`tinytc_f32_type_get` - * :ref:`tinytc_gemv_inst_create` + * :ref:`tinytc_f64_type_get` - * :ref:`tinytc_ger_inst_create` + * :ref:`tinytc_c32_type_get` - * :ref:`tinytc_group_id_inst_create` + * :ref:`tinytc_c64_type_get` - * :ref:`tinytc_group_size_inst_create` + * :ref:`tinytc_coopmatrix_type_get` - * :ref:`tinytc_hadamard_inst_create` + * :ref:`tinytc_group_type_get` - * :ref:`tinytc_if_inst_create` + * :ref:`tinytc_memref_type_get` - * :ref:`tinytc_load_inst_create` + * :ref:`tinytc_void_type_get` - * :ref:`tinytc_size_inst_create` + * :ref:`tinytc_type_get_compiler_context` - * :ref:`tinytc_store_inst_create` +Data Type Functions +------------------- - * :ref:`tinytc_subview_inst_create` +.. _tinytc_boolean_type_get: - * :ref:`tinytc_sum_inst_create` +tinytc_boolean_type_get +....................... - * :ref:`tinytc_yield_inst_create` +.. doxygenfunction:: tinytc_boolean_type_get - * :ref:`tinytc_inst_get_value` +.. _tinytc_i8_type_get: - * :ref:`tinytc_inst_get_values` +tinytc_i8_type_get +.................. - * :ref:`tinytc_inst_release` +.. doxygenfunction:: tinytc_i8_type_get - * :ref:`tinytc_inst_retain` +.. _tinytc_i16_type_get: -Instruction Functions ---------------------- +tinytc_i16_type_get +................... -tinytc_alloca_inst_create -......................... +.. doxygenfunction:: tinytc_i16_type_get -.. doxygenfunction:: tinytc_alloca_inst_create +.. _tinytc_i32_type_get: -tinytc_axpby_inst_create -........................ +tinytc_i32_type_get +................... -.. doxygenfunction:: tinytc_axpby_inst_create +.. doxygenfunction:: tinytc_i32_type_get -tinytc_arith_inst_create -........................ +.. _tinytc_i64_type_get: -.. doxygenfunction:: tinytc_arith_inst_create +tinytc_i64_type_get +................... -tinytc_arith_unary_inst_create -.............................. +.. doxygenfunction:: tinytc_i64_type_get -.. doxygenfunction:: tinytc_arith_unary_inst_create +.. _tinytc_index_type_get: -tinytc_cast_inst_create -....................... +tinytc_index_type_get +..................... -.. doxygenfunction:: tinytc_cast_inst_create +.. doxygenfunction:: tinytc_index_type_get -tinytc_cmp_inst_create -...................... +.. _tinytc_f16_type_get: -.. doxygenfunction:: tinytc_cmp_inst_create +tinytc_f16_type_get +................... -tinytc_expand_inst_create -......................... +.. doxygenfunction:: tinytc_f16_type_get -.. doxygenfunction:: tinytc_expand_inst_create +.. _tinytc_bf16_type_get: -tinytc_for_inst_create -...................... +tinytc_bf16_type_get +.................... -.. doxygenfunction:: tinytc_for_inst_create +.. doxygenfunction:: tinytc_bf16_type_get -tinytc_foreach_inst_create -.......................... +.. _tinytc_f32_type_get: -.. doxygenfunction:: tinytc_foreach_inst_create +tinytc_f32_type_get +................... -tinytc_fuse_inst_create -....................... +.. doxygenfunction:: tinytc_f32_type_get -.. doxygenfunction:: tinytc_fuse_inst_create +.. _tinytc_f64_type_get: -tinytc_gemm_inst_create -....................... +tinytc_f64_type_get +................... -.. doxygenfunction:: tinytc_gemm_inst_create +.. doxygenfunction:: tinytc_f64_type_get -tinytc_gemv_inst_create -....................... +.. _tinytc_c32_type_get: -.. doxygenfunction:: tinytc_gemv_inst_create +tinytc_c32_type_get +................... -tinytc_ger_inst_create -...................... +.. doxygenfunction:: tinytc_c32_type_get -.. doxygenfunction:: tinytc_ger_inst_create +.. _tinytc_c64_type_get: -tinytc_group_id_inst_create -........................... +tinytc_c64_type_get +................... -.. doxygenfunction:: tinytc_group_id_inst_create +.. doxygenfunction:: tinytc_c64_type_get -tinytc_group_size_inst_create -............................. +.. _tinytc_coopmatrix_type_get: -.. doxygenfunction:: tinytc_group_size_inst_create +tinytc_coopmatrix_type_get +.......................... -tinytc_hadamard_inst_create -........................... +.. doxygenfunction:: tinytc_coopmatrix_type_get -.. doxygenfunction:: tinytc_hadamard_inst_create +.. _tinytc_group_type_get: -tinytc_if_inst_create +tinytc_group_type_get ..................... -.. doxygenfunction:: tinytc_if_inst_create +.. doxygenfunction:: tinytc_group_type_get -tinytc_load_inst_create -....................... +.. _tinytc_memref_type_get: -.. doxygenfunction:: tinytc_load_inst_create +tinytc_memref_type_get +...................... -tinytc_size_inst_create -....................... +.. doxygenfunction:: tinytc_memref_type_get -.. doxygenfunction:: tinytc_size_inst_create +.. _tinytc_void_type_get: -tinytc_store_inst_create -........................ +tinytc_void_type_get +.................... -.. doxygenfunction:: tinytc_store_inst_create +.. doxygenfunction:: tinytc_void_type_get -tinytc_subview_inst_create -.......................... +.. _tinytc_type_get_compiler_context: -.. doxygenfunction:: tinytc_subview_inst_create +tinytc_type_get_compiler_context +................................ -tinytc_sum_inst_create -...................... +.. doxygenfunction:: tinytc_type_get_compiler_context -.. doxygenfunction:: tinytc_sum_inst_create +Function +======== -tinytc_yield_inst_create -........................ +* Functions -.. doxygenfunction:: tinytc_yield_inst_create + * :ref:`tinytc_func_create` -tinytc_inst_get_value -..................... + * :ref:`tinytc_func_destroy` -.. doxygenfunction:: tinytc_inst_get_value + * :ref:`tinytc_func_get_body` -tinytc_inst_get_values -...................... + * :ref:`tinytc_func_set_attr` -.. doxygenfunction:: tinytc_inst_get_values + * :ref:`tinytc_func_set_parameter_attr` -tinytc_inst_release -................... +Function Functions +------------------ -.. doxygenfunction:: tinytc_inst_release +.. _tinytc_func_create: -tinytc_inst_retain +tinytc_func_create .................. -.. doxygenfunction:: tinytc_inst_retain +.. doxygenfunction:: tinytc_func_create -Program -======= +.. _tinytc_func_destroy: -* Functions +tinytc_func_destroy +................... - * :ref:`tinytc_program_create` +.. doxygenfunction:: tinytc_func_destroy - * :ref:`tinytc_prog_dump` +.. _tinytc_func_get_body: - * :ref:`tinytc_prog_print_to_file` +tinytc_func_get_body +.................... - * :ref:`tinytc_prog_print_to_string` +.. doxygenfunction:: tinytc_func_get_body - * :ref:`tinytc_prog_release` +.. _tinytc_func_set_attr: - * :ref:`tinytc_prog_retain` +tinytc_func_set_attr +.................... -Program Functions ------------------ +.. doxygenfunction:: tinytc_func_set_attr -tinytc_program_create +.. _tinytc_func_set_parameter_attr: + +tinytc_func_set_parameter_attr +.............................. + +.. doxygenfunction:: tinytc_func_set_parameter_attr + +Instruction +=========== + +* Functions + + * :ref:`tinytc_alloca_inst_create` + + * :ref:`tinytc_barrier_inst_create` + + * :ref:`tinytc_cast_inst_create` + + * :ref:`tinytc_constant_inst_create_boolean` + + * :ref:`tinytc_constant_inst_create_complex` + + * :ref:`tinytc_constant_inst_create_float` + + * :ref:`tinytc_constant_inst_create_int` + + * :ref:`tinytc_constant_inst_create_one` + + * :ref:`tinytc_constant_inst_create_zero` + + * :ref:`tinytc_cooperative_matrix_apply_inst_create` + + * :ref:`tinytc_cooperative_matrix_extract_inst_create` + + * :ref:`tinytc_cooperative_matrix_insert_inst_create` + + * :ref:`tinytc_cooperative_matrix_load_inst_create` + + * :ref:`tinytc_cooperative_matrix_mul_add_inst_create` + + * :ref:`tinytc_cooperative_matrix_reduce_add_inst_create` + + * :ref:`tinytc_cooperative_matrix_reduce_max_inst_create` + + * :ref:`tinytc_cooperative_matrix_reduce_min_inst_create` + + * :ref:`tinytc_cooperative_matrix_prefetch_inst_create` + + * :ref:`tinytc_cooperative_matrix_scale_inst_create` + + * :ref:`tinytc_cooperative_matrix_store_inst_create` + + * :ref:`tinytc_expand_inst_create` + + * :ref:`tinytc_fuse_inst_create` + + * :ref:`tinytc_if_inst_create` + + * :ref:`tinytc_lifetime_stop_inst_create` + + * :ref:`tinytc_load_inst_create` + + * :ref:`tinytc_parallel_inst_create` + + * :ref:`tinytc_size_inst_create` + + * :ref:`tinytc_subgroup_broadcast_inst_create` + + * :ref:`tinytc_store_inst_create` + + * :ref:`tinytc_subview_inst_create` + + * :ref:`tinytc_yield_inst_create` + + * :ref:`tinytc_add_inst_create` + + * :ref:`tinytc_sub_inst_create` + + * :ref:`tinytc_mul_inst_create` + + * :ref:`tinytc_div_inst_create` + + * :ref:`tinytc_rem_inst_create` + + * :ref:`tinytc_shl_inst_create` + + * :ref:`tinytc_shr_inst_create` + + * :ref:`tinytc_and_inst_create` + + * :ref:`tinytc_or_inst_create` + + * :ref:`tinytc_xor_inst_create` + + * :ref:`tinytc_min_inst_create` + + * :ref:`tinytc_max_inst_create` + + * :ref:`tinytc_abs_inst_create` + + * :ref:`tinytc_neg_inst_create` + + * :ref:`tinytc_not_inst_create` + + * :ref:`tinytc_conj_inst_create` + + * :ref:`tinytc_im_inst_create` + + * :ref:`tinytc_re_inst_create` + + * :ref:`tinytc_axpby_inst_create` + + * :ref:`tinytc_cumsum_inst_create` + + * :ref:`tinytc_sum_inst_create` + + * :ref:`tinytc_gemm_inst_create` + + * :ref:`tinytc_gemv_inst_create` + + * :ref:`tinytc_ger_inst_create` + + * :ref:`tinytc_hadamard_inst_create` + + * :ref:`tinytc_group_id_inst_create` + + * :ref:`tinytc_num_groups_inst_create` + + * :ref:`tinytc_num_subgroups_inst_create` + + * :ref:`tinytc_subgroup_size_inst_create` + + * :ref:`tinytc_subgroup_id_inst_create` + + * :ref:`tinytc_subgroup_linear_id_inst_create` + + * :ref:`tinytc_subgroup_local_id_inst_create` + + * :ref:`tinytc_cos_inst_create` + + * :ref:`tinytc_sin_inst_create` + + * :ref:`tinytc_exp_inst_create` + + * :ref:`tinytc_exp2_inst_create` + + * :ref:`tinytc_native_cos_inst_create` + + * :ref:`tinytc_native_sin_inst_create` + + * :ref:`tinytc_native_exp_inst_create` + + * :ref:`tinytc_native_exp2_inst_create` + + * :ref:`tinytc_subgroup_exclusive_scan_add_inst_create` + + * :ref:`tinytc_subgroup_exclusive_scan_max_inst_create` + + * :ref:`tinytc_subgroup_exclusive_scan_min_inst_create` + + * :ref:`tinytc_subgroup_inclusive_scan_add_inst_create` + + * :ref:`tinytc_subgroup_inclusive_scan_max_inst_create` + + * :ref:`tinytc_subgroup_inclusive_scan_min_inst_create` + + * :ref:`tinytc_subgroup_reduce_add_inst_create` + + * :ref:`tinytc_subgroup_reduce_max_inst_create` + + * :ref:`tinytc_subgroup_reduce_min_inst_create` + + * :ref:`tinytc_equal_inst_create` + + * :ref:`tinytc_not_equal_inst_create` + + * :ref:`tinytc_greater_than_inst_create` + + * :ref:`tinytc_greater_than_equal_inst_create` + + * :ref:`tinytc_less_than_inst_create` + + * :ref:`tinytc_less_than_equal_inst_create` + + * :ref:`tinytc_for_inst_create` + + * :ref:`tinytc_foreach_inst_create` + + * :ref:`tinytc_inst_get_parent_region` + + * :ref:`tinytc_inst_get_regions` + + * :ref:`tinytc_inst_get_values` + + * :ref:`tinytc_inst_destroy` + + * :ref:`tinytc_inst_set_attr` + +Instruction Functions +--------------------- + +.. _tinytc_alloca_inst_create: + +tinytc_alloca_inst_create +......................... + +.. doxygenfunction:: tinytc_alloca_inst_create + +.. _tinytc_barrier_inst_create: + +tinytc_barrier_inst_create +.......................... + +.. doxygenfunction:: tinytc_barrier_inst_create + +.. _tinytc_cast_inst_create: + +tinytc_cast_inst_create +....................... + +.. doxygenfunction:: tinytc_cast_inst_create + +.. _tinytc_constant_inst_create_boolean: + +tinytc_constant_inst_create_boolean +................................... + +.. doxygenfunction:: tinytc_constant_inst_create_boolean + +.. _tinytc_constant_inst_create_complex: + +tinytc_constant_inst_create_complex +................................... + +.. doxygenfunction:: tinytc_constant_inst_create_complex + +.. _tinytc_constant_inst_create_float: + +tinytc_constant_inst_create_float +................................. + +.. doxygenfunction:: tinytc_constant_inst_create_float + +.. _tinytc_constant_inst_create_int: + +tinytc_constant_inst_create_int +............................... + +.. doxygenfunction:: tinytc_constant_inst_create_int + +.. _tinytc_constant_inst_create_one: + +tinytc_constant_inst_create_one +............................... + +.. doxygenfunction:: tinytc_constant_inst_create_one + +.. _tinytc_constant_inst_create_zero: + +tinytc_constant_inst_create_zero +................................ + +.. doxygenfunction:: tinytc_constant_inst_create_zero + +.. _tinytc_cooperative_matrix_apply_inst_create: + +tinytc_cooperative_matrix_apply_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_apply_inst_create + +.. _tinytc_cooperative_matrix_extract_inst_create: + +tinytc_cooperative_matrix_extract_inst_create +............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_extract_inst_create + +.. _tinytc_cooperative_matrix_insert_inst_create: + +tinytc_cooperative_matrix_insert_inst_create +............................................ + +.. doxygenfunction:: tinytc_cooperative_matrix_insert_inst_create + +.. _tinytc_cooperative_matrix_load_inst_create: + +tinytc_cooperative_matrix_load_inst_create +.......................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_load_inst_create + +.. _tinytc_cooperative_matrix_mul_add_inst_create: + +tinytc_cooperative_matrix_mul_add_inst_create +............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_mul_add_inst_create + +.. _tinytc_cooperative_matrix_reduce_add_inst_create: + +tinytc_cooperative_matrix_reduce_add_inst_create +................................................ + +.. doxygenfunction:: tinytc_cooperative_matrix_reduce_add_inst_create + +.. _tinytc_cooperative_matrix_reduce_max_inst_create: + +tinytc_cooperative_matrix_reduce_max_inst_create +................................................ + +.. doxygenfunction:: tinytc_cooperative_matrix_reduce_max_inst_create + +.. _tinytc_cooperative_matrix_reduce_min_inst_create: + +tinytc_cooperative_matrix_reduce_min_inst_create +................................................ + +.. doxygenfunction:: tinytc_cooperative_matrix_reduce_min_inst_create + +.. _tinytc_cooperative_matrix_prefetch_inst_create: + +tinytc_cooperative_matrix_prefetch_inst_create +.............................................. + +.. doxygenfunction:: tinytc_cooperative_matrix_prefetch_inst_create + +.. _tinytc_cooperative_matrix_scale_inst_create: + +tinytc_cooperative_matrix_scale_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_scale_inst_create + +.. _tinytc_cooperative_matrix_store_inst_create: + +tinytc_cooperative_matrix_store_inst_create +........................................... + +.. doxygenfunction:: tinytc_cooperative_matrix_store_inst_create + +.. _tinytc_expand_inst_create: + +tinytc_expand_inst_create +......................... + +.. doxygenfunction:: tinytc_expand_inst_create + +.. _tinytc_fuse_inst_create: + +tinytc_fuse_inst_create +....................... + +.. doxygenfunction:: tinytc_fuse_inst_create + +.. _tinytc_if_inst_create: + +tinytc_if_inst_create +..................... + +.. doxygenfunction:: tinytc_if_inst_create + +.. _tinytc_lifetime_stop_inst_create: + +tinytc_lifetime_stop_inst_create +................................ + +.. doxygenfunction:: tinytc_lifetime_stop_inst_create + +.. _tinytc_load_inst_create: + +tinytc_load_inst_create +....................... + +.. doxygenfunction:: tinytc_load_inst_create + +.. _tinytc_parallel_inst_create: + +tinytc_parallel_inst_create +........................... + +.. doxygenfunction:: tinytc_parallel_inst_create + +.. _tinytc_size_inst_create: + +tinytc_size_inst_create +....................... + +.. doxygenfunction:: tinytc_size_inst_create + +.. _tinytc_subgroup_broadcast_inst_create: + +tinytc_subgroup_broadcast_inst_create +..................................... + +.. doxygenfunction:: tinytc_subgroup_broadcast_inst_create + +.. _tinytc_store_inst_create: + +tinytc_store_inst_create +........................ + +.. doxygenfunction:: tinytc_store_inst_create + +.. _tinytc_subview_inst_create: + +tinytc_subview_inst_create +.......................... + +.. doxygenfunction:: tinytc_subview_inst_create + +.. _tinytc_yield_inst_create: + +tinytc_yield_inst_create +........................ + +.. doxygenfunction:: tinytc_yield_inst_create + +.. _tinytc_add_inst_create: + +tinytc_add_inst_create +...................... + +.. doxygenfunction:: tinytc_add_inst_create + +.. _tinytc_sub_inst_create: + +tinytc_sub_inst_create +...................... + +.. doxygenfunction:: tinytc_sub_inst_create + +.. _tinytc_mul_inst_create: + +tinytc_mul_inst_create +...................... + +.. doxygenfunction:: tinytc_mul_inst_create + +.. _tinytc_div_inst_create: + +tinytc_div_inst_create +...................... + +.. doxygenfunction:: tinytc_div_inst_create + +.. _tinytc_rem_inst_create: + +tinytc_rem_inst_create +...................... + +.. doxygenfunction:: tinytc_rem_inst_create + +.. _tinytc_shl_inst_create: + +tinytc_shl_inst_create +...................... + +.. doxygenfunction:: tinytc_shl_inst_create + +.. _tinytc_shr_inst_create: + +tinytc_shr_inst_create +...................... + +.. doxygenfunction:: tinytc_shr_inst_create + +.. _tinytc_and_inst_create: + +tinytc_and_inst_create +...................... + +.. doxygenfunction:: tinytc_and_inst_create + +.. _tinytc_or_inst_create: + +tinytc_or_inst_create +..................... + +.. doxygenfunction:: tinytc_or_inst_create + +.. _tinytc_xor_inst_create: + +tinytc_xor_inst_create +...................... + +.. doxygenfunction:: tinytc_xor_inst_create + +.. _tinytc_min_inst_create: + +tinytc_min_inst_create +...................... + +.. doxygenfunction:: tinytc_min_inst_create + +.. _tinytc_max_inst_create: + +tinytc_max_inst_create +...................... + +.. doxygenfunction:: tinytc_max_inst_create + +.. _tinytc_abs_inst_create: + +tinytc_abs_inst_create +...................... + +.. doxygenfunction:: tinytc_abs_inst_create + +.. _tinytc_neg_inst_create: + +tinytc_neg_inst_create +...................... + +.. doxygenfunction:: tinytc_neg_inst_create + +.. _tinytc_not_inst_create: + +tinytc_not_inst_create +...................... + +.. doxygenfunction:: tinytc_not_inst_create + +.. _tinytc_conj_inst_create: + +tinytc_conj_inst_create +....................... + +.. doxygenfunction:: tinytc_conj_inst_create + +.. _tinytc_im_inst_create: + +tinytc_im_inst_create ..................... -.. doxygenfunction:: tinytc_program_create +.. doxygenfunction:: tinytc_im_inst_create -tinytc_prog_dump -................ +.. _tinytc_re_inst_create: + +tinytc_re_inst_create +..................... + +.. doxygenfunction:: tinytc_re_inst_create + +.. _tinytc_axpby_inst_create: + +tinytc_axpby_inst_create +........................ + +.. doxygenfunction:: tinytc_axpby_inst_create -.. doxygenfunction:: tinytc_prog_dump +.. _tinytc_cumsum_inst_create: -tinytc_prog_print_to_file +tinytc_cumsum_inst_create ......................... -.. doxygenfunction:: tinytc_prog_print_to_file +.. doxygenfunction:: tinytc_cumsum_inst_create + +.. _tinytc_sum_inst_create: + +tinytc_sum_inst_create +...................... + +.. doxygenfunction:: tinytc_sum_inst_create + +.. _tinytc_gemm_inst_create: + +tinytc_gemm_inst_create +....................... + +.. doxygenfunction:: tinytc_gemm_inst_create + +.. _tinytc_gemv_inst_create: + +tinytc_gemv_inst_create +....................... + +.. doxygenfunction:: tinytc_gemv_inst_create + +.. _tinytc_ger_inst_create: + +tinytc_ger_inst_create +...................... + +.. doxygenfunction:: tinytc_ger_inst_create + +.. _tinytc_hadamard_inst_create: + +tinytc_hadamard_inst_create +........................... + +.. doxygenfunction:: tinytc_hadamard_inst_create + +.. _tinytc_group_id_inst_create: -tinytc_prog_print_to_string +tinytc_group_id_inst_create ........................... -.. doxygenfunction:: tinytc_prog_print_to_string +.. doxygenfunction:: tinytc_group_id_inst_create + +.. _tinytc_num_groups_inst_create: + +tinytc_num_groups_inst_create +............................. + +.. doxygenfunction:: tinytc_num_groups_inst_create + +.. _tinytc_num_subgroups_inst_create: + +tinytc_num_subgroups_inst_create +................................ + +.. doxygenfunction:: tinytc_num_subgroups_inst_create + +.. _tinytc_subgroup_size_inst_create: + +tinytc_subgroup_size_inst_create +................................ + +.. doxygenfunction:: tinytc_subgroup_size_inst_create + +.. _tinytc_subgroup_id_inst_create: + +tinytc_subgroup_id_inst_create +.............................. + +.. doxygenfunction:: tinytc_subgroup_id_inst_create + +.. _tinytc_subgroup_linear_id_inst_create: + +tinytc_subgroup_linear_id_inst_create +..................................... + +.. doxygenfunction:: tinytc_subgroup_linear_id_inst_create + +.. _tinytc_subgroup_local_id_inst_create: + +tinytc_subgroup_local_id_inst_create +.................................... + +.. doxygenfunction:: tinytc_subgroup_local_id_inst_create + +.. _tinytc_cos_inst_create: + +tinytc_cos_inst_create +...................... + +.. doxygenfunction:: tinytc_cos_inst_create + +.. _tinytc_sin_inst_create: + +tinytc_sin_inst_create +...................... + +.. doxygenfunction:: tinytc_sin_inst_create + +.. _tinytc_exp_inst_create: + +tinytc_exp_inst_create +...................... + +.. doxygenfunction:: tinytc_exp_inst_create + +.. _tinytc_exp2_inst_create: + +tinytc_exp2_inst_create +....................... + +.. doxygenfunction:: tinytc_exp2_inst_create + +.. _tinytc_native_cos_inst_create: + +tinytc_native_cos_inst_create +............................. + +.. doxygenfunction:: tinytc_native_cos_inst_create + +.. _tinytc_native_sin_inst_create: + +tinytc_native_sin_inst_create +............................. + +.. doxygenfunction:: tinytc_native_sin_inst_create + +.. _tinytc_native_exp_inst_create: + +tinytc_native_exp_inst_create +............................. + +.. doxygenfunction:: tinytc_native_exp_inst_create + +.. _tinytc_native_exp2_inst_create: + +tinytc_native_exp2_inst_create +.............................. + +.. doxygenfunction:: tinytc_native_exp2_inst_create + +.. _tinytc_subgroup_exclusive_scan_add_inst_create: + +tinytc_subgroup_exclusive_scan_add_inst_create +.............................................. + +.. doxygenfunction:: tinytc_subgroup_exclusive_scan_add_inst_create + +.. _tinytc_subgroup_exclusive_scan_max_inst_create: + +tinytc_subgroup_exclusive_scan_max_inst_create +.............................................. + +.. doxygenfunction:: tinytc_subgroup_exclusive_scan_max_inst_create + +.. _tinytc_subgroup_exclusive_scan_min_inst_create: + +tinytc_subgroup_exclusive_scan_min_inst_create +.............................................. + +.. doxygenfunction:: tinytc_subgroup_exclusive_scan_min_inst_create + +.. _tinytc_subgroup_inclusive_scan_add_inst_create: + +tinytc_subgroup_inclusive_scan_add_inst_create +.............................................. + +.. doxygenfunction:: tinytc_subgroup_inclusive_scan_add_inst_create + +.. _tinytc_subgroup_inclusive_scan_max_inst_create: + +tinytc_subgroup_inclusive_scan_max_inst_create +.............................................. -tinytc_prog_release +.. doxygenfunction:: tinytc_subgroup_inclusive_scan_max_inst_create + +.. _tinytc_subgroup_inclusive_scan_min_inst_create: + +tinytc_subgroup_inclusive_scan_min_inst_create +.............................................. + +.. doxygenfunction:: tinytc_subgroup_inclusive_scan_min_inst_create + +.. _tinytc_subgroup_reduce_add_inst_create: + +tinytc_subgroup_reduce_add_inst_create +...................................... + +.. doxygenfunction:: tinytc_subgroup_reduce_add_inst_create + +.. _tinytc_subgroup_reduce_max_inst_create: + +tinytc_subgroup_reduce_max_inst_create +...................................... + +.. doxygenfunction:: tinytc_subgroup_reduce_max_inst_create + +.. _tinytc_subgroup_reduce_min_inst_create: + +tinytc_subgroup_reduce_min_inst_create +...................................... + +.. doxygenfunction:: tinytc_subgroup_reduce_min_inst_create + +.. _tinytc_equal_inst_create: + +tinytc_equal_inst_create +........................ + +.. doxygenfunction:: tinytc_equal_inst_create + +.. _tinytc_not_equal_inst_create: + +tinytc_not_equal_inst_create +............................ + +.. doxygenfunction:: tinytc_not_equal_inst_create + +.. _tinytc_greater_than_inst_create: + +tinytc_greater_than_inst_create +............................... + +.. doxygenfunction:: tinytc_greater_than_inst_create + +.. _tinytc_greater_than_equal_inst_create: + +tinytc_greater_than_equal_inst_create +..................................... + +.. doxygenfunction:: tinytc_greater_than_equal_inst_create + +.. _tinytc_less_than_inst_create: + +tinytc_less_than_inst_create +............................ + +.. doxygenfunction:: tinytc_less_than_inst_create + +.. _tinytc_less_than_equal_inst_create: + +tinytc_less_than_equal_inst_create +.................................. + +.. doxygenfunction:: tinytc_less_than_equal_inst_create + +.. _tinytc_for_inst_create: + +tinytc_for_inst_create +...................... + +.. doxygenfunction:: tinytc_for_inst_create + +.. _tinytc_foreach_inst_create: + +tinytc_foreach_inst_create +.......................... + +.. doxygenfunction:: tinytc_foreach_inst_create + +.. _tinytc_inst_get_parent_region: + +tinytc_inst_get_parent_region +............................. + +.. doxygenfunction:: tinytc_inst_get_parent_region + +.. _tinytc_inst_get_regions: + +tinytc_inst_get_regions +....................... + +.. doxygenfunction:: tinytc_inst_get_regions + +.. _tinytc_inst_get_values: + +tinytc_inst_get_values +...................... + +.. doxygenfunction:: tinytc_inst_get_values + +.. _tinytc_inst_destroy: + +tinytc_inst_destroy ................... -.. doxygenfunction:: tinytc_prog_release +.. doxygenfunction:: tinytc_inst_destroy + +.. _tinytc_inst_set_attr: + +tinytc_inst_set_attr +.................... + +.. doxygenfunction:: tinytc_inst_set_attr + +Program +======= + +* Functions + + * :ref:`tinytc_prog_create` + + * :ref:`tinytc_prog_add_function` -tinytc_prog_retain +Program Functions +----------------- + +.. _tinytc_prog_create: + +tinytc_prog_create .................. -.. doxygenfunction:: tinytc_prog_retain +.. doxygenfunction:: tinytc_prog_create + +.. _tinytc_prog_add_function: + +tinytc_prog_add_function +........................ + +.. doxygenfunction:: tinytc_prog_add_function Region ====== * Functions - * :ref:`tinytc_region_create` + * :ref:`tinytc_region_append` - * :ref:`tinytc_region_release` + * :ref:`tinytc_region_begin` - * :ref:`tinytc_region_retain` + * :ref:`tinytc_region_end` + + * :ref:`tinytc_region_erase` + + * :ref:`tinytc_region_insert` + + * :ref:`tinytc_next_inst` + + * :ref:`tinytc_prev_inst` + + * :ref:`tinytc_region_get_parameters` Region Functions ---------------- -tinytc_region_create +.. _tinytc_region_append: + +tinytc_region_append .................... -.. doxygenfunction:: tinytc_region_create +.. doxygenfunction:: tinytc_region_append -tinytc_region_release -..................... +.. _tinytc_region_begin: + +tinytc_region_begin +................... + +.. doxygenfunction:: tinytc_region_begin + +.. _tinytc_region_end: + +tinytc_region_end +................. + +.. doxygenfunction:: tinytc_region_end + +.. _tinytc_region_erase: + +tinytc_region_erase +................... + +.. doxygenfunction:: tinytc_region_erase -.. doxygenfunction:: tinytc_region_release +.. _tinytc_region_insert: -tinytc_region_retain +tinytc_region_insert .................... -.. doxygenfunction:: tinytc_region_retain +.. doxygenfunction:: tinytc_region_insert -Value -===== +.. _tinytc_next_inst: -* Functions +tinytc_next_inst +................ - * :ref:`tinytc_float_imm_create` +.. doxygenfunction:: tinytc_next_inst - * :ref:`tinytc_int_imm_create` +.. _tinytc_prev_inst: - * :ref:`tinytc_value_create` +tinytc_prev_inst +................ - * :ref:`tinytc_value_get_name` +.. doxygenfunction:: tinytc_prev_inst - * :ref:`tinytc_value_set_name` +.. _tinytc_region_get_parameters: - * :ref:`tinytc_value_release` +tinytc_region_get_parameters +............................ - * :ref:`tinytc_value_retain` +.. doxygenfunction:: tinytc_region_get_parameters -Value Functions ---------------- +Value +===== -tinytc_float_imm_create -....................... +* Functions -.. doxygenfunction:: tinytc_float_imm_create + * :ref:`tinytc_value_get_name` -tinytc_int_imm_create -..................... + * :ref:`tinytc_value_get_type` -.. doxygenfunction:: tinytc_int_imm_create + * :ref:`tinytc_value_set_name` -tinytc_value_create -................... + * :ref:`tinytc_value_set_name_n` -.. doxygenfunction:: tinytc_value_create +Value Functions +--------------- + +.. _tinytc_value_get_name: tinytc_value_get_name ..................... .. doxygenfunction:: tinytc_value_get_name +.. _tinytc_value_get_type: + +tinytc_value_get_type +..................... + +.. doxygenfunction:: tinytc_value_get_type + +.. _tinytc_value_set_name: + tinytc_value_set_name ..................... .. doxygenfunction:: tinytc_value_set_name -tinytc_value_release -.................... - -.. doxygenfunction:: tinytc_value_release +.. _tinytc_value_set_name_n: -tinytc_value_retain -................... +tinytc_value_set_name_n +....................... -.. doxygenfunction:: tinytc_value_retain +.. doxygenfunction:: tinytc_value_set_name_n diff --git a/docs/api/builder_capi.yaml b/docs/api/builder_capi.yaml index 4412c745..3033bf68 100644 --- a/docs/api/builder_capi.yaml +++ b/docs/api/builder_capi.yaml @@ -3,100 +3,192 @@ Builder C-API: Common: enum: - - tinytc_arithmetic_t - - tinytc_arithmetic_unary_t - - tinytc_cmp_condition_t - - tinytc_scalar_type_t + - tinytc_address_space_t + - tinytc_checked_flag_t + - tinytc_comp3_t + - tinytc_matrix_use_t + - tinytc_reduce_mode_t + - tinytc_store_flag_t - tinytc_transpose_t define: - TINYTC_DYNAMIC function: - - tinytc_arithmetic_to_string - - tinytc_arithmetic_unary_to_string - - tinytc_cmp_condition_to_string - - tinytc_scalar_type_size - - tinytc_scalar_type_to_string + - tinytc_address_space_to_string + - tinytc_checked_flag_to_string + - tinytc_comp3_to_string + - tinytc_matrix_use_to_string + - tinytc_reduce_mode_to_string + - tinytc_store_flag_to_string - tinytc_transpose_to_string struct: - - tinytc_position + - tinytc_named_attr - tinytc_location + - tinytc_position typedef: - - tinytc_data_type_t + - tinytc_address_spaces_t + - tinytc_attr_t - tinytc_func_t + - tinytc_named_attr_t - tinytc_location_t - tinytc_position_t - - tinytc_prog_t - tinytc_inst_t + - tinytc_inst_iterator_t - tinytc_region_t + - tinytc_type_t - tinytc_value_t - - const_tinytc_data_type_t + - const_tinytc_attr_t - const_tinytc_func_t - const_tinytc_inst_t - - const_tinytc_prog_t - const_tinytc_region_t + - const_tinytc_type_t + - const_tinytc_value_t + Attribute: + function: + - tinytc_array_attr_get + - tinytc_boolean_attr_get + - tinytc_dictionary_attr_get + - tinytc_dictionary_attr_get_with_sorted + - tinytc_dictionary_attr_sort + - tinytc_integer_attr_get + - tinytc_string_attr_get Data Type: function: - - tinytc_group_type_create - - tinytc_memref_type_create - - tinytc_scalar_type_create - - tinytc_data_type_release - - tinytc_data_type_retain + - tinytc_boolean_type_get + - tinytc_i8_type_get + - tinytc_i16_type_get + - tinytc_i32_type_get + - tinytc_i64_type_get + - tinytc_index_type_get + - tinytc_f16_type_get + - tinytc_bf16_type_get + - tinytc_f32_type_get + - tinytc_f64_type_get + - tinytc_c32_type_get + - tinytc_c64_type_get + - tinytc_coopmatrix_type_get + - tinytc_group_type_get + - tinytc_memref_type_get + - tinytc_void_type_get + - tinytc_type_get_compiler_context Function: function: - - tinytc_function_create - - tinytc_function_prototype_create - - tinytc_function_set_subgroup_size - - tinytc_function_set_work_group_size - - tinytc_func_release - - tinytc_func_retain + - tinytc_func_create + - tinytc_func_destroy + - tinytc_func_get_body + - tinytc_func_set_attr + - tinytc_func_set_parameter_attr Instruction: function: - tinytc_alloca_inst_create - - tinytc_axpby_inst_create - - tinytc_arith_inst_create - - tinytc_arith_unary_inst_create + - tinytc_barrier_inst_create - tinytc_cast_inst_create - - tinytc_cmp_inst_create + - tinytc_constant_inst_create_boolean + - tinytc_constant_inst_create_complex + - tinytc_constant_inst_create_float + - tinytc_constant_inst_create_int + - tinytc_constant_inst_create_one + - tinytc_constant_inst_create_zero + - tinytc_cooperative_matrix_apply_inst_create + - tinytc_cooperative_matrix_extract_inst_create + - tinytc_cooperative_matrix_insert_inst_create + - tinytc_cooperative_matrix_load_inst_create + - tinytc_cooperative_matrix_mul_add_inst_create + - tinytc_cooperative_matrix_reduce_add_inst_create + - tinytc_cooperative_matrix_reduce_max_inst_create + - tinytc_cooperative_matrix_reduce_min_inst_create + - tinytc_cooperative_matrix_prefetch_inst_create + - tinytc_cooperative_matrix_scale_inst_create + - tinytc_cooperative_matrix_store_inst_create - tinytc_expand_inst_create - - tinytc_for_inst_create - - tinytc_foreach_inst_create - tinytc_fuse_inst_create - - tinytc_gemm_inst_create - - tinytc_gemv_inst_create - - tinytc_ger_inst_create - - tinytc_group_id_inst_create - - tinytc_group_size_inst_create - - tinytc_hadamard_inst_create - tinytc_if_inst_create + - tinytc_lifetime_stop_inst_create - tinytc_load_inst_create + - tinytc_parallel_inst_create - tinytc_size_inst_create + - tinytc_subgroup_broadcast_inst_create - tinytc_store_inst_create - tinytc_subview_inst_create - - tinytc_sum_inst_create - tinytc_yield_inst_create - - tinytc_inst_get_value + - tinytc_add_inst_create + - tinytc_sub_inst_create + - tinytc_mul_inst_create + - tinytc_div_inst_create + - tinytc_rem_inst_create + - tinytc_shl_inst_create + - tinytc_shr_inst_create + - tinytc_and_inst_create + - tinytc_or_inst_create + - tinytc_xor_inst_create + - tinytc_min_inst_create + - tinytc_max_inst_create + - tinytc_abs_inst_create + - tinytc_neg_inst_create + - tinytc_not_inst_create + - tinytc_conj_inst_create + - tinytc_im_inst_create + - tinytc_re_inst_create + - tinytc_axpby_inst_create + - tinytc_cumsum_inst_create + - tinytc_sum_inst_create + - tinytc_gemm_inst_create + - tinytc_gemv_inst_create + - tinytc_ger_inst_create + - tinytc_hadamard_inst_create + - tinytc_group_id_inst_create + - tinytc_num_groups_inst_create + - tinytc_num_subgroups_inst_create + - tinytc_subgroup_size_inst_create + - tinytc_subgroup_id_inst_create + - tinytc_subgroup_linear_id_inst_create + - tinytc_subgroup_local_id_inst_create + - tinytc_cos_inst_create + - tinytc_sin_inst_create + - tinytc_exp_inst_create + - tinytc_exp2_inst_create + - tinytc_native_cos_inst_create + - tinytc_native_sin_inst_create + - tinytc_native_exp_inst_create + - tinytc_native_exp2_inst_create + - tinytc_subgroup_exclusive_scan_add_inst_create + - tinytc_subgroup_exclusive_scan_max_inst_create + - tinytc_subgroup_exclusive_scan_min_inst_create + - tinytc_subgroup_inclusive_scan_add_inst_create + - tinytc_subgroup_inclusive_scan_max_inst_create + - tinytc_subgroup_inclusive_scan_min_inst_create + - tinytc_subgroup_reduce_add_inst_create + - tinytc_subgroup_reduce_max_inst_create + - tinytc_subgroup_reduce_min_inst_create + - tinytc_equal_inst_create + - tinytc_not_equal_inst_create + - tinytc_greater_than_inst_create + - tinytc_greater_than_equal_inst_create + - tinytc_less_than_inst_create + - tinytc_less_than_equal_inst_create + - tinytc_for_inst_create + - tinytc_foreach_inst_create + - tinytc_inst_get_parent_region + - tinytc_inst_get_regions - tinytc_inst_get_values - - tinytc_inst_release - - tinytc_inst_retain + - tinytc_inst_destroy + - tinytc_inst_set_attr Program: function: - - tinytc_program_create - - tinytc_prog_dump - - tinytc_prog_print_to_file - - tinytc_prog_print_to_string - - tinytc_prog_release - - tinytc_prog_retain + - tinytc_prog_create + - tinytc_prog_add_function Region: function: - - tinytc_region_create - - tinytc_region_release - - tinytc_region_retain + - tinytc_region_append + - tinytc_region_begin + - tinytc_region_end + - tinytc_region_erase + - tinytc_region_insert + - tinytc_next_inst + - tinytc_prev_inst + - tinytc_region_get_parameters Value: function: - - tinytc_float_imm_create - - tinytc_int_imm_create - - tinytc_value_create - tinytc_value_get_name + - tinytc_value_get_type - tinytc_value_set_name - - tinytc_value_release - - tinytc_value_retain + - tinytc_value_set_name_n diff --git a/docs/api/builder_cxxapi.rst b/docs/api/builder_cxxapi.rst index 57668d4a..a7c6c8c4 100644 --- a/docs/api/builder_cxxapi.rst +++ b/docs/api/builder_cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Builder C++-API: + =============== Builder C++-API =============== @@ -10,68 +12,98 @@ Common * Enumerations - * :ref:`arithmetic` + * :ref:`tinytc::address_space` + + * :ref:`tinytc::checked_flag` - * :ref:`arithmetic_unary` + * :ref:`tinytc::comp3` - * :ref:`cmp_condition` + * :ref:`tinytc::matrix_use` - * :ref:`scalar_type` + * :ref:`tinytc::reduce_mode` - * :ref:`transpose` + * :ref:`tinytc::store_flag` + + * :ref:`tinytc::transpose` * Functions - * :ref:`is_dynamic_value` + * :ref:`tinytc::is_dynamic_value` + + * :ref:`tinytc::to_string(address_space)` - * :ref:`to_string(arithmetic)` + * :ref:`tinytc::to_string(checked_flag)` - * :ref:`to_string(arithmetic_unary)` + * :ref:`tinytc::to_string(comp3)` - * :ref:`to_string(cmp_condition)` + * :ref:`tinytc::to_string(matrix_use)` - * :ref:`to_string(scalar_type)` + * :ref:`tinytc::to_string(reduce_mode)` - * :ref:`to_string(transpose)` + * :ref:`tinytc::to_string(store_flag)` - * :ref:`size` + * :ref:`tinytc::to_string(transpose)` * Classes - * :ref:`builder_error` + * :ref:`tinytc::builder_error` * Typedefs - * :ref:`position` + * :ref:`tinytc::location` - * :ref:`location` + * :ref:`tinytc::position` * Variables - * :ref:`dynamic` + * :ref:`tinytc::dynamic` Common Enumerations ------------------- -arithmetic -.......... +.. _tinytc::address_space: -.. doxygenenum:: tinytc::arithmetic +address_space +............. -arithmetic_unary -................ +.. doxygenenum:: tinytc::address_space -.. doxygenenum:: tinytc::arithmetic_unary +.. _tinytc::checked_flag: -cmp_condition -............. +checked_flag +............ + +.. doxygenenum:: tinytc::checked_flag + +.. _tinytc::comp3: + +comp3 +..... + +.. doxygenenum:: tinytc::comp3 + +.. _tinytc::matrix_use: + +matrix_use +.......... + +.. doxygenenum:: tinytc::matrix_use -.. doxygenenum:: tinytc::cmp_condition +.. _tinytc::reduce_mode: -scalar_type +reduce_mode ........... -.. doxygenenum:: tinytc::scalar_type +.. doxygenenum:: tinytc::reduce_mode + +.. _tinytc::store_flag: + +store_flag +.......... + +.. doxygenenum:: tinytc::store_flag + +.. _tinytc::transpose: transpose ......... @@ -81,44 +113,67 @@ transpose Common Functions ---------------- +.. _tinytc::is_dynamic_value: + is_dynamic_value ................ .. doxygenfunction:: tinytc::is_dynamic_value -to_string(arithmetic) -..................... +.. _tinytc::to_string(address_space): -.. doxygenfunction:: tinytc::to_string(arithmetic) +to_string(address_space) +........................ -to_string(arithmetic_unary) -........................... +.. doxygenfunction:: tinytc::to_string(address_space) -.. doxygenfunction:: tinytc::to_string(arithmetic_unary) +.. _tinytc::to_string(checked_flag): -to_string(cmp_condition) -........................ +to_string(checked_flag) +....................... + +.. doxygenfunction:: tinytc::to_string(checked_flag) + +.. _tinytc::to_string(comp3): + +to_string(comp3) +................ + +.. doxygenfunction:: tinytc::to_string(comp3) + +.. _tinytc::to_string(matrix_use): + +to_string(matrix_use) +..................... + +.. doxygenfunction:: tinytc::to_string(matrix_use) -.. doxygenfunction:: tinytc::to_string(cmp_condition) +.. _tinytc::to_string(reduce_mode): -to_string(scalar_type) +to_string(reduce_mode) ...................... -.. doxygenfunction:: tinytc::to_string(scalar_type) +.. doxygenfunction:: tinytc::to_string(reduce_mode) + +.. _tinytc::to_string(store_flag): + +to_string(store_flag) +..................... + +.. doxygenfunction:: tinytc::to_string(store_flag) + +.. _tinytc::to_string(transpose): to_string(transpose) .................... .. doxygenfunction:: tinytc::to_string(transpose) -size -.... - -.. doxygenfunction:: tinytc::size - Common Classes -------------- +.. _tinytc::builder_error: + builder_error ............. @@ -127,487 +182,1309 @@ builder_error Common Typedefs --------------- -position -........ - -.. doxygentypedef:: tinytc::position +.. _tinytc::location: location ........ .. doxygentypedef:: tinytc::location +.. _tinytc::position: + +position +........ + +.. doxygentypedef:: tinytc::position + Common Variables ---------------- +.. _tinytc::dynamic: + dynamic ....... .. doxygenvariable:: tinytc::dynamic -Data Type +Attribute ========= * Functions - * :ref:`make_memref` + * :ref:`get_dictionary_attr_with_sorted` + + * :ref:`sort_items` - * :ref:`make_group` +* Structures - * :ref:`make_scalar` + * :ref:`tinytc::getter\< array_attr \>` -* Classes + * :ref:`tinytc::getter\< boolean_attr \>` + + * :ref:`tinytc::getter\< dictionary_attr \>` + + * :ref:`tinytc::getter\< integer_attr \>` - * :ref:`data_type` + * :ref:`tinytc::getter\< string_attr \>` + +Attribute Functions +------------------- + +.. _get_dictionary_attr_with_sorted: + +get_dictionary_attr_with_sorted +............................... + +.. doxygenfunction:: get_dictionary_attr_with_sorted + +.. _sort_items: + +sort_items +.......... + +.. doxygenfunction:: sort_items + +Attribute Structures +-------------------- + +.. _tinytc::getter\< array_attr \>: + +getter +.................. + +.. doxygenstruct:: tinytc::getter< array_attr > + +.. _tinytc::getter\< boolean_attr \>: + +getter +.................... + +.. doxygenstruct:: tinytc::getter< boolean_attr > + +.. _tinytc::getter\< dictionary_attr \>: + +getter +....................... + +.. doxygenstruct:: tinytc::getter< dictionary_attr > + +.. _tinytc::getter\< integer_attr \>: + +getter +.................... + +.. doxygenstruct:: tinytc::getter< integer_attr > + +.. _tinytc::getter\< string_attr \>: + +getter +................... + +.. doxygenstruct:: tinytc::getter< string_attr > + +Data Type +========= + +* Functions + + * :ref:`tinytc::get` + + * :ref:`tinytc::get_compiler_context(const_tinytc_type_t)` + + * :ref:`tinytc::to_type` * Structures - * :ref:`to_scalar_type` + * :ref:`tinytc::getter\< boolean_type \>` -* Variables + * :ref:`tinytc::getter\< i8_type \>` + + * :ref:`tinytc::getter\< i16_type \>` + + * :ref:`tinytc::getter\< i32_type \>` + + * :ref:`tinytc::getter\< i64_type \>` + + * :ref:`tinytc::getter\< index_type \>` + + * :ref:`tinytc::getter\< f16_type \>` + + * :ref:`tinytc::getter\< bf16_type \>` + + * :ref:`tinytc::getter\< f32_type \>` + + * :ref:`tinytc::getter\< f64_type \>` - * :ref:`to_scalar_type_v` + * :ref:`tinytc::getter\< c32_type \>` + + * :ref:`tinytc::getter\< c64_type \>` + + * :ref:`tinytc::getter\< coopmatrix_type \>` + + * :ref:`tinytc::getter\< group_type \>` + + * :ref:`tinytc::getter\< memref_type \>` + + * :ref:`tinytc::getter\< void_type \>` Data Type Functions ------------------- -make_memref -........... +.. _tinytc::get: -.. doxygenfunction:: tinytc::make_memref +get +... -make_group -.......... +.. doxygenfunction:: tinytc::get -.. doxygenfunction:: tinytc::make_group +.. _tinytc::get_compiler_context(const_tinytc_type_t): -make_scalar -........... +get_compiler_context(const_tinytc_type_t) +......................................... -.. doxygenfunction:: tinytc::make_scalar +.. doxygenfunction:: tinytc::get_compiler_context(const_tinytc_type_t) -Data Type Classes ------------------ +.. _tinytc::to_type: -data_type -......... +to_type +....... -.. doxygenclass:: tinytc::data_type +.. doxygenfunction:: tinytc::to_type Data Type Structures -------------------- -to_scalar_type -.............. +.. _tinytc::getter\< boolean_type \>: -.. doxygenstruct:: tinytc::to_scalar_type +getter +.................... -Data Type Variables -------------------- +.. doxygenstruct:: tinytc::getter< boolean_type > -to_scalar_type_v -................ +.. _tinytc::getter\< i8_type \>: + +getter +............... -.. doxygenvariable:: tinytc::to_scalar_type_v +.. doxygenstruct:: tinytc::getter< i8_type > -Function -======== +.. _tinytc::getter\< i16_type \>: -* Functions +getter +................ - * :ref:`make_function` +.. doxygenstruct:: tinytc::getter< i16_type > - * :ref:`make_function_prototype` +.. _tinytc::getter\< i32_type \>: - * :ref:`set_work_group_size` +getter +................ - * :ref:`set_subgroup_size` +.. doxygenstruct:: tinytc::getter< i32_type > -* Classes +.. _tinytc::getter\< i64_type \>: - * :ref:`func` +getter +................ - * :ref:`function_builder` +.. doxygenstruct:: tinytc::getter< i64_type > -Function Functions ------------------- +.. _tinytc::getter\< index_type \>: -make_function -............. +getter +.................. -.. doxygenfunction:: tinytc::make_function +.. doxygenstruct:: tinytc::getter< index_type > -make_function_prototype -....................... +.. _tinytc::getter\< f16_type \>: -.. doxygenfunction:: tinytc::make_function_prototype +getter +................ -set_work_group_size -................... +.. doxygenstruct:: tinytc::getter< f16_type > -.. doxygenfunction:: tinytc::set_work_group_size +.. _tinytc::getter\< bf16_type \>: -set_subgroup_size +getter ................. -.. doxygenfunction:: tinytc::set_subgroup_size +.. doxygenstruct:: tinytc::getter< bf16_type > -Function Classes ----------------- +.. _tinytc::getter\< f32_type \>: -func -.... +getter +................ + +.. doxygenstruct:: tinytc::getter< f32_type > -.. doxygenclass:: tinytc::func +.. _tinytc::getter\< f64_type \>: -function_builder +getter ................ -.. doxygenclass:: tinytc::function_builder +.. doxygenstruct:: tinytc::getter< f64_type > -Instruction -=========== +.. _tinytc::getter\< c32_type \>: -* Functions +getter +................ - * :ref:`make_alloca` +.. doxygenstruct:: tinytc::getter< c32_type > - * :ref:`make_axpby` +.. _tinytc::getter\< c64_type \>: - * :ref:`make_arith(arithmetic,value const&,value const&,location const&)` +getter +................ - * :ref:`make_arith(arithmetic_unary,value const&,location const&)` +.. doxygenstruct:: tinytc::getter< c64_type > - * :ref:`make_cast` +.. _tinytc::getter\< coopmatrix_type \>: - * :ref:`make_cmp` +getter +....................... - * :ref:`make_expand` +.. doxygenstruct:: tinytc::getter< coopmatrix_type > - * :ref:`make_for` +.. _tinytc::getter\< group_type \>: - * :ref:`make_foreach` +getter +.................. - * :ref:`make_fuse` +.. doxygenstruct:: tinytc::getter< group_type > - * :ref:`make_gemm` +.. _tinytc::getter\< memref_type \>: - * :ref:`make_gemv` +getter +................... - * :ref:`make_ger` +.. doxygenstruct:: tinytc::getter< memref_type > - * :ref:`make_group_id` +.. _tinytc::getter\< void_type \>: - * :ref:`make_group_size` +getter +................. - * :ref:`make_hadamard` +.. doxygenstruct:: tinytc::getter< void_type > - * :ref:`make_if` +Function +======== - * :ref:`make_load` +* Functions - * :ref:`make_size` + * :ref:`tinytc::create_func` - * :ref:`make_store` + * :ref:`tinytc::get_body` - * :ref:`make_subview` + * :ref:`tinytc::set_attr(tinytc_func_t,tinytc_attr_t)` - * :ref:`make_sum` + * :ref:`tinytc::set_parameter_attr` - * :ref:`make_yield` +Function Functions +------------------ -* Classes +.. _tinytc::create_func: - * :ref:`inst` +create_func +........... -Instruction Functions ---------------------- +.. doxygenfunction:: tinytc::create_func -make_alloca -........... +.. _tinytc::get_body: -.. doxygenfunction:: tinytc::make_alloca +get_body +........ -make_axpby -.......... +.. doxygenfunction:: tinytc::get_body -.. doxygenfunction:: tinytc::make_axpby +.. _tinytc::set_attr(tinytc_func_t,tinytc_attr_t): -make_arith(arithmetic,value const&,value const&,location const&) -................................................................ +set_attr(tinytc_func_t,tinytc_attr_t) +..................................... -.. doxygenfunction:: tinytc::make_arith(arithmetic,value const&,value const&,location const&) +.. doxygenfunction:: tinytc::set_attr(tinytc_func_t,tinytc_attr_t) -make_arith(arithmetic_unary,value const&,location const&) -......................................................... +.. _tinytc::set_parameter_attr: -.. doxygenfunction:: tinytc::make_arith(arithmetic_unary,value const&,location const&) +set_parameter_attr +.................. -make_cast -......... +.. doxygenfunction:: tinytc::set_parameter_attr -.. doxygenfunction:: tinytc::make_cast +Instruction +=========== -make_cmp -........ +* Functions -.. doxygenfunction:: tinytc::make_cmp + * :ref:`tinytc::create` -make_expand -........... + * :ref:`tinytc::set_attr(tinytc_inst_t,tinytc_attr_t)` -.. doxygenfunction:: tinytc::make_expand +* Structures -make_for -........ + * :ref:`tinytc::creator\< alloca_inst \>` -.. doxygenfunction:: tinytc::make_for + * :ref:`tinytc::creator\< barrier_inst \>` -make_foreach -............ + * :ref:`tinytc::creator\< cast_inst \>` -.. doxygenfunction:: tinytc::make_foreach + * :ref:`tinytc::creator\< constant_inst \>` -make_fuse -......... + * :ref:`tinytc::creator\< cooperative_matrix_apply_inst \>` -.. doxygenfunction:: tinytc::make_fuse + * :ref:`tinytc::creator\< cooperative_matrix_extract_inst \>` -make_gemm -......... + * :ref:`tinytc::creator\< cooperative_matrix_insert_inst \>` -.. doxygenfunction:: tinytc::make_gemm + * :ref:`tinytc::creator\< cooperative_matrix_load_inst \>` -make_gemv -......... + * :ref:`tinytc::creator\< cooperative_matrix_mul_add_inst \>` -.. doxygenfunction:: tinytc::make_gemv + * :ref:`tinytc::creator\< cooperative_matrix_prefetch_inst \>` -make_ger -........ + * :ref:`tinytc::creator\< cooperative_matrix_reduce_add_inst \>` -.. doxygenfunction:: tinytc::make_ger + * :ref:`tinytc::creator\< cooperative_matrix_reduce_max_inst \>` -make_group_id -............. + * :ref:`tinytc::creator\< cooperative_matrix_reduce_min_inst \>` -.. doxygenfunction:: tinytc::make_group_id + * :ref:`tinytc::creator\< cooperative_matrix_scale_inst \>` -make_group_size -............... + * :ref:`tinytc::creator\< cooperative_matrix_store_inst \>` -.. doxygenfunction:: tinytc::make_group_size + * :ref:`tinytc::creator\< expand_inst \>` -make_hadamard -............. + * :ref:`tinytc::creator\< fuse_inst \>` -.. doxygenfunction:: tinytc::make_hadamard + * :ref:`tinytc::creator\< if_inst \>` -make_if -....... + * :ref:`tinytc::creator\< lifetime_stop_inst \>` -.. doxygenfunction:: tinytc::make_if + * :ref:`tinytc::creator\< load_inst \>` -make_load -......... + * :ref:`tinytc::creator\< parallel_inst \>` -.. doxygenfunction:: tinytc::make_load + * :ref:`tinytc::creator\< size_inst \>` -make_size -......... + * :ref:`tinytc::creator\< subgroup_broadcast_inst \>` -.. doxygenfunction:: tinytc::make_size + * :ref:`tinytc::creator\< subview_inst \>` -make_store -.......... + * :ref:`tinytc::creator\< store_inst \>` -.. doxygenfunction:: tinytc::make_store + * :ref:`tinytc::creator\< yield_inst \>` -make_subview -............ + * :ref:`tinytc::creator\< add_inst \>` -.. doxygenfunction:: tinytc::make_subview + * :ref:`tinytc::creator\< sub_inst \>` -make_sum -........ + * :ref:`tinytc::creator\< mul_inst \>` -.. doxygenfunction:: tinytc::make_sum + * :ref:`tinytc::creator\< div_inst \>` -make_yield -.......... + * :ref:`tinytc::creator\< rem_inst \>` -.. doxygenfunction:: tinytc::make_yield + * :ref:`tinytc::creator\< shl_inst \>` -Instruction Classes -------------------- + * :ref:`tinytc::creator\< shr_inst \>` -inst -.... + * :ref:`tinytc::creator\< and_inst \>` -.. doxygenclass:: tinytc::inst + * :ref:`tinytc::creator\< or_inst \>` -Program -======= + * :ref:`tinytc::creator\< xor_inst \>` -* Functions + * :ref:`tinytc::creator\< min_inst \>` - * :ref:`make_program` + * :ref:`tinytc::creator\< max_inst \>` -* Classes + * :ref:`tinytc::creator\< abs_inst \>` - * :ref:`prog` + * :ref:`tinytc::creator\< neg_inst \>` - * :ref:`program_builder` + * :ref:`tinytc::creator\< not_inst \>` -Program Functions ------------------ + * :ref:`tinytc::creator\< conj_inst \>` -make_program -............ + * :ref:`tinytc::creator\< im_inst \>` -.. doxygenfunction:: tinytc::make_program + * :ref:`tinytc::creator\< re_inst \>` -Program Classes ---------------- + * :ref:`tinytc::creator\< axpby_inst \>` -prog -.... + * :ref:`tinytc::creator\< cumsum_inst \>` -.. doxygenclass:: tinytc::prog + * :ref:`tinytc::creator\< sum_inst \>` -program_builder -............... + * :ref:`tinytc::creator\< gemm_inst \>` -.. doxygenclass:: tinytc::program_builder + * :ref:`tinytc::creator\< gemv_inst \>` -Region -====== + * :ref:`tinytc::creator\< ger_inst \>` -* Functions + * :ref:`tinytc::creator\< hadamard_inst \>` - * :ref:`make_region` + * :ref:`tinytc::creator\< group_id_inst \>` -* Classes + * :ref:`tinytc::creator\< num_groups_inst \>` - * :ref:`region` + * :ref:`tinytc::creator\< num_subgroups_inst \>` - * :ref:`region_builder` + * :ref:`tinytc::creator\< subgroup_size_inst \>` -Region Functions ----------------- + * :ref:`tinytc::creator\< subgroup_id_inst \>` -make_region -........... + * :ref:`tinytc::creator\< subgroup_linear_id_inst \>` -.. doxygenfunction:: tinytc::make_region + * :ref:`tinytc::creator\< subgroup_local_id_inst \>` -Region Classes --------------- + * :ref:`tinytc::creator\< equal_inst \>` -region -...... + * :ref:`tinytc::creator\< not_equal_inst \>` -.. doxygenclass:: tinytc::region + * :ref:`tinytc::creator\< greater_than_inst \>` -region_builder -.............. + * :ref:`tinytc::creator\< greater_than_equal_inst \>` -.. doxygenclass:: tinytc::region_builder + * :ref:`tinytc::creator\< less_than_inst \>` -Value -===== + * :ref:`tinytc::creator\< less_than_equal_inst \>` -* Functions + * :ref:`tinytc::creator\< for_inst \>` - * :ref:`make_dynamic(location const&)` + * :ref:`tinytc::creator\< foreach_inst \>` - * :ref:`make_imm(float,location const&)` + * :ref:`tinytc::creator\< cos_inst \>` - * :ref:`make_imm(double,scalar_type,location const&)` + * :ref:`tinytc::creator\< sin_inst \>` - * :ref:`make_imm(std::int8_t,location const&)` + * :ref:`tinytc::creator\< exp_inst \>` - * :ref:`make_imm(std::int16_t,location const&)` + * :ref:`tinytc::creator\< exp2_inst \>` - * :ref:`make_imm(std::int32_t,location const&)` + * :ref:`tinytc::creator\< native_cos_inst \>` - * :ref:`make_imm(std::int64_t,scalar_type,location const&)` + * :ref:`tinytc::creator\< native_sin_inst \>` - * :ref:`make_index(std::int32_t,location const&)` + * :ref:`tinytc::creator\< native_exp_inst \>` - * :ref:`make_index(std::int64_t,location const&)` + * :ref:`tinytc::creator\< native_exp2_inst \>` - * :ref:`make_value(data_type const&,location const&)` + * :ref:`tinytc::creator\< subgroup_exclusive_scan_add_inst \>` - * :ref:`make_value(scalar_type,location const&)` + * :ref:`tinytc::creator\< subgroup_exclusive_scan_max_inst \>` -* Classes + * :ref:`tinytc::creator\< subgroup_exclusive_scan_min_inst \>` - * :ref:`value` + * :ref:`tinytc::creator\< subgroup_inclusive_scan_add_inst \>` -Value Functions ---------------- + * :ref:`tinytc::creator\< subgroup_inclusive_scan_max_inst \>` -make_dynamic(location const&) -............................. + * :ref:`tinytc::creator\< subgroup_inclusive_scan_min_inst \>` -.. doxygenfunction:: tinytc::make_dynamic(location const&) + * :ref:`tinytc::creator\< subgroup_reduce_add_inst \>` -make_imm(float,location const&) -............................... + * :ref:`tinytc::creator\< subgroup_reduce_max_inst \>` + + * :ref:`tinytc::creator\< subgroup_reduce_min_inst \>` + +Instruction Functions +--------------------- -.. doxygenfunction:: tinytc::make_imm(float,location const&) +.. _tinytc::create: -make_imm(double,scalar_type,location const&) -............................................ +create +...... + +.. doxygenfunction:: tinytc::create -.. doxygenfunction:: tinytc::make_imm(double,scalar_type,location const&) +.. _tinytc::set_attr(tinytc_inst_t,tinytc_attr_t): -make_imm(std::int8_t,location const&) +set_attr(tinytc_inst_t,tinytc_attr_t) ..................................... -.. doxygenfunction:: tinytc::make_imm(std::int8_t,location const&) +.. doxygenfunction:: tinytc::set_attr(tinytc_inst_t,tinytc_attr_t) -make_imm(std::int16_t,location const&) -...................................... +Instruction Structures +---------------------- -.. doxygenfunction:: tinytc::make_imm(std::int16_t,location const&) +.. _tinytc::creator\< alloca_inst \>: -make_imm(std::int32_t,location const&) -...................................... +creator +.................... + +.. doxygenstruct:: tinytc::creator< alloca_inst > + +.. _tinytc::creator\< barrier_inst \>: + +creator +..................... + +.. doxygenstruct:: tinytc::creator< barrier_inst > + +.. _tinytc::creator\< cast_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< cast_inst > + +.. _tinytc::creator\< constant_inst \>: + +creator +...................... -.. doxygenfunction:: tinytc::make_imm(std::int32_t,location const&) +.. doxygenstruct:: tinytc::creator< constant_inst > -make_imm(std::int64_t,scalar_type,location const&) -.................................................. +.. _tinytc::creator\< cooperative_matrix_apply_inst \>: -.. doxygenfunction:: tinytc::make_imm(std::int64_t,scalar_type,location const&) +creator +...................................... + +.. doxygenstruct:: tinytc::creator< cooperative_matrix_apply_inst > -make_index(std::int32_t,location const&) +.. _tinytc::creator\< cooperative_matrix_extract_inst \>: + +creator ........................................ -.. doxygenfunction:: tinytc::make_index(std::int32_t,location const&) +.. doxygenstruct:: tinytc::creator< cooperative_matrix_extract_inst > + +.. _tinytc::creator\< cooperative_matrix_insert_inst \>: + +creator +....................................... + +.. doxygenstruct:: tinytc::creator< cooperative_matrix_insert_inst > + +.. _tinytc::creator\< cooperative_matrix_load_inst \>: -make_index(std::int64_t,location const&) +creator +..................................... + +.. doxygenstruct:: tinytc::creator< cooperative_matrix_load_inst > + +.. _tinytc::creator\< cooperative_matrix_mul_add_inst \>: + +creator ........................................ -.. doxygenfunction:: tinytc::make_index(std::int64_t,location const&) +.. doxygenstruct:: tinytc::creator< cooperative_matrix_mul_add_inst > -make_value(data_type const&,location const&) -............................................ +.. _tinytc::creator\< cooperative_matrix_prefetch_inst \>: -.. doxygenfunction:: tinytc::make_value(data_type const&,location const&) +creator +......................................... -make_value(scalar_type,location const&) -....................................... +.. doxygenstruct:: tinytc::creator< cooperative_matrix_prefetch_inst > -.. doxygenfunction:: tinytc::make_value(scalar_type,location const&) +.. _tinytc::creator\< cooperative_matrix_reduce_add_inst \>: -Value Classes -------------- +creator +........................................... -value -..... +.. doxygenstruct:: tinytc::creator< cooperative_matrix_reduce_add_inst > + +.. _tinytc::creator\< cooperative_matrix_reduce_max_inst \>: + +creator +........................................... -.. doxygenclass:: tinytc::value +.. doxygenstruct:: tinytc::creator< cooperative_matrix_reduce_max_inst > + +.. _tinytc::creator\< cooperative_matrix_reduce_min_inst \>: + +creator +........................................... + +.. doxygenstruct:: tinytc::creator< cooperative_matrix_reduce_min_inst > + +.. _tinytc::creator\< cooperative_matrix_scale_inst \>: + +creator +...................................... + +.. doxygenstruct:: tinytc::creator< cooperative_matrix_scale_inst > + +.. _tinytc::creator\< cooperative_matrix_store_inst \>: + +creator +...................................... + +.. doxygenstruct:: tinytc::creator< cooperative_matrix_store_inst > + +.. _tinytc::creator\< expand_inst \>: + +creator +.................... + +.. doxygenstruct:: tinytc::creator< expand_inst > + +.. _tinytc::creator\< fuse_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< fuse_inst > + +.. _tinytc::creator\< if_inst \>: + +creator +................ + +.. doxygenstruct:: tinytc::creator< if_inst > + +.. _tinytc::creator\< lifetime_stop_inst \>: + +creator +........................... + +.. doxygenstruct:: tinytc::creator< lifetime_stop_inst > + +.. _tinytc::creator\< load_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< load_inst > + +.. _tinytc::creator\< parallel_inst \>: + +creator +...................... + +.. doxygenstruct:: tinytc::creator< parallel_inst > + +.. _tinytc::creator\< size_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< size_inst > + +.. _tinytc::creator\< subgroup_broadcast_inst \>: + +creator +................................ + +.. doxygenstruct:: tinytc::creator< subgroup_broadcast_inst > + +.. _tinytc::creator\< subview_inst \>: + +creator +..................... + +.. doxygenstruct:: tinytc::creator< subview_inst > + +.. _tinytc::creator\< store_inst \>: + +creator +................... + +.. doxygenstruct:: tinytc::creator< store_inst > + +.. _tinytc::creator\< yield_inst \>: + +creator +................... + +.. doxygenstruct:: tinytc::creator< yield_inst > + +.. _tinytc::creator\< add_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< add_inst > + +.. _tinytc::creator\< sub_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< sub_inst > + +.. _tinytc::creator\< mul_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< mul_inst > + +.. _tinytc::creator\< div_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< div_inst > + +.. _tinytc::creator\< rem_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< rem_inst > + +.. _tinytc::creator\< shl_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< shl_inst > + +.. _tinytc::creator\< shr_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< shr_inst > + +.. _tinytc::creator\< and_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< and_inst > + +.. _tinytc::creator\< or_inst \>: + +creator +................ + +.. doxygenstruct:: tinytc::creator< or_inst > + +.. _tinytc::creator\< xor_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< xor_inst > + +.. _tinytc::creator\< min_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< min_inst > + +.. _tinytc::creator\< max_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< max_inst > + +.. _tinytc::creator\< abs_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< abs_inst > + +.. _tinytc::creator\< neg_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< neg_inst > + +.. _tinytc::creator\< not_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< not_inst > + +.. _tinytc::creator\< conj_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< conj_inst > + +.. _tinytc::creator\< im_inst \>: + +creator +................ + +.. doxygenstruct:: tinytc::creator< im_inst > + +.. _tinytc::creator\< re_inst \>: + +creator +................ + +.. doxygenstruct:: tinytc::creator< re_inst > + +.. _tinytc::creator\< axpby_inst \>: + +creator +................... + +.. doxygenstruct:: tinytc::creator< axpby_inst > + +.. _tinytc::creator\< cumsum_inst \>: + +creator +.................... + +.. doxygenstruct:: tinytc::creator< cumsum_inst > + +.. _tinytc::creator\< sum_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< sum_inst > + +.. _tinytc::creator\< gemm_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< gemm_inst > + +.. _tinytc::creator\< gemv_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< gemv_inst > + +.. _tinytc::creator\< ger_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< ger_inst > + +.. _tinytc::creator\< hadamard_inst \>: + +creator +...................... + +.. doxygenstruct:: tinytc::creator< hadamard_inst > + +.. _tinytc::creator\< group_id_inst \>: + +creator +...................... + +.. doxygenstruct:: tinytc::creator< group_id_inst > + +.. _tinytc::creator\< num_groups_inst \>: + +creator +........................ + +.. doxygenstruct:: tinytc::creator< num_groups_inst > + +.. _tinytc::creator\< num_subgroups_inst \>: + +creator +........................... + +.. doxygenstruct:: tinytc::creator< num_subgroups_inst > + +.. _tinytc::creator\< subgroup_size_inst \>: + +creator +........................... + +.. doxygenstruct:: tinytc::creator< subgroup_size_inst > + +.. _tinytc::creator\< subgroup_id_inst \>: + +creator +......................... + +.. doxygenstruct:: tinytc::creator< subgroup_id_inst > + +.. _tinytc::creator\< subgroup_linear_id_inst \>: + +creator +................................ + +.. doxygenstruct:: tinytc::creator< subgroup_linear_id_inst > + +.. _tinytc::creator\< subgroup_local_id_inst \>: + +creator +............................... + +.. doxygenstruct:: tinytc::creator< subgroup_local_id_inst > + +.. _tinytc::creator\< equal_inst \>: + +creator +................... + +.. doxygenstruct:: tinytc::creator< equal_inst > + +.. _tinytc::creator\< not_equal_inst \>: + +creator +....................... + +.. doxygenstruct:: tinytc::creator< not_equal_inst > + +.. _tinytc::creator\< greater_than_inst \>: + +creator +.......................... + +.. doxygenstruct:: tinytc::creator< greater_than_inst > + +.. _tinytc::creator\< greater_than_equal_inst \>: + +creator +................................ + +.. doxygenstruct:: tinytc::creator< greater_than_equal_inst > + +.. _tinytc::creator\< less_than_inst \>: + +creator +....................... + +.. doxygenstruct:: tinytc::creator< less_than_inst > + +.. _tinytc::creator\< less_than_equal_inst \>: + +creator +............................. + +.. doxygenstruct:: tinytc::creator< less_than_equal_inst > + +.. _tinytc::creator\< for_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< for_inst > + +.. _tinytc::creator\< foreach_inst \>: + +creator +..................... + +.. doxygenstruct:: tinytc::creator< foreach_inst > + +.. _tinytc::creator\< cos_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< cos_inst > + +.. _tinytc::creator\< sin_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< sin_inst > + +.. _tinytc::creator\< exp_inst \>: + +creator +................. + +.. doxygenstruct:: tinytc::creator< exp_inst > + +.. _tinytc::creator\< exp2_inst \>: + +creator +.................. + +.. doxygenstruct:: tinytc::creator< exp2_inst > + +.. _tinytc::creator\< native_cos_inst \>: + +creator +........................ + +.. doxygenstruct:: tinytc::creator< native_cos_inst > + +.. _tinytc::creator\< native_sin_inst \>: + +creator +........................ + +.. doxygenstruct:: tinytc::creator< native_sin_inst > + +.. _tinytc::creator\< native_exp_inst \>: + +creator +........................ + +.. doxygenstruct:: tinytc::creator< native_exp_inst > + +.. _tinytc::creator\< native_exp2_inst \>: + +creator +......................... + +.. doxygenstruct:: tinytc::creator< native_exp2_inst > + +.. _tinytc::creator\< subgroup_exclusive_scan_add_inst \>: + +creator +......................................... + +.. doxygenstruct:: tinytc::creator< subgroup_exclusive_scan_add_inst > + +.. _tinytc::creator\< subgroup_exclusive_scan_max_inst \>: + +creator +......................................... + +.. doxygenstruct:: tinytc::creator< subgroup_exclusive_scan_max_inst > + +.. _tinytc::creator\< subgroup_exclusive_scan_min_inst \>: + +creator +......................................... + +.. doxygenstruct:: tinytc::creator< subgroup_exclusive_scan_min_inst > + +.. _tinytc::creator\< subgroup_inclusive_scan_add_inst \>: + +creator +......................................... + +.. doxygenstruct:: tinytc::creator< subgroup_inclusive_scan_add_inst > + +.. _tinytc::creator\< subgroup_inclusive_scan_max_inst \>: + +creator +......................................... + +.. doxygenstruct:: tinytc::creator< subgroup_inclusive_scan_max_inst > + +.. _tinytc::creator\< subgroup_inclusive_scan_min_inst \>: + +creator +......................................... + +.. doxygenstruct:: tinytc::creator< subgroup_inclusive_scan_min_inst > + +.. _tinytc::creator\< subgroup_reduce_add_inst \>: + +creator +................................. + +.. doxygenstruct:: tinytc::creator< subgroup_reduce_add_inst > + +.. _tinytc::creator\< subgroup_reduce_max_inst \>: + +creator +................................. + +.. doxygenstruct:: tinytc::creator< subgroup_reduce_max_inst > + +.. _tinytc::creator\< subgroup_reduce_min_inst \>: + +creator +................................. + +.. doxygenstruct:: tinytc::creator< subgroup_reduce_min_inst > + +Program +======= + +* Functions + + * :ref:`tinytc::add_function` + + * :ref:`tinytc::create_prog` + +Program Functions +----------------- + +.. _tinytc::add_function: + +add_function +............ + +.. doxygenfunction:: tinytc::add_function + +.. _tinytc::create_prog: + +create_prog +........... + +.. doxygenfunction:: tinytc::create_prog + +Recipe +====== + +* Functions + + * :ref:`tinytc::create_small_gemm_batched` + + * :ref:`tinytc::create_tall_and_skinny` + + * :ref:`tinytc::create_tall_and_skinny_specialized` + + * :ref:`tinytc::get_prog` + + * :ref:`tinytc::get_binary` + + * :ref:`tinytc::get_recipe` + + * :ref:`tinytc::set_small_gemm_batched_args` + + * :ref:`tinytc::set_tall_and_skinny_args` + + * :ref:`tinytc::to_string(mem_type)` + +Recipe Functions +---------------- + +.. _tinytc::create_small_gemm_batched: + +create_small_gemm_batched +......................... + +.. doxygenfunction:: tinytc::create_small_gemm_batched + +.. _tinytc::create_tall_and_skinny: + +create_tall_and_skinny +...................... + +.. doxygenfunction:: tinytc::create_tall_and_skinny + +.. _tinytc::create_tall_and_skinny_specialized: + +create_tall_and_skinny_specialized +.................................. + +.. doxygenfunction:: tinytc::create_tall_and_skinny_specialized + +.. _tinytc::get_prog: + +get_prog +........ + +.. doxygenfunction:: tinytc::get_prog + +.. _tinytc::get_binary: + +get_binary +.......... + +.. doxygenfunction:: tinytc::get_binary + +.. _tinytc::get_recipe: + +get_recipe +.......... + +.. doxygenfunction:: tinytc::get_recipe + +.. _tinytc::set_small_gemm_batched_args: + +set_small_gemm_batched_args +........................... + +.. doxygenfunction:: tinytc::set_small_gemm_batched_args + +.. _tinytc::set_tall_and_skinny_args: + +set_tall_and_skinny_args +........................ + +.. doxygenfunction:: tinytc::set_tall_and_skinny_args + +.. _tinytc::to_string(mem_type): + +to_string(mem_type) +................... + +.. doxygenfunction:: tinytc::to_string(mem_type) + +Region +====== + +* Functions + + * :ref:`tinytc::append` + + * :ref:`tinytc::begin` + + * :ref:`tinytc::end` + + * :ref:`tinytc::get_parameters` + + * :ref:`tinytc::insert` + + * :ref:`tinytc::next` + + * :ref:`tinytc::prev` + +* Classes + + * :ref:`tinytc::region_builder` + +Region Functions +---------------- + +.. _tinytc::append: + +append +...... + +.. doxygenfunction:: tinytc::append + +.. _tinytc::begin: + +begin +..... + +.. doxygenfunction:: tinytc::begin + +.. _tinytc::end: + +end +... + +.. doxygenfunction:: tinytc::end + +.. _tinytc::get_parameters: + +get_parameters +.............. + +.. doxygenfunction:: tinytc::get_parameters + +.. _tinytc::insert: + +insert +...... + +.. doxygenfunction:: tinytc::insert + +.. _tinytc::next: + +next +.... + +.. doxygenfunction:: tinytc::next + +.. _tinytc::prev: + +prev +.... + +.. doxygenfunction:: tinytc::prev + +Region Classes +-------------- + +.. _tinytc::region_builder: + +region_builder +.............. + +.. doxygenclass:: tinytc::region_builder diff --git a/docs/api/builder_cxxapi.yaml b/docs/api/builder_cxxapi.yaml index b2358d2a..b36bca9d 100644 --- a/docs/api/builder_cxxapi.yaml +++ b/docs/api/builder_cxxapi.yaml @@ -3,97 +3,178 @@ Builder C++-API: Common: enum: - - tinytc::arithmetic - - tinytc::arithmetic_unary - - tinytc::cmp_condition - - tinytc::scalar_type + - tinytc::address_space + - tinytc::checked_flag + - tinytc::comp3 + - tinytc::matrix_use + - tinytc::reduce_mode + - tinytc::store_flag - tinytc::transpose function: - tinytc::is_dynamic_value - - tinytc::to_string(arithmetic) - - tinytc::to_string(arithmetic_unary) - - tinytc::to_string(cmp_condition) - - tinytc::to_string(scalar_type) + - tinytc::to_string(address_space) + - tinytc::to_string(checked_flag) + - tinytc::to_string(comp3) + - tinytc::to_string(matrix_use) + - tinytc::to_string(reduce_mode) + - tinytc::to_string(store_flag) - tinytc::to_string(transpose) - - tinytc::size class: - tinytc::builder_error typedef: - - tinytc::position - tinytc::location + - tinytc::position variable: - tinytc::dynamic + Attribute: + function: + - get_dictionary_attr_with_sorted + - sort_items + struct: + - tinytc::getter< array_attr > + - tinytc::getter< boolean_attr > + - tinytc::getter< dictionary_attr > + - tinytc::getter< integer_attr > + - tinytc::getter< string_attr > Data Type: function: - - tinytc::make_memref - - tinytc::make_group - - tinytc::make_scalar - class: - - tinytc::data_type + - tinytc::get + - tinytc::get_compiler_context(const_tinytc_type_t) + - tinytc::to_type struct: - - tinytc::to_scalar_type - variable: - - tinytc::to_scalar_type_v + - tinytc::getter< boolean_type > + - tinytc::getter< i8_type > + - tinytc::getter< i16_type > + - tinytc::getter< i32_type > + - tinytc::getter< i64_type > + - tinytc::getter< index_type > + - tinytc::getter< f16_type > + - tinytc::getter< bf16_type > + - tinytc::getter< f32_type > + - tinytc::getter< f64_type > + - tinytc::getter< c32_type > + - tinytc::getter< c64_type > + - tinytc::getter< coopmatrix_type > + - tinytc::getter< group_type > + - tinytc::getter< memref_type > + - tinytc::getter< void_type > Function: function: - - tinytc::make_function - - tinytc::make_function_prototype - - tinytc::set_work_group_size - - tinytc::set_subgroup_size - class: - - tinytc::func - - tinytc::function_builder + - tinytc::create_func + - tinytc::get_body + - tinytc::set_attr(tinytc_func_t,tinytc_attr_t) + - tinytc::set_parameter_attr Instruction: function: - - tinytc::make_alloca - - tinytc::make_axpby - - tinytc::make_arith(arithmetic,value const&,value const&,location const&) - - tinytc::make_arith(arithmetic_unary,value const&,location const&) - - tinytc::make_cast - - tinytc::make_cmp - - tinytc::make_expand - - tinytc::make_for - - tinytc::make_foreach - - tinytc::make_fuse - - tinytc::make_gemm - - tinytc::make_gemv - - tinytc::make_ger - - tinytc::make_group_id - - tinytc::make_group_size - - tinytc::make_hadamard - - tinytc::make_if - - tinytc::make_load - - tinytc::make_size - - tinytc::make_store - - tinytc::make_subview - - tinytc::make_sum - - tinytc::make_yield - class: - - tinytc::inst + - tinytc::create + - tinytc::set_attr(tinytc_inst_t,tinytc_attr_t) + struct: + - tinytc::creator< alloca_inst > + - tinytc::creator< barrier_inst > + - tinytc::creator< cast_inst > + - tinytc::creator< constant_inst > + - tinytc::creator< cooperative_matrix_apply_inst > + - tinytc::creator< cooperative_matrix_extract_inst > + - tinytc::creator< cooperative_matrix_insert_inst > + - tinytc::creator< cooperative_matrix_load_inst > + - tinytc::creator< cooperative_matrix_mul_add_inst > + - tinytc::creator< cooperative_matrix_prefetch_inst > + - tinytc::creator< cooperative_matrix_reduce_add_inst > + - tinytc::creator< cooperative_matrix_reduce_max_inst > + - tinytc::creator< cooperative_matrix_reduce_min_inst > + - tinytc::creator< cooperative_matrix_scale_inst > + - tinytc::creator< cooperative_matrix_store_inst > + - tinytc::creator< expand_inst > + - tinytc::creator< fuse_inst > + - tinytc::creator< if_inst > + - tinytc::creator< lifetime_stop_inst > + - tinytc::creator< load_inst > + - tinytc::creator< parallel_inst > + - tinytc::creator< size_inst > + - tinytc::creator< subgroup_broadcast_inst > + - tinytc::creator< subview_inst > + - tinytc::creator< store_inst > + - tinytc::creator< yield_inst > + - tinytc::creator< add_inst > + - tinytc::creator< sub_inst > + - tinytc::creator< mul_inst > + - tinytc::creator< div_inst > + - tinytc::creator< rem_inst > + - tinytc::creator< shl_inst > + - tinytc::creator< shr_inst > + - tinytc::creator< and_inst > + - tinytc::creator< or_inst > + - tinytc::creator< xor_inst > + - tinytc::creator< min_inst > + - tinytc::creator< max_inst > + - tinytc::creator< abs_inst > + - tinytc::creator< neg_inst > + - tinytc::creator< not_inst > + - tinytc::creator< conj_inst > + - tinytc::creator< im_inst > + - tinytc::creator< re_inst > + - tinytc::creator< axpby_inst > + - tinytc::creator< cumsum_inst > + - tinytc::creator< sum_inst > + - tinytc::creator< gemm_inst > + - tinytc::creator< gemv_inst > + - tinytc::creator< ger_inst > + - tinytc::creator< hadamard_inst > + - tinytc::creator< group_id_inst > + - tinytc::creator< num_groups_inst > + - tinytc::creator< num_subgroups_inst > + - tinytc::creator< subgroup_size_inst > + - tinytc::creator< subgroup_id_inst > + - tinytc::creator< subgroup_linear_id_inst > + - tinytc::creator< subgroup_local_id_inst > + - tinytc::creator< equal_inst > + - tinytc::creator< not_equal_inst > + - tinytc::creator< greater_than_inst > + - tinytc::creator< greater_than_equal_inst > + - tinytc::creator< less_than_inst > + - tinytc::creator< less_than_equal_inst > + - tinytc::creator< for_inst > + - tinytc::creator< foreach_inst > + - tinytc::creator< cos_inst > + - tinytc::creator< sin_inst > + - tinytc::creator< exp_inst > + - tinytc::creator< exp2_inst > + - tinytc::creator< native_cos_inst > + - tinytc::creator< native_sin_inst > + - tinytc::creator< native_exp_inst > + - tinytc::creator< native_exp2_inst > + - tinytc::creator< subgroup_exclusive_scan_add_inst > + - tinytc::creator< subgroup_exclusive_scan_max_inst > + - tinytc::creator< subgroup_exclusive_scan_min_inst > + - tinytc::creator< subgroup_inclusive_scan_add_inst > + - tinytc::creator< subgroup_inclusive_scan_max_inst > + - tinytc::creator< subgroup_inclusive_scan_min_inst > + - tinytc::creator< subgroup_reduce_add_inst > + - tinytc::creator< subgroup_reduce_max_inst > + - tinytc::creator< subgroup_reduce_min_inst > Program: function: - - tinytc::make_program - class: - - tinytc::prog - - tinytc::program_builder + - tinytc::add_function + - tinytc::create_prog + Recipe: + function: + - tinytc::create_small_gemm_batched + - tinytc::create_tall_and_skinny + - tinytc::create_tall_and_skinny_specialized + - tinytc::get_prog + - tinytc::get_binary + - tinytc::get_recipe + - tinytc::set_small_gemm_batched_args + - tinytc::set_tall_and_skinny_args + - tinytc::to_string(mem_type) Region: function: - - tinytc::make_region + - tinytc::append + - tinytc::begin + - tinytc::end + - tinytc::get_parameters + - tinytc::insert + - tinytc::next + - tinytc::prev class: - - tinytc::region - tinytc::region_builder - Value: - function: - - tinytc::make_dynamic(location const&) - - tinytc::make_imm(float,location const&) - - tinytc::make_imm(double,scalar_type,location const&) - - tinytc::make_imm(std::int8_t,location const&) - - tinytc::make_imm(std::int16_t,location const&) - - tinytc::make_imm(std::int32_t,location const&) - - tinytc::make_imm(std::int64_t,scalar_type,location const&) - - tinytc::make_index(std::int32_t,location const&) - - tinytc::make_index(std::int64_t,location const&) - - tinytc::make_value(data_type const&,location const&) - - tinytc::make_value(scalar_type,location const&) - class: - - tinytc::value diff --git a/docs/api/cl/capi.rst b/docs/api/cl/capi.rst index f4d2d8a2..502129f3 100644 --- a/docs/api/cl/capi.rst +++ b/docs/api/cl/capi.rst @@ -1,25 +1,12 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _OpenCL C-API: + ===== C-API ===== -Common -====== - -* Functions - - * :ref:`tinytc_cl_convert_status` - -Common Functions ----------------- - -tinytc_cl_convert_status -........................ - -.. doxygenfunction:: tinytc_cl_convert_status - Device Info =========== @@ -32,11 +19,15 @@ Device Info Device Info Functions --------------------- +.. _tinytc_cl_core_info_create: + tinytc_cl_core_info_create .......................... .. doxygenfunction:: tinytc_cl_core_info_create +.. _tinytc_cl_get_support_level: + tinytc_cl_get_support_level ........................... @@ -51,8 +42,6 @@ Kernel * :ref:`tinytc_cl_get_group_size` - * :ref:`tinytc_cl_kernel_bundle_create_with_source` - * :ref:`tinytc_cl_kernel_bundle_create_with_program` * :ref:`tinytc_cl_kernel_bundle_create_with_binary` @@ -60,26 +49,29 @@ Kernel Kernel Functions ---------------- +.. _tinytc_cl_get_global_size: + tinytc_cl_get_global_size ......................... .. doxygenfunction:: tinytc_cl_get_global_size +.. _tinytc_cl_get_group_size: + tinytc_cl_get_group_size ........................ .. doxygenfunction:: tinytc_cl_get_group_size -tinytc_cl_kernel_bundle_create_with_source -.......................................... - -.. doxygenfunction:: tinytc_cl_kernel_bundle_create_with_source +.. _tinytc_cl_kernel_bundle_create_with_program: tinytc_cl_kernel_bundle_create_with_program ........................................... .. doxygenfunction:: tinytc_cl_kernel_bundle_create_with_program +.. _tinytc_cl_kernel_bundle_create_with_binary: + tinytc_cl_kernel_bundle_create_with_binary .......................................... @@ -97,11 +89,15 @@ Recipe Recipe Functions ---------------- +.. _tinytc_cl_recipe_handler_create: + tinytc_cl_recipe_handler_create ............................... .. doxygenfunction:: tinytc_cl_recipe_handler_create +.. _tinytc_cl_recipe_handler_submit: + tinytc_cl_recipe_handler_submit ............................... diff --git a/docs/api/cl/capi.yaml b/docs/api/cl/capi.yaml index 67ed7e61..66b436fa 100644 --- a/docs/api/cl/capi.yaml +++ b/docs/api/cl/capi.yaml @@ -1,9 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C-API: - Common: - function: - - tinytc_cl_convert_status +C-API : Device Info: function: - tinytc_cl_core_info_create @@ -12,7 +9,6 @@ C-API: function: - tinytc_cl_get_global_size - tinytc_cl_get_group_size - - tinytc_cl_kernel_bundle_create_with_source - tinytc_cl_kernel_bundle_create_with_program - tinytc_cl_kernel_bundle_create_with_binary Recipe: diff --git a/docs/api/cl/cxxapi.rst b/docs/api/cl/cxxapi.rst index 2245e883..31279e51 100644 --- a/docs/api/cl/cxxapi.rst +++ b/docs/api/cl/cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _OpenCL C++-API: + ======= C++-API ======= @@ -10,11 +12,13 @@ Common * Functions - * :ref:`CL_CHECK_STATUS` + * :ref:`tinytc::CL_CHECK_STATUS` Common Functions ---------------- +.. _tinytc::CL_CHECK_STATUS: + CL_CHECK_STATUS ............... @@ -25,18 +29,22 @@ Device Info * Functions - * :ref:`get_support_level(cl_device_id)` + * :ref:`tinytc::get_support_level(cl_device_id)` - * :ref:`make_core_info(cl_device_id)` + * :ref:`tinytc::make_core_info(cl_device_id)` Device Info Functions --------------------- +.. _tinytc::get_support_level(cl_device_id): + get_support_level(cl_device_id) ............................... .. doxygenfunction:: tinytc::get_support_level(cl_device_id) +.. _tinytc::make_core_info(cl_device_id): + make_core_info(cl_device_id) ............................ @@ -47,85 +55,98 @@ Kernel * Functions - * :ref:`get_global_size(std::int64_t,std::array\ const &)` - - * :ref:`get_group_size(cl_kernel)` + * :ref:`tinytc::get_global_size(std::array\ const &,std::array\ const &)` - * :ref:`make_kernel(cl_program,char const\\*)` + * :ref:`tinytc::get_group_size(cl_kernel)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context)` + * :ref:`tinytc::make_kernel(cl_program,char const\*)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`tinytc::make_kernel_bundle(cl_context,cl_device_id,const_tinytc_binary_t)` - * :ref:`make_kernel_bundle(cl_context,cl_device_id,source const&,source_context)` + * :ref:`tinytc::make_kernel_bundle(cl_context,cl_device_id,tinytc_prog_t,tinytc_core_feature_flags_t)` Kernel Functions ---------------- -get_global_size(std::int64_t,std::array const &) -................................................................. +.. _tinytc::get_global_size(std::array\ const &,std::array\ const &): + +get_global_size(std::array const &,std::array const &) +...................................................................................... -.. doxygenfunction:: tinytc::get_global_size(std::int64_t,std::array const &) +.. doxygenfunction:: tinytc::get_global_size(std::array const &,std::array const &) + +.. _tinytc::get_group_size(cl_kernel): get_group_size(cl_kernel) ......................... .. doxygenfunction:: tinytc::get_group_size(cl_kernel) +.. _tinytc::make_kernel(cl_program,char const\*): + make_kernel(cl_program,char const\*) .................................... .. doxygenfunction:: tinytc::make_kernel(cl_program,char const*) -make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) -........................................................................ +.. _tinytc::make_kernel_bundle(cl_context,cl_device_id,const_tinytc_binary_t): -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) +make_kernel_bundle(cl_context,cl_device_id,const_tinytc_binary_t) +................................................................. -make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) -........................................................................................... +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,const_tinytc_binary_t) -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) +.. _tinytc::make_kernel_bundle(cl_context,cl_device_id,tinytc_prog_t,tinytc_core_feature_flags_t): -make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) -........................................................................ +make_kernel_bundle(cl_context,cl_device_id,tinytc_prog_t,tinytc_core_feature_flags_t) +..................................................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(cl_context,cl_device_id,tinytc_prog_t,tinytc_core_feature_flags_t) Recipe ====== * Functions - * :ref:`make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context)` + * :ref:`tinytc::make_recipe_handler(cl_context,cl_device_id,tinytc_recipe_t)` -* Classes + * :ref:`tinytc::submit(tinytc_recipe_handler_t,cl_command_queue,uint32_t,cl_event\*)` - * :ref:`opencl_recipe_handler` + * :ref:`tinytc::submit_no_event` * Structures - * :ref:`auto_mem_type\` + * :ref:`tinytc::auto_mem_type\< cl_mem \>` Recipe Functions ---------------- -make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) -......................................................................... +.. _tinytc::make_recipe_handler(cl_context,cl_device_id,tinytc_recipe_t): -.. doxygenfunction:: tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) +make_recipe_handler(cl_context,cl_device_id,tinytc_recipe_t) +............................................................ -Recipe Classes --------------- +.. doxygenfunction:: tinytc::make_recipe_handler(cl_context,cl_device_id,tinytc_recipe_t) -opencl_recipe_handler -..................... +.. _tinytc::submit(tinytc_recipe_handler_t,cl_command_queue,uint32_t,cl_event\*): + +submit(tinytc_recipe_handler_t,cl_command_queue,uint32_t,cl_event\*) +.................................................................... + +.. doxygenfunction:: tinytc::submit(tinytc_recipe_handler_t,cl_command_queue,uint32_t,cl_event*) -.. doxygenclass:: tinytc::opencl_recipe_handler +.. _tinytc::submit_no_event: + +submit_no_event +............... + +.. doxygenfunction:: tinytc::submit_no_event Recipe Structures ----------------- +.. _tinytc::auto_mem_type\< cl_mem \>: + auto_mem_type ..................... diff --git a/docs/api/cl/cxxapi.yaml b/docs/api/cl/cxxapi.yaml index a61f974d..53273db6 100644 --- a/docs/api/cl/cxxapi.yaml +++ b/docs/api/cl/cxxapi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C++-API: +C++-API : Common: function: - tinytc::CL_CHECK_STATUS @@ -10,16 +10,15 @@ C++-API: - tinytc::make_core_info(cl_device_id) Kernel: function: - - tinytc::get_global_size(std::int64_t,std::array const &) + - tinytc::get_global_size(std::array const &,std::array const &) - tinytc::get_group_size(cl_kernel) - tinytc::make_kernel(cl_program,char const*) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,binary const&,source_context) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(cl_context,cl_device_id,source const&,source_context) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,const_tinytc_binary_t) + - tinytc::make_kernel_bundle(cl_context,cl_device_id,tinytc_prog_t,tinytc_core_feature_flags_t) Recipe: function: - - tinytc::make_recipe_handler(cl_context,cl_device_id,recipe const&,source_context) - class: - - tinytc::opencl_recipe_handler + - tinytc::make_recipe_handler(cl_context,cl_device_id,tinytc_recipe_t) + - tinytc::submit(tinytc_recipe_handler_t,cl_command_queue,uint32_t,cl_event*) + - tinytc::submit_no_event struct: - tinytc::auto_mem_type< cl_mem > diff --git a/docs/api/core_capi.rst b/docs/api/core_capi.rst index 81c1901d..8915453d 100644 --- a/docs/api/core_capi.rst +++ b/docs/api/core_capi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Core C-API: + ========== Core C-API ========== @@ -30,10 +32,12 @@ Common * Functions - * :ref:`tinytc_error_string` + * :ref:`tinytc_status_to_string` * :ref:`tinytc_string_destroy` + * :ref:`tinytc_support_level_to_string` + * Typedefs * :ref:`tinytc_binary_t` @@ -42,34 +46,44 @@ Common * :ref:`tinytc_core_info_t` + * :ref:`tinytc_prog_t` + * :ref:`tinytc_recipe_t` * :ref:`tinytc_recipe_handler_t` - * :ref:`tinytc_source_t` + * :ref:`tinytc_spv_mod_t` - * :ref:`tinytc_source_context_t` + * :ref:`tinytc_compiler_context_t` * :ref:`const_tinytc_binary_t` * :ref:`const_tinytc_core_info_t` + * :ref:`const_tinytc_prog_t` + * :ref:`const_tinytc_recipe_t` * :ref:`const_tinytc_recipe_handler_t` - * :ref:`const_tinytc_source_t` + * :ref:`const_tinytc_spv_mod_t` + + * :ref:`const_tinytc_compiler_context_t` - * :ref:`const_tinytc_source_context_t` + * :ref:`tinytc_error_reporter_t` Common Enumerations ------------------- +.. _tinytc_status_t: + tinytc_status_t ............... .. doxygenenum:: tinytc_status_t +.. _tinytc_support_level_t: + tinytc_support_level_t ...................... @@ -78,31 +92,43 @@ tinytc_support_level_t Common Definitions ------------------ +.. _TINYTC_VERSION_MAJOR: + TINYTC_VERSION_MAJOR .................... .. doxygendefine:: TINYTC_VERSION_MAJOR +.. _TINYTC_VERSION_MINOR: + TINYTC_VERSION_MINOR .................... .. doxygendefine:: TINYTC_VERSION_MINOR +.. _TINYTC_VERSION_PATCH: + TINYTC_VERSION_PATCH .................... .. doxygendefine:: TINYTC_VERSION_PATCH +.. _TINYTC_VERSION_HASH: + TINYTC_VERSION_HASH ................... .. doxygendefine:: TINYTC_VERSION_HASH +.. _TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE: + TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE .............................................. .. doxygendefine:: TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE +.. _TINYTC_VERSION_DESCRIPTION: + TINYTC_VERSION_DESCRIPTION .......................... @@ -111,91 +137,155 @@ TINYTC_VERSION_DESCRIPTION Common Functions ---------------- -tinytc_error_string -................... +.. _tinytc_status_to_string: + +tinytc_status_to_string +....................... + +.. doxygenfunction:: tinytc_status_to_string -.. doxygenfunction:: tinytc_error_string +.. _tinytc_string_destroy: tinytc_string_destroy ..................... .. doxygenfunction:: tinytc_string_destroy +.. _tinytc_support_level_to_string: + +tinytc_support_level_to_string +.............................. + +.. doxygenfunction:: tinytc_support_level_to_string + Common Typedefs --------------- +.. _tinytc_binary_t: + tinytc_binary_t ............... .. doxygentypedef:: tinytc_binary_t +.. _tinytc_bool_t: + tinytc_bool_t ............. .. doxygentypedef:: tinytc_bool_t +.. _tinytc_core_info_t: + tinytc_core_info_t .................. .. doxygentypedef:: tinytc_core_info_t +.. _tinytc_prog_t: + +tinytc_prog_t +............. + +.. doxygentypedef:: tinytc_prog_t + +.. _tinytc_recipe_t: + tinytc_recipe_t ............... .. doxygentypedef:: tinytc_recipe_t +.. _tinytc_recipe_handler_t: + tinytc_recipe_handler_t ....................... .. doxygentypedef:: tinytc_recipe_handler_t -tinytc_source_t -............... +.. _tinytc_spv_mod_t: -.. doxygentypedef:: tinytc_source_t +tinytc_spv_mod_t +................ -tinytc_source_context_t -....................... +.. doxygentypedef:: tinytc_spv_mod_t + +.. _tinytc_compiler_context_t: + +tinytc_compiler_context_t +......................... -.. doxygentypedef:: tinytc_source_context_t +.. doxygentypedef:: tinytc_compiler_context_t + +.. _const_tinytc_binary_t: const_tinytc_binary_t ..................... .. doxygentypedef:: const_tinytc_binary_t +.. _const_tinytc_core_info_t: + const_tinytc_core_info_t ........................ .. doxygentypedef:: const_tinytc_core_info_t +.. _const_tinytc_prog_t: + +const_tinytc_prog_t +................... + +.. doxygentypedef:: const_tinytc_prog_t + +.. _const_tinytc_recipe_t: + const_tinytc_recipe_t ..................... .. doxygentypedef:: const_tinytc_recipe_t +.. _const_tinytc_recipe_handler_t: + const_tinytc_recipe_handler_t ............................. .. doxygentypedef:: const_tinytc_recipe_handler_t -const_tinytc_source_t -..................... +.. _const_tinytc_spv_mod_t: + +const_tinytc_spv_mod_t +...................... -.. doxygentypedef:: const_tinytc_source_t +.. doxygentypedef:: const_tinytc_spv_mod_t -const_tinytc_source_context_t -............................. +.. _const_tinytc_compiler_context_t: + +const_tinytc_compiler_context_t +............................... + +.. doxygentypedef:: const_tinytc_compiler_context_t -.. doxygentypedef:: const_tinytc_source_context_t +.. _tinytc_error_reporter_t: + +tinytc_error_reporter_t +....................... + +.. doxygentypedef:: tinytc_error_reporter_t Binary ====== +* Enumerations + + * :ref:`tinytc_bundle_format_t` + * Functions * :ref:`tinytc_binary_create` + * :ref:`tinytc_binary_get_compiler_context` + * :ref:`tinytc_binary_get_core_features` * :ref:`tinytc_binary_get_raw` @@ -204,60 +294,225 @@ Binary * :ref:`tinytc_binary_retain` + * :ref:`tinytc_optflag_to_string` + +Binary Enumerations +------------------- + +.. _tinytc_bundle_format_t: + +tinytc_bundle_format_t +...................... + +.. doxygenenum:: tinytc_bundle_format_t + Binary Functions ---------------- +.. _tinytc_binary_create: + tinytc_binary_create .................... .. doxygenfunction:: tinytc_binary_create +.. _tinytc_binary_get_compiler_context: + +tinytc_binary_get_compiler_context +.................................. + +.. doxygenfunction:: tinytc_binary_get_compiler_context + +.. _tinytc_binary_get_core_features: + tinytc_binary_get_core_features ............................... .. doxygenfunction:: tinytc_binary_get_core_features +.. _tinytc_binary_get_raw: + tinytc_binary_get_raw ..................... .. doxygenfunction:: tinytc_binary_get_raw +.. _tinytc_binary_release: + tinytc_binary_release ..................... .. doxygenfunction:: tinytc_binary_release +.. _tinytc_binary_retain: + tinytc_binary_retain .................... .. doxygenfunction:: tinytc_binary_retain +.. _tinytc_optflag_to_string: + +tinytc_optflag_to_string +........................ + +.. doxygenfunction:: tinytc_optflag_to_string + Compiler ======== * Enumerations - * :ref:`tinytc_bundle_format_t` + * :ref:`tinytc_optflag_t` * Functions - * :ref:`tinytc_prog_compile_to_opencl` + * :ref:`tinytc_bundle_format_to_string` + + * :ref:`tinytc_list_function_passes` + + * :ref:`tinytc_prog_compile_to_spirv` + + * :ref:`tinytc_prog_compile_to_spirv_and_assemble` + + * :ref:`tinytc_run_function_pass` + + * :ref:`tinytc_spirv_assemble` Compiler Enumerations --------------------- -tinytc_bundle_format_t -...................... +.. _tinytc_optflag_t: -.. doxygenenum:: tinytc_bundle_format_t +tinytc_optflag_t +................ + +.. doxygenenum:: tinytc_optflag_t Compiler Functions ------------------ -tinytc_prog_compile_to_opencl -............................. +.. _tinytc_bundle_format_to_string: + +tinytc_bundle_format_to_string +.............................. + +.. doxygenfunction:: tinytc_bundle_format_to_string + +.. _tinytc_list_function_passes: + +tinytc_list_function_passes +........................... + +.. doxygenfunction:: tinytc_list_function_passes + +.. _tinytc_prog_compile_to_spirv: + +tinytc_prog_compile_to_spirv +............................ + +.. doxygenfunction:: tinytc_prog_compile_to_spirv + +.. _tinytc_prog_compile_to_spirv_and_assemble: + +tinytc_prog_compile_to_spirv_and_assemble +......................................... + +.. doxygenfunction:: tinytc_prog_compile_to_spirv_and_assemble + +.. _tinytc_run_function_pass: + +tinytc_run_function_pass +........................ + +.. doxygenfunction:: tinytc_run_function_pass + +.. _tinytc_spirv_assemble: + +tinytc_spirv_assemble +..................... + +.. doxygenfunction:: tinytc_spirv_assemble + +Compiler Context +================ + +* Functions + + * :ref:`tinytc_compiler_context_create` + + * :ref:`tinytc_compiler_context_add_source` + + * :ref:`tinytc_compiler_context_set_error_reporter` + + * :ref:`tinytc_compiler_context_set_optimization_flag` + + * :ref:`tinytc_compiler_context_set_optimization_level` + + * :ref:`tinytc_compiler_context_report_error` + + * :ref:`tinytc_compiler_context_release` + + * :ref:`tinytc_compiler_context_retain` + +Compiler Context Functions +-------------------------- + +.. _tinytc_compiler_context_create: + +tinytc_compiler_context_create +.............................. + +.. doxygenfunction:: tinytc_compiler_context_create + +.. _tinytc_compiler_context_add_source: + +tinytc_compiler_context_add_source +.................................. -.. doxygenfunction:: tinytc_prog_compile_to_opencl +.. doxygenfunction:: tinytc_compiler_context_add_source + +.. _tinytc_compiler_context_set_error_reporter: + +tinytc_compiler_context_set_error_reporter +.......................................... + +.. doxygenfunction:: tinytc_compiler_context_set_error_reporter + +.. _tinytc_compiler_context_set_optimization_flag: + +tinytc_compiler_context_set_optimization_flag +............................................. + +.. doxygenfunction:: tinytc_compiler_context_set_optimization_flag + +.. _tinytc_compiler_context_set_optimization_level: + +tinytc_compiler_context_set_optimization_level +.............................................. + +.. doxygenfunction:: tinytc_compiler_context_set_optimization_level + +.. _tinytc_compiler_context_report_error: + +tinytc_compiler_context_report_error +.................................... + +.. doxygenfunction:: tinytc_compiler_context_report_error + +.. _tinytc_compiler_context_release: + +tinytc_compiler_context_release +............................... + +.. doxygenfunction:: tinytc_compiler_context_release + +.. _tinytc_compiler_context_retain: + +tinytc_compiler_context_retain +.............................. + +.. doxygenfunction:: tinytc_compiler_context_retain Device Info =========== @@ -268,26 +523,44 @@ Device Info * :ref:`tinytc_intel_gpu_architecture_t` + * :ref:`tinytc_spirv_feature_t` + * Functions + * :ref:`tinytc_core_feature_flag_to_string` + + * :ref:`tinytc_core_info_generic_create` + * :ref:`tinytc_core_info_get_core_features` + * :ref:`tinytc_core_info_get_default_alignment` + * :ref:`tinytc_core_info_get_register_space` * :ref:`tinytc_core_info_get_subgroup_sizes` - * :ref:`tinytc_core_info_set_core_features` - - * :ref:`tinytc_core_info_generic_create` + * :ref:`tinytc_core_info_have_spirv_feature` * :ref:`tinytc_core_info_intel_create` * :ref:`tinytc_core_info_intel_create_from_arch` + * :ref:`tinytc_core_info_intel_create_from_name` + * :ref:`tinytc_core_info_release` * :ref:`tinytc_core_info_retain` + * :ref:`tinytc_core_info_set_core_features` + + * :ref:`tinytc_core_info_set_default_alignment` + + * :ref:`tinytc_core_info_set_spirv_feature` + + * :ref:`tinytc_intel_gpu_architecture_to_string` + + * :ref:`tinytc_spirv_feature_to_string` + * Typedefs * :ref:`tinytc_core_feature_flags_t` @@ -295,72 +568,203 @@ Device Info Device Info Enumerations ------------------------ +.. _tinytc_core_feature_flag_t: + tinytc_core_feature_flag_t .......................... .. doxygenenum:: tinytc_core_feature_flag_t +.. _tinytc_intel_gpu_architecture_t: + tinytc_intel_gpu_architecture_t ............................... .. doxygenenum:: tinytc_intel_gpu_architecture_t +.. _tinytc_spirv_feature_t: + +tinytc_spirv_feature_t +...................... + +.. doxygenenum:: tinytc_spirv_feature_t + Device Info Functions --------------------- +.. _tinytc_core_feature_flag_to_string: + +tinytc_core_feature_flag_to_string +.................................. + +.. doxygenfunction:: tinytc_core_feature_flag_to_string + +.. _tinytc_core_info_generic_create: + +tinytc_core_info_generic_create +............................... + +.. doxygenfunction:: tinytc_core_info_generic_create + +.. _tinytc_core_info_get_core_features: + tinytc_core_info_get_core_features .................................. .. doxygenfunction:: tinytc_core_info_get_core_features +.. _tinytc_core_info_get_default_alignment: + +tinytc_core_info_get_default_alignment +...................................... + +.. doxygenfunction:: tinytc_core_info_get_default_alignment + +.. _tinytc_core_info_get_register_space: + tinytc_core_info_get_register_space ................................... .. doxygenfunction:: tinytc_core_info_get_register_space +.. _tinytc_core_info_get_subgroup_sizes: + tinytc_core_info_get_subgroup_sizes ................................... .. doxygenfunction:: tinytc_core_info_get_subgroup_sizes -tinytc_core_info_set_core_features -.................................. +.. _tinytc_core_info_have_spirv_feature: -.. doxygenfunction:: tinytc_core_info_set_core_features +tinytc_core_info_have_spirv_feature +................................... -tinytc_core_info_generic_create -............................... +.. doxygenfunction:: tinytc_core_info_have_spirv_feature -.. doxygenfunction:: tinytc_core_info_generic_create +.. _tinytc_core_info_intel_create: tinytc_core_info_intel_create ............................. .. doxygenfunction:: tinytc_core_info_intel_create +.. _tinytc_core_info_intel_create_from_arch: + tinytc_core_info_intel_create_from_arch ....................................... .. doxygenfunction:: tinytc_core_info_intel_create_from_arch +.. _tinytc_core_info_intel_create_from_name: + +tinytc_core_info_intel_create_from_name +....................................... + +.. doxygenfunction:: tinytc_core_info_intel_create_from_name + +.. _tinytc_core_info_release: + tinytc_core_info_release ........................ .. doxygenfunction:: tinytc_core_info_release +.. _tinytc_core_info_retain: + tinytc_core_info_retain ....................... .. doxygenfunction:: tinytc_core_info_retain +.. _tinytc_core_info_set_core_features: + +tinytc_core_info_set_core_features +.................................. + +.. doxygenfunction:: tinytc_core_info_set_core_features + +.. _tinytc_core_info_set_default_alignment: + +tinytc_core_info_set_default_alignment +...................................... + +.. doxygenfunction:: tinytc_core_info_set_default_alignment + +.. _tinytc_core_info_set_spirv_feature: + +tinytc_core_info_set_spirv_feature +.................................. + +.. doxygenfunction:: tinytc_core_info_set_spirv_feature + +.. _tinytc_intel_gpu_architecture_to_string: + +tinytc_intel_gpu_architecture_to_string +....................................... + +.. doxygenfunction:: tinytc_intel_gpu_architecture_to_string + +.. _tinytc_spirv_feature_to_string: + +tinytc_spirv_feature_to_string +.............................. + +.. doxygenfunction:: tinytc_spirv_feature_to_string + Device Info Typedefs -------------------- +.. _tinytc_core_feature_flags_t: + tinytc_core_feature_flags_t ........................... .. doxygentypedef:: tinytc_core_feature_flags_t +FP math +======= + +* Functions + + * :ref:`tinytc_f32_to_bf16_as_ui16` + + * :ref:`tinytc_f32_to_f16_as_ui16` + + * :ref:`tinytc_f16_as_ui16_to_f32` + + * :ref:`tinytc_bf16_as_ui16_to_f32` + +FP math Functions +----------------- + +.. _tinytc_f32_to_bf16_as_ui16: + +tinytc_f32_to_bf16_as_ui16 +.......................... + +.. doxygenfunction:: tinytc_f32_to_bf16_as_ui16 + +.. _tinytc_f32_to_f16_as_ui16: + +tinytc_f32_to_f16_as_ui16 +......................... + +.. doxygenfunction:: tinytc_f32_to_f16_as_ui16 + +.. _tinytc_f16_as_ui16_to_f32: + +tinytc_f16_as_ui16_to_f32 +......................... + +.. doxygenfunction:: tinytc_f16_as_ui16_to_f32 + +.. _tinytc_bf16_as_ui16_to_f32: + +tinytc_bf16_as_ui16_to_f32 +.......................... + +.. doxygenfunction:: tinytc_bf16_as_ui16_to_f32 + Parser ====== @@ -375,21 +779,89 @@ Parser Parser Functions ---------------- +.. _tinytc_parse_file: + tinytc_parse_file ................. .. doxygenfunction:: tinytc_parse_file +.. _tinytc_parse_stdin: + tinytc_parse_stdin .................. .. doxygenfunction:: tinytc_parse_stdin +.. _tinytc_parse_string: + tinytc_parse_string ................... .. doxygenfunction:: tinytc_parse_string +Program +======= + +* Functions + + * :ref:`tinytc_prog_dump` + + * :ref:`tinytc_prog_get_compiler_context` + + * :ref:`tinytc_prog_print_to_file` + + * :ref:`tinytc_prog_print_to_string` + + * :ref:`tinytc_prog_release` + + * :ref:`tinytc_prog_retain` + +Program Functions +----------------- + +.. _tinytc_prog_dump: + +tinytc_prog_dump +................ + +.. doxygenfunction:: tinytc_prog_dump + +.. _tinytc_prog_get_compiler_context: + +tinytc_prog_get_compiler_context +................................ + +.. doxygenfunction:: tinytc_prog_get_compiler_context + +.. _tinytc_prog_print_to_file: + +tinytc_prog_print_to_file +......................... + +.. doxygenfunction:: tinytc_prog_print_to_file + +.. _tinytc_prog_print_to_string: + +tinytc_prog_print_to_string +........................... + +.. doxygenfunction:: tinytc_prog_print_to_string + +.. _tinytc_prog_release: + +tinytc_prog_release +................... + +.. doxygenfunction:: tinytc_prog_release + +.. _tinytc_prog_retain: + +tinytc_prog_retain +.................. + +.. doxygenfunction:: tinytc_prog_retain + Recipe ====== @@ -399,9 +871,11 @@ Recipe * Functions - * :ref:`tinytc_recipe_get_prog` + * :ref:`tinytc_mem_type_to_string` - * :ref:`tinytc_recipe_get_source` + * :ref:`tinytc_recipe_get_binary` + + * :ref:`tinytc_recipe_get_prog` * :ref:`tinytc_recipe_handler_get_recipe` @@ -428,6 +902,8 @@ Recipe Recipe Enumerations ------------------- +.. _tinytc_mem_type_t: + tinytc_mem_type_t ................. @@ -436,168 +912,154 @@ tinytc_mem_type_t Recipe Functions ---------------- +.. _tinytc_mem_type_to_string: + +tinytc_mem_type_to_string +......................... + +.. doxygenfunction:: tinytc_mem_type_to_string + +.. _tinytc_recipe_get_binary: + +tinytc_recipe_get_binary +........................ + +.. doxygenfunction:: tinytc_recipe_get_binary + +.. _tinytc_recipe_get_prog: + tinytc_recipe_get_prog ...................... .. doxygenfunction:: tinytc_recipe_get_prog -tinytc_recipe_get_source -........................ - -.. doxygenfunction:: tinytc_recipe_get_source +.. _tinytc_recipe_handler_get_recipe: tinytc_recipe_handler_get_recipe ................................ .. doxygenfunction:: tinytc_recipe_handler_get_recipe +.. _tinytc_recipe_small_gemm_batched_create: + tinytc_recipe_small_gemm_batched_create ....................................... .. doxygenfunction:: tinytc_recipe_small_gemm_batched_create +.. _tinytc_recipe_small_gemm_batched_set_args: + tinytc_recipe_small_gemm_batched_set_args ......................................... .. doxygenfunction:: tinytc_recipe_small_gemm_batched_set_args +.. _tinytc_recipe_tall_and_skinny_create: + tinytc_recipe_tall_and_skinny_create .................................... .. doxygenfunction:: tinytc_recipe_tall_and_skinny_create +.. _tinytc_recipe_tall_and_skinny_create_specialized: + tinytc_recipe_tall_and_skinny_create_specialized ................................................ .. doxygenfunction:: tinytc_recipe_tall_and_skinny_create_specialized +.. _tinytc_recipe_tall_and_skinny_set_args: + tinytc_recipe_tall_and_skinny_set_args ...................................... .. doxygenfunction:: tinytc_recipe_tall_and_skinny_set_args +.. _tinytc_recipe_tall_and_skinny_suggest_block_size: + tinytc_recipe_tall_and_skinny_suggest_block_size ................................................ .. doxygenfunction:: tinytc_recipe_tall_and_skinny_suggest_block_size +.. _tinytc_recipe_release: + tinytc_recipe_release ..................... .. doxygenfunction:: tinytc_recipe_release +.. _tinytc_recipe_retain: + tinytc_recipe_retain .................... .. doxygenfunction:: tinytc_recipe_retain +.. _tinytc_recipe_handler_release: + tinytc_recipe_handler_release ............................. .. doxygenfunction:: tinytc_recipe_handler_release +.. _tinytc_recipe_handler_retain: + tinytc_recipe_handler_retain ............................ .. doxygenfunction:: tinytc_recipe_handler_retain -Source -====== +SPIR-V module +============= * Functions - * :ref:`tinytc_source_get_code` - - * :ref:`tinytc_source_get_core_features` - - * :ref:`tinytc_source_get_location` - - * :ref:`tinytc_source_get_extensions` - - * :ref:`tinytc_source_release` - - * :ref:`tinytc_source_retain` - -Source Functions ----------------- - -tinytc_source_get_code -...................... - -.. doxygenfunction:: tinytc_source_get_code - -tinytc_source_get_core_features -............................... - -.. doxygenfunction:: tinytc_source_get_core_features - -tinytc_source_get_location -.......................... - -.. doxygenfunction:: tinytc_source_get_location - -tinytc_source_get_extensions -............................ - -.. doxygenfunction:: tinytc_source_get_extensions - -tinytc_source_release -..................... + * :ref:`tinytc_spv_mod_dump` -.. doxygenfunction:: tinytc_source_release + * :ref:`tinytc_spv_mod_print_to_file` -tinytc_source_retain -.................... - -.. doxygenfunction:: tinytc_source_retain - -Source Context -============== - -* Functions + * :ref:`tinytc_spv_mod_print_to_string` - * :ref:`tinytc_source_context_create` + * :ref:`tinytc_spv_mod_release` - * :ref:`tinytc_source_context_add_source` + * :ref:`tinytc_spv_mod_retain` - * :ref:`tinytc_source_context_get_error_log` +SPIR-V module Functions +----------------------- - * :ref:`tinytc_source_context_report_error` +.. _tinytc_spv_mod_dump: - * :ref:`tinytc_source_context_release` +tinytc_spv_mod_dump +................... - * :ref:`tinytc_source_context_retain` +.. doxygenfunction:: tinytc_spv_mod_dump -Source Context Functions ------------------------- +.. _tinytc_spv_mod_print_to_file: -tinytc_source_context_create +tinytc_spv_mod_print_to_file ............................ -.. doxygenfunction:: tinytc_source_context_create +.. doxygenfunction:: tinytc_spv_mod_print_to_file -tinytc_source_context_add_source -................................ +.. _tinytc_spv_mod_print_to_string: -.. doxygenfunction:: tinytc_source_context_add_source - -tinytc_source_context_get_error_log -................................... +tinytc_spv_mod_print_to_string +.............................. -.. doxygenfunction:: tinytc_source_context_get_error_log +.. doxygenfunction:: tinytc_spv_mod_print_to_string -tinytc_source_context_report_error -.................................. +.. _tinytc_spv_mod_release: -.. doxygenfunction:: tinytc_source_context_report_error +tinytc_spv_mod_release +...................... -tinytc_source_context_release -............................. +.. doxygenfunction:: tinytc_spv_mod_release -.. doxygenfunction:: tinytc_source_context_release +.. _tinytc_spv_mod_retain: -tinytc_source_context_retain -............................ +tinytc_spv_mod_retain +..................... -.. doxygenfunction:: tinytc_source_context_retain +.. doxygenfunction:: tinytc_spv_mod_retain diff --git a/docs/api/core_capi.yaml b/docs/api/core_capi.yaml index 0a2e97c6..9140e862 100644 --- a/docs/api/core_capi.yaml +++ b/docs/api/core_capi.yaml @@ -13,61 +13,108 @@ Core C-API: - TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE - TINYTC_VERSION_DESCRIPTION function: - - tinytc_error_string + - tinytc_status_to_string - tinytc_string_destroy + - tinytc_support_level_to_string typedef: - tinytc_binary_t - tinytc_bool_t - tinytc_core_info_t + - tinytc_prog_t - tinytc_recipe_t - tinytc_recipe_handler_t - - tinytc_source_t - - tinytc_source_context_t + - tinytc_spv_mod_t + - tinytc_compiler_context_t - const_tinytc_binary_t - const_tinytc_core_info_t + - const_tinytc_prog_t - const_tinytc_recipe_t - const_tinytc_recipe_handler_t - - const_tinytc_source_t - - const_tinytc_source_context_t + - const_tinytc_spv_mod_t + - const_tinytc_compiler_context_t + - tinytc_error_reporter_t Binary: + enum: + - tinytc_bundle_format_t function: - tinytc_binary_create + - tinytc_binary_get_compiler_context - tinytc_binary_get_core_features - tinytc_binary_get_raw - tinytc_binary_release - tinytc_binary_retain + - tinytc_optflag_to_string Compiler: enum: - - tinytc_bundle_format_t + - tinytc_optflag_t function: - - tinytc_prog_compile_to_opencl + - tinytc_bundle_format_to_string + - tinytc_list_function_passes + - tinytc_prog_compile_to_spirv + - tinytc_prog_compile_to_spirv_and_assemble + - tinytc_run_function_pass + - tinytc_spirv_assemble + Compiler Context: + function: + - tinytc_compiler_context_create + - tinytc_compiler_context_add_source + - tinytc_compiler_context_set_error_reporter + - tinytc_compiler_context_set_optimization_flag + - tinytc_compiler_context_set_optimization_level + - tinytc_compiler_context_report_error + - tinytc_compiler_context_release + - tinytc_compiler_context_retain Device Info: enum: - tinytc_core_feature_flag_t - tinytc_intel_gpu_architecture_t + - tinytc_spirv_feature_t function: + - tinytc_core_feature_flag_to_string + - tinytc_core_info_generic_create - tinytc_core_info_get_core_features + - tinytc_core_info_get_default_alignment - tinytc_core_info_get_register_space - tinytc_core_info_get_subgroup_sizes - - tinytc_core_info_set_core_features - - tinytc_core_info_generic_create + - tinytc_core_info_have_spirv_feature - tinytc_core_info_intel_create - tinytc_core_info_intel_create_from_arch + - tinytc_core_info_intel_create_from_name - tinytc_core_info_release - tinytc_core_info_retain + - tinytc_core_info_set_core_features + - tinytc_core_info_set_default_alignment + - tinytc_core_info_set_spirv_feature + - tinytc_intel_gpu_architecture_to_string + - tinytc_spirv_feature_to_string typedef: - tinytc_core_feature_flags_t + FP math: + function: + - tinytc_f32_to_bf16_as_ui16 + - tinytc_f32_to_f16_as_ui16 + - tinytc_f16_as_ui16_to_f32 + - tinytc_bf16_as_ui16_to_f32 Parser: function: - tinytc_parse_file - tinytc_parse_stdin - tinytc_parse_string + Program: + function: + - tinytc_prog_dump + - tinytc_prog_get_compiler_context + - tinytc_prog_print_to_file + - tinytc_prog_print_to_string + - tinytc_prog_release + - tinytc_prog_retain Recipe: enum: - tinytc_mem_type_t function: + - tinytc_mem_type_to_string + - tinytc_recipe_get_binary - tinytc_recipe_get_prog - - tinytc_recipe_get_source - tinytc_recipe_handler_get_recipe - tinytc_recipe_small_gemm_batched_create - tinytc_recipe_small_gemm_batched_set_args @@ -79,19 +126,10 @@ Core C-API: - tinytc_recipe_retain - tinytc_recipe_handler_release - tinytc_recipe_handler_retain - Source: - function: - - tinytc_source_get_code - - tinytc_source_get_core_features - - tinytc_source_get_location - - tinytc_source_get_extensions - - tinytc_source_release - - tinytc_source_retain - Source Context: + SPIR-V module: function: - - tinytc_source_context_create - - tinytc_source_context_add_source - - tinytc_source_context_get_error_log - - tinytc_source_context_report_error - - tinytc_source_context_release - - tinytc_source_context_retain + - tinytc_spv_mod_dump + - tinytc_spv_mod_print_to_file + - tinytc_spv_mod_print_to_string + - tinytc_spv_mod_release + - tinytc_spv_mod_retain diff --git a/docs/api/core_cxxapi.rst b/docs/api/core_cxxapi.rst index ba68c04e..fc5e9d83 100644 --- a/docs/api/core_cxxapi.rst +++ b/docs/api/core_cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Core C++-API: + ============ Core C++-API ============ @@ -10,32 +12,69 @@ Common * Enumerations - * :ref:`status` + * :ref:`tinytc::mem_type` + + * :ref:`tinytc::status` - * :ref:`support_level` + * :ref:`tinytc::support_level` * Functions - * :ref:`error_string` + * :ref:`tinytc::CHECK_STATUS` + + * :ref:`tinytc::CHECK_STATUS_LOC` - * :ref:`CHECK_STATUS` + * :ref:`tinytc::to_string(status)` - * :ref:`CHECK_STATUS_LOC` + * :ref:`tinytc::to_string(support_level)` * Classes - * :ref:`shared_handle` + * :ref:`tinytc::array_view_base` + + * :ref:`tinytc::array_view` + + * :ref:`tinytc::mutable_array_view` + + * :ref:`tinytc::shared_handle` + + * :ref:`tinytc::unique_handle` + +* Structures + + * :ref:`tinytc::auto_mem_type` + + * :ref:`tinytc::auto_mem_type\< T, std::enable_if_t\< is_usm_pointer_type\< T \> \> \>` - * :ref:`unique_handle` + * :ref:`tinytc::mem` + +* Variables + + * :ref:`tinytc::auto_mem_type_v` + + * :ref:`tinytc::is_supported_scalar_type` + + * :ref:`tinytc::is_usm_pointer_type` Common Enumerations ------------------- +.. _tinytc::mem_type: + +mem_type +........ + +.. doxygenenum:: tinytc::mem_type + +.. _tinytc::status: + status ...... .. doxygenenum:: tinytc::status +.. _tinytc::support_level: + support_level ............. @@ -44,52 +83,148 @@ support_level Common Functions ---------------- -error_string -............ - -.. doxygenfunction:: tinytc::error_string +.. _tinytc::CHECK_STATUS: CHECK_STATUS ............ .. doxygenfunction:: tinytc::CHECK_STATUS +.. _tinytc::CHECK_STATUS_LOC: + CHECK_STATUS_LOC ................ .. doxygenfunction:: tinytc::CHECK_STATUS_LOC +.. _tinytc::to_string(status): + +to_string(status) +................. + +.. doxygenfunction:: tinytc::to_string(status) + +.. _tinytc::to_string(support_level): + +to_string(support_level) +........................ + +.. doxygenfunction:: tinytc::to_string(support_level) + Common Classes -------------- +.. _tinytc::array_view_base: + +array_view_base +............... + +.. doxygenclass:: tinytc::array_view_base + +.. _tinytc::array_view: + +array_view +.......... + +.. doxygenclass:: tinytc::array_view + +.. _tinytc::mutable_array_view: + +mutable_array_view +.................. + +.. doxygenclass:: tinytc::mutable_array_view + +.. _tinytc::shared_handle: + shared_handle ............. .. doxygenclass:: tinytc::shared_handle +.. _tinytc::unique_handle: + unique_handle ............. .. doxygenclass:: tinytc::unique_handle +Common Structures +----------------- + +.. _tinytc::auto_mem_type: + +auto_mem_type +............. + +.. doxygenstruct:: tinytc::auto_mem_type + +.. _tinytc::auto_mem_type\< T, std::enable_if_t\< is_usm_pointer_type\< T \> \> \>: + +auto_mem_type>> +.......................................................... + +.. doxygenstruct:: tinytc::auto_mem_type< T, std::enable_if_t< is_usm_pointer_type< T > > > + +.. _tinytc::mem: + +mem +... + +.. doxygenstruct:: tinytc::mem + +Common Variables +---------------- + +.. _tinytc::auto_mem_type_v: + +auto_mem_type_v +............... + +.. doxygenvariable:: tinytc::auto_mem_type_v + +.. _tinytc::is_supported_scalar_type: + +is_supported_scalar_type +........................ + +.. doxygenvariable:: tinytc::is_supported_scalar_type + +.. _tinytc::is_usm_pointer_type: + +is_usm_pointer_type +................... + +.. doxygenvariable:: tinytc::is_usm_pointer_type + Binary ====== * Enumerations - * :ref:`bundle_format` + * :ref:`tinytc::bundle_format` * Functions - * :ref:`make_binary` + * :ref:`tinytc::create_binary` -* Classes + * :ref:`tinytc::get_compiler_context(const_tinytc_binary_t)` - * :ref:`binary` + * :ref:`tinytc::get_core_features(const_tinytc_binary_t)` + + * :ref:`tinytc::get_raw` + + * :ref:`tinytc::to_string(bundle_format)` + +* Structures + + * :ref:`tinytc::raw_binary` Binary Enumerations ------------------- +.. _tinytc::bundle_format: + bundle_format ............. @@ -98,281 +233,570 @@ bundle_format Binary Functions ---------------- -make_binary -........... +.. _tinytc::create_binary: -.. doxygenfunction:: tinytc::make_binary +create_binary +............. -Binary Classes --------------- +.. doxygenfunction:: tinytc::create_binary -binary -...... +.. _tinytc::get_compiler_context(const_tinytc_binary_t): + +get_compiler_context(const_tinytc_binary_t) +........................................... -.. doxygenclass:: tinytc::binary +.. doxygenfunction:: tinytc::get_compiler_context(const_tinytc_binary_t) + +.. _tinytc::get_core_features(const_tinytc_binary_t): + +get_core_features(const_tinytc_binary_t) +........................................ + +.. doxygenfunction:: tinytc::get_core_features(const_tinytc_binary_t) + +.. _tinytc::get_raw: + +get_raw +....... + +.. doxygenfunction:: tinytc::get_raw + +.. _tinytc::to_string(bundle_format): + +to_string(bundle_format) +........................ + +.. doxygenfunction:: tinytc::to_string(bundle_format) + +Binary Structures +----------------- + +.. _tinytc::raw_binary: + +raw_binary +.......... + +.. doxygenstruct:: tinytc::raw_binary Compiler ======== * Functions - * :ref:`compile_to_opencl` + * :ref:`tinytc::run_function_pass` + + * :ref:`tinytc::list_function_passes` + + * :ref:`tinytc::compile_to_spirv` + + * :ref:`tinytc::compile_to_spirv_and_assemble` + + * :ref:`tinytc::spirv_assemble` Compiler Functions ------------------ -compile_to_opencl +.. _tinytc::run_function_pass: + +run_function_pass ................. -.. doxygenfunction:: tinytc::compile_to_opencl +.. doxygenfunction:: tinytc::run_function_pass + +.. _tinytc::list_function_passes: + +list_function_passes +.................... + +.. doxygenfunction:: tinytc::list_function_passes + +.. _tinytc::compile_to_spirv: + +compile_to_spirv +................ + +.. doxygenfunction:: tinytc::compile_to_spirv + +.. _tinytc::compile_to_spirv_and_assemble: + +compile_to_spirv_and_assemble +............................. + +.. doxygenfunction:: tinytc::compile_to_spirv_and_assemble + +.. _tinytc::spirv_assemble: + +spirv_assemble +.............. + +.. doxygenfunction:: tinytc::spirv_assemble + +Compiler Context +================ + +* Enumerations + + * :ref:`tinytc::optflag` + +* Functions + + * :ref:`tinytc::add_source` + + * :ref:`tinytc::create_compiler_context` + + * :ref:`tinytc::set_error_reporter` + + * :ref:`tinytc::set_optimization_flag` + + * :ref:`tinytc::set_optimization_level` + + * :ref:`tinytc::report_error` + + * :ref:`tinytc::to_string(optflag)` + +Compiler Context Enumerations +----------------------------- + +.. _tinytc::optflag: + +optflag +....... + +.. doxygenenum:: tinytc::optflag + +Compiler Context Functions +-------------------------- + +.. _tinytc::add_source: + +add_source +.......... + +.. doxygenfunction:: tinytc::add_source + +.. _tinytc::create_compiler_context: + +create_compiler_context +....................... + +.. doxygenfunction:: tinytc::create_compiler_context + +.. _tinytc::set_error_reporter: + +set_error_reporter +.................. + +.. doxygenfunction:: tinytc::set_error_reporter + +.. _tinytc::set_optimization_flag: + +set_optimization_flag +..................... + +.. doxygenfunction:: tinytc::set_optimization_flag + +.. _tinytc::set_optimization_level: + +set_optimization_level +...................... + +.. doxygenfunction:: tinytc::set_optimization_level + +.. _tinytc::report_error: + +report_error +............ + +.. doxygenfunction:: tinytc::report_error + +.. _tinytc::to_string(optflag): + +to_string(optflag) +.................. + +.. doxygenfunction:: tinytc::to_string(optflag) Device Info =========== * Enumerations - * :ref:`core_feature_flag` + * :ref:`tinytc::core_feature_flag` + + * :ref:`tinytc::intel_gpu_architecture` - * :ref:`intel_gpu_architecture` + * :ref:`tinytc::spirv_feature` * Functions - * :ref:`make_core_info_generic` + * :ref:`tinytc::create_core_info_generic` - * :ref:`make_core_info_intel` + * :ref:`tinytc::create_core_info_intel` - * :ref:`make_core_info_intel_from_arch` + * :ref:`tinytc::create_core_info_intel_from_arch` -* Classes + * :ref:`tinytc::create_core_info_intel_from_name` + + * :ref:`tinytc::get_core_features(const_tinytc_core_info_t)` + + * :ref:`tinytc::get_subgroup_sizes` + + * :ref:`tinytc::get_register_space` + + * :ref:`tinytc::have_spirv_feature` + + * :ref:`tinytc::set_core_features` + + * :ref:`tinytc::set_default_alignment` + + * :ref:`tinytc::set_spirv_feature` + + * :ref:`tinytc::to_string(core_feature_flag)` + + * :ref:`tinytc::to_string(intel_gpu_architecture)` - * :ref:`core_info` + * :ref:`tinytc::to_string(spirv_feature)` Device Info Enumerations ------------------------ +.. _tinytc::core_feature_flag: + core_feature_flag ................. .. doxygenenum:: tinytc::core_feature_flag +.. _tinytc::intel_gpu_architecture: + intel_gpu_architecture ...................... .. doxygenenum:: tinytc::intel_gpu_architecture +.. _tinytc::spirv_feature: + +spirv_feature +............. + +.. doxygenenum:: tinytc::spirv_feature + Device Info Functions --------------------- -make_core_info_generic +.. _tinytc::create_core_info_generic: + +create_core_info_generic +........................ + +.. doxygenfunction:: tinytc::create_core_info_generic + +.. _tinytc::create_core_info_intel: + +create_core_info_intel ...................... -.. doxygenfunction:: tinytc::make_core_info_generic +.. doxygenfunction:: tinytc::create_core_info_intel -make_core_info_intel -.................... +.. _tinytc::create_core_info_intel_from_arch: -.. doxygenfunction:: tinytc::make_core_info_intel +create_core_info_intel_from_arch +................................ -make_core_info_intel_from_arch -.............................. +.. doxygenfunction:: tinytc::create_core_info_intel_from_arch -.. doxygenfunction:: tinytc::make_core_info_intel_from_arch +.. _tinytc::create_core_info_intel_from_name: -Device Info Classes -------------------- +create_core_info_intel_from_name +................................ -core_info -......... +.. doxygenfunction:: tinytc::create_core_info_intel_from_name -.. doxygenclass:: tinytc::core_info +.. _tinytc::get_core_features(const_tinytc_core_info_t): -Parser -====== +get_core_features(const_tinytc_core_info_t) +........................................... -* Functions +.. doxygenfunction:: tinytc::get_core_features(const_tinytc_core_info_t) - * :ref:`parse_file` +.. _tinytc::get_subgroup_sizes: - * :ref:`parse_stdin` +get_subgroup_sizes +.................. - * :ref:`parse_string` +.. doxygenfunction:: tinytc::get_subgroup_sizes -Parser Functions ----------------- +.. _tinytc::get_register_space: -parse_file -.......... +get_register_space +.................. -.. doxygenfunction:: tinytc::parse_file +.. doxygenfunction:: tinytc::get_register_space -parse_stdin -........... +.. _tinytc::have_spirv_feature: -.. doxygenfunction:: tinytc::parse_stdin +have_spirv_feature +.................. -parse_string -............ +.. doxygenfunction:: tinytc::have_spirv_feature -.. doxygenfunction:: tinytc::parse_string +.. _tinytc::set_core_features: -Recipe -====== +set_core_features +................. -* Enumerations +.. doxygenfunction:: tinytc::set_core_features - * :ref:`mem_type` +.. _tinytc::set_default_alignment: -* Functions +set_default_alignment +..................... - * :ref:`make_small_gemm_batched` +.. doxygenfunction:: tinytc::set_default_alignment - * :ref:`make_tall_and_skinny` +.. _tinytc::set_spirv_feature: - * :ref:`make_tall_and_skinny_specialized` +set_spirv_feature +................. -* Classes +.. doxygenfunction:: tinytc::set_spirv_feature - * :ref:`recipe` +.. _tinytc::to_string(core_feature_flag): - * :ref:`recipe_handler` +to_string(core_feature_flag) +............................ - * :ref:`small_gemm_batched` +.. doxygenfunction:: tinytc::to_string(core_feature_flag) - * :ref:`tall_and_skinny` +.. _tinytc::to_string(intel_gpu_architecture): -* Structures +to_string(intel_gpu_architecture) +................................. - * :ref:`auto_mem_type` +.. doxygenfunction:: tinytc::to_string(intel_gpu_architecture) - * :ref:`auto_mem_type\\>\>` +.. _tinytc::to_string(spirv_feature): - * :ref:`mem` +to_string(spirv_feature) +........................ -* Variables +.. doxygenfunction:: tinytc::to_string(spirv_feature) - * :ref:`auto_mem_type_v` +FP math +======= - * :ref:`usm_pointer_type` +* Functions -Recipe Enumerations -------------------- + * :ref:`tinytc::ieee754_extend` -mem_type -........ + * :ref:`tinytc::ieee754_truncate` -.. doxygenenum:: tinytc::mem_type +* Classes -Recipe Functions ----------------- + * :ref:`tinytc::lp_float` -make_small_gemm_batched -....................... +* Structures -.. doxygenfunction:: tinytc::make_small_gemm_batched + * :ref:`tinytc::ieee754_format` -make_tall_and_skinny -.................... +* Typedefs -.. doxygenfunction:: tinytc::make_tall_and_skinny + * :ref:`tinytc::bf16_format` -make_tall_and_skinny_specialized -................................ + * :ref:`tinytc::bfloat16` -.. doxygenfunction:: tinytc::make_tall_and_skinny_specialized + * :ref:`tinytc::f16_format` -Recipe Classes --------------- + * :ref:`tinytc::f32_format` -recipe -...... + * :ref:`tinytc::half` + +FP math Functions +----------------- -.. doxygenclass:: tinytc::recipe +.. _tinytc::ieee754_extend: -recipe_handler +ieee754_extend .............. -.. doxygenclass:: tinytc::recipe_handler +.. doxygenfunction:: tinytc::ieee754_extend -small_gemm_batched -.................. +.. _tinytc::ieee754_truncate: -.. doxygenclass:: tinytc::small_gemm_batched +ieee754_truncate +................ -tall_and_skinny -............... +.. doxygenfunction:: tinytc::ieee754_truncate -.. doxygenclass:: tinytc::tall_and_skinny +FP math Classes +--------------- -Recipe Structures ------------------ +.. _tinytc::lp_float: -auto_mem_type -............. +lp_float +........ -.. doxygenstruct:: tinytc::auto_mem_type +.. doxygenclass:: tinytc::lp_float -auto_mem_type>> -....................................................... +FP math Structures +------------------ -.. doxygenstruct:: tinytc::auto_mem_type< T, std::enable_if_t< usm_pointer_type< T > > > +.. _tinytc::ieee754_format: -mem -... +ieee754_format +.............. -.. doxygenstruct:: tinytc::mem +.. doxygenstruct:: tinytc::ieee754_format -Recipe Variables +FP math Typedefs ---------------- -auto_mem_type_v -............... +.. _tinytc::bf16_format: -.. doxygenvariable:: tinytc::auto_mem_type_v +bf16_format +........... -usm_pointer_type -................ +.. doxygentypedef:: tinytc::bf16_format + +.. _tinytc::bfloat16: + +bfloat16 +........ + +.. doxygentypedef:: tinytc::bfloat16 + +.. _tinytc::f16_format: + +f16_format +.......... + +.. doxygentypedef:: tinytc::f16_format -.. doxygenvariable:: tinytc::usm_pointer_type +.. _tinytc::f32_format: -Source +f32_format +.......... + +.. doxygentypedef:: tinytc::f32_format + +.. _tinytc::half: + +half +.... + +.. doxygentypedef:: tinytc::half + +Parser ====== -* Classes +* Functions - * :ref:`source` + * :ref:`tinytc::parse_file` -Source Classes --------------- + * :ref:`tinytc::parse_stdin` -source -...... + * :ref:`tinytc::parse_string` + +Parser Functions +---------------- + +.. _tinytc::parse_file: -.. doxygenclass:: tinytc::source +parse_file +.......... -Source Context -============== +.. doxygenfunction:: tinytc::parse_file + +.. _tinytc::parse_stdin: + +parse_stdin +........... + +.. doxygenfunction:: tinytc::parse_stdin + +.. _tinytc::parse_string: + +parse_string +............ + +.. doxygenfunction:: tinytc::parse_string + +Program +======= * Functions - * :ref:`make_source_context` + * :ref:`tinytc::dump(tinytc_prog_t)` -* Classes + * :ref:`tinytc::get_compiler_context(const_tinytc_prog_t)` - * :ref:`source_context` + * :ref:`tinytc::print_to_file(tinytc_prog_t, char const\*)` -Source Context Functions ------------------------- + * :ref:`tinytc::print_to_string(tinytc_prog_t)` + +Program Functions +----------------- -make_source_context +.. _tinytc::dump(tinytc_prog_t): + +dump(tinytc_prog_t) ................... -.. doxygenfunction:: tinytc::make_source_context +.. doxygenfunction:: tinytc::dump(tinytc_prog_t) -Source Context Classes ----------------------- +.. _tinytc::get_compiler_context(const_tinytc_prog_t): -source_context -.............. +get_compiler_context(const_tinytc_prog_t) +......................................... + +.. doxygenfunction:: tinytc::get_compiler_context(const_tinytc_prog_t) + +.. _tinytc::print_to_file(tinytc_prog_t, char const\*): + +print_to_file(tinytc_prog_t, char const\*) +.......................................... + +.. doxygenfunction:: tinytc::print_to_file(tinytc_prog_t, char const*) + +.. _tinytc::print_to_string(tinytc_prog_t): + +print_to_string(tinytc_prog_t) +.............................. + +.. doxygenfunction:: tinytc::print_to_string(tinytc_prog_t) + +SPIR-V module +============= + +* Functions + + * :ref:`tinytc::dump(const_tinytc_spv_mod_t)` + + * :ref:`tinytc::print_to_file(const_tinytc_spv_mod_t, char const\*)` + + * :ref:`tinytc::print_to_string(const_tinytc_spv_mod_t)` + +SPIR-V module Functions +----------------------- + +.. _tinytc::dump(const_tinytc_spv_mod_t): + +dump(const_tinytc_spv_mod_t) +............................ + +.. doxygenfunction:: tinytc::dump(const_tinytc_spv_mod_t) + +.. _tinytc::print_to_file(const_tinytc_spv_mod_t, char const\*): + +print_to_file(const_tinytc_spv_mod_t, char const\*) +................................................... + +.. doxygenfunction:: tinytc::print_to_file(const_tinytc_spv_mod_t, char const*) + +.. _tinytc::print_to_string(const_tinytc_spv_mod_t): + +print_to_string(const_tinytc_spv_mod_t) +....................................... -.. doxygenclass:: tinytc::source_context +.. doxygenfunction:: tinytc::print_to_string(const_tinytc_spv_mod_t) diff --git a/docs/api/core_cxxapi.yaml b/docs/api/core_cxxapi.yaml index 990f9cf6..7801022e 100644 --- a/docs/api/core_cxxapi.yaml +++ b/docs/api/core_cxxapi.yaml @@ -3,64 +3,104 @@ Core C++-API: Common: enum: + - tinytc::mem_type - tinytc::status - tinytc::support_level function: - - tinytc::error_string - tinytc::CHECK_STATUS - tinytc::CHECK_STATUS_LOC + - tinytc::to_string(status) + - tinytc::to_string(support_level) class: + - tinytc::array_view_base + - tinytc::array_view + - tinytc::mutable_array_view - tinytc::shared_handle - tinytc::unique_handle + struct: + - tinytc::auto_mem_type + - tinytc::auto_mem_type< T, std::enable_if_t< is_usm_pointer_type< T > > > + - tinytc::mem + variable: + - tinytc::auto_mem_type_v + - tinytc::is_supported_scalar_type + - tinytc::is_usm_pointer_type Binary: enum: - tinytc::bundle_format function: - - tinytc::make_binary - class: - - tinytc::binary + - tinytc::create_binary + - tinytc::get_compiler_context(const_tinytc_binary_t) + - tinytc::get_core_features(const_tinytc_binary_t) + - tinytc::get_raw + - tinytc::to_string(bundle_format) + struct: + - tinytc::raw_binary Compiler: function: - - tinytc::compile_to_opencl + - tinytc::run_function_pass + - tinytc::list_function_passes + - tinytc::compile_to_spirv + - tinytc::compile_to_spirv_and_assemble + - tinytc::spirv_assemble + Compiler Context: + enum: + - tinytc::optflag + function: + - tinytc::add_source + - tinytc::create_compiler_context + - tinytc::set_error_reporter + - tinytc::set_optimization_flag + - tinytc::set_optimization_level + - tinytc::report_error + - tinytc::to_string(optflag) Device Info: enum: - tinytc::core_feature_flag - tinytc::intel_gpu_architecture + - tinytc::spirv_feature + function: + - tinytc::create_core_info_generic + - tinytc::create_core_info_intel + - tinytc::create_core_info_intel_from_arch + - tinytc::create_core_info_intel_from_name + - tinytc::get_core_features(const_tinytc_core_info_t) + - tinytc::get_subgroup_sizes + - tinytc::get_register_space + - tinytc::have_spirv_feature + - tinytc::set_core_features + - tinytc::set_default_alignment + - tinytc::set_spirv_feature + - tinytc::to_string(core_feature_flag) + - tinytc::to_string(intel_gpu_architecture) + - tinytc::to_string(spirv_feature) + FP math: function: - - tinytc::make_core_info_generic - - tinytc::make_core_info_intel - - tinytc::make_core_info_intel_from_arch + - tinytc::ieee754_extend + - tinytc::ieee754_truncate class: - - tinytc::core_info + - tinytc::lp_float + struct: + - tinytc::ieee754_format + typedef: + - tinytc::bf16_format + - tinytc::bfloat16 + - tinytc::f16_format + - tinytc::f32_format + - tinytc::half Parser: function: - tinytc::parse_file - tinytc::parse_stdin - tinytc::parse_string - Recipe: - enum: - - tinytc::mem_type + Program: function: - - tinytc::make_small_gemm_batched - - tinytc::make_tall_and_skinny - - tinytc::make_tall_and_skinny_specialized - class: - - tinytc::recipe - - tinytc::recipe_handler - - tinytc::small_gemm_batched - - tinytc::tall_and_skinny - struct: - - tinytc::auto_mem_type - - tinytc::auto_mem_type< T, std::enable_if_t< usm_pointer_type< T > > > - - tinytc::mem - variable: - - tinytc::auto_mem_type_v - - tinytc::usm_pointer_type - Source: - class: - - tinytc::source - Source Context: + - tinytc::dump(tinytc_prog_t) + - tinytc::get_compiler_context(const_tinytc_prog_t) + - tinytc::print_to_file(tinytc_prog_t, char const*) + - tinytc::print_to_string(tinytc_prog_t) + SPIR-V module: function: - - tinytc::make_source_context - class: - - tinytc::source_context + - tinytc::dump(const_tinytc_spv_mod_t) + - tinytc::print_to_file(const_tinytc_spv_mod_t, char const*) + - tinytc::print_to_string(const_tinytc_spv_mod_t) diff --git a/docs/api/sycl/cxxapi.rst b/docs/api/sycl/cxxapi.rst index 82417e6c..396a6352 100644 --- a/docs/api/sycl/cxxapi.rst +++ b/docs/api/sycl/cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _SYCL C++-API: + ======= C++-API ======= @@ -10,18 +12,22 @@ Device Info * Functions - * :ref:`get_support_level(sycl::device const&)` + * :ref:`tinytc::get_support_level(sycl::device const&)` - * :ref:`make_core_info(sycl::device const&)` + * :ref:`tinytc::make_core_info(sycl::device const&)` Device Info Functions --------------------- +.. _tinytc::get_support_level(sycl::device const&): + get_support_level(sycl::device const&) ...................................... .. doxygenfunction:: tinytc::get_support_level(sycl::device const&) +.. _tinytc::make_core_info(sycl::device const&): + make_core_info(sycl::device const&) ................................... @@ -32,89 +38,122 @@ Kernel * Functions - * :ref:`get_execution_range` + * :ref:`tinytc::get_execution_range` - * :ref:`get_global_size(std::int64_t,sycl::range\<3u\> const &)` + * :ref:`tinytc::get_global_size(sycl::range\<3u\> const &,sycl::range\<3u\> const &)` - * :ref:`get_group_size(sycl::kernel const &)` + * :ref:`tinytc::get_group_size(sycl::kernel const &)` - * :ref:`make_kernel(sycl::kernel_bundle\ const &,char const \\*)` + * :ref:`tinytc::make_kernel(sycl::kernel_bundle\ const &,char const \*)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context)` + * :ref:`tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,const_tinytc_binary_t)` - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context)` - - * :ref:`make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context)` + * :ref:`tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,tinytc_prog_t,tinytc_core_feature_flags_t)` Kernel Functions ---------------- +.. _tinytc::get_execution_range: + get_execution_range ................... .. doxygenfunction:: tinytc::get_execution_range -get_global_size(std::int64_t,sycl::range<3u> const &) -..................................................... +.. _tinytc::get_global_size(sycl::range\<3u\> const &,sycl::range\<3u\> const &): + +get_global_size(sycl::range<3u> const &,sycl::range<3u> const &) +................................................................ -.. doxygenfunction:: tinytc::get_global_size(std::int64_t,sycl::range<3u> const &) +.. doxygenfunction:: tinytc::get_global_size(sycl::range<3u> const &,sycl::range<3u> const &) + +.. _tinytc::get_group_size(sycl::kernel const &): get_group_size(sycl::kernel const &) .................................... .. doxygenfunction:: tinytc::get_group_size(sycl::kernel const &) +.. _tinytc::make_kernel(sycl::kernel_bundle\ const &,char const \*): + make_kernel(sycl::kernel_bundle const &,char const \*) ...................................................................................... .. doxygenfunction:: tinytc::make_kernel(sycl::kernel_bundle const &,char const *) -make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) -............................................................................................ +.. _tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,const_tinytc_binary_t): -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) +make_kernel_bundle(sycl::context const &,sycl::device const &,const_tinytc_binary_t) +.................................................................................... -make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) -.............................................................................................................. +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,const_tinytc_binary_t) -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) +.. _tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,tinytc_prog_t,tinytc_core_feature_flags_t): -make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) -............................................................................................ +make_kernel_bundle(sycl::context const &,sycl::device const &,tinytc_prog_t,tinytc_core_feature_flags_t) +........................................................................................................ -.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,tinytc_prog_t,tinytc_core_feature_flags_t) Recipe ====== * Functions - * :ref:`make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context)` + * :ref:`tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,tinytc_recipe_t)` + + * :ref:`tinytc::make_recipe_handler(sycl::queue const&,tinytc_recipe_t)` - * :ref:`make_recipe_handler(sycl::queue const&,recipe const&,source_context)` + * :ref:`tinytc::parallel_for` -* Classes + * :ref:`tinytc::submit(tinytc_recipe_handler_t,sycl::queue)` - * :ref:`sycl_recipe_handler` + * :ref:`tinytc::submit(tinytc_recipe_handler_t,sycl::queue,std::vector\ const &)` + + * :ref:`tinytc::submit(tinytc_recipe_handler_t,sycl::queue,sycl::event const &)` Recipe Functions ---------------- -make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) -............................................................................................. +.. _tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,tinytc_recipe_t): -.. doxygenfunction:: tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) +make_recipe_handler(sycl::context const &,sycl::device const &,tinytc_recipe_t) +............................................................................... -make_recipe_handler(sycl::queue const&,recipe const&,source_context) -.................................................................... +.. doxygenfunction:: tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,tinytc_recipe_t) -.. doxygenfunction:: tinytc::make_recipe_handler(sycl::queue const&,recipe const&,source_context) +.. _tinytc::make_recipe_handler(sycl::queue const&,tinytc_recipe_t): -Recipe Classes --------------- +make_recipe_handler(sycl::queue const&,tinytc_recipe_t) +....................................................... -sycl_recipe_handler -................... +.. doxygenfunction:: tinytc::make_recipe_handler(sycl::queue const&,tinytc_recipe_t) + +.. _tinytc::parallel_for: + +parallel_for +............ + +.. doxygenfunction:: tinytc::parallel_for + +.. _tinytc::submit(tinytc_recipe_handler_t,sycl::queue): + +submit(tinytc_recipe_handler_t,sycl::queue) +........................................... + +.. doxygenfunction:: tinytc::submit(tinytc_recipe_handler_t,sycl::queue) + +.. _tinytc::submit(tinytc_recipe_handler_t,sycl::queue,std::vector\ const &): + +submit(tinytc_recipe_handler_t,sycl::queue,std::vector const &) +............................................................................ + +.. doxygenfunction:: tinytc::submit(tinytc_recipe_handler_t,sycl::queue,std::vector const &) + +.. _tinytc::submit(tinytc_recipe_handler_t,sycl::queue,sycl::event const &): + +submit(tinytc_recipe_handler_t,sycl::queue,sycl::event const &) +............................................................... -.. doxygenclass:: tinytc::sycl_recipe_handler +.. doxygenfunction:: tinytc::submit(tinytc_recipe_handler_t,sycl::queue,sycl::event const &) diff --git a/docs/api/sycl/cxxapi.yaml b/docs/api/sycl/cxxapi.yaml index 2d3416f6..5323dc7c 100644 --- a/docs/api/sycl/cxxapi.yaml +++ b/docs/api/sycl/cxxapi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C++-API: +C++-API : Device Info: function: - tinytc::get_support_level(sycl::device const&) @@ -8,15 +8,16 @@ C++-API: Kernel: function: - tinytc::get_execution_range - - tinytc::get_global_size(std::int64_t,sycl::range<3u> const &) + - tinytc::get_global_size(sycl::range<3u> const &,sycl::range<3u> const &) - tinytc::get_group_size(sycl::kernel const &) - tinytc::make_kernel(sycl::kernel_bundle const &,char const *) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,binary const &,source_context) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,source const &,source_context) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,const_tinytc_binary_t) + - tinytc::make_kernel_bundle(sycl::context const &,sycl::device const &,tinytc_prog_t,tinytc_core_feature_flags_t) Recipe: function: - - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,recipe const &,source_context) - - tinytc::make_recipe_handler(sycl::queue const&,recipe const&,source_context) - class: - - tinytc::sycl_recipe_handler + - tinytc::make_recipe_handler(sycl::context const &,sycl::device const &,tinytc_recipe_t) + - tinytc::make_recipe_handler(sycl::queue const&,tinytc_recipe_t) + - tinytc::parallel_for + - tinytc::submit(tinytc_recipe_handler_t,sycl::queue) + - tinytc::submit(tinytc_recipe_handler_t,sycl::queue,std::vector const &) + - tinytc::submit(tinytc_recipe_handler_t,sycl::queue,sycl::event const &) diff --git a/docs/api/ze/capi.rst b/docs/api/ze/capi.rst index d0c0e6ca..d406401f 100644 --- a/docs/api/ze/capi.rst +++ b/docs/api/ze/capi.rst @@ -1,25 +1,12 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Level Zero C-API: + ===== C-API ===== -Common -====== - -* Functions - - * :ref:`tinytc_ze_convert_status` - -Common Functions ----------------- - -tinytc_ze_convert_status -........................ - -.. doxygenfunction:: tinytc_ze_convert_status - Device Info =========== @@ -32,11 +19,15 @@ Device Info Device Info Functions --------------------- +.. _tinytc_ze_core_info_create: + tinytc_ze_core_info_create .......................... .. doxygenfunction:: tinytc_ze_core_info_create +.. _tinytc_ze_get_support_level: + tinytc_ze_get_support_level ........................... @@ -47,8 +38,6 @@ Kernel * Functions - * :ref:`tinytc_ze_get_group_count` - * :ref:`tinytc_ze_get_group_size` * :ref:`tinytc_ze_kernel_create` @@ -57,48 +46,37 @@ Kernel * :ref:`tinytc_ze_kernel_bundle_create_with_program` - * :ref:`tinytc_ze_kernel_bundle_create_with_source` - - * :ref:`tinytc_ze_source_compile_to_binary` - Kernel Functions ---------------- -tinytc_ze_get_group_count -......................... - -.. doxygenfunction:: tinytc_ze_get_group_count +.. _tinytc_ze_get_group_size: tinytc_ze_get_group_size ........................ .. doxygenfunction:: tinytc_ze_get_group_size +.. _tinytc_ze_kernel_create: + tinytc_ze_kernel_create ....................... .. doxygenfunction:: tinytc_ze_kernel_create +.. _tinytc_ze_kernel_bundle_create_with_binary: + tinytc_ze_kernel_bundle_create_with_binary .......................................... .. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_binary +.. _tinytc_ze_kernel_bundle_create_with_program: + tinytc_ze_kernel_bundle_create_with_program ........................................... .. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_program -tinytc_ze_kernel_bundle_create_with_source -.......................................... - -.. doxygenfunction:: tinytc_ze_kernel_bundle_create_with_source - -tinytc_ze_source_compile_to_binary -.................................. - -.. doxygenfunction:: tinytc_ze_source_compile_to_binary - Recipe ====== @@ -111,11 +89,15 @@ Recipe Recipe Functions ---------------- +.. _tinytc_ze_recipe_handler_create: + tinytc_ze_recipe_handler_create ............................... .. doxygenfunction:: tinytc_ze_recipe_handler_create +.. _tinytc_ze_recipe_handler_submit: + tinytc_ze_recipe_handler_submit ............................... diff --git a/docs/api/ze/capi.yaml b/docs/api/ze/capi.yaml index acfde096..2eca2023 100644 --- a/docs/api/ze/capi.yaml +++ b/docs/api/ze/capi.yaml @@ -1,22 +1,16 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C-API: - Common: - function: - - tinytc_ze_convert_status +C-API : Device Info: function: - tinytc_ze_core_info_create - tinytc_ze_get_support_level Kernel: function: - - tinytc_ze_get_group_count - tinytc_ze_get_group_size - tinytc_ze_kernel_create - tinytc_ze_kernel_bundle_create_with_binary - tinytc_ze_kernel_bundle_create_with_program - - tinytc_ze_kernel_bundle_create_with_source - - tinytc_ze_source_compile_to_binary Recipe: function: - tinytc_ze_recipe_handler_create diff --git a/docs/api/ze/cxxapi.rst b/docs/api/ze/cxxapi.rst index d11b47b5..1c3d907a 100644 --- a/docs/api/ze/cxxapi.rst +++ b/docs/api/ze/cxxapi.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _Level Zero C++-API: + ======= C++-API ======= @@ -10,11 +12,13 @@ Common * Functions - * :ref:`ZE_CHECK_STATUS` + * :ref:`tinytc::ZE_CHECK_STATUS` Common Functions ---------------- +.. _tinytc::ZE_CHECK_STATUS: + ZE_CHECK_STATUS ............... @@ -25,18 +29,22 @@ Device Info * Functions - * :ref:`get_support_level(ze_device_handle_t)` + * :ref:`tinytc::get_support_level(ze_device_handle_t)` - * :ref:`make_core_info(ze_device_handle_t)` + * :ref:`tinytc::make_core_info(ze_device_handle_t)` Device Info Functions --------------------- +.. _tinytc::get_support_level(ze_device_handle_t): + get_support_level(ze_device_handle_t) ..................................... .. doxygenfunction:: tinytc::get_support_level(ze_device_handle_t) +.. _tinytc::make_core_info(ze_device_handle_t): + make_core_info(ze_device_handle_t) .................................. @@ -47,82 +55,68 @@ Kernel * Functions - * :ref:`compile_to_binary` - - * :ref:`get_group_count` - - * :ref:`get_group_size(ze_kernel_handle_t)` - - * :ref:`make_kernel(ze_module_handle_t,char const \\*)` + * :ref:`tinytc::get_group_size(ze_kernel_handle_t)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context)` + * :ref:`tinytc::make_kernel(ze_module_handle_t,char const \*)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context)` + * :ref:`tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,const_tinytc_binary_t)` - * :ref:`make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context)` + * :ref:`tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,tinytc_prog_t,tinytc_core_feature_flags_t)` Kernel Functions ---------------- -compile_to_binary -................. - -.. doxygenfunction:: tinytc::compile_to_binary - -get_group_count -............... - -.. doxygenfunction:: tinytc::get_group_count +.. _tinytc::get_group_size(ze_kernel_handle_t): get_group_size(ze_kernel_handle_t) .................................. .. doxygenfunction:: tinytc::get_group_size(ze_kernel_handle_t) +.. _tinytc::make_kernel(ze_module_handle_t,char const \*): + make_kernel(ze_module_handle_t,char const \*) ............................................. .. doxygenfunction:: tinytc::make_kernel(ze_module_handle_t,char const *) -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) -....................................................................................... +.. _tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,const_tinytc_binary_t): -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,const_tinytc_binary_t) +................................................................................ -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) -.......................................................................................................... +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,const_tinytc_binary_t) -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) +.. _tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,tinytc_prog_t,tinytc_core_feature_flags_t): -make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) -....................................................................................... +make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,tinytc_prog_t,tinytc_core_feature_flags_t) +.................................................................................................... -.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) +.. doxygenfunction:: tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,tinytc_prog_t,tinytc_core_feature_flags_t) Recipe ====== * Functions - * :ref:`make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context)` - -* Classes + * :ref:`tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,tinytc_recipe_t)` - * :ref:`level_zero_recipe_handler` + * :ref:`tinytc::submit(tinytc_recipe_handler_t,ze_command_list_handle_t,ze_event_handle_t,uint32_t,ze_event_handle_t\*)` Recipe Functions ---------------- -make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) -........................................................................................ +.. _tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,tinytc_recipe_t): + +make_recipe_handler(ze_context_handle_t,ze_device_handle_t,tinytc_recipe_t) +........................................................................... -.. doxygenfunction:: tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) +.. doxygenfunction:: tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,tinytc_recipe_t) -Recipe Classes --------------- +.. _tinytc::submit(tinytc_recipe_handler_t,ze_command_list_handle_t,ze_event_handle_t,uint32_t,ze_event_handle_t\*): -level_zero_recipe_handler -......................... +submit(tinytc_recipe_handler_t,ze_command_list_handle_t,ze_event_handle_t,uint32_t,ze_event_handle_t\*) +....................................................................................................... -.. doxygenclass:: tinytc::level_zero_recipe_handler +.. doxygenfunction:: tinytc::submit(tinytc_recipe_handler_t,ze_command_list_handle_t,ze_event_handle_t,uint32_t,ze_event_handle_t*) diff --git a/docs/api/ze/cxxapi.yaml b/docs/api/ze/cxxapi.yaml index 4308b3ed..ccf94198 100644 --- a/docs/api/ze/cxxapi.yaml +++ b/docs/api/ze/cxxapi.yaml @@ -1,6 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: BSD-3-Clause -C++-API: +C++-API : Common: function: - tinytc::ZE_CHECK_STATUS @@ -10,15 +10,11 @@ C++-API: - tinytc::make_core_info(ze_device_handle_t) Kernel: function: - - tinytc::compile_to_binary - - tinytc::get_group_count - tinytc::get_group_size(ze_kernel_handle_t) - tinytc::make_kernel(ze_module_handle_t,char const *) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,binary const&,source_context) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,prog,tinytc_core_feature_flags_t,source_context) - - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,source const&,source_context) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,const_tinytc_binary_t) + - tinytc::make_kernel_bundle(ze_context_handle_t,ze_device_handle_t,tinytc_prog_t,tinytc_core_feature_flags_t) Recipe: function: - - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,recipe const&,source_context) - class: - - tinytc::level_zero_recipe_handler + - tinytc::make_recipe_handler(ze_context_handle_t,ze_device_handle_t,tinytc_recipe_t) + - tinytc::submit(tinytc_recipe_handler_t,ze_command_list_handle_t,ze_event_handle_t,uint32_t,ze_event_handle_t*) diff --git a/docs/conf.py b/docs/conf.py index 3d0d7c8e..e04c88c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: BSD-3-Clause project = 'Tiny Tensor Compiler' -copyright = '2024, Intel Corporation' +copyright = '2025, Intel Corporation' author = 'Intel' -extensions = ['breathe', 'sphinx.ext.autosectionlabel', 'sphinx.ext.mathjax', 'sphinx_tabs.tabs'] +extensions = ['breathe', 'sphinx.ext.mathjax', 'sphinx_tabs.tabs'] templates_path = ['_templates'] exclude_patterns = [] diff --git a/docs/dev/coopmatrix_layout.rst b/docs/dev/coopmatrix_layout.rst new file mode 100644 index 00000000..3b4f23e1 --- /dev/null +++ b/docs/dev/coopmatrix_layout.rst @@ -0,0 +1,219 @@ +.. Copyright (C) 2025 Intel Corporation + SPDX-License-Identifier: BSD-3-Clause + +.. _coopmatrix layout: + +================= +Coopmatrix layout +================= + +A cooperative matrix is distributed to work-items such that each work-item holds an equal +share of the matrix, potentially with zero-padding if the matrix size is not divisible by +the subgroup size. + +General layout +============== + +Let a coopmatrix :math:`A` of size :math:`M\times N` be given and let the subgroup size +be given by :math:`S`. +We require :math:`M,S` to be powers of two and to be greater or equal than 1. +Internally, we represent the :math:`A` matrix as the :math:`I \times K_1\times J \times K_2` tensor :math:`A^*`. +The mapping of :math:`A` to :math:`A^*` is + +.. math:: + + A^*_{i,k_1,j,k_2} = \left\{\begin{array}{rcl} + A_{i+k_1I+k_2IK_1,j} & \text{ if } & i+k_1I+k_2IK_1 < M \wedge j < N, \\ + 0 & \text{ else.} + \end{array}\right. + +The shape of :math:`A^*` is given by :math:`(I,K_1,J,K_2)`, where + +.. math:: + + \begin{aligned} + I &:= \min(M, S),\\ + J &:= \min(\{n\in\mathbb N : n \geq N \wedge (In) \bmod S = 0\})\\ + K &:= K_1K_2 = M/I.\\ + \end{aligned} + +As both :math:`S` and :math:`I` are powers of two, an explicit formula for :math:`J` is given by +:math:`J = (\lceil IN/S\rceil S) / I`. + +Work-item mapping +----------------- + +We linearize the index of the :math:`A^*` tensor canonically: + +.. math:: + + L(i,j,k) := i + k_1I + j IK_1 + k_2 IK_1J + +Every work-item stores a vector :math:`v` with :math:`V:=IKJ/S` components. +We define the per work-item vector as + +.. math:: + + W^p := (v \in [V] : A^*[L^{-1}(p+vS)]), + +where :math:`p=0,\dots,S-1`. + +An example is helpful at this point. +Say we have a :math:`4\times 15` matrix and subgroup size :math:`S=16`, then the following table +shows how the work-item id maps to the 2d matrix index (the per work-item vectors are given by the +columns): + +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +p 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +.x 0,0 1,0 2,0 3,0 0,1 1,1 2,1 3,1 0,2 1,2 2,2 3,2 0,3 1,3 2,3 3,3 +.y 0,4 1,4 2,4 3,4 0,5 1,5 2,5 3,5 0,6 1,6 2,6 3,6 0,7 1,7 2,7 3,7 +.z 0,8 1,8 2,8 3,8 0,9 1,9 2,9 3,9 0,10 1,10 2,10 3,10 0,11 1,11 2,11 3,11 +.w 0,12 1,12 2,12 3,12 0,13 1,13 2,13 3,13 0,14 1,14 2,14 3,14 -, - -, - -, - -, - +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + +For a :math:`1\times 17` coopmatrix we have + +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +p 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +.x 0,0 0,1 0,2 0,3 0,4 0,5 0,6 0,7 0,8 0,9 0,10 0,11 0,12 0,13 0,14 0,15 +.y 0,16 -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- -,- +== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + +Mapping properties +------------------ + +The inverse of :math:`L` is + +.. math:: + + \begin{aligned} + i &= L \bmod I, \\ + k_1 &= \lfloor L / I \rfloor \bmod K_1, \\ + j &= \lfloor L / (IK_1)\rfloor \bmod J, \\ + k_2 &= \lfloor L / (IK_1J)\rfloor. \\ + \end{aligned} + +Let :math:`v=w_1 + uK_1 + w_2K_1(V/K)`, with :math:`u=0,\dots,V/K-1`, +:math:`w_1=0,\dots,K_1-1`, and :math:`w=0,\dots,K_2-1`. +(Note that :math:`V/K=IJ/S`.) + +We first assume that :math:`I=S`. Then + +.. math:: + + \begin{aligned} + i &= (p + (w_1 + uK_1 + w_2K_1J)S) \bmod S = p, \\ + k_1 &= \lfloor (p + (w_1 + uK_1 + wK_1J)S) / S \rfloor \bmod K_1 = w_1, \\ + j &= \lfloor (p + (w_1 + uK_1 + wK_1J)S) / (SK_1) \rfloor \bmod J = u, \\ + k_2 &= \lfloor (p + (w_1 + uK_1 + wK_1J)S) / (SK_1J)\rfloor = w_2, \\ + \end{aligned} + +Now we assume :math:`I + vector load_coopmatrix(RealT* B, int pos0, int pos1, int shape0, int shape1, + int stride0, int stride1) { + constexpr int S = get_sub_group_size(); + constexpr int I = min(M,S); + constexpr int J = ceil(I*N/S)*S/I; + constexpr int K = M/I; + static_assert(K%K1 == 0); + constexpr int K2 = K/K1; + constexpr bool needs_mask = J*S/I > N; + + if (Transpose) { + std::swap(shape0, shape1); + std::swap(stride0, stride1); + } + + constexpr int V = I*K*J/S; + array R; + int p = get_sub_group_local_id(); + int i0 = p % I; + int j0 = p / I; + for (int w1 = 0; w1 < K1; ++w1) { + for (int w2 = 0; w2 < K2; ++w2) { + int k1 = w1, k2 = w2; + int row = pos0 + i0 + (k1 + k2*K1)*I; + bool row_ok = !RowsChecked || (row >= 0 && row < shape0); + if (row_ok) { + for (int u = 0; u < V/K; ++u) { + int j = j0 + u*(S/I); + int col = pos1 + j; + bool col_ok = !ColsChecked || (col >= 0 && col < shape1); + bool mask_ok = !needs_mask || j < N; + R[w1 + u*K1 + w2*K1*(V/K)] = mask_ok && col_ok ? A[row * stride0 + col * stride1] : 0; + } + } else { + for (int u = 0; u < V/K; ++u) { + R[w1 + u*K1 + w2*K1*(V/K)] = 0; + } + } + } + } + return R; + } + +Matrix Accumulator +================== + +For cooperative matrices with use *matrix_acc* we always have :math:`K_1=1`. + +Matrix A +======== + +For cooperative matrices with use *matrix_a* we always have :math:`K_1=1`. +Moreover, low precision matrices are VNNI transformed if :math:`N` is a multiple of :math:`\omega`, +where :math:`\omega=\max(1, \max(1,4/\text{size}(\text{ty}))` is the number of operands per channel. +We store :math:`A^*` tensors internally +as the :math:`\omega\times I \times K_1\times \lceil J/\omega\rceil \times K_2` tensor :math:`A^{**}`, +The mapping from :math:`A^*` to :math:`A^{**}` is given by + +.. math:: + + A^{**}_{c,i,k_1,j,k_2} := A^{*}_{i,k_1,c+j\omega,k_2} + +and the inverse mapping is given by + +.. math:: + + A^{*}_{i,k_1,j,k_2} = A^{**}_{j\bmod\omega,i,k_1,\lfloor j/\omega\rfloor,k_2}. + +Moreover, all channels of an entry are packed. E.g. for half precision floats we have two channels +and we pack :math:`A^{**}_{0,i,k_1,j,k_2}` in the lower 16 bits of an i32 and :math:`A^{**}_{1,i,k_1,j,k_2}` +in the higher 16 bits of the i32. +We store an i32 per work-item, so from the SIMT point of view one work-item owns all channels of an entry. + +Matrix B +======== + +Let :math:`\omega_b = \max(1,2/\text{size}(\text{ty}))`. +We choose :math:`K_1 = \omega_b \text{ if } M/S > 1 \text{ else } 1`. diff --git a/docs/index.rst b/docs/index.rst index 6a1226af..b37c8c47 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,12 @@ Table of contents api/sycl/index api/ze/index +.. toctree:: + :maxdepth: 2 + :caption: Developer guide + + dev/coopmatrix_layout + Index ----- diff --git a/docs/manual/build.rst b/docs/manual/build.rst index a1fd4f0b..8001aef4 100644 --- a/docs/manual/build.rst +++ b/docs/manual/build.rst @@ -8,20 +8,19 @@ Building and linking Dependencies ============ -- `CMake `_ >= 3.23 -- `Intel oneAPI Base Toolkit `_ -- OpenCL -- Level Zero -- `ocloc `_ (OpenCL offline compiler from the Intel Compute Runtime) -- `Double-Batched FFT Library `_ >= 0.5.1 -- `re2c `_ >= 3.0 -- `bison `_ >= 3.8.2 +- Mandatory + - `CMake `_ >= 3.23 + - `re2c `_ >= 3.0 + - `bison `_ >= 3.8.2 +- Optional + - `OpenCL ICD loader `_ [OpenCL, SYCL] + - `Level Zero Loader `_ [Level Zero, SYCL] + - `Intel oneAPI Base Toolkit `_ [SYCL] Build from source using oneAPI ============================== -Install CMake, the oneAPI Base Toolkit, the Intel compute runtime, re2c, and bison using your system -package manager. +Install the dependencies via your favourite package manager. Initialize the oneAPI environment. @@ -29,51 +28,38 @@ Initialize the oneAPI environment. . /opt/intel/oneapi/setvars.sh -Clone the Double-Batched FFT library to your filesystem. -.. code:: console - - git clone https://github.com/intel/double-batched-fft-library.git - -We only need libclir.so that we build and install using the following steps. - -.. code:: console - - cd double-batched-fft-library/clir/ - cmake -Bbuild -GNinja -DCMAKE_CXX_COMPILER=icpx -DCMAKE_INSTALL_PREFIX=$(pwd)/../../install/ \ - -DCMAKE_CXX_FLAGS="-ffp-model=precise" -DBUILD_SHARED_LIBS=YES - cmake --build build - cmake --install build - cd ../.. - -Then, build and install Tiny Tensor Compiler with the following steps +Build and install Tiny Tensor Compiler with the following steps .. code:: console git clone https://github.com/intel/tiny-tensor-compiler.git tinytc cd tinytc cmake -Bbuild -GNinja -DCMAKE_CXX_COMPILER=icpx -DCMAKE_INSTALL_PREFIX=$(pwd)/../install/ \ - -DCMAKE_PREFIX_PATH=$(pwd)/../install -DBUILD_SHARED_LIBS=YES + -DBUILD_SHARED_LIBS=YES cmake --build build cmake --install build cd .. -If you need a static library, set `-DBUILD_SHARED_LIBS=NO` when compiling libclir and the Tiny Tensor Compiler. +If you need a static library, set `-DBUILD_SHARED_LIBS=NO` when compiling the Tiny Tensor Compiler. +Adjust `CMAKE_INSTALL_PREFIX` to control the installation directory. +(Can be left empty for a system install; needs sudo.) Build options ============= The following CMake option options are supported. -====================== ============ -Option Description -====================== ============ -BUILD_DOCUMENTATION Build the documentation -BUILD_TESTING Build unit tests -BUILD_LEVEL_ZERO Build libtinytc_ze for Level Zero support (enforced if BUILD_SYCL=ON) -BUILD_OPENCL Build libtinytc_cl for OpenCL support (enforced if BUILD_SYCL=ON) -BUILD_SYCL Build libtinytc_sycl for SYCL support -====================== ============ +============================ ========================================================================= +Option Description +============================ ========================================================================= +BUILD_DOCUMENTATION Build the documentation +BUILD_DOUBLE_PRECISION_TESTS Build unit tests for double precision (e.g. for iGPUs without DP support) +BUILD_TESTING Build unit tests +BUILD_LEVEL_ZERO Build libtinytc_ze for Level Zero support (enforced if BUILD_SYCL=ON) +BUILD_OPENCL Build libtinytc_cl for OpenCL support (enforced if BUILD_SYCL=ON) +BUILD_SYCL Build libtinytc_sycl for SYCL support +============================ ========================================================================= Linking in a CMake project ========================== @@ -91,7 +77,7 @@ in your CMakeLists.txt to find the Tiny Tensor Compiler. For non-standard installation directories add -DCMAKE_PREFIX_PATH=/path/to/installation when invoking cmake. -Runtime support is split in the three library libtinytc_ze, libtinytc_cl, and libtinytc_sycl. +Runtime support is split in the three libraries libtinytc_ze, libtinytc_cl, and libtinytc_sycl. The BUILD_(LEVEL_ZERO, OPENCL, SYCL) options control which libraries are built, respectively. For example, when using OpenCL only, you can set BUILD_SYCL=OFF such that you do not need a C++ compiler with SYCL support. diff --git a/docs/manual/builder.rst b/docs/manual/builder.rst index f05a67f4..9cb2cf7f 100644 --- a/docs/manual/builder.rst +++ b/docs/manual/builder.rst @@ -21,7 +21,9 @@ Consider the following simple copy kernel .. code-block:: func @copy(%A: memref<${type}x${M}x${N}>, %B: memref<${type}x${M}x${M}>) { - axpby.n 1.0, %A, 0.0, %B : ${type}, memref<${type}x${M}x${N}>, ${type}, memref<${type}x${M}x${N}> + %c0 = constant 0.0 : ${type} + %c1 = constant 1.0 : ${type} + axpby %c1, %A, %c0, %B } In the following example we build the above code programmatically and replace the place-holders (${.}) @@ -33,63 +35,91 @@ by actual values: .. code:: C - tinytc_scalar_type_t type = ...; int64_t M = ...; int64_t N = ...; - tinytc_data_type_t dt; + char const *copy_fun_name = "copy"; + size_t num_results; + size_t num_params; + tinytc_compiler_context_t ctx; + tinytc_prog_t program; + tinytc_type_t void_ty, element_ty, ty; + tinytc_func_t copy_fun; + tinytc_region_t copy_body; + tinytc_inst_t tmp; + tinytc_value_t params[2]; + tinytc_value_t alpha, beta; + + tinytc_compiler_context_create(&ctx); + + // Create program + tinytc_prog_create(&program, ctx, NULL); + + // Get types + tinytc_f32_type_get(&element_ty, ctx); int64_t shape[2] = {M, N}; - tinytc_memref_type_create(&dt, type, 2, shape, 0, NULL, NULL); - - tinytc_value_t A, B, alpha, beta; - tinytc_value_create(&A, dt, NULL); - tinytc_value_create(&B, dt, NULL); - tinytc_float_imm_create(&alpha, 1.0, type, NULL); - tinytc_float_imm_create(&beta, 0.0, type, NULL); - tinytc_data_type_release(dt); - - tinytc_inst_t copy_inst; - tinytc_axpby_inst_create(©_inst, tinytc_transpose_N, 0, alpha, A, beta, B, NULL); - tinytc_value_release(alpha); - tinytc_value_release(beta); - - tinytc_func_t copy_proto; - tinytc_value_t args[2] = {A, B}; - tinytc_function_prototype_create(©_proto, "copy", 2, args, NULL); - tinytc_value_release(A); - tinytc_value_release(B); + tinytc_memref_type_get(&ty, element_ty, 2, shape, 0, NULL, tinytc_address_space_global); - tinytc_region_t copy_body; - tinytc_region_create(©_body, 1, ©_inst, NULL); - tinytc_inst_release(copy_inst); + // Get void type + tinytc_void_type_get(&void_ty, ctx); - tinytc_func_t copy_fun; - tinytc_function_create(©_fun, copy_proto, copy_body, NULL); - tinytc_func_release(copy_proto); - tinytc_region_release(copy_body); + // Create function + tinytc_type_t param_types[2] = {ty, ty}; + tinytc_func_create(©_fun, sizeof(copy_fun_name) - 1, copy_fun_name, 2, param_types, void_ty, + NULL); + tinytc_prog_add_function(program, copy_fun); - tinytc_prog_t program; - tinytc_program_create(&program, 1, ©_fun, NULL); - tinytc_func_release(copy_fun); + // Get body + tinytc_func_get_body(copy_fun, ©_body); + num_params = 2; + tinytc_region_get_parameters(copy_body, &num_params, params); + + // Create instructions + tinytc_constant_inst_create_one(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &alpha); + tinytc_region_append(copy_body, tmp); + + tinytc_constant_inst_create_zero(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &beta); + tinytc_region_append(copy_body, tmp); + + tinytc_axpby_inst_create(&tmp, 0, tinytc_transpose_N, alpha, params[0], beta, params[1], NULL); + tinytc_region_append(copy_body, tmp); + + // Dump program + tinytc_prog_dump(program); + + // Clean-up + tinytc_prog_release(program); + tinytc_compiler_context_release(ctx); .. tab:: C++ .. code:: C++ - scalar_type type = ...; int64_t M = ...; int64_t N = ...; - auto pb = program_builder{}; - pb.create("copy", [&](function_builder &fb) { - auto dt = make_memref(type, {M, N}); - auto A = fb.argument(dt); - auto B = fb.argument(dt); - fb.body([&](region_builder &bb) { - auto alpha = make_imm(1.0, type); - auto beta = make_imm(0.0, type); - bb.add(make_axpby(transpose::N, false, alpha, A, beta, B)); - }); - }); - auto program = pb.get_product(); + auto ctx = create_compiler_context(); + auto element_ty = get(ctx.get()); + auto ty = get(element_ty, array_view{M, N}, array_view{}, + address_space::global); + + auto void_ty = get(ctx.get()); + auto f = create_func("copy", {ty, ty}, void_ty); + + auto body = get_body(f.get()); + std::array params; + get_parameters(body, params); + + auto bb = region_builder{body}; + auto alpha = bb.constant_one(element_ty); + auto beta = bb.constant_zero(element_ty); + bb.create(false, transpose::N, alpha, params[0], beta, params[1]); + + auto p = create_prog(ctx.get()); + add_function(p.get(), std::move(f)); + dump(p.get()); diff --git a/docs/manual/calling_convention.rst b/docs/manual/calling_convention.rst index f023e5fb..106130cc 100644 --- a/docs/manual/calling_convention.rst +++ b/docs/manual/calling_convention.rst @@ -1,6 +1,8 @@ .. Copyright (C) 2024 Intel Corporation SPDX-License-Identifier: BSD-3-Clause +.. _calling convention: + ================== Calling convention ================== @@ -70,6 +72,17 @@ leads to Note that `memref_example3` and `memref_example4` have the same signature, because `memref` has the canonical stride `strided<1,5,?>`. +.. _memref alignment requirements: + +**Memory alignment:** The base pointer must be sufficiently aligned. +The required alignment depends on the core info and may be queried with +:ref:`tinytc_core_info_get_default_alignment`. +Using :ref:`tinytc_core_info_set_default_alignment` the alignment requirements may be overriden. +The alignment requirement may also be overriden per memref using the +:ref:`"alignment" attribute `. +When the core info object is created using :ref:`tinytc_cl_core_info_create` the default alignment +is queried from `CL_DEVICE_MEM_BASE_ADDR_ALIGN `_. + Group types =========== @@ -77,23 +90,28 @@ A group argument might require multiple arguments in the OpenCL-C code. The rule is that the first argument in the OpenCL kernel is a global pointer to a global pointer to the underlying scalar type of the memref. Then a global pointer argument follows for every '?' in the memref's shape or stride, ordered from left-to-right. -If an dynamic offset is given, the offset is the last argument. +Afterwards, the dynamic group size and the dynamic offset follow if necessary. .. code:: - func @group_example1(%a: group) {} - func @group_example2(%a: group>) {} - func @group_example3(%a: group, offset: ?>) {} + func @group_example1(%a: groupx42) {} + func @group_example2(%a: groupx?>) {} + func @group_example3(%a: groupx?, offset: ?>) {} + func @group_example4(%a: groupx42, offset: ?>) {} leads to .. code:: c kernel void group_example1(global short*global* a) {} - kernel void group_example2(global int*global* a, global long* a_shape1, global long* a_stride2) {} - kernel void group_example3(global float*global* a, global long* a_shape0, long a_offset) {} + kernel void group_example2(global int*global* a, global long* a_shape1, global long* a_stride2, long a_size) {} + kernel void group_example3(global float*global* a, global long* a_shape0, long a_size, long a_offset) {} + kernel void group_example4(global float*global* a, long a_offset) {} -Note that `a_shape_0`, `a_shape1`, and `a_stride2` must contain at least as many values as the group size. -That is, if a is accessed with `load %a[%id] : group>`, then +Note that `a_shape_0`, `a_shape1`, and `a_stride2` must contain at least as many values as the group size (`a_size`). +That is, if a is accessed with `load %a[%id] : memref`, then `*(a_shape0 + id)`, `*(a_shape1 + id)`, and `*(a_stride2 + id)` must not lead to out-of-bounds memory access. + +**Memory alignment:** The memrefs the group points to are subject to the same alignment requirements as a +:ref:`regular memref argument (see above) `. diff --git a/docs/manual/core.rst b/docs/manual/core.rst index 716f752a..2d033193 100644 --- a/docs/manual/core.rst +++ b/docs/manual/core.rst @@ -5,8 +5,14 @@ Core programming guide ====================== -Memory -====== +Memory management +================= + +Objects are either shared, or unique, or managed. +We detail the memory policies in the following. + +Shared objects +-------------- Objects are created, retained, and released. At creation, objects are constructed and the reference count is set to 1. @@ -30,14 +36,60 @@ does not need the object anymore, or it increased the reference count by one. .. tab:: C++ - C handles are wrapped in the :ref:`shared_handle` class. + C handles are wrapped in the :ref:`tinytc::shared_handle` class. The wrapper implements copy constructor, move constructor, copy operator, and move operator, correctly handling the reference count. - Objects of *type* are created using the make\_\ *type* functions. + Objects of *type* are created using the create\_\ *type* functions. + + **Important:** The default constructor of a :ref:`tinytc::shared_handle` or any of its + derivatives always gives an invalid object, wrapping a nullptr. + Always use the create\_\ *type* function unless you know what you are doing. + +Unique objects +-------------- + +Unique objects always have a single owner. In the C-API, when the object is passed to a function, the ownership +may be passed to another object when the function's documentation says so. +For example, when adding an instruction to a region, the region takes ownership of the instruction and the user must not +destroy the instruction as that would lead to a double free. +In C++, the copy constructor is deleted and unique objects must be moved when a function transfers ownership. + +.. tabs:: + + .. tab:: C + + An object is created and deleted with + + * tinytc\_\ *type*\ _create + + * tinytc\_\ *type*\ _destroy + + .. tab:: C++ + + C handles are wrapped in the :ref:`tinytc::unique_handle` class. + The wrapper deletes the copy constructor and copy operator, and implements the move constructor and move operator. + Objects of *type* are created using the create\_\ *type* functions. + +Managed ojects +-------------- + +Some objects are never created or deleted but looked up in a parent object. +In that case the user never needs to destroy the object but only the parent object. +Care must be taken that the parent object is not deleted while the managed object is still in use. + +.. tabs:: + + .. tab:: C + + The object is obtained with + + * tinytc\_\ *type*\ _get - **Important:** The default constructor of a :ref:`shared_handle` or any of its derivatives - always gives an invalid object, wrapping a nullptr. - Always use the make\_\ *type* function unless you know what you are doing. + .. tab:: C++ + + The object is obtained with + + * tinytc::get\_\ *type* Error ===== @@ -53,7 +105,7 @@ The C-API returns error codes, the C++-API throws exceptions. .. tab:: C++ - Functions throw the :ref:`status` enum. + Functions throw the :ref:`tinytc::status` enum. The following minimum error handling code is recommended: .. code:: C++ @@ -61,30 +113,38 @@ The C-API returns error codes, the C++-API throws exceptions. try { ... } catch (tinytc::status const& st) { - std::cerr << static_cast(st) << ": " << tinytc::error_string(st) << std::endl; + std::cerr << static_cast(st) << ": " << tinytc::to_string(st) << std::endl; } catch (std::exception const& e) { std::cerr << e.what() << std::endl; } - **Hint:** The IR builder API throws the :ref:`builder_error` (deriving from std::exception) - instead of the status enum for better source code location tracking. + **Hint:** The IR builder API throws the :ref:`tinytc::builder_error` + (deriving from std::exception) instead of the status enum for better + source code location tracking. Parser ====== Programs written in the :ref:`tensor language ` are parsed from a file, stdin, or a string. -A :ref:`tinytc_source_context_t` (:ref:`source_context`) object can be attached any of the parse functions. -The source context stores the file name and source text and enhances error messages with source code context. -For example, if a parse error occurs, -then the error log of the source context contains the following error message: +The :ref:`tinytc_compiler_context_t` (:ref:`tinytc::compiler_context`) object controls optimization level, optimization flags, +and error logging. (The default compiler context does not print or log errors.) +When an error reporter is installed via :ref:`tinytc_compiler_context_set_error_reporter`, +then errors are printed along with source code locations and source context. +For example: .. code-block:: + test/lit/opt/check-ir/type_mismatch0.ir:6.8-23: Type of operand must match return type + func @kernel(%K0: memref) { - %0 = load %K0[] : memref - ~~~~~~~~~~~~~~~~~~~ - test/codegen/type_mismatch0.ir:6.13-31: Type of SSA value does not match operand type + + %0 = load %K0[] : f64 + ~~~~~~~~~~~~~~~~ + test/lit/opt/check-ir/type_mismatch0.ir:5.14-16: value defined here + + func @kernel(%K0: memref) { + ~~~ .. tabs:: @@ -96,21 +156,20 @@ then the error log of the source context contains the following error message: .. code:: C tinytc_status_t status; - tinytc_source_context_t source_ctx = NULL; + tinytc_compiler_context_t ctx = NULL; tinytc_prog_t program = NULL; - status = tinytc_source_context_create(&source_ctx); + status = tinytc_compiler_context_create(&ctx); // ... check status ... - status = tinytc_parse_file(&program, "test/codegen/type_mismatch0.ir"), source_ctx) + status = tinytc_compiler_context_set_error_reporter(ctx, error_callback, NULL); + // ... check status ... + status = tinytc_parse_file(&program, "test/lit/opt/check-ir/type_mismatch0.ir", ctx) if (status != tinytc_status_success) { printf("Error: %d\n", status); - char const* error_log; - status = tinytc_source_context_get_error_log(source_ctx, &error_log); - // ... check status ... - printf("Error log:\n%s\n", error_log); } // ... + err: tinytc_prog_release(program); - tinytc_source_context_release(source_ctx); + tinytc_compiler_context_release(ctx); .. tab:: C++ @@ -119,29 +178,31 @@ then the error log of the source context contains the following error message: .. code:: C++ try { - auto source_ctx = tinytc::make_source_context(); - auto program = tinytc::parse_file("test/codegen/type_mismatch0.ir", source_ctx); + auto ctx = tinytc::create_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); + auto program = tinytc::parse_file("test/lit/opt/check-ir/type_mismatch0.ir", ctx); } catch (tinytc::status const& st) { - std::cerr << "Error: " << tinytc::error_string(st) << std::endl; - std::cerr << "Error log: " << source_ctx.get_error_log() << std::endl; + std::cerr << "Error: " << tinytc::to_string(st) << std::endl; + } catch (std::exception const &e) { + std::cerr << e.what() << std::endl; } Compiler ======== -Program objects (:ref:`tinytc_prog_t`, :ref:`prog`) are online-compiled -using the :ref:`tinytc_prog_compile_to_opencl` (:ref:`compile_to_opencl`) function. +Program objects (:ref:`tinytc_prog_t`, :ref:`tinytc::prog`) are online-compiled +using the :ref:`tinytc_prog_compile_to_spirv_and_assemble` (:ref:`tinytc::compile_to_spirv_and_assemble`) function. The program object is hereby modified as compiler passes are necessary. -A source object is returned that contains OpenCL-C source text. +A binary object is returned that contains the SPIR-V binary. Some compiler passes specialize the code based on properties of the GPU device. -Therefore, a :ref:`tinytc_core_info_t` (:ref:`core_info`) object is required. +Therefore, a :ref:`tinytc_core_info_t` (:ref:`tinytc::core_info`) object is required. It is recommend to query the core info from the runtime using any of the tinytc\_\ *runtime*\ _core_info_create -functions (make_core_info in C++), but one may also look up the core info from a table, +functions (create_core_info in C++), but one may also look up the core info from a table, as done in the example code below. -A source context can be added to capture potential errors in the optimizer. - .. tabs:: .. tab:: C @@ -152,12 +213,12 @@ A source context can be added to capture potential errors in the optimizer. tinytc_status_t status; tinytc_core_info_t info = NULL; - tinytc_source_t source = NULL; + tinytc_binary_t bin = NULL; status = tinytc_core_info_intel_create_from_arch(&info, tinytc_intel_gpu_architecture_pvc); // ... check status ... - status = tinytc_prog_compile_to_opencl(&source, program, info, source_ctx); + status = tinytc_prog_compile_to_spirv_and_assemble(&bin, program, info); // ... - tinytc_source_release(source); + tinytc_binary_release(source); tinytc_core_info_release(info); .. tab:: C++ @@ -167,19 +228,28 @@ A source context can be added to capture potential errors in the optimizer. .. code:: C++ try { - auto info = tinytc::make_core_info_intel_from_arch(tinytc::intel_gpu_architecture::pvc); - auto source = tinytc::compile_to_opencl(program, info, source_ctx); + auto info = tinytc::create_core_info_intel_from_arch(tinytc::intel_gpu_architecture::pvc); + auto source = tinytc::compile_to_spirv_and_assemble(program, info); } catch (tinytc::status const& st) { ... } .. note:: - Code generation targets OpenCL-C. Currently, the library requires the - `cl_intel_required_subgroup_size `_ extension, - the `cl_intel_subgroups `_ extension, - the `cl_intel_subgroups_long `_ extension, - and the `cl_intel_subgroups_short `_ extension. + Code generation targets SPIR-V. + As a minimum, the Addresses, SubgroupDispatch, and Int64 capability must be supported by the runtime. + + + Further capabilites are required for specific functionality: + + * Int(8|16) for i8, i16 ints + * Float(16|64) for f16, f64 floats + * Int64Atomics for atomics on i64 + * Groups for work group operations (e.g. broadcast) + * AtomicFloat(16|32|64)AddExt for atomics on f16, f32, f64 (SPV_EXT_shader_atomic_float[16]_add extensions) + * BFloat16ConversionINTEL for bf16 support (SPV_INTEL_bfloat16_conversion extension) + * SubgroupBufferBlockIOINTEL for efficient block loads and stores (SPV_INTEL_subgroups extension) + Device info =========== @@ -222,18 +292,18 @@ run-time. .. code:: C++ if (tinytc::get_support_level(device) >= tinytc::support_level::basic) { - auto info = tinytc::make_core_info(device); + auto info = tinytc::create_core_info(device); // ... } Runtime ======= -The JIT compiler compiles tensor programs into OpenCL-C code. +The JIT compiler compiles tensor programs into SPIR-V binaries. The libray provides functions to create the runtime's kernel bundle object -(cl_program, sycl::kernel_bundle, ze_module_handle_t) from a source object. +(cl_program, sycl::kernel_bundle, ze_module_handle_t) from a binary object. The runtime's kernel objects are obtained using the native API or the Tiny Tensor Compiler API (if applicable). -Setting the kernel arguments should following the :ref:`calling convention`. +Setting the kernel arguments should follow the :ref:`calling convention `. The Tiny Tensor Compiler should be used to translate the 2D work-group size of the tensor language to a 3D work-group size, and to translate the group size to the global size that is passed to the runtime. @@ -245,29 +315,29 @@ Example for "func @foo(%a: i32, ...) { ... }" (without error handling code): .. code:: C - ze_module_handle_t module = NULL; + ze_module_handle_t bundle = NULL; ze_kernel_handle_t kernel = NULL; int a = 42; - tinytc_ze_kernel_bundle_create_with_source(&module, context, device, source, source_ctx); - tinytc_ze_kernel_create(&kernel, module, "foo"); // Sets the work-group size + tinytc_ze_kernel_bundle_create_with_binary(&bundle, context, device, bin); + tinytc_ze_kernel_create(&kernel, bundle, "foo"); // Sets the work-group size zeKernelSetArgumentValue(kernel, 0, sizeof(a), &a); // ... ze_group_count_t group_count = tinytc_ze_get_group_count(howmany); zeCommandListAppendLaunchKernel(command_list, kernel, &group_count, NULL, 0, NULL); // ... zeKernelDestroy(kernel); - zeModuleDestroy(module); + zeModuleDestroy(bundle); .. tab:: OpenCL (C) .. code:: C - cl_program program = NULL; + cl_program bundle = NULL; cl_kernel kernel; cl_int err; int a = 42; - tinytc_cl_kernel_bundle_create_with_source(&program, context, device, binary, source_ctx); - kernel = clCreateKernel(program, "foo", &err); + tinytc_cl_kernel_bundle_create_with_binary(&bundle, context, device, bin); + kernel = clCreateKernel(bundle, "foo", &err); clSetKernelArg(kernel, 0, sizeof(a), &a); // ... size_t ls[3], gs[3]; @@ -276,14 +346,14 @@ Example for "func @foo(%a: i32, ...) { ... }" (without error handling code): clEnqueueNDRangeKernel(command_list, kernel, 3u, NULL, gs, ls, 0, NULL, NULL); // ... clReleaseKernel(kernel); - clReleaseProgram(program); + clReleaseProgram(bundle); .. tab:: SYCL (C++) .. code:: C++ - auto bundle = tinytc::make_kernel_bundle(context, device, source, source_ctx); - auto kernel = tinytc::make_kernel(bundle, "foo"); + auto bundle = tinytc::create_kernel_bundle(context, device, bin); + auto kernel = tinytc::create_kernel(bundle, "foo"); auto exe_range = tinytc::get_execution_range(kernel, howmany); queue.submit([&](sycl::handler &h) { h.set_args(42, ...); @@ -314,8 +384,8 @@ The general usage of a recipe is as following: tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; - tinytc_recipe__create(&recipe, info, , source_ctx); - tinytc_ze_recipe_handler_create(&handler, context, device, recipe, source_ctx); + tinytc_recipe__create(&recipe, info, , ctx); + tinytc_ze_recipe_handler_create(&handler, context, device, recipe, ctx); tinytc_recipe__set_args(handler, ); tinytc_ze_recipe_handler_submit(handler, command_list, NULL, 0, NULL); // ... @@ -328,8 +398,8 @@ The general usage of a recipe is as following: tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; - tinytc_recipe__create(&recipe, info, , source_ctx); - tinytc_cl_recipe_handler_create(&handler, context, device, recipe, source_ctx); + tinytc_recipe__create(&recipe, info, , ctx); + tinytc_cl_recipe_handler_create(&handler, context, device, recipe, ctx); tinytc_recipe__set_args(handler, ); tinytc_cl_recipe_handler_submit(handler, queue, 0, NULL, NULL); // ... @@ -340,8 +410,8 @@ The general usage of a recipe is as following: .. code:: C++ - auto handler = tinytc::make_recipe_handler(queue, - tinytc::make_(info, , source_ctx), source_ctx); + auto handler = tinytc::create_recipe_handler(queue, + tinytc::create_(info, , ctx), ctx); ::set_args(handler, ); handler.submit(queue); @@ -362,7 +432,7 @@ For example: tinytc_recipe__set_args(..., A, tinytc_mem_type_usm_pointer, ...); In C++, one only needs to pass the memory object. -The memory object is implicitly converted to the :ref:`mem` type that +The memory object is implicitly converted to the :ref:`tinytc::mem` type that automatically determines whether a pointer or a cl_mem object is given. A pointer maps to tinytc_mem_type_usm_pointer and a cl_mem object maps to tinytc_mem_type_buffer. diff --git a/docs/manual/tensor-ir.rst b/docs/manual/tensor-ir.rst index 3f7bcdeb..cdb4158c 100644 --- a/docs/manual/tensor-ir.rst +++ b/docs/manual/tensor-ir.rst @@ -15,23 +15,26 @@ Execution model The unit of execution described by a function written in the tensor language is called a **kernel**. Kernels are launched in batches, where each instance of the kernel is called a work-group. -The kernel has access to its group id that is used to select the work done in the work group. -Each work group consists of a fixed number of work-items that execute concurrently. -The language distinguishes between two kinds of instructions: *replicated* and *collective* instructions. -It is distinguished between *mixed* and *spmd* regions. -Mixed regions may contain replicated and collective instructions whereas spmd regions -may only contain replicated instructions. - -A collective instruction distributes the work among the work-items. -The instruction is responsible to distribute the work in a sensible manner. - -A replicated instruction replicates the work across all work-items. -In a mixed region, the replicated instructions always operate on the same data. -In spmd regions, the replicated instructions can operate on multiple data, -but in these regions collective instructions are prohibited. - -Mixed regions can be nested whereas spmd regions must not be nested. -A mixed region may be nested in a spmd region. +The kernel has access to a three dimensional group id that is used to select the work done in the work group. +Each work group consists of a fixed number of subgroups that execute concurrently. +Subgroups can be further divided into work-items, where the number of work-items per subgroup +is given by the subgroup size. + +The language distinguishes between *collective*, *SPMD*, and *mixed* instructions. +A collective instruction distributes the work among the work-items in an implementation-defined manner. +Local variables passed to or returned from a collective instruction are always uniform, meaning +that each work-item holds the same value. +An SPMD instruction follows the OpenCL execution model, where local variables may have a different value +for each work-item. +Mixed instructions accept both varying and uniform local variables. + +In an SPMD region, we call an argument *dynamically uniform* if all work-items in a subgroup have +the same value. + +Regions come in two different kinds: collective and SPMD. +A collective instructions must only appear in a collective region, and an SPMD instruction +must only appear in a SPMD region. Mixed instructions might appear in both kinds of regions. +SPMD regions may be nested in collective regions but collective regions must not be nested in SPMD regions. Core rules ========== @@ -54,7 +57,9 @@ are prefixed with ``@``. .. code:: abnf - identifier = 1*DIGIT / (ALPHA *(ALPHA / DIGIT / "_")) + identifier = unnamed-identifier / named-identifier + unnamed-identifier = 1*DIGIT + named-identifier = ALPHA *(ALPHA / DIGIT / "_") local-identifier = "%" identifier global-identifier = "@" identifier @@ -63,15 +68,18 @@ Constants .. code:: abnf + constant = boolean-constant / integer-constant / floating-constant / complex-constant + boolean-constant = "true" / "false" + integer-constant = [sign] 1*DIGIT sign = "-" / "+" - integer-constant = "true" / "false" / [sign] 1*DIGIT - floating-constant = [sign] *DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] + floating-constant = [sign] (*DIGIT "." 1*DIGIT ["e" [sign] 1*DIGIT] / "inf" / "nan") mantissa-dec = *DIGIT "." 1*DIGIT / 1*DIGIT "." mantissa-hex = *HEXDIG "." 1*HEXDIG / 1*HEXDIG "." exponent = [sign] 1*DIGIT floating-constant-dec = [sign] (mantissa-dec ["e" exponent] / 1*DIGIT "e" exponent) floating-constant-hex = [sign] "0x" (mantissa-hex ["p" exponent] / 1*HEXDIG "p" exponent) floating-constant = floating-constant-dec / floating-constant-hex + complex-constant = "[" floating-constant "," floating-constant "]" Integer constants must lie in the range :math:`-2^{63}+1,\dots,2^{63}-1`. @@ -80,6 +88,32 @@ The hexadecimal floating point syntax is supported, too. `strtod `_ can be used for parsing floating point numbers. +Attributes +========== + +.. code:: abnf + + attribute = array-attribute / + boolean-attribute / + dictionary-attribute / + integer-attribute / + string-attribute + array-attribute = "[" [attribute *(", " attribute)] "]" + boolean-attribute = boolean-constant + dictionary-attribute = "{" [named-attribute *("," named-attribute)] "}" + named-attribute = attribute-name "=" attribute + attribute-name = "alignment" / + "shape_gcd" / + "stride_gcd" / + "subgroup_size" / + "unroll" / + "work_group_size" / + string-attribute + integer-attribute = integer-constant + string-attribute = %x22 *(%x20-21 / %x23-7E) %x22 + +Attributes add information about an operation, for example to assert properties or to direct the compiler. + .. _tensor language functions: Functions @@ -87,16 +121,30 @@ Functions .. code:: abnf - function-definition = "func" global-identifier "(" [argument-list] ")" *attribute region + function-definition = "func" global-identifier "(" [argument-list] ")" + ["attributes" dictionary-attribute] region argument-list = argument *("," argument) - argument = local-identifier ":" type - attribute = work-group-size-attribute / subgroup-size-attribute - work-group-size-attribute = "work_group_size" "(" 1*DIGIT "," 1*DIGIT ")" - subgroup-size-attribute = "subgroup_size" "(" 1*DIGIT ")" + argument = local-identifier ":" type [dictionary-attribute] Defines a function that is callable from the host. -Attributes are optional and autoatically determined if omitted. +Attributes +---------- + +Subgroup size and work-group size are determined automatically by the compiler, but can be overriden +using the function's attribute dictionary: + +.. list-table:: + + * - Name + - Type + - Description + * - subgroup_size + - integer-attribute + - Subgroup size; valid values depend on the target device (typically 16 or 32) + * - work_group_size + - array-attribute with 2 integer-attribute entries + - Two dimensional work-group size in number of work-items The work-group size attribute defines the size of the local work group. Due to the focus on matrix operations, the work-group size is always two-dimensional, @@ -108,11 +156,35 @@ the subgroup sizes supported by the device. The product of the work-group size modes must be smaller or equal than the maximum work-group size of device. -The work-group is divided into full subgroups, therefore the work-group size -is always a multiple of the subgroup size. -The subgroup size attribute enforces a particular subgroup device supported by +The subgroup size attribute enforces a particular subgroup size that must be supported by the device. +Parameter attributes +-------------------- + +Parameters with memref or group type accept the following named attributes: + +.. list-table:: + + * - Name + - Type + - Description + * - alignment + - integer-attribute + - Minimum pointer alignment + * - shape_gcd + - array-attribute of integer-attribute + - Greatest common divisors of shape + * - stride_gcd + - array-attribute of integer-attribute + - Greatest common divisors of stride + +Cf. the documentation of the :ref:`memref type ` and the :ref:`group type `. + +Restrictions +------------ + +* Arguments must not have coopmatrix type. Regions ======= @@ -131,31 +203,88 @@ Types .. code:: abnf - type = void-type / scalar-type / memref-type / group-type + type = void-type / boolean-type / scalar-type / memref-type / group-type void-type = "void" +Boolean type +------------ + +.. code:: abnf + + boolean-type = "bool" + +Boolean type that only has two states (true or false). + Scalar types ------------ .. code:: abnf - scalar-type = integer-type / floating-type - integer-type = "i" ("1" / "8" / "16" / "32" / "64") / "index" - floating-type = "f" ("32" / "64") + scalar-type = integer-type / floating-type / complex-type + integer-type = "i8" / "i16" / "i32" / "i64" / "index" + floating-type = "bf16" / "f16" / "f32" / "f64" + complex-type = "c32" / "c64" -Scalar types are either signless integer ("i") or floating point ("f"). +Scalar types are either signless integer ("i"), floating point ("f"), +or complex floating point ("c"). The number behind the scalar type prefix denotes the number of bits, e.g. "f64" are double precision floating point numbers. +The "bf16" type encodes bfloat16 floating point numbers. The "index" type is an integer type whose width is platform-specific. +Type sizes in bytes are given by + +=========================== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +:math:`\alpha` i8 i16 i32 i64 bf16 f16 f32 f64 c32 c64 +=========================== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +:math:`\text{size}(\alpha)` 1 2 4 8 2 2 4 8 8 16 +=========================== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + + +Mixed precision operands might be allowed in instructions if the operands' types are *promotable*. +The scalar type :math:`\alpha` may be promoted to the scalar type :math:`\beta` if all values an operand +of type :math:`\alpha` may take can be exactly represented in type :math:`\beta`. +Formally, :math:`\alpha` is promotable to :math:`\beta` if :math:`\alpha \preceq \beta`, +where the partial order :math:`\preceq` is defined by the following relation matrix: + +=============== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +:math:`\preceq` i8 i16 i32 i64 bf16 f16 f32 f64 c32 c64 +=============== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== +i8 1 1 1 1 1 1 1 1 1 1 +i16 1 1 1 1 1 1 1 +i32 1 1 1 1 1 +i64 1 +bf16 1 1 1 1 1 +f16 1 1 1 1 1 +f32 1 1 1 1 +f64 1 1 +c32 1 1 +c64 1 +=============== ==== ==== ==== ==== ==== ==== ==== ==== ==== ==== + +Moreover, for scalar types :math:`\alpha,\beta` we define + +.. math:: + + \text{promote}(\alpha, \beta) = \left\{\begin{array}{rcl} + \beta & \text{ if } & \alpha \preceq \beta, \\ + \alpha & \text{ if } & \beta \preceq \alpha, \\ + \text{fail} & \text{ else.} + \end{array}\right. + +Here, "fail" means that the promotion is not allowed and the compiler should throw an error. + + + Memref type ----------- .. code:: abnf - memref-type = "memref<" scalar-type tensor-shape ["," memory-layout] ">" + memref-type = "memref<" scalar-type tensor-shape ["," memory-layout] ["," address-space] ">" constant-or-dynamic = integer-constant / "?" tensor-shape = *("x" constant-or-dynamic) + address-space = "global" / "local" A memref is a reference to a region of memory. In analogy to the C/C++-language, the memref can be thought of as a pointer, @@ -182,6 +311,31 @@ E.g. the memory layout of ``memref`` is ``strided<1,5,30>``. We note that ``memref`` and ``memref>`` are the same type. +Memrefs have an optional address space attribute. +The global address space referse to memory objects allocated from the global memory pool +that is shared by all work groups. +The local memory space is shared by all work-items of the work-group but inaccessible to another work-group. +The default address space is "global", memrefs with "local" address space are returned by +the alloca instruction. + +Definitions +........... + +Let V be a value of memref type. +The :math:`\text{order}(V)` operation returns the memref's order. +The :math:`\text{shape}(V)` returns the tensor shape as tuple. +:math:`\text{rows}(V)` and :math:`\text{columns}(V)` return the size of the first +and second mode, respectively. +The :math:`\text{element_type}(V)` operation gives the underlying scalar type. + +For example, let B be a value of memref type, then + +* :math:`\text{order}(B) = 3` +* :math:`\text{shape}(B) = (8,16,4)` +* :math:`\text{rows}(B) = 8` +* :math:`\text{columns}(B) = 16` +* :math:`\text{element_type}(B) = \text{f32}` + Memory layout ............. @@ -190,6 +344,8 @@ Memory layout memory-layout = strided-layout +.. _strided layout: + Strided layout ~~~~~~~~~~~~~~ @@ -219,825 +375,1806 @@ The default packed dense layout is given by Stride modes might be dynamic as well, indicated by a question mark. +.. _memref attributes: + +Alignment attribute +................... + +The *alignment=X* attribute gives the alignment X of the memref's base pointer in bytes. +That is, for the pointer P pointing to the first element of the memref we must have :math:`P = 0 \pmod{X}`. + +**Restriction:** The alignment must be a multiple of the size of the memref's element type. + + +Greatest common divisor (GCD) attributes +........................................ + +The *shape_gcd=[d_1,...,d_k]* attribute asserts that :math:`s_i = 0 \pmod{d_i}, i=1,\dots,k`, where k is +smaller or equal than the order of the tensor n and :math:`s_i` is the i-th entry of the shape vector. +The divisors are understood to be the greatest common divisors for the set of shapes that the kernel is used for. +For example, if we know that :math:`s_1` is always a multiple of 4 then we can set *shape_gcd=[4]*. + +The *stride_gcd=[D_1,...,D_m]* attribute asserts that :math:`S_i = 0 \pmod{D_i}, i=1,\dots,m`, where m is +smaller or equal than the order of the tensor n and :math:`S_i` is the i-th entry of the stride vector. +The divisors are understood to be the greatest common divisors for the set of strides that the kernel is used for. +For example, if we know that :math:`S_2` is always a multiple of 4 then we can set *stride_gcd=[1,4]*. + Group type ---------- .. code:: abnf - group-type = "group<" memref-type ["," "offset" ":" constant-or-dynamic] ">" + group-type = "group<" memref-type "x" constant-or-dynamic ["," "offset" ":" constant-or-dynamic] ">" The group type collects unstructured pointers to memref's with potentially different dynamic mode sizes. The C-analogy of a group is a pointer-to-a-pointer. -For example, the C-analogue of a ``group>`` is a ``float**``. +For example, the C-analogue of a ``groupx?>`` is a ``float**``. + +The group shape is always one-dimensional and may be queried using the +:ref:`size instruction `. The optional offset parameter is used to offset each pointer by the given number of elements. Given the C-analogue ``float** group``, loading element ``i`` with offset ``off`` gives the pointer ``float* tmp = group[i] + off``. The default offset is 0. -Dynamic values ('?') may appear in the memref-type and in the offset. +Dynamic values ('?') may appear in the memref-type, in the group shape, and in the offset. These values are stored in the dope vector; the calling convention for groups is implementation-defined. -Instructions -============ +.. _group attributes: -.. code:: abnf +Attributes +.......... - value-instruction = local-identifier "=" (alloca-instruction - / arith-binary-instruction - / arith-unary-instruction - / cast-instruction - / comparison-instruction - / expand-instruction - / fuse-instruction - / group-id-instruction - / group-size-instruction - / load-instruction - / size-instruction - / subview-instruction) - multi-value-instruction = [local-identifier-list "="] if-instruction - local-identifier-list = local-identifier *("," local-identifier) - instruction = value-instruction - / multi-value-instruction - / axpby-instruction - / barrier-instruction - / for-instruction - / foreach-instruction - / lifetime-stop-instruction - / gemm-instruction - / gemv-instruction - / ger-instruction - / hadamard-product-instruction - / store-instruction - / sum-instruction - / yield-instruction +Attributes applied on a group type are passed through to the memrefs. +That is, when a memref is loaded from the group then the :ref:`memref attributes ` +are equal to the attributes of the group. -Alloca ------- +Cooperative matrix type +----------------------- .. code:: abnf - alloca-instruction = "alloca" "->" memref-type - -Overview -........ - -*Collective instruction.* -The alloca instruction allocates temporary memory that is freed automatically at the end of the block that contains the alloca. - -Returns -....... + coopmatrix-type = "coopmatrix<" scalar-type 2*2("x" integer-constant) "," matrix-use ">" + matrix-use = "matrix_a" / "matrix_b" / "matrix_acc" -A memref of the memref-type. +The coopmatrix represents a matrix distributed across a subgroup, where each work-item in a subgroup +stores a part of the matrix. +The scalar-type specifies the matrix element type, the first integer-constant the number of rows, +and the second integer-constant the number of columns. +The matrix-use may affect the distribution of the matrix in the subgroup, and the name refers to the +position of the matrix in a matrix multiplication. -Restrictions -............ +Not all matrix shapes need to be supported in the implementation. +The supported matrix shapes may depend on data type, matrix use, and target hardware. -The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +An argument to any instruction that has coopmatrix type **must** be dynamically uniform. -Arithmetic (binary) -------------------- +Definitions +........... -.. code:: abnf +Let V be a value of coopmatrix type. +The :math:`\text{rows}(V)` and :math:`\text{columns}(V)` functions return the size of the first +and second mode, respectively, and :math:`\text{shape}(V)` returns rows and cols as tuple. +The :math:`\text{component_type}(V)` operation gives the underlying scalar type +and :math:`\text{use}(V)` returns the use. - identifier-or-constant = local-identifier / integer-constant / floating-constant - arith-binary-type = ".add" / - ".sub" / - ".mul" / - ".div" / - ".rem" / - ".shl" / - ".shr" / - ".and" / - ".or" / - ".xor" - arith-binary-instruction = "arith" arith-binary-type - identifier-or-constant "," identifier-or-constant ":" scalar-type +For example, let B be a value of coopmatrix type, then -Overview -........ +* :math:`\text{shape}(B) = (8,16)` +* :math:`\text{rows}(B) = 8` +* :math:`\text{columns}(B) = 16` +* :math:`\text{component_type}(B) = \text{f32}` +* :math:`\text{use}(B) = \text{matrix_acc}` -*Replicated instruction.* -Binary arithmetic operation on scalars. -Both operands, as well as the returned type, have the same scalar type. - -==== ============ ============================================================================== -Op Allowed type Description -==== ============ ============================================================================== -.add scalar-type Sum of operands -.sub scalar-type Difference of operands -.mul scalar-type Product of operands -.div scalar-type Quotient of operands -.rem scalar-type Remainder from the division of operands -.shl integer-type Left shift first operand by number of bits given by second operand -.shr integer-type Arithmetic right shift first operand by number of bits given by second operand -.and integer-type Bitwise and -.or integer-type Bitwise or -.xor integer-type Bitwise xor -==== ============ ============================================================================== +Instructions +============ -Arithmetic (unary) ------------------- +Instructions may return zero, one, or multiple values, and follow the following format: .. code:: abnf - arith-unary-type = ".neg" / ".not" - arith-unary-instruction = "arith" arith-unary-type identifier-or-constant ":" scalar-type + value-instruction-assignment = local-identifier "=" value-instruction + multi-value-instruction-assignment = [local-identifier-list "="] multi-value-instruction + local-identifier-list = local-identifier *("," local-identifier) + instruction = value-instruction-assignment + / multi-value-instruction-assignment -Overview -........ +That is, on the left-hand side we have list of values that are produced by the instruction followed by an equals sign, +or an empty string, if the instruction does not produce values. +On the right-hand side, after the equals sign or empty string, the name of the instruction is written, e.g. "ger", optionally followed by instruction modifiers, e.g. "ger.atomic". +Then, a list of operands follows that is usually comma-seperated but might also be printed in a custom format +(e.g. for "load", "store", "subview", etc.). +If the instruction produces values, then the types of the returned values must be annotated after a colon. -*Replicated instruction.* -Unary arithmetic operation on scalars. -The returned value has the same type as the operand. -==== ============ ============================================================================== -Op Allowed type Description -==== ============ ============================================================================== -.neg scalar-type Negation -.not integer-type Bitwise not -==== ============ ============================================================================== -Cast ----- +Collective instructions +----------------------- + +Alloca +...... .. code:: abnf - cast-instruction = "cast" identifier-or-constant ":" scalar-type "->" scalar-type + value-instruction = "alloca" [dictionary-attribute] ":" memref-type Overview -........ +~~~~~~~~ -*Replicated instruction.* -Cast scalar values. +The alloca instruction allocates temporary memory that is freed automatically at the end of the block that contains the alloca. -Comparison ----------- +Attributes +~~~~~~~~~~ -.. code:: abnf +Alloca accepts the following named attributes: - comparison-instruction = "cmp" (".eq" / ".ne" / ".gt" / ".ge" / ".lt" / ".le") - identifier-or-constant "," identifier-or-constant ":" scalar-type +.. list-table:: -Overview -........ + * - Name + - Type + - Description + * - alignment + - integer-attribute + - Base pointer alignment; must not be larger than the :ref:`default alignment `. -*Replicated instruction.* -Scalar comparison. -Both operands must have the same scalar type and the returned value is boolean. - -==== ===================== -Cond Description -==== ===================== -.eq Equal -.ne Not equal -.gt Greater than -.ge Greater than or equal -.lt Less than -.le Less than or equal -==== ===================== +Restrictions +~~~~~~~~~~~~ -Expand ------- +* The memref's size must known at compile-time, i.e. the tensor shape must not contain any dynamic modes. +* The address space must be "local". + +Axpby +..... .. code:: abnf - expand-instruction = "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type - expand-shape = constant-or-dynamic-or-identifier 1*("x" constant-or-dynamic-or-identifier) - constant-or-dynamic-or-identifier = integer-constant / "?" / local-identifier + transpose = ".t" / ".n" + instruction =/ "axpby" [".atomic"] [transpose] local-identifier "," local-identifier "," + local-identifier "," local-identifier Overview -........ - -*Replicated instruction.* -The expand instruction returns a view on a tensor with a mode viewed as higher-order mode. - -Arguments -......... - -The first argument must point to a value of memref type. -The integer constant in square brackets gives the mode that shall be expanded. -The expand shape gives the new shape of the mode. -Values in the expand shape must have index type. +~~~~~~~~ -The output type is a memref type according to the following rules: - -#. **Shape:** The mode size is replaced with the expand shape. If one entry in expand shape is dynamic, - then either its size is inferred automatically if the mode size is known, or it determined automatically - at run-time if the mode size is dynamic. +Axpby implements - .. code:: +.. math:: - expand %0[1 -> 2x8] : memref ; -> memref - expand %0[1 -> 2x?] : memref ; -> memref - expand %0[1 -> ?x8] : memref ; -> memref - expand %0[1 -> 2x?] : memref ; -> memref - expand %0[1 -> ?x8] : memref ; -> memref + B := \alpha \text{op}(A) + \beta B -#. **Identifiers:** Local identifiers in the expand shape are dynamic in the resulting memref type. +for vectors and matrices, where :math:`\text{op}(X)` is defined as - .. code:: +.. math:: - expand %0[1 -> %1 x ?] : memref ; -> memref - expand %0[1 -> %1 x ?] : memref ; -> memref - expand %0[1 -> %1 x %2] : memref ; -> memref - expand %0[1 -> 4 x %1] : memref ; -> memref + \text{op}(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{transpose} = \text{".t"} \wedge \text{order}(X) = 2,\\ + X & \text{ else. } + \end{array} + \right. -#. **Stride:** A new stride entry is entered that follows the canonical stride computation. +If the atomic flag is set, B is updated atomically. - .. code:: +Operands +~~~~~~~~ - expand %0[0->4x8] : memref> ; -> memref> - expand %0[0->4x?] : memref> ; -> memref> - expand %0[0->?x4] : memref> ; -> memref> - expand %0[0->4x?] : memref> ; -> memref> +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 scalar-type :math:`\beta` +4 memref-type B +======= =========== ============== Restrictions -............ - -At most one mode in expand-shape must be dynamic. +~~~~~~~~~~~~ -The product of the expand shape must be the same as the mode size. -If one entry in the expand shape is dynamic then the other must evenly divide the mode size. +* :math:`\text{shape}(B) = \text{shape}(\text{op}(A))` +* :math:`\text{order}(B) = 0 \lor \text{order}(B) = 1 \lor \text{order}(B) = 2` +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A) \preceq \text{element_type}(B)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -Fuse ----- +Cumulative sum +.............. .. code:: abnf - fuse-instruction = "fuse" local-identifier "[" integer-constant "," integer-constant "]" ":" memref-type + instruction =/ "cumsum" [".atomic"] local-identifier "," local-identifier "," integer-constant "," + local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* -The fuse instruction returns a view on a tensor with two or more adjacent modes viewed as a single mode. +Computes the n-mode cumulative sum -Arguments -......... +.. math:: -The first argument must point to a value of memref type. -The fused modes are specified as the interval [from, to], where from is given -by the first integer and to is given by the second integer. -Counting starts from 0 so we have + B := \alpha A \times_{n} L_{s_n} + \beta B, + +where :math:`L_{s_n}` is the lower triangular matrix of ones of size :math:`s_n\times s_n` and +:math:`s_n` is the n-th entry of the shape vector of A. +In index notation, we have equivalently .. math:: - - 0 \leq from < to < order(memref) -The local identifier must have the memref type specified last. -The output type is a memref type according to the following rules: + B_{i_1\dots i_{n-1}ji_{n+1}\dots i_M} + := \alpha \sum_{i_n=1}^{j}A_{i_1\dots i_{n-1}i_ni_{n+1}\dots i_M} + + \beta B_{i_1\dots i_{n-1}ji_{n+1}\dots i_M}, -#. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. +If the atomic flag is set, B is updated atomically. - .. code:: - fuse %0[1,3] : memref ; -> memref - fuse %0[1,3] : memref> ; -> memref> +Operands +~~~~~~~~ -#. **Stride:** Strides remain unchanged. +======= ================ ================== +Op.-No. Type Description +======= ================ ================== +1 scalar-type :math:`\alpha` +2 memref-type A +3 integer-constant n (summation mode) +4 scalar-type :math:`\beta` +5 memref-type B +======= ================ ================== - .. code:: +Restrictions +~~~~~~~~~~~~ - fuse %0[1,2] : memref> ; -> memref> - fuse %0[0,1] : memref> ; -> memref> +* :math:`\text{order}(A) \geq 1` +* :math:`\text{shape}(A) = \text{shape}(B)` +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A) \preceq \text{element_type}(B)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -Restrictions -............ +Foreach +....... -Let i be the first mode and j the last mode. -The stride vector S and the shape vector s must satisify the following compatibility condition: +.. code:: abnf -:math:`\forall k \in [i,j): S_{k}s_{k} = S_{k+1}` + instruction =/ "foreach" "(" local-identifier-list ")" "=" + "(" local-identifier-list ")" "," "(" local-identifier-list ")" region -If S(i:j) and s(i:j) are known at compile time, the fuse instruction is illegal if the compatibility -condition is not satisfied. -If a single entry in S(i:j) or s(i:j) is dynamic, then fusing modes that violate the compatbility condition -is undefined beheaviour. +Overview +~~~~~~~~ -.. code:: +A foreach loop that executes the loop's range without any sequence guarantee. +The region of a foreach is a *spmd region*. + +The three local identifier lists define the loop range and the local identifiers that +make the trip count available within the loop body. +All three lists must have the same length and have the following format: - fuse %0[0,1] : memref> ; Illegal, modes cannot be fused - fuse %0[0,1] : memref> ; Undefined behaviour if dynamic stride != 8 +.. math:: + (\text{var}_1, \dots, \text{var}_N) = (\text{from}_1, \dots, \text{from}_N), + (\text{to}_1, \dots, \text{to}_N), -Group id --------- +where :math:`N` is the common length of each of the three lists. +The loop range is defined as the cartesian product of the half-open intervals +:math:`[\text{from}_i; \text{to}_i)` such that the trip count take the values -.. code:: abnf +.. math:: - group-id-instruction = "group_id" + (\text{var}_1, \dots, \text{var}_N) \in [\text{from}_1; \text{to}_1) \times \dots \times + [\text{from}_N; \text{to}_N) -Overview -........ +The integer type of a "from" and "to" pair must match. +The integer type of a loop variable follows the integer type of its corresponding "from" and "to" pair. -*Replicated instruction.* -Returns the group id, an integer of type "index" inbetween 0 and the group size - 1. +The mapping of trip count to work-item is implementation-defined. -Group size ----------- +GEMM +.... .. code:: abnf - group-size-instruction = "group_size" + instruction =/ "gemm" [".atomic"] [transpose] [transpose] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* -Returns the group size, an integer of type "index". +GEMM implements the well-known GEMM BLAS-3 operation. +.. math:: -Load ----- + C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C -.. code:: abnf +The functions :math:`\text{op}_1` and :math:`\text{op}_2` are defined as - load-instruction = "load" local-identifier "[" [index-list] "]" ":" memref-or-group-type - index-list = identifier-or-int-constant *("," identifier-or-int-constant) - identifier-or-int-constant = integer-constant / local-identifier - memref-or-group-type = memref-type / group-type +.. math:: -Overview -........ + \text{op}_i(X) := \left\{ + \begin{array}{rcl} + X^T & \text{ if } & \text{transpose}_i = \text{".t"},\\ + X & \text{ else. } + \end{array} + \right. -Load the element given by the index list from a memref or group. -The number of indices must match the order of the memref -and a single index must be given for a group. +where transpose\ :sub:`1` and transpose\ :sub:`2` refer to the first and second transpose modifier, respectively. -Arguments -......... +If the atomic flag is set, C is updated atomically. -The first operand must have memref or group type. -The indices must be of ``index`` type. +Operands +~~~~~~~~ -Returns -....... +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 memref-type B +4 scalar-type :math:`\beta` +5 memref-type C +======= =========== ============== -A value of the memref's element type or the group's memref type. -Examples: +Restrictions +~~~~~~~~~~~~ -#. ``load %0[] : memref`` returns a ``f32`` value. -#. ``load %0[5, %1] : memref`` returns a ``f32`` value. -#. ``load %0[%1] : group>`` returns a ``memref`` value. -#. ``load %0[%1] : group, offset: ?>`` returns a ``memref`` value. +* :math:`\text{order}(A) = \text{order}(B) = \text{order}(C) = 2` +* :math:`\text{colums}(\text{op}_1(A)) = \text{rows}(\text{op}_2(B))` +* :math:`\text{rows}(C) = \text{rows}(\text{op}_1(A))` +* :math:`\text{columns}(C) = \text{columns}(\text{op}_2(B))` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(B)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -Size ----- +GEMV +.... .. code:: abnf - size-instruction = "size" local-identifier "[" integer-constant "]" ":" memref-type + instruction =/ "gemv" [".atomic"] [transpose] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Replicated instruction.* -The size instruction returns the i-th entry of the tensor's shape, where "i" is given by the integer -constant in square brackets. +GEMV implements the well-known GEMM BLAS-2 operation. -Arguments -......... +.. math:: -The first argument must point to a value of memref type. -The integer constant i gives the mode for which the size shall be returned. -It is required that + c := \alpha \text{op}_1(A) b + \beta c -.. math:: - - 0 \leq i < order(memref) +where :math:`\text{op}_1` is defined as in GEMM. -The local identifier must have the memref type specified last. -The instruction returns an integer of index type. +If the atomic flag is set, c is updated atomically. -Subview -------- +Operands +~~~~~~~~ -.. code:: abnf +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type c +======= =========== ============== - subview-instruction = "subview" local-identifier "[" [index-or-slice-list] "]" ":" memref-type - index-or-slice-list = index-or-slice *("," index-or-slice) - index-or-slice = identifier-or-int-constant [":" (identifier-or-int-constant / "?")] / ":" +Restrictions +~~~~~~~~~~~~ -Overview -........ +* :math:`\text{order}(A) = 2` +* :math:`\text{order}(b) = \text{order}(c) = 1` +* :math:`\text{colums}(\text{op}_1(A)) = \text{rows}(b)` +* :math:`\text{rows}(c) = \text{rows}(\text{op}_1(A))` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(b)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -*Replicated instruction.* -The subview instruction returns a view on a tensor. +GER +... -Arguments -......... +.. code:: abnf -The first argument must point to a value of memref type. -The number of indices in square brackets must match the order of the memref. -The indices are either given as single index or as a slice, where -slices are given in offset plus size notation ("%offset : %size"). -E.g. the slice "%0 : %1" extracts a block of %1 elements beginning from %0, which is equivalent -to the index interval [%0, %0 + %1). + instruction =/ "ger" [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier -.. admonition:: Note +Overview +~~~~~~~~ - A slice is often defined as "%0 : %1" being the index interval [%0, %1). - However, then the compiler needs to figure out whether %1 - %0 is constant or not in order - to determine whether the mode size is known at compile-time or not. - Therefore, we prefer the offset plus size notation. +Computes the general rank-1 update: -A dynamic size ("?") means that the size is the mode size inferred from the memref type -minus the offset. -A plain colon is syntactic sugar for "0:?". +.. math:: -There is no run-time check whether the indices are within bounds. -Offset and size must be of index type. -Offset must be non-negative and size must be positive. + C := \alpha a b^T + \beta C -The local identifier must have the memref type specified last. -The output type is a memref type according to the following rules: +If the atomic flag is set, C is updated atomically. -#. **Invariant-stride:** The stride is not changed. +Operands +~~~~~~~~ - .. code:: +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type a +3 memref-type b +4 scalar-type :math:`\beta` +5 memref-type C +======= =========== ============== - subview %0[4:8,8:4] : memref ; Returns memref> +Restrictions +~~~~~~~~~~~~ +* :math:`\text{order}(a) = \text{order}(b) = 1` +* :math:`\text{order}(C) = 2` +* :math:`\text{rows}(C) = \text{rows}(a)` +* :math:`\text{columns}(C) = \text{rows}(b)` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(b)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -#. **Rank-reduction:** A mode accessed by a single constant or value is removed from the output tensor. - .. code:: +Hadamard product +................ - subview %0[2:4, %1] : memref ; Returns memref> - subview %0[2:4, %1:1] : memref ; Returns memref> +.. code:: abnf -#. **Output-mode size:** The size of the output mode is determined by the size field of a slice - and may be dynamic. + instruction =/ "hadamard_product" [".atomic"] local-identifier "," local-identifier "," + local-identifier "," local-identifier "," local-identifier - .. code:: +Overview +~~~~~~~~ - subview %0[%1:4] : memref ; Returns memref - subview %0[%2:%2] : memref ; Returns memref - subview %0[2:4, %2:%2, 6:7] : memref ; Returns memref - subview %0[2:4, %2:%2, 6:7] : memref> ; Returns memref +Computes the Hadamard product of two vectors or two matrices. +That is, in index notation we have -#. **Dynamic size:** +.. math:: - .. code:: + c_{i} := \alpha a_{i} b_{i} + \beta c_{i} - subview %0[:] : memref ; Returns memref - subview %0[:] : memref ; Returns memref - subview %0[5:?] : memref ; Returns memref - subview %0[%2:?] : memref ; Returns memref +for vectors and -If --- +.. math:: -.. code:: abnf + C_{ij} := \alpha A_{ij} B_{ij} + \beta C_{ij} - if-instruction = "if" identifier-or-int-constant ["->" "(" scalar-type-list ")"] - region ["else" region] - type-list = scalar-type *("," scalar-type) +for matrices. If the atomic flag is set, c/C is updated atomically. -Overview -........ +Operands +~~~~~~~~ -An if statement. -Both regions are *mixed regions*. +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type a/A +3 memref-type b/B +4 scalar-type :math:`\beta` +5 memref-type c/C +======= =========== ============== -The condition must be of bool type. +Restrictions +~~~~~~~~~~~~ -Arguments -......... +* :math:`\text{order}(a) = \text{order}(b) = \text{order}(c) = o` with :math:`o\in\{1,2\}` +* :math:`\text{shape}(a) = \text{shape}(b) = \text{shape}(c)` +* :math:`\text{type}(\alpha) \preceq \text{promote}(\text{element_type}(A), \text{element_type}(b)) \preceq \text{element_type}(C)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(C)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -The if instruction may return multiple values, where the number of values and the value types -are given by the scalar-type-list. -If values are returned, the last instruction in both the "then"-region and the "else"-region must -be a yield instruction (the "else"-region cannot be omitted). +Parallel +........ -Example: +.. code:: abnf - .. code:: + instruction =/ "parallel" region - %1 = cmp.lt %0, 16 : i32 - %x = if %1 -> (i32) { - yield %0 : i32 - } else { - yield 16 : i32 - } +Overview +~~~~~~~~ -Axpby ------ +Opens an *spmd region*. + +Sum +... .. code:: abnf - transpose = ".t" / ".n" - const-or-val = floating-constant / local-identifier - axpby-instruction = "axpby" transpose [".atomic"] - const-or-val "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + instruction =/ "sum" [".atomic"] [transpose] local-identifier "," local-identifier "," + local-identifier "," local-identifier Overview -........ +~~~~~~~~ -*Collective instruction.* -Axpby implements +Computes the matrix-vector product or the dot product of A with a vector of ones. +That is, if the result is a vector we have .. math:: - B := \alpha \text{op}(A) + \beta B + b := \alpha \text{op}(A) \vec{1} + \beta b, -for vectors and matrices. -If the atomic flag is set, B is updated atomically. +where :math:`\text{op}(A)` is defined as in the axpby instruction, +and if the result is a scalar we have -Arguments -......... +.. math:: -The first argument gives :math:`\alpha`, and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. + b := \alpha \left + \beta b -The transpose modifier defines :math:`\text{op}` as following: +If the atomic flag is set, b is updated atomically. -.. math:: - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i= t \wedge \text{order}(X) = 2,\\ - X & \text{ else. } - \end{array} - \right. +Operands +~~~~~~~~ -(Note that ".t" has no effect on vectors.) +======= =========== ============== +Op.-No. Type Description +======= =========== ============== +1 scalar-type :math:`\alpha` +2 memref-type A +3 scalar-type :math:`\beta` +4 memref-type b +======= =========== ============== -The shape of :math:`\text{op}(A)` and B must be identical and the order of A and B needs to be 1 (vector) -or 2 (matrix). +Restrictions +~~~~~~~~~~~~ +* :math:`\text{order}(b) = 1 \lor \text{order}(b) = 0` +* :math:`\text{order}(A) = \text{order}(b)+1` +* :math:`\text{rows}(b) = \text{rows}(\text{op}(A)) \text{ if } \text{order}(b) = 1` +* :math:`\text{type}(\alpha) \preceq \text{element_type}(A) \preceq \text{element_type}(B)` +* :math:`\text{type}(\beta) \preceq \text{element_type}(B)` +* If the atomic flag is set, :math:`\beta` must be constant and :math:`\beta \in \{0,1\}`. -For ---- + +Additional instructions +....................... + +.. code:: abnf + + instruction =/ "lifetime_stop" local-identifier + + +Mixed instructions +------------------ + +Arithmetic (binary) +................... .. code:: abnf - for-instruction = "for" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - ["," identifier-or-int-constant] [":" integer-type] region + arith-binary-type = "add" / + "sub" / + "mul" / + "div" / + "rem" / + "max" / + "min" / + "shl" / + "shr" / + "and" / + "or" / + "xor" + value-instruction =/ arith-binary-type local-identifier "," local-identifier + ":" (boolean-type / scalar-type / coopmatrix-type) Overview +~~~~~~~~ + +Binary arithmetic operation on scalars and cooperative matrices. +Both operands, as well as the returned type, have the same scalar or component type. +Arithmetic on cooperative matrices is done component-wise. + +The following table shows the operations' description and the types that are allowed for the operation. +The backslash "\\" is used to exclude types from the list of allowed types. + +=== ============================= ====================================================== +Op Allowed type Description +=== ============================= ====================================================== +add scalar-type Sum of operands +sub scalar-type Difference of operands +mul scalar-type Product of operands +div scalar-type Quotient of operands +rem scalar-type \\ complex-type Remainder from the division of operands +max scalar-type \\ complex-type Maximum of operands +min scalar-type \\ complex-type Minimum of operands +shl integer-type Left shift first operand by second operand +shr integer-type Arithmetic right shift first operand by second operand +and boolean-type / integer-type Bitwise and +or boolean-type / integer-type Bitwise or +xor boolean-type / integer-type Bitwise xor +=== ============================= ====================================================== + +Arithmetic (unary) +.................. + +.. code:: abnf + + arith-unary-type = "abs" / + "neg" / + "not" / + "conj" / + "im" / + "re" + value-instruction =/ arith-unary-type local-identifier + ":" (scalar-type / coopmatrix-type) + +Overview +~~~~~~~~ + +Unary arithmetic operation on scalars and cooperative matrices. +For integer and floating point input, the operand must have the same type as the returned value. +For complex input, the returned value has the component floating point type +for ".abs", ".im", and ".re", and the returned value has the same type as the operand +for ".neg" and ".conj". +Arithmetic on cooperative matrices is done component-wise. + +The following table shows the operations' description and the types that are allowed for the operation. + +==== ============================= ============================= +Op Allowed type Description +==== ============================= ============================= +abs scalar-type Compute absolute value +neg scalar-type Negation +not boolean-type / integer-type Bitwise not +conj complex-type Complex conjugate +im complex-type Extract imaginary part +re complex-type Extract real part +==== ============================= ============================= + +Barrier +....... + +.. code:: abnf + + instruction =/ "barrier" [".global"] [".local"] + +Overview +~~~~~~~~ + +**Note:** Barriers are inserted automatically in collective regions, but not in SPMD regions. +Manual barrier insertion should only be only necessesary in SPMD regions. + + +Control barrier. +The barrier must be encountered by all work-items. +A work-item in a work-group is not allowed to continue until all work-items in the work-group +have reached the barrier. + +Aditional memory fences are controlled by the following attributes: + +========= ====================================================================================== +Attribute Description +========= ====================================================================================== +.global Ensure that global memory accesses become visible to the work-group. +.local Ensure that local memory accesses become visible to the work-group. +========= ====================================================================================== + +Builtin (mixed) +............... + +.. code:: abnf + + mixed-builtin-type = "group_id" comp3 / + "num_groups" comp3 / + "num_subgroups" comp3 / + "subgroup_size" + comp3 = ".x" / ".y" / ".z" + value-instruction =/ mixed-builtin-type ":" integer-type + +Overview +~~~~~~~~ + +Returns a builtin value. + +The group id is three dimensional; the mode is selected with the .x, .y, and .z suffix. +Each mode starts with zero and is limited by the corresponding num_groups mode. That is, + +.. math:: + + \forall d \in \{x,y,z\} : 0 \leq \text{group_id}_d < \text{num_groups}_d + +The number of subgroups is related to the 2-dimensional work-group size as following: + +.. math:: + + \begin{aligned} + \text{num_subgroups}_x &= \frac{\text{work_group_size[0]}}{\text{subgroup_size}} \\ + \text{num_subgroups}_y &= \text{work_group_size[1]} \\ + \text{num_subgroups}_z &= 1 + \end{aligned} + +The following table shows the builtins' description and the types that are returned. + +============= ===== ====================== ====================================================== +Builtin Type OpenCL analogue Description +============= ===== ====================== ====================================================== +group_id index get_group_id Returns the x, y, or z mode of the group id +num_groups index get_num_groups Returns number of groups in the x, y, or z mode +num_subgroups i32 N/A Returns the number of subgroups in the x, y, or z mode +subgroup_size i32 get_max_sub_group_size Returns the subgroup size +============= ===== ====================== ====================================================== + +Cast +.... + +.. code:: abnf + + value-instruction =/ "cast" local-identifier ":" scalar-type + value-instruction =/ "cast" local-identifier ":" coopmatrix-type + +Overview +~~~~~~~~ + +Cast scalar values or cooperative matrices to type indicated after the colon. + +The source type must be a coopmatrix type if the destination type is a coopmatrix type, +and the shapes must match. +The coopmatrix use must either match, or +the use of the source type must be matrix_acc and the use of the destination type +must be matrix_a or matrix_b. + +Casts from complex types to non-complex types are forbidden. +The following table summarizes the casts and the mapping to SPIR-V +(the casts are done component-wise for coopmatrix types): + +============= ============= ================================================== +Operand type Result type SPIR-V Op +============= ============= ================================================== +integer-type integer-type OpSConvert +floating-type floating-type OpFConvert +complex-type complex-type OpFConvert (on vector2) +integer-type floating-type OpConvertSToF +floating-type integer-type OpConvertFToS +floating-type complex-type OpFConvert on real part, imaginary part is zero +integer-type complex-type OpConvertSToF on real part, imaginary part is zero +complex-type integer-type Forbidden +complex-type floating-type Forbidden +============= ============= ================================================== + +Comparison +.......... + +.. code:: abnf + + comparison-type = "equal" / + "not_equal" / + "greater_than" / + "greater_than_equal" / + "less_than" / + "less_than_equal" + value-instruction =/ comparison-type local-identifier "," local-identifier ":" "bool" + +Overview +~~~~~~~~ + +Scalar comparison. +Both operands must have the same scalar type and the returned value has boolean type. + +The following table shows the comparisons' description and the types that are allowed for the comparison. +The backslash "\\" is used to exclude types from the list of allowed types. + +=================== =========================== ===================== +Cond Allowed type Description +=================== =========================== ===================== +equal scalar-type Equal +not_equal scalar-type Not equal +greater_than scalar-type \\ complex-type Greater than +greather_than_equal scalar-type \\ complex-type Greater than or equal +less_than scalar-type \\ complex-type Less than +less_than_equal scalar-type \\ complex-type Less than or equal +=================== =========================== ===================== + +Constant ........ +.. code:: abnf + + value-instruction =/ "constant" constant ":" (boolean-type / scalar-type / coopmatrix-type) + +Overview +~~~~~~~~ + +Sets the result value to a constant value. +The type of the constant must match the scalar or component type +(e.g. an integer type requires an integer-constant and a floating type requires a floating-constant). + +When the result is a cooperative matrix, all entries are set to the same constant value. + +Expand +...... + +.. code:: abnf + + value-instruction =/ "expand" local-identifier "[" integer-constant "->" expand-shape "]" ":" memref-type + expand-shape = integer-constant-or-identifier 1*("x" integer-constant-or-identifier) + integer-constant-or-identifier = integer-constant / local-identifier + +Overview +~~~~~~~~ + +The expand instruction returns a view on a tensor with a mode viewed as higher-order mode. + +Operands +~~~~~~~~ + +The first argument must point to a value of memref type. +The first integer constant before "->" gives the mode that shall be expanded. +The expand shape coming after "->" gives the new shape of the mode. +Dynamic values in the expand shape must have `index` type. + +Restrictions +~~~~~~~~~~~~ + +The memref type of the result must conform with the following rules: + +#. Element type and address space must match the operand's memref type. +#. **Shape:** The mode size is replaced with the expand shape. + The product of the expand shape must equal the size of the expanded mode. + + .. code:: + + expand %0[1 -> 2x8] : memref ; %0: memref + expand %0[1 -> 2x2x2x2] : memref ; %0: memref + +#. **Identifiers:** Local identifiers in the expand shape are dynamic in the resulting memref type. + The product of the dynamic expand shape must equal the size of the expanded mode. + + .. code:: + + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> 2 x %1] : memref ; %0: memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> %1 x 2] : memref ; %0: memref + expand %0[1 -> %1 x %2 x 2] : memref ; %0: memref + expand %0[1 -> %2 x 2 x %1] : memref ; %0: memref + expand %0[1 -> %1 x %2] : memref ; %0: memref + expand %0[1 -> %1 x %2] : memref ; %0: memref + + *Note:* In the third example above, %1 must be equal to 8. + The output mode corresponding to %1 is still dynamic. + +#. **Stride:** A new stride entry is entered that follows the canonical stride computation. + It is also permissible to put '?' for a stride instead of the constant value. + + .. code:: + + expand %0[0->4 x 8] : memref> ; %0: memref> + expand %0[0->4 x 8] : memref> ; %0: memref> + expand %0[0->%1 x 4] : memref> ; %0: memref> + expand %0[0->4 x %1] : memref> ; %0: memref> + expand %0[0->4 x %1] : memref> ; %0: memref> + +Further restrictions: + +* The product of the expand shape must be the same as the mode size. +* If the product of the expand shape is only known at runtime, then it is undefined behaviour + if the dynamic product does not match the mode size. + +For +... + +.. code:: abnf + + multi-value-instruction = "for" local-identifier "=" + local-identifier "," local-identifier ["," local-identifier] + ["init" "(" init-value-list ")" "->" "(" return-type-list ")" ] region + [dictionary-attribute] + init-value-list = init-value *("," init-value) + init-value = local-identifier "=" local-identifier + return-type-list = return-type *("," return-type) + return-type = boolean-type / scalar-type / coopmatrix-type + + +Overview +~~~~~~~~ + A for loop. Instructions in the for loop execute sequentially and its region is a *mixed region*. -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -A step size can be given with the third integer constant. +Arguments +~~~~~~~~~ + +The trip count is stored in the first local identifier and is accessible within the loop body. +The loop's range [from; to) is given by the first and the second local identifier after the equals sign, +and a step size may be given with the third local identifier after the equals sign. The step size defaults to 1 if omitted. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. +The integer type of "from", "to", and "step" must be identical, and the integer type of the loop variable +follows the loop range's type. -Foreach -------- +Values that are given in the init-value-list may be carried from one iteration to the next. +The local identifier gives the name of the loop-carried value as it is accessible in the loop body. +The local identifier given on the right-hand side of the init-value expression determines +the initial value of the loop-carried value, and its type must coincide with the scalar-type-list. +When loop-carried values are present, the loop's last instruction must be a yield instruction that +updates the loop-carried values for the next iteration. +The number and types of the yielded values must correspond the scalar-type-list. + +Returns +~~~~~~~ + +The final value of the loop-carried values are returned by the for instruction. + + +Example: + +.. code:: + + %from = constant 2 : i32 + %to = constant 6 : i32 + %f0 = constant 0 : i64 + %f1 = constant 1 : i64 + %fn_1, %fn = for %n:i32=%from,%to init(%fn_2=%f0,%fn_1=%f1) -> (i64,i64) { + %fn = arith.add %fn_2, %fn_1 : i64 + yield (%fn_1, %fn) + } + ; %fn_1 contains the fourth Fibonacci number and %fn the fifth Fibonacci number + +Attributes +~~~~~~~~~~ + +The following named attributes may be passed in the attribute dictionary: + +.. list-table:: + + * - Name + - Type + - Description + * - unroll + - boolean-attribute or integer-attribute + - true: request to unroll loop, false: request to not unroll loop, integer: partial unroll count + +Fuse +.... .. code:: abnf - foreach-instruction = "foreach" local-identifier "=" identifier-or-int-constant "," identifier-or-int-constant - [":" integer-type] region + value-instruction =/ "fuse" local-identifier "[" integer-constant "," integer-constant "]" + ":" memref-type Overview -........ +~~~~~~~~ -A foreach loop that executes the loop's range [from; to) without any sequence guarantee. -The region of a foreach is a *spmd region*. +The fuse instruction returns a view on a tensor with two or more adjacent modes viewed as a single mode. -The loop's range [from; to) is given by the first integer constant and second integer constant, -and the trip count is stored in the local identifier. -The integer type of the loop variable is given after the colon. -The integer type of the loop variable and the loop bounds is given after the colon. -The default integer type is ``index``. +Fused modes are specified as the interval [from, to], where counting starts from 0. +From and to must refer to existing modes, that is, we require :math:`0 \leq \text{from} < \text{to} < \text{order}(\text{tensor})`. +Moreover, the stride vector S and the shape vector s must satisify the following compatibility condition: -GEMM ----- +:math:`\forall k \in [\text{from},\text{to}): S_{k}s_{k} = S_{k+1}` + +If S(i:j) and s(i:j) are known at compile time, the fuse instruction is illegal if the compatibility +condition is not satisfied. +If a single entry in S(i:j) or s(i:j) is dynamic, then fusing modes that violate the compatbility condition +is undefined beheaviour, e.g. + +.. code:: + + ; Illegal, modes cannot be fused + fuse %0[0,1] : memref ; %0: memref> + ; Undefined behaviour if dynamic stride != 8 + fuse %0[0,1] : memref> ; %0: memref> + +Operands +~~~~~~~~ + +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 memref-type tensor +2 integer-constant from +3 integer-constant to +======= ================ =========== + +Restrictions +~~~~~~~~~~~~ + +The memref type of the result must conform with the following rules: + +#. Element type and address space must match the operand's memref type. +#. **Shape:** The mode size of the fused modes is the product of the mode sizes. If one mode is dynamic the fused mode size is dynamic. + + .. code:: + + fuse %0[1,3] : memref ; %0: memref + fuse %0[1,3] : memref> ; %0: memref> + +#. **Stride:** Strides remain unchanged or are replaced by '?'. + + .. code:: + + fuse %0[1,2] : memref> ; %0: memref> + fuse %0[1,2] : memref> ; %0: memref> + fuse %0[0,1] : memref> ; %0: memref> + +If +.. .. code:: abnf - gemm-instruction = "gemm" transpose transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + multi-value-instruction =/ "if" local-identifier ["->" "(" return-type-list ")"] + region ["else" region] Overview -........ +~~~~~~~~ -*Collective instruction.* -GEMM implements the well-known GEMM BLAS-3 operation. +An if statement. +Both regions are *mixed regions*. + +The condition (first operand) must have boolean type. + +Returns +~~~~~~~ + +The if instruction may return multiple values, where the number of values and the value types +are given by the return-type-list. +If values are returned, the last instruction in both the "then"-region and the "else"-region must +be a yield instruction (the "else"-region cannot be omitted). + +Example: + + .. code:: + + %1 = cmp.lt %0, 16 : i32 + %x = if %1 -> (i32) { + yield (%0) + } else { + %c16 = constant 16 : i32 + yield (%c16) + } + + +Load +.... + +.. code:: abnf + + value-instruction =/ "load" local-identifier "[" [local-identifier-list] "]" + ":" scalar-or-memref-type + scalar-or-memref-type = scalar-type / memref-type + +Overview +~~~~~~~~ + +Load the element given by the index list from a memref or group. +The number of indices must match the order of the memref +and a single index must be given for a group. + +Operands +~~~~~~~~~ + +======= ======================== =========== +Op.-No. Type Description +======= ======================== =========== +1 memref-type / group-type tensor +2... index index list +======= ======================== =========== + +Returns +~~~~~~~ + +A value of the memref's element type or the group's memref type. +Examples: + +#. ``load %0[] : f32 ; %0: memref`` +#. ``load %0[5, %1] : f32 ; %0: memref`` +#. ``load %0[%1] : memref ; %0: groupx?>`` +#. ``load %0[%1] : memref ; %0: groupx?, offset: ?>`` + +Math (unary) +............ + +.. code:: abnf + + math-unary-type = "cos" / + "sin" / + "exp" / + "exp2" / + "native_cos" / + "native_sin" + "native_exp" / + "native_exp2" + value-instruction =/ math-unary-type local-identifier ":" scalar-type + +Overview +~~~~~~~~ + +Unary math operation on scalars. +The operand must have the same type as the returned value. + +The following table shows the operations' description and the types that are allowed for the operation. + +=========== ============================= ===================================================================== +Op Allowed type Description +=========== ============================= ===================================================================== +cos floating-type Compute cosine function +sin floating-type Compute sine function +exp floating-type / complex-type Compute base-e exponential function +exp2 floating-type / complex-type Compute base-2 exponential function +native_cos floating-type Compute cosine function with implementation-defined error +native_sin floating-type Compute sine function with implementation-defined error +native_exp floating-type / complex-type Compute base-e exponential function with implementation-defined error +native_exp2 floating-type / complex-type Compute base-2 exponential function with implementation-defined error +=========== ============================= ===================================================================== + +.. _size instruction: + +Size +.... + +.. code:: abnf + + value-instruction =/ "size" local-identifier "[" integer-constant "]" ":" "index" + +Overview +~~~~~~~~ + +The size instruction returns the i-th entry of the tensor's shape, where "i" is given by the integer +constant in square brackets. +"i" must be in bounds, i.e. :math:`0 \leq i < \text{order}(tensor)`. + +For group types, the group size is returned and "i" must be 0. + +Operands +~~~~~~~~~ + +======= ======================== =========== +Op.-No. Type Description +======= ======================== =========== +1 memref-type / group-type tensor +2 integer-constant mode index +======= ======================== =========== + +Subview +....... + +.. code:: abnf + + value-instruction =/ "subview" local-identifier "[" [index-or-slice-list] "]" + ":" memref-type + index-or-slice-list = index-or-slice *("," index-or-slice) + index-or-slice = integer-constant-or-identifier [":" integer-constant-or-identifier] + +Overview +~~~~~~~~ + +The subview instruction returns a view on a tensor. + +The first argument must point to a value of memref type. +The number of indices in square brackets must match the order of the memref type. +The indices are either given as single index or as a slice, where +slices are given in offset plus size notation ("%offset : %size"). +E.g. the slice "%0 : %1" extracts a block of %1 elements beginning from %0, which is equivalent +to the index interval [%0, %0 + %1). + +.. admonition:: Note + + A slice is often defined as "%0 : %1" being the index interval [%0, %1). + However, then the compiler needs to figure out whether %1 - %0 is constant or not in order + to determine whether the mode size is known at compile-time or not. + Therefore, we prefer the offset plus size notation. + +Zero sizes are used to encode that a rank-reduction is required, that is, +the rank of size 0 is removed from the output memref type. +A single index is syntactic sugar for offset plus size 0, e.g. %0 is syntactic sugar for %0:0. +(Note that a zero-size rank, e.g. in memref, is non-sense, because any multi-index passed +to the memref would be out-of-bounds. However, a one-sized rank, e.g. memref, might be desirable.) +A dynamic size of zero is undefined behaviour. + +There is no run-time check whether the indices are within bounds. +Offset and size must be of index type. +Offset must be non-negative and size must be positive. + +Restrictions +~~~~~~~~~~~~ + +The memref type of the result must conform with the following rules: + +#. Element type and address space must match the operand's memref type. +#. **Invariant-stride:** The stride is not changed or replaced with '?'. + + .. code:: + + subview %0[4:8,8:4] : memref> ; %0: memref + subview %0[4:8,8:4] : memref> ; %0: memref + + +#. **Rank-reduction:** A mode accessed by offset only or a mode with size statically known to be 0 is removed from the output tensor. + + .. code:: + + subview %0[2:4, %1] : memref ; %0: memref + subview %0[2:4, %1:0] : memref ; %0: memref + subview %0[2:4, %1:1] : memref> ; %0: memref + +#. **Output-mode size:** The size of the output mode is determined by the size field of a slice + and may be dynamic. + + .. code:: + + subview %0[%1:4] : memref ; %0: memref + subview %0[%2:%2] : memref ; %0: memref + subview %0[2:4, %2:%2, 6:7] : memref ; %0: memref + subview %0[2:4, %2:%2, 6:7] : memref ; %0: memref> + +Store +..... + +.. code:: abnf + + instruction =/ "store" [store-flag] local-identifier "," + local-identifier "[" [local-identifier-list] "]" + store-flag = ".atomic" / ".atomic_add" / ".atomic_max" / ".atomic_min" + +Overview +~~~~~~~~ + +Store a scalar value (first operand) in a memref (second operand) at the position given by the index list. +The number of indices must match the order of the memref. + +The store is atomic when the atomic flag is set with relaxed memory ordering. +When the atomic_add/max/min flag is set, the following steps are done atomically: +The value at the memory location is fetched, the scalar value is added to the fetched value, +and the resulting value is stored at the memory location. + +When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used +for the the real and imaginary separately. + +*Note:* Store should only be used in SPMD regions as otherwise the same memory location is written +from all work-items. + +Operands +~~~~~~~~ + +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 scalar-type value +2 memref-type tensor +3... index index list +======= ================ =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{type}(value) = \text{element_type}(tensor)` + +Yield +..... + +.. code:: abnf + + instruction =/ "yield" "(" [local-identifier-list] ")" + +Overview +~~~~~~~~ + +Yield returns values from an if or for instruction. + +Operands +~~~~~~~~ + +======= ============================================ =========== +Op.-No. Type Description +======= ============================================ =========== +1... boolean-type / scalar-type / coopmatrix-type value +======= ============================================ =========== + +SPMD instructions +----------------- + +Builtin (SPMD) +.............. + +.. code:: abnf + + spmd-builtin-type = "subgroup_id" comp3 / + "subgroup_linear_id" / + "subgroup_local_id" + value-instruction =/ spmd-builtin-type ":" integer-type + +Overview +~~~~~~~~ + +Returns a builtin value. + +The mode of the subgroup id is selected with the .x, .y, and .z suffix. +Each mode starts with zero and is limited by the corresponding num_subgroups mode. That is, .. math:: - C := \alpha \text{op}_1(A) \text{op}_2(B) + \beta C + \forall d \in \{x,y,z\} : 0 \leq \text{subgroup_id}_d < \text{num_subgroups}_d -If the atomic flag is set, C is updated atomically. +The subgroup linear id combines the x, y, and z modes of the subgroup id as following (note that +that :math:`\text{subgroup_id}_z = 0` due to :math:`\text{num_subgroups}_z = 1`): + +.. math:: + + \text{subgroup_linear_id} = \text{subgroup_id}_x + + \text{subgroup_id}_y\cdot \text{num_subgroups}_x + + +The subgroup local id is the invocation id within the subgroup and ranges from 0 to subgroup_size-1. + +The following table shows the builtins' description and the types that are returned. + +================== ===== ====================== ==================================================== +Builtin Type OpenCL analogue Description +================== ===== ====================== ==================================================== +subgroup_id i32 N/A Returns the x, y, or z mode of the subgroup id +subgroup_linear_id i32 get_sub_group_id Returns linear subgroup id +subgroup_local_id i32 get_sub_group_local_id Returns the local invocation id in the subgroup +================== ===== ====================== ==================================================== + +Cooperative matrix apply +........................ + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_apply" + "(" local-identifier "," local-identifier "," local-identifier ")" + "=" local-identifier + "->" coopmatrix-type region + +Overview +~~~~~~~~ + +Apply an action on every component of a coopmatrix and update the component with the result of the action. +The action is described in the *parallel region* of the instruction. Arguments -......... +~~~~~~~~~ -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, B, and C, respectively. +The first three local identifier introduce SSA values for the row index, column index, and component value. +The row and columns values have i32 type and the component value has the same component type as the resulting +coopmatrix type. +The fourth identifer, after "in", gives the input coopmatrix, and its type must match the result type. -The first transpose modifier defines :math:`\text{op}_1` and the second transpose modifier -defines :math:`\text{op}_2` as following: +The region must yield exactly one value whose scalar type is identical to the component type of the coopmatrix. -.. math:: +Example: - \text{op}_i(X) := \left\{ - \begin{array}{rcl} - X^T & \text{ if } & \text{modifier}_i = t,\\ - X & \text{ if } & \text{modifier}_i = n. - \end{array} - \right. +.. code:: + %0 = ... ; contains a coopmatrix of type coopmatrix + %1 = cooperative_matrix_apply (%i,%j,%v)=%0 -> coopmatrix { + %mask = cmp.le %i, %j : bool + %exp_v_masked = if %mask -> (f32) { + %exp_v = math.native_exp %v : f32 + yield (%exp_v) + } else { + %zero = constant 0.0 : f32 + yield (%zero) + } + yield (%exp_v_masked) + } + ; The entries of %1 are given by %1[i,j] = exp(%0[i,j]) if i <= j else 0 -If :math:`\text{op}_1(A)` has the shape MxK and -:math:`\text{op}_2(B)` has the shape KxN then C must have the shape MxN. +Cooperative matrix extract +.......................... -GEMV ----- +.. code:: abnf + + value-instruction =/ "cooperative_matrix_extract" + local-identifier "[" integer-constant "]" ":" scalar-type + +Overview +~~~~~~~~ + +Return an element of the coopmatrix's work-item vector. +The index is supplied in square brackets and must be greater or equal than zero +and smaller than the length of the work-item vector, cf. :ref:`coopmatrix layout`. + +The scalar type of the returned value must match the component type of the coopmatrix. + +Operands +~~~~~~~~~ + +======= ================ =========================== +Op.-No. Type Description +======= ================ =========================== +1 coopmatrix-type Cooperative matrix +2 integer-constant Index into work-item vector +======= ================ =========================== + +Cooperative matrix insert +......................... .. code:: abnf - gemv-instruction = "gemm" transpose [".atomic"] - "," const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + value-instruction =/ "cooperative_matrix_insert" local-identifier "," + local-identifier "[" integer-constant "]" ":" coopmatrix-type Overview -........ +~~~~~~~~ -*Collective instruction.* -GEMV implements the well-known GEMM BLAS-2 operation. +Return a copy the coopmatrix, while modifying one entry of the coopmatrix. +The index is supplied in square brackets and must be greater or equal than zero +and smaller than the length of the work-item vector, cf. :ref:`coopmatrix layout`. + +The coopmatrix type of the returned value must match the coopmatrix type of the incoming matrix. +The scalar type of the inserted scalar must match the component type of the coopmatrix. + +Operands +~~~~~~~~~ + +======= ================ =========================== +Op.-No. Type Description +======= ================ =========================== +1 scalar-type Inserted scalar +2 coopmatrix-type Cooperative matrix +3 integer-constant Index into work-item vector +======= ================ =========================== + +Cooperative matrix load +....................... + +.. code:: abnf + + value-instruction =/ "cooperative_matrix_load" [transpose] [checked-flag] + local-identifier "[" local-identifier "," local-identifier "]" + ":" coopmatrix-type + checked-flag = ".rows_checked" / ".cols_checked" / ".both_checked" + +Overview +~~~~~~~~ + +Load a cooperative matrix from a 2d-memref at the position given by the indices in square brackets. +The position gives the starting row and column index, that is, +when a coopmatrix of size :math:`X\times Y` is loaded from memref :math:`M` at +position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are given by .. math:: - c := \alpha \text{op}_1(A) b + \beta C + \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + i) S_1 + (y + j) S_2], -If the atomic flag is set, c is updated atomically. +where :math:`S_1` and :math:`S_2` are the entries of the memref's stride array. +When the transpose modifier ".t" is given, we have -Arguments -......... +.. math:: -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -A, b, and c, respectively. + \forall i \in [0,X), j \in [0,Y): A_{ij} := M[(x + j) S_1 + (y + i) S_2] + +When the checked flag is set, the following out-of-bound checks are added +(with memref shape :math:`s_1\times s_2`): + +=============== ===================================================================== +Flag Description +=============== ===================================================================== +.n.rows_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq x+i < s_1 \text{ else } 0` +.t.rows_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq y+i < s_2 \text{ else } 0` +.n.cols_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq y+j < s_2 \text{ else } 0` +.t.cols_checked :math:`A_{ij} := M[...] \text{ if } 0 \leq x+j < s_1 \text{ else } 0` +.n.both_checked .n.rows_checked.n and .n.cols_checked +.t.both_checked .t.rows_checked.t and .t.cols_checked +=============== ===================================================================== + +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 memref-type M +2 index x +3 index y +======= =============== =========== -The transpose modifier for A as in GEMM. +Restrictions +~~~~~~~~~~~~ -:math:`\text{op}_1(A)` has the shape MxK and :math:`B` has the shape K then c must have the shape M. +* :math:`\text{order}(M) = 2` +* :math:`\text{component_type}(A) = \text{element_type}(M)` +* All arguments **must** be dynamically uniform. -GER ---- +Cooperative matrix mul add +.......................... .. code:: abnf - ger-instruction = "ger" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + value-instruction =/ "cooperative_matrix_mul_add" local-identifier "," + local-identifier "," local-identifier ":" coopmatrix-type Overview -........ +~~~~~~~~ -Computes the general rank-1 update: +Matrix mul add returns the value of .. math:: - C := \alpha a b^T + \beta C + D := AB + C, -If the atomic flag is set, C is updated atomically. +where A, B, and C are matrices given by the three operands. -Arguments -......... +The number of rows of matrix A,C, and D must be a multiple of the subgroup size. -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and C, respectively. +Operands +~~~~~~~~ -a and b must be vectors. If the size of a is M and the size of b is N the shape of C must be :math:`M\times N`. +======= =============== ========== =========== +Op.-No. Type Use Description +======= =============== ========== =========== +1 coopmatrix-type matrix_a A +2 coopmatrix-type matrix_b B +3 coopmatrix-type matrix_acc C +======= =============== ========== =========== +Restrictions +~~~~~~~~~~~~ -Hadamard product ----------------- +* :math:`\forall X\in\{A,C,D\}: \text{rows}(X) \bmod \text{subgroup_size} = 0` +* :math:`\text{columns}(A) = \text{rows}(B)` +* :math:`\text{rows}(C) = \text{rows}(A) \land \text{columns}(C) = \text{columns}(B)` +* :math:`\text{shape}(D) = \text{shape}(C)` +* :math:`\text{use}(D) = \text{matrix_acc}` +* :math:`\text{promote}(\text{component_type}(A), \text{component_type}(B)) \preceq \text{component_type}(C)` +* Cast of :math:`\text{component_type}(C)` to :math:`\text{component_type}(D)` must be allowed + +Cooperative matrix prefetch +........................... .. code:: abnf - hadamard-product-instruction = "hadamard_product" [".atomic"] - const-or-val "," local-identifier "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," memref-type "," scalar-type "," memref-type + instruction =/ "cooperative_matrix_prefetch" integer-constant "," + local-identifier "[" local-identifier "," local-identifier "]" "," + integer-constant "," integer-constant Overview -........ +~~~~~~~~ -*Collective instruction.* -Computes the Hadamard product of two tensors. -That is, in index notation we have +Cooperatively prefetch memory into device cache. +The cache level is given by the first non-negative integer constant, where "0" is the cache closest the core +and core distance increases with increasing cache level. +The prefetch instruction is ignored if the cache level does not exist in the target device. +The position in square brackets gives the starting row and column index. +The last two positive integer constants give the size of the memory region to fetch (in rows by columns). +The following memory locations are prefetched: .. math:: - c_{i} := \alpha a_{i} b_{i} + \beta c_{i} + \{\forall i \in [0,X), j \in [0,Y): M[(x + i) S_1 + (y + j) S_2]\} -If the atomic flag is set, c is updated atomically. +Prefetch is an optimization hint and may be disregarded by the compiler. -Arguments -......... +Operands +~~~~~~~~ -The first argument gives :math:`\alpha` and the fourth argument gives :math:`\beta`. -The second, the third, and the fifth argument must have memref type and give -a, b, and c, respectively. +======= ================ =========== +Op.-No. Type Description +======= ================ =========== +1 integer-constant Cache-level +2 memref-type M +3 index x +4 index y +5 integer-constant X +6 integer-constant Y +======= ================ =========== -a, b, and c must be vectors and have equal shape. +Restrictions +~~~~~~~~~~~~ +* All arguments **must** be dynamically uniform. -Store ------ +Cooperative matrix reduce +......................... .. code:: abnf - store-instruction = "store" local-identifier "," local-identifier "[" [index-list] "]" ":" memref-type + coopmatrix-reduce-op = "cooperative_matrix_reduce_add" / + "cooperative_matrix_reduce_max" / + "cooperative_matrix_reduce_min" / + value-instruction =/ coopmatrix-reduce-op reduce-mode local-identifier ":" coopmatrix-type + reduce-mode = ".row" / ".column" Overview -........ +~~~~~~~~ -*Replicated instruction.* -Store a scalar value in a memref at the position given by the index list. -The number of indices must match the order of the memref. +Computes the sum, maximum, or minimum over either the rows or columns of a coopmatrix. -*Note:* Store should only be used in SPMD regions as otherwise the same memory location is written -from all work-items. +The component type and use of the the returned value's coopmatrix type +must match the component type and use of the incoming matrix. -Arguments -......... +For a row reduction the resulting shape must be :math:`M\times 1` and for a column reduction +the resulting shape must be :math:`1\times N`, where the shape of the incoming matrix is :math:`M\times N`. -The first operand must have the same scalar type as the memref type. -The indices must be of ``index`` type. +Operands +~~~~~~~~~ -Sum ---- +======= ================ =========================== +Op.-No. Type Description +======= ================ =========================== +1 coopmatrix-type Incoming cooperative matrix +======= ================ =========================== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{rows}(A) \bmod \text{subgroup_size} = 0` + +Cooperative matrix scale +........................ .. code:: abnf - sum-instruction = "sum" transpose [".atomic"] - "," const-or-val "," local-identifier "," const-or-val "," local-identifier - ":" scalar-type "," memref-type "," scalar-type "," memref-type + value-instruction =/ "cooperative_matrix_scale" local-identifier "," local-identifier + ":" coopmatrix-type Overview -........ +~~~~~~~~ -*Collective instruction.* -Computes the matrix-vector product or the dot product of A with a vector of ones. -That is, for matrices we have +Scale a coopmatrix by a scalar. +The scalar type of the scalar and the component type of the coopmatrix must match, +and the returned must have the same coopmatrix type as the matrix operand. + +Operands +~~~~~~~~ + +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 scalar-type scalar +2 coopmatrix-type matrix +======= =============== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{type}(scalar) = \text{component_type}(matrix)` +* :math:`\text{type}(result) = \text{type}(matrix)` + +Cooperative matrix store +........................ + +.. code:: abnf + + instruction =/ "cooperative_matrix_store" [transpose] [checked-flag] [store-flag] + local-identifier "," local-identifier + "[" local-identifier "," local-identifier "]" + +Overview +~~~~~~~~ + +Store a cooperative matrix value in a 2d-memref at the position given by the indices in square brackets. +The position gives the starting row and column index, that is, +when a coopmatrix of size :math:`X\times Y` is written to memref :math:`M` at +position :math:`x, y`, then the components :math:`A_{ij}` of the coopmatrix are written to .. math:: - B := \alpha \text{op}(A) \vec{1} + \beta B + \forall i \in [0,X), j \in [0,Y): M[(x + i) S_1 + (y + j) S_2] := A_{ij}, -and for vectors we have +where :math:`S_1` and :math:`S_2` are the entries of the memref's stride array. +When the transpose modifier ".t" is given, we have .. math:: - b := \alpha \left + \beta b + \forall i \in [0,X), j \in [0,Y): M[(x + j) S_1 + (y + i) S_2] := A_{ij} -If the atomic flag is set, B is updated atomically. +When the checked flag is set, the following out-of-bound checks are added +(with memref shape :math:`s_1\times s_2`): +=============== ============================================== +Flag Description +=============== ============================================== +.n.rows_checked Only execute store if :math:`0 \leq x+i < s_1` +.t.rows_checked Only execute store if :math:`0 \leq y+i < s_2` +.n.cols_checked Only execute store if :math:`0 \leq y+j < s_2` +.t.cols_checked Only execute store if :math:`0 \leq x+j < s_1` +.n.both_checked .n.rows_checked + .n.cols_checked +.t.both_checked .t.rows_checked + .t.cols_checked +=============== ============================================== -Arguments -......... +The store is atomic when the atomic flag is set with relaxed memory ordering. +When the atomic_add flag is set, the coopmatrix is added to the memref atomically. -The first argument gives :math:`\alpha` and the third argument gives :math:`\beta`. -The second and the fourth argument must have memref type and give A and B, respectively. -If A is a matrix then B must be a vector. -The first mode size of :math:`\text{op}(A)` must match the size of B. -If A is a vector, then B must be a scalar memref. +When storing a complex value the update may be pseudo-atomic, meaning that an atomic store is used +for the the real and imaginary separately. -The transpose op is defined as in the axpby instruction. +Operands +~~~~~~~~ -Yield ------ +======= =============== =========== +Op.-No. Type Description +======= =============== =========== +1 coopmatrix-type A +2 memref-type M +3 index x +4 index y +======= =============== =========== + +Restrictions +~~~~~~~~~~~~ + +* :math:`\text{component_type}(A) = \text{element_type}(B)` +* All arguments **must** be dynamically uniform. + +Subgroup broadcast +.................. .. code:: abnf - yield-instruction = "yield" [local-identifier-list] ":" [scalar-type-list] - identifier-or-constant-list = identifier-or-constant *("," identifier-or-constant) + value-instruction =/ "subgroup_broadcast" local-identifier "," local-identifier ":" scalar-type Overview -........ +~~~~~~~~ -Yield returns values from an if or for instruction. +Broadcast a scalar to all work-items in the subgroup. +The scalar type of the first operand and the type of the result must match. +The second identifier must have i32 type. -Arguments -......... +Operands +~~~~~~~~ + +======= =============== ================================================================================================== +Op.-No. Type Description +======= =============== ================================================================================================== +1 scalar-type Value that is to be distributed to all work-items of the sub-group +2 i32 Subgroup local index that identifies the work-item whose value is returned to all other work-items +======= =============== ================================================================================================== -The length of the local identifier list must equal the length of the scalar type list. +Restrictions +~~~~~~~~~~~~ +* The second operand **must** be dynamically uniform. -Additional instructions ------------------------ +Subgroup operation +.................. .. code:: abnf - barrier-instruction = "barrier" - lifetime-stop-instruction = "lifetime_stop" local-identifier + subgroup-operation-type = "subgroup_exclusive_scan_add" / + "subgroup_exclusive_scan_max" / + "subgroup_exclusive_scan_min" / + "subgroup_inclusive_scan_add" / + "subgroup_inclusive_scan_max" / + "subgroup_inclusive_scan_min" / + "subgroup_reduce_add" / + "subgroup_reduce_max" / + "subgroup_reduce_min" + value-instruction =/ subgroup-operation-type local-identifier ":" scalar-type + +Overview +~~~~~~~~ + +Let :math:`[x_0,x_1,\dots,x_{n-1}]` be the input vector contributed by a subgroup of size *n*. +(The work-item with subgroup local id *i* contributes :math:`x_i`.) +Let :math:`\diamond` be the binary operator and *I* the identity. +We define the output vector of size *n* for the group operations in the following table: + +============== ============================================================================================= +Operation type Result +============== ============================================================================================= +exclusive_scan :math:`[I, x_0, (x_0 \diamond x_1), \dots, x_0 \diamond x_1 \diamond \dots \diamond x_{n-2}]` +inclusive_scan :math:`[x_0, (x_0 \diamond x_1), \dots, x_0 \diamond x_1 \diamond \dots \diamond x_{n-1}]` +reduce :math:`[s,s,\dots,s] \text{ with } s := x_0 \diamond \dots \diamond x_{n-1}` +============== ============================================================================================= + +Add +~~~ + +Computes the subgroup operation with :math:`\diamond:=+` and :math:`I:=0`. + +Max +~~~ + +Computes the subgroup operation with :math:`\diamond:=\max` and identity as given in the following table: + +============= ============================================== +Identity Value +============= ============================================== +integer-type Smallest integer representable by integer type +floating-type :math:`-\infty` +complex type Forbidden +============= ============================================== + +Min +~~~ + +Computes the subgroup operation with :math:`\diamond:=\min` and identity as given in the following table: + +============= ============================================= +Identity Value +============= ============================================= +integer-type Largest integer representable by integer type +floating-type :math:`+\infty` +complex type Forbidden +============= ============================================= + Sample code =========== @@ -1057,16 +2194,16 @@ where B and C are constant matrices and A and D are matrix batches. .. code:: func @fused_kernel(%alpha: f32, - %A: group>, + %A: groupx?>, %B: memref, %C: memref, %D: memref) { - %0 = group_id - %1 = load %A[%0] : group> ; Returns memref - %2 = subview %D[:,:,%0] : memref ; Returns memref - %tmp0 = alloca -> memref - gemm.n.t 1.0, %1, %B, 0.0, %tmp0 - : f32, memref, memref, f32, memref - gemm.n.n %alpha, %tmp0, %C, 1.0, %2 - : f32, memref, memref, f32, memref + %0 = group_id : index + %1 = load %A[%0] : memref + %2 = subview %D[:,:,%0] : memref + %tmp0 = alloca : memref + %zero = constant 0.0 : f32 + %one = constant 1.0 : f32 + gemm.n.t %one, %1, %B, %zero, %tmp0 + gemm.n.n %alpha, %tmp0, %C, %one, %2 } diff --git a/docs/manual/tutorial_matrix_chain.rst b/docs/manual/tutorial_matrix_chain.rst index e9c552d6..77299544 100644 --- a/docs/manual/tutorial_matrix_chain.rst +++ b/docs/manual/tutorial_matrix_chain.rst @@ -27,50 +27,57 @@ In the :ref:`tensor language ` we can implement the kernel as f func @fused_kernel(%K: memref, %P: memref, - %A: group>, + %A: groupx?>, %Q: memref) { - %gid = group_id ; Get our index e - - %p = subview %P[:,:,%gid] : memref ; %p has type memref - %a = load %A[%gid] : group> ; %a has type memref - %q = subview %Q[:,:,%gid] : memref ; %q has type memref - - %tmp = alloca -> memref ; Reserve temporary memory - - gemm.n.n 1.0, %K, %p, 0.0, %tmp ; Compute tmp <- K P(:,:,e) - : f32, memref, memref, f32, memref - gemm.n.n 1.0, %tmp, %a, 1.0, %q ; Update Q(:,:,e) <- Q(:,:,e) + tmp A_e - : f32, memref, memref, f32, memref + %gid = builtin.group_id : index ; Get our index e + + %p = subview %P[0:56,0:9,%gid] : memref ; Get view on submatrix + %a = load %A[%gid] : memref ; Load matrix from group + %q = subview %Q[0:56,0:9,%gid] : memref ; Get view on submatrix + + %tmp = alloca : memref ; Reserve temporary memory + ; in the Shared Local Memory + %c0 = constant 0.0 : f32 + %c1 = constant 1.0 : f32 + gemm.n.n %c1, %K, %p, %c0, %tmp ; Compute tmp <- K P(:,:,e) + gemm.n.n %c1, %tmp, %a, %c1, %q ; Update Q(:,:,e) <- Q(:,:,e) + tmp A_e } -Compilation with the Tiny Tensor Compiler generates the following OpenCL-C code - -.. code-block:: c - - kernel - __attribute__((reqd_work_group_size(64,1,1))) - __attribute__((intel_reqd_sub_group_size(32))) - fused_kernel(global float *K, global float *P, uint P_shape2, global float *global *A, - global float *Q, uint Q_shape2) { - local uchar stack[2016] __attribute__((aligned(64))); - uint gid = get_global_id(2); - global float *p = P + 0ll * 1 + 0ll * 56 + gid * 504; - global float *a = *(A + gid); - global float *q = Q + 0ll * 1 + 0ll * 56 + gid * 504; - local float *tmp = (local float *)(stack + 0); - gemm_f32f32f32f32f32_An_Bn_M56_N9_K56_Astride1_56_Bstride1_56_Cstride1_56_alpha3ff0000000000000_beta0( - 56, 9, 56, 0x1p+0f, K, 1, 56, p, 1, 56, 0x0p+0f, tmp, 1, 56); - barrier(CLK_LOCAL_MEM_FENCE); - gemm_f32f32f32f32f32_An_Bn_M56_N9_K9_Astride1_56_Bstride1_9_Cstride1_56_alpha3ff0000000000000_beta3ff0000000000000( - 56, 9, 9, 0x1p+0f, tmp, 1, 56, a, 1, 9, 0x1p+0f, q, 1, 56); +Using the *tinytc-opt* tool we can run compiler passes on the code to get insight on what is happening under the hood. +For example, running the insert-lifetime-stop, insert-barrier, and work-group-size pass, + +.. code-block:: bash + + tinytc-opt -pinsert-lifetime-stop -pinsert-barrier -pwork-group-size test.ir + +we get + +.. code-block:: + :emphasize-lines: 4, 13, 15 + + func @fused_kernel(%K: memref, + %P: memref, + %A: groupx?>, + %Q: memref) attributes{subgroup_size=32, work_group_size=[64,1]} { + %gid = builtin.group_id : index + %p = subview %P[0:56,0:9,%gid] : memref + %a = load %A[%gid] : memref + %q = subview %Q[0:56,0:9,%gid] : memref + %tmp = alloca : memref + %c0 = constant 0x0p+0 : f32 + %c1 = constant 0x1p+0 : f32 + gemm.n.n %c1, %K, %p, %c0, %tmp + barrier.local + gemm.n.n %c1, %tmp, %a, %c1, %q + lifetime_stop %tmp } -where the definition of the generated GEMM functions have been omitted for brevity. We observe that -* a GEMM is processed in parallel by a work-group with 64 threads, -* temporary memory is mapped to shared local memory (local uchar stack), -* load and subview calls translate to simple pointer manipulation, +* the kernel is executed concurrently by 64 work-items, +* temporary memory is only needed until after the lifetime_stop instruction after the GEMM + (if multiple alloca's are present that do not overlap, that is, lifetime_stop for alloca #1 appears before alloca #2, + then Shared Local Memory is reused, reducing the total amount needed), * and that a barrier has been introduced between the GEMM calls to avoid data races. When using SYCL, we can run the kernel using the following pseudo-code: @@ -81,16 +88,15 @@ When using SYCL, we can run the kernel using the following pseudo-code: #include #include - auto source_ctx = tinytc::make_source_context(); + #include + + auto ctx = tinytc::make_compiler_context(); + ctx.set_error_reporter([](char const *what, const tinytc_location_t *, + void *) { std::cerr << what << std::endl; }, + nullptr); try { // Parse tensor program - auto prog = tinytc::parse_file("fused_kernel.ir", source_ctx); - - // JIT compile program - auto q = sycl::queue{}; - auto info = tinytc::make_core_info(q.get_device()); - auto bin = tinytc::compile_to_binary(std::move(prog), info, tinytc::bundle_format::native, - source_ctx); + auto prog = tinytc::parse_file("fused_kernel.ir", ctx); // Initialize tensors float* K = ...; @@ -98,20 +104,21 @@ When using SYCL, we can run the kernel using the following pseudo-code: float** A = ...; float* Q = ...; - auto bundle = tinytc::make_kernel_bundle(q.get_context(), q.get_device(), bin); + // JIT compile program + auto q = sycl::queue{}; + auto bundle = tinytc::make_kernel_bundle(q.get_context(), q.get_device(), prog); + auto kernel = tinytc::make_kernel(bundle, "fused_kernel"); auto exe_range = tinytc::get_execution_range(kernel, howmany); for (int timestep = 0; timestep < num_timesteps; ++timestep) { q.submit([&](sycl::handler &h) { - h.set_args(K, P, howmany, A, Q, howmany); - h.parallel_for(exec_range, kernel); + h.set_args(K, P, howmany, A, howmany, Q, howmany); + h.parallel_for(exe_range, kernel); }).wait(); } } catch (tinytc::status const& st) { std::cerr << "Error (" << static_cast(st) << "): " - << tinytc::error_string(st) << std::endl; - std::cerr << "Error log:" << std::endl - << source_ctx.get_error_log() << std::endl; + << tinytc::to_string(st) << std::endl; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; } diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index d3f351a4..81cc0224 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) find_package(SYCL REQUIRED) -add_executable(tinytc-bench main.cpp args.cpp) +add_executable(tinytc-bench main.cpp) add_sycl_to_target(TARGET tinytc-bench SOURCES main.cpp) -target_link_libraries(tinytc-bench PRIVATE tinytc tinytc_sycl) +target_link_libraries(tinytc-bench PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(tinytc-bench) diff --git a/examples/benchmark/args.cpp b/examples/benchmark/args.cpp deleted file mode 100644 index 61030e6b..00000000 --- a/examples/benchmark/args.cpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" - -#include -#include -#include -#include -#include - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.internal_repetitions = 1; - a.transA = tinytc::transpose::N; - a.transB = tinytc::transpose::N; - a.beta = 0.0; - auto num = std::vector(3); - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error("==> Error: unrecognized argument " + - std::string(argv[i])); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "--trans-a") == 0) { - a.transA = tinytc::transpose::T; - } else if (std::strcmp(argv[i], "--trans-b") == 0) { - a.transB = tinytc::transpose::T; - } else if (std::strcmp(argv[i], "-v") == 0 || std::strcmp(argv[i], "--verify") == 0) { - a.verify = true; - } else if (std::strcmp(argv[i], "-a") == 0 || std::strcmp(argv[i], "--atomic") == 0) { - a.atomic = true; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-i") == 0 || - std::strcmp(argv[i], "--internal-reps") == 0) { - a.internal_repetitions = atoi(argv[++i]); - } else if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { - ++i; - a.beta = atof(argv[i]); - } else if (std::strcmp(argv[i], "-p") == 0 || - std::strcmp(argv[i], "--precision") == 0) { - ++i; - if (argv[i][0] == 'd') { - a.double_precision = true; - } else if (argv[i][0] == 's') { - a.double_precision = false; - } else { - fail(); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - num.clear(); - char const *delim = "x"; - auto arg = std::string(argv[i]); - char *token = std::strtok(argv[i], delim); - while (token) { - num.emplace_back(atoi(token)); - token = std::strtok(nullptr, delim); - } - if (num.size() != 3) { - throw std::runtime_error("==> Could not parse test case: " + arg); - } - a.tc.push_back({num[0], num[1], num[2]}); - } - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tinytcbench test-case1 test-case2 ..." << std::endl - << R"HELP( -positional arguments: - test-caseN MxNxK triplet (e.g. 64x64x64) - -optional arguments: - -h, --help Show help and quit - -i, --internal-reps Number of GEMM repetitions inside kernel (default: 1) - -p, --precision Precision (single = s, double = d) - --trans-a Transpose A matrix - --trans-b Transpose B matrix - -v, --verify Verify optimized implementation - -a, --atomic Update C atomically -)HELP"; -} diff --git a/examples/benchmark/args.hpp b/examples/benchmark/args.hpp deleted file mode 100644 index 23f69c65..00000000 --- a/examples/benchmark/args.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20230417_HPP -#define ARGS_20230417_HPP - -#include "tinytc/types.hpp" - -#include -#include -#include - -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - -struct args { - std::vector tc; - int internal_repetitions; - bool double_precision; - bool help; - tinytc::transpose transA; - tinytc::transpose transB; - double beta; - bool verify; - bool atomic; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20230417_HPP diff --git a/examples/benchmark/main.cpp b/examples/benchmark/main.cpp index f62101a7..4f4d2264 100644 --- a/examples/benchmark/main.cpp +++ b/examples/benchmark/main.cpp @@ -1,15 +1,20 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "../gemm_common.hpp" +#include #include +#include #include #include #include +#include #include +#include #include +#include #include #include #include @@ -21,6 +26,19 @@ using namespace sycl; using namespace tinytc; +struct args { + std::int32_t alignment = 0; + bool atomic = false; + bool dump = false; + std::int32_t internal_repetitions = 1; + bool trans_a = false; + bool trans_b = false; + examples::test_type ty = examples::test_type::f32; + bool update = false; + bool verify = false; + std::vector tc; +}; + template double bench(F f, int nrepeat = 10) { f(); double min_exec_time_ns = std::numeric_limits::max(); @@ -35,15 +53,16 @@ template double bench(F f, int nrepeat = 10) { return min_exec_time_ns; } -auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose tB, bool atomic, - std::int64_t M, std::int64_t N, std::int64_t K, +auto gemm_kernel_with_inner_repetition(tinytc_type_t element_ty, transpose tA, transpose tB, + bool atomic, std::int64_t M, std::int64_t N, std::int64_t K, std::array A_stride, - std::array B_stride, double beta, - std::array C_stride, - std::int32_t repetitions, queue q) -> source { - auto ctx = make_source_context(); + std::array B_stride, bool update, + std::array C_stride, std::int32_t alignment, + std::int32_t repetitions, bool dump_code, queue q) + -> shared_handle { + auto ctx = get_compiler_context(element_ty); char const *file_name = std::source_location::current().file_name(); - auto const source_id = ctx.add_source(file_name, ""); + auto const source_id = add_source(ctx.get(), file_name, ""); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -54,88 +73,117 @@ auto gemm_kernel_with_inner_repetition(scalar_type ty, transpose tA, transpose t ++l.end.column; return l; }; + auto const make_memref = [](tinytc_type_t element_ty, transpose t, int64_t A, std::int64_t B, + std::array const &stride) { + auto s = std::array{A, B}; + if (t == transpose::T) { + std::swap(s[0], s[1]); + } + return get(element_ty, s, stride, address_space::global); + }; - auto kernel = [&](function_builder &fb) { - auto A = fb.argument( - make_group(make_memref( - ty, {M, K}, std::vector(A_stride.begin(), A_stride.end()), my_loc())), - "A", my_loc()); - auto B = fb.argument( - make_group(make_memref( - ty, {K, N}, std::vector(B_stride.begin(), B_stride.end()), my_loc())), - "B", my_loc()); - auto C = fb.argument( - make_group(make_memref( - ty, {M, N}, std::vector(C_stride.begin(), C_stride.end()), my_loc())), - "C", my_loc()); - fb.body( - [&](region_builder &bb) { - auto gid = bb.add(make_group_id(my_loc())); - auto a = bb.add(make_load(A, {gid}, my_loc())); - auto b = bb.add(make_load(B, {gid}, my_loc())); - auto c = bb.add(make_load(C, {gid}, my_loc())); - bb.for_loop( - scalar_type::index, make_index(0, my_loc()), make_index(repetitions, my_loc()), - [&](region_builder &bb) { - bb.add(make_gemm(tA, tB, atomic, make_imm(1.0, ty, my_loc()), a, b, - make_imm(beta, ty, my_loc()), c, my_loc())); - }, - "r", my_loc()); + auto kernel = [&](tinytc_compiler_context_t ctx) { + auto index_ty = get(ctx); + auto A_ty = make_memref(element_ty, tA, M, K, A_stride); + auto B_ty = make_memref(element_ty, tB, K, N, B_stride); + auto C_ty = make_memref(element_ty, transpose::N, M, N, C_stride); + auto void_ty = get(ctx); + auto f = create_func("gemm", + {get(A_ty, dynamic, 0), get(B_ty, dynamic, 0), + get(C_ty, dynamic, 0)}, + void_ty, my_loc()); + if (alignment > 0) { + auto align_attr = get_dictionary_attr_with_sorted( + ctx, tinytc_named_attr_t{get(ctx, "align"), + get(ctx, alignment)}); + set_parameter_attr(f.get(), 0, align_attr); + set_parameter_attr(f.get(), 1, align_attr); + set_parameter_attr(f.get(), 2, align_attr); + } + auto fn_body = get_body(f.get()); + auto params = std::array{}; + get_parameters(fn_body, params); + + auto bb = region_builder{fn_body}; + auto gid = bb.create(comp3::x, index_ty, my_loc()); + auto from = bb.constant_zero(index_ty, my_loc()); + auto to = bb.create(repetitions, index_ty, my_loc()); + auto calpha = bb.constant_one(element_ty, my_loc()); + auto cbeta = + update ? bb.constant_one(element_ty, my_loc()) : bb.constant_zero(element_ty, my_loc()); + auto a = bb.create(params[0], array_view{gid}, A_ty, my_loc()); + auto b = bb.create(params[1], array_view{gid}, B_ty, my_loc()); + auto c = bb.create(params[2], array_view{gid}, C_ty, my_loc()); + bb.for_loop( + from, to, + [&](region_builder &bb, tinytc_value_t const &) { + bb.create(atomic, tA, tB, calpha, a, b, cbeta, c, my_loc()); }, - my_loc()); + nullptr, my_loc()); + + return f; }; try { - auto pb = program_builder{}; - pb.create("gemm", kernel, my_loc()); + auto p = create_prog(ctx.get(), my_loc()); + add_function(p.get(), kernel(ctx.get())); + if (dump_code) { + dump(p.get()); + } - auto info = make_core_info(q.get_device()); - info.set_core_features(tinytc_core_feature_flag_large_register_file); - return compile_to_opencl(pb.get_product(my_loc()), info, ctx); + auto info = create_core_info(q.get_device()); + set_core_features(info.get(), tinytc_core_feature_flag_large_register_file); + return compile_to_spirv_and_assemble(p.get(), info.get()); } catch (builder_error const &e) { - ctx.report_error(e.loc(), e.what()); - std::cerr << "Error (" << static_cast(e.code()) << "): " << std::endl - << ctx.get_error_log() << std::endl; + report_error(ctx.get(), e.loc(), e.what()); + std::cerr << "Error (" << static_cast(e.code()) << "): " << to_string(e.code()) + << std::endl; } catch (status const &st) { - std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl - << "Error log:" << std::endl - << ctx.get_error_log() << std::endl; + std::cerr << "Error (" << static_cast(st) << "): " << to_string(st) << std::endl; } - return source{nullptr}; + return {}; } template void test(queue q, args &a) { - auto const fill = [](T *ptr, std::size_t n) { - for (std::size_t i = 0; i < n; ++i) { - ptr[i] = i % 101; - } - }; + auto ctx = create_compiler_context(); + set_error_reporter(ctx.get(), [](char const *what, const tinytc_location_t *, void *) { + std::cerr << what << std::endl; + }); + auto total_reals = 1024 * 1024 * 1024 / sizeof(T); T *A_host = new T[total_reals]; T *B_host = new T[total_reals]; T *C_host = new T[total_reals]; - T *C_ref_host = new T[total_reals]; - T *C_ref = malloc_device(total_reals, q); - T *A = malloc_device(total_reals, q); - T *B = malloc_device(total_reals, q); - T *C = malloc_device(total_reals, q); - fill(A_host, total_reals); - fill(B_host, total_reals); - q.copy(A_host, A, total_reals).wait(); - q.copy(B_host, B, total_reals).wait(); + auto const alloc_device = [&a, &q](std::size_t num_bytes) { + if (a.alignment == 0) { + return malloc_device(num_bytes, q); + } else { + return aligned_alloc_device(a.alignment, num_bytes, q); + } + }; + T *A = alloc_device(total_reals); + T *B = alloc_device(total_reals); + T *C = alloc_device(total_reals); - auto const check = [&](std::int64_t M, std::int64_t N, std::size_t howmany) { - q.copy(C_ref, C_ref_host, total_reals).wait(); + auto const check = [&](std::int64_t M, std::int64_t N, std::int64_t K, std::int64_t howmany) { q.copy(C, C_host, total_reals).wait(); std::size_t num_err = 0; - for (std::size_t i = 0; i < M * N * howmany; ++i) { - auto err = std::abs(C_host[i] - C_ref_host[i]); - if (err > 10.0 * std::numeric_limits::epsilon()) { - if (num_err < 10) { - std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] - << std::endl; + const auto error_bound = examples::test_gemm_error_bound(K); + for (std::int64_t b = 0; b < howmany; ++b) { + auto C_host_b = C_host + b * M * N; + for (std::int64_t j = 0; j < N; ++j) { + for (std::int64_t i = 0; i < M; ++i) { + const auto relerr = examples::test_gemm_rel_error(C_host_b, i, j, M); + if (relerr > error_bound) { + if (num_err < 10) { + std::cout + << "C_{" << i << "," << j << "," << b << "}=" << C_host_b[i + j * M] + << ", relative_error=" << relerr << ", error_bound=" << error_bound + << std::endl; + } + ++num_err; + } } - ++num_err; } } if (num_err > 10) { @@ -143,7 +191,6 @@ template void test(queue q, args &a) { } }; - auto const &type = typeid(T); for (auto &c : a.tc) { auto na = c.m * c.k; auto nb = c.k * c.n; @@ -151,29 +198,13 @@ template void test(queue q, args &a) { auto max_reals = std::max(std::max(na, nb), nc); auto howmany = total_reals / max_reals; - if (a.verify && a.internal_repetitions == 1) { - q.submit([&](auto &h) { - bool transa = a.transA == transpose::T; - bool transb = a.transB == transpose::T; - h.parallel_for(range{howmany, 32}, [=](id<2> it) { - auto batch = it[0]; - auto m = it[1]; - auto a = A + batch * na; - auto b = B + batch * nb; - auto c_ref = C_ref + batch * nc; - for (std::int64_t mb = m; mb < c.m; mb += 32) { - for (std::int64_t n = 0; n < c.n; ++n) { - auto c_acc = 0.0f; - for (std::int64_t k = 0; k < c.k; ++k) { - c_acc += a[transa ? k + mb * c.k : mb + k * c.m] * - b[transb ? n + k * c.n : k + n * c.k]; - } - c_ref[mb + n * c.m] = c_acc; - } - } - }); - }); + for (std::size_t i = 0; i < howmany; ++i) { + examples::test_gemm_matrix(A_host + i * na, c.m, c.k, a.trans_a); + examples::test_gemm_matrix(B_host + i * nb, c.k, c.n, a.trans_b); } + q.copy(A_host, A, total_reals).wait(); + q.copy(B_host, B, total_reals).wait(); + q.memset(C, 0, total_reals * sizeof(T)).wait(); T const **AA = malloc_shared(howmany, q); T const **BB = malloc_shared(howmany, q); @@ -186,42 +217,53 @@ template void test(queue q, args &a) { double min_exec_time_ns = 0.0; try { + auto element_ty = to_type(ctx.get()); auto src = gemm_kernel_with_inner_repetition( - to_scalar_type_v, a.transA, a.transB, a.atomic, c.m, c.n, c.k, - {1, a.transA == transpose::T ? c.k : c.m}, - {1, a.transB == transpose::T ? c.n : c.k}, a.beta, {1, c.m}, a.internal_repetitions, - q); + element_ty, a.trans_a ? transpose::T : transpose::N, + a.trans_b ? transpose::T : transpose::N, a.atomic, c.m, c.n, c.k, + {1, a.trans_a ? c.k : c.m}, {1, a.trans_b ? c.n : c.k}, a.update, {1, c.m}, + a.alignment, a.internal_repetitions, a.dump, q); if (src) { - auto bundle = make_kernel_bundle(q.get_context(), q.get_device(), src); - auto kernel = make_kernel(bundle, "gemm"); - auto exe_range = get_execution_range(kernel, howmany); + auto bundle = create_kernel_bundle(q.get_context(), q.get_device(), src.get()); + auto kernel = create_kernel(bundle, "gemm"); + auto exe_range = get_execution_range(kernel, sycl::range<3u>{1u, 1u, howmany}); q.submit([&](handler &h) { - h.set_args(AA, BB, CC); + h.set_args(AA, howmany, BB, howmany, CC, howmany); h.parallel_for(exe_range, kernel); }).wait(); if (a.internal_repetitions == 1 && a.verify) { - check(c.m, c.n, howmany); + check(c.m, c.n, c.k, howmany); } min_exec_time_ns = bench([&]() { q.submit([&](handler &h) { - h.set_args(AA, BB, CC); + h.set_args(AA, howmany, BB, howmany, CC, howmany); h.parallel_for(exe_range, kernel); }).wait(); }); - auto gflops = - a.internal_repetitions * 2 * c.m * c.n * c.k * howmany / min_exec_time_ns; + const auto ops_per_mnk = [&] { + switch (a.ty) { + case examples::test_type::c32: + case examples::test_type::c64: + return 8; + default: + return 2; + } + }(); + + auto gflops = a.internal_repetitions * ops_per_mnk * c.m * c.n * c.k * howmany / + min_exec_time_ns; auto roofline_gflops = std::min(512 * 32 * 1.6e9, a.internal_repetitions * 2 * c.m * c.n * c.k / (sizeof(T) * (na + nb + nc) / 1.1e12)) / 1e9; - std::cout << type.name() << "," << c.m << "," << c.n << "," << c.k << "," << howmany - << "," << min_exec_time_ns / 1e9 << "," << gflops << "," + std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," + << howmany << "," << min_exec_time_ns / 1e9 << "," << gflops << "," << roofline_gflops << "," << std::round(gflops / roofline_gflops * 100) << "%," << a.internal_repetitions << std::endl; } } catch (status const &st) { - std::cerr << "Error: " << error_string(st) << std::endl; + std::cerr << "Error: " << to_string(st) << std::endl; } catch (std::exception const &e) { std::cerr << "Error: " << e.what() << std::endl; } @@ -234,24 +276,45 @@ template void test(queue q, args &a) { free(A, q); free(B, q); free(C, q); - free(C_ref, q); delete[] A_host; delete[] B_host; delete[] C_host; - delete[] C_ref_host; }; int main(int argc, char **argv) { auto a = args{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); - } catch (std::runtime_error const &e) { + parser.set_short_opt('a', &a.atomic, "Update C atomically"); + parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); + parser.set_short_opt('f', &a.ty, "Data type (bf16, f16, f32, f64, c32, c64)") + .converter(examples::convert_data_type); + parser + .set_short_opt('i', &a.internal_repetitions, + "Number of GEMM repetitions inside kernel (default: 1)") + .validator([](std::int32_t rep) { return 0 <= rep; }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('u', &a.update, + "Add A*B to C (beta=1) instead of overwriting C (beta=0)"); + parser.set_short_opt('v', &a.verify, "Verify optimized implementation"); + parser.set_long_opt("help", &help, "Show help"); + parser.set_long_opt("alignment", &a.alignment, "Memory alignment"); + parser.set_long_opt("transpose-a", &a.trans_a, "Transpose A matrix"); + parser.set_long_opt("transpose-b", &a.trans_b, "Transpose B matrix"); + parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 64x64x64)") + .converter(examples::convert_test_case) + .validator(examples::validate_test_case); + + parser.parse(argc, argv); + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } - if (a.help || a.tc.empty()) { - arg_parser::show_help(std::cout); - return 0; + if (help || a.tc.empty()) { + parser.print_help(std::cout, "tinytc-bench", ""); + return !help ? -1 : 0; } auto q = queue{}; @@ -260,11 +323,7 @@ int main(int argc, char **argv) { "repetitions" << std::endl; try { - if (a.double_precision) { - test(std::move(q), a); - } else { - test(std::move(q), a); - } + dispatch(a.ty, [&]() { test(q, a); }); } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; diff --git a/examples/builder/main.c b/examples/builder/main.c index 49ad78d4..9dcbdf36 100644 --- a/examples/builder/main.c +++ b/examples/builder/main.c @@ -1,53 +1,71 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include "tinytc/builder.h" #include "tinytc/tinytc.h" #include int main(void) { - tinytc_scalar_type_t type = tinytc_scalar_type_f32; int64_t M = 64; int64_t N = 32; - tinytc_data_type_t dt; + char const *copy_fun_name = "copy"; + size_t num_results; + size_t num_params; + tinytc_compiler_context_t ctx; + tinytc_prog_t program; + tinytc_type_t void_ty, element_ty, ty; + tinytc_func_t copy_fun; + tinytc_region_t copy_body; + tinytc_inst_t tmp; + tinytc_value_t params[2]; + tinytc_value_t alpha, beta; + + tinytc_compiler_context_create(&ctx); + + // Create program + tinytc_prog_create(&program, ctx, NULL); + + // Get types + tinytc_f32_type_get(&element_ty, ctx); int64_t shape[2] = {M, N}; - tinytc_memref_type_create(&dt, type, 2, shape, 0, NULL, NULL); - - tinytc_value_t A, B, alpha, beta; - tinytc_value_create(&A, dt, NULL); - tinytc_value_create(&B, dt, NULL); - tinytc_float_imm_create(&alpha, 1.0, type, NULL); - tinytc_float_imm_create(&beta, 0.0, type, NULL); - tinytc_data_type_release(dt); - - tinytc_inst_t copy_inst; - tinytc_axpby_inst_create(©_inst, tinytc_transpose_N, 0, alpha, A, beta, B, NULL); - tinytc_value_release(alpha); - tinytc_value_release(beta); - - tinytc_func_t copy_proto; - tinytc_value_t args[2] = {A, B}; - tinytc_function_prototype_create(©_proto, "copy", 2, args, NULL); - tinytc_value_release(A); - tinytc_value_release(B); + tinytc_memref_type_get(&ty, element_ty, 2, shape, 0, NULL, tinytc_address_space_global); - tinytc_region_t copy_body; - tinytc_region_create(©_body, 1, ©_inst, NULL); - tinytc_inst_release(copy_inst); + // Get void type + tinytc_void_type_get(&void_ty, ctx); - tinytc_func_t copy_fun; - tinytc_function_create(©_fun, copy_proto, copy_body, NULL); - tinytc_func_release(copy_proto); - tinytc_region_release(copy_body); + // Create function + tinytc_type_t param_types[2] = {ty, ty}; + tinytc_func_create(©_fun, sizeof(copy_fun_name) - 1, copy_fun_name, 2, param_types, void_ty, + NULL); + tinytc_prog_add_function(program, copy_fun); - tinytc_prog_t program; - tinytc_program_create(&program, 1, ©_fun, NULL); - tinytc_func_release(copy_fun); + // Get body + tinytc_func_get_body(copy_fun, ©_body); + num_params = 2; + tinytc_region_get_parameters(copy_body, &num_params, params); + + // Create instructions + tinytc_constant_inst_create_one(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &alpha); + tinytc_region_append(copy_body, tmp); + + tinytc_constant_inst_create_zero(&tmp, element_ty, NULL); + num_results = 1; + tinytc_inst_get_values(tmp, &num_results, &beta); + tinytc_region_append(copy_body, tmp); + + tinytc_axpby_inst_create(&tmp, 0, tinytc_transpose_N, alpha, params[0], beta, params[1], NULL); + tinytc_region_append(copy_body, tmp); + // Dump program tinytc_prog_dump(program); + // Clean-up tinytc_prog_release(program); + tinytc_compiler_context_release(ctx); return 0; } diff --git a/examples/builder/main.cpp b/examples/builder/main.cpp index f506aeeb..47370bfc 100644 --- a/examples/builder/main.cpp +++ b/examples/builder/main.cpp @@ -1,34 +1,44 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "tinytc/tinytc.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include #include +#include using namespace tinytc; int main() { - scalar_type type = scalar_type::f32; int64_t M = 64; int64_t N = 32; try { - auto pb = program_builder{}; - pb.create("copy", [&](function_builder &fb) { - auto dt = make_memref(type, {M, N}); - auto A = fb.argument(dt); - auto B = fb.argument(dt); - fb.body([&](region_builder &bb) { - auto alpha = make_imm(1.0, type); - auto beta = make_imm(0.0, type); - bb.add(make_axpby(transpose::N, false, alpha, A, beta, B)); - }); - }); - auto program = pb.get_product(); - - program.dump(); + auto ctx = create_compiler_context(); + auto element_ty = get(ctx.get()); + auto ty = get(element_ty, array_view{M, N}, array_view{}, + address_space::global); + + auto void_ty = get(ctx.get()); + auto f = create_func("copy", {ty, ty}, void_ty); + + auto body = get_body(f.get()); + std::array params; + get_parameters(body, params); + + auto bb = region_builder{body}; + auto alpha = bb.constant_one(element_ty); + auto beta = bb.constant_zero(element_ty); + bb.create(false, transpose::N, alpha, params[0], beta, params[1]); + + auto p = create_prog(ctx.get()); + add_function(p.get(), std::move(f)); + + dump(p.get()); } catch (builder_error const &e) { std::cerr << "Error " << static_cast(e.code()) << std::endl; } catch (status const &st) { diff --git a/examples/gemm_common.hpp b/examples/gemm_common.hpp new file mode 100644 index 00000000..0762185d --- /dev/null +++ b/examples/gemm_common.hpp @@ -0,0 +1,183 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GEMM_COMMON_20241014_HPP +#define GEMM_COMMON_20241014_HPP + +#include "argparser.hpp" +#include "tinytc/tinytc.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::examples { + +struct test_case { + std::int64_t m; + std::int64_t n; + std::int64_t k; +}; + +enum class test_type { bf16, f16, f32, f64, c32, c64 }; +auto to_string(test_type ty) { + switch (ty) { + case test_type::bf16: + return "bf16"; + case test_type::f16: + return "f16"; + case test_type::f32: + return "f32"; + case test_type::f64: + return "f64"; + case test_type::c32: + return "c32"; + case test_type::c64: + return "c64"; + } + return "unknown"; +} + +inline auto convert_data_type(char const *str, test_type &val) -> cmd::parser_status { + if (std::strcmp(str, "bf16") == 0) { + val = test_type::bf16; + } else if (std::strcmp(str, "f16") == 0) { + val = test_type::f16; + } else if (std::strcmp(str, "f32") == 0) { + val = test_type::f32; + } else if (std::strcmp(str, "f64") == 0) { + val = test_type::f64; + } else if (std::strcmp(str, "c32") == 0) { + val = test_type::c32; + } else if (std::strcmp(str, "c64") == 0) { + val = test_type::c64; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; +} +template auto dispatch(test_type ty, F &&f) { + switch (ty) { + case test_type::bf16: + f.template operator()(); + break; + case test_type::f16: + f.template operator()(); + break; + case test_type::f32: + f.template operator()(); + break; + case test_type::f64: + f.template operator()(); + break; + case test_type::c32: + f.template operator()>(); + break; + case test_type::c64: + f.template operator()>(); + break; + default: + throw std::runtime_error("Unknown test type"); + } +} + +inline auto convert_test_case(char const *str, test_case &tc) -> cmd::parser_status { + auto const parse = [](std::int64_t *v, char const *str, char **end, char sep) { + *v = strtol(str, end, 10); + if (*v == 0 || **end != sep) { + throw cmd::parser_status::invalid_argument; + } + if (errno == ERANGE) { + throw cmd::parser_status::argument_out_of_range; + } + }; + char *end = nullptr; + try { + parse(&tc.m, str, &end, 'x'); + parse(&tc.n, end + 1, &end, 'x'); + parse(&tc.k, end + 1, &end, 0); + } catch (cmd::parser_status st) { + return st; + } + return cmd::parser_status::success; +} +inline auto validate_test_case(test_case const &tc) -> bool { + return tc.m > 0 && tc.n > 0 && tc.k > 0; +} + +template auto fabs(T x) { + if constexpr (std::is_same_v) { + return sycl::fabs(x); + } else { + return std::abs(x); + } +} + +template auto compute_error(T x, T x_ref) { + auto err = examples::fabs(x - x_ref); + const auto scale = examples::fabs(x_ref); + return scale > std::numeric_limits::epsilon() ? err / scale : err; +} + +// Increment values in bf16 epsilons +constexpr double test_gemm_smallest_eps = 0.0078125; + +template +void test_gemm_matrix(T *data, std::size_t M, std::size_t N, bool transposed = false) { + for (std::size_t j = 0; j < N; ++j) { + for (std::size_t i = 0; i < M; ++i) { + const auto idx = transposed ? j + i * N : i + j * M; + if constexpr (Use == matrix_use::a) { + data[idx] = (1.0 + i * test_gemm_smallest_eps) * (j + 1) / N; + } else if constexpr (Use == matrix_use::b) { + data[idx] = 1.0 / ((i + 1) * (1 + j * test_gemm_smallest_eps)); + } else { + data[idx] = 0; + } + } + } +} + +template struct is_complex : public std::false_type {}; +template struct is_complex> : public std::true_type {}; +template inline constexpr bool is_complex_v = is_complex::value; + +template struct is_lp_float : public std::false_type {}; +template +struct is_lp_float> : public std::true_type {}; +template inline constexpr bool is_lp_float_v = is_lp_float::value; + +template +auto test_gemm_rel_error(T *data, std::size_t i, std::size_t j, std::size_t M) -> double { + const double ref = (1 + i * test_gemm_smallest_eps) / (1 + j * test_gemm_smallest_eps); + if constexpr (is_complex_v) { + return std::abs(static_cast>(data[i + j * M]) - + std::complex{ref}) / + ref; + } else { + return std::abs(static_cast(data[i + j * M]) - ref) / ref; + } +} + +template auto test_gemm_error_bound(std::size_t K) { + const auto gamma = [](std::size_t K, double u) { return K * u / (1.0 - K * u); }; + if constexpr (is_lp_float_v) { + const double u = std::pow(2.0, -static_cast(T::lp_format::mantissa_bits)); + const double u_f32 = std::numeric_limits::epsilon(); + // Accumulation is done in single precision + return 2.0 * u + u * u + gamma(K, u_f32) * (1 + u) * (1 + u); + } else if constexpr (is_complex_v) { + return gamma(K, std::numeric_limits::epsilon()); + } else { + return gamma(K, std::numeric_limits::epsilon()); + } +} + +} // namespace tinytc::examples + +#endif // GEMM_COMMON_20241014_HPP diff --git a/examples/jit/main.cpp b/examples/jit/main.cpp index 14aee5a3..0fa102ed 100644 --- a/examples/jit/main.cpp +++ b/examples/jit/main.cpp @@ -1,12 +1,11 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "tinytc/tinytc.hpp" +#include "tinytc/core.hpp" #include "tinytc/types.hpp" #include #include -#include using namespace tinytc; @@ -15,18 +14,15 @@ int main(int argc, char **argv) { return -1; } - auto ctx = source_context{}; try { - ctx = make_source_context(); - auto info = make_core_info_intel_from_arch(intel_gpu_architecture::pvc); - auto prog = parse_file(argv[1], ctx); + auto info = create_core_info_intel_from_arch(intel_gpu_architecture::pvc); + auto prog = parse_file(argv[1]); if (!prog) { return -1; } - compile_to_opencl(std::move(prog), info, ctx); + compile_to_spirv_and_assemble(prog.get(), info.get()); } catch (status const &st) { - std::cerr << "Error (" << static_cast(st) << "): " << error_string(st) << std::endl; - std::cerr << "Error log: " << std::endl << ctx.get_error_log() << std::endl; + std::cerr << "Error (" << static_cast(st) << "): " << to_string(st) << std::endl; return 1; } catch (std::exception const &e) { std::cerr << e.what() << std::endl; diff --git a/examples/matrix_chain/CMakeLists.txt b/examples/matrix_chain/CMakeLists.txt index 7fa1358c..9f5e9eb0 100644 --- a/examples/matrix_chain/CMakeLists.txt +++ b/examples/matrix_chain/CMakeLists.txt @@ -14,5 +14,5 @@ set(SOURCES add_executable(matrix_chain ${SOURCES}) add_sycl_to_target(TARGET matrix_chain SOURCES ${SOURCES}) -target_link_libraries(matrix_chain PRIVATE tinytc tinytc_sycl) +target_link_libraries(matrix_chain PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(matrix_chain) diff --git a/examples/matrix_chain/main.cpp b/examples/matrix_chain/main.cpp index e8062846..1355b4e1 100644 --- a/examples/matrix_chain/main.cpp +++ b/examples/matrix_chain/main.cpp @@ -4,6 +4,7 @@ #include "test.hpp" #include "test_multi.hpp" +#include #include #include @@ -14,8 +15,58 @@ #include using namespace sycl; +using namespace tinytc; int main(int argc, char **argv) { + bool dump = false; + std::int64_t N = 5, P = 9, howmany; + std::size_t alignment = 0; + char precision = 's'; + test_case tc = test_case::volume; + bool help = false; + + auto parser = cmd::arg_parser{}; + try { + parser.set_short_opt('a', &alignment, "Alignment (in number of bytes)"); + parser.set_short_opt('d', &dump, "Dump IR to stdout"); + parser.set_short_opt('f', &precision, "Data type (s or d)").validator([](char f) { + return f == 's' || f == 'd'; + }); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('N', &N, "Polynomial degree").validator([](std::int64_t p) { + return p > 0; + }); + parser.set_short_opt('P', &P, "Number of quantities").validator([](std::int64_t n) { + return n > 0; + }); + parser.set_long_opt("help", &help, "Show help"); + parser.add_positional_arg("test_case", &tc, "Test case (volume or ader)", true) + .converter([](char const *str, test_case &val) -> cmd::parser_status { + if (strcmp(str, "volume") == 0) { + val = test_case::volume; + } else if (strcmp(str, "ader") == 0) { + val = test_case::ader; + } else { + return cmd::parser_status::invalid_argument; + } + return cmd::parser_status::success; + }); + parser.add_positional_arg("howmany", &howmany, "Batch size", true) + .validator([](std::int64_t h) { return h > 0; }); + + parser.parse(argc, argv); + } catch (std::exception const &e) { + if (!help) { + std::cerr << e.what() << std::endl; + } + parser.print_help(std::cout, "matrix_chain", ""); + return help ? 0 : -1; + } + if (help) { + parser.print_help(std::cout, "matrix_chain", ""); + return 0; + } + auto devices = platform{}.get_devices(); auto sub_devices = std::vector{}; for (auto &device : devices) { @@ -33,41 +84,9 @@ int main(int argc, char **argv) { q.emplace_back(queue(device)); } - std::int64_t N = 5, P = 9, howmany; - std::size_t alignment = 0; - char precision = 's'; - test_case tc = test_case::volume; - - if (argc < 5) { - std::cerr << "Usage: matrix_chain

[alignment] [s/d]" - << std::endl; - return -1; - } - if (strcmp(argv[1], "volume") == 0) { - tc = test_case::volume; - } else if (strcmp(argv[1], "ader") == 0) { - tc = test_case::ader; - } else { - std::cerr << "Unknown test case " << argv[1] << ". Available are: ader, volume." - << std::endl; - return -1; - } - N = static_cast(std::atol(argv[2])); - P = static_cast(std::atol(argv[3])); - howmany = static_cast(std::atol(argv[4])); - if (argc >= 6) { - alignment = static_cast(std::atol(argv[5])); - } - if (argc >= 7) { - precision = argv[6][0]; - if (precision != 's' && precision != 'd') { - std::cerr << "Precision must be single (s) or double (d)" << std::endl; - return -1; - } - } auto run_test_multi = [&](auto precision) { using T = decltype(precision); - auto t = test_multi(N, P, howmany, alignment, tc, q); + auto t = test_multi(N, P, howmany, alignment, tc, q, dump); if (!t.check()) { std::cerr << "Result mismatch between reference and optimized!" << std::endl; // return; diff --git a/examples/matrix_chain/matrix_batch.hpp b/examples/matrix_chain/matrix_batch.hpp index f882b5b7..68e5b44f 100644 --- a/examples/matrix_chain/matrix_batch.hpp +++ b/examples/matrix_chain/matrix_batch.hpp @@ -7,8 +7,10 @@ #include "device_array.hpp" #include +#include #include +#include #include #include @@ -16,30 +18,41 @@ template class matrix_batch { public: matrix_batch(std::int64_t nrows, std::int64_t ncols, std::int64_t ld, std::int64_t howmany, sycl::queue q) - : nrows_(nrows), ncols_(ncols), ld_(ld), howmany_(howmany), - data_(ld_ * ncols_ * howmany, std::move(q)) {} + : shape_{nrows, ncols}, ld_{ld}, howmany_{howmany}, + data_(stride() * howmany_, std::move(q)) {} inline T *get() { return data_.get(); } inline T const *get() const { return data_.get(); } - inline std::int64_t nrows() const { return nrows_; } - inline std::int64_t ncols() const { return ncols_; } - inline std::int64_t ld() const { return ld_; } + inline auto shape() const -> tinytc::array_view { return shape_; } + inline std::int64_t nrows() const { return shape_[0]; } + inline std::int64_t ncols() const { return shape_[1]; } inline std::int64_t howmany() const { return howmany_; } - inline std::int64_t stride() const { return ld_ * ncols_; } + inline std::int64_t ld() const { return ld_; } + inline std::int64_t stride() const { return ld_ * ncols(); } inline std::size_t size() const { return data_.size(); } inline void fill(T const &v) { data_.fill(v); } inline void random() { data_.random(); } - inline tinytc::data_type type(bool include_batch_dim = true) { - constexpr auto real_t = tinytc::to_scalar_type_v; - if (include_batch_dim && howmany() > 1) { - return tinytc::make_memref(real_t, {nrows(), ncols(), tinytc::dynamic}, - {1, ld(), stride()}); + inline auto type(tinytc_type_t element_ty) -> tinytc_type_t { + auto shape = std::array{nrows(), ncols(), tinytc::dynamic}; + auto strid = std::array{std::int64_t{1}, ld(), stride()}; + if (howmany_ == 1) { + return tinytc::get(element_ty, tinytc::array_view(shape.data(), 2), + tinytc::array_view(strid.data(), 2), + tinytc::address_space::global); } - return tinytc::make_memref(real_t, {nrows(), ncols()}, {1, ld()}); + return tinytc::get(element_ty, shape, strid, + tinytc::address_space::global); + } + inline auto local_type(tinytc_type_t element_ty) -> tinytc_type_t { + auto shape = std::array{nrows(), ncols()}; + auto strid = std::array{std::int64_t{1}, ld()}; + return tinytc::get(element_ty, shape, strid, + tinytc::address_space::local); } private: - std::int64_t nrows_, ncols_, ld_, howmany_; + std::array shape_; + std::int64_t ld_, howmany_; device_array data_; }; diff --git a/examples/matrix_chain/test_ader.cpp b/examples/matrix_chain/test_ader.cpp index 78192ff9..bba0c06d 100644 --- a/examples/matrix_chain/test_ader.cpp +++ b/examples/matrix_chain/test_ader.cpp @@ -2,11 +2,12 @@ // SPDX-License-Identifier: BSD-3-Clause #include "test_ader.hpp" +#include #include #include -#include #include +#include #include using namespace sycl; @@ -14,14 +15,15 @@ using namespace tinytc; template test_ader::test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - queue q) + queue q, bool dump) : N_(N), P_(P), howmany_(howmany), alignment_(alignment), q_(std::move(q)), - dev_info_(make_core_info(q_.get_device())), I_ref_(Bd(), P_, Bd_aligned(), howmany_, q_), + dev_info_(create_core_info(q_.get_device())), I_ref_(Bd(), P_, Bd_aligned(), howmany_, q_), I_opt_(Bd(), P_, Bd_aligned(), howmany_, q_), tmp_(Bd(), P_, Bd_aligned(N_ - 1), howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), K_(dim, matrix_batch(Bd(), Bd(), Bd_aligned(N_ - 1), 1, q_)), dQ_(make_dQ()), - opt_bundle_(make_optimized_kernel()), opt_kernel_(make_kernel(opt_bundle_, "ader_kernel")) { + opt_bundle_(make_optimized_kernel(dump)), + opt_kernel_(create_kernel(opt_bundle_, "ader_kernel")) { I_ref_.random(); I_opt_.random(); for (auto &a : A_) { @@ -36,16 +38,19 @@ test_ader::test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, st d->fill(0); } + auto ctx = create_compiler_context(); for (std::int64_t n = 1; n <= N_; ++n) { auto bn = Bd_aligned(N_ - n); - g_.emplace_back(make_recipe_handler( - q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, - bn, P_, Bd(N_ - n + 1), K_[0].ld(), 0, dQ_[n - 1].ld(), - dQ_[n - 1].stride(), bn, bn * P_))); - g_.emplace_back(make_recipe_handler( - q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, - bn, P_, P_, bn, bn * P_, A_[0].ld(), A_[0].stride(), - dQ_[n].ld(), dQ_[n].stride()))); + g_.emplace_back(create_recipe_handler( + q_, create_small_gemm_batched(dev_info_.get(), to_type(ctx.get()), transpose::N, + transpose::N, bn, P_, Bd(N_ - n + 1), K_[0].ld(), 0, + dQ_[n - 1].ld(), dQ_[n - 1].stride(), bn, bn * P_) + .get())); + g_.emplace_back(create_recipe_handler( + q_, create_small_gemm_batched(dev_info_.get(), to_type(ctx.get()), transpose::N, + transpose::N, bn, P_, P_, bn, bn * P_, A_[0].ld(), + A_[0].stride(), dQ_[n].ld(), dQ_[n].stride()) + .get())); } } @@ -58,63 +63,119 @@ template std::vector> test_ader::make_dQ() { } template -auto test_ader::make_optimized_kernel() -> sycl::kernel_bundle { - constexpr auto real_t = to_scalar_type_v; - auto opt_kernel = [&](function_builder &fb) { - T dt = 1.01; - T num = T(1.0); - int denom = 1; - std::array A; - std::array K; +auto test_ader::make_optimized_kernel(bool dump_code) + -> sycl::kernel_bundle { + auto opt_kernel = [&](tinytc_compiler_context_t ctx) { + auto element_ty = to_type(ctx); + std::array param_types; + param_types[0] = element_ty; for (std::size_t i = 0; i < dim; ++i) { - A[i] = fb.argument(A_[i].type(), "A"); + param_types[1 + i] = A_[i].type(element_ty); } for (std::size_t i = 0; i < dim; ++i) { - K[i] = fb.argument(K_[i].type(), "K"); + param_types[1 + dim + i] = K_[i].type(element_ty); + } + param_types[1 + 2 * dim + 0] = dQ_[0].type(element_ty); + param_types[1 + 2 * dim + 1] = I_opt_.type(element_ty); + + auto void_ty = get(ctx); + auto f = create_func("ader_kernel", param_types, void_ty); + auto fn_body = get_body(f.get()); + + std::array params; + get_parameters(fn_body, params); + + auto dt = params[0]; + set_name(dt, "dt"); + auto A = [¶ms](std::size_t i) -> tinytc_value_t & { return params[1 + i]; }; + auto K = [¶ms](std::size_t i) -> tinytc_value_t & { return params[1 + dim + i]; }; + auto Q = params[1 + 2 * dim + 0]; + auto I = params[1 + 2 * dim + 1]; + for (std::size_t i = 0; i < dim; ++i) { + set_name(A(i), (std::ostringstream{} << 'A' << i).str()); + set_name(K(i), (std::ostringstream{} << 'K' << i).str()); + } + set_name(Q, "Q"); + set_name(I, "I"); + + auto bb = region_builder{fn_body}; + auto const c0 = bb.constant_zero(element_ty); + auto const c1 = bb.constant_one(element_ty); + auto const gid = bb.create(comp3::x, get(ctx)); + auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes3 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols(), 0}; + }; + auto const static_sizes2 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols()}; + }; + auto const offsets3 = array_view(gid); + const auto dynamic_stride = std::array{std::int64_t{1}, dynamic}; + auto dqt = get(element_ty, static_sizes2(dQ_[0]), dynamic_stride, + address_space::global); + auto dq = bb.create(static_offsets3, static_sizes3(dQ_[0]), Q, offsets3, + array_view{}, dqt); + for (std::size_t d = 0; d < dim; ++d) { + auto At = get(element_ty, static_sizes2(A_[d]), array_view{}, + address_space::global); + A(d) = bb.create(static_offsets3, static_sizes3(A_[d]), A(d), offsets3, + array_view{}, At); } - auto Q = fb.argument(dQ_[0].type(), "dQ"); - auto I = fb.argument(I_opt_.type(), "I"); - fb.body([&](region_builder &bb) { - auto const gid = bb.add(make_group_id()); - auto const offsets3 = std::vector{make_index(0), make_index(0), gid}; - auto const size3 = std::vector{make_dynamic(), make_dynamic(), value{}}; - auto dq = bb.add(make_subview(Q, offsets3, size3)); + auto it = get(element_ty, static_sizes2(I_opt_), dynamic_stride, + address_space::global); + auto i = bb.create(static_offsets3, static_sizes3(I_opt_), I, offsets3, + array_view{}, it); + bb.create(false, transpose::N, c1, dq, c1, i); + + int denom = 1; + auto cnum = c1; + auto const static_offsets2 = std::array{0, 0}; + for (std::int64_t n = 1; n <= N_; ++n) { + cnum = bb.create(cnum, dt, get_type(dt)); + denom *= n + 1; + auto cdenom = bb.create(static_cast(denom), element_ty); + auto cfactor = bb.create(cnum, cdenom, get_type(cnum)); + auto bn = Bd_aligned(N_ - n); + auto dq_next = bb.create(dQ_[n].local_type(element_ty)); + auto dq_nextvt = get(element_ty, std::array{bn, P_}, dynamic_stride, + address_space::local); + auto dq_nextv = bb.create(static_offsets2, array_view{bn, P_}, dq_next, + array_view{}, + array_view{}, dq_nextvt); + auto tmp = bb.create(get( + element_ty, std::array{bn, P_}, dynamic_stride, address_space::local)); for (std::size_t d = 0; d < dim; ++d) { - A[d] = bb.add(make_subview(A[d], offsets3, size3)); + auto Kvt = get(element_ty, std::array{bn, Bd(N_ - n + 1)}, + dynamic_stride, address_space::global); + auto Kv = bb.create(static_offsets2, array_view{bn, Bd(N_ - n + 1)}, + K(d), array_view{}, + array_view{}, Kvt); + bb.create(false, transpose::N, transpose::N, c1, Kv, dq, c0, tmp); + bb.create(false, transpose::N, transpose::N, c1, tmp, A(d), + d > 0 ? c1 : c0, dq_nextv); } - auto i = bb.add(make_subview(I, offsets3, size3)); - bb.add(make_axpby(transpose::N, false, make_imm(num / denom), dq, make_imm(T(1.0)), i)); - - auto const offsets2 = std::vector{make_index(0), make_index(0)}; - for (std::int64_t n = 1; n <= N_; ++n) { - num *= dt; - denom *= n + 1; - auto bn = Bd_aligned(N_ - n); - auto dq_next = bb.add(make_alloca(dQ_[n].type(false))); - auto dq_nextv = - bb.add(make_subview(dq_next, offsets2, {make_index(bn), make_index(P_)})); - auto tmp = bb.add(make_alloca(make_memref(real_t, {bn, P_}, {1, bn}))); - for (std::size_t d = 0; d < dim; ++d) { - auto Kv = bb.add( - make_subview(K[d], offsets2, {make_index(bn), make_index(Bd(N_ - n + 1))})); - bb.add(make_gemm(transpose::N, transpose::N, false, make_imm(T(1.0)), Kv, dq, - make_imm(T(0.0)), tmp)); - bb.add(make_gemm(transpose::N, transpose::N, false, make_imm(T(1.0)), tmp, A[d], - make_imm(T(d > 0 ? 1.0 : 0.0)), dq_nextv)); - } - auto iv = - bb.add(make_subview(i, offsets2, {make_index(Bd(N_ - n)), make_index(P_)})); - bb.add(make_axpby(transpose::N, false, make_imm(num / denom), dq_next, - make_imm(T(1.0)), iv)); - dq = dq_next; - } - }); - }; - auto pb = program_builder{}; - pb.create("ader_kernel", opt_kernel); + auto ivt = get(element_ty, std::array{Bd(N_ - n), P_}, dynamic_stride, + address_space::global); + auto iv = bb.create(static_offsets2, array_view{Bd(N_ - n), P_}, i, + array_view{}, + array_view{}, ivt); + bb.create(false, transpose::N, cfactor, dq_next, c1, iv); + dq = dq_next; + } - return make_kernel_bundle(q_.get_context(), q_.get_device(), - compile_to_opencl(pb.get_product(), dev_info_)); + return f; + }; + auto ctx = create_compiler_context(); + set_error_reporter(ctx.get(), [](char const *what, const tinytc_location_t *, void *) { + std::cerr << what << std::endl; + }); + auto p = create_prog(ctx.get()); + add_function(p.get(), opt_kernel(ctx.get())); + if (dump_code) { + dump(p.get()); + } + auto bin = compile_to_spirv_and_assemble(p.get(), dev_info_.get()); + return create_kernel_bundle(q_.get_context(), q_.get_device(), bin.get()); } template @@ -144,12 +205,13 @@ template std::vector test_ader::reference() { num *= dt; denom *= n + 1; for (std::size_t d = 0; d < dim; ++d) { - small_gemm_batched::set_args(g_[2 * (n - 1)], howmany_, T(1.0), K_[d].get(), - dQ_[n - 1].get(), T(0.0), tmp_.get()); - e[0] = g_[2 * (n - 1)].submit(q_, e); - small_gemm_batched::set_args(g_[2 * n - 1], howmany_, T(1.0), tmp_.get(), A_[d].get(), - T(1.0), dQ_[n].get()); - e[0] = g_[2 * n - 1].submit(q_, e); + auto handler = g_[2 * (n - 1)].get(); + set_small_gemm_batched_args(handler, howmany_, T(1.0), K_[d].get(), dQ_[n - 1].get(), + T(0.0), tmp_.get()); + e[0] = submit(handler, q_, e); + set_small_gemm_batched_args(handler, howmany_, T(1.0), tmp_.get(), A_[d].get(), T(1.0), + dQ_[n].get()); + e[0] = submit(handler, q_, e); } e[0] = taylor_sum(I_ref_, dQ_[n], num / denom, e); } @@ -157,10 +219,13 @@ template std::vector test_ader::reference() { } template std::vector test_ader::optimized() { - auto exe_range = get_execution_range(opt_kernel_, howmany_); + T dt = 1.01; + auto exe_range = get_execution_range( + opt_kernel_, sycl::range<3u>{1u, 1u, static_cast(howmany_)}); return {q_.submit([&](handler &h) { - h.set_args(A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, K_[0].get(), - K_[1].get(), K_[2].get(), dQ_[0].get(), howmany_, I_opt_.get(), howmany_); + h.set_args(dt, A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, + K_[0].get(), K_[1].get(), K_[2].get(), dQ_[0].get(), howmany_, I_opt_.get(), + howmany_); h.parallel_for(exe_range, opt_kernel_); })}; } diff --git a/examples/matrix_chain/test_ader.hpp b/examples/matrix_chain/test_ader.hpp index 0195c975..a7a34442 100644 --- a/examples/matrix_chain/test_ader.hpp +++ b/examples/matrix_chain/test_ader.hpp @@ -20,7 +20,7 @@ template class test_ader : public test { public: test_ader(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - sycl::queue q); + sycl::queue q, bool dump = false); ~test_ader() = default; test_ader(test_ader const &other) = delete; test_ader(test_ader &&other) = default; @@ -41,16 +41,16 @@ template class test_ader : public test { inline std::int64_t Bd_aligned() { return aligned(Bd(N_), alignment_); } inline std::int64_t Bd_aligned(std::int64_t N) { return aligned(Bd(N), alignment_); } std::vector> make_dQ(); - auto make_optimized_kernel() -> sycl::kernel_bundle; + auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; sycl::event taylor_sum(matrix_batch &I, matrix_batch &dQ, T factor, std::vector const &dep_events = {}); std::int64_t N_, P_, howmany_, alignment_; sycl::queue q_; - tinytc::core_info dev_info_; + tinytc::shared_handle dev_info_; matrix_batch I_ref_, I_opt_, tmp_; std::vector> A_, K_, dQ_; - std::vector g_; + std::vector> g_; sycl::kernel_bundle opt_bundle_; sycl::kernel opt_kernel_; }; diff --git a/examples/matrix_chain/test_multi.cpp b/examples/matrix_chain/test_multi.cpp index c0d30308..34fbc50f 100644 --- a/examples/matrix_chain/test_multi.cpp +++ b/examples/matrix_chain/test_multi.cpp @@ -25,14 +25,17 @@ template double bench(F f, int nrepeat = 10) { template test_multi::test_multi(std::int64_t N, std::int64_t P, std::int64_t howmany, - std::size_t alignment, test_case tc, std::vector const &q) { + std::size_t alignment, test_case tc, std::vector const &q, + bool dump) { for (auto &qu : q) { switch (tc) { case test_case::ader: - instances_.emplace_back(std::make_unique>(N, P, howmany, alignment, qu)); + instances_.emplace_back( + std::make_unique>(N, P, howmany, alignment, qu, dump)); break; case test_case::volume: - instances_.emplace_back(std::make_unique>(N, P, howmany, alignment, qu)); + instances_.emplace_back( + std::make_unique>(N, P, howmany, alignment, qu, dump)); break; default: break; diff --git a/examples/matrix_chain/test_multi.hpp b/examples/matrix_chain/test_multi.hpp index 1bd4beac..b7b3027f 100644 --- a/examples/matrix_chain/test_multi.hpp +++ b/examples/matrix_chain/test_multi.hpp @@ -17,7 +17,7 @@ enum class test_case { volume, ader }; template class test_multi { public: test_multi(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - test_case tc, std::vector<::sycl::queue> const &q); + test_case tc, std::vector<::sycl::queue> const &q, bool dump = false); void reference(); void optimized(); diff --git a/examples/matrix_chain/test_volume.cpp b/examples/matrix_chain/test_volume.cpp index 479e5e03..b420f11a 100644 --- a/examples/matrix_chain/test_volume.cpp +++ b/examples/matrix_chain/test_volume.cpp @@ -6,8 +6,8 @@ #include #include -#include #include +#include #include using namespace sycl; @@ -15,15 +15,16 @@ using namespace tinytc; template test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany, - std::size_t alignment, queue q) + std::size_t alignment, queue q, bool dump) : B3_(num_basis(N, dim)), B2_(num_basis(N - 1, dim)), P_(P), howmany_(howmany), B3_aligned_(aligned(B3_, alignment)), B2_aligned_(aligned(B2_, alignment)), - q_(std::move(q)), dev_info_(make_core_info(q_.get_device())), + q_(std::move(q)), dev_info_(create_core_info(q_.get_device())), Q_ref_(B3_, P_, B3_aligned_, howmany_, q_), Q_opt_(B3_, P_, B3_aligned_, howmany_, q_), I_(B3_, P_, B3_aligned_, howmany_, q_), tmp_(B3_, P_, B2_aligned_, howmany_, q_), A_(dim, matrix_batch(P_, P_, P_, howmany_, q_)), - K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), opt_bundle_(make_optimized_kernel()), - opt_kernel_(make_kernel(opt_bundle_, "volume_kernel")) { + K_(dim, matrix_batch(B3_, B3_, B3_aligned_, 1, q_)), ctx_(make_compiler_context()), + opt_bundle_(make_optimized_kernel(dump)), + opt_kernel_(create_kernel(opt_bundle_, "volume_kernel")) { Q_ref_.random(); Q_opt_.random(); I_.random(); @@ -35,114 +36,155 @@ test_volume::test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany k.random(); } - g_.emplace_back(make_recipe_handler( - q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, - B2_aligned_, P_, P_, B3_aligned_, B3_aligned_ * P_, P_, P_ * P_, - B2_aligned_, B2_aligned_ * P_))); - g_.emplace_back(make_recipe_handler( - q_, make_small_gemm_batched(dev_info_, to_scalar_type_v, transpose::N, transpose::N, - B3_aligned_, P_, B2_, B3_aligned_, 0, B2_aligned_, - B2_aligned_ * P_, B3_aligned_, B3_aligned_ * P_))); + g_.emplace_back(create_recipe_handler( + q_, create_small_gemm_batched(dev_info_.get(), to_type(ctx_.get()), transpose::N, + transpose::N, B2_aligned_, P_, P_, B3_aligned_, + B3_aligned_ * P_, P_, P_ * P_, B2_aligned_, B2_aligned_ * P_) + .get())); + g_.emplace_back(create_recipe_handler( + q_, create_small_gemm_batched(dev_info_.get(), to_type(ctx_.get()), transpose::N, + transpose::N, B3_aligned_, P_, B2_, B3_aligned_, 0, + B2_aligned_, B2_aligned_ * P_, B3_aligned_, B3_aligned_ * P_) + .get())); } template -auto test_volume::make_optimized_kernel() +auto test_volume::make_compiler_context() -> shared_handle { + auto ctx = ::tinytc::create_compiler_context(); + set_error_reporter(ctx.get(), [](char const *what, const tinytc_location_t *, void *) { + std::cerr << what << std::endl; + }); + return ctx; +} + +template +auto test_volume::make_optimized_kernel(bool dump_code) -> sycl::kernel_bundle { - constexpr auto real_t = to_scalar_type_v; - /** - * With B3_ = 56, B3_aligned_ = 64, B2_ = 35, B2_aligned_ = 48, P_ = 9 - * - * func chain(A0: batch, distance<81>>, - * A1: batch, distance<81>>, - * A2: batch, distance<81>>, - * K0: memref, - * K1: memref, - * K2: memref, - * Q: batch, distance<576>>, - * I: batch, distance<576>>) { - * a0 = get_work_item A0 - * a1 = get_work_item A1 - * a2 = get_work_item A2 - * q = get_work_item Q - * i = get_work_item i - * tmp = alloca matrix; - * K0v = submatrix K0[0:64,0:35] - * K1v = submatrix K1[0:64,0:35] - * K2v = submatrix K2[0:64,0:35] - * qv = submatrix Q[0:64,0:9] - * iv = submatrix I[0:48,0:9] - * tmpv = submatrix tmp[0:48,0:9] - * matmul(iv, a0, tmpv, 1.0, 0.0); - * matmul(K0v, tmp, qv, 1.0, 0.0); - * matmul(iv, a1, tmpv, 1.0, 0.0); - * matmul(K1v, tmp, qv, 1.0, 0.0); - * matmul(iv, a2, tmpv, 1.0, 0.0); - * matmul(K2v, tmp, qv, 1.0, 0.0); - * } - */ // Optimized kernel - auto opt_kernel = [&](function_builder &fb) { - auto A0 = fb.argument(A_[0].type(), "A0"); - auto A1 = fb.argument(A_[1].type(), "A1"); - auto A2 = fb.argument(A_[2].type(), "A2"); - auto K0 = fb.argument(K_[0].type(), "K0"); - auto K1 = fb.argument(K_[1].type(), "K1"); - auto K2 = fb.argument(K_[2].type(), "K2"); - auto Q = fb.argument(Q_opt_.type(), "Q"); - auto I = fb.argument(I_.type(), "I"); - fb.body([&](region_builder &bb) { - auto gid = bb.add(make_group_id()); - auto const offsets2 = std::vector{make_index(0), make_index(0)}; - auto const offsets3 = std::vector{make_index(0), make_index(0), gid}; - auto const size3 = std::vector{make_dynamic(), make_dynamic(), value{}}; - auto const sizeK2 = std::vector{make_index(B3_aligned_), make_index(B2_)}; - auto tmp = bb.add(make_alloca(make_memref(real_t, {B2_, P_}, {1, B2_aligned_}))); - auto a0 = bb.add(make_subview(A0, offsets3, size3)); - auto a1 = bb.add(make_subview(A1, offsets3, size3)); - auto a2 = bb.add(make_subview(A2, offsets3, size3)); - auto K0v = bb.add(make_subview(K0, offsets2, sizeK2)); - auto K1v = bb.add(make_subview(K1, offsets2, sizeK2)); - auto K2v = bb.add(make_subview(K2, offsets2, sizeK2)); - auto qv = bb.add( - make_subview(Q, offsets3, {make_index(B3_aligned_), make_index(P_), value{}})); - auto iv = bb.add( - make_subview(I, offsets3, {make_index(B2_aligned_), make_index(P_), value{}})); - auto tmpv = - bb.add(make_subview(tmp, offsets2, {make_index(B2_aligned_), make_index(P_)})); - auto const s0 = make_imm(T(0.0)); - auto const s1 = make_imm(T(1.0)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a0, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K0v, tmp, s1, qv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a1, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K1v, tmp, s1, qv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, iv, a2, s0, tmpv)); - bb.add(make_gemm(transpose::N, transpose::N, false, s1, K2v, tmp, s1, qv)); - }); - }; - auto pb = program_builder{}; - pb.create("volume_kernel", opt_kernel); + auto opt_kernel = [&](tinytc_compiler_context_t ctx) { + auto element_ty = to_type(ctx); + std::array param_types; + for (std::size_t i = 0; i < dim; ++i) { + param_types[i] = A_[i].type(element_ty); + } + for (std::size_t i = 0; i < dim; ++i) { + param_types[dim + i] = K_[i].type(element_ty); + } + param_types[2 * dim + 0] = Q_opt_.type(element_ty); + param_types[2 * dim + 1] = I_.type(element_ty); + + auto void_ty = get(ctx); + auto f = create_func("volume_kernel", param_types, void_ty); + auto fn_body = get_body(f.get()); + + std::array params; + get_parameters(fn_body, params); - return make_kernel_bundle(q_.get_context(), q_.get_device(), - compile_to_opencl(pb.get_product(), dev_info_)); + auto A = [¶ms](std::size_t i) -> tinytc_value_t & { return params[i]; }; + auto K = [¶ms](std::size_t i) -> tinytc_value_t & { return params[dim + i]; }; + auto Q = params[2 * dim + 0]; + auto I = params[2 * dim + 1]; + + for (std::size_t i = 0; i < dim; ++i) { + set_name(A(i), (std::ostringstream{} << 'A' << i).str()); + set_name(K(i), (std::ostringstream{} << 'K' << i).str()); + } + set_name(Q, "Q"); + set_name(I, "I"); + + auto bb = region_builder{fn_body}; + auto gid = bb.create(comp3::x, get(ctx)); + auto const static_offsets2 = std::array{0, 0}; + auto const static_offsets3 = std::array{0, 0, dynamic}; + auto const static_sizes2 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols()}; + }; + auto const static_sizes3 = [](matrix_batch const &b) -> std::array { + return {b.nrows(), b.ncols(), 0}; + }; + auto const default_stride = array_view{}; + auto const offsets3 = array_view(gid); + auto const sizeK2 = std::array{B3_aligned_, B2_}; + auto tmp = bb.create(get(element_ty, std::array{B2_aligned_, P_}, + default_stride, address_space::local)); + + auto a0t = get(element_ty, static_sizes2(A_[0]), default_stride, + address_space::global); + auto a1t = get(element_ty, static_sizes2(A_[1]), default_stride, + address_space::global); + auto a2t = get(element_ty, static_sizes2(A_[2]), default_stride, + address_space::global); + auto k0t = get(element_ty, sizeK2, default_stride, address_space::global); + auto k1t = get(element_ty, sizeK2, default_stride, address_space::global); + auto k2t = get(element_ty, sizeK2, default_stride, address_space::global); + auto qvt = get(element_ty, std::array{B3_aligned_, P_}, default_stride, + address_space::global); + auto ivt = get(element_ty, std::array{B2_aligned_, P_}, + std::array{std::int64_t{1}, dynamic}, address_space::global); + auto tmpvt = + get(element_ty, std::array{B2_, P_}, default_stride, address_space::local); + auto a0 = bb.create(static_offsets3, static_sizes3(A_[0]), A(0), offsets3, + array_view{}, a0t); + auto a1 = bb.create(static_offsets3, static_sizes3(A_[1]), A(1), offsets3, + array_view{}, a1t); + auto a2 = bb.create(static_offsets3, static_sizes3(A_[2]), A(2), offsets3, + array_view{}, a2t); + auto k0 = + bb.create(static_offsets2, sizeK2, K(0), array_view{}, + array_view{}, k0t); + auto k1 = + bb.create(static_offsets2, sizeK2, K(1), array_view{}, + array_view{}, k1t); + auto k2 = + bb.create(static_offsets2, sizeK2, K(2), array_view{}, + array_view{}, k2t); + auto qv = + bb.create(static_offsets3, array_view{B3_aligned_, P_, std::int64_t{0}}, + Q, offsets3, array_view{}, qvt); + auto iv = + bb.create(static_offsets3, array_view{B2_aligned_, P_, std::int64_t{0}}, + I, offsets3, array_view{}, ivt); + auto tmpv = bb.create(static_offsets2, array_view{B2_, P_}, tmp, + array_view{}, + array_view{}, tmpvt); + auto const c0 = bb.constant_zero(element_ty); + auto const c1 = bb.constant_one(element_ty); + bb.create(false, transpose::N, transpose::N, c1, iv, a0, c0, tmp); + bb.create(false, transpose::N, transpose::N, c1, k0, tmpv, c1, qv); + bb.create(false, transpose::N, transpose::N, c1, iv, a1, c0, tmp); + bb.create(false, transpose::N, transpose::N, c1, k1, tmpv, c1, qv); + bb.create(false, transpose::N, transpose::N, c1, iv, a2, c0, tmp); + bb.create(false, transpose::N, transpose::N, c1, k2, tmpv, c1, qv); + + return f; + }; + auto p = create_prog(ctx_.get()); + add_function(p.get(), opt_kernel(ctx_.get())); + if (dump_code) { + dump(p.get()); + } + auto bin = compile_to_spirv_and_assemble(p.get(), dev_info_.get()); + return create_kernel_bundle(q_.get_context(), q_.get_device(), bin.get()); } template std::vector test_volume::reference() { auto e = std::vector{}; for (std::size_t d = 0; d < dim; ++d) { - small_gemm_batched::set_args(g_[0], howmany_, T(1.0), I_.get(), A_[d].get(), T(0.0), - tmp_.get()); - e.emplace_back(g_[0].submit(q_, e)); + set_small_gemm_batched_args(g_[0].get(), howmany_, T(1.0), I_.get(), A_[d].get(), T(0.0), + tmp_.get()); + e.emplace_back(submit(g_[0].get(), q_, e)); e.front() = e.back(); e.pop_back(); - small_gemm_batched::set_args(g_[1], howmany_, T(1.0), K_[d].get(), tmp_.get(), T(1.0), - Q_ref_.get()); - e.emplace_back(g_[1].submit(q_, e)); + set_small_gemm_batched_args(g_[1].get(), howmany_, T(1.0), K_[d].get(), tmp_.get(), T(1.0), + Q_ref_.get()); + e.emplace_back(submit(g_[1].get(), q_, e)); } return e; } template std::vector test_volume::optimized() { - auto exe_range = get_execution_range(opt_kernel_, howmany_); + auto exe_range = get_execution_range( + opt_kernel_, sycl::range<3u>{1u, 1u, static_cast(howmany_)}); return {q_.submit([&](handler &h) { h.set_args(A_[0].get(), howmany_, A_[1].get(), howmany_, A_[2].get(), howmany_, K_[0].get(), K_[1].get(), K_[2].get(), Q_opt_.get(), howmany_, I_.get(), howmany_); diff --git a/examples/matrix_chain/test_volume.hpp b/examples/matrix_chain/test_volume.hpp index ebd64f55..fdab07c7 100644 --- a/examples/matrix_chain/test_volume.hpp +++ b/examples/matrix_chain/test_volume.hpp @@ -19,7 +19,7 @@ template class test_volume : public test { public: test_volume(std::int64_t N, std::int64_t P, std::int64_t howmany, std::size_t alignment, - sycl::queue q); + sycl::queue q, bool dump = false); ~test_volume() = default; test_volume(test_volume const &other) = delete; test_volume(test_volume &&other) = default; @@ -35,16 +35,18 @@ template class test_volume : public test { private: constexpr static std::size_t dim = 3; - auto make_optimized_kernel() -> sycl::kernel_bundle; + auto make_optimized_kernel(bool dump) -> sycl::kernel_bundle; + auto make_compiler_context() -> tinytc::shared_handle; std::int64_t B3_, B2_, P_, howmany_, B3_aligned_, B2_aligned_; sycl::queue q_; - tinytc::core_info dev_info_; + tinytc::shared_handle dev_info_; matrix_batch Q_ref_, Q_opt_, I_, tmp_; std::vector> A_, K_; + tinytc::shared_handle ctx_; sycl::kernel_bundle opt_bundle_; sycl::kernel opt_kernel_; - std::vector g_; + std::vector> g_; }; extern template class test_volume; diff --git a/examples/simple_cl/main.c b/examples/simple_cl/main.c index 0ae5c520..73042d73 100644 --- a/examples/simple_cl/main.c +++ b/examples/simple_cl/main.c @@ -1,6 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include #include #include @@ -14,7 +15,7 @@ do { \ status = X; \ if (status != tinytc_status_success) { \ - printf("Error (%d): %s\n", status, tinytc_error_string(status)); \ + printf("Error (%d): %s\n", status, tinytc_status_to_string(status)); \ printf("in %s:%d: \"%s\"\n", __FILE__, __LINE__, #X); \ goto err; \ } \ @@ -24,8 +25,7 @@ do { \ cl_int result = X; \ if (result != CL_SUCCESS) { \ - status = tinytc_cl_convert_status(result); \ - printf("Error (%d): %s\n", status, tinytc_error_string(status)); \ + printf("OpenCL error (%d)\n", result); \ printf("in %s:%d: \"%s\"\n", __FILE__, __LINE__, #X); \ goto err; \ } \ @@ -39,7 +39,8 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue queue) { tinytc_status_t status = tinytc_status_success; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; + tinytc_compiler_context_t ctx = NULL; + tinytc_type_t f32_ty = NULL; tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; cl_mem A = NULL, B = NULL, C = NULL; @@ -47,13 +48,14 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue q cl_int err; CHECK(tinytc_cl_core_info_create(&info, device)); + CHECK(tinytc_compiler_context_create(&ctx)); + CHECK(tinytc_f32_type_get(&f32_ty, ctx)); const uint32_t M = 64, N = 64, K = 64, howmany = 1000; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, tinytc_scalar_type_f32, - tinytc_transpose_N, tinytc_transpose_N, M, N, K, - M, M * K, K, K * N, M, M * N, source_ctx)); - CHECK(tinytc_cl_recipe_handler_create(&handler, context, device, recipe, source_ctx)); + CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, f32_ty, tinytc_transpose_N, + tinytc_transpose_N, M, N, K, M, M * K, K, K * N, + M, M * N)); + CHECK(tinytc_cl_recipe_handler_create(&handler, context, device, recipe)); const size_t Abytes = M * K * howmany * sizeof(float); const size_t Bbytes = K * N * howmany * sizeof(float); @@ -114,14 +116,7 @@ tinytc_status_t gemm(cl_context context, cl_device_id device, cl_command_queue q } tinytc_recipe_handler_release(handler); tinytc_recipe_release(recipe); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } + tinytc_compiler_context_release(ctx); tinytc_core_info_release(info); return status; @@ -132,7 +127,6 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman int32_t *host = NULL; cl_mem A = NULL, B = NULL; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_prog_t program = NULL; cl_program module = NULL; cl_kernel kernel = NULL; @@ -158,17 +152,16 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman static const char source_text[] = "func @copy(%A: memref, %B: memref) {\n" - " %gid = group_id\n" - " %a = subview %A[:,%gid] : memref\n" - " %b = subview %B[:,%gid] : memref\n" - " axpby.n 1, %a, 0, %b\n" - " : i32, memref, i32, memref\n" + " %gid = group_id.x : index\n" + " %a = subview %A[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %b = subview %B[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %c0 = constant 0 : i32\n" + " %c1 = constant 1 : i32\n" + " axpby.n %c1, %a, %c0, %b\n" "}\n"; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, source_ctx)); - CHECK(tinytc_cl_kernel_bundle_create_with_program(&module, context, device, program, 0u, - source_ctx)); + CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, NULL)); + CHECK(tinytc_cl_kernel_bundle_create_with_program(&module, context, device, program, 0u)); kernel = clCreateKernel(module, "copy", &err); CL_CHECK(err); @@ -176,9 +169,10 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman CL_CHECK(clSetKernelArg(kernel, 1, sizeof(howmany), &howmany)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(B), &B)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(howmany), &howmany)); + size_t ng[3] = {howmany, 1, 1}; size_t ls[3], gs[3]; CHECK(tinytc_cl_get_group_size(kernel, ls)); - tinytc_cl_get_global_size(howmany, ls, gs); + tinytc_cl_get_global_size(ng, ls, gs); struct timespec start_time, end_time; clock_gettime(CLOCK_MONOTONIC, &start_time); @@ -211,14 +205,6 @@ tinytc_status_t custom_kernel(cl_context context, cl_device_id device, cl_comman clReleaseProgram(module); } tinytc_prog_release(program); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); if (B) { clReleaseMemObject(B); diff --git a/examples/simple_ze/main.c b/examples/simple_ze/main.c index d4a6a713..da1d4734 100644 --- a/examples/simple_ze/main.c +++ b/examples/simple_ze/main.c @@ -1,6 +1,7 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include #include #include @@ -13,7 +14,7 @@ do { \ status = X; \ if (status != tinytc_status_success) { \ - printf("Error (%d): %s\n", status, tinytc_error_string(status)); \ + printf("Error (%d): %s\n", status, tinytc_status_to_string(status)); \ printf("in %s:%d: \"%s\"\n", __FILE__, __LINE__, #X); \ goto err; \ } \ @@ -23,8 +24,7 @@ do { \ ze_result_t result = X; \ if (result != ZE_RESULT_SUCCESS) { \ - status = tinytc_ze_convert_status(result); \ - printf("Error (%d): %s\n", status, tinytc_error_string(status)); \ + printf("Level Zero error (%d)\n", result); \ printf("in %s:%d: \"%s\"\n", __FILE__, __LINE__, #X); \ goto err; \ } \ @@ -41,20 +41,22 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, ze_command_list_handle_t list) { tinytc_status_t status = tinytc_status_success; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; + tinytc_compiler_context_t ctx = NULL; + tinytc_type_t f32_ty = NULL; tinytc_recipe_t recipe = NULL; tinytc_recipe_handler_t handler = NULL; void *A = NULL, *B = NULL, *C = NULL; float *Chost = NULL; CHECK(tinytc_ze_core_info_create(&info, device)); + CHECK(tinytc_compiler_context_create(&ctx)); + CHECK(tinytc_f32_type_get(&f32_ty, ctx)); const uint32_t M = 64, N = 64, K = 64, howmany = 1000; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, tinytc_scalar_type_f32, - tinytc_transpose_N, tinytc_transpose_N, M, N, K, - M, M * K, K, K * N, M, M * N, source_ctx)); - CHECK(tinytc_ze_recipe_handler_create(&handler, context, device, recipe, source_ctx)); + CHECK(tinytc_recipe_small_gemm_batched_create(&recipe, info, f32_ty, tinytc_transpose_N, + tinytc_transpose_N, M, N, K, M, M * K, K, K * N, + M, M * N)); + CHECK(tinytc_ze_recipe_handler_create(&handler, context, device, recipe)); const size_t Abytes = M * K * howmany * sizeof(float); const size_t Bbytes = K * N * howmany * sizeof(float); @@ -110,14 +112,7 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, } tinytc_recipe_handler_release(handler); tinytc_recipe_release(recipe); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } + tinytc_compiler_context_release(ctx); tinytc_core_info_release(info); return status; @@ -126,18 +121,17 @@ tinytc_status_t gemm(ze_context_handle_t context, ze_device_handle_t device, tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t device, ze_command_list_handle_t list) { tinytc_status_t status = tinytc_status_success; - int32_t *host = NULL; + int16_t *host = NULL; void *A = NULL, *B = NULL; tinytc_core_info_t info = NULL; - tinytc_source_context_t source_ctx = NULL; tinytc_prog_t program = NULL; ze_module_handle_t module = NULL; ze_kernel_handle_t kernel = NULL; const uint32_t howmany = 1000; const int32_t elements = CHUNK_SIZE * howmany; - const size_t bytes = elements * sizeof(float); - host = (int32_t *)malloc(bytes); + const size_t bytes = elements * sizeof(int16_t); + host = (int16_t *)malloc(bytes); if (!host) { goto err; } @@ -156,24 +150,21 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de static const char source_text[] = "func @copy(%A: memref, %B: memref) {\n" - " %gid = group_id\n" - " %a = subview %A[:,%gid] : memref\n" - " %b = subview %B[:,%gid] : memref\n" - " axpby.n 1, %a, 0, %b\n" - " : i32, memref, i32, memref\n" + " %gid = group_id.x : index\n" + " %a = subview %A[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %b = subview %B[0:" CHUNK_SIZE_S ",%gid] : memref\n" + " %c0 = constant 0 : i32\n" + " %c1 = constant 1 : i32\n" + " axpby.n %c1, %a, %c0, %b\n" "}\n"; - CHECK(tinytc_source_context_create(&source_ctx)); - CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, source_ctx)); - CHECK(tinytc_ze_kernel_bundle_create_with_program(&module, context, device, program, 0u, - source_ctx)); + CHECK(tinytc_parse_string(&program, sizeof(source_text), source_text, NULL)); + CHECK(tinytc_ze_kernel_bundle_create_with_program(&module, context, device, program, 0u)); CHECK(tinytc_ze_kernel_create(&kernel, module, "copy")); ZE_CHECK(zeKernelSetArgumentValue(kernel, 0, sizeof(A), &A)); - ZE_CHECK(zeKernelSetArgumentValue(kernel, 1, sizeof(howmany), &howmany)); - ZE_CHECK(zeKernelSetArgumentValue(kernel, 2, sizeof(B), &B)); - ZE_CHECK(zeKernelSetArgumentValue(kernel, 3, sizeof(howmany), &howmany)); - ze_group_count_t group_count = tinytc_ze_get_group_count(howmany); + ZE_CHECK(zeKernelSetArgumentValue(kernel, 1, sizeof(B), &B)); + ze_group_count_t group_count = {howmany, 1u, 1u}; ZE_CHECK(zeCommandListAppendLaunchKernel(list, kernel, &group_count, NULL, 0, NULL)); ZE_CHECK(zeCommandListHostSynchronize(list, TIMEOUT)); @@ -182,10 +173,19 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de uint32_t ok = 0; for (int32_t i = 0; i < elements; ++i) { + int32_t c = i / 64; + int32_t r = i % 64; + if (r < 16 && c < 8) { + printf("%d ", host[i]); + } + if (r == 63 && c < 8) { + printf("\n"); + } if (host[i] == i) { ++ok; } } + printf("\n"); if (ok == (uint32_t)elements) { printf("Custom kernel was successful\n"); } else { @@ -200,14 +200,6 @@ tinytc_status_t custom_kernel(ze_context_handle_t context, ze_device_handle_t de zeModuleDestroy(module); } tinytc_prog_release(program); - if (source_ctx) { - const char *error_log; - tinytc_source_context_get_error_log(source_ctx, &error_log); - if (error_log[0] != '\0') { - printf("\nError log:\n%s\n", error_log); - } - tinytc_source_context_release(source_ctx); - } tinytc_core_info_release(info); if (B) { zeMemFree(context, B); diff --git a/examples/tall_and_skinny/CMakeLists.txt b/examples/tall_and_skinny/CMakeLists.txt index f65f8efb..9defb5c6 100644 --- a/examples/tall_and_skinny/CMakeLists.txt +++ b/examples/tall_and_skinny/CMakeLists.txt @@ -5,7 +5,7 @@ include(CommonOptions) find_package(SYCL REQUIRED) -add_executable(tall_and_skinny main.cpp args.cpp) +add_executable(tall_and_skinny main.cpp) add_sycl_to_target(TARGET tall_and_skinny SOURCES main.cpp) -target_link_libraries(tall_and_skinny PRIVATE tinytc tinytc_sycl) +target_link_libraries(tall_and_skinny PRIVATE tinytc tinytc_sycl argparser) set_cxx_common_options(tall_and_skinny) diff --git a/examples/tall_and_skinny/args.cpp b/examples/tall_and_skinny/args.cpp deleted file mode 100644 index 6aa1268a..00000000 --- a/examples/tall_and_skinny/args.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "args.hpp" - -#include -#include -#include -#include -#include - -args arg_parser::parse_args(int argc, char **argv) { - args a = {}; - a.beta = 0.0; - a.specialize_M = false; - a.specialize_ld = false; - auto num = std::vector(3); - for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') { - auto const fail = [&]() { - throw std::runtime_error("==> Error: unrecognized argument " + - std::string(argv[i])); - }; - if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { - a.help = true; - } else if (std::strcmp(argv[i], "-v") == 0 || std::strcmp(argv[i], "--verify") == 0) { - a.verify = true; - } else if (std::strcmp(argv[i], "--specialize-M") == 0) { - a.specialize_M = true; - } else if (std::strcmp(argv[i], "--specialize-ld") == 0) { - a.specialize_ld = true; - } else if (i + 1 < argc) { - if (std::strcmp(argv[i], "-b") == 0 || std::strcmp(argv[i], "--beta") == 0) { - ++i; - a.beta = atof(argv[i]); - } else if (std::strcmp(argv[i], "-p") == 0 || - std::strcmp(argv[i], "--precision") == 0) { - ++i; - if (argv[i][0] == 'd') { - a.double_precision = true; - } else if (argv[i][0] == 's') { - a.double_precision = false; - } else { - fail(); - } - } else { - fail(); - } - } else { - fail(); - } - } else { - num.clear(); - char const *delim = "x"; - auto arg = std::string(argv[i]); - char *token = std::strtok(argv[i], delim); - while (token) { - num.emplace_back(atoi(token)); - token = std::strtok(nullptr, delim); - } - if (num.size() != 3) { - throw std::runtime_error("==> Could not parse test case: " + arg); - } - a.tc.push_back({num[0], num[1], num[2]}); - } - } - - return a; -} - -void arg_parser::show_help(std::ostream &os) { - os << "usage: tall_and_skinny test-case1 test-case2 ..." << std::endl - << R"HELP( -positional arguments: - test-caseN MxNxK triplet (e.g. 300000x64x64) - -optional arguments: - -h, --help Show help and quit - -p, --precision Precision (single = s, double = d) - -v, --verify Verify optimized implementation - --specialize-M Specialize M instead of using dynamic value - --specialize-ld Specialize ldA, ldB, ldC instead of using dynamic value -)HELP"; -} diff --git a/examples/tall_and_skinny/args.hpp b/examples/tall_and_skinny/args.hpp deleted file mode 100644 index 701710fe..00000000 --- a/examples/tall_and_skinny/args.hpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef ARGS_20240215_HPP -#define ARGS_20240215_HPP - -#include -#include -#include - -struct test_case { - std::int64_t m; - std::int64_t n; - std::int64_t k; -}; - -struct args { - std::vector tc; - bool double_precision; - bool help; - bool verify; - double beta; - bool specialize_M; - bool specialize_ld; -}; - -class arg_parser { - public: - static args parse_args(int argc, char **argv); - static void show_help(std::ostream &os); -}; - -#endif // ARGS_20240215_HPP diff --git a/examples/tall_and_skinny/main.cpp b/examples/tall_and_skinny/main.cpp index 8538cf0e..9023e65e 100644 --- a/examples/tall_and_skinny/main.cpp +++ b/examples/tall_and_skinny/main.cpp @@ -1,14 +1,17 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "args.hpp" +#include "../gemm_common.hpp" +#include #include +#include #include #include #include #include +#include #include #include #include @@ -20,6 +23,18 @@ using namespace sycl; using namespace tinytc; +struct args { + bool dump = false; + bool specialize_M = false; + bool specialize_ld = false; + examples::test_type ty = examples::test_type::f32; + bool update = false; + bool verify = false; + std::int32_t alignment = 0; + std::int32_t M_block_size = 0; + std::vector tc; +}; + template double bench(F f, int nrepeat = 10) { f(); double min_exec_time_ns = std::numeric_limits::max(); @@ -35,11 +50,6 @@ template double bench(F f, int nrepeat = 10) { } template void test(queue q, args &a) { - auto const fill = [](std::vector &x) { - for (std::size_t i = 0; i < x.size(); ++i) { - x[i] = i % 101; - } - }; std::int64_t na_max = 0; std::int64_t nb_max = 0; std::int64_t nc_max = 0; @@ -51,28 +61,32 @@ template void test(queue q, args &a) { auto A_host = std::vector(na_max); auto B_host = std::vector(nb_max); auto C_host = std::vector(nc_max); - auto C_ref_host = std::vector(nc_max); - T *C_ref = malloc_device(nc_max, q); - T *A = malloc_device(na_max, q); - T *B = malloc_device(nb_max, q); - T *C = malloc_device(nc_max, q); - fill(A_host); - fill(B_host); - q.copy(A_host.data(), A, na_max).wait(); - q.copy(B_host.data(), B, nb_max).wait(); - - auto const check = [&](std::int64_t M, std::int64_t N) { - q.copy(C_ref, C_ref_host.data(), M * N).wait(); + auto const alloc_device = [&a, &q](std::size_t num_bytes) { + if (a.alignment == 0) { + return malloc_device(num_bytes, q); + } else { + return aligned_alloc_device(a.alignment, num_bytes, q); + } + }; + T *A = alloc_device(na_max); + T *B = alloc_device(nb_max); + T *C = alloc_device(nc_max); + + auto const check = [&](std::int64_t M, std::int64_t N, std::int64_t K) { q.copy(C, C_host.data(), M * N).wait(); std::size_t num_err = 0; - for (std::int64_t i = 0; i < M * N; ++i) { - auto err = std::abs(C_host[i] - C_ref_host[i]); - if (err > 10.0 * std::numeric_limits::epsilon()) { - if (num_err < 10) { - std::cout << i << " " << err << " " << C_host[i] << " " << C_ref_host[i] - << std::endl; + const auto error_bound = examples::test_gemm_error_bound(K); + for (std::int64_t j = 0; j < N; ++j) { + for (std::int64_t i = 0; i < M; ++i) { + const auto relerr = examples::test_gemm_rel_error(C_host.data(), i, j, M); + if (relerr > error_bound) { + if (num_err < 10) { + std::cout << "C_{" << i << "," << j << "}=" << C_host[i + j * M] + << ", relative_error=" << relerr + << ", error_bound=" << error_bound << std::endl; + } + ++num_err; } - ++num_err; } } if (num_err > 10) { @@ -80,31 +94,17 @@ template void test(queue q, args &a) { } }; - auto const &type = typeid(T); for (auto &c : a.tc) { - if (a.verify) { - q.memset(C, 0, c.m * c.n * sizeof(T)).wait(); - q.memset(C_ref, 0, c.m * c.n * sizeof(T)).wait(); - q.submit([&](auto &h) { - auto beta = a.beta; - h.parallel_for(range{static_cast(c.n), static_cast(c.m)}, - [=](id<2> it) { - auto m = it[1]; - auto n = it[0]; - auto c_acc = T(0.0); - for (std::int64_t k = 0; k < c.k; ++k) { - c_acc += A[m + k * c.m] * B[k + n * c.k]; - } - C_ref[m + n * c.m] = c_acc + T(beta) * C_ref[m + n * c.m]; - }); - }).wait(); - } + examples::test_gemm_matrix(A_host.data(), c.m, c.k); + examples::test_gemm_matrix(B_host.data(), c.k, c.n); + q.copy(A_host.data(), A, c.m * c.k).wait(); + q.copy(B_host.data(), B, c.k * c.n).wait(); + q.memset(C, 0, c.m * c.n * sizeof(T)).wait(); - auto source_ctx = source_context{}; + auto beta = a.update ? T{1} : T{0}; try { - source_ctx = make_source_context(); - auto info = make_core_info(q.get_device()); - info.set_core_features(tinytc_core_feature_flag_large_register_file); + auto info = create_core_info(q.get_device()); + set_core_features(info.get(), tinytc_core_feature_flag_large_register_file); std::int64_t M = a.specialize_M ? c.m : dynamic; std::int64_t ldA = dynamic, ldB = dynamic, ldC = dynamic; @@ -113,31 +113,47 @@ template void test(queue q, args &a) { ldB = c.k; ldC = c.m; } - auto tas = make_recipe_handler( - q, - make_tall_and_skinny_specialized(info, to_scalar_type_v, M, c.n, c.k, ldA, ldB, - ldC, 0, source_ctx), - source_ctx); - - tall_and_skinny::set_args(tas, c.m, T(1.0), A, c.m, B, c.k, T(a.beta), C, c.m); - tas.submit(q).wait(); + auto ctx = create_compiler_context(); + set_error_reporter(ctx.get(), [](char const *what, const tinytc_location_t *, void *) { + std::cerr << what << std::endl; + }); + auto r = create_tall_and_skinny_specialized(info.get(), to_type(ctx.get()), M, c.n, + c.k, ldA, ldB, ldC, a.alignment, + a.alignment, a.alignment, a.M_block_size); + if (a.dump) { + dump(get_prog(r.get()).get()); + } + auto tas = create_recipe_handler(q, r.get()); + + set_tall_and_skinny_args(tas.get(), c.m, T{1}, mem(A, mem_type::usm_pointer), c.m, + mem(B, mem_type::usm_pointer), c.k, beta, + mem(C, mem_type::usm_pointer), c.m); + submit(tas.get(), q).wait(); if (a.verify) { - check(c.m, c.n); + check(c.m, c.n, c.k); } - double min_exec_time_ns = bench([&]() { tas.submit(q).wait(); }); + double min_exec_time_ns = bench([&]() { submit(tas.get(), q).wait(); }); + + const auto ops_per_mnk = [&] { + switch (a.ty) { + case examples::test_type::c32: + case examples::test_type::c64: + return 8; + default: + return 2; + } + }(); - auto bw_C_factor = a.beta != 0.0 ? 2 : 1; + auto bw_C_factor = a.update ? 2 : 1; auto bw = sizeof(T) * (c.m * c.n * bw_C_factor + c.m * c.k + c.k * c.n) / min_exec_time_ns; - auto gflops = 2 * c.m * c.n * c.k / min_exec_time_ns; - std::cout << type.name() << "," << c.m << "," << c.n << "," << c.k << "," << a.beta - << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops << std::endl; + auto gflops = ops_per_mnk * c.m * c.n * c.k / min_exec_time_ns; + std::cout << to_string(a.ty) << "," << c.m << "," << c.n << "," << c.k << "," + << a.update << "," << min_exec_time_ns / 1e9 << "," << bw << "," << gflops + << std::endl; } catch (status const &st) { - std::cerr << "Error (" << static_cast(st) << "): " << tinytc::error_string(st) + std::cerr << "Error (" << static_cast(st) << "): " << tinytc::to_string(st) << std::endl; - if (source_ctx.get_error_log()[0] != '\0') { - std::cerr << "Error log: " << std::endl << source_ctx.get_error_log() << std::endl; - } } catch (std::exception const &e) { std::cerr << "Error: " << e.what() << std::endl; } @@ -146,31 +162,48 @@ template void test(queue q, args &a) { free(A, q); free(B, q); free(C, q); - free(C_ref, q); }; int main(int argc, char **argv) { auto a = args{}; + bool help = false; + + auto parser = cmd::arg_parser{}; try { - a = arg_parser::parse_args(argc, argv); - } catch (std::runtime_error const &e) { + parser.set_short_opt('a', &a.alignment, "Override memory alignment"); + parser.set_short_opt('d', &a.dump, "Dump IR to stdout"); + parser.set_short_opt('f', &a.ty, "Data type (bf16, f16, f32, f64, c32, c64)") + .converter(examples::convert_data_type); + parser.set_short_opt('h', &help, "Show help"); + parser.set_short_opt('u', &a.update, + "Add A*B to C (beta=1) instead of overwriting C (beta=0)"); + parser.set_short_opt('v', &a.verify, "Verify optimized implementation"); + parser.set_long_opt("help", &help, "Show help"); + parser.set_long_opt("m-block-size", &a.M_block_size, + "Set block size for M mode (one work-group per block)"); + parser.set_long_opt("specialize-m", &a.specialize_M, + "Specialize M instead of using dynamic value"); + parser.set_long_opt("specialize-ld", &a.specialize_ld, + "Specialize ldA, ldB, ldC instead of using dynamic value"); + parser.add_positional_arg("test-case", &a.tc, "MxNxK triplet (e.g. 300000x64x64)") + .converter(examples::convert_test_case) + .validator(examples::validate_test_case); + + parser.parse(argc, argv); + } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; } - if (a.help || a.tc.empty()) { - arg_parser::show_help(std::cout); - return 0; + if (help || a.tc.empty()) { + parser.print_help(std::cout, "tall_and_skinny", ""); + return !help ? -1 : 0; } auto q = queue{}; - std::cout << "precision,m,n,k,beta,time,bandwidth,gflops" << std::endl; + std::cout << "precision,m,n,k,update,time,bandwidth,gflops" << std::endl; try { - if (a.double_precision) { - test(std::move(q), a); - } else { - test(std::move(q), a); - } + dispatch(a.ty, [&]() { test(q, a); }); } catch (std::exception const &e) { std::cerr << e.what() << std::endl; return -1; diff --git a/external/doctest/doctest.h b/external/doctest/doctest.h index 8f442842..5c754cde 100644 --- a/external/doctest/doctest.h +++ b/external/doctest/doctest.h @@ -1,7019 +1,7106 @@ -// ====================================================================== lgtm [cpp/missing-header-guard] -// == DO NOT MODIFY THIS FILE BY HAND - IT IS AUTO GENERATED BY CMAKE! == -// ====================================================================== -// -// doctest.h - the lightest feature-rich C++ single-header testing framework for unit tests and TDD -// -// Copyright (c) 2016-2021 Viktor Kirilov -// -// Distributed under the MIT Software License -// See accompanying file LICENSE.txt or copy at -// https://opensource.org/licenses/MIT -// -// The documentation can be found at the library's page: -// https://github.com/doctest/doctest/blob/master/doc/markdown/readme.md -// -// ================================================================================================= -// ================================================================================================= -// ================================================================================================= -// -// The library is heavily influenced by Catch - https://github.com/catchorg/Catch2 -// which uses the Boost Software License - Version 1.0 -// see here - https://github.com/catchorg/Catch2/blob/master/LICENSE.txt -// -// The concept of subcases (sections in Catch) and expression decomposition are from there. -// Some parts of the code are taken directly: -// - stringification - the detection of "ostream& operator<<(ostream&, const T&)" and StringMaker<> -// - the Approx() helper class for floating point comparison -// - colors in the console -// - breaking into a debugger -// - signal / SEH handling -// - timer -// - XmlWriter class - thanks to Phil Nash for allowing the direct reuse (AKA copy/paste) -// -// The expression decomposing templates are taken from lest - https://github.com/martinmoene/lest -// which uses the Boost Software License - Version 1.0 -// see here - https://github.com/martinmoene/lest/blob/master/LICENSE.txt -// -// ================================================================================================= -// ================================================================================================= -// ================================================================================================= - -#ifndef DOCTEST_LIBRARY_INCLUDED -#define DOCTEST_LIBRARY_INCLUDED - -// ================================================================================================= -// == VERSION ====================================================================================== -// ================================================================================================= - -#define DOCTEST_VERSION_MAJOR 2 -#define DOCTEST_VERSION_MINOR 4 -#define DOCTEST_VERSION_PATCH 9 - -// util we need here -#define DOCTEST_TOSTR_IMPL(x) #x -#define DOCTEST_TOSTR(x) DOCTEST_TOSTR_IMPL(x) - -#define DOCTEST_VERSION_STR \ - DOCTEST_TOSTR(DOCTEST_VERSION_MAJOR) "." \ - DOCTEST_TOSTR(DOCTEST_VERSION_MINOR) "." \ - DOCTEST_TOSTR(DOCTEST_VERSION_PATCH) - -#define DOCTEST_VERSION \ - (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) - -// ================================================================================================= -// == COMPILER VERSION ============================================================================= -// ================================================================================================= - -// ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect - -#ifdef _MSC_VER -#define DOCTEST_CPLUSPLUS _MSVC_LANG -#else -#define DOCTEST_CPLUSPLUS __cplusplus -#endif - -#define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) - -// GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... -#if defined(_MSC_VER) && defined(_MSC_FULL_VER) -#if _MSC_VER == _MSC_FULL_VER / 10000 -#define DOCTEST_MSVC DOCTEST_COMPILER(_MSC_VER / 100, _MSC_VER % 100, _MSC_FULL_VER % 10000) -#else // MSVC -#define DOCTEST_MSVC \ - DOCTEST_COMPILER(_MSC_VER / 100, (_MSC_FULL_VER / 100000) % 100, _MSC_FULL_VER % 100000) -#endif // MSVC -#endif // MSVC -#if defined(__clang__) && defined(__clang_minor__) -#define DOCTEST_CLANG DOCTEST_COMPILER(__clang_major__, __clang_minor__, __clang_patchlevel__) -#elif defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) && \ - !defined(__INTEL_COMPILER) -#define DOCTEST_GCC DOCTEST_COMPILER(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#endif // GCC - -#ifndef DOCTEST_MSVC -#define DOCTEST_MSVC 0 -#endif // DOCTEST_MSVC -#ifndef DOCTEST_CLANG -#define DOCTEST_CLANG 0 -#endif // DOCTEST_CLANG -#ifndef DOCTEST_GCC -#define DOCTEST_GCC 0 -#endif // DOCTEST_GCC - -// ================================================================================================= -// == COMPILER WARNINGS HELPERS ==================================================================== -// ================================================================================================= - -#if DOCTEST_CLANG -#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) -#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") -#define DOCTEST_CLANG_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(clang diagnostic ignored w) -#define DOCTEST_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") -#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ - DOCTEST_CLANG_SUPPRESS_WARNING_PUSH DOCTEST_CLANG_SUPPRESS_WARNING(w) -#else // DOCTEST_CLANG -#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -#define DOCTEST_CLANG_SUPPRESS_WARNING(w) -#define DOCTEST_CLANG_SUPPRESS_WARNING_POP -#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) -#endif // DOCTEST_CLANG - -#if DOCTEST_GCC -#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) -#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH _Pragma("GCC diagnostic push") -#define DOCTEST_GCC_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(GCC diagnostic ignored w) -#define DOCTEST_GCC_SUPPRESS_WARNING_POP _Pragma("GCC diagnostic pop") -#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) \ - DOCTEST_GCC_SUPPRESS_WARNING_PUSH DOCTEST_GCC_SUPPRESS_WARNING(w) -#else // DOCTEST_GCC -#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH -#define DOCTEST_GCC_SUPPRESS_WARNING(w) -#define DOCTEST_GCC_SUPPRESS_WARNING_POP -#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) -#endif // DOCTEST_GCC - -#if DOCTEST_MSVC -#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH __pragma(warning(push)) -#define DOCTEST_MSVC_SUPPRESS_WARNING(w) __pragma(warning(disable : w)) -#define DOCTEST_MSVC_SUPPRESS_WARNING_POP __pragma(warning(pop)) -#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) \ - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH DOCTEST_MSVC_SUPPRESS_WARNING(w) -#else // DOCTEST_MSVC -#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -#define DOCTEST_MSVC_SUPPRESS_WARNING(w) -#define DOCTEST_MSVC_SUPPRESS_WARNING_POP -#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) -#endif // DOCTEST_MSVC - -// ================================================================================================= -// == COMPILER WARNINGS ============================================================================ -// ================================================================================================= - -// both the header and the implementation suppress all of these, -// so it only makes sense to aggregrate them like so -#define DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH \ - DOCTEST_CLANG_SUPPRESS_WARNING_PUSH \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") \ - DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") \ - \ - DOCTEST_GCC_SUPPRESS_WARNING_PUSH \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") \ - DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") \ - \ - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ - /* these 4 also disabled globally via cmake: */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4514) /* unreferenced inline function has been removed */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4571) /* SEH related */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4710) /* function not inlined */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4711) /* function selected for inline expansion*/ \ - /* */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4616) /* invalid compiler warning */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4619) /* invalid compiler warning */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4996) /* The compiler encountered a deprecated declaration */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4706) /* assignment within conditional expression */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4512) /* 'class' : assignment operator could not be generated */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4127) /* conditional expression is constant */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4640) /* construction of local static object not thread-safe */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ - /* static analysis */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26439) /* Function may not throw. Declare it 'noexcept' */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26495) /* Always initialize a member variable */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26451) /* Arithmetic overflow ... */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26444) /* Avoid unnamed objects with custom ctor and dtor... */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(26812) /* Prefer 'enum class' over 'enum' */ - -#define DOCTEST_SUPPRESS_COMMON_WARNINGS_POP \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP \ - DOCTEST_GCC_SUPPRESS_WARNING_POP \ - DOCTEST_MSVC_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH - -DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") - -DOCTEST_GCC_SUPPRESS_WARNING_PUSH -DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") -DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") -DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") - -DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted - -#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ - DOCTEST_MSVC_SUPPRESS_WARNING(4548) /* before comma no effect; expected side - effect */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4265) /* virtual functions, but destructor is not virtual */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4986) /* exception specification does not match previous */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4350) /* 'member1' called instead of 'member2' */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4668) /* not defined as a preprocessor macro */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4365) /* signed/unsigned mismatch */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4774) /* format string not a string literal */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4623) /* default constructor was implicitly deleted */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5039) /* pointer to pot. throwing function passed to extern C */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(5105) /* macro producing 'defined' has undefined behavior */ \ - DOCTEST_MSVC_SUPPRESS_WARNING(4738) /* storing float result in memory, loss of performance */ - -#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP - -// ================================================================================================= -// == FEATURE DETECTION ============================================================================ -// ================================================================================================= - -// general compiler feature support table: https://en.cppreference.com/w/cpp/compiler_support -// MSVC C++11 feature support table: https://msdn.microsoft.com/en-us/library/hh567368.aspx -// GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html -// MSVC version table: -// https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering -// MSVC++ 14.3 (17) _MSC_VER == 1930 (Visual Studio 2022) -// MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) -// MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) -// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) -// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) -// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) -// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) -// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) -// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) - -// Universal Windows Platform support -#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) -#define DOCTEST_CONFIG_NO_WINDOWS_SEH -#endif // WINAPI_FAMILY -#if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) -#define DOCTEST_CONFIG_WINDOWS_SEH -#endif // MSVC -#if defined(DOCTEST_CONFIG_NO_WINDOWS_SEH) && defined(DOCTEST_CONFIG_WINDOWS_SEH) -#undef DOCTEST_CONFIG_WINDOWS_SEH -#endif // DOCTEST_CONFIG_NO_WINDOWS_SEH - -#if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ - !defined(__EMSCRIPTEN__) && !defined(__wasi__) -#define DOCTEST_CONFIG_POSIX_SIGNALS -#endif // _WIN32 -#if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) -#undef DOCTEST_CONFIG_POSIX_SIGNALS -#endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS -#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) \ - || defined(__wasi__) -#define DOCTEST_CONFIG_NO_EXCEPTIONS -#endif // no exceptions -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS -#define DOCTEST_CONFIG_NO_EXCEPTIONS -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS - -#if defined(DOCTEST_CONFIG_NO_EXCEPTIONS) && !defined(DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS) -#define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS - -#ifdef __wasi__ -#define DOCTEST_CONFIG_NO_MULTITHREADING -#endif - -#if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) -#define DOCTEST_CONFIG_IMPLEMENT -#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN - -#if defined(_WIN32) || defined(__CYGWIN__) -#if DOCTEST_MSVC -#define DOCTEST_SYMBOL_EXPORT __declspec(dllexport) -#define DOCTEST_SYMBOL_IMPORT __declspec(dllimport) -#else // MSVC -#define DOCTEST_SYMBOL_EXPORT __attribute__((dllexport)) -#define DOCTEST_SYMBOL_IMPORT __attribute__((dllimport)) -#endif // MSVC -#else // _WIN32 -#define DOCTEST_SYMBOL_EXPORT __attribute__((visibility("default"))) -#define DOCTEST_SYMBOL_IMPORT -#endif // _WIN32 - -#ifdef DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL -#ifdef DOCTEST_CONFIG_IMPLEMENT -#define DOCTEST_INTERFACE DOCTEST_SYMBOL_EXPORT -#else // DOCTEST_CONFIG_IMPLEMENT -#define DOCTEST_INTERFACE DOCTEST_SYMBOL_IMPORT -#endif // DOCTEST_CONFIG_IMPLEMENT -#else // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL -#define DOCTEST_INTERFACE -#endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL - -// needed for extern template instantiations -// see https://github.com/fmtlib/fmt/issues/2228 -#if DOCTEST_MSVC -#define DOCTEST_INTERFACE_DECL -#define DOCTEST_INTERFACE_DEF DOCTEST_INTERFACE -#else // DOCTEST_MSVC -#define DOCTEST_INTERFACE_DECL DOCTEST_INTERFACE -#define DOCTEST_INTERFACE_DEF -#endif // DOCTEST_MSVC - -#define DOCTEST_EMPTY - -#if DOCTEST_MSVC -#define DOCTEST_NOINLINE __declspec(noinline) -#define DOCTEST_UNUSED -#define DOCTEST_ALIGNMENT(x) -#elif DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 5, 0) -#define DOCTEST_NOINLINE -#define DOCTEST_UNUSED -#define DOCTEST_ALIGNMENT(x) -#else -#define DOCTEST_NOINLINE __attribute__((noinline)) -#define DOCTEST_UNUSED __attribute__((unused)) -#define DOCTEST_ALIGNMENT(x) __attribute__((aligned(x))) -#endif - -#ifndef DOCTEST_NORETURN -#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_NORETURN -#else // DOCTEST_MSVC -#define DOCTEST_NORETURN [[noreturn]] -#endif // DOCTEST_MSVC -#endif // DOCTEST_NORETURN - -#ifndef DOCTEST_NOEXCEPT -#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_NOEXCEPT -#else // DOCTEST_MSVC -#define DOCTEST_NOEXCEPT noexcept -#endif // DOCTEST_MSVC -#endif // DOCTEST_NOEXCEPT - -#ifndef DOCTEST_CONSTEXPR -#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_CONSTEXPR const -#define DOCTEST_CONSTEXPR_FUNC inline -#else // DOCTEST_MSVC -#define DOCTEST_CONSTEXPR constexpr -#define DOCTEST_CONSTEXPR_FUNC constexpr -#endif // DOCTEST_MSVC -#endif // DOCTEST_CONSTEXPR - -// ================================================================================================= -// == FEATURE DETECTION END ======================================================================== -// ================================================================================================= - -#define DOCTEST_DECLARE_INTERFACE(name) \ - virtual ~name(); \ - name() = default; \ - name(const name&) = delete; \ - name(name&&) = delete; \ - name& operator=(const name&) = delete; \ - name& operator=(name&&) = delete; - -#define DOCTEST_DEFINE_INTERFACE(name) \ - name::~name() = default; - -// internal macros for string concatenation and anonymous variable name generation -#define DOCTEST_CAT_IMPL(s1, s2) s1##s2 -#define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) -#ifdef __COUNTER__ // not standard and may be missing for some compilers -#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __COUNTER__) -#else // __COUNTER__ -#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) -#endif // __COUNTER__ - -#ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE -#define DOCTEST_REF_WRAP(x) x& -#else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE -#define DOCTEST_REF_WRAP(x) x -#endif // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE - -// not using __APPLE__ because... this is how Catch does it -#ifdef __MAC_OS_X_VERSION_MIN_REQUIRED -#define DOCTEST_PLATFORM_MAC -#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) -#define DOCTEST_PLATFORM_IPHONE -#elif defined(_WIN32) -#define DOCTEST_PLATFORM_WINDOWS -#elif defined(__wasi__) -#define DOCTEST_PLATFORM_WASI -#else // DOCTEST_PLATFORM -#define DOCTEST_PLATFORM_LINUX -#endif // DOCTEST_PLATFORM - -namespace doctest { namespace detail { - static DOCTEST_CONSTEXPR int consume(const int*, int) noexcept { return 0; } -}} - -#define DOCTEST_GLOBAL_NO_WARNINGS(var, ...) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ - static const int var = doctest::detail::consume(&var, __VA_ARGS__); \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP - -#ifndef DOCTEST_BREAK_INTO_DEBUGGER -// should probably take a look at https://github.com/scottt/debugbreak -#ifdef DOCTEST_PLATFORM_LINUX -#if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) -// Break at the location of the failing check if possible -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) -#else -#include -#define DOCTEST_BREAK_INTO_DEBUGGER() raise(SIGTRAP) -#endif -#elif defined(DOCTEST_PLATFORM_MAC) -#if defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || defined(__i386) -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) -#elif defined(__ppc__) || defined(__ppc64__) -// https://www.cocoawithlove.com/2008/03/break-into-debugger.html -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n": : : "memory","r0","r3","r4") // NOLINT(hicpp-no-assembler) -#else -#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT(hicpp-no-assembler) -#endif -#elif DOCTEST_MSVC -#define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() -#elif defined(__MINGW32__) -DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wredundant-decls") -extern "C" __declspec(dllimport) void __stdcall DebugBreak(); -DOCTEST_GCC_SUPPRESS_WARNING_POP -#define DOCTEST_BREAK_INTO_DEBUGGER() ::DebugBreak() -#else // linux -#define DOCTEST_BREAK_INTO_DEBUGGER() (static_cast(0)) -#endif // linux -#endif // DOCTEST_BREAK_INTO_DEBUGGER - -// this is kept here for backwards compatibility since the config option was changed -#ifdef DOCTEST_CONFIG_USE_IOSFWD -#ifndef DOCTEST_CONFIG_USE_STD_HEADERS -#define DOCTEST_CONFIG_USE_STD_HEADERS -#endif -#endif // DOCTEST_CONFIG_USE_IOSFWD - -// for clang - always include ciso646 (which drags some std stuff) because -// we want to check if we are using libc++ with the _LIBCPP_VERSION macro in -// which case we don't want to forward declare stuff from std - for reference: -// https://github.com/doctest/doctest/issues/126 -// https://github.com/doctest/doctest/issues/356 -#if DOCTEST_CLANG -#include -#ifdef _LIBCPP_VERSION -#ifndef DOCTEST_CONFIG_USE_STD_HEADERS -#define DOCTEST_CONFIG_USE_STD_HEADERS -#endif -#endif // _LIBCPP_VERSION -#endif // clang - -#ifdef DOCTEST_CONFIG_USE_STD_HEADERS -#ifndef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#define DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN -#include -#include -#include -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END -#else // DOCTEST_CONFIG_USE_STD_HEADERS - -// Forward declaring 'X' in namespace std is not permitted by the C++ Standard. -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) - -namespace std { // NOLINT(cert-dcl58-cpp) -typedef decltype(nullptr) nullptr_t; // NOLINT(modernize-use-using) -typedef decltype(sizeof(void*)) size_t; // NOLINT(modernize-use-using) -template -struct char_traits; -template <> -struct char_traits; -template -class basic_ostream; // NOLINT(fuchsia-virtual-inheritance) -typedef basic_ostream> ostream; // NOLINT(modernize-use-using) -template -// NOLINTNEXTLINE -basic_ostream& operator<<(basic_ostream&, const char*); -template -class basic_istream; -typedef basic_istream> istream; // NOLINT(modernize-use-using) -template -class tuple; -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 -template -class allocator; -template -class basic_string; -using string = basic_string, allocator>; -#endif // VS 2019 -} // namespace std - -DOCTEST_MSVC_SUPPRESS_WARNING_POP - -#endif // DOCTEST_CONFIG_USE_STD_HEADERS - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#include -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - -namespace doctest { - -using std::size_t; - -DOCTEST_INTERFACE extern bool is_running_in_test; - -#ifndef DOCTEST_CONFIG_STRING_SIZE_TYPE -#define DOCTEST_CONFIG_STRING_SIZE_TYPE unsigned -#endif - -// A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length -// of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: -// - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) -// - if small - capacity left before going on the heap - using the lowest 5 bits -// - if small - 2 bits are left unused - the second and third highest ones -// - if small - acts as a null terminator if strlen() is 23 (24 including the null terminator) -// and the "is small" bit remains "0" ("as well as the capacity left") so its OK -// Idea taken from this lecture about the string implementation of facebook/folly - fbstring -// https://www.youtube.com/watch?v=kPR8h4-qZdk -// TODO: -// - optimizations - like not deleting memory unnecessarily in operator= and etc. -// - resize/reserve/clear -// - replace -// - back/front -// - iterator stuff -// - find & friends -// - push_back/pop_back -// - assign/insert/erase -// - relational operators as free functions - taking const char* as one of the params -class DOCTEST_INTERFACE String -{ -public: - using size_type = DOCTEST_CONFIG_STRING_SIZE_TYPE; - -private: - static DOCTEST_CONSTEXPR size_type len = 24; //!OCLINT avoid private static members - static DOCTEST_CONSTEXPR size_type last = len - 1; //!OCLINT avoid private static members - - struct view // len should be more than sizeof(view) - because of the final byte for flags - { - char* ptr; - size_type size; - size_type capacity; - }; - - union - { - char buf[len]; // NOLINT(*-avoid-c-arrays) - view data; - }; - - char* allocate(size_type sz); - - bool isOnStack() const noexcept { return (buf[last] & 128) == 0; } - void setOnHeap() noexcept; - void setLast(size_type in = last) noexcept; - void setSize(size_type sz) noexcept; - - void copy(const String& other); - -public: - static DOCTEST_CONSTEXPR size_type npos = static_cast(-1); - - String() noexcept; - ~String(); - - // cppcheck-suppress noExplicitConstructor - String(const char* in); - String(const char* in, size_type in_size); - - String(std::istream& in, size_type in_size); - - String(const String& other); - String& operator=(const String& other); - - String& operator+=(const String& other); - - String(String&& other) noexcept; - String& operator=(String&& other) noexcept; - - char operator[](size_type i) const; - char& operator[](size_type i); - - // the only functions I'm willing to leave in the interface - available for inlining - const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT - char* c_str() { - if (isOnStack()) { - return reinterpret_cast(buf); - } - return data.ptr; - } - - size_type size() const; - size_type capacity() const; - - String substr(size_type pos, size_type cnt = npos) &&; - String substr(size_type pos, size_type cnt = npos) const &; - - size_type find(char ch, size_type pos = 0) const; - size_type rfind(char ch, size_type pos = npos) const; - - int compare(const char* other, bool no_case = false) const; - int compare(const String& other, bool no_case = false) const; - -friend DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); -}; - -DOCTEST_INTERFACE String operator+(const String& lhs, const String& rhs); - -DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); - -class DOCTEST_INTERFACE Contains { -public: - explicit Contains(const String& string); - - bool checkWith(const String& other) const; - - String string; -}; - -DOCTEST_INTERFACE String toString(const Contains& in); - -DOCTEST_INTERFACE bool operator==(const String& lhs, const Contains& rhs); -DOCTEST_INTERFACE bool operator==(const Contains& lhs, const String& rhs); -DOCTEST_INTERFACE bool operator!=(const String& lhs, const Contains& rhs); -DOCTEST_INTERFACE bool operator!=(const Contains& lhs, const String& rhs); - -namespace Color { - enum Enum - { - None = 0, - White, - Red, - Green, - Blue, - Cyan, - Yellow, - Grey, - - Bright = 0x10, - - BrightRed = Bright | Red, - BrightGreen = Bright | Green, - LightGrey = Bright | Grey, - BrightWhite = Bright | White - }; - - DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, Color::Enum code); -} // namespace Color - -namespace assertType { - enum Enum - { - // macro traits - - is_warn = 1, - is_check = 2 * is_warn, - is_require = 2 * is_check, - - is_normal = 2 * is_require, - is_throws = 2 * is_normal, - is_throws_as = 2 * is_throws, - is_throws_with = 2 * is_throws_as, - is_nothrow = 2 * is_throws_with, - - is_false = 2 * is_nothrow, - is_unary = 2 * is_false, // not checked anywhere - used just to distinguish the types - - is_eq = 2 * is_unary, - is_ne = 2 * is_eq, - - is_lt = 2 * is_ne, - is_gt = 2 * is_lt, - - is_ge = 2 * is_gt, - is_le = 2 * is_ge, - - // macro types - - DT_WARN = is_normal | is_warn, - DT_CHECK = is_normal | is_check, - DT_REQUIRE = is_normal | is_require, - - DT_WARN_FALSE = is_normal | is_false | is_warn, - DT_CHECK_FALSE = is_normal | is_false | is_check, - DT_REQUIRE_FALSE = is_normal | is_false | is_require, - - DT_WARN_THROWS = is_throws | is_warn, - DT_CHECK_THROWS = is_throws | is_check, - DT_REQUIRE_THROWS = is_throws | is_require, - - DT_WARN_THROWS_AS = is_throws_as | is_warn, - DT_CHECK_THROWS_AS = is_throws_as | is_check, - DT_REQUIRE_THROWS_AS = is_throws_as | is_require, - - DT_WARN_THROWS_WITH = is_throws_with | is_warn, - DT_CHECK_THROWS_WITH = is_throws_with | is_check, - DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, - - DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, - DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, - DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, - - DT_WARN_NOTHROW = is_nothrow | is_warn, - DT_CHECK_NOTHROW = is_nothrow | is_check, - DT_REQUIRE_NOTHROW = is_nothrow | is_require, - - DT_WARN_EQ = is_normal | is_eq | is_warn, - DT_CHECK_EQ = is_normal | is_eq | is_check, - DT_REQUIRE_EQ = is_normal | is_eq | is_require, - - DT_WARN_NE = is_normal | is_ne | is_warn, - DT_CHECK_NE = is_normal | is_ne | is_check, - DT_REQUIRE_NE = is_normal | is_ne | is_require, - - DT_WARN_GT = is_normal | is_gt | is_warn, - DT_CHECK_GT = is_normal | is_gt | is_check, - DT_REQUIRE_GT = is_normal | is_gt | is_require, - - DT_WARN_LT = is_normal | is_lt | is_warn, - DT_CHECK_LT = is_normal | is_lt | is_check, - DT_REQUIRE_LT = is_normal | is_lt | is_require, - - DT_WARN_GE = is_normal | is_ge | is_warn, - DT_CHECK_GE = is_normal | is_ge | is_check, - DT_REQUIRE_GE = is_normal | is_ge | is_require, - - DT_WARN_LE = is_normal | is_le | is_warn, - DT_CHECK_LE = is_normal | is_le | is_check, - DT_REQUIRE_LE = is_normal | is_le | is_require, - - DT_WARN_UNARY = is_normal | is_unary | is_warn, - DT_CHECK_UNARY = is_normal | is_unary | is_check, - DT_REQUIRE_UNARY = is_normal | is_unary | is_require, - - DT_WARN_UNARY_FALSE = is_normal | is_false | is_unary | is_warn, - DT_CHECK_UNARY_FALSE = is_normal | is_false | is_unary | is_check, - DT_REQUIRE_UNARY_FALSE = is_normal | is_false | is_unary | is_require, - }; -} // namespace assertType - -DOCTEST_INTERFACE const char* assertString(assertType::Enum at); -DOCTEST_INTERFACE const char* failureString(assertType::Enum at); -DOCTEST_INTERFACE const char* skipPathFromFilename(const char* file); - -struct DOCTEST_INTERFACE TestCaseData -{ - String m_file; // the file in which the test was registered (using String - see #350) - unsigned m_line; // the line where the test was registered - const char* m_name; // name of the test case - const char* m_test_suite; // the test suite in which the test was added - const char* m_description; - bool m_skip; - bool m_no_breaks; - bool m_no_output; - bool m_may_fail; - bool m_should_fail; - int m_expected_failures; - double m_timeout; -}; - -struct DOCTEST_INTERFACE AssertData -{ - // common - for all asserts - const TestCaseData* m_test_case; - assertType::Enum m_at; - const char* m_file; - int m_line; - const char* m_expr; - bool m_failed; - - // exception-related - for all asserts - bool m_threw; - String m_exception; - - // for normal asserts - String m_decomp; - - // for specific exception-related asserts - bool m_threw_as; - const char* m_exception_type; - - class DOCTEST_INTERFACE StringContains { - private: - Contains content; - bool isContains; - - public: - StringContains(const String& str) : content(str), isContains(false) { } - StringContains(Contains cntn) : content(static_cast(cntn)), isContains(true) { } - - bool check(const String& str) { return isContains ? (content == str) : (content.string == str); } - - operator const String&() const { return content.string; } - - const char* c_str() const { return content.string.c_str(); } - } m_exception_string; - - AssertData(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const StringContains& exception_string); -}; - -struct DOCTEST_INTERFACE MessageData -{ - String m_string; - const char* m_file; - int m_line; - assertType::Enum m_severity; -}; - -struct DOCTEST_INTERFACE SubcaseSignature -{ - String m_name; - const char* m_file; - int m_line; - - bool operator==(const SubcaseSignature& other) const; - bool operator<(const SubcaseSignature& other) const; -}; - -struct DOCTEST_INTERFACE IContextScope -{ - DOCTEST_DECLARE_INTERFACE(IContextScope) - virtual void stringify(std::ostream*) const = 0; -}; - -namespace detail { - struct DOCTEST_INTERFACE TestCase; -} // namespace detail - -struct ContextOptions //!OCLINT too many fields -{ - std::ostream* cout = nullptr; // stdout stream - String binary_name; // the test binary name - - const detail::TestCase* currentTest = nullptr; - - // == parameters from the command line - String out; // output filename - String order_by; // how tests should be ordered - unsigned rand_seed; // the seed for rand ordering - - unsigned first; // the first (matching) test to be executed - unsigned last; // the last (matching) test to be executed - - int abort_after; // stop tests after this many failed assertions - int subcase_filter_levels; // apply the subcase filters for the first N levels - - bool success; // include successful assertions in output - bool case_sensitive; // if filtering should be case sensitive - bool exit; // if the program should be exited after the tests are ran/whatever - bool duration; // print the time duration of each test case - bool minimal; // minimal console output (only test failures) - bool quiet; // no console output - bool no_throw; // to skip exceptions-related assertion macros - bool no_exitcode; // if the framework should return 0 as the exitcode - bool no_run; // to not run the tests at all (can be done with an "*" exclude) - bool no_intro; // to not print the intro of the framework - bool no_version; // to not print the version of the framework - bool no_colors; // if output to the console should be colorized - bool force_colors; // forces the use of colors even when a tty cannot be detected - bool no_breaks; // to not break into the debugger - bool no_skip; // don't skip test cases which are marked to be skipped - bool gnu_file_line; // if line numbers should be surrounded with :x: and not (x): - bool no_path_in_filenames; // if the path to files should be removed from the output - bool no_line_numbers; // if source code line numbers should be omitted from the output - bool no_debug_output; // no output in the debug console when a debugger is attached - bool no_skipped_summary; // don't print "skipped" in the summary !!! UNDOCUMENTED !!! - bool no_time_in_output; // omit any time/timestamps from output !!! UNDOCUMENTED !!! - - bool help; // to print the help - bool version; // to print the version - bool count; // if only the count of matching tests is to be retrieved - bool list_test_cases; // to list all tests matching the filters - bool list_test_suites; // to list all suites matching the filters - bool list_reporters; // lists all registered reporters -}; - -namespace detail { - namespace types { -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - using namespace std; -#else - template - struct enable_if { }; - - template - struct enable_if { using type = T; }; - - struct true_type { static DOCTEST_CONSTEXPR bool value = true; }; - struct false_type { static DOCTEST_CONSTEXPR bool value = false; }; - - template struct remove_reference { using type = T; }; - template struct remove_reference { using type = T; }; - template struct remove_reference { using type = T; }; - - template struct is_rvalue_reference : false_type { }; - template struct is_rvalue_reference : true_type { }; - - template struct remove_const { using type = T; }; - template struct remove_const { using type = T; }; - - // Compiler intrinsics - template struct is_enum { static DOCTEST_CONSTEXPR bool value = __is_enum(T); }; - template struct underlying_type { using type = __underlying_type(T); }; - - template struct is_pointer : false_type { }; - template struct is_pointer : true_type { }; - - template struct is_array : false_type { }; - // NOLINTNEXTLINE(*-avoid-c-arrays) - template struct is_array : true_type { }; -#endif - } - - // - template - T&& declval(); - - template - DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type& t) DOCTEST_NOEXCEPT { - return static_cast(t); - } - - template - DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type&& t) DOCTEST_NOEXCEPT { - return static_cast(t); - } - - template - struct deferred_false : types::false_type { }; - -// MSVS 2015 :( -#if defined(_MSC_VER) && _MSC_VER <= 1900 - template - struct has_global_insertion_operator : types::false_type { }; - - template - struct has_global_insertion_operator(), declval()), void())> : types::true_type { }; - - template - struct has_insertion_operator { static DOCTEST_CONSTEXPR bool value = has_global_insertion_operator::value; }; - - template - struct insert_hack; - - template - struct insert_hack { - static void insert(std::ostream& os, const T& t) { ::operator<<(os, t); } - }; - - template - struct insert_hack { - static void insert(std::ostream& os, const T& t) { operator<<(os, t); } - }; - - template - using insert_hack_t = insert_hack::value>; -#else - template - struct has_insertion_operator : types::false_type { }; -#endif - -template -struct has_insertion_operator(), declval()), void())> : types::true_type { }; - - DOCTEST_INTERFACE std::ostream* tlssPush(); - DOCTEST_INTERFACE String tlssPop(); - - template - struct StringMakerBase { - template - static String convert(const DOCTEST_REF_WRAP(T)) { -#ifdef DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES - static_assert(deferred_false::value, "No stringification detected for type T. See string conversion manual"); -#endif - return "{?}"; - } - }; - - template - struct filldata; - - template - void filloss(std::ostream* stream, const T& in) { - filldata::fill(stream, in); - } - - template - void filloss(std::ostream* stream, const T (&in)[N]) { // NOLINT(*-avoid-c-arrays) - // T[N], T(&)[N], T(&&)[N] have same behaviour. - // Hence remove reference. - filloss::type>(stream, in); - } - - template - String toStream(const T& in) { - std::ostream* stream = tlssPush(); - filloss(stream, in); - return tlssPop(); - } - - template <> - struct StringMakerBase { - template - static String convert(const DOCTEST_REF_WRAP(T) in) { - return toStream(in); - } - }; -} // namespace detail - -template -struct StringMaker : public detail::StringMakerBase< - detail::has_insertion_operator::value || detail::types::is_pointer::value || detail::types::is_array::value> -{}; - -#ifndef DOCTEST_STRINGIFY -#ifdef DOCTEST_CONFIG_DOUBLE_STRINGIFY -#define DOCTEST_STRINGIFY(...) toString(toString(__VA_ARGS__)) -#else -#define DOCTEST_STRINGIFY(...) toString(__VA_ARGS__) -#endif -#endif - -template -String toString() { -#if DOCTEST_MSVC >= 0 && DOCTEST_CLANG == 0 && DOCTEST_GCC == 0 - String ret = __FUNCSIG__; // class doctest::String __cdecl doctest::toString(void) - String::size_type beginPos = ret.find('<'); - return ret.substr(beginPos + 1, ret.size() - beginPos - static_cast(sizeof(">(void)"))); -#else - String ret = __PRETTY_FUNCTION__; // doctest::String toString() [with T = TYPE] - String::size_type begin = ret.find('=') + 2; - return ret.substr(begin, ret.size() - begin - 1); -#endif -} - -template ::value, bool>::type = true> -String toString(const DOCTEST_REF_WRAP(T) value) { - return StringMaker::convert(value); -} - -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -DOCTEST_INTERFACE String toString(const char* in); -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 -DOCTEST_INTERFACE String toString(const std::string& in); -#endif // VS 2019 - -DOCTEST_INTERFACE String toString(String in); - -DOCTEST_INTERFACE String toString(std::nullptr_t); - -DOCTEST_INTERFACE String toString(bool in); - -DOCTEST_INTERFACE String toString(float in); -DOCTEST_INTERFACE String toString(double in); -DOCTEST_INTERFACE String toString(double long in); - -DOCTEST_INTERFACE String toString(char in); -DOCTEST_INTERFACE String toString(char signed in); -DOCTEST_INTERFACE String toString(char unsigned in); -DOCTEST_INTERFACE String toString(short in); -DOCTEST_INTERFACE String toString(short unsigned in); -DOCTEST_INTERFACE String toString(signed in); -DOCTEST_INTERFACE String toString(unsigned in); -DOCTEST_INTERFACE String toString(long in); -DOCTEST_INTERFACE String toString(long unsigned in); -DOCTEST_INTERFACE String toString(long long in); -DOCTEST_INTERFACE String toString(long long unsigned in); - -template ::value, bool>::type = true> -String toString(const DOCTEST_REF_WRAP(T) value) { - using UT = typename detail::types::underlying_type::type; - return (DOCTEST_STRINGIFY(static_cast(value))); -} - -namespace detail { - template - struct filldata - { - static void fill(std::ostream* stream, const T& in) { -#if defined(_MSC_VER) && _MSC_VER <= 1900 - insert_hack_t::insert(*stream, in); -#else - operator<<(*stream, in); -#endif - } - }; - -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) -// NOLINTBEGIN(*-avoid-c-arrays) - template - struct filldata { - static void fill(std::ostream* stream, const T(&in)[N]) { - *stream << "["; - for (size_t i = 0; i < N; i++) { - if (i != 0) { *stream << ", "; } - *stream << (DOCTEST_STRINGIFY(in[i])); - } - *stream << "]"; - } - }; -// NOLINTEND(*-avoid-c-arrays) -DOCTEST_MSVC_SUPPRESS_WARNING_POP - - // Specialized since we don't want the terminating null byte! -// NOLINTBEGIN(*-avoid-c-arrays) - template - struct filldata { - static void fill(std::ostream* stream, const char (&in)[N]) { - *stream << String(in, in[N - 1] ? N : N - 1); - } // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - }; -// NOLINTEND(*-avoid-c-arrays) - - template <> - struct filldata { - static void fill(std::ostream* stream, const void* in); - }; - - template - struct filldata { - static void fill(std::ostream* stream, const T* in) { - filldata::fill(stream, in); - } - }; -} - -struct DOCTEST_INTERFACE Approx -{ - Approx(double value); - - Approx operator()(double value) const; - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template - explicit Approx(const T& value, - typename detail::types::enable_if::value>::type* = - static_cast(nullptr)) { - *this = static_cast(value); - } -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - Approx& epsilon(double newEpsilon); - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template - typename std::enable_if::value, Approx&>::type epsilon( - const T& newEpsilon) { - m_epsilon = static_cast(newEpsilon); - return *this; - } -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - Approx& scale(double newScale); - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - template - typename std::enable_if::value, Approx&>::type scale( - const T& newScale) { - m_scale = static_cast(newScale); - return *this; - } -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - // clang-format off - DOCTEST_INTERFACE friend bool operator==(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator==(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator!=(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator!=(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator<=(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator<=(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator>=(double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator>=(const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator< (double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator< (const Approx & lhs, double rhs); - DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); - DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); - -#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS -#define DOCTEST_APPROX_PREFIX \ - template friend typename std::enable_if::value, bool>::type - - DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(static_cast(lhs), rhs); } - DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } - DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } - DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } - DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) || lhs == rhs; } - DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value && lhs != rhs; } - DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) && lhs != rhs; } -#undef DOCTEST_APPROX_PREFIX -#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS - - // clang-format on - - double m_epsilon; - double m_scale; - double m_value; -}; - -DOCTEST_INTERFACE String toString(const Approx& in); - -DOCTEST_INTERFACE const ContextOptions* getContextOptions(); - -template -struct DOCTEST_INTERFACE_DECL IsNaN -{ - F value; bool flipped; - IsNaN(F f, bool flip = false) : value(f), flipped(flip) { } - IsNaN operator!() const { return { value, !flipped }; } - operator bool() const; -}; -#ifndef __MINGW32__ -extern template struct DOCTEST_INTERFACE_DECL IsNaN; -extern template struct DOCTEST_INTERFACE_DECL IsNaN; -extern template struct DOCTEST_INTERFACE_DECL IsNaN; -#endif -DOCTEST_INTERFACE String toString(IsNaN in); -DOCTEST_INTERFACE String toString(IsNaN in); -DOCTEST_INTERFACE String toString(IsNaN in); - -#ifndef DOCTEST_CONFIG_DISABLE - -namespace detail { - // clang-format off -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - template struct decay_array { using type = T; }; - template struct decay_array { using type = T*; }; - template struct decay_array { using type = T*; }; - - template struct not_char_pointer { static DOCTEST_CONSTEXPR value = 1; }; - template<> struct not_char_pointer { static DOCTEST_CONSTEXPR value = 0; }; - template<> struct not_char_pointer { static DOCTEST_CONSTEXPR value = 0; }; - - template struct can_use_op : public not_char_pointer::type> {}; -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - // clang-format on - - struct DOCTEST_INTERFACE TestFailureException - { - }; - - DOCTEST_INTERFACE bool checkIfShouldThrow(assertType::Enum at); - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - DOCTEST_NORETURN -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - DOCTEST_INTERFACE void throwException(); - - struct DOCTEST_INTERFACE Subcase - { - SubcaseSignature m_signature; - bool m_entered = false; - - Subcase(const String& name, const char* file, int line); - Subcase(const Subcase&) = delete; - Subcase(Subcase&&) = delete; - Subcase& operator=(const Subcase&) = delete; - Subcase& operator=(Subcase&&) = delete; - ~Subcase(); - - operator bool() const; - - private: - bool checkFilters(); - }; - - template - String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, - const DOCTEST_REF_WRAP(R) rhs) { - return (DOCTEST_STRINGIFY(lhs)) + op + (DOCTEST_STRINGIFY(rhs)); - } - -#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) -DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") -#endif - -// This will check if there is any way it could find a operator like member or friend and uses it. -// If not it doesn't find the operator or if the operator at global scope is defined after -// this template, the template won't be instantiated due to SFINAE. Once the template is not -// instantiated it can look for global operator using normal conversions. -#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),ret{}) - -#define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ - template \ - DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ - bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ - if(m_at & assertType::is_false) \ - res = !res; \ - if(!res || doctest::getContextOptions()->success) \ - return Result(res, stringifyBinaryExpr(lhs, op_str, rhs)); \ - return Result(res); \ - } - - // more checks could be added - like in Catch: - // https://github.com/catchorg/Catch2/pull/1480/files - // https://github.com/catchorg/Catch2/pull/1481/files -#define DOCTEST_FORBIT_EXPRESSION(rt, op) \ - template \ - rt& operator op(const R&) { \ - static_assert(deferred_false::value, \ - "Expression Too Complex Please Rewrite As Binary Comparison!"); \ - return *this; \ - } - - struct DOCTEST_INTERFACE Result // NOLINT(*-member-init) - { - bool m_passed; - String m_decomp; - - Result() = default; // TODO: Why do we need this? (To remove NOLINT) - Result(bool passed, const String& decomposition = String()); - - // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence - DOCTEST_FORBIT_EXPRESSION(Result, &) - DOCTEST_FORBIT_EXPRESSION(Result, ^) - DOCTEST_FORBIT_EXPRESSION(Result, |) - DOCTEST_FORBIT_EXPRESSION(Result, &&) - DOCTEST_FORBIT_EXPRESSION(Result, ||) - DOCTEST_FORBIT_EXPRESSION(Result, ==) - DOCTEST_FORBIT_EXPRESSION(Result, !=) - DOCTEST_FORBIT_EXPRESSION(Result, <) - DOCTEST_FORBIT_EXPRESSION(Result, >) - DOCTEST_FORBIT_EXPRESSION(Result, <=) - DOCTEST_FORBIT_EXPRESSION(Result, >=) - DOCTEST_FORBIT_EXPRESSION(Result, =) - DOCTEST_FORBIT_EXPRESSION(Result, +=) - DOCTEST_FORBIT_EXPRESSION(Result, -=) - DOCTEST_FORBIT_EXPRESSION(Result, *=) - DOCTEST_FORBIT_EXPRESSION(Result, /=) - DOCTEST_FORBIT_EXPRESSION(Result, %=) - DOCTEST_FORBIT_EXPRESSION(Result, <<=) - DOCTEST_FORBIT_EXPRESSION(Result, >>=) - DOCTEST_FORBIT_EXPRESSION(Result, &=) - DOCTEST_FORBIT_EXPRESSION(Result, ^=) - DOCTEST_FORBIT_EXPRESSION(Result, |=) - }; - -#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - - DOCTEST_CLANG_SUPPRESS_WARNING_PUSH - DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") - DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-compare") - //DOCTEST_CLANG_SUPPRESS_WARNING("-Wdouble-promotion") - //DOCTEST_CLANG_SUPPRESS_WARNING("-Wconversion") - //DOCTEST_CLANG_SUPPRESS_WARNING("-Wfloat-equal") - - DOCTEST_GCC_SUPPRESS_WARNING_PUSH - DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") - DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-compare") - //DOCTEST_GCC_SUPPRESS_WARNING("-Wdouble-promotion") - //DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") - //DOCTEST_GCC_SUPPRESS_WARNING("-Wfloat-equal") - - DOCTEST_MSVC_SUPPRESS_WARNING_PUSH - // https://stackoverflow.com/questions/39479163 what's the difference between 4018 and 4389 - DOCTEST_MSVC_SUPPRESS_WARNING(4388) // signed/unsigned mismatch - DOCTEST_MSVC_SUPPRESS_WARNING(4389) // 'operator' : signed/unsigned mismatch - DOCTEST_MSVC_SUPPRESS_WARNING(4018) // 'expression' : signed/unsigned mismatch - //DOCTEST_MSVC_SUPPRESS_WARNING(4805) // 'operation' : unsafe mix of type 'type' and type 'type' in operation - -#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - - // clang-format off -#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_COMPARISON_RETURN_TYPE bool -#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_COMPARISON_RETURN_TYPE typename types::enable_if::value || can_use_op::value, bool>::type - inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } - inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } - inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } - inline bool gt(const char* lhs, const char* rhs) { return String(lhs) > String(rhs); } - inline bool le(const char* lhs, const char* rhs) { return String(lhs) <= String(rhs); } - inline bool ge(const char* lhs, const char* rhs) { return String(lhs) >= String(rhs); } -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - // clang-format on - -#define DOCTEST_RELATIONAL_OP(name, op) \ - template \ - DOCTEST_COMPARISON_RETURN_TYPE name(const DOCTEST_REF_WRAP(L) lhs, \ - const DOCTEST_REF_WRAP(R) rhs) { \ - return lhs op rhs; \ - } - - DOCTEST_RELATIONAL_OP(eq, ==) - DOCTEST_RELATIONAL_OP(ne, !=) - DOCTEST_RELATIONAL_OP(lt, <) - DOCTEST_RELATIONAL_OP(gt, >) - DOCTEST_RELATIONAL_OP(le, <=) - DOCTEST_RELATIONAL_OP(ge, >=) - -#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_CMP_EQ(l, r) l == r -#define DOCTEST_CMP_NE(l, r) l != r -#define DOCTEST_CMP_GT(l, r) l > r -#define DOCTEST_CMP_LT(l, r) l < r -#define DOCTEST_CMP_GE(l, r) l >= r -#define DOCTEST_CMP_LE(l, r) l <= r -#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -#define DOCTEST_CMP_EQ(l, r) eq(l, r) -#define DOCTEST_CMP_NE(l, r) ne(l, r) -#define DOCTEST_CMP_GT(l, r) gt(l, r) -#define DOCTEST_CMP_LT(l, r) lt(l, r) -#define DOCTEST_CMP_GE(l, r) ge(l, r) -#define DOCTEST_CMP_LE(l, r) le(l, r) -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - - template - // cppcheck-suppress copyCtorAndEqOperator - struct Expression_lhs - { - L lhs; - assertType::Enum m_at; - - explicit Expression_lhs(L&& in, assertType::Enum at) - : lhs(static_cast(in)) - , m_at(at) {} - - DOCTEST_NOINLINE operator Result() { -// this is needed only for MSVC 2015 -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4800) // 'int': forcing value to bool - bool res = static_cast(lhs); -DOCTEST_MSVC_SUPPRESS_WARNING_POP - if(m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional - res = !res; - } - - if(!res || getContextOptions()->success) { - return { res, (DOCTEST_STRINGIFY(lhs)) }; - } - return { res }; - } - - /* This is required for user-defined conversions from Expression_lhs to L */ - operator L() const { return lhs; } - - // clang-format off - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(!=, " != ", DOCTEST_CMP_NE) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>, " > ", DOCTEST_CMP_GT) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<, " < ", DOCTEST_CMP_LT) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>=, " >= ", DOCTEST_CMP_GE) //!OCLINT bitwise operator in conditional - DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<=, " <= ", DOCTEST_CMP_LE) //!OCLINT bitwise operator in conditional - // clang-format on - - // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &&) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ||) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, =) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, +=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, -=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, *=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, /=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, %=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^=) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |=) - // these 2 are unfortunate because they should be allowed - they have higher precedence over the comparisons, but the - // ExpressionDecomposer class uses the left shift operator to capture the left operand of the binary expression... - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<) - DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>) - }; - -#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_MSVC_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP - -#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION - -#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) -DOCTEST_CLANG_SUPPRESS_WARNING_POP -#endif - - struct DOCTEST_INTERFACE ExpressionDecomposer - { - assertType::Enum m_at; - - ExpressionDecomposer(assertType::Enum at); - - // The right operator for capturing expressions is "<=" instead of "<<" (based on the operator precedence table) - // but then there will be warnings from GCC about "-Wparentheses" and since "_Pragma()" is problematic this will stay for now... - // https://github.com/catchorg/Catch2/issues/870 - // https://github.com/catchorg/Catch2/issues/565 - template - Expression_lhs operator<<(L&& operand) { - return Expression_lhs(static_cast(operand), m_at); - } - - template ::value,void >::type* = nullptr> - Expression_lhs operator<<(const L &operand) { - return Expression_lhs(operand, m_at); - } - }; - - struct DOCTEST_INTERFACE TestSuite - { - const char* m_test_suite = nullptr; - const char* m_description = nullptr; - bool m_skip = false; - bool m_no_breaks = false; - bool m_no_output = false; - bool m_may_fail = false; - bool m_should_fail = false; - int m_expected_failures = 0; - double m_timeout = 0; - - TestSuite& operator*(const char* in); - - template - TestSuite& operator*(const T& in) { - in.fill(*this); - return *this; - } - }; - - using funcType = void (*)(); - - struct DOCTEST_INTERFACE TestCase : public TestCaseData - { - funcType m_test; // a function pointer to the test case - - String m_type; // for templated test cases - gets appended to the real name - int m_template_id; // an ID used to distinguish between the different versions of a templated test case - String m_full_name; // contains the name (only for templated test cases!) + the template type - - TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, - const String& type = String(), int template_id = -1); - - TestCase(const TestCase& other); - TestCase(TestCase&&) = delete; - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function - TestCase& operator=(const TestCase& other); - DOCTEST_MSVC_SUPPRESS_WARNING_POP - - TestCase& operator=(TestCase&&) = delete; - - TestCase& operator*(const char* in); - - template - TestCase& operator*(const T& in) { - in.fill(*this); - return *this; - } - - bool operator<(const TestCase& other) const; - - ~TestCase() = default; - }; - - // forward declarations of functions used by the macros - DOCTEST_INTERFACE int regTest(const TestCase& tc); - DOCTEST_INTERFACE int setTestSuite(const TestSuite& ts); - DOCTEST_INTERFACE bool isDebuggerActive(); - - template - int instantiationHelper(const T&) { return 0; } - - namespace binaryAssertComparison { - enum Enum - { - eq = 0, - ne, - gt, - lt, - ge, - le - }; - } // namespace binaryAssertComparison - - // clang-format off - template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L), const DOCTEST_REF_WRAP(R) ) const { return false; } }; - -#define DOCTEST_BINARY_RELATIONAL_OP(n, op) \ - template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) const { return op(lhs, rhs); } }; - // clang-format on - - DOCTEST_BINARY_RELATIONAL_OP(0, doctest::detail::eq) - DOCTEST_BINARY_RELATIONAL_OP(1, doctest::detail::ne) - DOCTEST_BINARY_RELATIONAL_OP(2, doctest::detail::gt) - DOCTEST_BINARY_RELATIONAL_OP(3, doctest::detail::lt) - DOCTEST_BINARY_RELATIONAL_OP(4, doctest::detail::ge) - DOCTEST_BINARY_RELATIONAL_OP(5, doctest::detail::le) - - struct DOCTEST_INTERFACE ResultBuilder : public AssertData - { - ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type = "", const String& exception_string = ""); - - ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const Contains& exception_string); - - void setResult(const Result& res); - - template - DOCTEST_NOINLINE bool binary_assert(const DOCTEST_REF_WRAP(L) lhs, - const DOCTEST_REF_WRAP(R) rhs) { - m_failed = !RelationalComparator()(lhs, rhs); - if (m_failed || getContextOptions()->success) { - m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); - } - return !m_failed; - } - - template - DOCTEST_NOINLINE bool unary_assert(const DOCTEST_REF_WRAP(L) val) { - m_failed = !val; - - if (m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional - m_failed = !m_failed; - } - - if (m_failed || getContextOptions()->success) { - m_decomp = (DOCTEST_STRINGIFY(val)); - } - - return !m_failed; - } - - void translateException(); - - bool log(); - void react() const; - }; - - namespace assertAction { - enum Enum - { - nothing = 0, - dbgbreak = 1, - shouldthrow = 2 - }; - } // namespace assertAction - - DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); - - DOCTEST_INTERFACE bool decomp_assert(assertType::Enum at, const char* file, int line, - const char* expr, const Result& result); - -#define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ - do { \ - if(!is_running_in_test) { \ - if(failed) { \ - ResultBuilder rb(at, file, line, expr); \ - rb.m_failed = failed; \ - rb.m_decomp = decomp; \ - failed_out_of_a_testing_context(rb); \ - if(isDebuggerActive() && !getContextOptions()->no_breaks) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - if(checkIfShouldThrow(at)) \ - throwException(); \ - } \ - return !failed; \ - } \ - } while(false) - -#define DOCTEST_ASSERT_IN_TESTS(decomp) \ - ResultBuilder rb(at, file, line, expr); \ - rb.m_failed = failed; \ - if(rb.m_failed || getContextOptions()->success) \ - rb.m_decomp = decomp; \ - if(rb.log()) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - if(rb.m_failed && checkIfShouldThrow(at)) \ - throwException() - - template - DOCTEST_NOINLINE bool binary_assert(assertType::Enum at, const char* file, int line, - const char* expr, const DOCTEST_REF_WRAP(L) lhs, - const DOCTEST_REF_WRAP(R) rhs) { - bool failed = !RelationalComparator()(lhs, rhs); - - // ################################################################################### - // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT - // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED - // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); - DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); - return !failed; - } - - template - DOCTEST_NOINLINE bool unary_assert(assertType::Enum at, const char* file, int line, - const char* expr, const DOCTEST_REF_WRAP(L) val) { - bool failed = !val; - - if(at & assertType::is_false) //!OCLINT bitwise operator in conditional - failed = !failed; - - // ################################################################################### - // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT - // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED - // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS((DOCTEST_STRINGIFY(val))); - DOCTEST_ASSERT_IN_TESTS((DOCTEST_STRINGIFY(val))); - return !failed; - } - - struct DOCTEST_INTERFACE IExceptionTranslator - { - DOCTEST_DECLARE_INTERFACE(IExceptionTranslator) - virtual bool translate(String&) const = 0; - }; - - template - class ExceptionTranslator : public IExceptionTranslator //!OCLINT destructor of virtual class - { - public: - explicit ExceptionTranslator(String (*translateFunction)(T)) - : m_translateFunction(translateFunction) {} - - bool translate(String& res) const override { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - try { - throw; // lgtm [cpp/rethrow-no-exception] - // cppcheck-suppress catchExceptionByValue - } catch(const T& ex) { - res = m_translateFunction(ex); //!OCLINT parameter reassignment - return true; - } catch(...) {} //!OCLINT - empty catch statement -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - static_cast(res); // to silence -Wunused-parameter - return false; - } - - private: - String (*m_translateFunction)(T); - }; - - DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); - - // ContextScope base class used to allow implementing methods of ContextScope - // that don't depend on the template parameter in doctest.cpp. - struct DOCTEST_INTERFACE ContextScopeBase : public IContextScope { - ContextScopeBase(const ContextScopeBase&) = delete; - - ContextScopeBase& operator=(const ContextScopeBase&) = delete; - ContextScopeBase& operator=(ContextScopeBase&&) = delete; - - ~ContextScopeBase() override = default; - - protected: - ContextScopeBase(); - ContextScopeBase(ContextScopeBase&& other) noexcept; - - void destroy(); - bool need_to_destroy{true}; - }; - - template class ContextScope : public ContextScopeBase - { - L lambda_; - - public: - explicit ContextScope(const L &lambda) : lambda_(lambda) {} - explicit ContextScope(L&& lambda) : lambda_(static_cast(lambda)) { } - - ContextScope(const ContextScope&) = delete; - ContextScope(ContextScope&&) noexcept = default; - - ContextScope& operator=(const ContextScope&) = delete; - ContextScope& operator=(ContextScope&&) = delete; - - void stringify(std::ostream* s) const override { lambda_(s); } - - ~ContextScope() override { - if (need_to_destroy) { - destroy(); - } - } - }; - - struct DOCTEST_INTERFACE MessageBuilder : public MessageData - { - std::ostream* m_stream; - bool logged = false; - - MessageBuilder(const char* file, int line, assertType::Enum severity); - - MessageBuilder(const MessageBuilder&) = delete; - MessageBuilder(MessageBuilder&&) = delete; - - MessageBuilder& operator=(const MessageBuilder&) = delete; - MessageBuilder& operator=(MessageBuilder&&) = delete; - - ~MessageBuilder(); - - // the preferred way of chaining parameters for stringification -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) - template - MessageBuilder& operator,(const T& in) { - *m_stream << (DOCTEST_STRINGIFY(in)); - return *this; - } -DOCTEST_MSVC_SUPPRESS_WARNING_POP - - // kept here just for backwards-compatibility - the comma operator should be preferred now - template - MessageBuilder& operator<<(const T& in) { return this->operator,(in); } - - // the `,` operator has the lowest operator precedence - if `<<` is used by the user then - // the `,` operator will be called last which is not what we want and thus the `*` operator - // is used first (has higher operator precedence compared to `<<`) so that we guarantee that - // an operator of the MessageBuilder class is called first before the rest of the parameters - template - MessageBuilder& operator*(const T& in) { return this->operator,(in); } - - bool log(); - void react(); - }; - - template - ContextScope MakeContextScope(const L &lambda) { - return ContextScope(lambda); - } -} // namespace detail - -#define DOCTEST_DEFINE_DECORATOR(name, type, def) \ - struct name \ - { \ - type data; \ - name(type in = def) \ - : data(in) {} \ - void fill(detail::TestCase& state) const { state.DOCTEST_CAT(m_, name) = data; } \ - void fill(detail::TestSuite& state) const { state.DOCTEST_CAT(m_, name) = data; } \ - } - -DOCTEST_DEFINE_DECORATOR(test_suite, const char*, ""); -DOCTEST_DEFINE_DECORATOR(description, const char*, ""); -DOCTEST_DEFINE_DECORATOR(skip, bool, true); -DOCTEST_DEFINE_DECORATOR(no_breaks, bool, true); -DOCTEST_DEFINE_DECORATOR(no_output, bool, true); -DOCTEST_DEFINE_DECORATOR(timeout, double, 0); -DOCTEST_DEFINE_DECORATOR(may_fail, bool, true); -DOCTEST_DEFINE_DECORATOR(should_fail, bool, true); -DOCTEST_DEFINE_DECORATOR(expected_failures, int, 0); - -template -int registerExceptionTranslator(String (*translateFunction)(T)) { - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") - static detail::ExceptionTranslator exceptionTranslator(translateFunction); - DOCTEST_CLANG_SUPPRESS_WARNING_POP - detail::registerExceptionTranslatorImpl(&exceptionTranslator); - return 0; -} - -} // namespace doctest - -// in a separate namespace outside of doctest because the DOCTEST_TEST_SUITE macro -// introduces an anonymous namespace in which getCurrentTestSuite gets overridden -namespace doctest_detail_test_suite_ns { -DOCTEST_INTERFACE doctest::detail::TestSuite& getCurrentTestSuite(); -} // namespace doctest_detail_test_suite_ns - -namespace doctest { -#else // DOCTEST_CONFIG_DISABLE -template -int registerExceptionTranslator(String (*)(T)) { - return 0; -} -#endif // DOCTEST_CONFIG_DISABLE - -namespace detail { - using assert_handler = void (*)(const AssertData&); - struct ContextState; -} // namespace detail - -class DOCTEST_INTERFACE Context -{ - detail::ContextState* p; - - void parseArgs(int argc, const char* const* argv, bool withDefaults = false); - -public: - explicit Context(int argc = 0, const char* const* argv = nullptr); - - Context(const Context&) = delete; - Context(Context&&) = delete; - - Context& operator=(const Context&) = delete; - Context& operator=(Context&&) = delete; - - ~Context(); // NOLINT(performance-trivially-destructible) - - void applyCommandLine(int argc, const char* const* argv); - - void addFilter(const char* filter, const char* value); - void clearFilters(); - void setOption(const char* option, bool value); - void setOption(const char* option, int value); - void setOption(const char* option, const char* value); - - bool shouldExit(); - - void setAsDefaultForAssertsOutOfTestCases(); - - void setAssertHandler(detail::assert_handler ah); - - void setCout(std::ostream* out); - - int run(); -}; - -namespace TestCaseFailureReason { - enum Enum - { - None = 0, - AssertFailure = 1, // an assertion has failed in the test case - Exception = 2, // test case threw an exception - Crash = 4, // a crash... - TooManyFailedAsserts = 8, // the abort-after option - Timeout = 16, // see the timeout decorator - ShouldHaveFailedButDidnt = 32, // see the should_fail decorator - ShouldHaveFailedAndDid = 64, // see the should_fail decorator - DidntFailExactlyNumTimes = 128, // see the expected_failures decorator - FailedExactlyNumTimes = 256, // see the expected_failures decorator - CouldHaveFailedAndDid = 512 // see the may_fail decorator - }; -} // namespace TestCaseFailureReason - -struct DOCTEST_INTERFACE CurrentTestCaseStats -{ - int numAssertsCurrentTest; - int numAssertsFailedCurrentTest; - double seconds; - int failure_flags; // use TestCaseFailureReason::Enum - bool testCaseSuccess; -}; - -struct DOCTEST_INTERFACE TestCaseException -{ - String error_string; - bool is_crash; -}; - -struct DOCTEST_INTERFACE TestRunStats -{ - unsigned numTestCases; - unsigned numTestCasesPassingFilters; - unsigned numTestSuitesPassingFilters; - unsigned numTestCasesFailed; - int numAsserts; - int numAssertsFailed; -}; - -struct QueryData -{ - const TestRunStats* run_stats = nullptr; - const TestCaseData** data = nullptr; - unsigned num_data = 0; -}; - -struct DOCTEST_INTERFACE IReporter -{ - // The constructor has to accept "const ContextOptions&" as a single argument - // which has most of the options for the run + a pointer to the stdout stream - // Reporter(const ContextOptions& in) - - // called when a query should be reported (listing test cases, printing the version, etc.) - virtual void report_query(const QueryData&) = 0; - - // called when the whole test run starts - virtual void test_run_start() = 0; - // called when the whole test run ends (caching a pointer to the input doesn't make sense here) - virtual void test_run_end(const TestRunStats&) = 0; - - // called when a test case is started (safe to cache a pointer to the input) - virtual void test_case_start(const TestCaseData&) = 0; - // called when a test case is reentered because of unfinished subcases (safe to cache a pointer to the input) - virtual void test_case_reenter(const TestCaseData&) = 0; - // called when a test case has ended - virtual void test_case_end(const CurrentTestCaseStats&) = 0; - - // called when an exception is thrown from the test case (or it crashes) - virtual void test_case_exception(const TestCaseException&) = 0; - - // called whenever a subcase is entered (don't cache pointers to the input) - virtual void subcase_start(const SubcaseSignature&) = 0; - // called whenever a subcase is exited (don't cache pointers to the input) - virtual void subcase_end() = 0; - - // called for each assert (don't cache pointers to the input) - virtual void log_assert(const AssertData&) = 0; - // called for each message (don't cache pointers to the input) - virtual void log_message(const MessageData&) = 0; - - // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator - // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) - virtual void test_case_skipped(const TestCaseData&) = 0; - - DOCTEST_DECLARE_INTERFACE(IReporter) - - // can obtain all currently active contexts and stringify them if one wishes to do so - static int get_num_active_contexts(); - static const IContextScope* const* get_active_contexts(); - - // can iterate through contexts which have been stringified automatically in their destructors when an exception has been thrown - static int get_num_stringified_contexts(); - static const String* get_stringified_contexts(); -}; - -namespace detail { - using reporterCreatorFunc = IReporter* (*)(const ContextOptions&); - - DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); - - template - IReporter* reporterCreator(const ContextOptions& o) { - return new Reporter(o); - } -} // namespace detail - -template -int registerReporter(const char* name, int priority, bool isReporter) { - detail::registerReporterImpl(name, priority, detail::reporterCreator, isReporter); - return 0; -} -} // namespace doctest - -#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES -#define DOCTEST_FUNC_EMPTY [] { return false; }() -#else -#define DOCTEST_FUNC_EMPTY (void)0 -#endif - -// if registering is not disabled -#ifndef DOCTEST_CONFIG_DISABLE - -#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES -#define DOCTEST_FUNC_SCOPE_BEGIN [&] -#define DOCTEST_FUNC_SCOPE_END () -#define DOCTEST_FUNC_SCOPE_RET(v) return v -#else -#define DOCTEST_FUNC_SCOPE_BEGIN do -#define DOCTEST_FUNC_SCOPE_END while(false) -#define DOCTEST_FUNC_SCOPE_RET(v) (void)0 -#endif - -// common code in asserts - for convenience -#define DOCTEST_ASSERT_LOG_REACT_RETURN(b) \ - if(b.log()) DOCTEST_BREAK_INTO_DEBUGGER(); \ - b.react(); \ - DOCTEST_FUNC_SCOPE_RET(!b.m_failed) - -#ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS -#define DOCTEST_WRAP_IN_TRY(x) x; -#else // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS -#define DOCTEST_WRAP_IN_TRY(x) \ - try { \ - x; \ - } catch(...) { DOCTEST_RB.translateException(); } -#endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS - -#ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS -#define DOCTEST_CAST_TO_VOID(...) \ - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wuseless-cast") \ - static_cast(__VA_ARGS__); \ - DOCTEST_GCC_SUPPRESS_WARNING_POP -#else // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS -#define DOCTEST_CAST_TO_VOID(...) __VA_ARGS__; -#endif // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS - -// registers the test by initializing a dummy var with a function -#define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ - global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT */ \ - doctest::detail::regTest( \ - doctest::detail::TestCase( \ - f, __FILE__, __LINE__, \ - doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ - decorators)) - -#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ - namespace { /* NOLINT */ \ - struct der : public base \ - { \ - void f(); \ - }; \ - static inline DOCTEST_NOINLINE void func() { \ - der v; \ - v.f(); \ - } \ - DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ - } \ - inline DOCTEST_NOINLINE void der::f() // NOLINT(misc-definitions-in-headers) - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ - static void f(); \ - DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, f, decorators) \ - static void f() - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ - static doctest::detail::funcType proxy() { return f; } \ - DOCTEST_REGISTER_FUNCTION(inline, proxy(), decorators) \ - static void f() - -// for registering tests -#define DOCTEST_TEST_CASE(decorators) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) - -// for registering tests in classes - requires C++17 for inline variables! -#if DOCTEST_CPLUSPLUS >= 201703L -#define DOCTEST_TEST_CASE_CLASS(decorators) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_PROXY_), \ - decorators) -#else // DOCTEST_TEST_CASE_CLASS -#define DOCTEST_TEST_CASE_CLASS(...) \ - TEST_CASES_CAN_BE_REGISTERED_IN_CLASSES_ONLY_IN_CPP17_MODE_OR_WITH_VS_2017_OR_NEWER -#endif // DOCTEST_TEST_CASE_CLASS - -// for registering tests with a fixture -#define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ - DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), c, \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) - -// for converting types to strings without the header and demangling -#define DOCTEST_TYPE_TO_STRING_AS(str, ...) \ - namespace doctest { \ - template <> \ - inline String toString<__VA_ARGS__>() { \ - return str; \ - } \ - } \ - static_assert(true, "") - -#define DOCTEST_TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING_AS(#__VA_ARGS__, __VA_ARGS__) - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ - template \ - static void func(); \ - namespace { /* NOLINT */ \ - template \ - struct iter; \ - template \ - struct iter> \ - { \ - iter(const char* file, unsigned line, int index) { \ - doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ - doctest_detail_test_suite_ns::getCurrentTestSuite(), \ - doctest::toString(), \ - int(line) * 1000 + index) \ - * dec); \ - iter>(file, line, index + 1); \ - } \ - }; \ - template <> \ - struct iter> \ - { \ - iter(const char*, unsigned, int) {} \ - }; \ - } \ - template \ - static void func() - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ - DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)) - -#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY), /* NOLINT(cert-err58-cpp, fuchsia-statically-constructed-objects) */ \ - doctest::detail::instantiationHelper( \ - DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0))) - -#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ - static_assert(true, "") - -#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) \ - static_assert(true, "") - -#define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ - DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(anon, anon, std::tuple<__VA_ARGS__>) \ - template \ - static void anon() - -#define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ - DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) - -// for subcases -#define DOCTEST_SUBCASE(name) \ - if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ - doctest::detail::Subcase(name, __FILE__, __LINE__)) - -// for grouping tests in test suites by using code blocks -#define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ - namespace ns_name { namespace doctest_detail_test_suite_ns { \ - static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() noexcept { \ - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmissing-field-initializers") \ - static doctest::detail::TestSuite data{}; \ - static bool inited = false; \ - DOCTEST_MSVC_SUPPRESS_WARNING_POP \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP \ - DOCTEST_GCC_SUPPRESS_WARNING_POP \ - if(!inited) { \ - data* decorators; \ - inited = true; \ - } \ - return data; \ - } \ - } \ - } \ - namespace ns_name - -#define DOCTEST_TEST_SUITE(decorators) \ - DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(DOCTEST_ANON_SUITE_)) - -// for starting a testsuite block -#define DOCTEST_TEST_SUITE_BEGIN(decorators) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ - doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators)) \ - static_assert(true, "") - -// for ending a testsuite block -#define DOCTEST_TEST_SUITE_END \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ - doctest::detail::setTestSuite(doctest::detail::TestSuite() * "")) \ - using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int - -// for registering exception translators -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ - inline doctest::String translatorName(signature); \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), /* NOLINT(cert-err58-cpp) */ \ - doctest::registerExceptionTranslator(translatorName)) \ - doctest::String translatorName(signature) - -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ - DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), \ - signature) - -// for registering reporters -#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ - doctest::registerReporter(name, priority, true)) \ - static_assert(true, "") - -// for registering listeners -#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ - DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ - doctest::registerReporter(name, priority, false)) \ - static_assert(true, "") - -// clang-format off -// for logging - disabling formatting because it's important to have these on 2 separate lines - see PR #557 -#define DOCTEST_INFO(...) \ - DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_), \ - DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_OTHER_), \ - __VA_ARGS__) -// clang-format on - -#define DOCTEST_INFO_IMPL(mb_name, s_name, ...) \ - auto DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ - [&](std::ostream* s_name) { \ - doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ - mb_name.m_stream = s_name; \ - mb_name * __VA_ARGS__; \ - }) - -#define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := ", x) - -#define DOCTEST_ADD_AT_IMPL(type, file, line, mb, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ - mb * __VA_ARGS__; \ - if(mb.log()) \ - DOCTEST_BREAK_INTO_DEBUGGER(); \ - mb.react(); \ - } DOCTEST_FUNC_SCOPE_END - -// clang-format off -#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) -#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) -#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) -// clang-format on - -#define DOCTEST_MESSAGE(...) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, __VA_ARGS__) -#define DOCTEST_FAIL_CHECK(...) DOCTEST_ADD_FAIL_CHECK_AT(__FILE__, __LINE__, __VA_ARGS__) -#define DOCTEST_FAIL(...) DOCTEST_ADD_FAIL_AT(__FILE__, __LINE__, __VA_ARGS__) - -#define DOCTEST_TO_LVALUE(...) __VA_ARGS__ // Not removed to keep backwards compatibility. - -#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -#define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ - /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY(DOCTEST_RB.setResult( \ - doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ - << __VA_ARGS__)) /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB) \ - DOCTEST_CLANG_SUPPRESS_WARNING_POP - -#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ - } DOCTEST_FUNC_SCOPE_END // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - -#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY( \ - DOCTEST_RB.binary_assert( \ - __VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } DOCTEST_FUNC_SCOPE_END - -#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - DOCTEST_WRAP_IN_TRY(DOCTEST_RB.unary_assert(__VA_ARGS__)) \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } DOCTEST_FUNC_SCOPE_END - -#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -// necessary for _MESSAGE -#define DOCTEST_ASSERT_IMPLEMENT_2 DOCTEST_ASSERT_IMPLEMENT_1 - -#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ - doctest::detail::decomp_assert( \ - doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, \ - doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ - << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP - -#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ - doctest::detail::binary_assert( \ - doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) - -#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ - doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ - #__VA_ARGS__, __VA_ARGS__) - -#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS - -#define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) -#define DOCTEST_CHECK(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK, __VA_ARGS__) -#define DOCTEST_REQUIRE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE, __VA_ARGS__) -#define DOCTEST_WARN_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN_FALSE, __VA_ARGS__) -#define DOCTEST_CHECK_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK_FALSE, __VA_ARGS__) -#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) - -// clang-format off -#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } DOCTEST_FUNC_SCOPE_END -// clang-format on - -#define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) -#define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) -#define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) -#define DOCTEST_WARN_NE(...) DOCTEST_BINARY_ASSERT(DT_WARN_NE, ne, __VA_ARGS__) -#define DOCTEST_CHECK_NE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_NE, ne, __VA_ARGS__) -#define DOCTEST_REQUIRE_NE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_NE, ne, __VA_ARGS__) -#define DOCTEST_WARN_GT(...) DOCTEST_BINARY_ASSERT(DT_WARN_GT, gt, __VA_ARGS__) -#define DOCTEST_CHECK_GT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GT, gt, __VA_ARGS__) -#define DOCTEST_REQUIRE_GT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GT, gt, __VA_ARGS__) -#define DOCTEST_WARN_LT(...) DOCTEST_BINARY_ASSERT(DT_WARN_LT, lt, __VA_ARGS__) -#define DOCTEST_CHECK_LT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LT, lt, __VA_ARGS__) -#define DOCTEST_REQUIRE_LT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LT, lt, __VA_ARGS__) -#define DOCTEST_WARN_GE(...) DOCTEST_BINARY_ASSERT(DT_WARN_GE, ge, __VA_ARGS__) -#define DOCTEST_CHECK_GE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GE, ge, __VA_ARGS__) -#define DOCTEST_REQUIRE_GE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GE, ge, __VA_ARGS__) -#define DOCTEST_WARN_LE(...) DOCTEST_BINARY_ASSERT(DT_WARN_LE, le, __VA_ARGS__) -#define DOCTEST_CHECK_LE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LE, le, __VA_ARGS__) -#define DOCTEST_REQUIRE_LE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LE, le, __VA_ARGS__) - -#define DOCTEST_WARN_UNARY(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY, __VA_ARGS__) -#define DOCTEST_CHECK_UNARY(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY, __VA_ARGS__) -#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY, __VA_ARGS__) -#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY_FALSE, __VA_ARGS__) -#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) -#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - -#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - if(!doctest::getContextOptions()->no_throw) { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #expr, #__VA_ARGS__, message); \ - try { \ - DOCTEST_CAST_TO_VOID(expr) \ - } catch(const typename doctest::detail::types::remove_const< \ - typename doctest::detail::types::remove_reference<__VA_ARGS__>::type>::type&) {\ - DOCTEST_RB.translateException(); \ - DOCTEST_RB.m_threw_as = true; \ - } catch(...) { DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } else { /* NOLINT(*-else-after-return) */ \ - DOCTEST_FUNC_SCOPE_RET(false); \ - } \ - } DOCTEST_FUNC_SCOPE_END - -#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - if(!doctest::getContextOptions()->no_throw) { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, expr_str, "", __VA_ARGS__); \ - try { \ - DOCTEST_CAST_TO_VOID(expr) \ - } catch(...) { DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } else { /* NOLINT(*-else-after-return) */ \ - DOCTEST_FUNC_SCOPE_RET(false); \ - } \ - } DOCTEST_FUNC_SCOPE_END - -#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ - DOCTEST_FUNC_SCOPE_BEGIN { \ - doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ - __LINE__, #__VA_ARGS__); \ - try { \ - DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ - } catch(...) { DOCTEST_RB.translateException(); } \ - DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ - } DOCTEST_FUNC_SCOPE_END - -// clang-format off -#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") -#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") - -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) - -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END -// clang-format on - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -// ================================================================================================= -// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == -// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == -// ================================================================================================= -#else // DOCTEST_CONFIG_DISABLE - -#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ - namespace /* NOLINT */ { \ - template \ - struct der : public base \ - { void f(); }; \ - } \ - template \ - inline void der::f() - -#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ - template \ - static inline void f() - -// for registering tests -#define DOCTEST_TEST_CASE(name) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) - -// for registering tests in classes -#define DOCTEST_TEST_CASE_CLASS(name) \ - DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) - -// for registering tests with a fixture -#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ - DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), x, \ - DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) - -// for converting types to strings without the header and demangling -#define DOCTEST_TYPE_TO_STRING_AS(str, ...) static_assert(true, "") -#define DOCTEST_TYPE_TO_STRING(...) static_assert(true, "") - -// for typed tests -#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ - template \ - inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() - -#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ - template \ - inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() - -#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) static_assert(true, "") -#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) static_assert(true, "") - -// for subcases -#define DOCTEST_SUBCASE(name) - -// for a testsuite block -#define DOCTEST_TEST_SUITE(name) namespace // NOLINT - -// for starting a testsuite block -#define DOCTEST_TEST_SUITE_BEGIN(name) static_assert(true, "") - -// for ending a testsuite block -#define DOCTEST_TEST_SUITE_END using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int - -#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ - template \ - static inline doctest::String DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_)(signature) - -#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) -#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) - -#define DOCTEST_INFO(...) (static_cast(0)) -#define DOCTEST_CAPTURE(x) (static_cast(0)) -#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) -#define DOCTEST_MESSAGE(...) (static_cast(0)) -#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) -#define DOCTEST_FAIL(...) (static_cast(0)) - -#if defined(DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED) \ - && defined(DOCTEST_CONFIG_ASSERTS_RETURN_VALUES) - -#define DOCTEST_WARN(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_CHECK(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_REQUIRE(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_WARN_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_CHECK_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_FALSE(...) [&] { return !(__VA_ARGS__); }() - -#define DOCTEST_WARN_MESSAGE(cond, ...) [&] { return cond; }() -#define DOCTEST_CHECK_MESSAGE(cond, ...) [&] { return cond; }() -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) [&] { return cond; }() -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() - -namespace doctest { -namespace detail { -#define DOCTEST_RELATIONAL_OP(name, op) \ - template \ - bool name(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { return lhs op rhs; } - - DOCTEST_RELATIONAL_OP(eq, ==) - DOCTEST_RELATIONAL_OP(ne, !=) - DOCTEST_RELATIONAL_OP(lt, <) - DOCTEST_RELATIONAL_OP(gt, >) - DOCTEST_RELATIONAL_OP(le, <=) - DOCTEST_RELATIONAL_OP(ge, >=) -} // namespace detail -} // namespace doctest - -#define DOCTEST_WARN_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() -#define DOCTEST_CHECK_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() -#define DOCTEST_WARN_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() -#define DOCTEST_CHECK_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() -#define DOCTEST_WARN_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() -#define DOCTEST_CHECK_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() -#define DOCTEST_WARN_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() -#define DOCTEST_CHECK_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() -#define DOCTEST_WARN_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() -#define DOCTEST_CHECK_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() -#define DOCTEST_WARN_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() -#define DOCTEST_CHECK_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() -#define DOCTEST_WARN_UNARY(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_CHECK_UNARY(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_REQUIRE_UNARY(...) [&] { return __VA_ARGS__; }() -#define DOCTEST_WARN_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_CHECK_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() -#define DOCTEST_REQUIRE_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - -#define DOCTEST_WARN_THROWS_WITH(expr, with, ...) [] { static_assert(false, "Exception translation is not available when doctest is disabled."); return false; }() -#define DOCTEST_CHECK_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) - -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) - -#define DOCTEST_WARN_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_CHECK_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_REQUIRE_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_WARN_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_CHECK_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_WARN_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_CHECK_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_REQUIRE_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -#else // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED - -#define DOCTEST_WARN(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_EQ(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_EQ(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_EQ(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_NE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_NE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_NE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_GT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_GT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_GT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_LT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_LT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_LT(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_GE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_GE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_GE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_LE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_LE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_LE(...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_UNARY(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_UNARY(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - -#define DOCTEST_WARN_THROWS(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_FUNC_EMPTY - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -#endif // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED - -#endif // DOCTEST_CONFIG_DISABLE - -#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS - -#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS -#define DOCTEST_EXCEPTION_EMPTY_FUNC DOCTEST_FUNC_EMPTY -#else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS -#define DOCTEST_EXCEPTION_EMPTY_FUNC [] { static_assert(false, "Exceptions are disabled! " \ - "Use DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS if you want to compile with exceptions disabled."); return false; }() - -#undef DOCTEST_REQUIRE -#undef DOCTEST_REQUIRE_FALSE -#undef DOCTEST_REQUIRE_MESSAGE -#undef DOCTEST_REQUIRE_FALSE_MESSAGE -#undef DOCTEST_REQUIRE_EQ -#undef DOCTEST_REQUIRE_NE -#undef DOCTEST_REQUIRE_GT -#undef DOCTEST_REQUIRE_LT -#undef DOCTEST_REQUIRE_GE -#undef DOCTEST_REQUIRE_LE -#undef DOCTEST_REQUIRE_UNARY -#undef DOCTEST_REQUIRE_UNARY_FALSE - -#define DOCTEST_REQUIRE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_FALSE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_EQ DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_NE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_GT DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_LT DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_GE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_LE DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_UNARY DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_UNARY_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS - -#define DOCTEST_WARN_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC - -#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC -#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC - -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - -// clang-format off -// KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS -#define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ -#define DOCTEST_FAST_CHECK_EQ DOCTEST_CHECK_EQ -#define DOCTEST_FAST_REQUIRE_EQ DOCTEST_REQUIRE_EQ -#define DOCTEST_FAST_WARN_NE DOCTEST_WARN_NE -#define DOCTEST_FAST_CHECK_NE DOCTEST_CHECK_NE -#define DOCTEST_FAST_REQUIRE_NE DOCTEST_REQUIRE_NE -#define DOCTEST_FAST_WARN_GT DOCTEST_WARN_GT -#define DOCTEST_FAST_CHECK_GT DOCTEST_CHECK_GT -#define DOCTEST_FAST_REQUIRE_GT DOCTEST_REQUIRE_GT -#define DOCTEST_FAST_WARN_LT DOCTEST_WARN_LT -#define DOCTEST_FAST_CHECK_LT DOCTEST_CHECK_LT -#define DOCTEST_FAST_REQUIRE_LT DOCTEST_REQUIRE_LT -#define DOCTEST_FAST_WARN_GE DOCTEST_WARN_GE -#define DOCTEST_FAST_CHECK_GE DOCTEST_CHECK_GE -#define DOCTEST_FAST_REQUIRE_GE DOCTEST_REQUIRE_GE -#define DOCTEST_FAST_WARN_LE DOCTEST_WARN_LE -#define DOCTEST_FAST_CHECK_LE DOCTEST_CHECK_LE -#define DOCTEST_FAST_REQUIRE_LE DOCTEST_REQUIRE_LE - -#define DOCTEST_FAST_WARN_UNARY DOCTEST_WARN_UNARY -#define DOCTEST_FAST_CHECK_UNARY DOCTEST_CHECK_UNARY -#define DOCTEST_FAST_REQUIRE_UNARY DOCTEST_REQUIRE_UNARY -#define DOCTEST_FAST_WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE -#define DOCTEST_FAST_CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE -#define DOCTEST_FAST_REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE - -#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id,__VA_ARGS__) -// clang-format on - -// BDD style macros -// clang-format off -#define DOCTEST_SCENARIO(name) DOCTEST_TEST_CASE(" Scenario: " name) -#define DOCTEST_SCENARIO_CLASS(name) DOCTEST_TEST_CASE_CLASS(" Scenario: " name) -#define DOCTEST_SCENARIO_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(" Scenario: " name, T, __VA_ARGS__) -#define DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(" Scenario: " name, T, id) - -#define DOCTEST_GIVEN(name) DOCTEST_SUBCASE(" Given: " name) -#define DOCTEST_WHEN(name) DOCTEST_SUBCASE(" When: " name) -#define DOCTEST_AND_WHEN(name) DOCTEST_SUBCASE("And when: " name) -#define DOCTEST_THEN(name) DOCTEST_SUBCASE(" Then: " name) -#define DOCTEST_AND_THEN(name) DOCTEST_SUBCASE(" And: " name) -// clang-format on - -// == SHORT VERSIONS OF THE MACROS -#ifndef DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES - -#define TEST_CASE(name) DOCTEST_TEST_CASE(name) -#define TEST_CASE_CLASS(name) DOCTEST_TEST_CASE_CLASS(name) -#define TEST_CASE_FIXTURE(x, name) DOCTEST_TEST_CASE_FIXTURE(x, name) -#define TYPE_TO_STRING_AS(str, ...) DOCTEST_TYPE_TO_STRING_AS(str, __VA_ARGS__) -#define TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING(__VA_ARGS__) -#define TEST_CASE_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(name, T, __VA_ARGS__) -#define TEST_CASE_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, T, id) -#define TEST_CASE_TEMPLATE_INVOKE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, __VA_ARGS__) -#define TEST_CASE_TEMPLATE_APPLY(id, ...) DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, __VA_ARGS__) -#define SUBCASE(name) DOCTEST_SUBCASE(name) -#define TEST_SUITE(decorators) DOCTEST_TEST_SUITE(decorators) -#define TEST_SUITE_BEGIN(name) DOCTEST_TEST_SUITE_BEGIN(name) -#define TEST_SUITE_END DOCTEST_TEST_SUITE_END -#define REGISTER_EXCEPTION_TRANSLATOR(signature) DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) -#define REGISTER_REPORTER(name, priority, reporter) DOCTEST_REGISTER_REPORTER(name, priority, reporter) -#define REGISTER_LISTENER(name, priority, reporter) DOCTEST_REGISTER_LISTENER(name, priority, reporter) -#define INFO(...) DOCTEST_INFO(__VA_ARGS__) -#define CAPTURE(x) DOCTEST_CAPTURE(x) -#define ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_MESSAGE_AT(file, line, __VA_ARGS__) -#define ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_FAIL_CHECK_AT(file, line, __VA_ARGS__) -#define ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_FAIL_AT(file, line, __VA_ARGS__) -#define MESSAGE(...) DOCTEST_MESSAGE(__VA_ARGS__) -#define FAIL_CHECK(...) DOCTEST_FAIL_CHECK(__VA_ARGS__) -#define FAIL(...) DOCTEST_FAIL(__VA_ARGS__) -#define TO_LVALUE(...) DOCTEST_TO_LVALUE(__VA_ARGS__) - -#define WARN(...) DOCTEST_WARN(__VA_ARGS__) -#define WARN_FALSE(...) DOCTEST_WARN_FALSE(__VA_ARGS__) -#define WARN_THROWS(...) DOCTEST_WARN_THROWS(__VA_ARGS__) -#define WARN_THROWS_AS(expr, ...) DOCTEST_WARN_THROWS_AS(expr, __VA_ARGS__) -#define WARN_THROWS_WITH(expr, ...) DOCTEST_WARN_THROWS_WITH(expr, __VA_ARGS__) -#define WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_WARN_THROWS_WITH_AS(expr, with, __VA_ARGS__) -#define WARN_NOTHROW(...) DOCTEST_WARN_NOTHROW(__VA_ARGS__) -#define CHECK(...) DOCTEST_CHECK(__VA_ARGS__) -#define CHECK_FALSE(...) DOCTEST_CHECK_FALSE(__VA_ARGS__) -#define CHECK_THROWS(...) DOCTEST_CHECK_THROWS(__VA_ARGS__) -#define CHECK_THROWS_AS(expr, ...) DOCTEST_CHECK_THROWS_AS(expr, __VA_ARGS__) -#define CHECK_THROWS_WITH(expr, ...) DOCTEST_CHECK_THROWS_WITH(expr, __VA_ARGS__) -#define CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_AS(expr, with, __VA_ARGS__) -#define CHECK_NOTHROW(...) DOCTEST_CHECK_NOTHROW(__VA_ARGS__) -#define REQUIRE(...) DOCTEST_REQUIRE(__VA_ARGS__) -#define REQUIRE_FALSE(...) DOCTEST_REQUIRE_FALSE(__VA_ARGS__) -#define REQUIRE_THROWS(...) DOCTEST_REQUIRE_THROWS(__VA_ARGS__) -#define REQUIRE_THROWS_AS(expr, ...) DOCTEST_REQUIRE_THROWS_AS(expr, __VA_ARGS__) -#define REQUIRE_THROWS_WITH(expr, ...) DOCTEST_REQUIRE_THROWS_WITH(expr, __VA_ARGS__) -#define REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, __VA_ARGS__) -#define REQUIRE_NOTHROW(...) DOCTEST_REQUIRE_NOTHROW(__VA_ARGS__) - -#define WARN_MESSAGE(cond, ...) DOCTEST_WARN_MESSAGE(cond, __VA_ARGS__) -#define WARN_FALSE_MESSAGE(cond, ...) DOCTEST_WARN_FALSE_MESSAGE(cond, __VA_ARGS__) -#define WARN_THROWS_MESSAGE(expr, ...) DOCTEST_WARN_THROWS_MESSAGE(expr, __VA_ARGS__) -#define WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) -#define WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) -#define WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) -#define WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_WARN_NOTHROW_MESSAGE(expr, __VA_ARGS__) -#define CHECK_MESSAGE(cond, ...) DOCTEST_CHECK_MESSAGE(cond, __VA_ARGS__) -#define CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_CHECK_FALSE_MESSAGE(cond, __VA_ARGS__) -#define CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_CHECK_THROWS_MESSAGE(expr, __VA_ARGS__) -#define CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) -#define CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) -#define CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) -#define CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_CHECK_NOTHROW_MESSAGE(expr, __VA_ARGS__) -#define REQUIRE_MESSAGE(cond, ...) DOCTEST_REQUIRE_MESSAGE(cond, __VA_ARGS__) -#define REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_REQUIRE_FALSE_MESSAGE(cond, __VA_ARGS__) -#define REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_REQUIRE_THROWS_MESSAGE(expr, __VA_ARGS__) -#define REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) -#define REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) -#define REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) -#define REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, __VA_ARGS__) - -#define SCENARIO(name) DOCTEST_SCENARIO(name) -#define SCENARIO_CLASS(name) DOCTEST_SCENARIO_CLASS(name) -#define SCENARIO_TEMPLATE(name, T, ...) DOCTEST_SCENARIO_TEMPLATE(name, T, __VA_ARGS__) -#define SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) -#define GIVEN(name) DOCTEST_GIVEN(name) -#define WHEN(name) DOCTEST_WHEN(name) -#define AND_WHEN(name) DOCTEST_AND_WHEN(name) -#define THEN(name) DOCTEST_THEN(name) -#define AND_THEN(name) DOCTEST_AND_THEN(name) - -#define WARN_EQ(...) DOCTEST_WARN_EQ(__VA_ARGS__) -#define CHECK_EQ(...) DOCTEST_CHECK_EQ(__VA_ARGS__) -#define REQUIRE_EQ(...) DOCTEST_REQUIRE_EQ(__VA_ARGS__) -#define WARN_NE(...) DOCTEST_WARN_NE(__VA_ARGS__) -#define CHECK_NE(...) DOCTEST_CHECK_NE(__VA_ARGS__) -#define REQUIRE_NE(...) DOCTEST_REQUIRE_NE(__VA_ARGS__) -#define WARN_GT(...) DOCTEST_WARN_GT(__VA_ARGS__) -#define CHECK_GT(...) DOCTEST_CHECK_GT(__VA_ARGS__) -#define REQUIRE_GT(...) DOCTEST_REQUIRE_GT(__VA_ARGS__) -#define WARN_LT(...) DOCTEST_WARN_LT(__VA_ARGS__) -#define CHECK_LT(...) DOCTEST_CHECK_LT(__VA_ARGS__) -#define REQUIRE_LT(...) DOCTEST_REQUIRE_LT(__VA_ARGS__) -#define WARN_GE(...) DOCTEST_WARN_GE(__VA_ARGS__) -#define CHECK_GE(...) DOCTEST_CHECK_GE(__VA_ARGS__) -#define REQUIRE_GE(...) DOCTEST_REQUIRE_GE(__VA_ARGS__) -#define WARN_LE(...) DOCTEST_WARN_LE(__VA_ARGS__) -#define CHECK_LE(...) DOCTEST_CHECK_LE(__VA_ARGS__) -#define REQUIRE_LE(...) DOCTEST_REQUIRE_LE(__VA_ARGS__) -#define WARN_UNARY(...) DOCTEST_WARN_UNARY(__VA_ARGS__) -#define CHECK_UNARY(...) DOCTEST_CHECK_UNARY(__VA_ARGS__) -#define REQUIRE_UNARY(...) DOCTEST_REQUIRE_UNARY(__VA_ARGS__) -#define WARN_UNARY_FALSE(...) DOCTEST_WARN_UNARY_FALSE(__VA_ARGS__) -#define CHECK_UNARY_FALSE(...) DOCTEST_CHECK_UNARY_FALSE(__VA_ARGS__) -#define REQUIRE_UNARY_FALSE(...) DOCTEST_REQUIRE_UNARY_FALSE(__VA_ARGS__) - -// KEPT FOR BACKWARDS COMPATIBILITY -#define FAST_WARN_EQ(...) DOCTEST_FAST_WARN_EQ(__VA_ARGS__) -#define FAST_CHECK_EQ(...) DOCTEST_FAST_CHECK_EQ(__VA_ARGS__) -#define FAST_REQUIRE_EQ(...) DOCTEST_FAST_REQUIRE_EQ(__VA_ARGS__) -#define FAST_WARN_NE(...) DOCTEST_FAST_WARN_NE(__VA_ARGS__) -#define FAST_CHECK_NE(...) DOCTEST_FAST_CHECK_NE(__VA_ARGS__) -#define FAST_REQUIRE_NE(...) DOCTEST_FAST_REQUIRE_NE(__VA_ARGS__) -#define FAST_WARN_GT(...) DOCTEST_FAST_WARN_GT(__VA_ARGS__) -#define FAST_CHECK_GT(...) DOCTEST_FAST_CHECK_GT(__VA_ARGS__) -#define FAST_REQUIRE_GT(...) DOCTEST_FAST_REQUIRE_GT(__VA_ARGS__) -#define FAST_WARN_LT(...) DOCTEST_FAST_WARN_LT(__VA_ARGS__) -#define FAST_CHECK_LT(...) DOCTEST_FAST_CHECK_LT(__VA_ARGS__) -#define FAST_REQUIRE_LT(...) DOCTEST_FAST_REQUIRE_LT(__VA_ARGS__) -#define FAST_WARN_GE(...) DOCTEST_FAST_WARN_GE(__VA_ARGS__) -#define FAST_CHECK_GE(...) DOCTEST_FAST_CHECK_GE(__VA_ARGS__) -#define FAST_REQUIRE_GE(...) DOCTEST_FAST_REQUIRE_GE(__VA_ARGS__) -#define FAST_WARN_LE(...) DOCTEST_FAST_WARN_LE(__VA_ARGS__) -#define FAST_CHECK_LE(...) DOCTEST_FAST_CHECK_LE(__VA_ARGS__) -#define FAST_REQUIRE_LE(...) DOCTEST_FAST_REQUIRE_LE(__VA_ARGS__) - -#define FAST_WARN_UNARY(...) DOCTEST_FAST_WARN_UNARY(__VA_ARGS__) -#define FAST_CHECK_UNARY(...) DOCTEST_FAST_CHECK_UNARY(__VA_ARGS__) -#define FAST_REQUIRE_UNARY(...) DOCTEST_FAST_REQUIRE_UNARY(__VA_ARGS__) -#define FAST_WARN_UNARY_FALSE(...) DOCTEST_FAST_WARN_UNARY_FALSE(__VA_ARGS__) -#define FAST_CHECK_UNARY_FALSE(...) DOCTEST_FAST_CHECK_UNARY_FALSE(__VA_ARGS__) -#define FAST_REQUIRE_UNARY_FALSE(...) DOCTEST_FAST_REQUIRE_UNARY_FALSE(__VA_ARGS__) - -#define TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, __VA_ARGS__) - -#endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES - -#ifndef DOCTEST_CONFIG_DISABLE - -// this is here to clear the 'current test suite' for the current translation unit - at the top -DOCTEST_TEST_SUITE_END(); - -#endif // DOCTEST_CONFIG_DISABLE - -DOCTEST_CLANG_SUPPRESS_WARNING_POP -DOCTEST_MSVC_SUPPRESS_WARNING_POP -DOCTEST_GCC_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_POP - -#endif // DOCTEST_LIBRARY_INCLUDED - -#ifndef DOCTEST_SINGLE_HEADER -#define DOCTEST_SINGLE_HEADER -#endif // DOCTEST_SINGLE_HEADER - -#if defined(DOCTEST_CONFIG_IMPLEMENT) || !defined(DOCTEST_SINGLE_HEADER) - -#ifndef DOCTEST_SINGLE_HEADER -#include "doctest_fwd.h" -#endif // DOCTEST_SINGLE_HEADER - -DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") - -#ifndef DOCTEST_LIBRARY_IMPLEMENTATION -#define DOCTEST_LIBRARY_IMPLEMENTATION - -DOCTEST_CLANG_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH - -DOCTEST_CLANG_SUPPRESS_WARNING_PUSH -DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") -DOCTEST_CLANG_SUPPRESS_WARNING("-Wnonportable-system-include-path") - -DOCTEST_GCC_SUPPRESS_WARNING_PUSH -DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") -DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") -DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") -DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") -DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") -DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") -DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") -DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") -DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") -DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") - -DOCTEST_MSVC_SUPPRESS_WARNING_PUSH -DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data -DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled -DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified -DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal -DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch -DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C -DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) -DOCTEST_MSVC_SUPPRESS_WARNING(5245) // unreferenced function with internal linkage has been removed - -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN - -// required includes - will go only in one translation unit! -#include -#include -#include -// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/doctest/doctest/pull/37 -#ifdef __BORLANDC__ -#include -#endif // __BORLANDC__ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef DOCTEST_CONFIG_NO_MULTITHREADING -#include -#include -#define DOCTEST_DECLARE_MUTEX(name) std::mutex name; -#define DOCTEST_DECLARE_STATIC_MUTEX(name) static DOCTEST_DECLARE_MUTEX(name) -#define DOCTEST_LOCK_MUTEX(name) std::lock_guard DOCTEST_ANONYMOUS(DOCTEST_ANON_LOCK_)(name); -#else // DOCTEST_CONFIG_NO_MULTITHREADING -#define DOCTEST_DECLARE_MUTEX(name) -#define DOCTEST_DECLARE_STATIC_MUTEX(name) -#define DOCTEST_LOCK_MUTEX(name) -#endif // DOCTEST_CONFIG_NO_MULTITHREADING -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef DOCTEST_PLATFORM_MAC -#include -#include -#include -#endif // DOCTEST_PLATFORM_MAC - -#ifdef DOCTEST_PLATFORM_WINDOWS - -// defines for a leaner windows.h -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif // WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -#define NOMINMAX -#endif // NOMINMAX - -// not sure what AfxWin.h is for - here I do what Catch does -#ifdef __AFXDLL -#include -#else -#include -#endif -#include - -#else // DOCTEST_PLATFORM_WINDOWS - -#include -#include - -#endif // DOCTEST_PLATFORM_WINDOWS - -// this is a fix for https://github.com/doctest/doctest/issues/348 -// https://mail.gnome.org/archives/xml/2012-January/msg00000.html -#if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) -#define STDOUT_FILENO fileno(stdout) -#endif // HAVE_UNISTD_H - -DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END - -// counts the number of elements in a C array -#define DOCTEST_COUNTOF(x) (sizeof(x) / sizeof(x[0])) - -#ifdef DOCTEST_CONFIG_DISABLE -#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_disabled -#else // DOCTEST_CONFIG_DISABLE -#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_not_disabled -#endif // DOCTEST_CONFIG_DISABLE - -#ifndef DOCTEST_CONFIG_OPTIONS_PREFIX -#define DOCTEST_CONFIG_OPTIONS_PREFIX "dt-" -#endif - -#ifndef DOCTEST_THREAD_LOCAL -#if defined(DOCTEST_CONFIG_NO_MULTITHREADING) || DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) -#define DOCTEST_THREAD_LOCAL -#else // DOCTEST_MSVC -#define DOCTEST_THREAD_LOCAL thread_local -#endif // DOCTEST_MSVC -#endif // DOCTEST_THREAD_LOCAL - -#ifndef DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES -#define DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES 32 -#endif - -#ifndef DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE -#define DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE 64 -#endif - -#ifdef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS -#define DOCTEST_OPTIONS_PREFIX_DISPLAY DOCTEST_CONFIG_OPTIONS_PREFIX -#else -#define DOCTEST_OPTIONS_PREFIX_DISPLAY "" -#endif - -#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) -#define DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS -#endif - -#ifndef DOCTEST_CDECL -#define DOCTEST_CDECL __cdecl -#endif - -namespace doctest { - -bool is_running_in_test = false; - -namespace { - using namespace detail; - - template - DOCTEST_NORETURN void throw_exception(Ex const& e) { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - throw e; -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - std::cerr << "doctest will terminate because it needed to throw an exception.\n" - << "The message was: " << e.what() << '\n'; - std::terminate(); -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - } - -#ifndef DOCTEST_INTERNAL_ERROR -#define DOCTEST_INTERNAL_ERROR(msg) \ - throw_exception(std::logic_error( \ - __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) -#endif // DOCTEST_INTERNAL_ERROR - - // case insensitive strcmp - int stricmp(const char* a, const char* b) { - for(;; a++, b++) { - const int d = tolower(*a) - tolower(*b); - if(d != 0 || !*a) - return d; - } - } - - struct Endianness - { - enum Arch - { - Big, - Little - }; - - static Arch which() { - int x = 1; - // casting any data pointer to char* is allowed - auto ptr = reinterpret_cast(&x); - if(*ptr) - return Little; - return Big; - } - }; -} // namespace - -namespace detail { - DOCTEST_THREAD_LOCAL class - { - std::vector stack; - std::stringstream ss; - - public: - std::ostream* push() { - stack.push_back(ss.tellp()); - return &ss; - } - - String pop() { - if (stack.empty()) - DOCTEST_INTERNAL_ERROR("TLSS was empty when trying to pop!"); - - std::streampos pos = stack.back(); - stack.pop_back(); - unsigned sz = static_cast(ss.tellp() - pos); - ss.rdbuf()->pubseekpos(pos, std::ios::in | std::ios::out); - return String(ss, sz); - } - } g_oss; - - std::ostream* tlssPush() { - return g_oss.push(); - } - - String tlssPop() { - return g_oss.pop(); - } - -#ifndef DOCTEST_CONFIG_DISABLE - -namespace timer_large_integer -{ - -#if defined(DOCTEST_PLATFORM_WINDOWS) - using type = ULONGLONG; -#else // DOCTEST_PLATFORM_WINDOWS - using type = std::uint64_t; -#endif // DOCTEST_PLATFORM_WINDOWS -} - -using ticks_t = timer_large_integer::type; - -#ifdef DOCTEST_CONFIG_GETCURRENTTICKS - ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } -#elif defined(DOCTEST_PLATFORM_WINDOWS) - ticks_t getCurrentTicks() { - static LARGE_INTEGER hz = { {0} }, hzo = { {0} }; - if(!hz.QuadPart) { - QueryPerformanceFrequency(&hz); - QueryPerformanceCounter(&hzo); - } - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return ((t.QuadPart - hzo.QuadPart) * LONGLONG(1000000)) / hz.QuadPart; - } -#else // DOCTEST_PLATFORM_WINDOWS - ticks_t getCurrentTicks() { - timeval t; - gettimeofday(&t, nullptr); - return static_cast(t.tv_sec) * 1000000 + static_cast(t.tv_usec); - } -#endif // DOCTEST_PLATFORM_WINDOWS - - struct Timer - { - void start() { m_ticks = getCurrentTicks(); } - unsigned int getElapsedMicroseconds() const { - return static_cast(getCurrentTicks() - m_ticks); - } - //unsigned int getElapsedMilliseconds() const { - // return static_cast(getElapsedMicroseconds() / 1000); - //} - double getElapsedSeconds() const { return static_cast(getCurrentTicks() - m_ticks) / 1000000.0; } - - private: - ticks_t m_ticks = 0; - }; - -#ifdef DOCTEST_CONFIG_NO_MULTITHREADING - template - using Atomic = T; -#else // DOCTEST_CONFIG_NO_MULTITHREADING - template - using Atomic = std::atomic; -#endif // DOCTEST_CONFIG_NO_MULTITHREADING - -#if defined(DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS) || defined(DOCTEST_CONFIG_NO_MULTITHREADING) - template - using MultiLaneAtomic = Atomic; -#else // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS - // Provides a multilane implementation of an atomic variable that supports add, sub, load, - // store. Instead of using a single atomic variable, this splits up into multiple ones, - // each sitting on a separate cache line. The goal is to provide a speedup when most - // operations are modifying. It achieves this with two properties: - // - // * Multiple atomics are used, so chance of congestion from the same atomic is reduced. - // * Each atomic sits on a separate cache line, so false sharing is reduced. - // - // The disadvantage is that there is a small overhead due to the use of TLS, and load/store - // is slower because all atomics have to be accessed. - template - class MultiLaneAtomic - { - struct CacheLineAlignedAtomic - { - Atomic atomic{}; - char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(Atomic)]; - }; - CacheLineAlignedAtomic m_atomics[DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES]; - - static_assert(sizeof(CacheLineAlignedAtomic) == DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE, - "guarantee one atomic takes exactly one cache line"); - - public: - T operator++() DOCTEST_NOEXCEPT { return fetch_add(1) + 1; } - - T operator++(int) DOCTEST_NOEXCEPT { return fetch_add(1); } - - T fetch_add(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { - return myAtomic().fetch_add(arg, order); - } - - T fetch_sub(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { - return myAtomic().fetch_sub(arg, order); - } - - operator T() const DOCTEST_NOEXCEPT { return load(); } - - T load(std::memory_order order = std::memory_order_seq_cst) const DOCTEST_NOEXCEPT { - auto result = T(); - for(auto const& c : m_atomics) { - result += c.atomic.load(order); - } - return result; - } - - T operator=(T desired) DOCTEST_NOEXCEPT { // lgtm [cpp/assignment-does-not-return-this] - store(desired); - return desired; - } - - void store(T desired, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { - // first value becomes desired", all others become 0. - for(auto& c : m_atomics) { - c.atomic.store(desired, order); - desired = {}; - } - } - - private: - // Each thread has a different atomic that it operates on. If more than NumLanes threads - // use this, some will use the same atomic. So performance will degrade a bit, but still - // everything will work. - // - // The logic here is a bit tricky. The call should be as fast as possible, so that there - // is minimal to no overhead in determining the correct atomic for the current thread. - // - // 1. A global static counter laneCounter counts continuously up. - // 2. Each successive thread will use modulo operation of that counter so it gets an atomic - // assigned in a round-robin fashion. - // 3. This tlsLaneIdx is stored in the thread local data, so it is directly available with - // little overhead. - Atomic& myAtomic() DOCTEST_NOEXCEPT { - static Atomic laneCounter; - DOCTEST_THREAD_LOCAL size_t tlsLaneIdx = - laneCounter++ % DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES; - - return m_atomics[tlsLaneIdx].atomic; - } - }; -#endif // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS - - // this holds both parameters from the command line and runtime data for tests - struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats - { - MultiLaneAtomic numAssertsCurrentTest_atomic; - MultiLaneAtomic numAssertsFailedCurrentTest_atomic; - - std::vector> filters = decltype(filters)(9); // 9 different filters - - std::vector reporters_currently_used; - - assert_handler ah = nullptr; - - Timer timer; - - std::vector stringifiedContexts; // logging from INFO() due to an exception - - // stuff for subcases - bool reachedLeaf; - std::vector subcaseStack; - std::vector nextSubcaseStack; - std::unordered_set fullyTraversedSubcases; - size_t currentSubcaseDepth; - Atomic shouldLogCurrentException; - - void resetRunData() { - numTestCases = 0; - numTestCasesPassingFilters = 0; - numTestSuitesPassingFilters = 0; - numTestCasesFailed = 0; - numAsserts = 0; - numAssertsFailed = 0; - numAssertsCurrentTest = 0; - numAssertsFailedCurrentTest = 0; - } - - void finalizeTestCaseData() { - seconds = timer.getElapsedSeconds(); - - // update the non-atomic counters - numAsserts += numAssertsCurrentTest_atomic; - numAssertsFailed += numAssertsFailedCurrentTest_atomic; - numAssertsCurrentTest = numAssertsCurrentTest_atomic; - numAssertsFailedCurrentTest = numAssertsFailedCurrentTest_atomic; - - if(numAssertsFailedCurrentTest) - failure_flags |= TestCaseFailureReason::AssertFailure; - - if(Approx(currentTest->m_timeout).epsilon(DBL_EPSILON) != 0 && - Approx(seconds).epsilon(DBL_EPSILON) > currentTest->m_timeout) - failure_flags |= TestCaseFailureReason::Timeout; - - if(currentTest->m_should_fail) { - if(failure_flags) { - failure_flags |= TestCaseFailureReason::ShouldHaveFailedAndDid; - } else { - failure_flags |= TestCaseFailureReason::ShouldHaveFailedButDidnt; - } - } else if(failure_flags && currentTest->m_may_fail) { - failure_flags |= TestCaseFailureReason::CouldHaveFailedAndDid; - } else if(currentTest->m_expected_failures > 0) { - if(numAssertsFailedCurrentTest == currentTest->m_expected_failures) { - failure_flags |= TestCaseFailureReason::FailedExactlyNumTimes; - } else { - failure_flags |= TestCaseFailureReason::DidntFailExactlyNumTimes; - } - } - - bool ok_to_fail = (TestCaseFailureReason::ShouldHaveFailedAndDid & failure_flags) || - (TestCaseFailureReason::CouldHaveFailedAndDid & failure_flags) || - (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); - - // if any subcase has failed - the whole test case has failed - testCaseSuccess = !(failure_flags && !ok_to_fail); - if(!testCaseSuccess) - numTestCasesFailed++; - } - }; - - ContextState* g_cs = nullptr; - - // used to avoid locks for the debug output - // TODO: figure out if this is indeed necessary/correct - seems like either there still - // could be a race or that there wouldn't be a race even if using the context directly - DOCTEST_THREAD_LOCAL bool g_no_colors; - -#endif // DOCTEST_CONFIG_DISABLE -} // namespace detail - -char* String::allocate(size_type sz) { - if (sz <= last) { - buf[sz] = '\0'; - setLast(last - sz); - return buf; - } else { - setOnHeap(); - data.size = sz; - data.capacity = data.size + 1; - data.ptr = new char[data.capacity]; - data.ptr[sz] = '\0'; - return data.ptr; - } -} - -void String::setOnHeap() noexcept { *reinterpret_cast(&buf[last]) = 128; } -void String::setLast(size_type in) noexcept { buf[last] = char(in); } -void String::setSize(size_type sz) noexcept { - if (isOnStack()) { buf[sz] = '\0'; setLast(last - sz); } - else { data.ptr[sz] = '\0'; data.size = sz; } -} - -void String::copy(const String& other) { - if(other.isOnStack()) { - memcpy(buf, other.buf, len); - } else { - memcpy(allocate(other.data.size), other.data.ptr, other.data.size); - } -} - -String::String() noexcept { - buf[0] = '\0'; - setLast(); -} - -String::~String() { - if(!isOnStack()) - delete[] data.ptr; -} // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) - -String::String(const char* in) - : String(in, strlen(in)) {} - -String::String(const char* in, size_type in_size) { - memcpy(allocate(in_size), in, in_size); -} - -String::String(std::istream& in, size_type in_size) { - in.read(allocate(in_size), in_size); -} - -String::String(const String& other) { copy(other); } - -String& String::operator=(const String& other) { - if(this != &other) { - if(!isOnStack()) - delete[] data.ptr; - - copy(other); - } - - return *this; -} - -String& String::operator+=(const String& other) { - const size_type my_old_size = size(); - const size_type other_size = other.size(); - const size_type total_size = my_old_size + other_size; - if(isOnStack()) { - if(total_size < len) { - // append to the current stack space - memcpy(buf + my_old_size, other.c_str(), other_size + 1); - // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) - setLast(last - total_size); - } else { - // alloc new chunk - char* temp = new char[total_size + 1]; - // copy current data to new location before writing in the union - memcpy(temp, buf, my_old_size); // skip the +1 ('\0') for speed - // update data in union - setOnHeap(); - data.size = total_size; - data.capacity = data.size + 1; - data.ptr = temp; - // transfer the rest of the data - memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); - } - } else { - if(data.capacity > total_size) { - // append to the current heap block - data.size = total_size; - memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); - } else { - // resize - data.capacity *= 2; - if(data.capacity <= total_size) - data.capacity = total_size + 1; - // alloc new chunk - char* temp = new char[data.capacity]; - // copy current data to new location before releasing it - memcpy(temp, data.ptr, my_old_size); // skip the +1 ('\0') for speed - // release old chunk - delete[] data.ptr; - // update the rest of the union members - data.size = total_size; - data.ptr = temp; - // transfer the rest of the data - memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); - } - } - - return *this; -} - -String::String(String&& other) noexcept { - memcpy(buf, other.buf, len); - other.buf[0] = '\0'; - other.setLast(); -} - -String& String::operator=(String&& other) noexcept { - if(this != &other) { - if(!isOnStack()) - delete[] data.ptr; - memcpy(buf, other.buf, len); - other.buf[0] = '\0'; - other.setLast(); - } - return *this; -} - -char String::operator[](size_type i) const { - return const_cast(this)->operator[](i); -} - -char& String::operator[](size_type i) { - if(isOnStack()) - return reinterpret_cast(buf)[i]; - return data.ptr[i]; -} - -DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") -String::size_type String::size() const { - if(isOnStack()) - return last - (size_type(buf[last]) & 31); // using "last" would work only if "len" is 32 - return data.size; -} -DOCTEST_GCC_SUPPRESS_WARNING_POP - -String::size_type String::capacity() const { - if(isOnStack()) - return len; - return data.capacity; -} - -String String::substr(size_type pos, size_type cnt) && { - cnt = std::min(cnt, size() - 1 - pos); - char* cptr = c_str(); - memmove(cptr, cptr + pos, cnt); - setSize(cnt); - return std::move(*this); -} - -String String::substr(size_type pos, size_type cnt) const & { - cnt = std::min(cnt, size() - 1 - pos); - return String{ c_str() + pos, cnt }; -} - -String::size_type String::find(char ch, size_type pos) const { - const char* begin = c_str(); - const char* end = begin + size(); - const char* it = begin + pos; - for (; it < end && *it != ch; it++); - if (it < end) { return static_cast(it - begin); } - else { return npos; } -} - -String::size_type String::rfind(char ch, size_type pos) const { - const char* begin = c_str(); - const char* it = begin + std::min(pos, size() - 1); - for (; it >= begin && *it != ch; it--); - if (it >= begin) { return static_cast(it - begin); } - else { return npos; } -} - -int String::compare(const char* other, bool no_case) const { - if(no_case) - return doctest::stricmp(c_str(), other); - return std::strcmp(c_str(), other); -} - -int String::compare(const String& other, bool no_case) const { - return compare(other.c_str(), no_case); -} - -String operator+(const String& lhs, const String& rhs) { return String(lhs) += rhs; } - -bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } -bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } -bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } -bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } -bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } -bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } - -std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } - -Contains::Contains(const String& str) : string(str) { } - -bool Contains::checkWith(const String& other) const { - return strstr(other.c_str(), string.c_str()) != nullptr; -} - -String toString(const Contains& in) { - return "Contains( " + in.string + " )"; -} - -bool operator==(const String& lhs, const Contains& rhs) { return rhs.checkWith(lhs); } -bool operator==(const Contains& lhs, const String& rhs) { return lhs.checkWith(rhs); } -bool operator!=(const String& lhs, const Contains& rhs) { return !rhs.checkWith(lhs); } -bool operator!=(const Contains& lhs, const String& rhs) { return !lhs.checkWith(rhs); } - -namespace { - void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) -} // namespace - -namespace Color { - std::ostream& operator<<(std::ostream& s, Color::Enum code) { - color_to_stream(s, code); - return s; - } -} // namespace Color - -// clang-format off -const char* assertString(assertType::Enum at) { - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4061) // enum 'x' in switch of enum 'y' is not explicitely handled - #define DOCTEST_GENERATE_ASSERT_TYPE_CASE(assert_type) case assertType::DT_ ## assert_type: return #assert_type - #define DOCTEST_GENERATE_ASSERT_TYPE_CASES(assert_type) \ - DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN_ ## assert_type); \ - DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK_ ## assert_type); \ - DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE_ ## assert_type) - switch(at) { - DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN); - DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK); - DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(FALSE); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_AS); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH_AS); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(NOTHROW); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(EQ); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(NE); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(GT); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(LT); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(GE); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(LE); - - DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY); - DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY_FALSE); - - default: DOCTEST_INTERNAL_ERROR("Tried stringifying invalid assert type!"); - } - DOCTEST_MSVC_SUPPRESS_WARNING_POP -} -// clang-format on - -const char* failureString(assertType::Enum at) { - if(at & assertType::is_warn) //!OCLINT bitwise operator in conditional - return "WARNING"; - if(at & assertType::is_check) //!OCLINT bitwise operator in conditional - return "ERROR"; - if(at & assertType::is_require) //!OCLINT bitwise operator in conditional - return "FATAL ERROR"; - return ""; -} - -DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") -DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") -// depending on the current options this will remove the path of filenames -const char* skipPathFromFilename(const char* file) { -#ifndef DOCTEST_CONFIG_DISABLE - if(getContextOptions()->no_path_in_filenames) { - auto back = std::strrchr(file, '\\'); - auto forward = std::strrchr(file, '/'); - if(back || forward) { - if(back > forward) - forward = back; - return forward + 1; - } - } -#endif // DOCTEST_CONFIG_DISABLE - return file; -} -DOCTEST_CLANG_SUPPRESS_WARNING_POP -DOCTEST_GCC_SUPPRESS_WARNING_POP - -bool SubcaseSignature::operator==(const SubcaseSignature& other) const { - return m_line == other.m_line - && std::strcmp(m_file, other.m_file) == 0 - && m_name == other.m_name; -} - -bool SubcaseSignature::operator<(const SubcaseSignature& other) const { - if(m_line != other.m_line) - return m_line < other.m_line; - if(std::strcmp(m_file, other.m_file) != 0) - return std::strcmp(m_file, other.m_file) < 0; - return m_name.compare(other.m_name) < 0; -} - -DOCTEST_DEFINE_INTERFACE(IContextScope) - -namespace detail { - void filldata::fill(std::ostream* stream, const void* in) { - if (in) { *stream << in; } - else { *stream << "nullptr"; } - } - - template - String toStreamLit(T t) { - std::ostream* os = tlssPush(); - os->operator<<(t); - return tlssPop(); - } -} - -#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING -String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } -#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING - -#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) -// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 -String toString(const std::string& in) { return in.c_str(); } -#endif // VS 2019 - -String toString(String in) { return in; } - -String toString(std::nullptr_t) { return "nullptr"; } - -String toString(bool in) { return in ? "true" : "false"; } - -String toString(float in) { return toStreamLit(in); } -String toString(double in) { return toStreamLit(in); } -String toString(double long in) { return toStreamLit(in); } - -String toString(char in) { return toStreamLit(static_cast(in)); } -String toString(char signed in) { return toStreamLit(static_cast(in)); } -String toString(char unsigned in) { return toStreamLit(static_cast(in)); } -String toString(short in) { return toStreamLit(in); } -String toString(short unsigned in) { return toStreamLit(in); } -String toString(signed in) { return toStreamLit(in); } -String toString(unsigned in) { return toStreamLit(in); } -String toString(long in) { return toStreamLit(in); } -String toString(long unsigned in) { return toStreamLit(in); } -String toString(long long in) { return toStreamLit(in); } -String toString(long long unsigned in) { return toStreamLit(in); } - -Approx::Approx(double value) - : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) - , m_scale(1.0) - , m_value(value) {} - -Approx Approx::operator()(double value) const { - Approx approx(value); - approx.epsilon(m_epsilon); - approx.scale(m_scale); - return approx; -} - -Approx& Approx::epsilon(double newEpsilon) { - m_epsilon = newEpsilon; - return *this; -} -Approx& Approx::scale(double newScale) { - m_scale = newScale; - return *this; -} - -bool operator==(double lhs, const Approx& rhs) { - // Thanks to Richard Harris for his help refining this formula - return std::fabs(lhs - rhs.m_value) < - rhs.m_epsilon * (rhs.m_scale + std::max(std::fabs(lhs), std::fabs(rhs.m_value))); -} -bool operator==(const Approx& lhs, double rhs) { return operator==(rhs, lhs); } -bool operator!=(double lhs, const Approx& rhs) { return !operator==(lhs, rhs); } -bool operator!=(const Approx& lhs, double rhs) { return !operator==(rhs, lhs); } -bool operator<=(double lhs, const Approx& rhs) { return lhs < rhs.m_value || lhs == rhs; } -bool operator<=(const Approx& lhs, double rhs) { return lhs.m_value < rhs || lhs == rhs; } -bool operator>=(double lhs, const Approx& rhs) { return lhs > rhs.m_value || lhs == rhs; } -bool operator>=(const Approx& lhs, double rhs) { return lhs.m_value > rhs || lhs == rhs; } -bool operator<(double lhs, const Approx& rhs) { return lhs < rhs.m_value && lhs != rhs; } -bool operator<(const Approx& lhs, double rhs) { return lhs.m_value < rhs && lhs != rhs; } -bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs != rhs; } -bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } - -String toString(const Approx& in) { - return "Approx( " + doctest::toString(in.m_value) + " )"; -} -const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } - -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4738) -template -IsNaN::operator bool() const { - return std::isnan(value) ^ flipped; -} -DOCTEST_MSVC_SUPPRESS_WARNING_POP -template struct DOCTEST_INTERFACE_DEF IsNaN; -template struct DOCTEST_INTERFACE_DEF IsNaN; -template struct DOCTEST_INTERFACE_DEF IsNaN; -template -String toString(IsNaN in) { return String(in.flipped ? "! " : "") + "IsNaN( " + doctest::toString(in.value) + " )"; } -String toString(IsNaN in) { return toString(in); } -String toString(IsNaN in) { return toString(in); } -String toString(IsNaN in) { return toString(in); } - -} // namespace doctest - -#ifdef DOCTEST_CONFIG_DISABLE -namespace doctest { -Context::Context(int, const char* const*) {} -Context::~Context() = default; -void Context::applyCommandLine(int, const char* const*) {} -void Context::addFilter(const char*, const char*) {} -void Context::clearFilters() {} -void Context::setOption(const char*, bool) {} -void Context::setOption(const char*, int) {} -void Context::setOption(const char*, const char*) {} -bool Context::shouldExit() { return false; } -void Context::setAsDefaultForAssertsOutOfTestCases() {} -void Context::setAssertHandler(detail::assert_handler) {} -void Context::setCout(std::ostream*) {} -int Context::run() { return 0; } - -int IReporter::get_num_active_contexts() { return 0; } -const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } -int IReporter::get_num_stringified_contexts() { return 0; } -const String* IReporter::get_stringified_contexts() { return nullptr; } - -int registerReporter(const char*, int, IReporter*) { return 0; } - -} // namespace doctest -#else // DOCTEST_CONFIG_DISABLE - -#if !defined(DOCTEST_CONFIG_COLORS_NONE) -#if !defined(DOCTEST_CONFIG_COLORS_WINDOWS) && !defined(DOCTEST_CONFIG_COLORS_ANSI) -#ifdef DOCTEST_PLATFORM_WINDOWS -#define DOCTEST_CONFIG_COLORS_WINDOWS -#else // linux -#define DOCTEST_CONFIG_COLORS_ANSI -#endif // platform -#endif // DOCTEST_CONFIG_COLORS_WINDOWS && DOCTEST_CONFIG_COLORS_ANSI -#endif // DOCTEST_CONFIG_COLORS_NONE - -namespace doctest_detail_test_suite_ns { -// holds the current test suite -doctest::detail::TestSuite& getCurrentTestSuite() { - static doctest::detail::TestSuite data{}; - return data; -} -} // namespace doctest_detail_test_suite_ns - -namespace doctest { -namespace { - // the int (priority) is part of the key for automatic sorting - sadly one can register a - // reporter with a duplicate name and a different priority but hopefully that won't happen often :| - using reporterMap = std::map, reporterCreatorFunc>; - - reporterMap& getReporters() { - static reporterMap data; - return data; - } - reporterMap& getListeners() { - static reporterMap data; - return data; - } -} // namespace -namespace detail { -#define DOCTEST_ITERATE_THROUGH_REPORTERS(function, ...) \ - for(auto& curr_rep : g_cs->reporters_currently_used) \ - curr_rep->function(__VA_ARGS__) - - bool checkIfShouldThrow(assertType::Enum at) { - if(at & assertType::is_require) //!OCLINT bitwise operator in conditional - return true; - - if((at & assertType::is_check) //!OCLINT bitwise operator in conditional - && getContextOptions()->abort_after > 0 && - (g_cs->numAssertsFailed + g_cs->numAssertsFailedCurrentTest_atomic) >= - getContextOptions()->abort_after) - return true; - - return false; - } - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - DOCTEST_NORETURN void throwException() { - g_cs->shouldLogCurrentException = false; - throw TestFailureException(); // NOLINT(hicpp-exception-baseclass) - } -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - void throwException() {} -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS -} // namespace detail - -namespace { - using namespace detail; - // matching of a string against a wildcard mask (case sensitivity configurable) taken from - // https://www.codeproject.com/Articles/1088/Wildcard-string-compare-globbing - int wildcmp(const char* str, const char* wild, bool caseSensitive) { - const char* cp = str; - const char* mp = wild; - - while((*str) && (*wild != '*')) { - if((caseSensitive ? (*wild != *str) : (tolower(*wild) != tolower(*str))) && - (*wild != '?')) { - return 0; - } - wild++; - str++; - } - - while(*str) { - if(*wild == '*') { - if(!*++wild) { - return 1; - } - mp = wild; - cp = str + 1; - } else if((caseSensitive ? (*wild == *str) : (tolower(*wild) == tolower(*str))) || - (*wild == '?')) { - wild++; - str++; - } else { - wild = mp; //!OCLINT parameter reassignment - str = cp++; //!OCLINT parameter reassignment - } - } - - while(*wild == '*') { - wild++; - } - return !*wild; - } - - // checks if the name matches any of the filters (and can be configured what to do when empty) - bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, - bool caseSensitive) { - if (filters.empty() && matchEmpty) - return true; - for (auto& curr : filters) - if (wildcmp(name, curr.c_str(), caseSensitive)) - return true; - return false; - } - - unsigned long long hash(unsigned long long a, unsigned long long b) { - return (a << 5) + b; - } - - // C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html - unsigned long long hash(const char* str) { - unsigned long long hash = 5381; - char c; - while ((c = *str++)) - hash = ((hash << 5) + hash) + c; // hash * 33 + c - return hash; - } - - unsigned long long hash(const SubcaseSignature& sig) { - return hash(hash(hash(sig.m_file), hash(sig.m_name.c_str())), sig.m_line); - } - - unsigned long long hash(const std::vector& sigs, size_t count) { - unsigned long long running = 0; - auto end = sigs.begin() + count; - for (auto it = sigs.begin(); it != end; it++) { - running = hash(running, hash(*it)); - } - return running; - } - - unsigned long long hash(const std::vector& sigs) { - unsigned long long running = 0; - for (const SubcaseSignature& sig : sigs) { - running = hash(running, hash(sig)); - } - return running; - } -} // namespace -namespace detail { - bool Subcase::checkFilters() { - if (g_cs->subcaseStack.size() < size_t(g_cs->subcase_filter_levels)) { - if (!matchesAny(m_signature.m_name.c_str(), g_cs->filters[6], true, g_cs->case_sensitive)) - return true; - if (matchesAny(m_signature.m_name.c_str(), g_cs->filters[7], false, g_cs->case_sensitive)) - return true; - } - return false; - } - - Subcase::Subcase(const String& name, const char* file, int line) - : m_signature({name, file, line}) { - if (!g_cs->reachedLeaf) { - if (g_cs->nextSubcaseStack.size() <= g_cs->subcaseStack.size() - || g_cs->nextSubcaseStack[g_cs->subcaseStack.size()] == m_signature) { - // Going down. - if (checkFilters()) { return; } - - g_cs->subcaseStack.push_back(m_signature); - g_cs->currentSubcaseDepth++; - m_entered = true; - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); - } - } else { - if (g_cs->subcaseStack[g_cs->currentSubcaseDepth] == m_signature) { - // This subcase is reentered via control flow. - g_cs->currentSubcaseDepth++; - m_entered = true; - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); - } else if (g_cs->nextSubcaseStack.size() <= g_cs->currentSubcaseDepth - && g_cs->fullyTraversedSubcases.find(hash(hash(g_cs->subcaseStack, g_cs->currentSubcaseDepth), hash(m_signature))) - == g_cs->fullyTraversedSubcases.end()) { - if (checkFilters()) { return; } - // This subcase is part of the one to be executed next. - g_cs->nextSubcaseStack.clear(); - g_cs->nextSubcaseStack.insert(g_cs->nextSubcaseStack.end(), - g_cs->subcaseStack.begin(), g_cs->subcaseStack.begin() + g_cs->currentSubcaseDepth); - g_cs->nextSubcaseStack.push_back(m_signature); - } - } - } - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - - Subcase::~Subcase() { - if (m_entered) { - g_cs->currentSubcaseDepth--; - - if (!g_cs->reachedLeaf) { - // Leaf. - g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); - g_cs->nextSubcaseStack.clear(); - g_cs->reachedLeaf = true; - } else if (g_cs->nextSubcaseStack.empty()) { - // All children are finished. - g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); - } - -#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) - if(std::uncaught_exceptions() > 0 -#else - if(std::uncaught_exception() -#endif - && g_cs->shouldLogCurrentException) { - DOCTEST_ITERATE_THROUGH_REPORTERS( - test_case_exception, {"exception thrown in subcase - will translate later " - "when the whole test case has been exited (cannot " - "translate while there is an active exception)", - false}); - g_cs->shouldLogCurrentException = false; - } - - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); - } - } - - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP - DOCTEST_MSVC_SUPPRESS_WARNING_POP - - Subcase::operator bool() const { return m_entered; } - - Result::Result(bool passed, const String& decomposition) - : m_passed(passed) - , m_decomp(decomposition) {} - - ExpressionDecomposer::ExpressionDecomposer(assertType::Enum at) - : m_at(at) {} - - TestSuite& TestSuite::operator*(const char* in) { - m_test_suite = in; - return *this; - } - - TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, - const String& type, int template_id) { - m_file = file; - m_line = line; - m_name = nullptr; // will be later overridden in operator* - m_test_suite = test_suite.m_test_suite; - m_description = test_suite.m_description; - m_skip = test_suite.m_skip; - m_no_breaks = test_suite.m_no_breaks; - m_no_output = test_suite.m_no_output; - m_may_fail = test_suite.m_may_fail; - m_should_fail = test_suite.m_should_fail; - m_expected_failures = test_suite.m_expected_failures; - m_timeout = test_suite.m_timeout; - - m_test = test; - m_type = type; - m_template_id = template_id; - } - - TestCase::TestCase(const TestCase& other) - : TestCaseData() { - *this = other; - } - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function - TestCase& TestCase::operator=(const TestCase& other) { - TestCaseData::operator=(other); - m_test = other.m_test; - m_type = other.m_type; - m_template_id = other.m_template_id; - m_full_name = other.m_full_name; - - if(m_template_id != -1) - m_name = m_full_name.c_str(); - return *this; - } - DOCTEST_MSVC_SUPPRESS_WARNING_POP - - TestCase& TestCase::operator*(const char* in) { - m_name = in; - // make a new name with an appended type for templated test case - if(m_template_id != -1) { - m_full_name = String(m_name) + "<" + m_type + ">"; - // redirect the name to point to the newly constructed full name - m_name = m_full_name.c_str(); - } - return *this; - } - - bool TestCase::operator<(const TestCase& other) const { - // this will be used only to differentiate between test cases - not relevant for sorting - if(m_line != other.m_line) - return m_line < other.m_line; - const int name_cmp = strcmp(m_name, other.m_name); - if(name_cmp != 0) - return name_cmp < 0; - const int file_cmp = m_file.compare(other.m_file); - if(file_cmp != 0) - return file_cmp < 0; - return m_template_id < other.m_template_id; - } - - // all the registered tests - std::set& getRegisteredTests() { - static std::set data; - return data; - } -} // namespace detail -namespace { - using namespace detail; - // for sorting tests by file/line - bool fileOrderComparator(const TestCase* lhs, const TestCase* rhs) { - // this is needed because MSVC gives different case for drive letters - // for __FILE__ when evaluated in a header and a source file - const int res = lhs->m_file.compare(rhs->m_file, bool(DOCTEST_MSVC)); - if(res != 0) - return res < 0; - if(lhs->m_line != rhs->m_line) - return lhs->m_line < rhs->m_line; - return lhs->m_template_id < rhs->m_template_id; - } - - // for sorting tests by suite/file/line - bool suiteOrderComparator(const TestCase* lhs, const TestCase* rhs) { - const int res = std::strcmp(lhs->m_test_suite, rhs->m_test_suite); - if(res != 0) - return res < 0; - return fileOrderComparator(lhs, rhs); - } - - // for sorting tests by name/suite/file/line - bool nameOrderComparator(const TestCase* lhs, const TestCase* rhs) { - const int res = std::strcmp(lhs->m_name, rhs->m_name); - if(res != 0) - return res < 0; - return suiteOrderComparator(lhs, rhs); - } - - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - void color_to_stream(std::ostream& s, Color::Enum code) { - static_cast(s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS - static_cast(code); // for DOCTEST_CONFIG_COLORS_NONE -#ifdef DOCTEST_CONFIG_COLORS_ANSI - if(g_no_colors || - (isatty(STDOUT_FILENO) == false && getContextOptions()->force_colors == false)) - return; - - auto col = ""; - // clang-format off - switch(code) { //!OCLINT missing break in switch statement / unnecessary default statement in covered switch statement - case Color::Red: col = "[0;31m"; break; - case Color::Green: col = "[0;32m"; break; - case Color::Blue: col = "[0;34m"; break; - case Color::Cyan: col = "[0;36m"; break; - case Color::Yellow: col = "[0;33m"; break; - case Color::Grey: col = "[1;30m"; break; - case Color::LightGrey: col = "[0;37m"; break; - case Color::BrightRed: col = "[1;31m"; break; - case Color::BrightGreen: col = "[1;32m"; break; - case Color::BrightWhite: col = "[1;37m"; break; - case Color::Bright: // invalid - case Color::None: - case Color::White: - default: col = "[0m"; - } - // clang-format on - s << "\033" << col; -#endif // DOCTEST_CONFIG_COLORS_ANSI - -#ifdef DOCTEST_CONFIG_COLORS_WINDOWS - if(g_no_colors || - (_isatty(_fileno(stdout)) == false && getContextOptions()->force_colors == false)) - return; - - static struct ConsoleHelper { - HANDLE stdoutHandle; - WORD origFgAttrs; - WORD origBgAttrs; - - ConsoleHelper() { - stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); - CONSOLE_SCREEN_BUFFER_INFO csbiInfo; - GetConsoleScreenBufferInfo(stdoutHandle, &csbiInfo); - origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | - BACKGROUND_BLUE | BACKGROUND_INTENSITY); - origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | - FOREGROUND_BLUE | FOREGROUND_INTENSITY); - } - } ch; - -#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(ch.stdoutHandle, x | ch.origBgAttrs) - - // clang-format off - switch (code) { - case Color::White: DOCTEST_SET_ATTR(FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; - case Color::Red: DOCTEST_SET_ATTR(FOREGROUND_RED); break; - case Color::Green: DOCTEST_SET_ATTR(FOREGROUND_GREEN); break; - case Color::Blue: DOCTEST_SET_ATTR(FOREGROUND_BLUE); break; - case Color::Cyan: DOCTEST_SET_ATTR(FOREGROUND_BLUE | FOREGROUND_GREEN); break; - case Color::Yellow: DOCTEST_SET_ATTR(FOREGROUND_RED | FOREGROUND_GREEN); break; - case Color::Grey: DOCTEST_SET_ATTR(0); break; - case Color::LightGrey: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY); break; - case Color::BrightRed: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_RED); break; - case Color::BrightGreen: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN); break; - case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; - case Color::None: - case Color::Bright: // invalid - default: DOCTEST_SET_ATTR(ch.origFgAttrs); - } - // clang-format on -#endif // DOCTEST_CONFIG_COLORS_WINDOWS - } - DOCTEST_CLANG_SUPPRESS_WARNING_POP - - std::vector& getExceptionTranslators() { - static std::vector data; - return data; - } - - String translateActiveException() { -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - String res; - auto& translators = getExceptionTranslators(); - for(auto& curr : translators) - if(curr->translate(res)) - return res; - // clang-format off - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wcatch-value") - try { - throw; - } catch(std::exception& ex) { - return ex.what(); - } catch(std::string& msg) { - return msg.c_str(); - } catch(const char* msg) { - return msg; - } catch(...) { - return "unknown exception"; - } - DOCTEST_GCC_SUPPRESS_WARNING_POP -// clang-format on -#else // DOCTEST_CONFIG_NO_EXCEPTIONS - return ""; -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - } -} // namespace - -namespace detail { - // used by the macros for registering tests - int regTest(const TestCase& tc) { - getRegisteredTests().insert(tc); - return 0; - } - - // sets the current test suite - int setTestSuite(const TestSuite& ts) { - doctest_detail_test_suite_ns::getCurrentTestSuite() = ts; - return 0; - } - -#ifdef DOCTEST_IS_DEBUGGER_ACTIVE - bool isDebuggerActive() { return DOCTEST_IS_DEBUGGER_ACTIVE(); } -#else // DOCTEST_IS_DEBUGGER_ACTIVE -#ifdef DOCTEST_PLATFORM_LINUX - class ErrnoGuard { - public: - ErrnoGuard() : m_oldErrno(errno) {} - ~ErrnoGuard() { errno = m_oldErrno; } - private: - int m_oldErrno; - }; - // See the comments in Catch2 for the reasoning behind this implementation: - // https://github.com/catchorg/Catch2/blob/v2.13.1/include/internal/catch_debugger.cpp#L79-L102 - bool isDebuggerActive() { - ErrnoGuard guard; - std::ifstream in("/proc/self/status"); - for(std::string line; std::getline(in, line);) { - static const int PREFIX_LEN = 11; - if(line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0) { - return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; - } - } - return false; - } -#elif defined(DOCTEST_PLATFORM_MAC) - // The following function is taken directly from the following technical note: - // https://developer.apple.com/library/archive/qa/qa1361/_index.html - // Returns true if the current process is being debugged (either - // running under the debugger or has a debugger attached post facto). - bool isDebuggerActive() { - int mib[4]; - kinfo_proc info; - size_t size; - // Initialize the flags so that, if sysctl fails for some bizarre - // reason, we get a predictable result. - info.kp_proc.p_flag = 0; - // Initialize mib, which tells sysctl the info we want, in this case - // we're looking for information about a specific process ID. - mib[0] = CTL_KERN; - mib[1] = KERN_PROC; - mib[2] = KERN_PROC_PID; - mib[3] = getpid(); - // Call sysctl. - size = sizeof(info); - if(sysctl(mib, DOCTEST_COUNTOF(mib), &info, &size, 0, 0) != 0) { - std::cerr << "\nCall to sysctl failed - unable to determine if debugger is active **\n"; - return false; - } - // We're being debugged if the P_TRACED flag is set. - return ((info.kp_proc.p_flag & P_TRACED) != 0); - } -#elif DOCTEST_MSVC || defined(__MINGW32__) || defined(__MINGW64__) - bool isDebuggerActive() { return ::IsDebuggerPresent() != 0; } -#else - bool isDebuggerActive() { return false; } -#endif // Platform -#endif // DOCTEST_IS_DEBUGGER_ACTIVE - - void registerExceptionTranslatorImpl(const IExceptionTranslator* et) { - if(std::find(getExceptionTranslators().begin(), getExceptionTranslators().end(), et) == - getExceptionTranslators().end()) - getExceptionTranslators().push_back(et); - } - - DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() - - ContextScopeBase::ContextScopeBase() { - g_infoContexts.push_back(this); - } - - ContextScopeBase::ContextScopeBase(ContextScopeBase&& other) noexcept { - if (other.need_to_destroy) { - other.destroy(); - } - other.need_to_destroy = false; - g_infoContexts.push_back(this); - } - - DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 - DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") - - // destroy cannot be inlined into the destructor because that would mean calling stringify after - // ContextScope has been destroyed (base class destructors run after derived class destructors). - // Instead, ContextScope calls this method directly from its destructor. - void ContextScopeBase::destroy() { -#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) - if(std::uncaught_exceptions() > 0) { -#else - if(std::uncaught_exception()) { -#endif - std::ostringstream s; - this->stringify(&s); - g_cs->stringifiedContexts.push_back(s.str().c_str()); - } - g_infoContexts.pop_back(); - } - - DOCTEST_CLANG_SUPPRESS_WARNING_POP - DOCTEST_GCC_SUPPRESS_WARNING_POP - DOCTEST_MSVC_SUPPRESS_WARNING_POP -} // namespace detail -namespace { - using namespace detail; - -#if !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && !defined(DOCTEST_CONFIG_WINDOWS_SEH) - struct FatalConditionHandler - { - static void reset() {} - static void allocateAltStackMem() {} - static void freeAltStackMem() {} - }; -#else // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH - - void reportFatal(const std::string&); - -#ifdef DOCTEST_PLATFORM_WINDOWS - - struct SignalDefs - { - DWORD id; - const char* name; - }; - // There is no 1-1 mapping between signals and windows exceptions. - // Windows can easily distinguish between SO and SigSegV, - // but SigInt, SigTerm, etc are handled differently. - SignalDefs signalDefs[] = { - {static_cast(EXCEPTION_ILLEGAL_INSTRUCTION), - "SIGILL - Illegal instruction signal"}, - {static_cast(EXCEPTION_STACK_OVERFLOW), "SIGSEGV - Stack overflow"}, - {static_cast(EXCEPTION_ACCESS_VIOLATION), - "SIGSEGV - Segmentation violation signal"}, - {static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error"}, - }; - - struct FatalConditionHandler - { - static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { - // Multiple threads may enter this filter/handler at once. We want the error message to be printed on the - // console just once no matter how many threads have crashed. - DOCTEST_DECLARE_STATIC_MUTEX(mutex) - static bool execute = true; - { - DOCTEST_LOCK_MUTEX(mutex) - if(execute) { - bool reported = false; - for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - if(ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { - reportFatal(signalDefs[i].name); - reported = true; - break; - } - } - if(reported == false) - reportFatal("Unhandled SEH exception caught"); - if(isDebuggerActive() && !g_cs->no_breaks) - DOCTEST_BREAK_INTO_DEBUGGER(); - } - execute = false; - } - std::exit(EXIT_FAILURE); - } - - static void allocateAltStackMem() {} - static void freeAltStackMem() {} - - FatalConditionHandler() { - isSet = true; - // 32k seems enough for doctest to handle stack overflow, - // but the value was found experimentally, so there is no strong guarantee - guaranteeSize = 32 * 1024; - // Register an unhandled exception filter - previousTop = SetUnhandledExceptionFilter(handleException); - // Pass in guarantee size to be filled - SetThreadStackGuarantee(&guaranteeSize); - - // On Windows uncaught exceptions from another thread, exceptions from - // destructors, or calls to std::terminate are not a SEH exception - - // The terminal handler gets called when: - // - std::terminate is called FROM THE TEST RUNNER THREAD - // - an exception is thrown from a destructor FROM THE TEST RUNNER THREAD - original_terminate_handler = std::get_terminate(); - std::set_terminate([]() DOCTEST_NOEXCEPT { - reportFatal("Terminate handler called"); - if(isDebuggerActive() && !g_cs->no_breaks) - DOCTEST_BREAK_INTO_DEBUGGER(); - std::exit(EXIT_FAILURE); // explicitly exit - otherwise the SIGABRT handler may be called as well - }); - - // SIGABRT is raised when: - // - std::terminate is called FROM A DIFFERENT THREAD - // - an exception is thrown from a destructor FROM A DIFFERENT THREAD - // - an uncaught exception is thrown FROM A DIFFERENT THREAD - prev_sigabrt_handler = std::signal(SIGABRT, [](int signal) DOCTEST_NOEXCEPT { - if(signal == SIGABRT) { - reportFatal("SIGABRT - Abort (abnormal termination) signal"); - if(isDebuggerActive() && !g_cs->no_breaks) - DOCTEST_BREAK_INTO_DEBUGGER(); - std::exit(EXIT_FAILURE); - } - }); - - // The following settings are taken from google test, and more - // specifically from UnitTest::Run() inside of gtest.cc - - // the user does not want to see pop-up dialogs about crashes - prev_error_mode_1 = SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | - SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); - // This forces the abort message to go to stderr in all circumstances. - prev_error_mode_2 = _set_error_mode(_OUT_TO_STDERR); - // In the debug version, Visual Studio pops up a separate dialog - // offering a choice to debug the aborted program - we want to disable that. - prev_abort_behavior = _set_abort_behavior(0x0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); - // In debug mode, the Windows CRT can crash with an assertion over invalid - // input (e.g. passing an invalid file descriptor). The default handling - // for these assertions is to pop up a dialog and wait for user input. - // Instead ask the CRT to dump such assertions to stderr non-interactively. - prev_report_mode = _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); - prev_report_file = _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); - } - - static void reset() { - if(isSet) { - // Unregister handler and restore the old guarantee - SetUnhandledExceptionFilter(previousTop); - SetThreadStackGuarantee(&guaranteeSize); - std::set_terminate(original_terminate_handler); - std::signal(SIGABRT, prev_sigabrt_handler); - SetErrorMode(prev_error_mode_1); - _set_error_mode(prev_error_mode_2); - _set_abort_behavior(prev_abort_behavior, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); - static_cast(_CrtSetReportMode(_CRT_ASSERT, prev_report_mode)); - static_cast(_CrtSetReportFile(_CRT_ASSERT, prev_report_file)); - isSet = false; - } - } - - ~FatalConditionHandler() { reset(); } - - private: - static UINT prev_error_mode_1; - static int prev_error_mode_2; - static unsigned int prev_abort_behavior; - static int prev_report_mode; - static _HFILE prev_report_file; - static void (DOCTEST_CDECL *prev_sigabrt_handler)(int); - static std::terminate_handler original_terminate_handler; - static bool isSet; - static ULONG guaranteeSize; - static LPTOP_LEVEL_EXCEPTION_FILTER previousTop; - }; - - UINT FatalConditionHandler::prev_error_mode_1; - int FatalConditionHandler::prev_error_mode_2; - unsigned int FatalConditionHandler::prev_abort_behavior; - int FatalConditionHandler::prev_report_mode; - _HFILE FatalConditionHandler::prev_report_file; - void (DOCTEST_CDECL *FatalConditionHandler::prev_sigabrt_handler)(int); - std::terminate_handler FatalConditionHandler::original_terminate_handler; - bool FatalConditionHandler::isSet = false; - ULONG FatalConditionHandler::guaranteeSize = 0; - LPTOP_LEVEL_EXCEPTION_FILTER FatalConditionHandler::previousTop = nullptr; - -#else // DOCTEST_PLATFORM_WINDOWS - - struct SignalDefs - { - int id; - const char* name; - }; - SignalDefs signalDefs[] = {{SIGINT, "SIGINT - Terminal interrupt signal"}, - {SIGILL, "SIGILL - Illegal instruction signal"}, - {SIGFPE, "SIGFPE - Floating point error signal"}, - {SIGSEGV, "SIGSEGV - Segmentation violation signal"}, - {SIGTERM, "SIGTERM - Termination request signal"}, - {SIGABRT, "SIGABRT - Abort (abnormal termination) signal"}}; - - struct FatalConditionHandler - { - static bool isSet; - static struct sigaction oldSigActions[DOCTEST_COUNTOF(signalDefs)]; - static stack_t oldSigStack; - static size_t altStackSize; - static char* altStackMem; - - static void handleSignal(int sig) { - const char* name = ""; - for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - SignalDefs& def = signalDefs[i]; - if(sig == def.id) { - name = def.name; - break; - } - } - reset(); - reportFatal(name); - raise(sig); - } - - static void allocateAltStackMem() { - altStackMem = new char[altStackSize]; - } - - static void freeAltStackMem() { - delete[] altStackMem; - } - - FatalConditionHandler() { - isSet = true; - stack_t sigStack; - sigStack.ss_sp = altStackMem; - sigStack.ss_size = altStackSize; - sigStack.ss_flags = 0; - sigaltstack(&sigStack, &oldSigStack); - struct sigaction sa = {}; - sa.sa_handler = handleSignal; - sa.sa_flags = SA_ONSTACK; - for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); - } - } - - ~FatalConditionHandler() { reset(); } - static void reset() { - if(isSet) { - // Set signals back to previous values -- hopefully nobody overwrote them in the meantime - for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { - sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); - } - // Return the old stack - sigaltstack(&oldSigStack, nullptr); - isSet = false; - } - } - }; - - bool FatalConditionHandler::isSet = false; - struct sigaction FatalConditionHandler::oldSigActions[DOCTEST_COUNTOF(signalDefs)] = {}; - stack_t FatalConditionHandler::oldSigStack = {}; - size_t FatalConditionHandler::altStackSize = 4 * SIGSTKSZ; - char* FatalConditionHandler::altStackMem = nullptr; - -#endif // DOCTEST_PLATFORM_WINDOWS -#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH - -} // namespace - -namespace { - using namespace detail; - -#ifdef DOCTEST_PLATFORM_WINDOWS -#define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) -#else - // TODO: integration with XCode and other IDEs -#define DOCTEST_OUTPUT_DEBUG_STRING(text) -#endif // Platform - - void addAssert(assertType::Enum at) { - if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional - g_cs->numAssertsCurrentTest_atomic++; - } - - void addFailedAssert(assertType::Enum at) { - if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional - g_cs->numAssertsFailedCurrentTest_atomic++; - } - -#if defined(DOCTEST_CONFIG_POSIX_SIGNALS) || defined(DOCTEST_CONFIG_WINDOWS_SEH) - void reportFatal(const std::string& message) { - g_cs->failure_flags |= TestCaseFailureReason::Crash; - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); - - while (g_cs->subcaseStack.size()) { - g_cs->subcaseStack.pop_back(); - DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); - } - - g_cs->finalizeTestCaseData(); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); - } -#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH -} // namespace - -AssertData::AssertData(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const StringContains& exception_string) - : m_test_case(g_cs->currentTest), m_at(at), m_file(file), m_line(line), m_expr(expr), - m_failed(true), m_threw(false), m_threw_as(false), m_exception_type(exception_type), - m_exception_string(exception_string) { -#if DOCTEST_MSVC - if (m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC - ++m_expr; -#endif // MSVC -} - -namespace detail { - ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const String& exception_string) - : AssertData(at, file, line, expr, exception_type, exception_string) { } - - ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, - const char* exception_type, const Contains& exception_string) - : AssertData(at, file, line, expr, exception_type, exception_string) { } - - void ResultBuilder::setResult(const Result& res) { - m_decomp = res.m_decomp; - m_failed = !res.m_passed; - } - - void ResultBuilder::translateException() { - m_threw = true; - m_exception = translateActiveException(); - } - - bool ResultBuilder::log() { - if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional - m_failed = !m_threw; - } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT - m_failed = !m_threw_as || !m_exception_string.check(m_exception); - } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional - m_failed = !m_threw_as; - } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional - m_failed = !m_exception_string.check(m_exception); - } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional - m_failed = m_threw; - } - - if(m_exception.size()) - m_exception = "\"" + m_exception + "\""; - - if(is_running_in_test) { - addAssert(m_at); - DOCTEST_ITERATE_THROUGH_REPORTERS(log_assert, *this); - - if(m_failed) - addFailedAssert(m_at); - } else if(m_failed) { - failed_out_of_a_testing_context(*this); - } - - return m_failed && isDebuggerActive() && !getContextOptions()->no_breaks && - (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger - } - - void ResultBuilder::react() const { - if(m_failed && checkIfShouldThrow(m_at)) - throwException(); - } - - void failed_out_of_a_testing_context(const AssertData& ad) { - if(g_cs->ah) - g_cs->ah(ad); - else - std::abort(); - } - - bool decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, - const Result& result) { - bool failed = !result.m_passed; - - // ################################################################################### - // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT - // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED - // ################################################################################### - DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); - DOCTEST_ASSERT_IN_TESTS(result.m_decomp); - return !failed; - } - - MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { - m_stream = tlssPush(); - m_file = file; - m_line = line; - m_severity = severity; - } - - MessageBuilder::~MessageBuilder() { - if (!logged) - tlssPop(); - } - - DOCTEST_DEFINE_INTERFACE(IExceptionTranslator) - - bool MessageBuilder::log() { - if (!logged) { - m_string = tlssPop(); - logged = true; - } - - DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); - - const bool isWarn = m_severity & assertType::is_warn; - - // warn is just a message in this context so we don't treat it as an assert - if(!isWarn) { - addAssert(m_severity); - addFailedAssert(m_severity); - } - - return isDebuggerActive() && !getContextOptions()->no_breaks && !isWarn && - (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger - } - - void MessageBuilder::react() { - if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional - throwException(); - } -} // namespace detail -namespace { - using namespace detail; - - // clang-format off - -// ================================================================================================= -// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp -// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. -// ================================================================================================= - - class XmlEncode { - public: - enum ForWhat { ForTextNodes, ForAttributes }; - - XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); - - void encodeTo( std::ostream& os ) const; - - friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); - - private: - std::string m_str; - ForWhat m_forWhat; - }; - - class XmlWriter { - public: - - class ScopedElement { - public: - ScopedElement( XmlWriter* writer ); - - ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT; - ScopedElement& operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT; - - ~ScopedElement(); - - ScopedElement& writeText( std::string const& text, bool indent = true ); - - template - ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { - m_writer->writeAttribute( name, attribute ); - return *this; - } - - private: - mutable XmlWriter* m_writer = nullptr; - }; - - XmlWriter( std::ostream& os = std::cout ); - ~XmlWriter(); - - XmlWriter( XmlWriter const& ) = delete; - XmlWriter& operator=( XmlWriter const& ) = delete; - - XmlWriter& startElement( std::string const& name ); - - ScopedElement scopedElement( std::string const& name ); - - XmlWriter& endElement(); - - XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); - - XmlWriter& writeAttribute( std::string const& name, const char* attribute ); - - XmlWriter& writeAttribute( std::string const& name, bool attribute ); - - template - XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { - std::stringstream rss; - rss << attribute; - return writeAttribute( name, rss.str() ); - } - - XmlWriter& writeText( std::string const& text, bool indent = true ); - - //XmlWriter& writeComment( std::string const& text ); - - //void writeStylesheetRef( std::string const& url ); - - //XmlWriter& writeBlankLine(); - - void ensureTagClosed(); - - void writeDeclaration(); - - private: - - void newlineIfNecessary(); - - bool m_tagIsOpen = false; - bool m_needsNewline = false; - std::vector m_tags; - std::string m_indent; - std::ostream& m_os; - }; - -// ================================================================================================= -// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp -// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. -// ================================================================================================= - -using uchar = unsigned char; - -namespace { - - size_t trailingBytes(unsigned char c) { - if ((c & 0xE0) == 0xC0) { - return 2; - } - if ((c & 0xF0) == 0xE0) { - return 3; - } - if ((c & 0xF8) == 0xF0) { - return 4; - } - DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); - } - - uint32_t headerValue(unsigned char c) { - if ((c & 0xE0) == 0xC0) { - return c & 0x1F; - } - if ((c & 0xF0) == 0xE0) { - return c & 0x0F; - } - if ((c & 0xF8) == 0xF0) { - return c & 0x07; - } - DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); - } - - void hexEscapeChar(std::ostream& os, unsigned char c) { - std::ios_base::fmtflags f(os.flags()); - os << "\\x" - << std::uppercase << std::hex << std::setfill('0') << std::setw(2) - << static_cast(c); - os.flags(f); - } - -} // anonymous namespace - - XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) - : m_str( str ), - m_forWhat( forWhat ) - {} - - void XmlEncode::encodeTo( std::ostream& os ) const { - // Apostrophe escaping not necessary if we always use " to write attributes - // (see: https://www.w3.org/TR/xml/#syntax) - - for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { - uchar c = m_str[idx]; - switch (c) { - case '<': os << "<"; break; - case '&': os << "&"; break; - - case '>': - // See: https://www.w3.org/TR/xml/#syntax - if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') - os << ">"; - else - os << c; - break; - - case '\"': - if (m_forWhat == ForAttributes) - os << """; - else - os << c; - break; - - default: - // Check for control characters and invalid utf-8 - - // Escape control characters in standard ascii - // see https://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 - if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { - hexEscapeChar(os, c); - break; - } - - // Plain ASCII: Write it to stream - if (c < 0x7F) { - os << c; - break; - } - - // UTF-8 territory - // Check if the encoding is valid and if it is not, hex escape bytes. - // Important: We do not check the exact decoded values for validity, only the encoding format - // First check that this bytes is a valid lead byte: - // This means that it is not encoded as 1111 1XXX - // Or as 10XX XXXX - if (c < 0xC0 || - c >= 0xF8) { - hexEscapeChar(os, c); - break; - } - - auto encBytes = trailingBytes(c); - // Are there enough bytes left to avoid accessing out-of-bounds memory? - if (idx + encBytes - 1 >= m_str.size()) { - hexEscapeChar(os, c); - break; - } - // The header is valid, check data - // The next encBytes bytes must together be a valid utf-8 - // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) - bool valid = true; - uint32_t value = headerValue(c); - for (std::size_t n = 1; n < encBytes; ++n) { - uchar nc = m_str[idx + n]; - valid &= ((nc & 0xC0) == 0x80); - value = (value << 6) | (nc & 0x3F); - } - - if ( - // Wrong bit pattern of following bytes - (!valid) || - // Overlong encodings - (value < 0x80) || - ( value < 0x800 && encBytes > 2) || // removed "0x80 <= value &&" because redundant - (0x800 < value && value < 0x10000 && encBytes > 3) || - // Encoded value out of range - (value >= 0x110000) - ) { - hexEscapeChar(os, c); - break; - } - - // If we got here, this is in fact a valid(ish) utf-8 sequence - for (std::size_t n = 0; n < encBytes; ++n) { - os << m_str[idx + n]; - } - idx += encBytes - 1; - break; - } - } - } - - std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { - xmlEncode.encodeTo( os ); - return os; - } - - XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) - : m_writer( writer ) - {} - - XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT - : m_writer( other.m_writer ){ - other.m_writer = nullptr; - } - XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT { - if ( m_writer ) { - m_writer->endElement(); - } - m_writer = other.m_writer; - other.m_writer = nullptr; - return *this; - } - - - XmlWriter::ScopedElement::~ScopedElement() { - if( m_writer ) - m_writer->endElement(); - } - - XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { - m_writer->writeText( text, indent ); - return *this; - } - - XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) - { - // writeDeclaration(); // called explicitly by the reporters that use the writer class - see issue #627 - } - - XmlWriter::~XmlWriter() { - while( !m_tags.empty() ) - endElement(); - } - - XmlWriter& XmlWriter::startElement( std::string const& name ) { - ensureTagClosed(); - newlineIfNecessary(); - m_os << m_indent << '<' << name; - m_tags.push_back( name ); - m_indent += " "; - m_tagIsOpen = true; - return *this; - } - - XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { - ScopedElement scoped( this ); - startElement( name ); - return scoped; - } - - XmlWriter& XmlWriter::endElement() { - newlineIfNecessary(); - m_indent = m_indent.substr( 0, m_indent.size()-2 ); - if( m_tagIsOpen ) { - m_os << "/>"; - m_tagIsOpen = false; - } - else { - m_os << m_indent << ""; - } - m_os << std::endl; - m_tags.pop_back(); - return *this; - } - - XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { - if( !name.empty() && !attribute.empty() ) - m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; - return *this; - } - - XmlWriter& XmlWriter::writeAttribute( std::string const& name, const char* attribute ) { - if( !name.empty() && attribute && attribute[0] != '\0' ) - m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; - return *this; - } - - XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { - m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; - return *this; - } - - XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { - if( !text.empty() ){ - bool tagWasOpen = m_tagIsOpen; - ensureTagClosed(); - if( tagWasOpen && indent ) - m_os << m_indent; - m_os << XmlEncode( text ); - m_needsNewline = true; - } - return *this; - } - - //XmlWriter& XmlWriter::writeComment( std::string const& text ) { - // ensureTagClosed(); - // m_os << m_indent << ""; - // m_needsNewline = true; - // return *this; - //} - - //void XmlWriter::writeStylesheetRef( std::string const& url ) { - // m_os << "\n"; - //} - - //XmlWriter& XmlWriter::writeBlankLine() { - // ensureTagClosed(); - // m_os << '\n'; - // return *this; - //} - - void XmlWriter::ensureTagClosed() { - if( m_tagIsOpen ) { - m_os << ">" << std::endl; - m_tagIsOpen = false; - } - } - - void XmlWriter::writeDeclaration() { - m_os << "\n"; - } - - void XmlWriter::newlineIfNecessary() { - if( m_needsNewline ) { - m_os << std::endl; - m_needsNewline = false; - } - } - -// ================================================================================================= -// End of copy-pasted code from Catch -// ================================================================================================= - - // clang-format on - - struct XmlReporter : public IReporter - { - XmlWriter xml; - DOCTEST_DECLARE_MUTEX(mutex) - - // caching pointers/references to objects of these types - safe to do - const ContextOptions& opt; - const TestCaseData* tc = nullptr; - - XmlReporter(const ContextOptions& co) - : xml(*co.cout) - , opt(co) {} - - void log_contexts() { - int num_contexts = get_num_active_contexts(); - if(num_contexts) { - auto contexts = get_active_contexts(); - std::stringstream ss; - for(int i = 0; i < num_contexts; ++i) { - contexts[i]->stringify(&ss); - xml.scopedElement("Info").writeText(ss.str()); - ss.str(""); - } - } - } - - unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } - - void test_case_start_impl(const TestCaseData& in) { - bool open_ts_tag = false; - if(tc != nullptr) { // we have already opened a test suite - if(std::strcmp(tc->m_test_suite, in.m_test_suite) != 0) { - xml.endElement(); - open_ts_tag = true; - } - } - else { - open_ts_tag = true; // first test case ==> first test suite - } - - if(open_ts_tag) { - xml.startElement("TestSuite"); - xml.writeAttribute("name", in.m_test_suite); - } - - tc = ∈ - xml.startElement("TestCase") - .writeAttribute("name", in.m_name) - .writeAttribute("filename", skipPathFromFilename(in.m_file.c_str())) - .writeAttribute("line", line(in.m_line)) - .writeAttribute("description", in.m_description); - - if(Approx(in.m_timeout) != 0) - xml.writeAttribute("timeout", in.m_timeout); - if(in.m_may_fail) - xml.writeAttribute("may_fail", true); - if(in.m_should_fail) - xml.writeAttribute("should_fail", true); - } - - // ========================================================================================= - // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE - // ========================================================================================= - - void report_query(const QueryData& in) override { - test_run_start(); - if(opt.list_reporters) { - for(auto& curr : getListeners()) - xml.scopedElement("Listener") - .writeAttribute("priority", curr.first.first) - .writeAttribute("name", curr.first.second); - for(auto& curr : getReporters()) - xml.scopedElement("Reporter") - .writeAttribute("priority", curr.first.first) - .writeAttribute("name", curr.first.second); - } else if(opt.count || opt.list_test_cases) { - for(unsigned i = 0; i < in.num_data; ++i) { - xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) - .writeAttribute("testsuite", in.data[i]->m_test_suite) - .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) - .writeAttribute("line", line(in.data[i]->m_line)) - .writeAttribute("skipped", in.data[i]->m_skip); - } - xml.scopedElement("OverallResultsTestCases") - .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); - } else if(opt.list_test_suites) { - for(unsigned i = 0; i < in.num_data; ++i) - xml.scopedElement("TestSuite").writeAttribute("name", in.data[i]->m_test_suite); - xml.scopedElement("OverallResultsTestCases") - .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); - xml.scopedElement("OverallResultsTestSuites") - .writeAttribute("unskipped", in.run_stats->numTestSuitesPassingFilters); - } - xml.endElement(); - } - - void test_run_start() override { - xml.writeDeclaration(); - - // remove .exe extension - mainly to have the same output on UNIX and Windows - std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); -#ifdef DOCTEST_PLATFORM_WINDOWS - if(binary_name.rfind(".exe") != std::string::npos) - binary_name = binary_name.substr(0, binary_name.length() - 4); -#endif // DOCTEST_PLATFORM_WINDOWS - - xml.startElement("doctest").writeAttribute("binary", binary_name); - if(opt.no_version == false) - xml.writeAttribute("version", DOCTEST_VERSION_STR); - - // only the consequential ones (TODO: filters) - xml.scopedElement("Options") - .writeAttribute("order_by", opt.order_by.c_str()) - .writeAttribute("rand_seed", opt.rand_seed) - .writeAttribute("first", opt.first) - .writeAttribute("last", opt.last) - .writeAttribute("abort_after", opt.abort_after) - .writeAttribute("subcase_filter_levels", opt.subcase_filter_levels) - .writeAttribute("case_sensitive", opt.case_sensitive) - .writeAttribute("no_throw", opt.no_throw) - .writeAttribute("no_skip", opt.no_skip); - } - - void test_run_end(const TestRunStats& p) override { - if(tc) // the TestSuite tag - only if there has been at least 1 test case - xml.endElement(); - - xml.scopedElement("OverallResultsAsserts") - .writeAttribute("successes", p.numAsserts - p.numAssertsFailed) - .writeAttribute("failures", p.numAssertsFailed); - - xml.startElement("OverallResultsTestCases") - .writeAttribute("successes", - p.numTestCasesPassingFilters - p.numTestCasesFailed) - .writeAttribute("failures", p.numTestCasesFailed); - if(opt.no_skipped_summary == false) - xml.writeAttribute("skipped", p.numTestCases - p.numTestCasesPassingFilters); - xml.endElement(); - - xml.endElement(); - } - - void test_case_start(const TestCaseData& in) override { - test_case_start_impl(in); - xml.ensureTagClosed(); - } - - void test_case_reenter(const TestCaseData&) override {} - - void test_case_end(const CurrentTestCaseStats& st) override { - xml.startElement("OverallResultsAsserts") - .writeAttribute("successes", - st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) - .writeAttribute("failures", st.numAssertsFailedCurrentTest) - .writeAttribute("test_case_success", st.testCaseSuccess); - if(opt.duration) - xml.writeAttribute("duration", st.seconds); - if(tc->m_expected_failures) - xml.writeAttribute("expected_failures", tc->m_expected_failures); - xml.endElement(); - - xml.endElement(); - } - - void test_case_exception(const TestCaseException& e) override { - DOCTEST_LOCK_MUTEX(mutex) - - xml.scopedElement("Exception") - .writeAttribute("crash", e.is_crash) - .writeText(e.error_string.c_str()); - } - - void subcase_start(const SubcaseSignature& in) override { - xml.startElement("SubCase") - .writeAttribute("name", in.m_name) - .writeAttribute("filename", skipPathFromFilename(in.m_file)) - .writeAttribute("line", line(in.m_line)); - xml.ensureTagClosed(); - } - - void subcase_end() override { xml.endElement(); } - - void log_assert(const AssertData& rb) override { - if(!rb.m_failed && !opt.success) - return; - - DOCTEST_LOCK_MUTEX(mutex) - - xml.startElement("Expression") - .writeAttribute("success", !rb.m_failed) - .writeAttribute("type", assertString(rb.m_at)) - .writeAttribute("filename", skipPathFromFilename(rb.m_file)) - .writeAttribute("line", line(rb.m_line)); - - xml.scopedElement("Original").writeText(rb.m_expr); - - if(rb.m_threw) - xml.scopedElement("Exception").writeText(rb.m_exception.c_str()); - - if(rb.m_at & assertType::is_throws_as) - xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); - if(rb.m_at & assertType::is_throws_with) - xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string.c_str()); - if((rb.m_at & assertType::is_normal) && !rb.m_threw) - xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); - - log_contexts(); - - xml.endElement(); - } - - void log_message(const MessageData& mb) override { - DOCTEST_LOCK_MUTEX(mutex) - - xml.startElement("Message") - .writeAttribute("type", failureString(mb.m_severity)) - .writeAttribute("filename", skipPathFromFilename(mb.m_file)) - .writeAttribute("line", line(mb.m_line)); - - xml.scopedElement("Text").writeText(mb.m_string.c_str()); - - log_contexts(); - - xml.endElement(); - } - - void test_case_skipped(const TestCaseData& in) override { - if(opt.no_skipped_summary == false) { - test_case_start_impl(in); - xml.writeAttribute("skipped", "true"); - xml.endElement(); - } - } - }; - - DOCTEST_REGISTER_REPORTER("xml", 0, XmlReporter); - - void fulltext_log_assert_to_stream(std::ostream& s, const AssertData& rb) { - if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) == - 0) //!OCLINT bitwise operator in conditional - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << " ) " - << Color::None; - - if(rb.m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional - s << (rb.m_threw ? "threw as expected!" : "did NOT throw at all!") << "\n"; - } else if((rb.m_at & assertType::is_throws_as) && - (rb.m_at & assertType::is_throws_with)) { //!OCLINT - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" - << rb.m_exception_string.c_str() - << "\", " << rb.m_exception_type << " ) " << Color::None; - if(rb.m_threw) { - if(!rb.m_failed) { - s << "threw as expected!\n"; - } else { - s << "threw a DIFFERENT exception! (contents: " << rb.m_exception << ")\n"; - } - } else { - s << "did NOT throw at all!\n"; - } - } else if(rb.m_at & - assertType::is_throws_as) { //!OCLINT bitwise operator in conditional - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", " - << rb.m_exception_type << " ) " << Color::None - << (rb.m_threw ? (rb.m_threw_as ? "threw as expected!" : - "threw a DIFFERENT exception: ") : - "did NOT throw at all!") - << Color::Cyan << rb.m_exception << "\n"; - } else if(rb.m_at & - assertType::is_throws_with) { //!OCLINT bitwise operator in conditional - s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" - << rb.m_exception_string.c_str() - << "\" ) " << Color::None - << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : - "threw a DIFFERENT exception: ") : - "did NOT throw at all!") - << Color::Cyan << rb.m_exception << "\n"; - } else if(rb.m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional - s << (rb.m_threw ? "THREW exception: " : "didn't throw!") << Color::Cyan - << rb.m_exception << "\n"; - } else { - s << (rb.m_threw ? "THREW exception: " : - (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")); - if(rb.m_threw) - s << rb.m_exception << "\n"; - else - s << " values: " << assertString(rb.m_at) << "( " << rb.m_decomp << " )\n"; - } - } - - // TODO: - // - log_message() - // - respond to queries - // - honor remaining options - // - more attributes in tags - struct JUnitReporter : public IReporter - { - XmlWriter xml; - DOCTEST_DECLARE_MUTEX(mutex) - Timer timer; - std::vector deepestSubcaseStackNames; - - struct JUnitTestCaseData - { - static std::string getCurrentTimestamp() { - // Beware, this is not reentrant because of backward compatibility issues - // Also, UTC only, again because of backward compatibility (%z is C++11) - time_t rawtime; - std::time(&rawtime); - auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); - - std::tm timeInfo; -#ifdef DOCTEST_PLATFORM_WINDOWS - gmtime_s(&timeInfo, &rawtime); -#else // DOCTEST_PLATFORM_WINDOWS - gmtime_r(&rawtime, &timeInfo); -#endif // DOCTEST_PLATFORM_WINDOWS - - char timeStamp[timeStampSize]; - const char* const fmt = "%Y-%m-%dT%H:%M:%SZ"; - - std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); - return std::string(timeStamp); - } - - struct JUnitTestMessage - { - JUnitTestMessage(const std::string& _message, const std::string& _type, const std::string& _details) - : message(_message), type(_type), details(_details) {} - - JUnitTestMessage(const std::string& _message, const std::string& _details) - : message(_message), type(), details(_details) {} - - std::string message, type, details; - }; - - struct JUnitTestCase - { - JUnitTestCase(const std::string& _classname, const std::string& _name) - : classname(_classname), name(_name), time(0), failures() {} - - std::string classname, name; - double time; - std::vector failures, errors; - }; - - void add(const std::string& classname, const std::string& name) { - testcases.emplace_back(classname, name); - } - - void appendSubcaseNamesToLastTestcase(std::vector nameStack) { - for(auto& curr: nameStack) - if(curr.size()) - testcases.back().name += std::string("/") + curr.c_str(); - } - - void addTime(double time) { - if(time < 1e-4) - time = 0; - testcases.back().time = time; - totalSeconds += time; - } - - void addFailure(const std::string& message, const std::string& type, const std::string& details) { - testcases.back().failures.emplace_back(message, type, details); - ++totalFailures; - } - - void addError(const std::string& message, const std::string& details) { - testcases.back().errors.emplace_back(message, details); - ++totalErrors; - } - - std::vector testcases; - double totalSeconds = 0; - int totalErrors = 0, totalFailures = 0; - }; - - JUnitTestCaseData testCaseData; - - // caching pointers/references to objects of these types - safe to do - const ContextOptions& opt; - const TestCaseData* tc = nullptr; - - JUnitReporter(const ContextOptions& co) - : xml(*co.cout) - , opt(co) {} - - unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } - - // ========================================================================================= - // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE - // ========================================================================================= - - void report_query(const QueryData&) override { - xml.writeDeclaration(); - } - - void test_run_start() override { - xml.writeDeclaration(); - } - - void test_run_end(const TestRunStats& p) override { - // remove .exe extension - mainly to have the same output on UNIX and Windows - std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); -#ifdef DOCTEST_PLATFORM_WINDOWS - if(binary_name.rfind(".exe") != std::string::npos) - binary_name = binary_name.substr(0, binary_name.length() - 4); -#endif // DOCTEST_PLATFORM_WINDOWS - xml.startElement("testsuites"); - xml.startElement("testsuite").writeAttribute("name", binary_name) - .writeAttribute("errors", testCaseData.totalErrors) - .writeAttribute("failures", testCaseData.totalFailures) - .writeAttribute("tests", p.numAsserts); - if(opt.no_time_in_output == false) { - xml.writeAttribute("time", testCaseData.totalSeconds); - xml.writeAttribute("timestamp", JUnitTestCaseData::getCurrentTimestamp()); - } - if(opt.no_version == false) - xml.writeAttribute("doctest_version", DOCTEST_VERSION_STR); - - for(const auto& testCase : testCaseData.testcases) { - xml.startElement("testcase") - .writeAttribute("classname", testCase.classname) - .writeAttribute("name", testCase.name); - if(opt.no_time_in_output == false) - xml.writeAttribute("time", testCase.time); - // This is not ideal, but it should be enough to mimic gtest's junit output. - xml.writeAttribute("status", "run"); - - for(const auto& failure : testCase.failures) { - xml.scopedElement("failure") - .writeAttribute("message", failure.message) - .writeAttribute("type", failure.type) - .writeText(failure.details, false); - } - - for(const auto& error : testCase.errors) { - xml.scopedElement("error") - .writeAttribute("message", error.message) - .writeText(error.details); - } - - xml.endElement(); - } - xml.endElement(); - xml.endElement(); - } - - void test_case_start(const TestCaseData& in) override { - testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); - timer.start(); - } - - void test_case_reenter(const TestCaseData& in) override { - testCaseData.addTime(timer.getElapsedSeconds()); - testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); - deepestSubcaseStackNames.clear(); - - timer.start(); - testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); - } - - void test_case_end(const CurrentTestCaseStats&) override { - testCaseData.addTime(timer.getElapsedSeconds()); - testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); - deepestSubcaseStackNames.clear(); - } - - void test_case_exception(const TestCaseException& e) override { - DOCTEST_LOCK_MUTEX(mutex) - testCaseData.addError("exception", e.error_string.c_str()); - } - - void subcase_start(const SubcaseSignature& in) override { - deepestSubcaseStackNames.push_back(in.m_name); - } - - void subcase_end() override {} - - void log_assert(const AssertData& rb) override { - if(!rb.m_failed) // report only failures & ignore the `success` option - return; - - DOCTEST_LOCK_MUTEX(mutex) - - std::ostringstream os; - os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") - << line(rb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; - - fulltext_log_assert_to_stream(os, rb); - log_contexts(os); - testCaseData.addFailure(rb.m_decomp.c_str(), assertString(rb.m_at), os.str()); - } - - void log_message(const MessageData&) override {} - - void test_case_skipped(const TestCaseData&) override {} - - void log_contexts(std::ostringstream& s) { - int num_contexts = get_num_active_contexts(); - if(num_contexts) { - auto contexts = get_active_contexts(); - - s << " logged: "; - for(int i = 0; i < num_contexts; ++i) { - s << (i == 0 ? "" : " "); - contexts[i]->stringify(&s); - s << std::endl; - } - } - } - }; - - DOCTEST_REGISTER_REPORTER("junit", 0, JUnitReporter); - - struct Whitespace - { - int nrSpaces; - explicit Whitespace(int nr) - : nrSpaces(nr) {} - }; - - std::ostream& operator<<(std::ostream& out, const Whitespace& ws) { - if(ws.nrSpaces != 0) - out << std::setw(ws.nrSpaces) << ' '; - return out; - } - - struct ConsoleReporter : public IReporter - { - std::ostream& s; - bool hasLoggedCurrentTestStart; - std::vector subcasesStack; - size_t currentSubcaseLevel; - DOCTEST_DECLARE_MUTEX(mutex) - - // caching pointers/references to objects of these types - safe to do - const ContextOptions& opt; - const TestCaseData* tc; - - ConsoleReporter(const ContextOptions& co) - : s(*co.cout) - , opt(co) {} - - ConsoleReporter(const ContextOptions& co, std::ostream& ostr) - : s(ostr) - , opt(co) {} - - // ========================================================================================= - // WHAT FOLLOWS ARE HELPERS USED BY THE OVERRIDES OF THE VIRTUAL METHODS OF THE INTERFACE - // ========================================================================================= - - void separator_to_stream() { - s << Color::Yellow - << "===============================================================================" - "\n"; - } - - const char* getSuccessOrFailString(bool success, assertType::Enum at, - const char* success_str) { - if(success) - return success_str; - return failureString(at); - } - - Color::Enum getSuccessOrFailColor(bool success, assertType::Enum at) { - return success ? Color::BrightGreen : - (at & assertType::is_warn) ? Color::Yellow : Color::Red; - } - - void successOrFailColoredStringToStream(bool success, assertType::Enum at, - const char* success_str = "SUCCESS") { - s << getSuccessOrFailColor(success, at) - << getSuccessOrFailString(success, at, success_str) << ": "; - } - - void log_contexts() { - int num_contexts = get_num_active_contexts(); - if(num_contexts) { - auto contexts = get_active_contexts(); - - s << Color::None << " logged: "; - for(int i = 0; i < num_contexts; ++i) { - s << (i == 0 ? "" : " "); - contexts[i]->stringify(&s); - s << "\n"; - } - } - - s << "\n"; - } - - // this was requested to be made virtual so users could override it - virtual void file_line_to_stream(const char* file, int line, - const char* tail = "") { - s << Color::LightGrey << skipPathFromFilename(file) << (opt.gnu_file_line ? ":" : "(") - << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option - << (opt.gnu_file_line ? ":" : "):") << tail; - } - - void logTestStart() { - if(hasLoggedCurrentTestStart) - return; - - separator_to_stream(); - file_line_to_stream(tc->m_file.c_str(), tc->m_line, "\n"); - if(tc->m_description) - s << Color::Yellow << "DESCRIPTION: " << Color::None << tc->m_description << "\n"; - if(tc->m_test_suite && tc->m_test_suite[0] != '\0') - s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n"; - if(strncmp(tc->m_name, " Scenario:", 11) != 0) - s << Color::Yellow << "TEST CASE: "; - s << Color::None << tc->m_name << "\n"; - - for(size_t i = 0; i < currentSubcaseLevel; ++i) { - if(subcasesStack[i].m_name[0] != '\0') - s << " " << subcasesStack[i].m_name << "\n"; - } - - if(currentSubcaseLevel != subcasesStack.size()) { - s << Color::Yellow << "\nDEEPEST SUBCASE STACK REACHED (DIFFERENT FROM THE CURRENT ONE):\n" << Color::None; - for(size_t i = 0; i < subcasesStack.size(); ++i) { - if(subcasesStack[i].m_name[0] != '\0') - s << " " << subcasesStack[i].m_name << "\n"; - } - } - - s << "\n"; - - hasLoggedCurrentTestStart = true; - } - - void printVersion() { - if(opt.no_version == false) - s << Color::Cyan << "[doctest] " << Color::None << "doctest version is \"" - << DOCTEST_VERSION_STR << "\"\n"; - } - - void printIntro() { - if(opt.no_intro == false) { - printVersion(); - s << Color::Cyan << "[doctest] " << Color::None - << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; - } - } - - void printHelp() { - int sizePrefixDisplay = static_cast(strlen(DOCTEST_OPTIONS_PREFIX_DISPLAY)); - printVersion(); - // clang-format off - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "boolean values: \"1/on/yes/true\" or \"0/off/no/false\"\n"; - s << Color::Cyan << "[doctest] " << Color::None; - s << "filter values: \"str1,str2,str3\" (comma separated strings)\n"; - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "filters use wildcards for matching strings\n"; - s << Color::Cyan << "[doctest] " << Color::None; - s << "something passes a filter if any of the strings in a filter matches\n"; -#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "ALL FLAGS, OPTIONS AND FILTERS ALSO AVAILABLE WITH A \"" DOCTEST_CONFIG_OPTIONS_PREFIX "\" PREFIX!!!\n"; -#endif - s << Color::Cyan << "[doctest]\n" << Color::None; - s << Color::Cyan << "[doctest] " << Color::None; - s << "Query flags - the program quits after them. Available:\n\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "?, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "help, -" DOCTEST_OPTIONS_PREFIX_DISPLAY "h " - << Whitespace(sizePrefixDisplay*0) << "prints this message\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "v, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "version " - << Whitespace(sizePrefixDisplay*1) << "prints the version\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "c, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "count " - << Whitespace(sizePrefixDisplay*1) << "prints the number of matching tests\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ltc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-cases " - << Whitespace(sizePrefixDisplay*1) << "lists all matching tests by name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-suites " - << Whitespace(sizePrefixDisplay*1) << "lists all matching test suites\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-reporters " - << Whitespace(sizePrefixDisplay*1) << "lists all registered reporters\n\n"; - // ================================================================================== << 79 - s << Color::Cyan << "[doctest] " << Color::None; - s << "The available / options/filters are:\n\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case= " - << Whitespace(sizePrefixDisplay*1) << "filters tests by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file= " - << Whitespace(sizePrefixDisplay*1) << "filters tests by their file\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sfe, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their file\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite= " - << Whitespace(sizePrefixDisplay*1) << "filters tests by their test suite\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tse, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their test suite\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase= " - << Whitespace(sizePrefixDisplay*1) << "filters subcases by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-exclude= " - << Whitespace(sizePrefixDisplay*1) << "filters OUT subcases by their name\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "r, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "reporters= " - << Whitespace(sizePrefixDisplay*1) << "reporters to use (console is default)\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "o, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "out= " - << Whitespace(sizePrefixDisplay*1) << "output filename\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ob, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "order-by= " - << Whitespace(sizePrefixDisplay*1) << "how the tests should be ordered\n"; - s << Whitespace(sizePrefixDisplay*3) << " - [file/suite/name/rand/none]\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "rs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "rand-seed= " - << Whitespace(sizePrefixDisplay*1) << "seed for random ordering\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "f, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "first= " - << Whitespace(sizePrefixDisplay*1) << "the first test passing the filters to\n"; - s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "l, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "last= " - << Whitespace(sizePrefixDisplay*1) << "the last test passing the filters to\n"; - s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "aa, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "abort-after= " - << Whitespace(sizePrefixDisplay*1) << "stop after failed assertions\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "scfl,--" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-filter-levels= " - << Whitespace(sizePrefixDisplay*1) << "apply filters for the first levels\n"; - s << Color::Cyan << "\n[doctest] " << Color::None; - s << "Bool options - can be used like flags and true is assumed. Available:\n\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "s, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "success= " - << Whitespace(sizePrefixDisplay*1) << "include successful assertions in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "cs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "case-sensitive= " - << Whitespace(sizePrefixDisplay*1) << "filters being treated as case sensitive\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "e, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "exit= " - << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " - << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "m, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "minimal= " - << Whitespace(sizePrefixDisplay*1) << "minimal console output (only failures)\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "q, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "quiet= " - << Whitespace(sizePrefixDisplay*1) << "no console output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " - << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " - << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " - << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ni, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-intro= " - << Whitespace(sizePrefixDisplay*1) << "omit the framework intro in the output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " - << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " - << Whitespace(sizePrefixDisplay*1) << "disables colors in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "fc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "force-colors= " - << Whitespace(sizePrefixDisplay*1) << "use colors even when not in a tty\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nb, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-breaks= " - << Whitespace(sizePrefixDisplay*1) << "disables breakpoints in debuggers\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ns, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-skip= " - << Whitespace(sizePrefixDisplay*1) << "don't skip test cases marked as skip\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "gfl, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "gnu-file-line= " - << Whitespace(sizePrefixDisplay*1) << ":n: vs (n): for line numbers in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "npf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-path-filenames= " - << Whitespace(sizePrefixDisplay*1) << "only filenames and no paths in output\n"; - s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nln, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-line-numbers= " - << Whitespace(sizePrefixDisplay*1) << "0 instead of real line numbers in output\n"; - // ================================================================================== << 79 - // clang-format on - - s << Color::Cyan << "\n[doctest] " << Color::None; - s << "for more information visit the project documentation\n\n"; - } - - void printRegisteredReporters() { - printVersion(); - auto printReporters = [this] (const reporterMap& reporters, const char* type) { - if(reporters.size()) { - s << Color::Cyan << "[doctest] " << Color::None << "listing all registered " << type << "\n"; - for(auto& curr : reporters) - s << "priority: " << std::setw(5) << curr.first.first - << " name: " << curr.first.second << "\n"; - } - }; - printReporters(getListeners(), "listeners"); - printReporters(getReporters(), "reporters"); - } - - // ========================================================================================= - // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE - // ========================================================================================= - - void report_query(const QueryData& in) override { - if(opt.version) { - printVersion(); - } else if(opt.help) { - printHelp(); - } else if(opt.list_reporters) { - printRegisteredReporters(); - } else if(opt.count || opt.list_test_cases) { - if(opt.list_test_cases) { - s << Color::Cyan << "[doctest] " << Color::None - << "listing all test case names\n"; - separator_to_stream(); - } - - for(unsigned i = 0; i < in.num_data; ++i) - s << Color::None << in.data[i]->m_name << "\n"; - - separator_to_stream(); - - s << Color::Cyan << "[doctest] " << Color::None - << "unskipped test cases passing the current filters: " - << g_cs->numTestCasesPassingFilters << "\n"; - - } else if(opt.list_test_suites) { - s << Color::Cyan << "[doctest] " << Color::None << "listing all test suites\n"; - separator_to_stream(); - - for(unsigned i = 0; i < in.num_data; ++i) - s << Color::None << in.data[i]->m_test_suite << "\n"; - - separator_to_stream(); - - s << Color::Cyan << "[doctest] " << Color::None - << "unskipped test cases passing the current filters: " - << g_cs->numTestCasesPassingFilters << "\n"; - s << Color::Cyan << "[doctest] " << Color::None - << "test suites with unskipped test cases passing the current filters: " - << g_cs->numTestSuitesPassingFilters << "\n"; - } - } - - void test_run_start() override { - if(!opt.minimal) - printIntro(); - } - - void test_run_end(const TestRunStats& p) override { - if(opt.minimal && p.numTestCasesFailed == 0) - return; - - separator_to_stream(); - s << std::dec; - - auto totwidth = int(std::ceil(log10((std::max(p.numTestCasesPassingFilters, static_cast(p.numAsserts))) + 1))); - auto passwidth = int(std::ceil(log10((std::max(p.numTestCasesPassingFilters - p.numTestCasesFailed, static_cast(p.numAsserts - p.numAssertsFailed))) + 1))); - auto failwidth = int(std::ceil(log10((std::max(p.numTestCasesFailed, static_cast(p.numAssertsFailed))) + 1))); - const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0; - s << Color::Cyan << "[doctest] " << Color::None << "test cases: " << std::setw(totwidth) - << p.numTestCasesPassingFilters << " | " - << ((p.numTestCasesPassingFilters == 0 || anythingFailed) ? Color::None : - Color::Green) - << std::setw(passwidth) << p.numTestCasesPassingFilters - p.numTestCasesFailed << " passed" - << Color::None << " | " << (p.numTestCasesFailed > 0 ? Color::Red : Color::None) - << std::setw(failwidth) << p.numTestCasesFailed << " failed" << Color::None << " |"; - if(opt.no_skipped_summary == false) { - const int numSkipped = p.numTestCases - p.numTestCasesPassingFilters; - s << " " << (numSkipped == 0 ? Color::None : Color::Yellow) << numSkipped - << " skipped" << Color::None; - } - s << "\n"; - s << Color::Cyan << "[doctest] " << Color::None << "assertions: " << std::setw(totwidth) - << p.numAsserts << " | " - << ((p.numAsserts == 0 || anythingFailed) ? Color::None : Color::Green) - << std::setw(passwidth) << (p.numAsserts - p.numAssertsFailed) << " passed" << Color::None - << " | " << (p.numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(failwidth) - << p.numAssertsFailed << " failed" << Color::None << " |\n"; - s << Color::Cyan << "[doctest] " << Color::None - << "Status: " << (p.numTestCasesFailed > 0 ? Color::Red : Color::Green) - << ((p.numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl; - } - - void test_case_start(const TestCaseData& in) override { - hasLoggedCurrentTestStart = false; - tc = ∈ - subcasesStack.clear(); - currentSubcaseLevel = 0; - } - - void test_case_reenter(const TestCaseData&) override { - subcasesStack.clear(); - } - - void test_case_end(const CurrentTestCaseStats& st) override { - if(tc->m_no_output) - return; - - // log the preamble of the test case only if there is something - // else to print - something other than that an assert has failed - if(opt.duration || - (st.failure_flags && st.failure_flags != static_cast(TestCaseFailureReason::AssertFailure))) - logTestStart(); - - if(opt.duration) - s << Color::None << std::setprecision(6) << std::fixed << st.seconds - << " s: " << tc->m_name << "\n"; - - if(st.failure_flags & TestCaseFailureReason::Timeout) - s << Color::Red << "Test case exceeded time limit of " << std::setprecision(6) - << std::fixed << tc->m_timeout << "!\n"; - - if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedButDidnt) { - s << Color::Red << "Should have failed but didn't! Marking it as failed!\n"; - } else if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedAndDid) { - s << Color::Yellow << "Failed as expected so marking it as not failed\n"; - } else if(st.failure_flags & TestCaseFailureReason::CouldHaveFailedAndDid) { - s << Color::Yellow << "Allowed to fail so marking it as not failed\n"; - } else if(st.failure_flags & TestCaseFailureReason::DidntFailExactlyNumTimes) { - s << Color::Red << "Didn't fail exactly " << tc->m_expected_failures - << " times so marking it as failed!\n"; - } else if(st.failure_flags & TestCaseFailureReason::FailedExactlyNumTimes) { - s << Color::Yellow << "Failed exactly " << tc->m_expected_failures - << " times as expected so marking it as not failed!\n"; - } - if(st.failure_flags & TestCaseFailureReason::TooManyFailedAsserts) { - s << Color::Red << "Aborting - too many failed asserts!\n"; - } - s << Color::None; // lgtm [cpp/useless-expression] - } - - void test_case_exception(const TestCaseException& e) override { - DOCTEST_LOCK_MUTEX(mutex) - if(tc->m_no_output) - return; - - logTestStart(); - - file_line_to_stream(tc->m_file.c_str(), tc->m_line, " "); - successOrFailColoredStringToStream(false, e.is_crash ? assertType::is_require : - assertType::is_check); - s << Color::Red << (e.is_crash ? "test case CRASHED: " : "test case THREW exception: ") - << Color::Cyan << e.error_string << "\n"; - - int num_stringified_contexts = get_num_stringified_contexts(); - if(num_stringified_contexts) { - auto stringified_contexts = get_stringified_contexts(); - s << Color::None << " logged: "; - for(int i = num_stringified_contexts; i > 0; --i) { - s << (i == num_stringified_contexts ? "" : " ") - << stringified_contexts[i - 1] << "\n"; - } - } - s << "\n" << Color::None; - } - - void subcase_start(const SubcaseSignature& subc) override { - subcasesStack.push_back(subc); - ++currentSubcaseLevel; - hasLoggedCurrentTestStart = false; - } - - void subcase_end() override { - --currentSubcaseLevel; - hasLoggedCurrentTestStart = false; - } - - void log_assert(const AssertData& rb) override { - if((!rb.m_failed && !opt.success) || tc->m_no_output) - return; - - DOCTEST_LOCK_MUTEX(mutex) - - logTestStart(); - - file_line_to_stream(rb.m_file, rb.m_line, " "); - successOrFailColoredStringToStream(!rb.m_failed, rb.m_at); - - fulltext_log_assert_to_stream(s, rb); - - log_contexts(); - } - - void log_message(const MessageData& mb) override { - if(tc->m_no_output) - return; - - DOCTEST_LOCK_MUTEX(mutex) - - logTestStart(); - - file_line_to_stream(mb.m_file, mb.m_line, " "); - s << getSuccessOrFailColor(false, mb.m_severity) - << getSuccessOrFailString(mb.m_severity & assertType::is_warn, mb.m_severity, - "MESSAGE") << ": "; - s << Color::None << mb.m_string << "\n"; - log_contexts(); - } - - void test_case_skipped(const TestCaseData&) override {} - }; - - DOCTEST_REGISTER_REPORTER("console", 0, ConsoleReporter); - -#ifdef DOCTEST_PLATFORM_WINDOWS - struct DebugOutputWindowReporter : public ConsoleReporter - { - DOCTEST_THREAD_LOCAL static std::ostringstream oss; - - DebugOutputWindowReporter(const ContextOptions& co) - : ConsoleReporter(co, oss) {} - -#define DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(func, type, arg) \ - void func(type arg) override { \ - bool with_col = g_no_colors; \ - g_no_colors = false; \ - ConsoleReporter::func(arg); \ - if(oss.tellp() != std::streampos{}) { \ - DOCTEST_OUTPUT_DEBUG_STRING(oss.str().c_str()); \ - oss.str(""); \ - } \ - g_no_colors = with_col; \ - } - - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_start, DOCTEST_EMPTY, DOCTEST_EMPTY) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_end, const TestRunStats&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_start, const TestCaseData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_reenter, const TestCaseData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_end, const CurrentTestCaseStats&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_exception, const TestCaseException&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_start, const SubcaseSignature&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_end, DOCTEST_EMPTY, DOCTEST_EMPTY) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_assert, const AssertData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_message, const MessageData&, in) - DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_skipped, const TestCaseData&, in) - }; - - DOCTEST_THREAD_LOCAL std::ostringstream DebugOutputWindowReporter::oss; -#endif // DOCTEST_PLATFORM_WINDOWS - - // the implementation of parseOption() - bool parseOptionImpl(int argc, const char* const* argv, const char* pattern, String* value) { - // going from the end to the beginning and stopping on the first occurrence from the end - for(int i = argc; i > 0; --i) { - auto index = i - 1; - auto temp = std::strstr(argv[index], pattern); - if(temp && (value || strlen(temp) == strlen(pattern))) { //!OCLINT prefer early exits and continue - // eliminate matches in which the chars before the option are not '-' - bool noBadCharsFound = true; - auto curr = argv[index]; - while(curr != temp) { - if(*curr++ != '-') { - noBadCharsFound = false; - break; - } - } - if(noBadCharsFound && argv[index][0] == '-') { - if(value) { - // parsing the value of an option - temp += strlen(pattern); - const unsigned len = strlen(temp); - if(len) { - *value = temp; - return true; - } - } else { - // just a flag - no value - return true; - } - } - } - } - return false; - } - - // parses an option and returns the string after the '=' character - bool parseOption(int argc, const char* const* argv, const char* pattern, String* value = nullptr, - const String& defaultVal = String()) { - if(value) - *value = defaultVal; -#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS - // offset (normally 3 for "dt-") to skip prefix - if(parseOptionImpl(argc, argv, pattern + strlen(DOCTEST_CONFIG_OPTIONS_PREFIX), value)) - return true; -#endif // DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS - return parseOptionImpl(argc, argv, pattern, value); - } - - // locates a flag on the command line - bool parseFlag(int argc, const char* const* argv, const char* pattern) { - return parseOption(argc, argv, pattern); - } - - // parses a comma separated list of words after a pattern in one of the arguments in argv - bool parseCommaSepArgs(int argc, const char* const* argv, const char* pattern, - std::vector& res) { - String filtersString; - if(parseOption(argc, argv, pattern, &filtersString)) { - // tokenize with "," as a separator, unless escaped with backslash - std::ostringstream s; - auto flush = [&s, &res]() { - auto string = s.str(); - if(string.size() > 0) { - res.push_back(string.c_str()); - } - s.str(""); - }; - - bool seenBackslash = false; - const char* current = filtersString.c_str(); - const char* end = current + strlen(current); - while(current != end) { - char character = *current++; - if(seenBackslash) { - seenBackslash = false; - if(character == ',' || character == '\\') { - s.put(character); - continue; - } - s.put('\\'); - } - if(character == '\\') { - seenBackslash = true; - } else if(character == ',') { - flush(); - } else { - s.put(character); - } - } - - if(seenBackslash) { - s.put('\\'); - } - flush(); - return true; - } - return false; - } - - enum optionType - { - option_bool, - option_int - }; - - // parses an int/bool option from the command line - bool parseIntOption(int argc, const char* const* argv, const char* pattern, optionType type, - int& res) { - String parsedValue; - if(!parseOption(argc, argv, pattern, &parsedValue)) - return false; - - if(type) { - // integer - // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... - int theInt = std::atoi(parsedValue.c_str()); - if (theInt != 0) { - res = theInt; //!OCLINT parameter reassignment - return true; - } - } else { - // boolean - const char positive[][5] = { "1", "true", "on", "yes" }; // 5 - strlen("true") + 1 - const char negative[][6] = { "0", "false", "off", "no" }; // 6 - strlen("false") + 1 - - // if the value matches any of the positive/negative possibilities - for (unsigned i = 0; i < 4; i++) { - if (parsedValue.compare(positive[i], true) == 0) { - res = 1; //!OCLINT parameter reassignment - return true; - } - if (parsedValue.compare(negative[i], true) == 0) { - res = 0; //!OCLINT parameter reassignment - return true; - } - } - } - return false; - } -} // namespace - -Context::Context(int argc, const char* const* argv) - : p(new detail::ContextState) { - parseArgs(argc, argv, true); - if(argc) - p->binary_name = argv[0]; -} - -Context::~Context() { - if(g_cs == p) - g_cs = nullptr; - delete p; -} - -void Context::applyCommandLine(int argc, const char* const* argv) { - parseArgs(argc, argv); - if(argc) - p->binary_name = argv[0]; -} - -// parses args -void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { - using namespace detail; - - // clang-format off - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file=", p->filters[0]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sf=", p->filters[0]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file-exclude=",p->filters[1]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sfe=", p->filters[1]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite=", p->filters[2]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ts=", p->filters[2]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite-exclude=", p->filters[3]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tse=", p->filters[3]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case=", p->filters[4]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tc=", p->filters[4]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case-exclude=", p->filters[5]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tce=", p->filters[5]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase=", p->filters[6]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sc=", p->filters[6]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase-exclude=", p->filters[7]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sce=", p->filters[7]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "reporters=", p->filters[8]); - parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "r=", p->filters[8]); - // clang-format on - - int intRes = 0; - String strRes; - -#define DOCTEST_PARSE_AS_BOOL_OR_FLAG(name, sname, var, default) \ - if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_bool, intRes) || \ - parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_bool, intRes)) \ - p->var = static_cast(intRes); \ - else if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name) || \ - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname)) \ - p->var = true; \ - else if(withDefaults) \ - p->var = default - -#define DOCTEST_PARSE_INT_OPTION(name, sname, var, default) \ - if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_int, intRes) || \ - parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_int, intRes)) \ - p->var = intRes; \ - else if(withDefaults) \ - p->var = default - -#define DOCTEST_PARSE_STR_OPTION(name, sname, var, default) \ - if(parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", &strRes, default) || \ - parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", &strRes, default) || \ - withDefaults) \ - p->var = strRes - - // clang-format off - DOCTEST_PARSE_STR_OPTION("out", "o", out, ""); - DOCTEST_PARSE_STR_OPTION("order-by", "ob", order_by, "file"); - DOCTEST_PARSE_INT_OPTION("rand-seed", "rs", rand_seed, 0); - - DOCTEST_PARSE_INT_OPTION("first", "f", first, 0); - DOCTEST_PARSE_INT_OPTION("last", "l", last, UINT_MAX); - - DOCTEST_PARSE_INT_OPTION("abort-after", "aa", abort_after, 0); - DOCTEST_PARSE_INT_OPTION("subcase-filter-levels", "scfl", subcase_filter_levels, INT_MAX); - - DOCTEST_PARSE_AS_BOOL_OR_FLAG("success", "s", success, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("minimal", "m", minimal, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("quiet", "q", quiet, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-intro", "ni", no_intro, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-breaks", "nb", no_breaks, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skip", "ns", no_skip, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("gnu-file-line", "gfl", gnu_file_line, !bool(DOCTEST_MSVC)); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-path-filenames", "npf", no_path_in_filenames, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-line-numbers", "nln", no_line_numbers, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-debug-output", "ndo", no_debug_output, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skipped-summary", "nss", no_skipped_summary, false); - DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-time-in-output", "ntio", no_time_in_output, false); - // clang-format on - - if(withDefaults) { - p->help = false; - p->version = false; - p->count = false; - p->list_test_cases = false; - p->list_test_suites = false; - p->list_reporters = false; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "help") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "h") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "?")) { - p->help = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "version") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "v")) { - p->version = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "count") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "c")) { - p->count = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-cases") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ltc")) { - p->list_test_cases = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-suites") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lts")) { - p->list_test_suites = true; - p->exit = true; - } - if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-reporters") || - parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lr")) { - p->list_reporters = true; - p->exit = true; - } -} - -// allows the user to add procedurally to the filters from the command line -void Context::addFilter(const char* filter, const char* value) { setOption(filter, value); } - -// allows the user to clear all filters from the command line -void Context::clearFilters() { - for(auto& curr : p->filters) - curr.clear(); -} - -// allows the user to override procedurally the bool options from the command line -void Context::setOption(const char* option, bool value) { - setOption(option, value ? "true" : "false"); -} - -// allows the user to override procedurally the int options from the command line -void Context::setOption(const char* option, int value) { - setOption(option, toString(value).c_str()); -} - -// allows the user to override procedurally the string options from the command line -void Context::setOption(const char* option, const char* value) { - auto argv = String("-") + option + "=" + value; - auto lvalue = argv.c_str(); - parseArgs(1, &lvalue); -} - -// users should query this in their main() and exit the program if true -bool Context::shouldExit() { return p->exit; } - -void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } - -void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } - -void Context::setCout(std::ostream* out) { p->cout = out; } - -static class DiscardOStream : public std::ostream -{ -private: - class : public std::streambuf - { - private: - // allowing some buffering decreases the amount of calls to overflow - char buf[1024]; - - protected: - std::streamsize xsputn(const char_type*, std::streamsize count) override { return count; } - - int_type overflow(int_type ch) override { - setp(std::begin(buf), std::end(buf)); - return traits_type::not_eof(ch); - } - } discardBuf; - -public: - DiscardOStream() - : std::ostream(&discardBuf) {} -} discardOut; - -// the main function that does all the filtering and test running -int Context::run() { - using namespace detail; - - // save the old context state in case such was setup - for using asserts out of a testing context - auto old_cs = g_cs; - // this is the current contest - g_cs = p; - is_running_in_test = true; - - g_no_colors = p->no_colors; - p->resetRunData(); - - std::fstream fstr; - if(p->cout == nullptr) { - if(p->quiet) { - p->cout = &discardOut; - } else if(p->out.size()) { - // to a file if specified - fstr.open(p->out.c_str(), std::fstream::out); - p->cout = &fstr; - } else { - // stdout by default - p->cout = &std::cout; - } - } - - FatalConditionHandler::allocateAltStackMem(); - - auto cleanup_and_return = [&]() { - FatalConditionHandler::freeAltStackMem(); - - if(fstr.is_open()) - fstr.close(); - - // restore context - g_cs = old_cs; - is_running_in_test = false; - - // we have to free the reporters which were allocated when the run started - for(auto& curr : p->reporters_currently_used) - delete curr; - p->reporters_currently_used.clear(); - - if(p->numTestCasesFailed && !p->no_exitcode) - return EXIT_FAILURE; - return EXIT_SUCCESS; - }; - - // setup default reporter if none is given through the command line - if(p->filters[8].empty()) - p->filters[8].push_back("console"); - - // check to see if any of the registered reporters has been selected - for(auto& curr : getReporters()) { - if(matchesAny(curr.first.second.c_str(), p->filters[8], false, p->case_sensitive)) - p->reporters_currently_used.push_back(curr.second(*g_cs)); - } - - // TODO: check if there is nothing in reporters_currently_used - - // prepend all listeners - for(auto& curr : getListeners()) - p->reporters_currently_used.insert(p->reporters_currently_used.begin(), curr.second(*g_cs)); - -#ifdef DOCTEST_PLATFORM_WINDOWS - if(isDebuggerActive() && p->no_debug_output == false) - p->reporters_currently_used.push_back(new DebugOutputWindowReporter(*g_cs)); -#endif // DOCTEST_PLATFORM_WINDOWS - - // handle version, help and no_run - if(p->no_run || p->version || p->help || p->list_reporters) { - DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, QueryData()); - - return cleanup_and_return(); - } - - std::vector testArray; - for(auto& curr : getRegisteredTests()) - testArray.push_back(&curr); - p->numTestCases = testArray.size(); - - // sort the collected records - if(!testArray.empty()) { - if(p->order_by.compare("file", true) == 0) { - std::sort(testArray.begin(), testArray.end(), fileOrderComparator); - } else if(p->order_by.compare("suite", true) == 0) { - std::sort(testArray.begin(), testArray.end(), suiteOrderComparator); - } else if(p->order_by.compare("name", true) == 0) { - std::sort(testArray.begin(), testArray.end(), nameOrderComparator); - } else if(p->order_by.compare("rand", true) == 0) { - std::srand(p->rand_seed); - - // random_shuffle implementation - const auto first = &testArray[0]; - for(size_t i = testArray.size() - 1; i > 0; --i) { - int idxToSwap = std::rand() % (i + 1); - - const auto temp = first[i]; - - first[i] = first[idxToSwap]; - first[idxToSwap] = temp; - } - } else if(p->order_by.compare("none", true) == 0) { - // means no sorting - beneficial for death tests which call into the executable - // with a specific test case in mind - we don't want to slow down the startup times - } - } - - std::set testSuitesPassingFilt; - - bool query_mode = p->count || p->list_test_cases || p->list_test_suites; - std::vector queryResults; - - if(!query_mode) - DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_start, DOCTEST_EMPTY); - - // invoke the registered functions if they match the filter criteria (or just count them) - for(auto& curr : testArray) { - const auto& tc = *curr; - - bool skip_me = false; - if(tc.m_skip && !p->no_skip) - skip_me = true; - - if(!matchesAny(tc.m_file.c_str(), p->filters[0], true, p->case_sensitive)) - skip_me = true; - if(matchesAny(tc.m_file.c_str(), p->filters[1], false, p->case_sensitive)) - skip_me = true; - if(!matchesAny(tc.m_test_suite, p->filters[2], true, p->case_sensitive)) - skip_me = true; - if(matchesAny(tc.m_test_suite, p->filters[3], false, p->case_sensitive)) - skip_me = true; - if(!matchesAny(tc.m_name, p->filters[4], true, p->case_sensitive)) - skip_me = true; - if(matchesAny(tc.m_name, p->filters[5], false, p->case_sensitive)) - skip_me = true; - - if(!skip_me) - p->numTestCasesPassingFilters++; - - // skip the test if it is not in the execution range - if((p->last < p->numTestCasesPassingFilters && p->first <= p->last) || - (p->first > p->numTestCasesPassingFilters)) - skip_me = true; - - if(skip_me) { - if(!query_mode) - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_skipped, tc); - continue; - } - - // do not execute the test if we are to only count the number of filter passing tests - if(p->count) - continue; - - // print the name of the test and don't execute it - if(p->list_test_cases) { - queryResults.push_back(&tc); - continue; - } - - // print the name of the test suite if not done already and don't execute it - if(p->list_test_suites) { - if((testSuitesPassingFilt.count(tc.m_test_suite) == 0) && tc.m_test_suite[0] != '\0') { - queryResults.push_back(&tc); - testSuitesPassingFilt.insert(tc.m_test_suite); - p->numTestSuitesPassingFilters++; - } - continue; - } - - // execute the test if it passes all the filtering - { - p->currentTest = &tc; - - p->failure_flags = TestCaseFailureReason::None; - p->seconds = 0; - - // reset atomic counters - p->numAssertsFailedCurrentTest_atomic = 0; - p->numAssertsCurrentTest_atomic = 0; - - p->fullyTraversedSubcases.clear(); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); - - p->timer.start(); - - bool run_test = true; - - do { - // reset some of the fields for subcases (except for the set of fully passed ones) - p->reachedLeaf = false; - // May not be empty if previous subcase exited via exception. - p->subcaseStack.clear(); - p->currentSubcaseDepth = 0; - - p->shouldLogCurrentException = true; - - // reset stuff for logging with INFO() - p->stringifiedContexts.clear(); - -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - try { -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS -// MSVC 2015 diagnoses fatalConditionHandler as unused (because reset() is a static method) -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4101) // unreferenced local variable - FatalConditionHandler fatalConditionHandler; // Handle signals - // execute the test - tc.m_test(); - fatalConditionHandler.reset(); -DOCTEST_MSVC_SUPPRESS_WARNING_POP -#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS - } catch(const TestFailureException&) { - p->failure_flags |= TestCaseFailureReason::AssertFailure; - } catch(...) { - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, - {translateActiveException(), false}); - p->failure_flags |= TestCaseFailureReason::Exception; - } -#endif // DOCTEST_CONFIG_NO_EXCEPTIONS - - // exit this loop if enough assertions have failed - even if there are more subcases - if(p->abort_after > 0 && - p->numAssertsFailed + p->numAssertsFailedCurrentTest_atomic >= p->abort_after) { - run_test = false; - p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; - } - - if(!p->nextSubcaseStack.empty() && run_test) - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); - if(p->nextSubcaseStack.empty()) - run_test = false; - } while(run_test); - - p->finalizeTestCaseData(); - - DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); - - p->currentTest = nullptr; - - // stop executing tests if enough assertions have failed - if(p->abort_after > 0 && p->numAssertsFailed >= p->abort_after) - break; - } - } - - if(!query_mode) { - DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); - } else { - QueryData qdata; - qdata.run_stats = g_cs; - qdata.data = queryResults.data(); - qdata.num_data = unsigned(queryResults.size()); - DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); - } - - return cleanup_and_return(); -} - -DOCTEST_DEFINE_INTERFACE(IReporter) - -int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } -const IContextScope* const* IReporter::get_active_contexts() { - return get_num_active_contexts() ? &detail::g_infoContexts[0] : nullptr; -} - -int IReporter::get_num_stringified_contexts() { return detail::g_cs->stringifiedContexts.size(); } -const String* IReporter::get_stringified_contexts() { - return get_num_stringified_contexts() ? &detail::g_cs->stringifiedContexts[0] : nullptr; -} - -namespace detail { - void registerReporterImpl(const char* name, int priority, reporterCreatorFunc c, bool isReporter) { - if(isReporter) - getReporters().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); - else - getListeners().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); - } -} // namespace detail - -} // namespace doctest - -#endif // DOCTEST_CONFIG_DISABLE - -#ifdef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN -DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) // 'function' : must be 'attribute' - see issue #182 -int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } -DOCTEST_MSVC_SUPPRESS_WARNING_POP -#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN - -DOCTEST_CLANG_SUPPRESS_WARNING_POP -DOCTEST_MSVC_SUPPRESS_WARNING_POP -DOCTEST_GCC_SUPPRESS_WARNING_POP - -DOCTEST_SUPPRESS_COMMON_WARNINGS_POP - -#endif // DOCTEST_LIBRARY_IMPLEMENTATION -#endif // DOCTEST_CONFIG_IMPLEMENT +// ====================================================================== lgtm [cpp/missing-header-guard] +// == DO NOT MODIFY THIS FILE BY HAND - IT IS AUTO GENERATED BY CMAKE! == +// ====================================================================== +// +// doctest.h - the lightest feature-rich C++ single-header testing framework for unit tests and TDD +// +// Copyright (c) 2016-2023 Viktor Kirilov +// +// Distributed under the MIT Software License +// See accompanying file LICENSE.txt or copy at +// https://opensource.org/licenses/MIT +// +// The documentation can be found at the library's page: +// https://github.com/doctest/doctest/blob/master/doc/markdown/readme.md +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= +// +// The library is heavily influenced by Catch - https://github.com/catchorg/Catch2 +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/catchorg/Catch2/blob/master/LICENSE.txt +// +// The concept of subcases (sections in Catch) and expression decomposition are from there. +// Some parts of the code are taken directly: +// - stringification - the detection of "ostream& operator<<(ostream&, const T&)" and StringMaker<> +// - the Approx() helper class for floating point comparison +// - colors in the console +// - breaking into a debugger +// - signal / SEH handling +// - timer +// - XmlWriter class - thanks to Phil Nash for allowing the direct reuse (AKA copy/paste) +// +// The expression decomposing templates are taken from lest - https://github.com/martinmoene/lest +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/martinmoene/lest/blob/master/LICENSE.txt +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= + +#ifndef DOCTEST_LIBRARY_INCLUDED +#define DOCTEST_LIBRARY_INCLUDED + +// ================================================================================================= +// == VERSION ====================================================================================== +// ================================================================================================= + +#define DOCTEST_VERSION_MAJOR 2 +#define DOCTEST_VERSION_MINOR 4 +#define DOCTEST_VERSION_PATCH 11 + +// util we need here +#define DOCTEST_TOSTR_IMPL(x) #x +#define DOCTEST_TOSTR(x) DOCTEST_TOSTR_IMPL(x) + +#define DOCTEST_VERSION_STR \ + DOCTEST_TOSTR(DOCTEST_VERSION_MAJOR) "." \ + DOCTEST_TOSTR(DOCTEST_VERSION_MINOR) "." \ + DOCTEST_TOSTR(DOCTEST_VERSION_PATCH) + +#define DOCTEST_VERSION \ + (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) + +// ================================================================================================= +// == COMPILER VERSION ============================================================================= +// ================================================================================================= + +// ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect + +#ifdef _MSC_VER +#define DOCTEST_CPLUSPLUS _MSVC_LANG +#else +#define DOCTEST_CPLUSPLUS __cplusplus +#endif + +#define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) + +// GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... +#if defined(_MSC_VER) && defined(_MSC_FULL_VER) +#if _MSC_VER == _MSC_FULL_VER / 10000 +#define DOCTEST_MSVC DOCTEST_COMPILER(_MSC_VER / 100, _MSC_VER % 100, _MSC_FULL_VER % 10000) +#else // MSVC +#define DOCTEST_MSVC \ + DOCTEST_COMPILER(_MSC_VER / 100, (_MSC_FULL_VER / 100000) % 100, _MSC_FULL_VER % 100000) +#endif // MSVC +#endif // MSVC +#if defined(__clang__) && defined(__clang_minor__) && defined(__clang_patchlevel__) +#define DOCTEST_CLANG DOCTEST_COMPILER(__clang_major__, __clang_minor__, __clang_patchlevel__) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) && \ + !defined(__INTEL_COMPILER) +#define DOCTEST_GCC DOCTEST_COMPILER(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#endif // GCC +#if defined(__INTEL_COMPILER) +#define DOCTEST_ICC DOCTEST_COMPILER(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) +#endif // ICC + +#ifndef DOCTEST_MSVC +#define DOCTEST_MSVC 0 +#endif // DOCTEST_MSVC +#ifndef DOCTEST_CLANG +#define DOCTEST_CLANG 0 +#endif // DOCTEST_CLANG +#ifndef DOCTEST_GCC +#define DOCTEST_GCC 0 +#endif // DOCTEST_GCC +#ifndef DOCTEST_ICC +#define DOCTEST_ICC 0 +#endif // DOCTEST_ICC + +// ================================================================================================= +// == COMPILER WARNINGS HELPERS ==================================================================== +// ================================================================================================= + +#if DOCTEST_CLANG && !DOCTEST_ICC +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(clang diagnostic ignored w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH DOCTEST_CLANG_SUPPRESS_WARNING(w) +#else // DOCTEST_CLANG +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_CLANG + +#if DOCTEST_GCC +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH _Pragma("GCC diagnostic push") +#define DOCTEST_GCC_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(GCC diagnostic ignored w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP _Pragma("GCC diagnostic pop") +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH DOCTEST_GCC_SUPPRESS_WARNING(w) +#else // DOCTEST_GCC +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH +#define DOCTEST_GCC_SUPPRESS_WARNING(w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_GCC + +#if DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH __pragma(warning(push)) +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) __pragma(warning(disable : w)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP __pragma(warning(pop)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH DOCTEST_MSVC_SUPPRESS_WARNING(w) +#else // DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_MSVC + +// ================================================================================================= +// == COMPILER WARNINGS ============================================================================ +// ================================================================================================= + +// both the header and the implementation suppress all of these, +// so it only makes sense to aggregate them like so +#define DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") \ + \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") \ + DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") \ + \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + /* these 4 also disabled globally via cmake: */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4514) /* unreferenced inline function has been removed */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4571) /* SEH related */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4710) /* function not inlined */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4711) /* function selected for inline expansion*/ \ + /* common ones */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4616) /* invalid compiler warning */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4619) /* invalid compiler warning */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4996) /* The compiler encountered a deprecated declaration */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4706) /* assignment within conditional expression */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4512) /* 'class' : assignment operator could not be generated */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4127) /* conditional expression is constant */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4640) /* construction of local static object not thread-safe */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5264) /* 'variable-name': 'const' variable is not used */ \ + /* static analysis */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26439) /* Function may not throw. Declare it 'noexcept' */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26495) /* Always initialize a member variable */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26451) /* Arithmetic overflow ... */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26444) /* Avoid unnamed objects with custom ctor and dtor... */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(26812) /* Prefer 'enum class' over 'enum' */ + +#define DOCTEST_SUPPRESS_COMMON_WARNINGS_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + DOCTEST_GCC_SUPPRESS_WARNING_POP \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + DOCTEST_MSVC_SUPPRESS_WARNING(4548) /* before comma no effect; expected side - effect */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4265) /* virtual functions, but destructor is not virtual */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4986) /* exception specification does not match previous */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4350) /* 'member1' called instead of 'member2' */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4668) /* not defined as a preprocessor macro */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4365) /* signed/unsigned mismatch */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4774) /* format string not a string literal */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) /* padding */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) /* copy constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) /* assignment operator was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) /* move assignment operator implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) /* move constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4623) /* default constructor was implicitly deleted */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5039) /* pointer to pot. throwing function passed to extern C */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) /* Spectre mitigation for memory load */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5105) /* macro producing 'defined' has undefined behavior */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(4738) /* storing float result in memory, loss of performance */ \ + DOCTEST_MSVC_SUPPRESS_WARNING(5262) /* implicit fall-through */ + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP + +// ================================================================================================= +// == FEATURE DETECTION ============================================================================ +// ================================================================================================= + +// general compiler feature support table: https://en.cppreference.com/w/cpp/compiler_support +// MSVC C++11 feature support table: https://msdn.microsoft.com/en-us/library/hh567368.aspx +// GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html +// MSVC version table: +// https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering +// MSVC++ 14.3 (17) _MSC_VER == 1930 (Visual Studio 2022) +// MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) +// MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) +// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) +// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) +// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) +// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) +// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) +// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) + +// Universal Windows Platform support +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define DOCTEST_CONFIG_NO_WINDOWS_SEH +#endif // WINAPI_FAMILY +#if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) +#define DOCTEST_CONFIG_WINDOWS_SEH +#endif // MSVC +#if defined(DOCTEST_CONFIG_NO_WINDOWS_SEH) && defined(DOCTEST_CONFIG_WINDOWS_SEH) +#undef DOCTEST_CONFIG_WINDOWS_SEH +#endif // DOCTEST_CONFIG_NO_WINDOWS_SEH + +#if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ + !defined(__EMSCRIPTEN__) && !defined(__wasi__) +#define DOCTEST_CONFIG_POSIX_SIGNALS +#endif // _WIN32 +#if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) +#undef DOCTEST_CONFIG_POSIX_SIGNALS +#endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) \ + || defined(__wasi__) +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // no exceptions +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#if defined(DOCTEST_CONFIG_NO_EXCEPTIONS) && !defined(DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS) +#define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#ifdef __wasi__ +#define DOCTEST_CONFIG_NO_MULTITHREADING +#endif + +#if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) +#define DOCTEST_CONFIG_IMPLEMENT +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#if defined(_WIN32) || defined(__CYGWIN__) +#if DOCTEST_MSVC +#define DOCTEST_SYMBOL_EXPORT __declspec(dllexport) +#define DOCTEST_SYMBOL_IMPORT __declspec(dllimport) +#else // MSVC +#define DOCTEST_SYMBOL_EXPORT __attribute__((dllexport)) +#define DOCTEST_SYMBOL_IMPORT __attribute__((dllimport)) +#endif // MSVC +#else // _WIN32 +#define DOCTEST_SYMBOL_EXPORT __attribute__((visibility("default"))) +#define DOCTEST_SYMBOL_IMPORT +#endif // _WIN32 + +#ifdef DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#ifdef DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_EXPORT +#else // DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_IMPORT +#endif // DOCTEST_CONFIG_IMPLEMENT +#else // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#define DOCTEST_INTERFACE +#endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL + +// needed for extern template instantiations +// see https://github.com/fmtlib/fmt/issues/2228 +#if DOCTEST_MSVC +#define DOCTEST_INTERFACE_DECL +#define DOCTEST_INTERFACE_DEF DOCTEST_INTERFACE +#else // DOCTEST_MSVC +#define DOCTEST_INTERFACE_DECL DOCTEST_INTERFACE +#define DOCTEST_INTERFACE_DEF +#endif // DOCTEST_MSVC + +#define DOCTEST_EMPTY + +#if DOCTEST_MSVC +#define DOCTEST_NOINLINE __declspec(noinline) +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#elif DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 5, 0) +#define DOCTEST_NOINLINE +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#else +#define DOCTEST_NOINLINE __attribute__((noinline)) +#define DOCTEST_UNUSED __attribute__((unused)) +#define DOCTEST_ALIGNMENT(x) __attribute__((aligned(x))) +#endif + +#ifdef DOCTEST_CONFIG_NO_CONTRADICTING_INLINE +#define DOCTEST_INLINE_NOINLINE inline +#else +#define DOCTEST_INLINE_NOINLINE inline DOCTEST_NOINLINE +#endif + +#ifndef DOCTEST_NORETURN +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_NORETURN +#else // DOCTEST_MSVC +#define DOCTEST_NORETURN [[noreturn]] +#endif // DOCTEST_MSVC +#endif // DOCTEST_NORETURN + +#ifndef DOCTEST_NOEXCEPT +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_NOEXCEPT +#else // DOCTEST_MSVC +#define DOCTEST_NOEXCEPT noexcept +#endif // DOCTEST_MSVC +#endif // DOCTEST_NOEXCEPT + +#ifndef DOCTEST_CONSTEXPR +#if DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_CONSTEXPR const +#define DOCTEST_CONSTEXPR_FUNC inline +#else // DOCTEST_MSVC +#define DOCTEST_CONSTEXPR constexpr +#define DOCTEST_CONSTEXPR_FUNC constexpr +#endif // DOCTEST_MSVC +#endif // DOCTEST_CONSTEXPR + +#ifndef DOCTEST_NO_SANITIZE_INTEGER +#if DOCTEST_CLANG >= DOCTEST_COMPILER(3, 7, 0) +#define DOCTEST_NO_SANITIZE_INTEGER __attribute__((no_sanitize("integer"))) +#else +#define DOCTEST_NO_SANITIZE_INTEGER +#endif +#endif // DOCTEST_NO_SANITIZE_INTEGER + +// ================================================================================================= +// == FEATURE DETECTION END ======================================================================== +// ================================================================================================= + +#define DOCTEST_DECLARE_INTERFACE(name) \ + virtual ~name(); \ + name() = default; \ + name(const name&) = delete; \ + name(name&&) = delete; \ + name& operator=(const name&) = delete; \ + name& operator=(name&&) = delete; + +#define DOCTEST_DEFINE_INTERFACE(name) \ + name::~name() = default; + +// internal macros for string concatenation and anonymous variable name generation +#define DOCTEST_CAT_IMPL(s1, s2) s1##s2 +#define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) +#ifdef __COUNTER__ // not standard and may be missing for some compilers +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __COUNTER__) +#else // __COUNTER__ +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) +#endif // __COUNTER__ + +#ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x& +#else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x +#endif // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE + +// not using __APPLE__ because... this is how Catch does it +#ifdef __MAC_OS_X_VERSION_MIN_REQUIRED +#define DOCTEST_PLATFORM_MAC +#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) +#define DOCTEST_PLATFORM_IPHONE +#elif defined(_WIN32) +#define DOCTEST_PLATFORM_WINDOWS +#elif defined(__wasi__) +#define DOCTEST_PLATFORM_WASI +#else // DOCTEST_PLATFORM +#define DOCTEST_PLATFORM_LINUX +#endif // DOCTEST_PLATFORM + +namespace doctest { namespace detail { + static DOCTEST_CONSTEXPR int consume(const int*, int) noexcept { return 0; } +}} + +#define DOCTEST_GLOBAL_NO_WARNINGS(var, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ + static const int var = doctest::detail::consume(&var, __VA_ARGS__); \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#ifndef DOCTEST_BREAK_INTO_DEBUGGER +// should probably take a look at https://github.com/scottt/debugbreak +#ifdef DOCTEST_PLATFORM_LINUX +#if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) +// Break at the location of the failing check if possible +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) +#else +#include +#define DOCTEST_BREAK_INTO_DEBUGGER() raise(SIGTRAP) +#endif +#elif defined(DOCTEST_PLATFORM_MAC) +#if defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || defined(__i386) +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT(hicpp-no-assembler) +#elif defined(__ppc__) || defined(__ppc64__) +// https://www.cocoawithlove.com/2008/03/break-into-debugger.html +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n": : : "memory","r0","r3","r4") // NOLINT(hicpp-no-assembler) +#else +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT(hicpp-no-assembler) +#endif +#elif DOCTEST_MSVC +#define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() +#elif defined(__MINGW32__) +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wredundant-decls") +extern "C" __declspec(dllimport) void __stdcall DebugBreak(); +DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_BREAK_INTO_DEBUGGER() ::DebugBreak() +#else // linux +#define DOCTEST_BREAK_INTO_DEBUGGER() (static_cast(0)) +#endif // linux +#endif // DOCTEST_BREAK_INTO_DEBUGGER + +// this is kept here for backwards compatibility since the config option was changed +#ifdef DOCTEST_CONFIG_USE_IOSFWD +#ifndef DOCTEST_CONFIG_USE_STD_HEADERS +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif +#endif // DOCTEST_CONFIG_USE_IOSFWD + +// for clang - always include ciso646 (which drags some std stuff) because +// we want to check if we are using libc++ with the _LIBCPP_VERSION macro in +// which case we don't want to forward declare stuff from std - for reference: +// https://github.com/doctest/doctest/issues/126 +// https://github.com/doctest/doctest/issues/356 +#if DOCTEST_CLANG +#include +#endif // clang + +#ifdef _LIBCPP_VERSION +#ifndef DOCTEST_CONFIG_USE_STD_HEADERS +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif +#endif // _LIBCPP_VERSION + +#ifdef DOCTEST_CONFIG_USE_STD_HEADERS +#ifndef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN +#include +#include +#include +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END +#else // DOCTEST_CONFIG_USE_STD_HEADERS + +// Forward declaring 'X' in namespace std is not permitted by the C++ Standard. +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) + +namespace std { // NOLINT(cert-dcl58-cpp) +typedef decltype(nullptr) nullptr_t; // NOLINT(modernize-use-using) +typedef decltype(sizeof(void*)) size_t; // NOLINT(modernize-use-using) +template +struct char_traits; +template <> +struct char_traits; +template +class basic_ostream; // NOLINT(fuchsia-virtual-inheritance) +typedef basic_ostream> ostream; // NOLINT(modernize-use-using) +template +// NOLINTNEXTLINE +basic_ostream& operator<<(basic_ostream&, const char*); +template +class basic_istream; +typedef basic_istream> istream; // NOLINT(modernize-use-using) +template +class tuple; +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +template +class allocator; +template +class basic_string; +using string = basic_string, allocator>; +#endif // VS 2019 +} // namespace std + +DOCTEST_MSVC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_USE_STD_HEADERS + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#include +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + +namespace doctest { + +using std::size_t; + +DOCTEST_INTERFACE extern bool is_running_in_test; + +#ifndef DOCTEST_CONFIG_STRING_SIZE_TYPE +#define DOCTEST_CONFIG_STRING_SIZE_TYPE unsigned +#endif + +// A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length +// of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: +// - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) +// - if small - capacity left before going on the heap - using the lowest 5 bits +// - if small - 2 bits are left unused - the second and third highest ones +// - if small - acts as a null terminator if strlen() is 23 (24 including the null terminator) +// and the "is small" bit remains "0" ("as well as the capacity left") so its OK +// Idea taken from this lecture about the string implementation of facebook/folly - fbstring +// https://www.youtube.com/watch?v=kPR8h4-qZdk +// TODO: +// - optimizations - like not deleting memory unnecessarily in operator= and etc. +// - resize/reserve/clear +// - replace +// - back/front +// - iterator stuff +// - find & friends +// - push_back/pop_back +// - assign/insert/erase +// - relational operators as free functions - taking const char* as one of the params +class DOCTEST_INTERFACE String +{ +public: + using size_type = DOCTEST_CONFIG_STRING_SIZE_TYPE; + +private: + static DOCTEST_CONSTEXPR size_type len = 24; //!OCLINT avoid private static members + static DOCTEST_CONSTEXPR size_type last = len - 1; //!OCLINT avoid private static members + + struct view // len should be more than sizeof(view) - because of the final byte for flags + { + char* ptr; + size_type size; + size_type capacity; + }; + + union + { + char buf[len]; // NOLINT(*-avoid-c-arrays) + view data; + }; + + char* allocate(size_type sz); + + bool isOnStack() const noexcept { return (buf[last] & 128) == 0; } + void setOnHeap() noexcept; + void setLast(size_type in = last) noexcept; + void setSize(size_type sz) noexcept; + + void copy(const String& other); + +public: + static DOCTEST_CONSTEXPR size_type npos = static_cast(-1); + + String() noexcept; + ~String(); + + // cppcheck-suppress noExplicitConstructor + String(const char* in); + String(const char* in, size_type in_size); + + String(std::istream& in, size_type in_size); + + String(const String& other); + String& operator=(const String& other); + + String& operator+=(const String& other); + + String(String&& other) noexcept; + String& operator=(String&& other) noexcept; + + char operator[](size_type i) const; + char& operator[](size_type i); + + // the only functions I'm willing to leave in the interface - available for inlining + const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT + char* c_str() { + if (isOnStack()) { + return reinterpret_cast(buf); + } + return data.ptr; + } + + size_type size() const; + size_type capacity() const; + + String substr(size_type pos, size_type cnt = npos) &&; + String substr(size_type pos, size_type cnt = npos) const &; + + size_type find(char ch, size_type pos = 0) const; + size_type rfind(char ch, size_type pos = npos) const; + + int compare(const char* other, bool no_case = false) const; + int compare(const String& other, bool no_case = false) const; + +friend DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); +}; + +DOCTEST_INTERFACE String operator+(const String& lhs, const String& rhs); + +DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); + +class DOCTEST_INTERFACE Contains { +public: + explicit Contains(const String& string); + + bool checkWith(const String& other) const; + + String string; +}; + +DOCTEST_INTERFACE String toString(const Contains& in); + +DOCTEST_INTERFACE bool operator==(const String& lhs, const Contains& rhs); +DOCTEST_INTERFACE bool operator==(const Contains& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const Contains& rhs); +DOCTEST_INTERFACE bool operator!=(const Contains& lhs, const String& rhs); + +namespace Color { + enum Enum + { + None = 0, + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White + }; + + DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, Color::Enum code); +} // namespace Color + +namespace assertType { + enum Enum + { + // macro traits + + is_warn = 1, + is_check = 2 * is_warn, + is_require = 2 * is_check, + + is_normal = 2 * is_require, + is_throws = 2 * is_normal, + is_throws_as = 2 * is_throws, + is_throws_with = 2 * is_throws_as, + is_nothrow = 2 * is_throws_with, + + is_false = 2 * is_nothrow, + is_unary = 2 * is_false, // not checked anywhere - used just to distinguish the types + + is_eq = 2 * is_unary, + is_ne = 2 * is_eq, + + is_lt = 2 * is_ne, + is_gt = 2 * is_lt, + + is_ge = 2 * is_gt, + is_le = 2 * is_ge, + + // macro types + + DT_WARN = is_normal | is_warn, + DT_CHECK = is_normal | is_check, + DT_REQUIRE = is_normal | is_require, + + DT_WARN_FALSE = is_normal | is_false | is_warn, + DT_CHECK_FALSE = is_normal | is_false | is_check, + DT_REQUIRE_FALSE = is_normal | is_false | is_require, + + DT_WARN_THROWS = is_throws | is_warn, + DT_CHECK_THROWS = is_throws | is_check, + DT_REQUIRE_THROWS = is_throws | is_require, + + DT_WARN_THROWS_AS = is_throws_as | is_warn, + DT_CHECK_THROWS_AS = is_throws_as | is_check, + DT_REQUIRE_THROWS_AS = is_throws_as | is_require, + + DT_WARN_THROWS_WITH = is_throws_with | is_warn, + DT_CHECK_THROWS_WITH = is_throws_with | is_check, + DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, + + DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, + DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, + DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, + + DT_WARN_NOTHROW = is_nothrow | is_warn, + DT_CHECK_NOTHROW = is_nothrow | is_check, + DT_REQUIRE_NOTHROW = is_nothrow | is_require, + + DT_WARN_EQ = is_normal | is_eq | is_warn, + DT_CHECK_EQ = is_normal | is_eq | is_check, + DT_REQUIRE_EQ = is_normal | is_eq | is_require, + + DT_WARN_NE = is_normal | is_ne | is_warn, + DT_CHECK_NE = is_normal | is_ne | is_check, + DT_REQUIRE_NE = is_normal | is_ne | is_require, + + DT_WARN_GT = is_normal | is_gt | is_warn, + DT_CHECK_GT = is_normal | is_gt | is_check, + DT_REQUIRE_GT = is_normal | is_gt | is_require, + + DT_WARN_LT = is_normal | is_lt | is_warn, + DT_CHECK_LT = is_normal | is_lt | is_check, + DT_REQUIRE_LT = is_normal | is_lt | is_require, + + DT_WARN_GE = is_normal | is_ge | is_warn, + DT_CHECK_GE = is_normal | is_ge | is_check, + DT_REQUIRE_GE = is_normal | is_ge | is_require, + + DT_WARN_LE = is_normal | is_le | is_warn, + DT_CHECK_LE = is_normal | is_le | is_check, + DT_REQUIRE_LE = is_normal | is_le | is_require, + + DT_WARN_UNARY = is_normal | is_unary | is_warn, + DT_CHECK_UNARY = is_normal | is_unary | is_check, + DT_REQUIRE_UNARY = is_normal | is_unary | is_require, + + DT_WARN_UNARY_FALSE = is_normal | is_false | is_unary | is_warn, + DT_CHECK_UNARY_FALSE = is_normal | is_false | is_unary | is_check, + DT_REQUIRE_UNARY_FALSE = is_normal | is_false | is_unary | is_require, + }; +} // namespace assertType + +DOCTEST_INTERFACE const char* assertString(assertType::Enum at); +DOCTEST_INTERFACE const char* failureString(assertType::Enum at); +DOCTEST_INTERFACE const char* skipPathFromFilename(const char* file); + +struct DOCTEST_INTERFACE TestCaseData +{ + String m_file; // the file in which the test was registered (using String - see #350) + unsigned m_line; // the line where the test was registered + const char* m_name; // name of the test case + const char* m_test_suite; // the test suite in which the test was added + const char* m_description; + bool m_skip; + bool m_no_breaks; + bool m_no_output; + bool m_may_fail; + bool m_should_fail; + int m_expected_failures; + double m_timeout; +}; + +struct DOCTEST_INTERFACE AssertData +{ + // common - for all asserts + const TestCaseData* m_test_case; + assertType::Enum m_at; + const char* m_file; + int m_line; + const char* m_expr; + bool m_failed; + + // exception-related - for all asserts + bool m_threw; + String m_exception; + + // for normal asserts + String m_decomp; + + // for specific exception-related asserts + bool m_threw_as; + const char* m_exception_type; + + class DOCTEST_INTERFACE StringContains { + private: + Contains content; + bool isContains; + + public: + StringContains(const String& str) : content(str), isContains(false) { } + StringContains(Contains cntn) : content(static_cast(cntn)), isContains(true) { } + + bool check(const String& str) { return isContains ? (content == str) : (content.string == str); } + + operator const String&() const { return content.string; } + + const char* c_str() const { return content.string.c_str(); } + } m_exception_string; + + AssertData(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const StringContains& exception_string); +}; + +struct DOCTEST_INTERFACE MessageData +{ + String m_string; + const char* m_file; + int m_line; + assertType::Enum m_severity; +}; + +struct DOCTEST_INTERFACE SubcaseSignature +{ + String m_name; + const char* m_file; + int m_line; + + bool operator==(const SubcaseSignature& other) const; + bool operator<(const SubcaseSignature& other) const; +}; + +struct DOCTEST_INTERFACE IContextScope +{ + DOCTEST_DECLARE_INTERFACE(IContextScope) + virtual void stringify(std::ostream*) const = 0; +}; + +namespace detail { + struct DOCTEST_INTERFACE TestCase; +} // namespace detail + +struct ContextOptions //!OCLINT too many fields +{ + std::ostream* cout = nullptr; // stdout stream + String binary_name; // the test binary name + + const detail::TestCase* currentTest = nullptr; + + // == parameters from the command line + String out; // output filename + String order_by; // how tests should be ordered + unsigned rand_seed; // the seed for rand ordering + + unsigned first; // the first (matching) test to be executed + unsigned last; // the last (matching) test to be executed + + int abort_after; // stop tests after this many failed assertions + int subcase_filter_levels; // apply the subcase filters for the first N levels + + bool success; // include successful assertions in output + bool case_sensitive; // if filtering should be case sensitive + bool exit; // if the program should be exited after the tests are ran/whatever + bool duration; // print the time duration of each test case + bool minimal; // minimal console output (only test failures) + bool quiet; // no console output + bool no_throw; // to skip exceptions-related assertion macros + bool no_exitcode; // if the framework should return 0 as the exitcode + bool no_run; // to not run the tests at all (can be done with an "*" exclude) + bool no_intro; // to not print the intro of the framework + bool no_version; // to not print the version of the framework + bool no_colors; // if output to the console should be colorized + bool force_colors; // forces the use of colors even when a tty cannot be detected + bool no_breaks; // to not break into the debugger + bool no_skip; // don't skip test cases which are marked to be skipped + bool gnu_file_line; // if line numbers should be surrounded with :x: and not (x): + bool no_path_in_filenames; // if the path to files should be removed from the output + bool no_line_numbers; // if source code line numbers should be omitted from the output + bool no_debug_output; // no output in the debug console when a debugger is attached + bool no_skipped_summary; // don't print "skipped" in the summary !!! UNDOCUMENTED !!! + bool no_time_in_output; // omit any time/timestamps from output !!! UNDOCUMENTED !!! + + bool help; // to print the help + bool version; // to print the version + bool count; // if only the count of matching tests is to be retrieved + bool list_test_cases; // to list all tests matching the filters + bool list_test_suites; // to list all suites matching the filters + bool list_reporters; // lists all registered reporters +}; + +namespace detail { + namespace types { +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + using namespace std; +#else + template + struct enable_if { }; + + template + struct enable_if { using type = T; }; + + struct true_type { static DOCTEST_CONSTEXPR bool value = true; }; + struct false_type { static DOCTEST_CONSTEXPR bool value = false; }; + + template struct remove_reference { using type = T; }; + template struct remove_reference { using type = T; }; + template struct remove_reference { using type = T; }; + + template struct is_rvalue_reference : false_type { }; + template struct is_rvalue_reference : true_type { }; + + template struct remove_const { using type = T; }; + template struct remove_const { using type = T; }; + + // Compiler intrinsics + template struct is_enum { static DOCTEST_CONSTEXPR bool value = __is_enum(T); }; + template struct underlying_type { using type = __underlying_type(T); }; + + template struct is_pointer : false_type { }; + template struct is_pointer : true_type { }; + + template struct is_array : false_type { }; + // NOLINTNEXTLINE(*-avoid-c-arrays) + template struct is_array : true_type { }; +#endif + } + + // + template + T&& declval(); + + template + DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type& t) DOCTEST_NOEXCEPT { + return static_cast(t); + } + + template + DOCTEST_CONSTEXPR_FUNC T&& forward(typename types::remove_reference::type&& t) DOCTEST_NOEXCEPT { + return static_cast(t); + } + + template + struct deferred_false : types::false_type { }; + +// MSVS 2015 :( +#if !DOCTEST_CLANG && defined(_MSC_VER) && _MSC_VER <= 1900 + template + struct has_global_insertion_operator : types::false_type { }; + + template + struct has_global_insertion_operator(), declval()), void())> : types::true_type { }; + + template + struct has_insertion_operator { static DOCTEST_CONSTEXPR bool value = has_global_insertion_operator::value; }; + + template + struct insert_hack; + + template + struct insert_hack { + static void insert(std::ostream& os, const T& t) { ::operator<<(os, t); } + }; + + template + struct insert_hack { + static void insert(std::ostream& os, const T& t) { operator<<(os, t); } + }; + + template + using insert_hack_t = insert_hack::value>; +#else + template + struct has_insertion_operator : types::false_type { }; +#endif + + template + struct has_insertion_operator(), declval()), void())> : types::true_type { }; + + template + struct should_stringify_as_underlying_type { + static DOCTEST_CONSTEXPR bool value = detail::types::is_enum::value && !doctest::detail::has_insertion_operator::value; + }; + + DOCTEST_INTERFACE std::ostream* tlssPush(); + DOCTEST_INTERFACE String tlssPop(); + + template + struct StringMakerBase { + template + static String convert(const DOCTEST_REF_WRAP(T)) { +#ifdef DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES + static_assert(deferred_false::value, "No stringification detected for type T. See string conversion manual"); +#endif + return "{?}"; + } + }; + + template + struct filldata; + + template + void filloss(std::ostream* stream, const T& in) { + filldata::fill(stream, in); + } + + template + void filloss(std::ostream* stream, const T (&in)[N]) { // NOLINT(*-avoid-c-arrays) + // T[N], T(&)[N], T(&&)[N] have same behaviour. + // Hence remove reference. + filloss::type>(stream, in); + } + + template + String toStream(const T& in) { + std::ostream* stream = tlssPush(); + filloss(stream, in); + return tlssPop(); + } + + template <> + struct StringMakerBase { + template + static String convert(const DOCTEST_REF_WRAP(T) in) { + return toStream(in); + } + }; +} // namespace detail + +template +struct StringMaker : public detail::StringMakerBase< + detail::has_insertion_operator::value || detail::types::is_pointer::value || detail::types::is_array::value> +{}; + +#ifndef DOCTEST_STRINGIFY +#ifdef DOCTEST_CONFIG_DOUBLE_STRINGIFY +#define DOCTEST_STRINGIFY(...) toString(toString(__VA_ARGS__)) +#else +#define DOCTEST_STRINGIFY(...) toString(__VA_ARGS__) +#endif +#endif + +template +String toString() { +#if DOCTEST_CLANG == 0 && DOCTEST_GCC == 0 && DOCTEST_ICC == 0 + String ret = __FUNCSIG__; // class doctest::String __cdecl doctest::toString(void) + String::size_type beginPos = ret.find('<'); + return ret.substr(beginPos + 1, ret.size() - beginPos - static_cast(sizeof(">(void)"))); +#else + String ret = __PRETTY_FUNCTION__; // doctest::String toString() [with T = TYPE] + String::size_type begin = ret.find('=') + 2; + return ret.substr(begin, ret.size() - begin - 1); +#endif +} + +template ::value, bool>::type = true> +String toString(const DOCTEST_REF_WRAP(T) value) { + return StringMaker::convert(value); +} + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +DOCTEST_INTERFACE String toString(const char* in); +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +DOCTEST_INTERFACE String toString(const std::string& in); +#endif // VS 2019 + +DOCTEST_INTERFACE String toString(String in); + +DOCTEST_INTERFACE String toString(std::nullptr_t); + +DOCTEST_INTERFACE String toString(bool in); + +DOCTEST_INTERFACE String toString(float in); +DOCTEST_INTERFACE String toString(double in); +DOCTEST_INTERFACE String toString(double long in); + +DOCTEST_INTERFACE String toString(char in); +DOCTEST_INTERFACE String toString(char signed in); +DOCTEST_INTERFACE String toString(char unsigned in); +DOCTEST_INTERFACE String toString(short in); +DOCTEST_INTERFACE String toString(short unsigned in); +DOCTEST_INTERFACE String toString(signed in); +DOCTEST_INTERFACE String toString(unsigned in); +DOCTEST_INTERFACE String toString(long in); +DOCTEST_INTERFACE String toString(long unsigned in); +DOCTEST_INTERFACE String toString(long long in); +DOCTEST_INTERFACE String toString(long long unsigned in); + +template ::value, bool>::type = true> +String toString(const DOCTEST_REF_WRAP(T) value) { + using UT = typename detail::types::underlying_type::type; + return (DOCTEST_STRINGIFY(static_cast(value))); +} + +namespace detail { + template + struct filldata + { + static void fill(std::ostream* stream, const T& in) { +#if defined(_MSC_VER) && _MSC_VER <= 1900 + insert_hack_t::insert(*stream, in); +#else + operator<<(*stream, in); +#endif + } + }; + +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) +// NOLINTBEGIN(*-avoid-c-arrays) + template + struct filldata { + static void fill(std::ostream* stream, const T(&in)[N]) { + *stream << "["; + for (size_t i = 0; i < N; i++) { + if (i != 0) { *stream << ", "; } + *stream << (DOCTEST_STRINGIFY(in[i])); + } + *stream << "]"; + } + }; +// NOLINTEND(*-avoid-c-arrays) +DOCTEST_MSVC_SUPPRESS_WARNING_POP + + // Specialized since we don't want the terminating null byte! +// NOLINTBEGIN(*-avoid-c-arrays) + template + struct filldata { + static void fill(std::ostream* stream, const char (&in)[N]) { + *stream << String(in, in[N - 1] ? N : N - 1); + } // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + }; +// NOLINTEND(*-avoid-c-arrays) + + template <> + struct filldata { + static void fill(std::ostream* stream, const void* in); + }; + + template + struct filldata { +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4180) + static void fill(std::ostream* stream, const T* in) { +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wmicrosoft-cast") + filldata::fill(stream, +#if DOCTEST_GCC == 0 || DOCTEST_GCC >= DOCTEST_COMPILER(4, 9, 0) + reinterpret_cast(in) +#else + *reinterpret_cast(&in) +#endif + ); +DOCTEST_CLANG_SUPPRESS_WARNING_POP + } + }; +} + +struct DOCTEST_INTERFACE Approx +{ + Approx(double value); + + Approx operator()(double value) const; + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + explicit Approx(const T& value, + typename detail::types::enable_if::value>::type* = + static_cast(nullptr)) { + *this = static_cast(value); + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& epsilon(double newEpsilon); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename std::enable_if::value, Approx&>::type epsilon( + const T& newEpsilon) { + m_epsilon = static_cast(newEpsilon); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& scale(double newScale); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename std::enable_if::value, Approx&>::type scale( + const T& newScale) { + m_scale = static_cast(newScale); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format off + DOCTEST_INTERFACE friend bool operator==(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator==(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator!=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator!=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator<=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator<=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator>=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator>=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator< (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator< (const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_APPROX_PREFIX \ + template friend typename std::enable_if::value, bool>::type + + DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(static_cast(lhs), rhs); } + DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } + DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return static_cast(lhs) < rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < static_cast(rhs) && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return static_cast(lhs) > rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > static_cast(rhs) && lhs != rhs; } +#undef DOCTEST_APPROX_PREFIX +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format on + + double m_epsilon; + double m_scale; + double m_value; +}; + +DOCTEST_INTERFACE String toString(const Approx& in); + +DOCTEST_INTERFACE const ContextOptions* getContextOptions(); + +template +struct DOCTEST_INTERFACE_DECL IsNaN +{ + F value; bool flipped; + IsNaN(F f, bool flip = false) : value(f), flipped(flip) { } + IsNaN operator!() const { return { value, !flipped }; } + operator bool() const; +}; +#ifndef __MINGW32__ +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +extern template struct DOCTEST_INTERFACE_DECL IsNaN; +#endif +DOCTEST_INTERFACE String toString(IsNaN in); +DOCTEST_INTERFACE String toString(IsNaN in); +DOCTEST_INTERFACE String toString(IsNaN in); + +#ifndef DOCTEST_CONFIG_DISABLE + +namespace detail { + // clang-format off +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + template struct decay_array { using type = T; }; + template struct decay_array { using type = T*; }; + template struct decay_array { using type = T*; }; + + template struct not_char_pointer { static DOCTEST_CONSTEXPR int value = 1; }; + template<> struct not_char_pointer { static DOCTEST_CONSTEXPR int value = 0; }; + template<> struct not_char_pointer { static DOCTEST_CONSTEXPR int value = 0; }; + + template struct can_use_op : public not_char_pointer::type> {}; +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + + struct DOCTEST_INTERFACE TestFailureException + { + }; + + DOCTEST_INTERFACE bool checkIfShouldThrow(assertType::Enum at); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_INTERFACE void throwException(); + + struct DOCTEST_INTERFACE Subcase + { + SubcaseSignature m_signature; + bool m_entered = false; + + Subcase(const String& name, const char* file, int line); + Subcase(const Subcase&) = delete; + Subcase(Subcase&&) = delete; + Subcase& operator=(const Subcase&) = delete; + Subcase& operator=(Subcase&&) = delete; + ~Subcase(); + + operator bool() const; + + private: + bool checkFilters(); + }; + + template + String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, + const DOCTEST_REF_WRAP(R) rhs) { + return (DOCTEST_STRINGIFY(lhs)) + op + (DOCTEST_STRINGIFY(rhs)); + } + +#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") +#endif + +// This will check if there is any way it could find a operator like member or friend and uses it. +// If not it doesn't find the operator or if the operator at global scope is defined after +// this template, the template won't be instantiated due to SFINAE. Once the template is not +// instantiated it can look for global operator using normal conversions. +#ifdef __NVCC__ +#define SFINAE_OP(ret,op) ret +#else +#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),ret{}) +#endif + +#define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ + template \ + DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ + bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ + if(m_at & assertType::is_false) \ + res = !res; \ + if(!res || doctest::getContextOptions()->success) \ + return Result(res, stringifyBinaryExpr(lhs, op_str, rhs)); \ + return Result(res); \ + } + + // more checks could be added - like in Catch: + // https://github.com/catchorg/Catch2/pull/1480/files + // https://github.com/catchorg/Catch2/pull/1481/files +#define DOCTEST_FORBIT_EXPRESSION(rt, op) \ + template \ + rt& operator op(const R&) { \ + static_assert(deferred_false::value, \ + "Expression Too Complex Please Rewrite As Binary Comparison!"); \ + return *this; \ + } + + struct DOCTEST_INTERFACE Result // NOLINT(*-member-init) + { + bool m_passed; + String m_decomp; + + Result() = default; // TODO: Why do we need this? (To remove NOLINT) + Result(bool passed, const String& decomposition = String()); + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Result, &) + DOCTEST_FORBIT_EXPRESSION(Result, ^) + DOCTEST_FORBIT_EXPRESSION(Result, |) + DOCTEST_FORBIT_EXPRESSION(Result, &&) + DOCTEST_FORBIT_EXPRESSION(Result, ||) + DOCTEST_FORBIT_EXPRESSION(Result, ==) + DOCTEST_FORBIT_EXPRESSION(Result, !=) + DOCTEST_FORBIT_EXPRESSION(Result, <) + DOCTEST_FORBIT_EXPRESSION(Result, >) + DOCTEST_FORBIT_EXPRESSION(Result, <=) + DOCTEST_FORBIT_EXPRESSION(Result, >=) + DOCTEST_FORBIT_EXPRESSION(Result, =) + DOCTEST_FORBIT_EXPRESSION(Result, +=) + DOCTEST_FORBIT_EXPRESSION(Result, -=) + DOCTEST_FORBIT_EXPRESSION(Result, *=) + DOCTEST_FORBIT_EXPRESSION(Result, /=) + DOCTEST_FORBIT_EXPRESSION(Result, %=) + DOCTEST_FORBIT_EXPRESSION(Result, <<=) + DOCTEST_FORBIT_EXPRESSION(Result, >>=) + DOCTEST_FORBIT_EXPRESSION(Result, &=) + DOCTEST_FORBIT_EXPRESSION(Result, ^=) + DOCTEST_FORBIT_EXPRESSION(Result, |=) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_GCC_SUPPRESS_WARNING_PUSH + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH + // https://stackoverflow.com/questions/39479163 what's the difference between 4018 and 4389 + DOCTEST_MSVC_SUPPRESS_WARNING(4388) // signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4389) // 'operator' : signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4018) // 'expression' : signed/unsigned mismatch + //DOCTEST_MSVC_SUPPRESS_WARNING(4805) // 'operation' : unsafe mix of type 'type' and type 'type' in operation + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + // clang-format off +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE bool +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE typename types::enable_if::value || can_use_op::value, bool>::type + inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } + inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } + inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } + inline bool gt(const char* lhs, const char* rhs) { return String(lhs) > String(rhs); } + inline bool le(const char* lhs, const char* rhs) { return String(lhs) <= String(rhs); } + inline bool ge(const char* lhs, const char* rhs) { return String(lhs) >= String(rhs); } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + DOCTEST_COMPARISON_RETURN_TYPE name(const DOCTEST_REF_WRAP(L) lhs, \ + const DOCTEST_REF_WRAP(R) rhs) { \ + return lhs op rhs; \ + } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) + +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) l == r +#define DOCTEST_CMP_NE(l, r) l != r +#define DOCTEST_CMP_GT(l, r) l > r +#define DOCTEST_CMP_LT(l, r) l < r +#define DOCTEST_CMP_GE(l, r) l >= r +#define DOCTEST_CMP_LE(l, r) l <= r +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) eq(l, r) +#define DOCTEST_CMP_NE(l, r) ne(l, r) +#define DOCTEST_CMP_GT(l, r) gt(l, r) +#define DOCTEST_CMP_LT(l, r) lt(l, r) +#define DOCTEST_CMP_GE(l, r) ge(l, r) +#define DOCTEST_CMP_LE(l, r) le(l, r) +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + + template + // cppcheck-suppress copyCtorAndEqOperator + struct Expression_lhs + { + L lhs; + assertType::Enum m_at; + + explicit Expression_lhs(L&& in, assertType::Enum at) + : lhs(static_cast(in)) + , m_at(at) {} + + DOCTEST_NOINLINE operator Result() { +// this is needed only for MSVC 2015 +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4800) // 'int': forcing value to bool + bool res = static_cast(lhs); +DOCTEST_MSVC_SUPPRESS_WARNING_POP + if(m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional + res = !res; + } + + if(!res || getContextOptions()->success) { + return { res, (DOCTEST_STRINGIFY(lhs)) }; + } + return { res }; + } + + /* This is required for user-defined conversions from Expression_lhs to L */ + operator L() const { return lhs; } + + // clang-format off + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(!=, " != ", DOCTEST_CMP_NE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>, " > ", DOCTEST_CMP_GT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<, " < ", DOCTEST_CMP_LT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>=, " >= ", DOCTEST_CMP_GE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<=, " <= ", DOCTEST_CMP_LE) //!OCLINT bitwise operator in conditional + // clang-format on + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &&) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ||) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, =) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, +=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, -=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, *=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, /=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, %=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |=) + // these 2 are unfortunate because they should be allowed - they have higher precedence over the comparisons, but the + // ExpressionDecomposer class uses the left shift operator to capture the left operand of the binary expression... + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + +#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) +DOCTEST_CLANG_SUPPRESS_WARNING_POP +#endif + + struct DOCTEST_INTERFACE ExpressionDecomposer + { + assertType::Enum m_at; + + ExpressionDecomposer(assertType::Enum at); + + // The right operator for capturing expressions is "<=" instead of "<<" (based on the operator precedence table) + // but then there will be warnings from GCC about "-Wparentheses" and since "_Pragma()" is problematic this will stay for now... + // https://github.com/catchorg/Catch2/issues/870 + // https://github.com/catchorg/Catch2/issues/565 + template + Expression_lhs operator<<(L&& operand) { + return Expression_lhs(static_cast(operand), m_at); + } + + template ::value,void >::type* = nullptr> + Expression_lhs operator<<(const L &operand) { + return Expression_lhs(operand, m_at); + } + }; + + struct DOCTEST_INTERFACE TestSuite + { + const char* m_test_suite = nullptr; + const char* m_description = nullptr; + bool m_skip = false; + bool m_no_breaks = false; + bool m_no_output = false; + bool m_may_fail = false; + bool m_should_fail = false; + int m_expected_failures = 0; + double m_timeout = 0; + + TestSuite& operator*(const char* in); + + template + TestSuite& operator*(const T& in) { + in.fill(*this); + return *this; + } + }; + + using funcType = void (*)(); + + struct DOCTEST_INTERFACE TestCase : public TestCaseData + { + funcType m_test; // a function pointer to the test case + + String m_type; // for templated test cases - gets appended to the real name + int m_template_id; // an ID used to distinguish between the different versions of a templated test case + String m_full_name; // contains the name (only for templated test cases!) + the template type + + TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const String& type = String(), int template_id = -1); + + TestCase(const TestCase& other); + TestCase(TestCase&&) = delete; + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + TestCase& operator=(const TestCase& other); + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& operator=(TestCase&&) = delete; + + TestCase& operator*(const char* in); + + template + TestCase& operator*(const T& in) { + in.fill(*this); + return *this; + } + + bool operator<(const TestCase& other) const; + + ~TestCase() = default; + }; + + // forward declarations of functions used by the macros + DOCTEST_INTERFACE int regTest(const TestCase& tc); + DOCTEST_INTERFACE int setTestSuite(const TestSuite& ts); + DOCTEST_INTERFACE bool isDebuggerActive(); + + template + int instantiationHelper(const T&) { return 0; } + + namespace binaryAssertComparison { + enum Enum + { + eq = 0, + ne, + gt, + lt, + ge, + le + }; + } // namespace binaryAssertComparison + + // clang-format off + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L), const DOCTEST_REF_WRAP(R) ) const { return false; } }; + +#define DOCTEST_BINARY_RELATIONAL_OP(n, op) \ + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) const { return op(lhs, rhs); } }; + // clang-format on + + DOCTEST_BINARY_RELATIONAL_OP(0, doctest::detail::eq) + DOCTEST_BINARY_RELATIONAL_OP(1, doctest::detail::ne) + DOCTEST_BINARY_RELATIONAL_OP(2, doctest::detail::gt) + DOCTEST_BINARY_RELATIONAL_OP(3, doctest::detail::lt) + DOCTEST_BINARY_RELATIONAL_OP(4, doctest::detail::ge) + DOCTEST_BINARY_RELATIONAL_OP(5, doctest::detail::le) + + struct DOCTEST_INTERFACE ResultBuilder : public AssertData + { + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type = "", const String& exception_string = ""); + + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const Contains& exception_string); + + void setResult(const Result& res); + + template + DOCTEST_NOINLINE bool binary_assert(const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + m_failed = !RelationalComparator()(lhs, rhs); + if (m_failed || getContextOptions()->success) { + m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); + } + return !m_failed; + } + + template + DOCTEST_NOINLINE bool unary_assert(const DOCTEST_REF_WRAP(L) val) { + m_failed = !val; + + if (m_at & assertType::is_false) { //!OCLINT bitwise operator in conditional + m_failed = !m_failed; + } + + if (m_failed || getContextOptions()->success) { + m_decomp = (DOCTEST_STRINGIFY(val)); + } + + return !m_failed; + } + + void translateException(); + + bool log(); + void react() const; + }; + + namespace assertAction { + enum Enum + { + nothing = 0, + dbgbreak = 1, + shouldthrow = 2 + }; + } // namespace assertAction + + DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); + + DOCTEST_INTERFACE bool decomp_assert(assertType::Enum at, const char* file, int line, + const char* expr, const Result& result); + +#define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ + do { \ + if(!is_running_in_test) { \ + if(failed) { \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + rb.m_decomp = decomp; \ + failed_out_of_a_testing_context(rb); \ + if(isDebuggerActive() && !getContextOptions()->no_breaks) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(checkIfShouldThrow(at)) \ + throwException(); \ + } \ + return !failed; \ + } \ + } while(false) + +#define DOCTEST_ASSERT_IN_TESTS(decomp) \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + if(rb.m_failed || getContextOptions()->success) \ + rb.m_decomp = decomp; \ + if(rb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(rb.m_failed && checkIfShouldThrow(at)) \ + throwException() + + template + DOCTEST_NOINLINE bool binary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + bool failed = !RelationalComparator()(lhs, rhs); + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + return !failed; + } + + template + DOCTEST_NOINLINE bool unary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) val) { + bool failed = !val; + + if(at & assertType::is_false) //!OCLINT bitwise operator in conditional + failed = !failed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS((DOCTEST_STRINGIFY(val))); + DOCTEST_ASSERT_IN_TESTS((DOCTEST_STRINGIFY(val))); + return !failed; + } + + struct DOCTEST_INTERFACE IExceptionTranslator + { + DOCTEST_DECLARE_INTERFACE(IExceptionTranslator) + virtual bool translate(String&) const = 0; + }; + + template + class ExceptionTranslator : public IExceptionTranslator //!OCLINT destructor of virtual class + { + public: + explicit ExceptionTranslator(String (*translateFunction)(T)) + : m_translateFunction(translateFunction) {} + + bool translate(String& res) const override { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { + throw; // lgtm [cpp/rethrow-no-exception] + // cppcheck-suppress catchExceptionByValue + } catch(const T& ex) { + res = m_translateFunction(ex); //!OCLINT parameter reassignment + return true; + } catch(...) {} //!OCLINT - empty catch statement +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + static_cast(res); // to silence -Wunused-parameter + return false; + } + + private: + String (*m_translateFunction)(T); + }; + + DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); + + // ContextScope base class used to allow implementing methods of ContextScope + // that don't depend on the template parameter in doctest.cpp. + struct DOCTEST_INTERFACE ContextScopeBase : public IContextScope { + ContextScopeBase(const ContextScopeBase&) = delete; + + ContextScopeBase& operator=(const ContextScopeBase&) = delete; + ContextScopeBase& operator=(ContextScopeBase&&) = delete; + + ~ContextScopeBase() override = default; + + protected: + ContextScopeBase(); + ContextScopeBase(ContextScopeBase&& other) noexcept; + + void destroy(); + bool need_to_destroy{true}; + }; + + template class ContextScope : public ContextScopeBase + { + L lambda_; + + public: + explicit ContextScope(const L &lambda) : lambda_(lambda) {} + explicit ContextScope(L&& lambda) : lambda_(static_cast(lambda)) { } + + ContextScope(const ContextScope&) = delete; + ContextScope(ContextScope&&) noexcept = default; + + ContextScope& operator=(const ContextScope&) = delete; + ContextScope& operator=(ContextScope&&) = delete; + + void stringify(std::ostream* s) const override { lambda_(s); } + + ~ContextScope() override { + if (need_to_destroy) { + destroy(); + } + } + }; + + struct DOCTEST_INTERFACE MessageBuilder : public MessageData + { + std::ostream* m_stream; + bool logged = false; + + MessageBuilder(const char* file, int line, assertType::Enum severity); + + MessageBuilder(const MessageBuilder&) = delete; + MessageBuilder(MessageBuilder&&) = delete; + + MessageBuilder& operator=(const MessageBuilder&) = delete; + MessageBuilder& operator=(MessageBuilder&&) = delete; + + ~MessageBuilder(); + + // the preferred way of chaining parameters for stringification +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4866) + template + MessageBuilder& operator,(const T& in) { + *m_stream << (DOCTEST_STRINGIFY(in)); + return *this; + } +DOCTEST_MSVC_SUPPRESS_WARNING_POP + + // kept here just for backwards-compatibility - the comma operator should be preferred now + template + MessageBuilder& operator<<(const T& in) { return this->operator,(in); } + + // the `,` operator has the lowest operator precedence - if `<<` is used by the user then + // the `,` operator will be called last which is not what we want and thus the `*` operator + // is used first (has higher operator precedence compared to `<<`) so that we guarantee that + // an operator of the MessageBuilder class is called first before the rest of the parameters + template + MessageBuilder& operator*(const T& in) { return this->operator,(in); } + + bool log(); + void react(); + }; + + template + ContextScope MakeContextScope(const L &lambda) { + return ContextScope(lambda); + } +} // namespace detail + +#define DOCTEST_DEFINE_DECORATOR(name, type, def) \ + struct name \ + { \ + type data; \ + name(type in = def) \ + : data(in) {} \ + void fill(detail::TestCase& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + void fill(detail::TestSuite& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + } + +DOCTEST_DEFINE_DECORATOR(test_suite, const char*, ""); +DOCTEST_DEFINE_DECORATOR(description, const char*, ""); +DOCTEST_DEFINE_DECORATOR(skip, bool, true); +DOCTEST_DEFINE_DECORATOR(no_breaks, bool, true); +DOCTEST_DEFINE_DECORATOR(no_output, bool, true); +DOCTEST_DEFINE_DECORATOR(timeout, double, 0); +DOCTEST_DEFINE_DECORATOR(may_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(should_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(expected_failures, int, 0); + +template +int registerExceptionTranslator(String (*translateFunction)(T)) { + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") + static detail::ExceptionTranslator exceptionTranslator(translateFunction); + DOCTEST_CLANG_SUPPRESS_WARNING_POP + detail::registerExceptionTranslatorImpl(&exceptionTranslator); + return 0; +} + +} // namespace doctest + +// in a separate namespace outside of doctest because the DOCTEST_TEST_SUITE macro +// introduces an anonymous namespace in which getCurrentTestSuite gets overridden +namespace doctest_detail_test_suite_ns { +DOCTEST_INTERFACE doctest::detail::TestSuite& getCurrentTestSuite(); +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +#else // DOCTEST_CONFIG_DISABLE +template +int registerExceptionTranslator(String (*)(T)) { + return 0; +} +#endif // DOCTEST_CONFIG_DISABLE + +namespace detail { + using assert_handler = void (*)(const AssertData&); + struct ContextState; +} // namespace detail + +class DOCTEST_INTERFACE Context +{ + detail::ContextState* p; + + void parseArgs(int argc, const char* const* argv, bool withDefaults = false); + +public: + explicit Context(int argc = 0, const char* const* argv = nullptr); + + Context(const Context&) = delete; + Context(Context&&) = delete; + + Context& operator=(const Context&) = delete; + Context& operator=(Context&&) = delete; + + ~Context(); // NOLINT(performance-trivially-destructible) + + void applyCommandLine(int argc, const char* const* argv); + + void addFilter(const char* filter, const char* value); + void clearFilters(); + void setOption(const char* option, bool value); + void setOption(const char* option, int value); + void setOption(const char* option, const char* value); + + bool shouldExit(); + + void setAsDefaultForAssertsOutOfTestCases(); + + void setAssertHandler(detail::assert_handler ah); + + void setCout(std::ostream* out); + + int run(); +}; + +namespace TestCaseFailureReason { + enum Enum + { + None = 0, + AssertFailure = 1, // an assertion has failed in the test case + Exception = 2, // test case threw an exception + Crash = 4, // a crash... + TooManyFailedAsserts = 8, // the abort-after option + Timeout = 16, // see the timeout decorator + ShouldHaveFailedButDidnt = 32, // see the should_fail decorator + ShouldHaveFailedAndDid = 64, // see the should_fail decorator + DidntFailExactlyNumTimes = 128, // see the expected_failures decorator + FailedExactlyNumTimes = 256, // see the expected_failures decorator + CouldHaveFailedAndDid = 512 // see the may_fail decorator + }; +} // namespace TestCaseFailureReason + +struct DOCTEST_INTERFACE CurrentTestCaseStats +{ + int numAssertsCurrentTest; + int numAssertsFailedCurrentTest; + double seconds; + int failure_flags; // use TestCaseFailureReason::Enum + bool testCaseSuccess; +}; + +struct DOCTEST_INTERFACE TestCaseException +{ + String error_string; + bool is_crash; +}; + +struct DOCTEST_INTERFACE TestRunStats +{ + unsigned numTestCases; + unsigned numTestCasesPassingFilters; + unsigned numTestSuitesPassingFilters; + unsigned numTestCasesFailed; + int numAsserts; + int numAssertsFailed; +}; + +struct QueryData +{ + const TestRunStats* run_stats = nullptr; + const TestCaseData** data = nullptr; + unsigned num_data = 0; +}; + +struct DOCTEST_INTERFACE IReporter +{ + // The constructor has to accept "const ContextOptions&" as a single argument + // which has most of the options for the run + a pointer to the stdout stream + // Reporter(const ContextOptions& in) + + // called when a query should be reported (listing test cases, printing the version, etc.) + virtual void report_query(const QueryData&) = 0; + + // called when the whole test run starts + virtual void test_run_start() = 0; + // called when the whole test run ends (caching a pointer to the input doesn't make sense here) + virtual void test_run_end(const TestRunStats&) = 0; + + // called when a test case is started (safe to cache a pointer to the input) + virtual void test_case_start(const TestCaseData&) = 0; + // called when a test case is reentered because of unfinished subcases (safe to cache a pointer to the input) + virtual void test_case_reenter(const TestCaseData&) = 0; + // called when a test case has ended + virtual void test_case_end(const CurrentTestCaseStats&) = 0; + + // called when an exception is thrown from the test case (or it crashes) + virtual void test_case_exception(const TestCaseException&) = 0; + + // called whenever a subcase is entered (don't cache pointers to the input) + virtual void subcase_start(const SubcaseSignature&) = 0; + // called whenever a subcase is exited (don't cache pointers to the input) + virtual void subcase_end() = 0; + + // called for each assert (don't cache pointers to the input) + virtual void log_assert(const AssertData&) = 0; + // called for each message (don't cache pointers to the input) + virtual void log_message(const MessageData&) = 0; + + // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator + // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) + virtual void test_case_skipped(const TestCaseData&) = 0; + + DOCTEST_DECLARE_INTERFACE(IReporter) + + // can obtain all currently active contexts and stringify them if one wishes to do so + static int get_num_active_contexts(); + static const IContextScope* const* get_active_contexts(); + + // can iterate through contexts which have been stringified automatically in their destructors when an exception has been thrown + static int get_num_stringified_contexts(); + static const String* get_stringified_contexts(); +}; + +namespace detail { + using reporterCreatorFunc = IReporter* (*)(const ContextOptions&); + + DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); + + template + IReporter* reporterCreator(const ContextOptions& o) { + return new Reporter(o); + } +} // namespace detail + +template +int registerReporter(const char* name, int priority, bool isReporter) { + detail::registerReporterImpl(name, priority, detail::reporterCreator, isReporter); + return 0; +} +} // namespace doctest + +#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES +#define DOCTEST_FUNC_EMPTY [] { return false; }() +#else +#define DOCTEST_FUNC_EMPTY (void)0 +#endif + +// if registering is not disabled +#ifndef DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_ASSERTS_RETURN_VALUES +#define DOCTEST_FUNC_SCOPE_BEGIN [&] +#define DOCTEST_FUNC_SCOPE_END () +#define DOCTEST_FUNC_SCOPE_RET(v) return v +#else +#define DOCTEST_FUNC_SCOPE_BEGIN do +#define DOCTEST_FUNC_SCOPE_END while(false) +#define DOCTEST_FUNC_SCOPE_RET(v) (void)0 +#endif + +// common code in asserts - for convenience +#define DOCTEST_ASSERT_LOG_REACT_RETURN(b) \ + if(b.log()) DOCTEST_BREAK_INTO_DEBUGGER(); \ + b.react(); \ + DOCTEST_FUNC_SCOPE_RET(!b.m_failed) + +#ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) x; +#else // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) \ + try { \ + x; \ + } catch(...) { DOCTEST_RB.translateException(); } +#endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wuseless-cast") \ + static_cast(__VA_ARGS__); \ + DOCTEST_GCC_SUPPRESS_WARNING_POP +#else // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) __VA_ARGS__; +#endif // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS + +// registers the test by initializing a dummy var with a function +#define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ + global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT */ \ + doctest::detail::regTest( \ + doctest::detail::TestCase( \ + f, __FILE__, __LINE__, \ + doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ + decorators)) + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ + namespace { /* NOLINT */ \ + struct der : public base \ + { \ + void f(); \ + }; \ + static DOCTEST_INLINE_NOINLINE void func() { \ + der v; \ + v.f(); \ + } \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ + } \ + DOCTEST_INLINE_NOINLINE void der::f() // NOLINT(misc-definitions-in-headers) + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ + static void f(); \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, f, decorators) \ + static void f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ + static doctest::detail::funcType proxy() { return f; } \ + DOCTEST_REGISTER_FUNCTION(inline, proxy(), decorators) \ + static void f() + +// for registering tests +#define DOCTEST_TEST_CASE(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) + +// for registering tests in classes - requires C++17 for inline variables! +#if DOCTEST_CPLUSPLUS >= 201703L +#define DOCTEST_TEST_CASE_CLASS(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_PROXY_), \ + decorators) +#else // DOCTEST_TEST_CASE_CLASS +#define DOCTEST_TEST_CASE_CLASS(...) \ + TEST_CASES_CAN_BE_REGISTERED_IN_CLASSES_ONLY_IN_CPP17_MODE_OR_WITH_VS_2017_OR_NEWER +#endif // DOCTEST_TEST_CASE_CLASS + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), c, \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), decorators) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_AS(str, ...) \ + namespace doctest { \ + template <> \ + inline String toString<__VA_ARGS__>() { \ + return str; \ + } \ + } \ + static_assert(true, "") + +#define DOCTEST_TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING_AS(#__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ + template \ + static void func(); \ + namespace { /* NOLINT */ \ + template \ + struct iter; \ + template \ + struct iter> \ + { \ + iter(const char* file, unsigned line, int index) { \ + doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ + doctest_detail_test_suite_ns::getCurrentTestSuite(), \ + doctest::toString(), \ + int(line) * 1000 + index) \ + * dec); \ + iter>(file, line, index + 1); \ + } \ + }; \ + template <> \ + struct iter> \ + { \ + iter(const char*, unsigned, int) {} \ + }; \ + } \ + template \ + static void func() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)) + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY), /* NOLINT(cert-err58-cpp, fuchsia-statically-constructed-objects) */ \ + doctest::detail::instantiationHelper( \ + DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0))) + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ + static_assert(true, "") + +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) \ + static_assert(true, "") + +#define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(anon, anon, std::tuple<__VA_ARGS__>) \ + template \ + static void anon() + +#define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_), __VA_ARGS__) + +// for subcases +#define DOCTEST_SUBCASE(name) \ + if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ + doctest::detail::Subcase(name, __FILE__, __LINE__)) + +// for grouping tests in test suites by using code blocks +#define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ + namespace ns_name { namespace doctest_detail_test_suite_ns { \ + static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() noexcept { \ + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmissing-field-initializers") \ + static doctest::detail::TestSuite data{}; \ + static bool inited = false; \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + DOCTEST_GCC_SUPPRESS_WARNING_POP \ + if(!inited) { \ + data* decorators; \ + inited = true; \ + } \ + return data; \ + } \ + } \ + } \ + namespace ns_name + +#define DOCTEST_TEST_SUITE(decorators) \ + DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(DOCTEST_ANON_SUITE_)) + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(decorators) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators)) \ + static_assert(true, "") + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_VAR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * "")) \ + using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int + +// for registering exception translators +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ + inline doctest::String translatorName(signature); \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerExceptionTranslator(translatorName)) \ + doctest::String translatorName(signature) + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_), \ + signature) + +// for registering reporters +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerReporter(name, priority, true)) \ + static_assert(true, "") + +// for registering listeners +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(DOCTEST_ANON_REPORTER_), /* NOLINT(cert-err58-cpp) */ \ + doctest::registerReporter(name, priority, false)) \ + static_assert(true, "") + +// clang-format off +// for logging - disabling formatting because it's important to have these on 2 separate lines - see PR #557 +#define DOCTEST_INFO(...) \ + DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_), \ + DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_OTHER_), \ + __VA_ARGS__) +// clang-format on + +#define DOCTEST_INFO_IMPL(mb_name, s_name, ...) \ + auto DOCTEST_ANONYMOUS(DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ + [&](std::ostream* s_name) { \ + doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ + mb_name.m_stream = s_name; \ + mb_name * __VA_ARGS__; \ + }) + +#define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := ", x) + +#define DOCTEST_ADD_AT_IMPL(type, file, line, mb, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ + mb * __VA_ARGS__; \ + if(mb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + mb.react(); \ + } DOCTEST_FUNC_SCOPE_END + +// clang-format off +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(DOCTEST_MESSAGE_), __VA_ARGS__) +// clang-format on + +#define DOCTEST_MESSAGE(...) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, __VA_ARGS__) +#define DOCTEST_FAIL_CHECK(...) DOCTEST_ADD_FAIL_CHECK_AT(__FILE__, __LINE__, __VA_ARGS__) +#define DOCTEST_FAIL(...) DOCTEST_ADD_FAIL_AT(__FILE__, __LINE__, __VA_ARGS__) + +#define DOCTEST_TO_LVALUE(...) __VA_ARGS__ // Not removed to keep backwards compatibility. + +#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(DOCTEST_RB.setResult( \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__)) /* NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) */ \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB) \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ + } DOCTEST_FUNC_SCOPE_END // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + +#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY( \ + DOCTEST_RB.binary_assert( \ + __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(DOCTEST_RB.unary_assert(__VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +// necessary for _MESSAGE +#define DOCTEST_ASSERT_IMPLEMENT_2 DOCTEST_ASSERT_IMPLEMENT_1 + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + doctest::detail::decomp_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ + doctest::detail::binary_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ + #__VA_ARGS__, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) +#define DOCTEST_CHECK(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK, __VA_ARGS__) +#define DOCTEST_REQUIRE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE, __VA_ARGS__) +#define DOCTEST_WARN_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) + +// clang-format off +#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } DOCTEST_FUNC_SCOPE_END +// clang-format on + +#define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) +#define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) +#define DOCTEST_WARN_NE(...) DOCTEST_BINARY_ASSERT(DT_WARN_NE, ne, __VA_ARGS__) +#define DOCTEST_CHECK_NE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_NE, ne, __VA_ARGS__) +#define DOCTEST_REQUIRE_NE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_NE, ne, __VA_ARGS__) +#define DOCTEST_WARN_GT(...) DOCTEST_BINARY_ASSERT(DT_WARN_GT, gt, __VA_ARGS__) +#define DOCTEST_CHECK_GT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GT, gt, __VA_ARGS__) +#define DOCTEST_REQUIRE_GT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GT, gt, __VA_ARGS__) +#define DOCTEST_WARN_LT(...) DOCTEST_BINARY_ASSERT(DT_WARN_LT, lt, __VA_ARGS__) +#define DOCTEST_CHECK_LT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LT, lt, __VA_ARGS__) +#define DOCTEST_REQUIRE_LT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LT, lt, __VA_ARGS__) +#define DOCTEST_WARN_GE(...) DOCTEST_BINARY_ASSERT(DT_WARN_GE, ge, __VA_ARGS__) +#define DOCTEST_CHECK_GE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GE, ge, __VA_ARGS__) +#define DOCTEST_REQUIRE_GE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GE, ge, __VA_ARGS__) +#define DOCTEST_WARN_LE(...) DOCTEST_BINARY_ASSERT(DT_WARN_LE, le, __VA_ARGS__) +#define DOCTEST_CHECK_LE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LE, le, __VA_ARGS__) +#define DOCTEST_REQUIRE_LE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LE, le, __VA_ARGS__) + +#define DOCTEST_WARN_UNARY(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY, __VA_ARGS__) +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #expr, #__VA_ARGS__, message); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(const typename doctest::detail::types::remove_const< \ + typename doctest::detail::types::remove_reference<__VA_ARGS__>::type>::type&) {\ + DOCTEST_RB.translateException(); \ + DOCTEST_RB.m_threw_as = true; \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } else { /* NOLINT(*-else-after-return) */ \ + DOCTEST_FUNC_SCOPE_RET(false); \ + } \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, expr_str, "", __VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } else { /* NOLINT(*-else-after-return) */ \ + DOCTEST_FUNC_SCOPE_RET(false); \ + } \ + } DOCTEST_FUNC_SCOPE_END + +#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ + DOCTEST_FUNC_SCOPE_BEGIN { \ + doctest::detail::ResultBuilder DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ + } catch(...) { DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_REACT_RETURN(DOCTEST_RB); \ + } DOCTEST_FUNC_SCOPE_END + +// clang-format off +#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") +#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") + +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) + +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_SCOPE_BEGIN { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } DOCTEST_FUNC_SCOPE_END +// clang-format on + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// ================================================================================================= +// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == +// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == +// ================================================================================================= +#else // DOCTEST_CONFIG_DISABLE + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ + namespace /* NOLINT */ { \ + template \ + struct der : public base \ + { void f(); }; \ + } \ + template \ + inline void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ + template \ + static inline void f() + +// for registering tests +#define DOCTEST_TEST_CASE(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for registering tests in classes +#define DOCTEST_TEST_CASE_CLASS(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(DOCTEST_ANON_CLASS_), x, \ + DOCTEST_ANONYMOUS(DOCTEST_ANON_FUNC_), name) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_AS(str, ...) static_assert(true, "") +#define DOCTEST_TYPE_TO_STRING(...) static_assert(true, "") + +// for typed tests +#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ + template \ + inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ + template \ + inline void DOCTEST_ANONYMOUS(DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) static_assert(true, "") +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) static_assert(true, "") + +// for subcases +#define DOCTEST_SUBCASE(name) + +// for a testsuite block +#define DOCTEST_TEST_SUITE(name) namespace // NOLINT + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(name) static_assert(true, "") + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END using DOCTEST_ANONYMOUS(DOCTEST_ANON_FOR_SEMICOLON_) = int + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + template \ + static inline doctest::String DOCTEST_ANONYMOUS(DOCTEST_ANON_TRANSLATOR_)(signature) + +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) + +#define DOCTEST_INFO(...) (static_cast(0)) +#define DOCTEST_CAPTURE(x) (static_cast(0)) +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_MESSAGE(...) (static_cast(0)) +#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) +#define DOCTEST_FAIL(...) (static_cast(0)) + +#if defined(DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED) \ + && defined(DOCTEST_CONFIG_ASSERTS_RETURN_VALUES) + +#define DOCTEST_WARN(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_CHECK(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_REQUIRE(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_WARN_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_CHECK_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_FALSE(...) [&] { return !(__VA_ARGS__); }() + +#define DOCTEST_WARN_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_CHECK_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) [&] { return cond; }() +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) [&] { return !(cond); }() + +namespace doctest { +namespace detail { +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + bool name(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) { return lhs op rhs; } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) +} // namespace detail +} // namespace doctest + +#define DOCTEST_WARN_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_CHECK_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_EQ(...) [&] { return doctest::detail::eq(__VA_ARGS__); }() +#define DOCTEST_WARN_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_CHECK_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_NE(...) [&] { return doctest::detail::ne(__VA_ARGS__); }() +#define DOCTEST_WARN_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_CHECK_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_LT(...) [&] { return doctest::detail::lt(__VA_ARGS__); }() +#define DOCTEST_WARN_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_CHECK_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_GT(...) [&] { return doctest::detail::gt(__VA_ARGS__); }() +#define DOCTEST_WARN_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_CHECK_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_LE(...) [&] { return doctest::detail::le(__VA_ARGS__); }() +#define DOCTEST_WARN_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_CHECK_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_GE(...) [&] { return doctest::detail::ge(__VA_ARGS__); }() +#define DOCTEST_WARN_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_CHECK_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_REQUIRE_UNARY(...) [&] { return __VA_ARGS__; }() +#define DOCTEST_WARN_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_CHECK_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() +#define DOCTEST_REQUIRE_UNARY_FALSE(...) [&] { return !(__VA_ARGS__); }() + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_WARN_THROWS_WITH(expr, with, ...) [] { static_assert(false, "Exception translation is not available when doctest is disabled."); return false; }() +#define DOCTEST_CHECK_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) + +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH(,,) + +#define DOCTEST_WARN_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_CHECK_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_REQUIRE_THROWS(...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_WARN_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_CHECK_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_WARN_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_CHECK_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_REQUIRE_NOTHROW(...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return false; } catch (...) { return true; } }() +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) [&] { try { expr; } catch (__VA_ARGS__) { return true; } catch (...) { } return false; }() +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) [&] { try { __VA_ARGS__; return true; } catch (...) { return false; } }() + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#else // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED + +#define DOCTEST_WARN(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_GT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_LT(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_GE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_LE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_LE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_LE(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_FUNC_EMPTY + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + +#define DOCTEST_WARN_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_FUNC_EMPTY + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_FUNC_EMPTY + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#endif // DOCTEST_CONFIG_EVALUATE_ASSERTS_EVEN_WHEN_DISABLED + +#endif // DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#define DOCTEST_EXCEPTION_EMPTY_FUNC DOCTEST_FUNC_EMPTY +#else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#define DOCTEST_EXCEPTION_EMPTY_FUNC [] { static_assert(false, "Exceptions are disabled! " \ + "Use DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS if you want to compile with exceptions disabled."); return false; }() + +#undef DOCTEST_REQUIRE +#undef DOCTEST_REQUIRE_FALSE +#undef DOCTEST_REQUIRE_MESSAGE +#undef DOCTEST_REQUIRE_FALSE_MESSAGE +#undef DOCTEST_REQUIRE_EQ +#undef DOCTEST_REQUIRE_NE +#undef DOCTEST_REQUIRE_GT +#undef DOCTEST_REQUIRE_LT +#undef DOCTEST_REQUIRE_GE +#undef DOCTEST_REQUIRE_LE +#undef DOCTEST_REQUIRE_UNARY +#undef DOCTEST_REQUIRE_UNARY_FALSE + +#define DOCTEST_REQUIRE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_FALSE_MESSAGE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_EQ DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_GT DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_LT DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_GE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_LE DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_UNARY DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_UNARY_FALSE DOCTEST_EXCEPTION_EMPTY_FUNC + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#define DOCTEST_WARN_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_EXCEPTION_EMPTY_FUNC + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_EXCEPTION_EMPTY_FUNC + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// clang-format off +// KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS +#define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ +#define DOCTEST_FAST_CHECK_EQ DOCTEST_CHECK_EQ +#define DOCTEST_FAST_REQUIRE_EQ DOCTEST_REQUIRE_EQ +#define DOCTEST_FAST_WARN_NE DOCTEST_WARN_NE +#define DOCTEST_FAST_CHECK_NE DOCTEST_CHECK_NE +#define DOCTEST_FAST_REQUIRE_NE DOCTEST_REQUIRE_NE +#define DOCTEST_FAST_WARN_GT DOCTEST_WARN_GT +#define DOCTEST_FAST_CHECK_GT DOCTEST_CHECK_GT +#define DOCTEST_FAST_REQUIRE_GT DOCTEST_REQUIRE_GT +#define DOCTEST_FAST_WARN_LT DOCTEST_WARN_LT +#define DOCTEST_FAST_CHECK_LT DOCTEST_CHECK_LT +#define DOCTEST_FAST_REQUIRE_LT DOCTEST_REQUIRE_LT +#define DOCTEST_FAST_WARN_GE DOCTEST_WARN_GE +#define DOCTEST_FAST_CHECK_GE DOCTEST_CHECK_GE +#define DOCTEST_FAST_REQUIRE_GE DOCTEST_REQUIRE_GE +#define DOCTEST_FAST_WARN_LE DOCTEST_WARN_LE +#define DOCTEST_FAST_CHECK_LE DOCTEST_CHECK_LE +#define DOCTEST_FAST_REQUIRE_LE DOCTEST_REQUIRE_LE + +#define DOCTEST_FAST_WARN_UNARY DOCTEST_WARN_UNARY +#define DOCTEST_FAST_CHECK_UNARY DOCTEST_CHECK_UNARY +#define DOCTEST_FAST_REQUIRE_UNARY DOCTEST_REQUIRE_UNARY +#define DOCTEST_FAST_WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE +#define DOCTEST_FAST_CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE +#define DOCTEST_FAST_REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id,__VA_ARGS__) +// clang-format on + +// BDD style macros +// clang-format off +#define DOCTEST_SCENARIO(name) DOCTEST_TEST_CASE(" Scenario: " name) +#define DOCTEST_SCENARIO_CLASS(name) DOCTEST_TEST_CASE_CLASS(" Scenario: " name) +#define DOCTEST_SCENARIO_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(" Scenario: " name, T, __VA_ARGS__) +#define DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(" Scenario: " name, T, id) + +#define DOCTEST_GIVEN(name) DOCTEST_SUBCASE(" Given: " name) +#define DOCTEST_WHEN(name) DOCTEST_SUBCASE(" When: " name) +#define DOCTEST_AND_WHEN(name) DOCTEST_SUBCASE("And when: " name) +#define DOCTEST_THEN(name) DOCTEST_SUBCASE(" Then: " name) +#define DOCTEST_AND_THEN(name) DOCTEST_SUBCASE(" And: " name) +// clang-format on + +// == SHORT VERSIONS OF THE MACROS +#ifndef DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES + +#define TEST_CASE(name) DOCTEST_TEST_CASE(name) +#define TEST_CASE_CLASS(name) DOCTEST_TEST_CASE_CLASS(name) +#define TEST_CASE_FIXTURE(x, name) DOCTEST_TEST_CASE_FIXTURE(x, name) +#define TYPE_TO_STRING_AS(str, ...) DOCTEST_TYPE_TO_STRING_AS(str, __VA_ARGS__) +#define TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING(__VA_ARGS__) +#define TEST_CASE_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(name, T, __VA_ARGS__) +#define TEST_CASE_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, T, id) +#define TEST_CASE_TEMPLATE_INVOKE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, __VA_ARGS__) +#define TEST_CASE_TEMPLATE_APPLY(id, ...) DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, __VA_ARGS__) +#define SUBCASE(name) DOCTEST_SUBCASE(name) +#define TEST_SUITE(decorators) DOCTEST_TEST_SUITE(decorators) +#define TEST_SUITE_BEGIN(name) DOCTEST_TEST_SUITE_BEGIN(name) +#define TEST_SUITE_END DOCTEST_TEST_SUITE_END +#define REGISTER_EXCEPTION_TRANSLATOR(signature) DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) +#define REGISTER_REPORTER(name, priority, reporter) DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define REGISTER_LISTENER(name, priority, reporter) DOCTEST_REGISTER_LISTENER(name, priority, reporter) +#define INFO(...) DOCTEST_INFO(__VA_ARGS__) +#define CAPTURE(x) DOCTEST_CAPTURE(x) +#define ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_MESSAGE_AT(file, line, __VA_ARGS__) +#define ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_FAIL_CHECK_AT(file, line, __VA_ARGS__) +#define ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_FAIL_AT(file, line, __VA_ARGS__) +#define MESSAGE(...) DOCTEST_MESSAGE(__VA_ARGS__) +#define FAIL_CHECK(...) DOCTEST_FAIL_CHECK(__VA_ARGS__) +#define FAIL(...) DOCTEST_FAIL(__VA_ARGS__) +#define TO_LVALUE(...) DOCTEST_TO_LVALUE(__VA_ARGS__) + +#define WARN(...) DOCTEST_WARN(__VA_ARGS__) +#define WARN_FALSE(...) DOCTEST_WARN_FALSE(__VA_ARGS__) +#define WARN_THROWS(...) DOCTEST_WARN_THROWS(__VA_ARGS__) +#define WARN_THROWS_AS(expr, ...) DOCTEST_WARN_THROWS_AS(expr, __VA_ARGS__) +#define WARN_THROWS_WITH(expr, ...) DOCTEST_WARN_THROWS_WITH(expr, __VA_ARGS__) +#define WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_WARN_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define WARN_NOTHROW(...) DOCTEST_WARN_NOTHROW(__VA_ARGS__) +#define CHECK(...) DOCTEST_CHECK(__VA_ARGS__) +#define CHECK_FALSE(...) DOCTEST_CHECK_FALSE(__VA_ARGS__) +#define CHECK_THROWS(...) DOCTEST_CHECK_THROWS(__VA_ARGS__) +#define CHECK_THROWS_AS(expr, ...) DOCTEST_CHECK_THROWS_AS(expr, __VA_ARGS__) +#define CHECK_THROWS_WITH(expr, ...) DOCTEST_CHECK_THROWS_WITH(expr, __VA_ARGS__) +#define CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define CHECK_NOTHROW(...) DOCTEST_CHECK_NOTHROW(__VA_ARGS__) +#define REQUIRE(...) DOCTEST_REQUIRE(__VA_ARGS__) +#define REQUIRE_FALSE(...) DOCTEST_REQUIRE_FALSE(__VA_ARGS__) +#define REQUIRE_THROWS(...) DOCTEST_REQUIRE_THROWS(__VA_ARGS__) +#define REQUIRE_THROWS_AS(expr, ...) DOCTEST_REQUIRE_THROWS_AS(expr, __VA_ARGS__) +#define REQUIRE_THROWS_WITH(expr, ...) DOCTEST_REQUIRE_THROWS_WITH(expr, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define REQUIRE_NOTHROW(...) DOCTEST_REQUIRE_NOTHROW(__VA_ARGS__) + +#define WARN_MESSAGE(cond, ...) DOCTEST_WARN_MESSAGE(cond, __VA_ARGS__) +#define WARN_FALSE_MESSAGE(cond, ...) DOCTEST_WARN_FALSE_MESSAGE(cond, __VA_ARGS__) +#define WARN_THROWS_MESSAGE(expr, ...) DOCTEST_WARN_THROWS_MESSAGE(expr, __VA_ARGS__) +#define WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_WARN_NOTHROW_MESSAGE(expr, __VA_ARGS__) +#define CHECK_MESSAGE(cond, ...) DOCTEST_CHECK_MESSAGE(cond, __VA_ARGS__) +#define CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_CHECK_FALSE_MESSAGE(cond, __VA_ARGS__) +#define CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_CHECK_THROWS_MESSAGE(expr, __VA_ARGS__) +#define CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_CHECK_NOTHROW_MESSAGE(expr, __VA_ARGS__) +#define REQUIRE_MESSAGE(cond, ...) DOCTEST_REQUIRE_MESSAGE(cond, __VA_ARGS__) +#define REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_REQUIRE_FALSE_MESSAGE(cond, __VA_ARGS__) +#define REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_REQUIRE_THROWS_MESSAGE(expr, __VA_ARGS__) +#define REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, __VA_ARGS__) + +#define SCENARIO(name) DOCTEST_SCENARIO(name) +#define SCENARIO_CLASS(name) DOCTEST_SCENARIO_CLASS(name) +#define SCENARIO_TEMPLATE(name, T, ...) DOCTEST_SCENARIO_TEMPLATE(name, T, __VA_ARGS__) +#define SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) +#define GIVEN(name) DOCTEST_GIVEN(name) +#define WHEN(name) DOCTEST_WHEN(name) +#define AND_WHEN(name) DOCTEST_AND_WHEN(name) +#define THEN(name) DOCTEST_THEN(name) +#define AND_THEN(name) DOCTEST_AND_THEN(name) + +#define WARN_EQ(...) DOCTEST_WARN_EQ(__VA_ARGS__) +#define CHECK_EQ(...) DOCTEST_CHECK_EQ(__VA_ARGS__) +#define REQUIRE_EQ(...) DOCTEST_REQUIRE_EQ(__VA_ARGS__) +#define WARN_NE(...) DOCTEST_WARN_NE(__VA_ARGS__) +#define CHECK_NE(...) DOCTEST_CHECK_NE(__VA_ARGS__) +#define REQUIRE_NE(...) DOCTEST_REQUIRE_NE(__VA_ARGS__) +#define WARN_GT(...) DOCTEST_WARN_GT(__VA_ARGS__) +#define CHECK_GT(...) DOCTEST_CHECK_GT(__VA_ARGS__) +#define REQUIRE_GT(...) DOCTEST_REQUIRE_GT(__VA_ARGS__) +#define WARN_LT(...) DOCTEST_WARN_LT(__VA_ARGS__) +#define CHECK_LT(...) DOCTEST_CHECK_LT(__VA_ARGS__) +#define REQUIRE_LT(...) DOCTEST_REQUIRE_LT(__VA_ARGS__) +#define WARN_GE(...) DOCTEST_WARN_GE(__VA_ARGS__) +#define CHECK_GE(...) DOCTEST_CHECK_GE(__VA_ARGS__) +#define REQUIRE_GE(...) DOCTEST_REQUIRE_GE(__VA_ARGS__) +#define WARN_LE(...) DOCTEST_WARN_LE(__VA_ARGS__) +#define CHECK_LE(...) DOCTEST_CHECK_LE(__VA_ARGS__) +#define REQUIRE_LE(...) DOCTEST_REQUIRE_LE(__VA_ARGS__) +#define WARN_UNARY(...) DOCTEST_WARN_UNARY(__VA_ARGS__) +#define CHECK_UNARY(...) DOCTEST_CHECK_UNARY(__VA_ARGS__) +#define REQUIRE_UNARY(...) DOCTEST_REQUIRE_UNARY(__VA_ARGS__) +#define WARN_UNARY_FALSE(...) DOCTEST_WARN_UNARY_FALSE(__VA_ARGS__) +#define CHECK_UNARY_FALSE(...) DOCTEST_CHECK_UNARY_FALSE(__VA_ARGS__) +#define REQUIRE_UNARY_FALSE(...) DOCTEST_REQUIRE_UNARY_FALSE(__VA_ARGS__) + +// KEPT FOR BACKWARDS COMPATIBILITY +#define FAST_WARN_EQ(...) DOCTEST_FAST_WARN_EQ(__VA_ARGS__) +#define FAST_CHECK_EQ(...) DOCTEST_FAST_CHECK_EQ(__VA_ARGS__) +#define FAST_REQUIRE_EQ(...) DOCTEST_FAST_REQUIRE_EQ(__VA_ARGS__) +#define FAST_WARN_NE(...) DOCTEST_FAST_WARN_NE(__VA_ARGS__) +#define FAST_CHECK_NE(...) DOCTEST_FAST_CHECK_NE(__VA_ARGS__) +#define FAST_REQUIRE_NE(...) DOCTEST_FAST_REQUIRE_NE(__VA_ARGS__) +#define FAST_WARN_GT(...) DOCTEST_FAST_WARN_GT(__VA_ARGS__) +#define FAST_CHECK_GT(...) DOCTEST_FAST_CHECK_GT(__VA_ARGS__) +#define FAST_REQUIRE_GT(...) DOCTEST_FAST_REQUIRE_GT(__VA_ARGS__) +#define FAST_WARN_LT(...) DOCTEST_FAST_WARN_LT(__VA_ARGS__) +#define FAST_CHECK_LT(...) DOCTEST_FAST_CHECK_LT(__VA_ARGS__) +#define FAST_REQUIRE_LT(...) DOCTEST_FAST_REQUIRE_LT(__VA_ARGS__) +#define FAST_WARN_GE(...) DOCTEST_FAST_WARN_GE(__VA_ARGS__) +#define FAST_CHECK_GE(...) DOCTEST_FAST_CHECK_GE(__VA_ARGS__) +#define FAST_REQUIRE_GE(...) DOCTEST_FAST_REQUIRE_GE(__VA_ARGS__) +#define FAST_WARN_LE(...) DOCTEST_FAST_WARN_LE(__VA_ARGS__) +#define FAST_CHECK_LE(...) DOCTEST_FAST_CHECK_LE(__VA_ARGS__) +#define FAST_REQUIRE_LE(...) DOCTEST_FAST_REQUIRE_LE(__VA_ARGS__) + +#define FAST_WARN_UNARY(...) DOCTEST_FAST_WARN_UNARY(__VA_ARGS__) +#define FAST_CHECK_UNARY(...) DOCTEST_FAST_CHECK_UNARY(__VA_ARGS__) +#define FAST_REQUIRE_UNARY(...) DOCTEST_FAST_REQUIRE_UNARY(__VA_ARGS__) +#define FAST_WARN_UNARY_FALSE(...) DOCTEST_FAST_WARN_UNARY_FALSE(__VA_ARGS__) +#define FAST_CHECK_UNARY_FALSE(...) DOCTEST_FAST_CHECK_UNARY_FALSE(__VA_ARGS__) +#define FAST_REQUIRE_UNARY_FALSE(...) DOCTEST_FAST_REQUIRE_UNARY_FALSE(__VA_ARGS__) + +#define TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES + +#ifndef DOCTEST_CONFIG_DISABLE + +// this is here to clear the 'current test suite' for the current translation unit - at the top +DOCTEST_TEST_SUITE_END(); + +#endif // DOCTEST_CONFIG_DISABLE + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_POP + +#endif // DOCTEST_LIBRARY_INCLUDED + +#ifndef DOCTEST_SINGLE_HEADER +#define DOCTEST_SINGLE_HEADER +#endif // DOCTEST_SINGLE_HEADER + +#if defined(DOCTEST_CONFIG_IMPLEMENT) || !defined(DOCTEST_SINGLE_HEADER) + +#ifndef DOCTEST_SINGLE_HEADER +#include "doctest_fwd.h" +#endif // DOCTEST_SINGLE_HEADER + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") + +#ifndef DOCTEST_LIBRARY_IMPLEMENTATION +#define DOCTEST_LIBRARY_IMPLEMENTATION + +DOCTEST_CLANG_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_PUSH + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnonportable-system-include-path") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data +DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled +DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified +DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal +DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch +DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C +DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) +DOCTEST_MSVC_SUPPRESS_WARNING(5245) // unreferenced function with internal linkage has been removed + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN + +// required includes - will go only in one translation unit! +#include +#include +#include +// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/doctest/doctest/pull/37 +#ifdef __BORLANDC__ +#include +#endif // __BORLANDC__ +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM +#include +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM +#include +#include +#include +#ifndef DOCTEST_CONFIG_NO_MULTITHREADING +#include +#include +#define DOCTEST_DECLARE_MUTEX(name) std::mutex name; +#define DOCTEST_DECLARE_STATIC_MUTEX(name) static DOCTEST_DECLARE_MUTEX(name) +#define DOCTEST_LOCK_MUTEX(name) std::lock_guard DOCTEST_ANONYMOUS(DOCTEST_ANON_LOCK_)(name); +#else // DOCTEST_CONFIG_NO_MULTITHREADING +#define DOCTEST_DECLARE_MUTEX(name) +#define DOCTEST_DECLARE_STATIC_MUTEX(name) +#define DOCTEST_LOCK_MUTEX(name) +#endif // DOCTEST_CONFIG_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef DOCTEST_PLATFORM_MAC +#include +#include +#include +#endif // DOCTEST_PLATFORM_MAC + +#ifdef DOCTEST_PLATFORM_WINDOWS + +// defines for a leaner windows.h +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#define DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#define DOCTEST_UNDEF_NOMINMAX +#endif // NOMINMAX + +// not sure what AfxWin.h is for - here I do what Catch does +#ifdef __AFXDLL +#include +#else +#include +#endif +#include + +#else // DOCTEST_PLATFORM_WINDOWS + +#include +#include + +#endif // DOCTEST_PLATFORM_WINDOWS + +// this is a fix for https://github.com/doctest/doctest/issues/348 +// https://mail.gnome.org/archives/xml/2012-January/msg00000.html +#if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) +#define STDOUT_FILENO fileno(stdout) +#endif // HAVE_UNISTD_H + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END + +// counts the number of elements in a C array +#define DOCTEST_COUNTOF(x) (sizeof(x) / sizeof(x[0])) + +#ifdef DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_disabled +#else // DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_not_disabled +#endif // DOCTEST_CONFIG_DISABLE + +#ifndef DOCTEST_CONFIG_OPTIONS_PREFIX +#define DOCTEST_CONFIG_OPTIONS_PREFIX "dt-" +#endif + +#ifndef DOCTEST_THREAD_LOCAL +#if defined(DOCTEST_CONFIG_NO_MULTITHREADING) || DOCTEST_MSVC && (DOCTEST_MSVC < DOCTEST_COMPILER(19, 0, 0)) +#define DOCTEST_THREAD_LOCAL +#else // DOCTEST_MSVC +#define DOCTEST_THREAD_LOCAL thread_local +#endif // DOCTEST_MSVC +#endif // DOCTEST_THREAD_LOCAL + +#ifndef DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES +#define DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES 32 +#endif + +#ifndef DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE +#define DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE 64 +#endif + +#ifdef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS +#define DOCTEST_OPTIONS_PREFIX_DISPLAY DOCTEST_CONFIG_OPTIONS_PREFIX +#else +#define DOCTEST_OPTIONS_PREFIX_DISPLAY "" +#endif + +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS +#endif + +#ifndef DOCTEST_CDECL +#define DOCTEST_CDECL __cdecl +#endif + +namespace doctest { + +bool is_running_in_test = false; + +namespace { + using namespace detail; + + template + DOCTEST_NORETURN void throw_exception(Ex const& e) { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + throw e; +#else // DOCTEST_CONFIG_NO_EXCEPTIONS +#ifdef DOCTEST_CONFIG_HANDLE_EXCEPTION + DOCTEST_CONFIG_HANDLE_EXCEPTION(e); +#else // DOCTEST_CONFIG_HANDLE_EXCEPTION +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + std::cerr << "doctest will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM +#endif // DOCTEST_CONFIG_HANDLE_EXCEPTION + std::terminate(); +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } + +#ifndef DOCTEST_INTERNAL_ERROR +#define DOCTEST_INTERNAL_ERROR(msg) \ + throw_exception(std::logic_error( \ + __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) +#endif // DOCTEST_INTERNAL_ERROR + + // case insensitive strcmp + int stricmp(const char* a, const char* b) { + for(;; a++, b++) { + const int d = tolower(*a) - tolower(*b); + if(d != 0 || !*a) + return d; + } + } + + struct Endianness + { + enum Arch + { + Big, + Little + }; + + static Arch which() { + int x = 1; + // casting any data pointer to char* is allowed + auto ptr = reinterpret_cast(&x); + if(*ptr) + return Little; + return Big; + } + }; +} // namespace + +namespace detail { + DOCTEST_THREAD_LOCAL class + { + std::vector stack; + std::stringstream ss; + + public: + std::ostream* push() { + stack.push_back(ss.tellp()); + return &ss; + } + + String pop() { + if (stack.empty()) + DOCTEST_INTERNAL_ERROR("TLSS was empty when trying to pop!"); + + std::streampos pos = stack.back(); + stack.pop_back(); + unsigned sz = static_cast(ss.tellp() - pos); + ss.rdbuf()->pubseekpos(pos, std::ios::in | std::ios::out); + return String(ss, sz); + } + } g_oss; + + std::ostream* tlssPush() { + return g_oss.push(); + } + + String tlssPop() { + return g_oss.pop(); + } + +#ifndef DOCTEST_CONFIG_DISABLE + +namespace timer_large_integer +{ + +#if defined(DOCTEST_PLATFORM_WINDOWS) + using type = ULONGLONG; +#else // DOCTEST_PLATFORM_WINDOWS + using type = std::uint64_t; +#endif // DOCTEST_PLATFORM_WINDOWS +} + +using ticks_t = timer_large_integer::type; + +#ifdef DOCTEST_CONFIG_GETCURRENTTICKS + ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } +#elif defined(DOCTEST_PLATFORM_WINDOWS) + ticks_t getCurrentTicks() { + static LARGE_INTEGER hz = { {0} }, hzo = { {0} }; + if(!hz.QuadPart) { + QueryPerformanceFrequency(&hz); + QueryPerformanceCounter(&hzo); + } + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart - hzo.QuadPart) * LONGLONG(1000000)) / hz.QuadPart; + } +#else // DOCTEST_PLATFORM_WINDOWS + ticks_t getCurrentTicks() { + timeval t; + gettimeofday(&t, nullptr); + return static_cast(t.tv_sec) * 1000000 + static_cast(t.tv_usec); + } +#endif // DOCTEST_PLATFORM_WINDOWS + + struct Timer + { + void start() { m_ticks = getCurrentTicks(); } + unsigned int getElapsedMicroseconds() const { + return static_cast(getCurrentTicks() - m_ticks); + } + //unsigned int getElapsedMilliseconds() const { + // return static_cast(getElapsedMicroseconds() / 1000); + //} + double getElapsedSeconds() const { return static_cast(getCurrentTicks() - m_ticks) / 1000000.0; } + + private: + ticks_t m_ticks = 0; + }; + +#ifdef DOCTEST_CONFIG_NO_MULTITHREADING + template + using Atomic = T; +#else // DOCTEST_CONFIG_NO_MULTITHREADING + template + using Atomic = std::atomic; +#endif // DOCTEST_CONFIG_NO_MULTITHREADING + +#if defined(DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS) || defined(DOCTEST_CONFIG_NO_MULTITHREADING) + template + using MultiLaneAtomic = Atomic; +#else // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + // Provides a multilane implementation of an atomic variable that supports add, sub, load, + // store. Instead of using a single atomic variable, this splits up into multiple ones, + // each sitting on a separate cache line. The goal is to provide a speedup when most + // operations are modifying. It achieves this with two properties: + // + // * Multiple atomics are used, so chance of congestion from the same atomic is reduced. + // * Each atomic sits on a separate cache line, so false sharing is reduced. + // + // The disadvantage is that there is a small overhead due to the use of TLS, and load/store + // is slower because all atomics have to be accessed. + template + class MultiLaneAtomic + { + struct CacheLineAlignedAtomic + { + Atomic atomic{}; + char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(Atomic)]; + }; + CacheLineAlignedAtomic m_atomics[DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES]; + + static_assert(sizeof(CacheLineAlignedAtomic) == DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE, + "guarantee one atomic takes exactly one cache line"); + + public: + T operator++() DOCTEST_NOEXCEPT { return fetch_add(1) + 1; } + + T operator++(int) DOCTEST_NOEXCEPT { return fetch_add(1); } + + T fetch_add(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + return myAtomic().fetch_add(arg, order); + } + + T fetch_sub(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + return myAtomic().fetch_sub(arg, order); + } + + operator T() const DOCTEST_NOEXCEPT { return load(); } + + T load(std::memory_order order = std::memory_order_seq_cst) const DOCTEST_NOEXCEPT { + auto result = T(); + for(auto const& c : m_atomics) { + result += c.atomic.load(order); + } + return result; + } + + T operator=(T desired) DOCTEST_NOEXCEPT { // lgtm [cpp/assignment-does-not-return-this] + store(desired); + return desired; + } + + void store(T desired, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + // first value becomes desired", all others become 0. + for(auto& c : m_atomics) { + c.atomic.store(desired, order); + desired = {}; + } + } + + private: + // Each thread has a different atomic that it operates on. If more than NumLanes threads + // use this, some will use the same atomic. So performance will degrade a bit, but still + // everything will work. + // + // The logic here is a bit tricky. The call should be as fast as possible, so that there + // is minimal to no overhead in determining the correct atomic for the current thread. + // + // 1. A global static counter laneCounter counts continuously up. + // 2. Each successive thread will use modulo operation of that counter so it gets an atomic + // assigned in a round-robin fashion. + // 3. This tlsLaneIdx is stored in the thread local data, so it is directly available with + // little overhead. + Atomic& myAtomic() DOCTEST_NOEXCEPT { + static Atomic laneCounter; + DOCTEST_THREAD_LOCAL size_t tlsLaneIdx = + laneCounter++ % DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES; + + return m_atomics[tlsLaneIdx].atomic; + } + }; +#endif // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + + // this holds both parameters from the command line and runtime data for tests + struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats + { + MultiLaneAtomic numAssertsCurrentTest_atomic; + MultiLaneAtomic numAssertsFailedCurrentTest_atomic; + + std::vector> filters = decltype(filters)(9); // 9 different filters + + std::vector reporters_currently_used; + + assert_handler ah = nullptr; + + Timer timer; + + std::vector stringifiedContexts; // logging from INFO() due to an exception + + // stuff for subcases + bool reachedLeaf; + std::vector subcaseStack; + std::vector nextSubcaseStack; + std::unordered_set fullyTraversedSubcases; + size_t currentSubcaseDepth; + Atomic shouldLogCurrentException; + + void resetRunData() { + numTestCases = 0; + numTestCasesPassingFilters = 0; + numTestSuitesPassingFilters = 0; + numTestCasesFailed = 0; + numAsserts = 0; + numAssertsFailed = 0; + numAssertsCurrentTest = 0; + numAssertsFailedCurrentTest = 0; + } + + void finalizeTestCaseData() { + seconds = timer.getElapsedSeconds(); + + // update the non-atomic counters + numAsserts += numAssertsCurrentTest_atomic; + numAssertsFailed += numAssertsFailedCurrentTest_atomic; + numAssertsCurrentTest = numAssertsCurrentTest_atomic; + numAssertsFailedCurrentTest = numAssertsFailedCurrentTest_atomic; + + if(numAssertsFailedCurrentTest) + failure_flags |= TestCaseFailureReason::AssertFailure; + + if(Approx(currentTest->m_timeout).epsilon(DBL_EPSILON) != 0 && + Approx(seconds).epsilon(DBL_EPSILON) > currentTest->m_timeout) + failure_flags |= TestCaseFailureReason::Timeout; + + if(currentTest->m_should_fail) { + if(failure_flags) { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedAndDid; + } else { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedButDidnt; + } + } else if(failure_flags && currentTest->m_may_fail) { + failure_flags |= TestCaseFailureReason::CouldHaveFailedAndDid; + } else if(currentTest->m_expected_failures > 0) { + if(numAssertsFailedCurrentTest == currentTest->m_expected_failures) { + failure_flags |= TestCaseFailureReason::FailedExactlyNumTimes; + } else { + failure_flags |= TestCaseFailureReason::DidntFailExactlyNumTimes; + } + } + + bool ok_to_fail = (TestCaseFailureReason::ShouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::CouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); + + // if any subcase has failed - the whole test case has failed + testCaseSuccess = !(failure_flags && !ok_to_fail); + if(!testCaseSuccess) + numTestCasesFailed++; + } + }; + + ContextState* g_cs = nullptr; + + // used to avoid locks for the debug output + // TODO: figure out if this is indeed necessary/correct - seems like either there still + // could be a race or that there wouldn't be a race even if using the context directly + DOCTEST_THREAD_LOCAL bool g_no_colors; + +#endif // DOCTEST_CONFIG_DISABLE +} // namespace detail + +char* String::allocate(size_type sz) { + if (sz <= last) { + buf[sz] = '\0'; + setLast(last - sz); + return buf; + } else { + setOnHeap(); + data.size = sz; + data.capacity = data.size + 1; + data.ptr = new char[data.capacity]; + data.ptr[sz] = '\0'; + return data.ptr; + } +} + +void String::setOnHeap() noexcept { *reinterpret_cast(&buf[last]) = 128; } +void String::setLast(size_type in) noexcept { buf[last] = char(in); } +void String::setSize(size_type sz) noexcept { + if (isOnStack()) { buf[sz] = '\0'; setLast(last - sz); } + else { data.ptr[sz] = '\0'; data.size = sz; } +} + +void String::copy(const String& other) { + if(other.isOnStack()) { + memcpy(buf, other.buf, len); + } else { + memcpy(allocate(other.data.size), other.data.ptr, other.data.size); + } +} + +String::String() noexcept { + buf[0] = '\0'; + setLast(); +} + +String::~String() { + if(!isOnStack()) + delete[] data.ptr; +} // NOLINT(clang-analyzer-cplusplus.NewDeleteLeaks) + +String::String(const char* in) + : String(in, strlen(in)) {} + +String::String(const char* in, size_type in_size) { + memcpy(allocate(in_size), in, in_size); +} + +String::String(std::istream& in, size_type in_size) { + in.read(allocate(in_size), in_size); +} + +String::String(const String& other) { copy(other); } + +String& String::operator=(const String& other) { + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + + copy(other); + } + + return *this; +} + +String& String::operator+=(const String& other) { + const size_type my_old_size = size(); + const size_type other_size = other.size(); + const size_type total_size = my_old_size + other_size; + if(isOnStack()) { + if(total_size < len) { + // append to the current stack space + memcpy(buf + my_old_size, other.c_str(), other_size + 1); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + setLast(last - total_size); + } else { + // alloc new chunk + char* temp = new char[total_size + 1]; + // copy current data to new location before writing in the union + memcpy(temp, buf, my_old_size); // skip the +1 ('\0') for speed + // update data in union + setOnHeap(); + data.size = total_size; + data.capacity = data.size + 1; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } else { + if(data.capacity > total_size) { + // append to the current heap block + data.size = total_size; + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } else { + // resize + data.capacity *= 2; + if(data.capacity <= total_size) + data.capacity = total_size + 1; + // alloc new chunk + char* temp = new char[data.capacity]; + // copy current data to new location before releasing it + memcpy(temp, data.ptr, my_old_size); // skip the +1 ('\0') for speed + // release old chunk + delete[] data.ptr; + // update the rest of the union members + data.size = total_size; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } + + return *this; +} + +String::String(String&& other) noexcept { + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); +} + +String& String::operator=(String&& other) noexcept { + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); + } + return *this; +} + +char String::operator[](size_type i) const { + return const_cast(this)->operator[](i); +} + +char& String::operator[](size_type i) { + if(isOnStack()) + return reinterpret_cast(buf)[i]; + return data.ptr[i]; +} + +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") +String::size_type String::size() const { + if(isOnStack()) + return last - (size_type(buf[last]) & 31); // using "last" would work only if "len" is 32 + return data.size; +} +DOCTEST_GCC_SUPPRESS_WARNING_POP + +String::size_type String::capacity() const { + if(isOnStack()) + return len; + return data.capacity; +} + +String String::substr(size_type pos, size_type cnt) && { + cnt = std::min(cnt, size() - 1 - pos); + char* cptr = c_str(); + memmove(cptr, cptr + pos, cnt); + setSize(cnt); + return std::move(*this); +} + +String String::substr(size_type pos, size_type cnt) const & { + cnt = std::min(cnt, size() - 1 - pos); + return String{ c_str() + pos, cnt }; +} + +String::size_type String::find(char ch, size_type pos) const { + const char* begin = c_str(); + const char* end = begin + size(); + const char* it = begin + pos; + for (; it < end && *it != ch; it++); + if (it < end) { return static_cast(it - begin); } + else { return npos; } +} + +String::size_type String::rfind(char ch, size_type pos) const { + const char* begin = c_str(); + const char* it = begin + std::min(pos, size() - 1); + for (; it >= begin && *it != ch; it--); + if (it >= begin) { return static_cast(it - begin); } + else { return npos; } +} + +int String::compare(const char* other, bool no_case) const { + if(no_case) + return doctest::stricmp(c_str(), other); + return std::strcmp(c_str(), other); +} + +int String::compare(const String& other, bool no_case) const { + return compare(other.c_str(), no_case); +} + +String operator+(const String& lhs, const String& rhs) { return String(lhs) += rhs; } + +bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } +bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } +bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } +bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } +bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } +bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } + +std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } + +Contains::Contains(const String& str) : string(str) { } + +bool Contains::checkWith(const String& other) const { + return strstr(other.c_str(), string.c_str()) != nullptr; +} + +String toString(const Contains& in) { + return "Contains( " + in.string + " )"; +} + +bool operator==(const String& lhs, const Contains& rhs) { return rhs.checkWith(lhs); } +bool operator==(const Contains& lhs, const String& rhs) { return lhs.checkWith(rhs); } +bool operator!=(const String& lhs, const Contains& rhs) { return !rhs.checkWith(lhs); } +bool operator!=(const Contains& lhs, const String& rhs) { return !lhs.checkWith(rhs); } + +namespace { + void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) +} // namespace + +namespace Color { + std::ostream& operator<<(std::ostream& s, Color::Enum code) { + color_to_stream(s, code); + return s; + } +} // namespace Color + +// clang-format off +const char* assertString(assertType::Enum at) { + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4061) // enum 'x' in switch of enum 'y' is not explicitly handled + #define DOCTEST_GENERATE_ASSERT_TYPE_CASE(assert_type) case assertType::DT_ ## assert_type: return #assert_type + #define DOCTEST_GENERATE_ASSERT_TYPE_CASES(assert_type) \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN_ ## assert_type); \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK_ ## assert_type); \ + DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE_ ## assert_type) + switch(at) { + DOCTEST_GENERATE_ASSERT_TYPE_CASE(WARN); + DOCTEST_GENERATE_ASSERT_TYPE_CASE(CHECK); + DOCTEST_GENERATE_ASSERT_TYPE_CASE(REQUIRE); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(FALSE); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_AS); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(THROWS_WITH_AS); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(NOTHROW); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(EQ); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(NE); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(GT); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(LT); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(GE); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(LE); + + DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY); + DOCTEST_GENERATE_ASSERT_TYPE_CASES(UNARY_FALSE); + + default: DOCTEST_INTERNAL_ERROR("Tried stringifying invalid assert type!"); + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP +} +// clang-format on + +const char* failureString(assertType::Enum at) { + if(at & assertType::is_warn) //!OCLINT bitwise operator in conditional + return "WARNING"; + if(at & assertType::is_check) //!OCLINT bitwise operator in conditional + return "ERROR"; + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return "FATAL ERROR"; + return ""; +} + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +// depending on the current options this will remove the path of filenames +const char* skipPathFromFilename(const char* file) { +#ifndef DOCTEST_CONFIG_DISABLE + if(getContextOptions()->no_path_in_filenames) { + auto back = std::strrchr(file, '\\'); + auto forward = std::strrchr(file, '/'); + if(back || forward) { + if(back > forward) + forward = back; + return forward + 1; + } + } +#endif // DOCTEST_CONFIG_DISABLE + return file; +} +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +bool SubcaseSignature::operator==(const SubcaseSignature& other) const { + return m_line == other.m_line + && std::strcmp(m_file, other.m_file) == 0 + && m_name == other.m_name; +} + +bool SubcaseSignature::operator<(const SubcaseSignature& other) const { + if(m_line != other.m_line) + return m_line < other.m_line; + if(std::strcmp(m_file, other.m_file) != 0) + return std::strcmp(m_file, other.m_file) < 0; + return m_name.compare(other.m_name) < 0; +} + +DOCTEST_DEFINE_INTERFACE(IContextScope) + +namespace detail { + void filldata::fill(std::ostream* stream, const void* in) { + if (in) { *stream << in; } + else { *stream << "nullptr"; } + } + + template + String toStreamLit(T t) { + std::ostream* os = tlssPush(); + os->operator<<(t); + return tlssPop(); + } +} + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/doctest/doctest/issues/183 +String toString(const std::string& in) { return in.c_str(); } +#endif // VS 2019 + +String toString(String in) { return in; } + +String toString(std::nullptr_t) { return "nullptr"; } + +String toString(bool in) { return in ? "true" : "false"; } + +String toString(float in) { return toStreamLit(in); } +String toString(double in) { return toStreamLit(in); } +String toString(double long in) { return toStreamLit(in); } + +String toString(char in) { return toStreamLit(static_cast(in)); } +String toString(char signed in) { return toStreamLit(static_cast(in)); } +String toString(char unsigned in) { return toStreamLit(static_cast(in)); } +String toString(short in) { return toStreamLit(in); } +String toString(short unsigned in) { return toStreamLit(in); } +String toString(signed in) { return toStreamLit(in); } +String toString(unsigned in) { return toStreamLit(in); } +String toString(long in) { return toStreamLit(in); } +String toString(long unsigned in) { return toStreamLit(in); } +String toString(long long in) { return toStreamLit(in); } +String toString(long long unsigned in) { return toStreamLit(in); } + +Approx::Approx(double value) + : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) + , m_scale(1.0) + , m_value(value) {} + +Approx Approx::operator()(double value) const { + Approx approx(value); + approx.epsilon(m_epsilon); + approx.scale(m_scale); + return approx; +} + +Approx& Approx::epsilon(double newEpsilon) { + m_epsilon = newEpsilon; + return *this; +} +Approx& Approx::scale(double newScale) { + m_scale = newScale; + return *this; +} + +bool operator==(double lhs, const Approx& rhs) { + // Thanks to Richard Harris for his help refining this formula + return std::fabs(lhs - rhs.m_value) < + rhs.m_epsilon * (rhs.m_scale + std::max(std::fabs(lhs), std::fabs(rhs.m_value))); +} +bool operator==(const Approx& lhs, double rhs) { return operator==(rhs, lhs); } +bool operator!=(double lhs, const Approx& rhs) { return !operator==(lhs, rhs); } +bool operator!=(const Approx& lhs, double rhs) { return !operator==(rhs, lhs); } +bool operator<=(double lhs, const Approx& rhs) { return lhs < rhs.m_value || lhs == rhs; } +bool operator<=(const Approx& lhs, double rhs) { return lhs.m_value < rhs || lhs == rhs; } +bool operator>=(double lhs, const Approx& rhs) { return lhs > rhs.m_value || lhs == rhs; } +bool operator>=(const Approx& lhs, double rhs) { return lhs.m_value > rhs || lhs == rhs; } +bool operator<(double lhs, const Approx& rhs) { return lhs < rhs.m_value && lhs != rhs; } +bool operator<(const Approx& lhs, double rhs) { return lhs.m_value < rhs && lhs != rhs; } +bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs != rhs; } +bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } + +String toString(const Approx& in) { + return "Approx( " + doctest::toString(in.m_value) + " )"; +} +const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } + +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4738) +template +IsNaN::operator bool() const { + return std::isnan(value) ^ flipped; +} +DOCTEST_MSVC_SUPPRESS_WARNING_POP +template struct DOCTEST_INTERFACE_DEF IsNaN; +template struct DOCTEST_INTERFACE_DEF IsNaN; +template struct DOCTEST_INTERFACE_DEF IsNaN; +template +String toString(IsNaN in) { return String(in.flipped ? "! " : "") + "IsNaN( " + doctest::toString(in.value) + " )"; } +String toString(IsNaN in) { return toString(in); } +String toString(IsNaN in) { return toString(in); } +String toString(IsNaN in) { return toString(in); } + +} // namespace doctest + +#ifdef DOCTEST_CONFIG_DISABLE +namespace doctest { +Context::Context(int, const char* const*) {} +Context::~Context() = default; +void Context::applyCommandLine(int, const char* const*) {} +void Context::addFilter(const char*, const char*) {} +void Context::clearFilters() {} +void Context::setOption(const char*, bool) {} +void Context::setOption(const char*, int) {} +void Context::setOption(const char*, const char*) {} +bool Context::shouldExit() { return false; } +void Context::setAsDefaultForAssertsOutOfTestCases() {} +void Context::setAssertHandler(detail::assert_handler) {} +void Context::setCout(std::ostream*) {} +int Context::run() { return 0; } + +int IReporter::get_num_active_contexts() { return 0; } +const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } +int IReporter::get_num_stringified_contexts() { return 0; } +const String* IReporter::get_stringified_contexts() { return nullptr; } + +int registerReporter(const char*, int, IReporter*) { return 0; } + +} // namespace doctest +#else // DOCTEST_CONFIG_DISABLE + +#if !defined(DOCTEST_CONFIG_COLORS_NONE) +#if !defined(DOCTEST_CONFIG_COLORS_WINDOWS) && !defined(DOCTEST_CONFIG_COLORS_ANSI) +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_CONFIG_COLORS_WINDOWS +#else // linux +#define DOCTEST_CONFIG_COLORS_ANSI +#endif // platform +#endif // DOCTEST_CONFIG_COLORS_WINDOWS && DOCTEST_CONFIG_COLORS_ANSI +#endif // DOCTEST_CONFIG_COLORS_NONE + +namespace doctest_detail_test_suite_ns { +// holds the current test suite +doctest::detail::TestSuite& getCurrentTestSuite() { + static doctest::detail::TestSuite data{}; + return data; +} +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +namespace { + // the int (priority) is part of the key for automatic sorting - sadly one can register a + // reporter with a duplicate name and a different priority but hopefully that won't happen often :| + using reporterMap = std::map, reporterCreatorFunc>; + + reporterMap& getReporters() { + static reporterMap data; + return data; + } + reporterMap& getListeners() { + static reporterMap data; + return data; + } +} // namespace +namespace detail { +#define DOCTEST_ITERATE_THROUGH_REPORTERS(function, ...) \ + for(auto& curr_rep : g_cs->reporters_currently_used) \ + curr_rep->function(__VA_ARGS__) + + bool checkIfShouldThrow(assertType::Enum at) { + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return true; + + if((at & assertType::is_check) //!OCLINT bitwise operator in conditional + && getContextOptions()->abort_after > 0 && + (g_cs->numAssertsFailed + g_cs->numAssertsFailedCurrentTest_atomic) >= + getContextOptions()->abort_after) + return true; + + return false; + } + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN void throwException() { + g_cs->shouldLogCurrentException = false; + throw TestFailureException(); // NOLINT(hicpp-exception-baseclass) + } +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + void throwException() {} +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +} // namespace detail + +namespace { + using namespace detail; + // matching of a string against a wildcard mask (case sensitivity configurable) taken from + // https://www.codeproject.com/Articles/1088/Wildcard-string-compare-globbing + int wildcmp(const char* str, const char* wild, bool caseSensitive) { + const char* cp = str; + const char* mp = wild; + + while((*str) && (*wild != '*')) { + if((caseSensitive ? (*wild != *str) : (tolower(*wild) != tolower(*str))) && + (*wild != '?')) { + return 0; + } + wild++; + str++; + } + + while(*str) { + if(*wild == '*') { + if(!*++wild) { + return 1; + } + mp = wild; + cp = str + 1; + } else if((caseSensitive ? (*wild == *str) : (tolower(*wild) == tolower(*str))) || + (*wild == '?')) { + wild++; + str++; + } else { + wild = mp; //!OCLINT parameter reassignment + str = cp++; //!OCLINT parameter reassignment + } + } + + while(*wild == '*') { + wild++; + } + return !*wild; + } + + // checks if the name matches any of the filters (and can be configured what to do when empty) + bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, + bool caseSensitive) { + if (filters.empty() && matchEmpty) + return true; + for (auto& curr : filters) + if (wildcmp(name, curr.c_str(), caseSensitive)) + return true; + return false; + } + + DOCTEST_NO_SANITIZE_INTEGER + unsigned long long hash(unsigned long long a, unsigned long long b) { + return (a << 5) + b; + } + + // C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html + DOCTEST_NO_SANITIZE_INTEGER + unsigned long long hash(const char* str) { + unsigned long long hash = 5381; + char c; + while ((c = *str++)) + hash = ((hash << 5) + hash) + c; // hash * 33 + c + return hash; + } + + unsigned long long hash(const SubcaseSignature& sig) { + return hash(hash(hash(sig.m_file), hash(sig.m_name.c_str())), sig.m_line); + } + + unsigned long long hash(const std::vector& sigs, size_t count) { + unsigned long long running = 0; + auto end = sigs.begin() + count; + for (auto it = sigs.begin(); it != end; it++) { + running = hash(running, hash(*it)); + } + return running; + } + + unsigned long long hash(const std::vector& sigs) { + unsigned long long running = 0; + for (const SubcaseSignature& sig : sigs) { + running = hash(running, hash(sig)); + } + return running; + } +} // namespace +namespace detail { + bool Subcase::checkFilters() { + if (g_cs->subcaseStack.size() < size_t(g_cs->subcase_filter_levels)) { + if (!matchesAny(m_signature.m_name.c_str(), g_cs->filters[6], true, g_cs->case_sensitive)) + return true; + if (matchesAny(m_signature.m_name.c_str(), g_cs->filters[7], false, g_cs->case_sensitive)) + return true; + } + return false; + } + + Subcase::Subcase(const String& name, const char* file, int line) + : m_signature({name, file, line}) { + if (!g_cs->reachedLeaf) { + if (g_cs->nextSubcaseStack.size() <= g_cs->subcaseStack.size() + || g_cs->nextSubcaseStack[g_cs->subcaseStack.size()] == m_signature) { + // Going down. + if (checkFilters()) { return; } + + g_cs->subcaseStack.push_back(m_signature); + g_cs->currentSubcaseDepth++; + m_entered = true; + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } + } else { + if (g_cs->subcaseStack[g_cs->currentSubcaseDepth] == m_signature) { + // This subcase is reentered via control flow. + g_cs->currentSubcaseDepth++; + m_entered = true; + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } else if (g_cs->nextSubcaseStack.size() <= g_cs->currentSubcaseDepth + && g_cs->fullyTraversedSubcases.find(hash(hash(g_cs->subcaseStack, g_cs->currentSubcaseDepth), hash(m_signature))) + == g_cs->fullyTraversedSubcases.end()) { + if (checkFilters()) { return; } + // This subcase is part of the one to be executed next. + g_cs->nextSubcaseStack.clear(); + g_cs->nextSubcaseStack.insert(g_cs->nextSubcaseStack.end(), + g_cs->subcaseStack.begin(), g_cs->subcaseStack.begin() + g_cs->currentSubcaseDepth); + g_cs->nextSubcaseStack.push_back(m_signature); + } + } + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + Subcase::~Subcase() { + if (m_entered) { + g_cs->currentSubcaseDepth--; + + if (!g_cs->reachedLeaf) { + // Leaf. + g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); + g_cs->nextSubcaseStack.clear(); + g_cs->reachedLeaf = true; + } else if (g_cs->nextSubcaseStack.empty()) { + // All children are finished. + g_cs->fullyTraversedSubcases.insert(hash(g_cs->subcaseStack)); + } + +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) + if(std::uncaught_exceptions() > 0 +#else + if(std::uncaught_exception() +#endif + && g_cs->shouldLogCurrentException) { + DOCTEST_ITERATE_THROUGH_REPORTERS( + test_case_exception, {"exception thrown in subcase - will translate later " + "when the whole test case has been exited (cannot " + "translate while there is an active exception)", + false}); + g_cs->shouldLogCurrentException = false; + } + + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + Subcase::operator bool() const { return m_entered; } + + Result::Result(bool passed, const String& decomposition) + : m_passed(passed) + , m_decomp(decomposition) {} + + ExpressionDecomposer::ExpressionDecomposer(assertType::Enum at) + : m_at(at) {} + + TestSuite& TestSuite::operator*(const char* in) { + m_test_suite = in; + return *this; + } + + TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const String& type, int template_id) { + m_file = file; + m_line = line; + m_name = nullptr; // will be later overridden in operator* + m_test_suite = test_suite.m_test_suite; + m_description = test_suite.m_description; + m_skip = test_suite.m_skip; + m_no_breaks = test_suite.m_no_breaks; + m_no_output = test_suite.m_no_output; + m_may_fail = test_suite.m_may_fail; + m_should_fail = test_suite.m_should_fail; + m_expected_failures = test_suite.m_expected_failures; + m_timeout = test_suite.m_timeout; + + m_test = test; + m_type = type; + m_template_id = template_id; + } + + TestCase::TestCase(const TestCase& other) + : TestCaseData() { + *this = other; + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + TestCase& TestCase::operator=(const TestCase& other) { + TestCaseData::operator=(other); + m_test = other.m_test; + m_type = other.m_type; + m_template_id = other.m_template_id; + m_full_name = other.m_full_name; + + if(m_template_id != -1) + m_name = m_full_name.c_str(); + return *this; + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& TestCase::operator*(const char* in) { + m_name = in; + // make a new name with an appended type for templated test case + if(m_template_id != -1) { + m_full_name = String(m_name) + "<" + m_type + ">"; + // redirect the name to point to the newly constructed full name + m_name = m_full_name.c_str(); + } + return *this; + } + + bool TestCase::operator<(const TestCase& other) const { + // this will be used only to differentiate between test cases - not relevant for sorting + if(m_line != other.m_line) + return m_line < other.m_line; + const int name_cmp = strcmp(m_name, other.m_name); + if(name_cmp != 0) + return name_cmp < 0; + const int file_cmp = m_file.compare(other.m_file); + if(file_cmp != 0) + return file_cmp < 0; + return m_template_id < other.m_template_id; + } + + // all the registered tests + std::set& getRegisteredTests() { + static std::set data; + return data; + } +} // namespace detail +namespace { + using namespace detail; + // for sorting tests by file/line + bool fileOrderComparator(const TestCase* lhs, const TestCase* rhs) { + // this is needed because MSVC gives different case for drive letters + // for __FILE__ when evaluated in a header and a source file + const int res = lhs->m_file.compare(rhs->m_file, bool(DOCTEST_MSVC)); + if(res != 0) + return res < 0; + if(lhs->m_line != rhs->m_line) + return lhs->m_line < rhs->m_line; + return lhs->m_template_id < rhs->m_template_id; + } + + // for sorting tests by suite/file/line + bool suiteOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_test_suite, rhs->m_test_suite); + if(res != 0) + return res < 0; + return fileOrderComparator(lhs, rhs); + } + + // for sorting tests by name/suite/file/line + bool nameOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_name, rhs->m_name); + if(res != 0) + return res < 0; + return suiteOrderComparator(lhs, rhs); + } + + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + void color_to_stream(std::ostream& s, Color::Enum code) { + static_cast(s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS + static_cast(code); // for DOCTEST_CONFIG_COLORS_NONE +#ifdef DOCTEST_CONFIG_COLORS_ANSI + if(g_no_colors || + (isatty(STDOUT_FILENO) == false && getContextOptions()->force_colors == false)) + return; + + auto col = ""; + // clang-format off + switch(code) { //!OCLINT missing break in switch statement / unnecessary default statement in covered switch statement + case Color::Red: col = "[0;31m"; break; + case Color::Green: col = "[0;32m"; break; + case Color::Blue: col = "[0;34m"; break; + case Color::Cyan: col = "[0;36m"; break; + case Color::Yellow: col = "[0;33m"; break; + case Color::Grey: col = "[1;30m"; break; + case Color::LightGrey: col = "[0;37m"; break; + case Color::BrightRed: col = "[1;31m"; break; + case Color::BrightGreen: col = "[1;32m"; break; + case Color::BrightWhite: col = "[1;37m"; break; + case Color::Bright: // invalid + case Color::None: + case Color::White: + default: col = "[0m"; + } + // clang-format on + s << "\033" << col; +#endif // DOCTEST_CONFIG_COLORS_ANSI + +#ifdef DOCTEST_CONFIG_COLORS_WINDOWS + if(g_no_colors || + (_isatty(_fileno(stdout)) == false && getContextOptions()->force_colors == false)) + return; + + static struct ConsoleHelper { + HANDLE stdoutHandle; + WORD origFgAttrs; + WORD origBgAttrs; + + ConsoleHelper() { + stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(stdoutHandle, &csbiInfo); + origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | + BACKGROUND_BLUE | BACKGROUND_INTENSITY); + origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | + FOREGROUND_BLUE | FOREGROUND_INTENSITY); + } + } ch; + +#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(ch.stdoutHandle, x | ch.origBgAttrs) + + // clang-format off + switch (code) { + case Color::White: DOCTEST_SET_ATTR(FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::Red: DOCTEST_SET_ATTR(FOREGROUND_RED); break; + case Color::Green: DOCTEST_SET_ATTR(FOREGROUND_GREEN); break; + case Color::Blue: DOCTEST_SET_ATTR(FOREGROUND_BLUE); break; + case Color::Cyan: DOCTEST_SET_ATTR(FOREGROUND_BLUE | FOREGROUND_GREEN); break; + case Color::Yellow: DOCTEST_SET_ATTR(FOREGROUND_RED | FOREGROUND_GREEN); break; + case Color::Grey: DOCTEST_SET_ATTR(0); break; + case Color::LightGrey: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY); break; + case Color::BrightRed: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_RED); break; + case Color::BrightGreen: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN); break; + case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::None: + case Color::Bright: // invalid + default: DOCTEST_SET_ATTR(ch.origFgAttrs); + } + // clang-format on +#endif // DOCTEST_CONFIG_COLORS_WINDOWS + } + DOCTEST_CLANG_SUPPRESS_WARNING_POP + + std::vector& getExceptionTranslators() { + static std::vector data; + return data; + } + + String translateActiveException() { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + String res; + auto& translators = getExceptionTranslators(); + for(auto& curr : translators) + if(curr->translate(res)) + return res; + // clang-format off + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wcatch-value") + try { + throw; + } catch(std::exception& ex) { + return ex.what(); + } catch(std::string& msg) { + return msg.c_str(); + } catch(const char* msg) { + return msg; + } catch(...) { + return "unknown exception"; + } + DOCTEST_GCC_SUPPRESS_WARNING_POP +// clang-format on +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + return ""; +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } +} // namespace + +namespace detail { + // used by the macros for registering tests + int regTest(const TestCase& tc) { + getRegisteredTests().insert(tc); + return 0; + } + + // sets the current test suite + int setTestSuite(const TestSuite& ts) { + doctest_detail_test_suite_ns::getCurrentTestSuite() = ts; + return 0; + } + +#ifdef DOCTEST_IS_DEBUGGER_ACTIVE + bool isDebuggerActive() { return DOCTEST_IS_DEBUGGER_ACTIVE(); } +#else // DOCTEST_IS_DEBUGGER_ACTIVE +#ifdef DOCTEST_PLATFORM_LINUX + class ErrnoGuard { + public: + ErrnoGuard() : m_oldErrno(errno) {} + ~ErrnoGuard() { errno = m_oldErrno; } + private: + int m_oldErrno; + }; + // See the comments in Catch2 for the reasoning behind this implementation: + // https://github.com/catchorg/Catch2/blob/v2.13.1/include/internal/catch_debugger.cpp#L79-L102 + bool isDebuggerActive() { + ErrnoGuard guard; + std::ifstream in("/proc/self/status"); + for(std::string line; std::getline(in, line);) { + static const int PREFIX_LEN = 11; + if(line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0) { + return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; + } + } + return false; + } +#elif defined(DOCTEST_PLATFORM_MAC) + // The following function is taken directly from the following technical note: + // https://developer.apple.com/library/archive/qa/qa1361/_index.html + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive() { + int mib[4]; + kinfo_proc info; + size_t size; + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + info.kp_proc.p_flag = 0; + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + // Call sysctl. + size = sizeof(info); + if(sysctl(mib, DOCTEST_COUNTOF(mib), &info, &size, 0, 0) != 0) { + std::cerr << "\nCall to sysctl failed - unable to determine if debugger is active **\n"; + return false; + } + // We're being debugged if the P_TRACED flag is set. + return ((info.kp_proc.p_flag & P_TRACED) != 0); + } +#elif DOCTEST_MSVC || defined(__MINGW32__) || defined(__MINGW64__) + bool isDebuggerActive() { return ::IsDebuggerPresent() != 0; } +#else + bool isDebuggerActive() { return false; } +#endif // Platform +#endif // DOCTEST_IS_DEBUGGER_ACTIVE + + void registerExceptionTranslatorImpl(const IExceptionTranslator* et) { + if(std::find(getExceptionTranslators().begin(), getExceptionTranslators().end(), et) == + getExceptionTranslators().end()) + getExceptionTranslators().push_back(et); + } + + DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() + + ContextScopeBase::ContextScopeBase() { + g_infoContexts.push_back(this); + } + + ContextScopeBase::ContextScopeBase(ContextScopeBase&& other) noexcept { + if (other.need_to_destroy) { + other.destroy(); + } + other.need_to_destroy = false; + g_infoContexts.push_back(this); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + // destroy cannot be inlined into the destructor because that would mean calling stringify after + // ContextScope has been destroyed (base class destructors run after derived class destructors). + // Instead, ContextScope calls this method directly from its destructor. + void ContextScopeBase::destroy() { +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) + if(std::uncaught_exceptions() > 0) { +#else + if(std::uncaught_exception()) { +#endif + std::ostringstream s; + this->stringify(&s); + g_cs->stringifiedContexts.push_back(s.str().c_str()); + } + g_infoContexts.pop_back(); + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP +} // namespace detail +namespace { + using namespace detail; + +#if !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && !defined(DOCTEST_CONFIG_WINDOWS_SEH) + struct FatalConditionHandler + { + static void reset() {} + static void allocateAltStackMem() {} + static void freeAltStackMem() {} + }; +#else // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + + void reportFatal(const std::string&); + +#ifdef DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + DWORD id; + const char* name; + }; + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + SignalDefs signalDefs[] = { + {static_cast(EXCEPTION_ILLEGAL_INSTRUCTION), + "SIGILL - Illegal instruction signal"}, + {static_cast(EXCEPTION_STACK_OVERFLOW), "SIGSEGV - Stack overflow"}, + {static_cast(EXCEPTION_ACCESS_VIOLATION), + "SIGSEGV - Segmentation violation signal"}, + {static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error"}, + }; + + struct FatalConditionHandler + { + static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { + // Multiple threads may enter this filter/handler at once. We want the error message to be printed on the + // console just once no matter how many threads have crashed. + DOCTEST_DECLARE_STATIC_MUTEX(mutex) + static bool execute = true; + { + DOCTEST_LOCK_MUTEX(mutex) + if(execute) { + bool reported = false; + for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + if(ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { + reportFatal(signalDefs[i].name); + reported = true; + break; + } + } + if(reported == false) + reportFatal("Unhandled SEH exception caught"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + } + execute = false; + } + std::exit(EXIT_FAILURE); + } + + static void allocateAltStackMem() {} + static void freeAltStackMem() {} + + FatalConditionHandler() { + isSet = true; + // 32k seems enough for doctest to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + // Register an unhandled exception filter + previousTop = SetUnhandledExceptionFilter(handleException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + + // On Windows uncaught exceptions from another thread, exceptions from + // destructors, or calls to std::terminate are not a SEH exception + + // The terminal handler gets called when: + // - std::terminate is called FROM THE TEST RUNNER THREAD + // - an exception is thrown from a destructor FROM THE TEST RUNNER THREAD + original_terminate_handler = std::get_terminate(); + std::set_terminate([]() DOCTEST_NOEXCEPT { + reportFatal("Terminate handler called"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + std::exit(EXIT_FAILURE); // explicitly exit - otherwise the SIGABRT handler may be called as well + }); + + // SIGABRT is raised when: + // - std::terminate is called FROM A DIFFERENT THREAD + // - an exception is thrown from a destructor FROM A DIFFERENT THREAD + // - an uncaught exception is thrown FROM A DIFFERENT THREAD + prev_sigabrt_handler = std::signal(SIGABRT, [](int signal) DOCTEST_NOEXCEPT { + if(signal == SIGABRT) { + reportFatal("SIGABRT - Abort (abnormal termination) signal"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + std::exit(EXIT_FAILURE); + } + }); + + // The following settings are taken from google test, and more + // specifically from UnitTest::Run() inside of gtest.cc + + // the user does not want to see pop-up dialogs about crashes + prev_error_mode_1 = SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | + SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); + // This forces the abort message to go to stderr in all circumstances. + prev_error_mode_2 = _set_error_mode(_OUT_TO_STDERR); + // In the debug version, Visual Studio pops up a separate dialog + // offering a choice to debug the aborted program - we want to disable that. + prev_abort_behavior = _set_abort_behavior(0x0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); + // In debug mode, the Windows CRT can crash with an assertion over invalid + // input (e.g. passing an invalid file descriptor). The default handling + // for these assertions is to pop up a dialog and wait for user input. + // Instead ask the CRT to dump such assertions to stderr non-interactively. + prev_report_mode = _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + prev_report_file = _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); + } + + static void reset() { + if(isSet) { + // Unregister handler and restore the old guarantee + SetUnhandledExceptionFilter(previousTop); + SetThreadStackGuarantee(&guaranteeSize); + std::set_terminate(original_terminate_handler); + std::signal(SIGABRT, prev_sigabrt_handler); + SetErrorMode(prev_error_mode_1); + _set_error_mode(prev_error_mode_2); + _set_abort_behavior(prev_abort_behavior, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); + static_cast(_CrtSetReportMode(_CRT_ASSERT, prev_report_mode)); + static_cast(_CrtSetReportFile(_CRT_ASSERT, prev_report_file)); + isSet = false; + } + } + + ~FatalConditionHandler() { reset(); } + + private: + static UINT prev_error_mode_1; + static int prev_error_mode_2; + static unsigned int prev_abort_behavior; + static int prev_report_mode; + static _HFILE prev_report_file; + static void (DOCTEST_CDECL *prev_sigabrt_handler)(int); + static std::terminate_handler original_terminate_handler; + static bool isSet; + static ULONG guaranteeSize; + static LPTOP_LEVEL_EXCEPTION_FILTER previousTop; + }; + + UINT FatalConditionHandler::prev_error_mode_1; + int FatalConditionHandler::prev_error_mode_2; + unsigned int FatalConditionHandler::prev_abort_behavior; + int FatalConditionHandler::prev_report_mode; + _HFILE FatalConditionHandler::prev_report_file; + void (DOCTEST_CDECL *FatalConditionHandler::prev_sigabrt_handler)(int); + std::terminate_handler FatalConditionHandler::original_terminate_handler; + bool FatalConditionHandler::isSet = false; + ULONG FatalConditionHandler::guaranteeSize = 0; + LPTOP_LEVEL_EXCEPTION_FILTER FatalConditionHandler::previousTop = nullptr; + +#else // DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + int id; + const char* name; + }; + SignalDefs signalDefs[] = {{SIGINT, "SIGINT - Terminal interrupt signal"}, + {SIGILL, "SIGILL - Illegal instruction signal"}, + {SIGFPE, "SIGFPE - Floating point error signal"}, + {SIGSEGV, "SIGSEGV - Segmentation violation signal"}, + {SIGTERM, "SIGTERM - Termination request signal"}, + {SIGABRT, "SIGABRT - Abort (abnormal termination) signal"}}; + + struct FatalConditionHandler + { + static bool isSet; + static struct sigaction oldSigActions[DOCTEST_COUNTOF(signalDefs)]; + static stack_t oldSigStack; + static size_t altStackSize; + static char* altStackMem; + + static void handleSignal(int sig) { + const char* name = ""; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + SignalDefs& def = signalDefs[i]; + if(sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise(sig); + } + + static void allocateAltStackMem() { + altStackMem = new char[altStackSize]; + } + + static void freeAltStackMem() { + delete[] altStackMem; + } + + FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = altStackSize; + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = {}; + sa.sa_handler = handleSignal; + sa.sa_flags = SA_ONSTACK; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + ~FatalConditionHandler() { reset(); } + static void reset() { + if(isSet) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + isSet = false; + } + } + }; + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[DOCTEST_COUNTOF(signalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + size_t FatalConditionHandler::altStackSize = 4 * SIGSTKSZ; + char* FatalConditionHandler::altStackMem = nullptr; + +#endif // DOCTEST_PLATFORM_WINDOWS +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + +} // namespace + +namespace { + using namespace detail; + +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) +#else + // TODO: integration with XCode and other IDEs +#define DOCTEST_OUTPUT_DEBUG_STRING(text) +#endif // Platform + + void addAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsCurrentTest_atomic++; + } + + void addFailedAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsFailedCurrentTest_atomic++; + } + +#if defined(DOCTEST_CONFIG_POSIX_SIGNALS) || defined(DOCTEST_CONFIG_WINDOWS_SEH) + void reportFatal(const std::string& message) { + g_cs->failure_flags |= TestCaseFailureReason::Crash; + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); + + while (g_cs->subcaseStack.size()) { + g_cs->subcaseStack.pop_back(); + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + + g_cs->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH +} // namespace + +AssertData::AssertData(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const StringContains& exception_string) + : m_test_case(g_cs->currentTest), m_at(at), m_file(file), m_line(line), m_expr(expr), + m_failed(true), m_threw(false), m_threw_as(false), m_exception_type(exception_type), + m_exception_string(exception_string) { +#if DOCTEST_MSVC + if (m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC + ++m_expr; +#endif // MSVC +} + +namespace detail { + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const String& exception_string) + : AssertData(at, file, line, expr, exception_type, exception_string) { } + + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const Contains& exception_string) + : AssertData(at, file, line, expr, exception_type, exception_string) { } + + void ResultBuilder::setResult(const Result& res) { + m_decomp = res.m_decomp; + m_failed = !res.m_passed; + } + + void ResultBuilder::translateException() { + m_threw = true; + m_exception = translateActiveException(); + } + + bool ResultBuilder::log() { + if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw; + } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT + m_failed = !m_threw_as || !m_exception_string.check(m_exception); + } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw_as; + } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + m_failed = !m_exception_string.check(m_exception); + } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + m_failed = m_threw; + } + + if(m_exception.size()) + m_exception = "\"" + m_exception + "\""; + + if(is_running_in_test) { + addAssert(m_at); + DOCTEST_ITERATE_THROUGH_REPORTERS(log_assert, *this); + + if(m_failed) + addFailedAssert(m_at); + } else if(m_failed) { + failed_out_of_a_testing_context(*this); + } + + return m_failed && isDebuggerActive() && !getContextOptions()->no_breaks && + (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger + } + + void ResultBuilder::react() const { + if(m_failed && checkIfShouldThrow(m_at)) + throwException(); + } + + void failed_out_of_a_testing_context(const AssertData& ad) { + if(g_cs->ah) + g_cs->ah(ad); + else + std::abort(); + } + + bool decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, + const Result& result) { + bool failed = !result.m_passed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); + DOCTEST_ASSERT_IN_TESTS(result.m_decomp); + return !failed; + } + + MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { + m_stream = tlssPush(); + m_file = file; + m_line = line; + m_severity = severity; + } + + MessageBuilder::~MessageBuilder() { + if (!logged) + tlssPop(); + } + + DOCTEST_DEFINE_INTERFACE(IExceptionTranslator) + + bool MessageBuilder::log() { + if (!logged) { + m_string = tlssPop(); + logged = true; + } + + DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); + + const bool isWarn = m_severity & assertType::is_warn; + + // warn is just a message in this context so we don't treat it as an assert + if(!isWarn) { + addAssert(m_severity); + addFailedAssert(m_severity); + } + + return isDebuggerActive() && !getContextOptions()->no_breaks && !isWarn && + (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger + } + + void MessageBuilder::react() { + if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional + throwException(); + } +} // namespace detail +namespace { + using namespace detail; + + // clang-format off + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); + + void encodeTo( std::ostream& os ) const; + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ); + + ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT; + ScopedElement& operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT; + + ~ScopedElement(); + + ScopedElement& writeText( std::string const& text, bool indent = true ); + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer = nullptr; + }; + +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + XmlWriter( std::ostream& os = std::cout ); +#else // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + XmlWriter( std::ostream& os ); +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + ~XmlWriter(); + + XmlWriter( XmlWriter const& ) = delete; + XmlWriter& operator=( XmlWriter const& ) = delete; + + XmlWriter& startElement( std::string const& name ); + + ScopedElement scopedElement( std::string const& name ); + + XmlWriter& endElement(); + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); + + XmlWriter& writeAttribute( std::string const& name, const char* attribute ); + + XmlWriter& writeAttribute( std::string const& name, bool attribute ); + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + std::stringstream rss; + rss << attribute; + return writeAttribute( name, rss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ); + + //XmlWriter& writeComment( std::string const& text ); + + //void writeStylesheetRef( std::string const& url ); + + //XmlWriter& writeBlankLine(); + + void ensureTagClosed(); + + void writeDeclaration(); + + private: + + void newlineIfNecessary(); + + bool m_tagIsOpen = false; + bool m_needsNewline = false; + std::vector m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + +using uchar = unsigned char; + +namespace { + + size_t trailingBytes(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return 2; + } + if ((c & 0xF0) == 0xE0) { + return 3; + } + if ((c & 0xF8) == 0xF0) { + return 4; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + uint32_t headerValue(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return c & 0x1F; + } + if ((c & 0xF0) == 0xE0) { + return c & 0x0F; + } + if ((c & 0xF8) == 0xF0) { + return c & 0x07; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + void hexEscapeChar(std::ostream& os, unsigned char c) { + std::ios_base::fmtflags f(os.flags()); + os << "\\x" + << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast(c); + os.flags(f); + } + +} // anonymous namespace + + XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void XmlEncode::encodeTo( std::ostream& os ) const { + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: https://www.w3.org/TR/xml/#syntax) + + for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { + uchar c = m_str[idx]; + switch (c) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: https://www.w3.org/TR/xml/#syntax + if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') + os << ">"; + else + os << c; + break; + + case '\"': + if (m_forWhat == ForAttributes) + os << """; + else + os << c; + break; + + default: + // Check for control characters and invalid utf-8 + + // Escape control characters in standard ascii + // see https://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { + hexEscapeChar(os, c); + break; + } + + // Plain ASCII: Write it to stream + if (c < 0x7F) { + os << c; + break; + } + + // UTF-8 territory + // Check if the encoding is valid and if it is not, hex escape bytes. + // Important: We do not check the exact decoded values for validity, only the encoding format + // First check that this bytes is a valid lead byte: + // This means that it is not encoded as 1111 1XXX + // Or as 10XX XXXX + if (c < 0xC0 || + c >= 0xF8) { + hexEscapeChar(os, c); + break; + } + + auto encBytes = trailingBytes(c); + // Are there enough bytes left to avoid accessing out-of-bounds memory? + if (idx + encBytes - 1 >= m_str.size()) { + hexEscapeChar(os, c); + break; + } + // The header is valid, check data + // The next encBytes bytes must together be a valid utf-8 + // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) + bool valid = true; + uint32_t value = headerValue(c); + for (std::size_t n = 1; n < encBytes; ++n) { + uchar nc = m_str[idx + n]; + valid &= ((nc & 0xC0) == 0x80); + value = (value << 6) | (nc & 0x3F); + } + + if ( + // Wrong bit pattern of following bytes + (!valid) || + // Overlong encodings + (value < 0x80) || + ( value < 0x800 && encBytes > 2) || // removed "0x80 <= value &&" because redundant + (0x800 < value && value < 0x10000 && encBytes > 3) || + // Encoded value out of range + (value >= 0x110000) + ) { + hexEscapeChar(os, c); + break; + } + + // If we got here, this is in fact a valid(ish) utf-8 sequence + for (std::size_t n = 0; n < encBytes; ++n) { + os << m_str[idx + n]; + } + idx += encBytes - 1; + break; + } + } + } + + std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT + : m_writer( other.m_writer ){ + other.m_writer = nullptr; + } + XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT { + if ( m_writer ) { + m_writer->endElement(); + } + m_writer = other.m_writer; + other.m_writer = nullptr; + return *this; + } + + + XmlWriter::ScopedElement::~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { + m_writer->writeText( text, indent ); + return *this; + } + + XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) + { + // writeDeclaration(); // called explicitly by the reporters that use the writer class - see issue #627 + } + + XmlWriter::~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& XmlWriter::startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& XmlWriter::endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << ""; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, const char* attribute ) { + if( !name.empty() && attribute && attribute[0] != '\0' ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + //XmlWriter& XmlWriter::writeComment( std::string const& text ) { + // ensureTagClosed(); + // m_os << m_indent << ""; + // m_needsNewline = true; + // return *this; + //} + + //void XmlWriter::writeStylesheetRef( std::string const& url ) { + // m_os << "\n"; + //} + + //XmlWriter& XmlWriter::writeBlankLine() { + // ensureTagClosed(); + // m_os << '\n'; + // return *this; + //} + + void XmlWriter::ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + void XmlWriter::writeDeclaration() { + m_os << "\n"; + } + + void XmlWriter::newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } + +// ================================================================================================= +// End of copy-pasted code from Catch +// ================================================================================================= + + // clang-format on + + struct XmlReporter : public IReporter + { + XmlWriter xml; + DOCTEST_DECLARE_MUTEX(mutex) + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + XmlReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + std::stringstream ss; + for(int i = 0; i < num_contexts; ++i) { + contexts[i]->stringify(&ss); + xml.scopedElement("Info").writeText(ss.str()); + ss.str(""); + } + } + } + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + void test_case_start_impl(const TestCaseData& in) { + bool open_ts_tag = false; + if(tc != nullptr) { // we have already opened a test suite + if(std::strcmp(tc->m_test_suite, in.m_test_suite) != 0) { + xml.endElement(); + open_ts_tag = true; + } + } + else { + open_ts_tag = true; // first test case ==> first test suite + } + + if(open_ts_tag) { + xml.startElement("TestSuite"); + xml.writeAttribute("name", in.m_test_suite); + } + + tc = ∈ + xml.startElement("TestCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file.c_str())) + .writeAttribute("line", line(in.m_line)) + .writeAttribute("description", in.m_description); + + if(Approx(in.m_timeout) != 0) + xml.writeAttribute("timeout", in.m_timeout); + if(in.m_may_fail) + xml.writeAttribute("may_fail", true); + if(in.m_should_fail) + xml.writeAttribute("should_fail", true); + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + test_run_start(); + if(opt.list_reporters) { + for(auto& curr : getListeners()) + xml.scopedElement("Listener") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + for(auto& curr : getReporters()) + xml.scopedElement("Reporter") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + } else if(opt.count || opt.list_test_cases) { + for(unsigned i = 0; i < in.num_data; ++i) { + xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) + .writeAttribute("testsuite", in.data[i]->m_test_suite) + .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) + .writeAttribute("line", line(in.data[i]->m_line)) + .writeAttribute("skipped", in.data[i]->m_skip); + } + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + } else if(opt.list_test_suites) { + for(unsigned i = 0; i < in.num_data; ++i) + xml.scopedElement("TestSuite").writeAttribute("name", in.data[i]->m_test_suite); + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + xml.scopedElement("OverallResultsTestSuites") + .writeAttribute("unskipped", in.run_stats->numTestSuitesPassingFilters); + } + xml.endElement(); + } + + void test_run_start() override { + xml.writeDeclaration(); + + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + + xml.startElement("doctest").writeAttribute("binary", binary_name); + if(opt.no_version == false) + xml.writeAttribute("version", DOCTEST_VERSION_STR); + + // only the consequential ones (TODO: filters) + xml.scopedElement("Options") + .writeAttribute("order_by", opt.order_by.c_str()) + .writeAttribute("rand_seed", opt.rand_seed) + .writeAttribute("first", opt.first) + .writeAttribute("last", opt.last) + .writeAttribute("abort_after", opt.abort_after) + .writeAttribute("subcase_filter_levels", opt.subcase_filter_levels) + .writeAttribute("case_sensitive", opt.case_sensitive) + .writeAttribute("no_throw", opt.no_throw) + .writeAttribute("no_skip", opt.no_skip); + } + + void test_run_end(const TestRunStats& p) override { + if(tc) // the TestSuite tag - only if there has been at least 1 test case + xml.endElement(); + + xml.scopedElement("OverallResultsAsserts") + .writeAttribute("successes", p.numAsserts - p.numAssertsFailed) + .writeAttribute("failures", p.numAssertsFailed); + + xml.startElement("OverallResultsTestCases") + .writeAttribute("successes", + p.numTestCasesPassingFilters - p.numTestCasesFailed) + .writeAttribute("failures", p.numTestCasesFailed); + if(opt.no_skipped_summary == false) + xml.writeAttribute("skipped", p.numTestCases - p.numTestCasesPassingFilters); + xml.endElement(); + + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + test_case_start_impl(in); + xml.ensureTagClosed(); + } + + void test_case_reenter(const TestCaseData&) override {} + + void test_case_end(const CurrentTestCaseStats& st) override { + xml.startElement("OverallResultsAsserts") + .writeAttribute("successes", + st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) + .writeAttribute("failures", st.numAssertsFailedCurrentTest) + .writeAttribute("test_case_success", st.testCaseSuccess); + if(opt.duration) + xml.writeAttribute("duration", st.seconds); + if(tc->m_expected_failures) + xml.writeAttribute("expected_failures", tc->m_expected_failures); + xml.endElement(); + + xml.endElement(); + } + + void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) + + xml.scopedElement("Exception") + .writeAttribute("crash", e.is_crash) + .writeText(e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + xml.startElement("SubCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file)) + .writeAttribute("line", line(in.m_line)); + xml.ensureTagClosed(); + } + + void subcase_end() override { xml.endElement(); } + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed && !opt.success) + return; + + DOCTEST_LOCK_MUTEX(mutex) + + xml.startElement("Expression") + .writeAttribute("success", !rb.m_failed) + .writeAttribute("type", assertString(rb.m_at)) + .writeAttribute("filename", skipPathFromFilename(rb.m_file)) + .writeAttribute("line", line(rb.m_line)); + + xml.scopedElement("Original").writeText(rb.m_expr); + + if(rb.m_threw) + xml.scopedElement("Exception").writeText(rb.m_exception.c_str()); + + if(rb.m_at & assertType::is_throws_as) + xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); + if(rb.m_at & assertType::is_throws_with) + xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string.c_str()); + if((rb.m_at & assertType::is_normal) && !rb.m_threw) + xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void log_message(const MessageData& mb) override { + DOCTEST_LOCK_MUTEX(mutex) + + xml.startElement("Message") + .writeAttribute("type", failureString(mb.m_severity)) + .writeAttribute("filename", skipPathFromFilename(mb.m_file)) + .writeAttribute("line", line(mb.m_line)); + + xml.scopedElement("Text").writeText(mb.m_string.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void test_case_skipped(const TestCaseData& in) override { + if(opt.no_skipped_summary == false) { + test_case_start_impl(in); + xml.writeAttribute("skipped", "true"); + xml.endElement(); + } + } + }; + + DOCTEST_REGISTER_REPORTER("xml", 0, XmlReporter); + + void fulltext_log_assert_to_stream(std::ostream& s, const AssertData& rb) { + if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) == + 0) //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << " ) " + << Color::None; + + if(rb.m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "threw as expected!" : "did NOT throw at all!") << "\n"; + } else if((rb.m_at & assertType::is_throws_as) && + (rb.m_at & assertType::is_throws_with)) { //!OCLINT + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string.c_str() + << "\", " << rb.m_exception_type << " ) " << Color::None; + if(rb.m_threw) { + if(!rb.m_failed) { + s << "threw as expected!\n"; + } else { + s << "threw a DIFFERENT exception! (contents: " << rb.m_exception << ")\n"; + } + } else { + s << "did NOT throw at all!\n"; + } + } else if(rb.m_at & + assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", " + << rb.m_exception_type << " ) " << Color::None + << (rb.m_threw ? (rb.m_threw_as ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & + assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string.c_str() + << "\" ) " << Color::None + << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "THREW exception: " : "didn't throw!") << Color::Cyan + << rb.m_exception << "\n"; + } else { + s << (rb.m_threw ? "THREW exception: " : + (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")); + if(rb.m_threw) + s << rb.m_exception << "\n"; + else + s << " values: " << assertString(rb.m_at) << "( " << rb.m_decomp << " )\n"; + } + } + + // TODO: + // - log_message() + // - respond to queries + // - honor remaining options + // - more attributes in tags + struct JUnitReporter : public IReporter + { + XmlWriter xml; + DOCTEST_DECLARE_MUTEX(mutex) + Timer timer; + std::vector deepestSubcaseStackNames; + + struct JUnitTestCaseData + { + static std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + + std::tm timeInfo; +#ifdef DOCTEST_PLATFORM_WINDOWS + gmtime_s(&timeInfo, &rawtime); +#else // DOCTEST_PLATFORM_WINDOWS + gmtime_r(&rawtime, &timeInfo); +#endif // DOCTEST_PLATFORM_WINDOWS + + char timeStamp[timeStampSize]; + const char* const fmt = "%Y-%m-%dT%H:%M:%SZ"; + + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); + return std::string(timeStamp); + } + + struct JUnitTestMessage + { + JUnitTestMessage(const std::string& _message, const std::string& _type, const std::string& _details) + : message(_message), type(_type), details(_details) {} + + JUnitTestMessage(const std::string& _message, const std::string& _details) + : message(_message), type(), details(_details) {} + + std::string message, type, details; + }; + + struct JUnitTestCase + { + JUnitTestCase(const std::string& _classname, const std::string& _name) + : classname(_classname), name(_name), time(0), failures() {} + + std::string classname, name; + double time; + std::vector failures, errors; + }; + + void add(const std::string& classname, const std::string& name) { + testcases.emplace_back(classname, name); + } + + void appendSubcaseNamesToLastTestcase(std::vector nameStack) { + for(auto& curr: nameStack) + if(curr.size()) + testcases.back().name += std::string("/") + curr.c_str(); + } + + void addTime(double time) { + if(time < 1e-4) + time = 0; + testcases.back().time = time; + totalSeconds += time; + } + + void addFailure(const std::string& message, const std::string& type, const std::string& details) { + testcases.back().failures.emplace_back(message, type, details); + ++totalFailures; + } + + void addError(const std::string& message, const std::string& details) { + testcases.back().errors.emplace_back(message, details); + ++totalErrors; + } + + std::vector testcases; + double totalSeconds = 0; + int totalErrors = 0, totalFailures = 0; + }; + + JUnitTestCaseData testCaseData; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + JUnitReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData&) override { + xml.writeDeclaration(); + } + + void test_run_start() override { + xml.writeDeclaration(); + } + + void test_run_end(const TestRunStats& p) override { + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + xml.startElement("testsuites"); + xml.startElement("testsuite").writeAttribute("name", binary_name) + .writeAttribute("errors", testCaseData.totalErrors) + .writeAttribute("failures", testCaseData.totalFailures) + .writeAttribute("tests", p.numAsserts); + if(opt.no_time_in_output == false) { + xml.writeAttribute("time", testCaseData.totalSeconds); + xml.writeAttribute("timestamp", JUnitTestCaseData::getCurrentTimestamp()); + } + if(opt.no_version == false) + xml.writeAttribute("doctest_version", DOCTEST_VERSION_STR); + + for(const auto& testCase : testCaseData.testcases) { + xml.startElement("testcase") + .writeAttribute("classname", testCase.classname) + .writeAttribute("name", testCase.name); + if(opt.no_time_in_output == false) + xml.writeAttribute("time", testCase.time); + // This is not ideal, but it should be enough to mimic gtest's junit output. + xml.writeAttribute("status", "run"); + + for(const auto& failure : testCase.failures) { + xml.scopedElement("failure") + .writeAttribute("message", failure.message) + .writeAttribute("type", failure.type) + .writeText(failure.details, false); + } + + for(const auto& error : testCase.errors) { + xml.scopedElement("error") + .writeAttribute("message", error.message) + .writeText(error.details); + } + + xml.endElement(); + } + xml.endElement(); + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + timer.start(); + } + + void test_case_reenter(const TestCaseData& in) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + + timer.start(); + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + } + + void test_case_end(const CurrentTestCaseStats&) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + } + + void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) + testCaseData.addError("exception", e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + deepestSubcaseStackNames.push_back(in.m_name); + } + + void subcase_end() override {} + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed) // report only failures & ignore the `success` option + return; + + DOCTEST_LOCK_MUTEX(mutex) + + std::ostringstream os; + os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") + << line(rb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; + + fulltext_log_assert_to_stream(os, rb); + log_contexts(os); + testCaseData.addFailure(rb.m_decomp.c_str(), assertString(rb.m_at), os.str()); + } + + void log_message(const MessageData& mb) override { + if(mb.m_severity & assertType::is_warn) // report only failures + return; + + DOCTEST_LOCK_MUTEX(mutex) + + std::ostringstream os; + os << skipPathFromFilename(mb.m_file) << (opt.gnu_file_line ? ":" : "(") + << line(mb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; + + os << mb.m_string.c_str() << "\n"; + log_contexts(os); + + testCaseData.addFailure(mb.m_string.c_str(), + mb.m_severity & assertType::is_check ? "FAIL_CHECK" : "FAIL", os.str()); + } + + void test_case_skipped(const TestCaseData&) override {} + + void log_contexts(std::ostringstream& s) { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << std::endl; + } + } + } + }; + + DOCTEST_REGISTER_REPORTER("junit", 0, JUnitReporter); + + struct Whitespace + { + int nrSpaces; + explicit Whitespace(int nr) + : nrSpaces(nr) {} + }; + + std::ostream& operator<<(std::ostream& out, const Whitespace& ws) { + if(ws.nrSpaces != 0) + out << std::setw(ws.nrSpaces) << ' '; + return out; + } + + struct ConsoleReporter : public IReporter + { + std::ostream& s; + bool hasLoggedCurrentTestStart; + std::vector subcasesStack; + size_t currentSubcaseLevel; + DOCTEST_DECLARE_MUTEX(mutex) + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc; + + ConsoleReporter(const ContextOptions& co) + : s(*co.cout) + , opt(co) {} + + ConsoleReporter(const ContextOptions& co, std::ostream& ostr) + : s(ostr) + , opt(co) {} + + // ========================================================================================= + // WHAT FOLLOWS ARE HELPERS USED BY THE OVERRIDES OF THE VIRTUAL METHODS OF THE INTERFACE + // ========================================================================================= + + void separator_to_stream() { + s << Color::Yellow + << "===============================================================================" + "\n"; + } + + const char* getSuccessOrFailString(bool success, assertType::Enum at, + const char* success_str) { + if(success) + return success_str; + return failureString(at); + } + + Color::Enum getSuccessOrFailColor(bool success, assertType::Enum at) { + return success ? Color::BrightGreen : + (at & assertType::is_warn) ? Color::Yellow : Color::Red; + } + + void successOrFailColoredStringToStream(bool success, assertType::Enum at, + const char* success_str = "SUCCESS") { + s << getSuccessOrFailColor(success, at) + << getSuccessOrFailString(success, at, success_str) << ": "; + } + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << Color::None << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << "\n"; + } + } + + s << "\n"; + } + + // this was requested to be made virtual so users could override it + virtual void file_line_to_stream(const char* file, int line, + const char* tail = "") { + s << Color::LightGrey << skipPathFromFilename(file) << (opt.gnu_file_line ? ":" : "(") + << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option + << (opt.gnu_file_line ? ":" : "):") << tail; + } + + void logTestStart() { + if(hasLoggedCurrentTestStart) + return; + + separator_to_stream(); + file_line_to_stream(tc->m_file.c_str(), tc->m_line, "\n"); + if(tc->m_description) + s << Color::Yellow << "DESCRIPTION: " << Color::None << tc->m_description << "\n"; + if(tc->m_test_suite && tc->m_test_suite[0] != '\0') + s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n"; + if(strncmp(tc->m_name, " Scenario:", 11) != 0) + s << Color::Yellow << "TEST CASE: "; + s << Color::None << tc->m_name << "\n"; + + for(size_t i = 0; i < currentSubcaseLevel; ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + + if(currentSubcaseLevel != subcasesStack.size()) { + s << Color::Yellow << "\nDEEPEST SUBCASE STACK REACHED (DIFFERENT FROM THE CURRENT ONE):\n" << Color::None; + for(size_t i = 0; i < subcasesStack.size(); ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + } + + s << "\n"; + + hasLoggedCurrentTestStart = true; + } + + void printVersion() { + if(opt.no_version == false) + s << Color::Cyan << "[doctest] " << Color::None << "doctest version is \"" + << DOCTEST_VERSION_STR << "\"\n"; + } + + void printIntro() { + if(opt.no_intro == false) { + printVersion(); + s << Color::Cyan << "[doctest] " << Color::None + << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; + } + } + + void printHelp() { + int sizePrefixDisplay = static_cast(strlen(DOCTEST_OPTIONS_PREFIX_DISPLAY)); + printVersion(); + // clang-format off + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "boolean values: \"1/on/yes/true\" or \"0/off/no/false\"\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filter values: \"str1,str2,str3\" (comma separated strings)\n"; + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filters use wildcards for matching strings\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "something passes a filter if any of the strings in a filter matches\n"; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "ALL FLAGS, OPTIONS AND FILTERS ALSO AVAILABLE WITH A \"" DOCTEST_CONFIG_OPTIONS_PREFIX "\" PREFIX!!!\n"; +#endif + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "Query flags - the program quits after them. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "?, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "help, -" DOCTEST_OPTIONS_PREFIX_DISPLAY "h " + << Whitespace(sizePrefixDisplay*0) << "prints this message\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "v, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "version " + << Whitespace(sizePrefixDisplay*1) << "prints the version\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "c, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "count " + << Whitespace(sizePrefixDisplay*1) << "prints the number of matching tests\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ltc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-cases " + << Whitespace(sizePrefixDisplay*1) << "lists all matching tests by name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-suites " + << Whitespace(sizePrefixDisplay*1) << "lists all matching test suites\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-reporters " + << Whitespace(sizePrefixDisplay*1) << "lists all registered reporters\n\n"; + // ================================================================================== << 79 + s << Color::Cyan << "[doctest] " << Color::None; + s << "The available / options/filters are:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sfe, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tse, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase= " + << Whitespace(sizePrefixDisplay*1) << "filters subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "r, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "reporters= " + << Whitespace(sizePrefixDisplay*1) << "reporters to use (console is default)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "o, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "out= " + << Whitespace(sizePrefixDisplay*1) << "output filename\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ob, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "order-by= " + << Whitespace(sizePrefixDisplay*1) << "how the tests should be ordered\n"; + s << Whitespace(sizePrefixDisplay*3) << " - [file/suite/name/rand/none]\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "rs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "rand-seed= " + << Whitespace(sizePrefixDisplay*1) << "seed for random ordering\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "f, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "first= " + << Whitespace(sizePrefixDisplay*1) << "the first test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "l, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "last= " + << Whitespace(sizePrefixDisplay*1) << "the last test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "aa, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "abort-after= " + << Whitespace(sizePrefixDisplay*1) << "stop after failed assertions\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "scfl,--" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-filter-levels= " + << Whitespace(sizePrefixDisplay*1) << "apply filters for the first levels\n"; + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "Bool options - can be used like flags and true is assumed. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "s, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "success= " + << Whitespace(sizePrefixDisplay*1) << "include successful assertions in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "cs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "case-sensitive= " + << Whitespace(sizePrefixDisplay*1) << "filters being treated as case sensitive\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "e, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "exit= " + << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " + << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "m, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "minimal= " + << Whitespace(sizePrefixDisplay*1) << "minimal console output (only failures)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "q, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "quiet= " + << Whitespace(sizePrefixDisplay*1) << "no console output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " + << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " + << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " + << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ni, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-intro= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework intro in the output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " + << Whitespace(sizePrefixDisplay*1) << "disables colors in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "fc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "force-colors= " + << Whitespace(sizePrefixDisplay*1) << "use colors even when not in a tty\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nb, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-breaks= " + << Whitespace(sizePrefixDisplay*1) << "disables breakpoints in debuggers\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ns, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-skip= " + << Whitespace(sizePrefixDisplay*1) << "don't skip test cases marked as skip\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "gfl, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "gnu-file-line= " + << Whitespace(sizePrefixDisplay*1) << ":n: vs (n): for line numbers in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "npf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-path-filenames= " + << Whitespace(sizePrefixDisplay*1) << "only filenames and no paths in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nln, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-line-numbers= " + << Whitespace(sizePrefixDisplay*1) << "0 instead of real line numbers in output\n"; + // ================================================================================== << 79 + // clang-format on + + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "for more information visit the project documentation\n\n"; + } + + void printRegisteredReporters() { + printVersion(); + auto printReporters = [this] (const reporterMap& reporters, const char* type) { + if(reporters.size()) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all registered " << type << "\n"; + for(auto& curr : reporters) + s << "priority: " << std::setw(5) << curr.first.first + << " name: " << curr.first.second << "\n"; + } + }; + printReporters(getListeners(), "listeners"); + printReporters(getReporters(), "reporters"); + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + if(opt.version) { + printVersion(); + } else if(opt.help) { + printHelp(); + } else if(opt.list_reporters) { + printRegisteredReporters(); + } else if(opt.count || opt.list_test_cases) { + if(opt.list_test_cases) { + s << Color::Cyan << "[doctest] " << Color::None + << "listing all test case names\n"; + separator_to_stream(); + } + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_name << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + + } else if(opt.list_test_suites) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all test suites\n"; + separator_to_stream(); + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_test_suite << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "test suites with unskipped test cases passing the current filters: " + << g_cs->numTestSuitesPassingFilters << "\n"; + } + } + + void test_run_start() override { + if(!opt.minimal) + printIntro(); + } + + void test_run_end(const TestRunStats& p) override { + if(opt.minimal && p.numTestCasesFailed == 0) + return; + + separator_to_stream(); + s << std::dec; + + auto totwidth = int(std::ceil(log10(static_cast(std::max(p.numTestCasesPassingFilters, static_cast(p.numAsserts))) + 1))); + auto passwidth = int(std::ceil(log10(static_cast(std::max(p.numTestCasesPassingFilters - p.numTestCasesFailed, static_cast(p.numAsserts - p.numAssertsFailed))) + 1))); + auto failwidth = int(std::ceil(log10(static_cast(std::max(p.numTestCasesFailed, static_cast(p.numAssertsFailed))) + 1))); + const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0; + s << Color::Cyan << "[doctest] " << Color::None << "test cases: " << std::setw(totwidth) + << p.numTestCasesPassingFilters << " | " + << ((p.numTestCasesPassingFilters == 0 || anythingFailed) ? Color::None : + Color::Green) + << std::setw(passwidth) << p.numTestCasesPassingFilters - p.numTestCasesFailed << " passed" + << Color::None << " | " << (p.numTestCasesFailed > 0 ? Color::Red : Color::None) + << std::setw(failwidth) << p.numTestCasesFailed << " failed" << Color::None << " |"; + if(opt.no_skipped_summary == false) { + const int numSkipped = p.numTestCases - p.numTestCasesPassingFilters; + s << " " << (numSkipped == 0 ? Color::None : Color::Yellow) << numSkipped + << " skipped" << Color::None; + } + s << "\n"; + s << Color::Cyan << "[doctest] " << Color::None << "assertions: " << std::setw(totwidth) + << p.numAsserts << " | " + << ((p.numAsserts == 0 || anythingFailed) ? Color::None : Color::Green) + << std::setw(passwidth) << (p.numAsserts - p.numAssertsFailed) << " passed" << Color::None + << " | " << (p.numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(failwidth) + << p.numAssertsFailed << " failed" << Color::None << " |\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "Status: " << (p.numTestCasesFailed > 0 ? Color::Red : Color::Green) + << ((p.numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl; + } + + void test_case_start(const TestCaseData& in) override { + hasLoggedCurrentTestStart = false; + tc = ∈ + subcasesStack.clear(); + currentSubcaseLevel = 0; + } + + void test_case_reenter(const TestCaseData&) override { + subcasesStack.clear(); + } + + void test_case_end(const CurrentTestCaseStats& st) override { + if(tc->m_no_output) + return; + + // log the preamble of the test case only if there is something + // else to print - something other than that an assert has failed + if(opt.duration || + (st.failure_flags && st.failure_flags != static_cast(TestCaseFailureReason::AssertFailure))) + logTestStart(); + + if(opt.duration) + s << Color::None << std::setprecision(6) << std::fixed << st.seconds + << " s: " << tc->m_name << "\n"; + + if(st.failure_flags & TestCaseFailureReason::Timeout) + s << Color::Red << "Test case exceeded time limit of " << std::setprecision(6) + << std::fixed << tc->m_timeout << "!\n"; + + if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedButDidnt) { + s << Color::Red << "Should have failed but didn't! Marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedAndDid) { + s << Color::Yellow << "Failed as expected so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::CouldHaveFailedAndDid) { + s << Color::Yellow << "Allowed to fail so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::DidntFailExactlyNumTimes) { + s << Color::Red << "Didn't fail exactly " << tc->m_expected_failures + << " times so marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::FailedExactlyNumTimes) { + s << Color::Yellow << "Failed exactly " << tc->m_expected_failures + << " times as expected so marking it as not failed!\n"; + } + if(st.failure_flags & TestCaseFailureReason::TooManyFailedAsserts) { + s << Color::Red << "Aborting - too many failed asserts!\n"; + } + s << Color::None; // lgtm [cpp/useless-expression] + } + + void test_case_exception(const TestCaseException& e) override { + DOCTEST_LOCK_MUTEX(mutex) + if(tc->m_no_output) + return; + + logTestStart(); + + file_line_to_stream(tc->m_file.c_str(), tc->m_line, " "); + successOrFailColoredStringToStream(false, e.is_crash ? assertType::is_require : + assertType::is_check); + s << Color::Red << (e.is_crash ? "test case CRASHED: " : "test case THREW exception: ") + << Color::Cyan << e.error_string << "\n"; + + int num_stringified_contexts = get_num_stringified_contexts(); + if(num_stringified_contexts) { + auto stringified_contexts = get_stringified_contexts(); + s << Color::None << " logged: "; + for(int i = num_stringified_contexts; i > 0; --i) { + s << (i == num_stringified_contexts ? "" : " ") + << stringified_contexts[i - 1] << "\n"; + } + } + s << "\n" << Color::None; + } + + void subcase_start(const SubcaseSignature& subc) override { + subcasesStack.push_back(subc); + ++currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void subcase_end() override { + --currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void log_assert(const AssertData& rb) override { + if((!rb.m_failed && !opt.success) || tc->m_no_output) + return; + + DOCTEST_LOCK_MUTEX(mutex) + + logTestStart(); + + file_line_to_stream(rb.m_file, rb.m_line, " "); + successOrFailColoredStringToStream(!rb.m_failed, rb.m_at); + + fulltext_log_assert_to_stream(s, rb); + + log_contexts(); + } + + void log_message(const MessageData& mb) override { + if(tc->m_no_output) + return; + + DOCTEST_LOCK_MUTEX(mutex) + + logTestStart(); + + file_line_to_stream(mb.m_file, mb.m_line, " "); + s << getSuccessOrFailColor(false, mb.m_severity) + << getSuccessOrFailString(mb.m_severity & assertType::is_warn, mb.m_severity, + "MESSAGE") << ": "; + s << Color::None << mb.m_string << "\n"; + log_contexts(); + } + + void test_case_skipped(const TestCaseData&) override {} + }; + + DOCTEST_REGISTER_REPORTER("console", 0, ConsoleReporter); + +#ifdef DOCTEST_PLATFORM_WINDOWS + struct DebugOutputWindowReporter : public ConsoleReporter + { + DOCTEST_THREAD_LOCAL static std::ostringstream oss; + + DebugOutputWindowReporter(const ContextOptions& co) + : ConsoleReporter(co, oss) {} + +#define DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(func, type, arg) \ + void func(type arg) override { \ + bool with_col = g_no_colors; \ + g_no_colors = false; \ + ConsoleReporter::func(arg); \ + if(oss.tellp() != std::streampos{}) { \ + DOCTEST_OUTPUT_DEBUG_STRING(oss.str().c_str()); \ + oss.str(""); \ + } \ + g_no_colors = with_col; \ + } + + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_start, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_end, const TestRunStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_start, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_reenter, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_end, const CurrentTestCaseStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_exception, const TestCaseException&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_start, const SubcaseSignature&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_end, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_assert, const AssertData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_message, const MessageData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_skipped, const TestCaseData&, in) + }; + + DOCTEST_THREAD_LOCAL std::ostringstream DebugOutputWindowReporter::oss; +#endif // DOCTEST_PLATFORM_WINDOWS + + // the implementation of parseOption() + bool parseOptionImpl(int argc, const char* const* argv, const char* pattern, String* value) { + // going from the end to the beginning and stopping on the first occurrence from the end + for(int i = argc; i > 0; --i) { + auto index = i - 1; + auto temp = std::strstr(argv[index], pattern); + if(temp && (value || strlen(temp) == strlen(pattern))) { //!OCLINT prefer early exits and continue + // eliminate matches in which the chars before the option are not '-' + bool noBadCharsFound = true; + auto curr = argv[index]; + while(curr != temp) { + if(*curr++ != '-') { + noBadCharsFound = false; + break; + } + } + if(noBadCharsFound && argv[index][0] == '-') { + if(value) { + // parsing the value of an option + temp += strlen(pattern); + const unsigned len = strlen(temp); + if(len) { + *value = temp; + return true; + } + } else { + // just a flag - no value + return true; + } + } + } + } + return false; + } + + // parses an option and returns the string after the '=' character + bool parseOption(int argc, const char* const* argv, const char* pattern, String* value = nullptr, + const String& defaultVal = String()) { + if(value) + *value = defaultVal; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + // offset (normally 3 for "dt-") to skip prefix + if(parseOptionImpl(argc, argv, pattern + strlen(DOCTEST_CONFIG_OPTIONS_PREFIX), value)) + return true; +#endif // DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + return parseOptionImpl(argc, argv, pattern, value); + } + + // locates a flag on the command line + bool parseFlag(int argc, const char* const* argv, const char* pattern) { + return parseOption(argc, argv, pattern); + } + + // parses a comma separated list of words after a pattern in one of the arguments in argv + bool parseCommaSepArgs(int argc, const char* const* argv, const char* pattern, + std::vector& res) { + String filtersString; + if(parseOption(argc, argv, pattern, &filtersString)) { + // tokenize with "," as a separator, unless escaped with backslash + std::ostringstream s; + auto flush = [&s, &res]() { + auto string = s.str(); + if(string.size() > 0) { + res.push_back(string.c_str()); + } + s.str(""); + }; + + bool seenBackslash = false; + const char* current = filtersString.c_str(); + const char* end = current + strlen(current); + while(current != end) { + char character = *current++; + if(seenBackslash) { + seenBackslash = false; + if(character == ',' || character == '\\') { + s.put(character); + continue; + } + s.put('\\'); + } + if(character == '\\') { + seenBackslash = true; + } else if(character == ',') { + flush(); + } else { + s.put(character); + } + } + + if(seenBackslash) { + s.put('\\'); + } + flush(); + return true; + } + return false; + } + + enum optionType + { + option_bool, + option_int + }; + + // parses an int/bool option from the command line + bool parseIntOption(int argc, const char* const* argv, const char* pattern, optionType type, + int& res) { + String parsedValue; + if(!parseOption(argc, argv, pattern, &parsedValue)) + return false; + + if(type) { + // integer + // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... + int theInt = std::atoi(parsedValue.c_str()); + if (theInt != 0) { + res = theInt; //!OCLINT parameter reassignment + return true; + } + } else { + // boolean + const char positive[][5] = { "1", "true", "on", "yes" }; // 5 - strlen("true") + 1 + const char negative[][6] = { "0", "false", "off", "no" }; // 6 - strlen("false") + 1 + + // if the value matches any of the positive/negative possibilities + for (unsigned i = 0; i < 4; i++) { + if (parsedValue.compare(positive[i], true) == 0) { + res = 1; //!OCLINT parameter reassignment + return true; + } + if (parsedValue.compare(negative[i], true) == 0) { + res = 0; //!OCLINT parameter reassignment + return true; + } + } + } + return false; + } +} // namespace + +Context::Context(int argc, const char* const* argv) + : p(new detail::ContextState) { + parseArgs(argc, argv, true); + if(argc) + p->binary_name = argv[0]; +} + +Context::~Context() { + if(g_cs == p) + g_cs = nullptr; + delete p; +} + +void Context::applyCommandLine(int argc, const char* const* argv) { + parseArgs(argc, argv); + if(argc) + p->binary_name = argv[0]; +} + +// parses args +void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { + using namespace detail; + + // clang-format off + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sf=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file-exclude=",p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sfe=", p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ts=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite-exclude=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tse=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tc=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case-exclude=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tce=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sc=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase-exclude=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sce=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "reporters=", p->filters[8]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "r=", p->filters[8]); + // clang-format on + + int intRes = 0; + String strRes; + +#define DOCTEST_PARSE_AS_BOOL_OR_FLAG(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_bool, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_bool, intRes)) \ + p->var = static_cast(intRes); \ + else if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name) || \ + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname)) \ + p->var = true; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_INT_OPTION(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_int, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_int, intRes)) \ + p->var = intRes; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_STR_OPTION(name, sname, var, default) \ + if(parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", &strRes, default) || \ + parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", &strRes, default) || \ + withDefaults) \ + p->var = strRes + + // clang-format off + DOCTEST_PARSE_STR_OPTION("out", "o", out, ""); + DOCTEST_PARSE_STR_OPTION("order-by", "ob", order_by, "file"); + DOCTEST_PARSE_INT_OPTION("rand-seed", "rs", rand_seed, 0); + + DOCTEST_PARSE_INT_OPTION("first", "f", first, 0); + DOCTEST_PARSE_INT_OPTION("last", "l", last, UINT_MAX); + + DOCTEST_PARSE_INT_OPTION("abort-after", "aa", abort_after, 0); + DOCTEST_PARSE_INT_OPTION("subcase-filter-levels", "scfl", subcase_filter_levels, INT_MAX); + + DOCTEST_PARSE_AS_BOOL_OR_FLAG("success", "s", success, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("minimal", "m", minimal, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("quiet", "q", quiet, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-intro", "ni", no_intro, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-breaks", "nb", no_breaks, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skip", "ns", no_skip, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("gnu-file-line", "gfl", gnu_file_line, !bool(DOCTEST_MSVC)); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-path-filenames", "npf", no_path_in_filenames, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-line-numbers", "nln", no_line_numbers, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-debug-output", "ndo", no_debug_output, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skipped-summary", "nss", no_skipped_summary, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-time-in-output", "ntio", no_time_in_output, false); + // clang-format on + + if(withDefaults) { + p->help = false; + p->version = false; + p->count = false; + p->list_test_cases = false; + p->list_test_suites = false; + p->list_reporters = false; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "help") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "h") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "?")) { + p->help = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "version") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "v")) { + p->version = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "count") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "c")) { + p->count = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-cases") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ltc")) { + p->list_test_cases = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-suites") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lts")) { + p->list_test_suites = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-reporters") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lr")) { + p->list_reporters = true; + p->exit = true; + } +} + +// allows the user to add procedurally to the filters from the command line +void Context::addFilter(const char* filter, const char* value) { setOption(filter, value); } + +// allows the user to clear all filters from the command line +void Context::clearFilters() { + for(auto& curr : p->filters) + curr.clear(); +} + +// allows the user to override procedurally the bool options from the command line +void Context::setOption(const char* option, bool value) { + setOption(option, value ? "true" : "false"); +} + +// allows the user to override procedurally the int options from the command line +void Context::setOption(const char* option, int value) { + setOption(option, toString(value).c_str()); +} + +// allows the user to override procedurally the string options from the command line +void Context::setOption(const char* option, const char* value) { + auto argv = String("-") + option + "=" + value; + auto lvalue = argv.c_str(); + parseArgs(1, &lvalue); +} + +// users should query this in their main() and exit the program if true +bool Context::shouldExit() { return p->exit; } + +void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } + +void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } + +void Context::setCout(std::ostream* out) { p->cout = out; } + +static class DiscardOStream : public std::ostream +{ +private: + class : public std::streambuf + { + private: + // allowing some buffering decreases the amount of calls to overflow + char buf[1024]; + + protected: + std::streamsize xsputn(const char_type*, std::streamsize count) override { return count; } + + int_type overflow(int_type ch) override { + setp(std::begin(buf), std::end(buf)); + return traits_type::not_eof(ch); + } + } discardBuf; + +public: + DiscardOStream() + : std::ostream(&discardBuf) {} +} discardOut; + +// the main function that does all the filtering and test running +int Context::run() { + using namespace detail; + + // save the old context state in case such was setup - for using asserts out of a testing context + auto old_cs = g_cs; + // this is the current contest + g_cs = p; + is_running_in_test = true; + + g_no_colors = p->no_colors; + p->resetRunData(); + + std::fstream fstr; + if(p->cout == nullptr) { + if(p->quiet) { + p->cout = &discardOut; + } else if(p->out.size()) { + // to a file if specified + fstr.open(p->out.c_str(), std::fstream::out); + p->cout = &fstr; + } else { +#ifndef DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + // stdout by default + p->cout = &std::cout; +#else // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + return EXIT_FAILURE; +#endif // DOCTEST_CONFIG_NO_INCLUDE_IOSTREAM + } + } + + FatalConditionHandler::allocateAltStackMem(); + + auto cleanup_and_return = [&]() { + FatalConditionHandler::freeAltStackMem(); + + if(fstr.is_open()) + fstr.close(); + + // restore context + g_cs = old_cs; + is_running_in_test = false; + + // we have to free the reporters which were allocated when the run started + for(auto& curr : p->reporters_currently_used) + delete curr; + p->reporters_currently_used.clear(); + + if(p->numTestCasesFailed && !p->no_exitcode) + return EXIT_FAILURE; + return EXIT_SUCCESS; + }; + + // setup default reporter if none is given through the command line + if(p->filters[8].empty()) + p->filters[8].push_back("console"); + + // check to see if any of the registered reporters has been selected + for(auto& curr : getReporters()) { + if(matchesAny(curr.first.second.c_str(), p->filters[8], false, p->case_sensitive)) + p->reporters_currently_used.push_back(curr.second(*g_cs)); + } + + // TODO: check if there is nothing in reporters_currently_used + + // prepend all listeners + for(auto& curr : getListeners()) + p->reporters_currently_used.insert(p->reporters_currently_used.begin(), curr.second(*g_cs)); + +#ifdef DOCTEST_PLATFORM_WINDOWS + if(isDebuggerActive() && p->no_debug_output == false) + p->reporters_currently_used.push_back(new DebugOutputWindowReporter(*g_cs)); +#endif // DOCTEST_PLATFORM_WINDOWS + + // handle version, help and no_run + if(p->no_run || p->version || p->help || p->list_reporters) { + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, QueryData()); + + return cleanup_and_return(); + } + + std::vector testArray; + for(auto& curr : getRegisteredTests()) + testArray.push_back(&curr); + p->numTestCases = testArray.size(); + + // sort the collected records + if(!testArray.empty()) { + if(p->order_by.compare("file", true) == 0) { + std::sort(testArray.begin(), testArray.end(), fileOrderComparator); + } else if(p->order_by.compare("suite", true) == 0) { + std::sort(testArray.begin(), testArray.end(), suiteOrderComparator); + } else if(p->order_by.compare("name", true) == 0) { + std::sort(testArray.begin(), testArray.end(), nameOrderComparator); + } else if(p->order_by.compare("rand", true) == 0) { + std::srand(p->rand_seed); + + // random_shuffle implementation + const auto first = &testArray[0]; + for(size_t i = testArray.size() - 1; i > 0; --i) { + int idxToSwap = std::rand() % (i + 1); + + const auto temp = first[i]; + + first[i] = first[idxToSwap]; + first[idxToSwap] = temp; + } + } else if(p->order_by.compare("none", true) == 0) { + // means no sorting - beneficial for death tests which call into the executable + // with a specific test case in mind - we don't want to slow down the startup times + } + } + + std::set testSuitesPassingFilt; + + bool query_mode = p->count || p->list_test_cases || p->list_test_suites; + std::vector queryResults; + + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_start, DOCTEST_EMPTY); + + // invoke the registered functions if they match the filter criteria (or just count them) + for(auto& curr : testArray) { + const auto& tc = *curr; + + bool skip_me = false; + if(tc.m_skip && !p->no_skip) + skip_me = true; + + if(!matchesAny(tc.m_file.c_str(), p->filters[0], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_file.c_str(), p->filters[1], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_test_suite, p->filters[2], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_test_suite, p->filters[3], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_name, p->filters[4], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_name, p->filters[5], false, p->case_sensitive)) + skip_me = true; + + if(!skip_me) + p->numTestCasesPassingFilters++; + + // skip the test if it is not in the execution range + if((p->last < p->numTestCasesPassingFilters && p->first <= p->last) || + (p->first > p->numTestCasesPassingFilters)) + skip_me = true; + + if(skip_me) { + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_skipped, tc); + continue; + } + + // do not execute the test if we are to only count the number of filter passing tests + if(p->count) + continue; + + // print the name of the test and don't execute it + if(p->list_test_cases) { + queryResults.push_back(&tc); + continue; + } + + // print the name of the test suite if not done already and don't execute it + if(p->list_test_suites) { + if((testSuitesPassingFilt.count(tc.m_test_suite) == 0) && tc.m_test_suite[0] != '\0') { + queryResults.push_back(&tc); + testSuitesPassingFilt.insert(tc.m_test_suite); + p->numTestSuitesPassingFilters++; + } + continue; + } + + // execute the test if it passes all the filtering + { + p->currentTest = &tc; + + p->failure_flags = TestCaseFailureReason::None; + p->seconds = 0; + + // reset atomic counters + p->numAssertsFailedCurrentTest_atomic = 0; + p->numAssertsCurrentTest_atomic = 0; + + p->fullyTraversedSubcases.clear(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); + + p->timer.start(); + + bool run_test = true; + + do { + // reset some of the fields for subcases (except for the set of fully passed ones) + p->reachedLeaf = false; + // May not be empty if previous subcase exited via exception. + p->subcaseStack.clear(); + p->currentSubcaseDepth = 0; + + p->shouldLogCurrentException = true; + + // reset stuff for logging with INFO() + p->stringifiedContexts.clear(); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +// MSVC 2015 diagnoses fatalConditionHandler as unused (because reset() is a static method) +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4101) // unreferenced local variable + FatalConditionHandler fatalConditionHandler; // Handle signals + // execute the test + tc.m_test(); + fatalConditionHandler.reset(); +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + } catch(const TestFailureException&) { + p->failure_flags |= TestCaseFailureReason::AssertFailure; + } catch(...) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, + {translateActiveException(), false}); + p->failure_flags |= TestCaseFailureReason::Exception; + } +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + + // exit this loop if enough assertions have failed - even if there are more subcases + if(p->abort_after > 0 && + p->numAssertsFailed + p->numAssertsFailedCurrentTest_atomic >= p->abort_after) { + run_test = false; + p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; + } + + if(!p->nextSubcaseStack.empty() && run_test) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); + if(p->nextSubcaseStack.empty()) + run_test = false; + } while(run_test); + + p->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + p->currentTest = nullptr; + + // stop executing tests if enough assertions have failed + if(p->abort_after > 0 && p->numAssertsFailed >= p->abort_after) + break; + } + } + + if(!query_mode) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } else { + QueryData qdata; + qdata.run_stats = g_cs; + qdata.data = queryResults.data(); + qdata.num_data = unsigned(queryResults.size()); + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); + } + + return cleanup_and_return(); +} + +DOCTEST_DEFINE_INTERFACE(IReporter) + +int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } +const IContextScope* const* IReporter::get_active_contexts() { + return get_num_active_contexts() ? &detail::g_infoContexts[0] : nullptr; +} + +int IReporter::get_num_stringified_contexts() { return detail::g_cs->stringifiedContexts.size(); } +const String* IReporter::get_stringified_contexts() { + return get_num_stringified_contexts() ? &detail::g_cs->stringifiedContexts[0] : nullptr; +} + +namespace detail { + void registerReporterImpl(const char* name, int priority, reporterCreatorFunc c, bool isReporter) { + if(isReporter) + getReporters().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + else + getListeners().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + } +} // namespace detail + +} // namespace doctest + +#endif // DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) // 'function' : must be 'attribute' - see issue #182 +int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +DOCTEST_SUPPRESS_COMMON_WARNINGS_POP + +#endif // DOCTEST_LIBRARY_IMPLEMENTATION +#endif // DOCTEST_CONFIG_IMPLEMENT + +#ifdef DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN +#undef WIN32_LEAN_AND_MEAN +#undef DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN +#endif // DOCTEST_UNDEF_WIN32_LEAN_AND_MEAN + +#ifdef DOCTEST_UNDEF_NOMINMAX +#undef NOMINMAX +#undef DOCTEST_UNDEF_NOMINMAX +#endif // DOCTEST_UNDEF_NOMINMAX diff --git a/include/tinytc/builder.h.mochi b/include/tinytc/builder.h.mochi new file mode 100644 index 00000000..c2468fd6 --- /dev/null +++ b/include/tinytc/builder.h.mochi @@ -0,0 +1,534 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef BUILDER_H_20250625 +#define BUILDER_H_20250625 + +#include "tinytc/export.h" +#include "tinytc/types.h" + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +//////////////////////////// +///////// Attribute //////// +//////////////////////////// + +/** + * @brief Get array attribute + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param array_size [in] number of elements in array, must be 0 if array == nullptr + * @param array [in][range(0, array_size)] attribute array + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_array_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + size_t array_size, const tinytc_attr_t *array); + +/** + * @brief Get boolean attribute + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param value [in] value of attribute + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_boolean_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + tinytc_bool_t value); + +/** + * @brief Get dictionary attribute + * + * Each name must only appear once. + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param items_size [in] number of elements in items array, must be 0 if items == nullptr + * @param items [in][range(0, items_size)] array of items + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_dictionary_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + size_t items_size, + tinytc_named_attr_t *items); + +/** + * @brief Get dictionary attribute with pre-sorted items + * + * The list of items must be sorted by name and each name must only appear once. + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param items_size [in] number of elements in items array, must be 0 if items == nullptr + * @param items [in][range(0, items_size)] array of items + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t +tinytc_dictionary_attr_get_with_sorted(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + size_t items_size, const tinytc_named_attr_t *items); + +/** + * @brief Sort items array by name + * + * @param items_size [in] number of elements in items array, must be 0 if items == nullptr + * @param items [in][range(0, items_size)] array of items + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_dictionary_attr_sort(size_t items_size, + tinytc_named_attr_t *items); + +/** + * @brief Get integer attribute + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param value [in] value of attribute + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_integer_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, int64_t value); + +/** + * @brief Get string attribute + * + * @param attr [out] pointer to the attribute object + * @param ctx [inout] compiler context + * @param str_length [in] number of characters (not including a null terminator) + * @param str [in] string; not necessarily null-terminated + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_string_attr_get(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + size_t str_length, char const *str); + +//////////////////////////// +///////// Data type //////// +//////////////////////////// + +// もち api_builder_h "tinytc/types.anko" + +/** + * @brief Get context object from type object + * + * The reference count of the context remains unchanged. + * + * @param ty [in] type object + * @param ctx [out] pointer to context object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_type_get_compiler_context(const_tinytc_type_t ty, + tinytc_compiler_context_t *ctx); + +//////////////////////////// +/////////// Value ////////// +//////////////////////////// + +/** + * @brief Set name of value + * + * @param vl [inout] value object + * @param name [in] name; null-terminated string + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name); + +/** + * @brief Set name of value with explicit number of characters + * + * @param vl [inout] value object + * @param name_length [in] number of characters + * @param name [in] name; not necessarily null-terminated + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_set_name_n(tinytc_value_t vl, size_t name_length, + char const *name); + +/** + * @brief Get name of value + * + * The returned pointer may be invalidated if the value or any node in the abstract syntax + * tree referencing the value is modified. + * + * @param vl [in] value object + * @param name [out] pointer to C string + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name); + +/** + * @brief Get type of value + * + * @param vl [in] value object + * @param ty [out] pointer to data type + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_value_get_type(const_tinytc_value_t vl, tinytc_type_t *ty); + +//////////////////////////// +/////// Instructions /////// +//////////////////////////// + +// もち api_builder_h "tinytc/instructions.anko" + +/** + * @brief Create boolean constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_boolean(tinytc_inst_t *instr, + tinytc_bool_t value, + tinytc_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create complex constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value_re [in] constant value (real part) + * @param value_im [in] constant value (imaginary part) + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, + double value_re, double value_im, + tinytc_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create floating constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_float(tinytc_inst_t *instr, double value, + tinytc_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Create integer constant instruction + * + * @param instr [out] pointer to the inst object created + * @param value [in] constant value + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t value, + tinytc_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Creates the multiplicative identity constant (i.e. "1") for the given data type + * + * @param instr [out] pointer to the inst object created + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, + tinytc_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Creates the additive identity constant (i.e. "0") for the given data type + * + * @param instr [out] pointer to the inst object created + * @param ty [in] type of constant + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, + tinytc_type_t ty, + const tinytc_location_t *loc); + +/** + * @brief Get parent region of instruction + * + * @param instr [in] inst object + * @param parent [out] parent region + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_parent_region(tinytc_inst_t instr, + tinytc_region_t *parent); + +/** + * @brief Get child regions of instruction + * + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results + * + * @param instr [in] inst object + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greater than the number of + * results, the value is updated with the correct number of results + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, size_t *result_list_size, + tinytc_region_t *result_list); + +/** + * @brief Get values produced by instruction + * + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results + * + * @param instr [in] inst object + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greater than the number of + * results, the value is updated with the correct number of results + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, size_t *result_list_size, + tinytc_value_t *result_list); + +/** + * @brief Set instruction attributes + * + * @param instr [inout] inst object + * @param a [in] attribute object (dictionary attribute) + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_inst_set_attr(tinytc_inst_t instr, tinytc_attr_t a); + +//////////////////////////// +////////// Region ////////// +//////////////////////////// + +/** + * @brief Append instruction to region + * + * The region takes ownership of the instruction. + * An instruction must not be added to multiple regions. + * + * @param reg [inout] region object + * @param instr [in,pass_ownership] instruction + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_append(tinytc_region_t reg, tinytc_inst_t instr); + +/** + * @brief Returns iterator pointing to begin of the region + * + * @param reg [in] region + * @param iterator [out] inst iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_begin(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator); + +/** + * @brief Returns iterator pointing to the end of the region + * + * The end iterator must not be dereferenced. + * + * @param reg [in] region] + * @param iterator [out] inst iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_end(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator); + +/** + * @brief Erase instruction at the position of the iterator + * + * The iterator is updated to point to the instruction coming after the iterator or the end iterator + * + * @param reg [inout] region object + * @param iterator [inout] iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_erase(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator); + +/** + * @brief Insert instruction at the position before the iterator + * + * The iterator is updated to point to the instruction that was just inserted. + * + * The region takes ownership of the instruction. + * An instruction must not be inserted into multiple regions. + * + * @param reg [inout] region object + * @param iterator [inout] + * @param instr [in,pass_ownership] + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_insert(tinytc_region_t reg, + tinytc_inst_iterator_t *iterator, + tinytc_inst_t instr); + +/** + * @brief Move iterator to the next instruction + * + * @param iterator [inout] iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_next_inst(tinytc_inst_iterator_t *iterator); + +/** + * @brief Move iterator to the previous instruction + * + * @param iterator [inout] iterator + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prev_inst(tinytc_inst_iterator_t *iterator); + +/** + * @brief Get region parameters + * + * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain + * the number of results + * + * @param reg [in] region object + * @param result_list_size [inout] pointer to the number of results; if result_list_size is 0, then + * it is updated with the number of results; if result_list_size is greather than the number of + * results, the value is updated with the correct number of results + * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result + * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, + size_t *result_list_size, + tinytc_value_t *result_list); + +//////////////////////////// +/////////// Func /////////// +//////////////////////////// + +/** + * @brief Create function + * + * Function takes ownership of region. + * + * @param fun [out] pointer to the func object created + * @param name_length [in] length of function_name + * @param name [in] function name + * @param num_params [in] number of parameters + * @param param_type_list [in][range(0,num_params)] parameter data types; can be nullptr if + * num_params is 0 + * @param ty [in] result type (must be void for host-callable function) + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_func_create(tinytc_func_t *fun, size_t name_length, + char const *name, size_t num_params, + const tinytc_type_t *param_type_list, + tinytc_type_t ty, const tinytc_location_t *loc); + +/** + * @brief Set function attributes + * + * @param fun [inout] function object + * @param a [in] attribute dictionary + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_func_set_attr(tinytc_func_t fun, tinytc_attr_t a); + +/** + * @brief Set parameter attributes + * + * @param fun [in] function object + * @param param_no [in] paramater number (0 to num_parameters-1) + * @param a [in] attribute dictionary + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_func_set_parameter_attr(tinytc_func_t fun, size_t param_no, + tinytc_attr_t a); + +/** + * @brief Get function body + * + * @param fun [in] function object + * @param body [out] pointer to body region + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_func_get_body(tinytc_func_t fun, tinytc_region_t *body); + +//////////////////////////// +/////////// Prog /////////// +//////////////////////////// + +/** + * @brief Create program + * + * @param prg [out] pointer to the prog object created + * @param ctx [in] compiler context object + * @param loc [in][optional] Source code location; can be nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc); + +/** + * @brief Append function to program + * + * The program takes ownership of the function. + * A function must not be added to multiple programs nor must the user destroy the function after + * adding it to the program. + * + * @param prg [inout] program object + * @param fun [in,pass_ownership] function object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun); + +#ifdef __cplusplus +} +#endif + +#endif // BUILDER_H_20250625 diff --git a/include/tinytc/builder.hpp.mochi b/include/tinytc/builder.hpp.mochi new file mode 100644 index 00000000..367dc64f --- /dev/null +++ b/include/tinytc/builder.hpp.mochi @@ -0,0 +1,1065 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef BUILDER_HPP_20250625 +#define BUILDER_HPP_20250625 + +#include "tinytc/builder.h" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +//////////////////////////// +///////// Attribute //////// +//////////////////////////// + +template struct getter; + +//! getter specialization for array_attr +template <> struct getter { + /** + * @brief Get array attribute + * + * @param ctx compiler context + * @param array attribute array + * + * @return Attribute + */ + inline auto operator()(tinytc_compiler_context_t ctx, array_view array) + -> tinytc_attr_t { + tinytc_attr_t a; + CHECK_STATUS(tinytc_array_attr_get(&a, ctx, array.size(), array.data())); + return a; + } +}; + +//! getter specialization for boolean_attr +template <> struct getter { + /** + * @brief Get boolean attribute + * + * @param ctx compiler context + * @param value boolean value + * + * @return Attribute + */ + inline auto operator()(tinytc_compiler_context_t ctx, bool value) -> tinytc_attr_t { + tinytc_attr_t a; + CHECK_STATUS(tinytc_boolean_attr_get(&a, ctx, value)); + return a; + } +}; + +//! getter specialization for dictionary_attr +template <> struct getter { + /** + * @brief Get dictionary attribute + * + * Each name must only appear once. + * + * @param ctx compiler context + * @param items named items array + * + * @return Attribute + */ + inline auto operator()(tinytc_compiler_context_t ctx, + mutable_array_view items) -> tinytc_attr_t { + tinytc_attr_t a; + CHECK_STATUS(tinytc_dictionary_attr_get(&a, ctx, items.size(), items.data())); + return a; + } +}; + +/** + * @brief Get dictionary attribute + * + * The list of items must be sorted by name and each name must only appear once. + * + * @param ctx compiler context + * @param items named items array + * + * @return Attribute + */ +inline tinytc_attr_t get_dictionary_attr_with_sorted(tinytc_compiler_context_t ctx, + array_view items) { + tinytc_attr_t a; + CHECK_STATUS(tinytc_dictionary_attr_get_with_sorted(&a, ctx, items.size(), items.data())); + return a; +} + +/** + * @brief Sort list of items + * + * Each name must only appear once. + * + * @param items named items array + */ +inline void sort_items(mutable_array_view items) { + CHECK_STATUS(tinytc_dictionary_attr_sort(items.size(), items.data())); +} + +//! getter specialization for integer_attr +template <> struct getter { + /** + * @brief Get integer attribute + * + * @param ctx compiler context + * @param value integer value + * + * @return Attribute + */ + inline auto operator()(tinytc_compiler_context_t ctx, std::int64_t value) -> tinytc_attr_t { + tinytc_attr_t a; + CHECK_STATUS(tinytc_integer_attr_get(&a, ctx, value)); + return a; + } +}; + +//! getter specialization for string_attr +template <> struct getter { + /** + * @brief Get string attribute + * + * @param ctx compiler context + * @param str string + * + * @return Attribute + */ + inline auto operator()(tinytc_compiler_context_t ctx, std::string_view str) -> tinytc_attr_t { + tinytc_attr_t a; + CHECK_STATUS(tinytc_string_attr_get(&a, ctx, str.size(), str.data())); + return a; + } +}; + +//////////////////////////// +///////// Data type //////// +//////////////////////////// + +// もち api_builder_hpp "tinytc/types.anko" + +/** + * @brief Get type + * + * @param args Arguments forwarded to getter + * + * @return Type + */ +template inline auto get(Args &&...args) { + return getter{}(std::forward(args)...); +} + +/** + * Returns an appropriate tinytc type for C++ type T + * + * Specializations exist for bool, int8_t, int16_t, int32_t, int64_t, bfloat16, half, float, double, + * std::complex, std::complex + */ +template auto to_type(tinytc_compiler_context_t ctx) -> tinytc_type_t { + if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v) { + return get(ctx); + } else if constexpr (std::is_same_v>) { + return get(ctx); + } else if constexpr (std::is_same_v>) { + return get(ctx); + } else { + static_assert(false, "Not implemented"); + } +} + +/** + * @brief Get context + * + * @param ty type + * + * @return Compiler context + */ +inline auto get_compiler_context(const_tinytc_type_t ty) + -> shared_handle { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_type_get_compiler_context(ty, &ctx)); + return shared_handle{ctx, true}; +} + +//////////////////////////// +/////////// Value ////////// +//////////////////////////// + +/** + * @brief Get name + * + * @param val Value + * + * @return Name as C-string + */ +inline auto get_name(tinytc_value_t val) -> char const * { + char const *name; + CHECK_STATUS(tinytc_value_get_name(val, &name)); + return name; +} + +/** + * @brief Set value name + * + * @param val Value + * @param name Name + */ +inline void set_name(tinytc_value_t val, std::string_view name) { + CHECK_STATUS(tinytc_value_set_name_n(val, name.size(), name.data())); +} + +/** + * @brief Get type + * + * @param val Value + * + * @return Data type + */ +inline auto get_type(tinytc_value_t val) -> tinytc_type_t { + tinytc_type_t ty; + CHECK_STATUS(tinytc_value_get_type(val, &ty)); + return ty; +} + +//////////////////////////// +/////////// Inst /////////// +//////////////////////////// + +/** + * @brief Get result values + * + * May be called with empty view (vals = {}) to get the number of results. + * + * @param in Instruction + * @param vals view on buffer that stores results + * + * @return Minimum of view size and actual number of result values + */ +inline auto get_values(tinytc_inst_t in, mutable_array_view vals) -> std::size_t { + std::size_t result_list_size = vals.size(); + tinytc_value_t *vs = reinterpret_cast(vals.data()); + CHECK_STATUS(tinytc_inst_get_values(in, &result_list_size, vs)); + return result_list_size; +} + +/** + * @brief Get child regions + * + * May be called with empty view (vals = {}) to get the number of child regions. + * + * @param in Instruction + * @param regs view on buffer that stores results + * + * @return Minimum of view size and actual number of child regions + */ +inline auto get_regions(tinytc_inst_t in, mutable_array_view regs) -> std::size_t { + std::size_t result_list_size = regs.size(); + tinytc_region_t *rl = reinterpret_cast(regs.data()); + CHECK_STATUS(tinytc_inst_get_regions(in, &result_list_size, rl)); + return result_list_size; +} + +/** + * @brief Set attribute + * + * @param in Instruction + * @param a attribute + */ +inline void set_attr(tinytc_inst_t in, tinytc_attr_t a) { + CHECK_STATUS(tinytc_inst_set_attr(in, a)); +} + +//////////////////////////// +/////// Instructions /////// +//////////////////////////// + +template struct creator; + +// もち api_builder_hpp "tinytc/instructions.anko" + +//! creator specialization for constant_inst +template <> struct creator { + constexpr static std::int32_t max_returned_values = 1; + /** + * @brief Create boolean constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto operator()(bool value, tinytc_type_t ty, location const &loc = {}) + -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_boolean(&instr, value, ty, &loc), loc); + return unique_handle{instr}; + } + + /** + * @brief Create complex constant + * + * @param value Complex constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto operator()(std::complex value, tinytc_type_t ty, location const &loc = {}) + -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC( + tinytc_constant_inst_create_complex(&instr, value.real(), value.imag(), ty, &loc), loc); + return unique_handle{instr}; + } + + /** + * @brief Create floating constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto operator()(double value, tinytc_type_t ty, location const &loc = {}) + -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_float(&instr, value, ty, &loc), loc); + return unique_handle{instr}; + } + + /** + * @brief Create integer constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto operator()(std::int32_t value, tinytc_type_t ty, location const &loc = {}) + -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); + return unique_handle{instr}; + } + + /** + * @brief Create integer constant + * + * @param value Constant + * @param ty Data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto operator()(std::int64_t value, tinytc_type_t ty, location const &loc = {}) + -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_int(&instr, value, ty, &loc), loc); + return unique_handle{instr}; + } + + /** + * @brief Create multiplicative identity constant ("1") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto one(tinytc_type_t ty, location const &loc = {}) -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_one(&instr, ty, &loc), loc); + return unique_handle{instr}; + } + + /** + * @brief Create additive identity constant ("0") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Instruction + */ + inline auto zero(tinytc_type_t ty, location const &loc = {}) -> unique_handle { + tinytc_inst_t instr; + CHECK_STATUS_LOC(tinytc_constant_inst_create_zero(&instr, ty, &loc), loc); + return unique_handle{instr}; + } +}; + +/** + * @brief Create instruction + * + * @param args Arguments forwarded to creator + * + * @return Instruction + */ +template +inline auto create(Args &&...args) -> unique_handle { + return creator{}(std::forward(args)...); +} + +//////////////////////////// +////////// Region ////////// +//////////////////////////// + +/** + * @brief Append instruction to region + * + * @param reg Region + * @param instruction instruction object + */ +inline void append(tinytc_region_t reg, unique_handle &&instruction) { + CHECK_STATUS(tinytc_region_append(reg, instruction.release())); +} + +/** + * @brief Get iterator pointing to the begin of the region + * + * @param reg Region + * + * @return iterator + */ +inline auto begin(tinytc_region_t reg) -> tinytc_inst_iterator_t { + tinytc_inst_iterator_t it; + CHECK_STATUS(tinytc_region_begin(reg, &it)); + return it; +} + +/** + * @brief Get iterator pointing to past the end of the region + * + * @param reg Region + * + * @return iterator + */ +inline auto end(tinytc_region_t reg) -> tinytc_inst_iterator_t { + tinytc_inst_iterator_t it; + CHECK_STATUS(tinytc_region_end(reg, &it)); + return it; +} + +/** + * @brief Erase instruction at iterator + * + * @param reg Region + * @param iterator Iterator + * + * @return Iterator pointing to the instruction after the one erased + */ +inline auto erase(tinytc_region_t reg, tinytc_inst_iterator_t iterator) -> tinytc_inst_iterator_t { + auto it = iterator; + CHECK_STATUS(tinytc_region_erase(reg, &it)); + return it; +} + +/** + * @brief Insert instruction into region before the iterator + * + * @param reg Region + * @param iterator Iterator + * @param instruction instruction object + * + * @return Iterator pointing to the newly inserted instruction + */ +inline auto insert(tinytc_region_t reg, tinytc_inst_iterator_t iterator, + unique_handle &&instruction) -> tinytc_inst_iterator_t { + auto it = iterator; + CHECK_STATUS(tinytc_region_insert(reg, &it, instruction.release())); + return it; +} + +/** + * + * @brief Get region parameters + * + * May be called with empty view (vals = {}) to get the number of parameters. + * + * @param reg Region + * @param params view on buffer that stores parameters + * + * @return Minimum of view size and actual number of parameters + */ +inline auto get_parameters(tinytc_region_t reg, mutable_array_view params) + -> std::size_t { + std::size_t result_list_size = params.size(); + tinytc_value_t *ps = reinterpret_cast(params.data()); + CHECK_STATUS(tinytc_region_get_parameters(reg, &result_list_size, ps)); + return result_list_size; +} + +/** + * @brief Move iterator to next instruction + * + * @param iterator + */ +inline void next(tinytc_inst_iterator_t &iterator) { CHECK_STATUS(tinytc_next_inst(&iterator)); } +/** + * @brief Move iterator to previous instruction + * + * @param iterator + */ +inline void prev(tinytc_inst_iterator_t &iterator) { CHECK_STATUS(tinytc_prev_inst(&iterator)); } + +//////////////////////////// +/////////// Func /////////// +//////////////////////////// + +/** + * @brief Get function body + * + * @param f function + * + * @return Region + */ +inline auto get_body(tinytc_func_t f) -> tinytc_region_t { + tinytc_region_t body; + CHECK_STATUS(tinytc_func_get_body(f, &body)); + return body; +} + +/** + * @brief Create function + * + * @param name Function name + * @param param_type_list List of parameter types + * @param ty Result type (must be void for host-callable function) + * @param loc Source code location + * + * @return Function + */ +inline auto create_func(std::string_view name, array_view param_type_list, + tinytc_type_t ty, location const &loc = {}) + -> unique_handle { + tinytc_func_t fun; + CHECK_STATUS_LOC(tinytc_func_create(&fun, name.size(), name.data(), param_type_list.size(), + param_type_list.data(), ty, &loc), + loc); + return unique_handle(fun); +} + +/** + * @brief Set function attributes + * + * @param f function + * @param a attribute + */ +inline void set_attr(tinytc_func_t f, tinytc_attr_t a) { CHECK_STATUS(tinytc_func_set_attr(f, a)); } + +/** + * @brief Set attribute of function parameter + * + * @param f function + * @param param_no parameter number + * @param a attribute + */ +inline void set_parameter_attr(tinytc_func_t f, std::size_t param_no, tinytc_attr_t a) { + CHECK_STATUS(tinytc_func_set_parameter_attr(f, param_no, a)); +} + +//////////////////////////// +/////////// Prog /////////// +//////////////////////////// + +/** + * @brief Append function to program + * + * @param prg program + * @param fun function + */ +inline void add_function(tinytc_prog_t prg, unique_handle &&fun) { + CHECK_STATUS(tinytc_prog_add_function(prg, fun.release())); +} + +/** + * @brief Create program + * + * @param ctx Compiler context + * @param loc Source code location + * + * @return Program + */ +inline auto create_prog(tinytc_compiler_context_t ctx, location const &loc = {}) + -> shared_handle { + tinytc_prog_t prg; + CHECK_STATUS_LOC(tinytc_prog_create(&prg, ctx, &loc), loc); + return shared_handle{prg}; +} + +//////////////////////////// +////////// Builder ///////// +//////////////////////////// + +//! Builder for regions +class region_builder { + public: + /** + * @brief ctor + * + * @param reg region object + */ + region_builder(tinytc_region_t reg) : reg_{reg}, ip_{end(reg_)} {} + /** + * @brief ctor + * + * @param reg region object + * @param ip insertion point + */ + region_builder(tinytc_region_t reg, tinytc_inst_iterator_t ip) : reg_{reg}, ip_{ip} {} + + /** + * @brief Get insertion point + * + * @return Iterator + */ + inline auto get_insertion_point() const -> tinytc_inst_iterator_t { return ip_; } + + /** + * @brief Add instruction + * + * @param i Instruction + * + * @return Value returned by instruction; may be empty + */ + [[maybe_unused]] inline auto add(unique_handle &&i) -> tinytc_value_t { + auto result = tinytc_value_t{}; + get_values(i.get(), result); + insert(reg_, ip_, std::move(i)); + return result; + } + + /** + * @brief Add instruction that returns multiple values + * + * @param i Instruction + * + * @return Values returned by instruction + */ + [[maybe_unused]] inline auto add_multivalued(unique_handle &&i) + -> std::vector { + auto num_results = get_values(i.get(), {}); + auto results = std::vector(static_cast(num_results)); + results.resize(get_values(i.get(), results)); + insert(reg_, ip_, std::move(i)); + return results; + } + + /** + * @brief Create and add instruction + * + * @param args Arguments forwarded to creator + * + * @return Type is either void (T::num_returned_values==0), value (T::num_returned_values==1), + * or std::vector (T::num_returned_values>1) + */ + template inline auto create(Args &&...args) { + auto i = creator{}(std::forward(args)...); + if constexpr (creator::max_returned_values > 1) { + return add_multivalued(std::move(i)); + } else if constexpr (creator::max_returned_values == 1) { + return add(std::move(i)); + } else { + insert(reg_, ip_, std::move(i)); + } + } + + /** + * @brief Create multiplicative identity constant ("1") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Value returned by instruction + */ + inline auto constant_one(tinytc_type_t ty, location const &loc = {}) -> tinytc_value_t { + return add(creator{}.one(ty, loc)); + } + /** + * @brief Create additive identity constant ("0") for the given data type + * + * @param ty Scalar data type + * @param loc Source code location + * + * @return Value returned by instruction + */ + inline auto constant_zero(tinytc_type_t ty, location const &loc = {}) -> tinytc_value_t { + return add(creator{}.zero(ty, loc)); + } + + /** + * @brief Build for-loop with functor f(region_builder&, value) -> void + * + * The loop trip count is passed as second argument to the functor. + * + * @tparam F Functor type + * @param from Loop variable start + * @param to Loop variable bound + * @param f Functor + * @param attributes For attributes + * @param loc Source code location + */ + template + void for_loop(tinytc_value_t from, tinytc_value_t to, F &&f, tinytc_attr_t attributes = nullptr, + location const &loc = {}) { + for_loop(std::move(from), std::move(to), nullptr, std::forward(f), attributes, loc); + } + /** + * @brief Build for-loop with functor f(region_builder&, value) -> void + * + * The loop trip count is passed as second argument to the functor. + * + * @tparam F Functor type + * @param from Loop variable start + * @param to Loop variable bound + * @param step Loop variable step + * @param f Functor + * @param attributes For attributes + * @param loc Source code location + */ + template + void for_loop(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, F &&f, + tinytc_attr_t attributes = nullptr, location const &loc = {}) { + auto fi = creator{}(from, to, step, {}, {}, loc); + set_attr(fi.get(), attributes); + auto reg = tinytc_region_t{}; + get_regions(fi.get(), reg); + auto loop_var = tinytc_value_t{}; + get_parameters(reg, loop_var); + if (!reg || !loop_var) { + throw status::internal_compiler_error; + } + insert(reg_, ip_, std::move(fi)); + auto bb = region_builder{reg}; + f(bb, loop_var); + } + /** + * @brief Build for-loop with functor f(region_builder&, array_view) -> void + * + * The loop trip count is the first value in the array_view. + * The following values are the loop-carried values. + * + * @tparam F Functor type + * @param from Loop variable start + * @param to Loop variable bound + * @param step Loop variable step + * @param initial_value_list Array of initial values; can be {} + * @param return_type_list Array of return types; can be {} + * @param f Functor + * @param attributes For attributes + * @param loc Source code location + */ + template + auto for_loop(tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, + array_view initial_value_list, + array_view return_type_list, F &&f, + tinytc_attr_t attributes = nullptr, location const &loc = {}) + -> std::vector { + auto fi = creator{}(from, to, step, initial_value_list, return_type_list, loc); + set_attr(fi.get(), attributes); + auto reg = tinytc_region_t{}; + get_regions(fi.get(), reg); + auto num_params = get_parameters(reg, {}); + auto params = std::vector(num_params); + get_parameters(reg, params); + if (!reg || num_params != 1 + initial_value_list.size()) { + throw status::internal_compiler_error; + } + auto results = add_multivalued(std::move(fi)); + auto bb = region_builder{reg}; + f(bb, array_view(params)); + return results; + } + /** + * @brief Build foreach-loop with functor f(region_builder&, array_view) -> void + * + * @tparam F Functor type + * @param from Loop variable start list + * @param to Loop variable bound list + * @param f functor + * @param loc Source code location + */ + template + void foreach_loop(array_view from, array_view to, F &&f, + location const &loc = {}) { + auto fi = creator{}(std::move(from), std::move(to), loc); + auto reg = tinytc_region_t{}; + get_regions(fi.get(), reg); + auto num_params = get_parameters(reg, {}); + auto params = std::vector(num_params); + get_parameters(reg, params); + if (!reg || num_params != from.size() || num_params != to.size()) { + throw status::internal_compiler_error; + } + insert(reg_, ip_, std::move(fi)); + auto bb = region_builder{reg}; + f(bb, array_view(params)); + } + + /** + * @brief Build if with functor then(region_builder&) -> void + * + * Note: If the if instruction returns values then we must have a "yield" instruction in + * both the "then" and the "else" branch. So to return values use the "ifelse" function. + * + * @tparam F Functor type + * @param condition Condition value + * @param then Then region functor + * @param loc Source code location + */ + template + void if_condition(tinytc_value_t condition, F &&then, location const &loc = {}) { + auto ii = creator{}(std::move(condition), {}, loc); + auto reg = tinytc_region_t{}; + get_regions(ii.get(), reg); + if (!reg) { + throw status::internal_compiler_error; + } + insert(reg_, ip_, std::move(ii)); + auto bb = region_builder{reg}; + then(bb); + } + /** + * @brief Build if/else with functors then(region_builder&) -> void and + * otherwise(region_builder&) -> void + * + * @tparam F "if" functor type + * @tparam G "else" functor type + * @param condition If condition + * @param then "if" functor + * @param otherwise "else" functor + * @param return_type_list List of types of returned values + * @param loc Source code location + * + * @return Returned values + */ + template + auto ifelse(tinytc_value_t condition, F &&then, G &&otherwise, + array_view return_type_list = {}, location const &loc = {}) + -> std::vector { + auto ii = creator{}(std::move(condition), return_type_list, loc); + std::array regs = {}; + get_regions(ii.get(), regs); + if (!regs[0] || !regs[1]) { + throw status::internal_compiler_error; + } + auto results = add_multivalued(std::move(ii)); + auto bb0 = region_builder{regs[0]}; + then(bb0); + auto bb1 = region_builder{regs[1]}; + otherwise(bb1); + return results; + } + + /** + * @brief Get region + * + * @return Region + */ + inline auto get_region() -> tinytc_region_t { return reg_; } + + private: + tinytc_region_t reg_; + tinytc_inst_iterator_t ip_; +}; + +//////////////////////////// +////////// Recipe ////////// +//////////////////////////// + +/** + * @brief Get program + * + * @param rec Recipe + * + * @return Program + */ +inline auto get_prog(const_tinytc_recipe_t rec) -> shared_handle { + tinytc_prog_t prg; + CHECK_STATUS(tinytc_recipe_get_prog(rec, &prg)); + return shared_handle{prg, true}; +} + +/** + * @brief Get binary + * + * @param rec Recipe + * + * @return Binary + */ +inline auto get_binary(const_tinytc_recipe_t rec) -> shared_handle { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_recipe_get_binary(rec, &bin)); + return shared_handle{bin, true}; +} + +/** + * @brief Get recipe + * + * @param handler Recipe handler + * + * @return Recipe + */ +inline auto get_recipe(const_tinytc_recipe_handler_t handler) -> shared_handle { + tinytc_recipe_t rec; + CHECK_STATUS(tinytc_recipe_handler_get_recipe(handler, &rec)); + return shared_handle{rec, true}; +} + +/** + * @brief Set kernel arguments + * + * @tparam T Scalar type; must match scalar_type passed to constructor + * @param handler Recipe handler + * @param howmany Batch size + * @param alpha @f$\alpha@f$ + * @param A Memory object used for A-matrix + * @param B Memory object used for B-matrix + * @param beta @f$\beta@f$ + * @param C Memory object used for C-matrix + */ +template +static void set_small_gemm_batched_args(tinytc_recipe_handler_t handler, std::int64_t howmany, + T alpha, mem A, mem B, T beta, mem C) { + CHECK_STATUS(tinytc_recipe_small_gemm_batched_set_args( + handler, howmany, sizeof(alpha), &alpha, static_cast(A.type), A.value, + static_cast(B.type), B.value, sizeof(beta), &beta, + static_cast(C.type), C.value)); +} + +/** + * @brief Create small GEMM batched recipe + * + * Cf. @ref tinytc_recipe_small_gemm_batched_create + * + * @param info Core info + * @param number_ty Number type of @f$\alpha@f$, A, B, @f$\beta@f$, C + * @param tA Operation applied on A + * @param tB Operation applied on B + * @param M Number of rows of A and C + * @param N Number of columns of B and C + * @param K Number of columns of A, number of rows of B + * @param ldA Leading dimension of an A matrix + * @param strideA Stride of A-matrices + * @param ldB Leading dimension of an B matrix + * @param strideB Stride of B-matrices + * @param ldC Leading dimension of an C matrix + * @param strideC Stride of C-matrices + * + * @return Small GEMM batched recipe + */ +inline auto create_small_gemm_batched(tinytc_core_info_t info, tinytc_type_t number_ty, + transpose tA, transpose tB, std::int64_t M, std::int64_t N, + std::int64_t K, std::int64_t ldA, std::int64_t strideA, + std::int64_t ldB, std::int64_t strideB, std::int64_t ldC, + std::int64_t strideC) -> shared_handle { + tinytc_recipe_t rec; + CHECK_STATUS(tinytc_recipe_small_gemm_batched_create( + &rec, info, number_ty, static_cast(tA), + static_cast(tB), M, N, K, ldA, strideA, ldB, strideB, ldC, strideC)); + return shared_handle{rec}; +} + +/** + * @brief Set kernel arguments + * + * @tparam T Scalar type; must match scalar_type passed to constructor + * @param handler Recipe handler + * @param M Number of rows of A and C + * @param alpha @f$\alpha@f$ + * @param A Memory object used for A-matrix + * @param ldA Leading dimension of A + * @param B Memory object used for B-matrix + * @param ldB Leading dimension of B + * @param beta @f$\beta@f$ + * @param C Memory object used for C-matrix + * @param ldC Leading dimension of C + */ +template +static void set_tall_and_skinny_args(tinytc_recipe_handler_t handler, std::int64_t M, T alpha, + mem A, std::int64_t ldA, mem B, std::int64_t ldB, T beta, + mem C, std::int64_t ldC) { + CHECK_STATUS(tinytc_recipe_tall_and_skinny_set_args( + handler, M, sizeof(alpha), &alpha, static_cast(A.type), A.value, ldA, + static_cast(B.type), B.value, ldB, sizeof(beta), &beta, + static_cast(C.type), C.value, ldC)); +} + +/** + * @brief Create tall and skinny recipe + * + * Cf. @ref tinytc_recipe_tall_and_skinny_create + * + * @param info Core info + * @param number_ty Number type of @f$\alpha@f$, A, B, @f$\beta@f$, C + * @param N Number of columns of B and C + * @param K Number of columns of A, number of rows of B + * @param M_block_size Chunk size for M-mode + * + * @return Tall and skinny recipe + */ +inline auto create_tall_and_skinny(tinytc_core_info_t info, tinytc_type_t number_ty, std::int64_t N, + std::int64_t K, std::int32_t M_block_size = 0) + -> shared_handle { + tinytc_recipe_t rec; + CHECK_STATUS(tinytc_recipe_tall_and_skinny_create(&rec, info, number_ty, N, K, M_block_size)); + return shared_handle{rec}; +} + +/** + * @brief Create tall and skinny recipe with additional specialization constants + * + * Cf. @ref tinytc_recipe_tall_and_skinny_create_specialized + * + * @param info Core info + * @param number_ty Number type of @f$\alpha@f$, A, B, @f$\beta@f$, C + * @param M Number of rows of A and C; can be dynamic + * @param N Number of columns of B and C + * @param K Number of columns of A, number of rows of B + * @param ldA Leading dimension of A; can be dynamic + * @param ldB Leading dimension of B; can be dynamic + * @param ldC Leading dimension of C; can be dynamic + * @param alignA [in] Memory alignment of A; can be 0 + * @param alignB [in] Memory alignment of B; can be 0 + * @param alignC [in] Memory alignment of C; can be 0 + * @param M_block_size Chunk size for M-mode + * + * @return Tall and skinny recipe + */ +inline auto create_tall_and_skinny_specialized(tinytc_core_info_t info, tinytc_type_t number_ty, + std::int64_t M, std::int64_t N, std::int64_t K, + std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, + std::int32_t alignA, std::int32_t alignB, + std::int32_t alignC, std::int32_t M_block_size = 0) + -> shared_handle { + tinytc_recipe_t rec; + CHECK_STATUS(tinytc_recipe_tall_and_skinny_create_specialized( + &rec, info, number_ty, M, N, K, ldA, ldB, ldC, alignA, alignB, alignC, M_block_size)); + return shared_handle{rec}; +} + +} // namespace tinytc + +#endif // BUILDER_HPP_20250625 diff --git a/include/tinytc/core.h b/include/tinytc/core.h new file mode 100644 index 00000000..eb17ceaa --- /dev/null +++ b/include/tinytc/core.h @@ -0,0 +1,790 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CORE_20240409_H +#define CORE_20240409_H + +#include "tinytc/export.h" +#include "tinytc/types.h" +#include "tinytc/version.h" + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +//////////////////////////// +/////////// Error ////////// +//////////////////////////// + +#define TINYTC_CHECK_STATUS(X) \ + do { \ + tinytc_status_t status = X; \ + if (status != tinytc_status_success) { \ + return status; \ + } \ + } while (0) + +//////////////////////////// +////////// FP math ///////// +//////////////////////////// + +/** + * @brief Convert f32 number to bf16 number (represented as ushort) + * + * @param x f32 number + * + * @return bf16 number + */ +TINYTC_EXPORT uint16_t tinytc_f32_to_bf16_as_ui16(float x); + +/** + * @brief Convert bf16 number (represented as ushort) to f32 number + * + * @param x bf16 number + * + * @return f32 number + */ +TINYTC_EXPORT float tinytc_bf16_as_ui16_to_f32(uint16_t x); + +/** + * @brief Convert f32 number to f16 number (represented as ushort) + * + * @param x f32 number + * + * @return f16 number + */ +TINYTC_EXPORT uint16_t tinytc_f32_to_f16_as_ui16(float x); + +/** + * @brief Convert f16 number (represented as ushort) to f32 number + * + * @param x f16 number + * + * @return f32 number + */ +TINYTC_EXPORT float tinytc_f16_as_ui16_to_f32(uint16_t x); + +//////////////////////////// +/////////// Prog /////////// +//////////////////////////// + +/** + * @brief Get context object from program object + * + * The reference count of the context remains unchanged. + * + * @param prg [in] program object + * @param ctx [out] pointer to context object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, + tinytc_compiler_context_t *ctx); + +//////////////////////////// +// Visitors and transforms / +//////////////////////////// + +/** + * @brief Dump program to stderr + * + * @param prg [in] program object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_dump(tinytc_prog_t prg); + +/** + * @brief Print program to file + * + * @param prg [in] program object + * @param filename [in] filename + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_file(tinytc_prog_t prg, char const *filename); + +/** + * @brief Print program to string + * + * The user is responsible to dispose the string with tinytc_string_destroy. + * + * @param prg [in] program object + * @param str [out] pointer to string + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_string(tinytc_prog_t prg, char **str); + +/** + * @brief Dump SPIR-V module to stderr + * + * @param mod [in] module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_dump(const_tinytc_spv_mod_t mod); + +/** + * @brief Print SPIR-V module to file + * + * @param mod [in] module + * @param filename [in] filename + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_print_to_file(const_tinytc_spv_mod_t mod, + char const *filename); + +/** + * @brief Print SPIR-V module to string + * + * The user is responsible to dispose the string with tinytc_string_destroy. + * + * @param mod [in] module + * @param str [out] pointer to string + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_print_to_string(const_tinytc_spv_mod_t mod, + char **str); + +//////////////////////////// +//////// Device info /////// +//////////////////////////// + +/** + * @brief Create core_info for a generic GPUs + * + * @param info [out] pointer to the core_info object created + * @param register_space [in] Size of register file per subgroup in bytes + * @param max_work_group_size [in] Maximum size of local work group + * @param sgs_size [in] Length of sgs array + * @param sgs [in] Allowed subgroup sizes + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_generic_create(tinytc_core_info_t *info, + int32_t register_space, + int32_t max_work_group_size, + size_t sgs_size, int32_t const *sgs); + +/** + * @brief Look up core info for Intel GPU architecture + * + * @param info [out] pointer to the core_info object created + * @param arch [in] IP version + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_arch( + tinytc_core_info_t *info, tinytc_intel_gpu_architecture_t arch); + +/** + * @brief Look up core info for Intel GPU architecture + * + * @param info [out] pointer to the core_info object created + * @param name [in] architecture name + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_name(tinytc_core_info_t *info, + char const *name); + +/** + * @brief Create core_info for Intel GPUs + * + * @param info [out] pointer to the core_info object created + * @param ip_version [in] IP version of architecture + * @param num_eus_per_subslice [in] Number of Execution Units (Xe Vector Engines) per subslice (Xe + * Core) + * @param num_threads_per_eu [in] Number of threads per Execution Unit (Xe Vector Engine) + * @param sgs_size [in] Length of sgs array + * @param sgs [in] Allowed subgroup sizes + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create(tinytc_core_info_t *info, + uint32_t ip_version, + int32_t num_eus_per_subslice, + int32_t num_threads_per_eu, + size_t sgs_size, int32_t const *sgs); + +/** + * @brief Returns available subgroup sizes + * + * @param info [in] core info object + * @param sgs_size [out] pointer to number of subgroup sizes + * @param sgs [out] pointer to subgroup size array; pointer is invalidated when core info is deleted + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_subgroup_sizes(const_tinytc_core_info_t info, + size_t *sgs_size, + int32_t const **sgs); + +/** + * @brief Returns register space per subgroup in bytes + * + * @param info [in] core info object + * @param space [out] pointer to register space + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_register_space(const_tinytc_core_info_t info, + int32_t *space); + +/** + * @brief Set core features + * + * @param info [inout] core info object + * @param flags [in] set core features; must be 0 or a combination of tinytc_core_feature_flag_t + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_core_features(tinytc_core_info_t info, + tinytc_core_feature_flags_t flags); + +/** + * @brief Get core features + * + * @param info [in] core info object + * @param flags [out] pointer to core feature flags + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_core_features( + const_tinytc_core_info_t info, tinytc_core_feature_flags_t *flags); + +/** + * @brief Set SPIR-V feature + * + * @param info [inout] core info object + * @param feature [in] SPIR-V feature + * @param available [in] Set to true if feature is available and false otherwise + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_spirv_feature(tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t available); + +/** + * @brief Get SPIR-V feature + * + * @param info [in] core info object + * @param feature [in] SPIR-V feature + * @param available [out] Writes true to available if feature is available and false otherwise + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_have_spirv_feature(const_tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t *available); + +/** + * @brief Get default memref alignment + * + * @param info [in] Core info + * @param alignment [out] pointer to alignment in bytes + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_default_alignment(const_tinytc_core_info_t info, + int32_t *alignment); + +/** + * @brief Set default memref alignment + * + * @param info [inout] Core info + * @param alignment [in] alignment in bytes + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_default_alignment(tinytc_core_info_t info, + int32_t alignment); + +//////////////////////////// +////////// Parser ////////// +//////////////////////////// + +/** + * @brief Parser tensor language source file and create prog + * + * @param prg [out] pointer to prog object created + * @param filename [in] path to source file + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, + tinytc_compiler_context_t ctx); + +/** + * @brief Parser tensor language source from stdin and create prog + * + * @param prg [out] pointer to prog object created + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_compiler_context_t ctx); + +/** + * @brief Parser tensor language source from string and create prog + * + * @param prg [out] pointer to prog object created + * @param source_size [in] length of source string + * @param source [in] source string + * @param ctx [inout][optional] context object; a new context is created if ctx is nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, + char const *source, + tinytc_compiler_context_t ctx); +/** + * @brief Create context + * + * The context stores the tensor language source and reports enhaces error messages with + * source code context. Moreover, the context caches data such as types and constants. + * + * @param ctx [out] pointer to the context object created + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx); + +/** + * @brief Add source context + * + * Manually add a source file to the context that can be referenced in a tinytc_location. + * Useful to enhance error messages when using the builder methods and classes. + * + * @param ctx [in] context object + * @param name [in] source name + * @param text [in] source text + * @param source_id [out] pointer to source id + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler_context_t ctx, + char const *name, char const *text, + int32_t *source_id); + +/** + * @brief Set error reporter + * + * Error reporting function that is called whenever an error occurs in the parser or the builder. + * + * @param ctx [inout] context object + * @param reporter [in] error reporting callback; set to nullptr to disable reporting + * @param user_data [in][optional] pointer to user data that is passed to the callback; can be + * nullptr + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_error_reporter( + tinytc_compiler_context_t ctx, tinytc_error_reporter_t reporter, void *user_data); + +/** + * @brief Sets an optimization flag + * + * The state can be 0 (disabled), 1 (enabled), or -1 (use default according to optimization level). + * + * @param ctx [inout] context object + * @param flag [in] optimization flag + * @param state [in] flag state + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_set_optimization_flag( + tinytc_compiler_context_t ctx, tinytc_optflag_t flag, int32_t state); + +/** + * @brief Set optimization level (from 0 to 2) + * + * @param ctx [inout] context object + * @param level [in] optimization level + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t +tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, int32_t level); + +/** + * @brief Report an error and augment the error with source context + * + * @param ctx [in] context object + * @param location [in] source location + * @param what [in] error description + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_report_error( + tinytc_compiler_context_t ctx, const tinytc_location_t *location, char const *what); + +//////////////////////////// +///////// Compiler ///////// +//////////////////////////// + +/** + * @brief Run a function pass on every function of a program + * + * @param pass_name [in] name of function pass; cf. tinytc_list_function_passes + * @param prg [inout] tensor program; modified as compiler pass is run + * @param info [in][optional] core info object; might be nullptr if core info is not required for + * pass + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info); + +/** + * @brief List function passes + * + * @param names_size [out] pointer to number of function pass names + * @param names [out][range(0,names_size)] pointer to array of C-strings; array owned by tinytc + * + * @return + */ +TINYTC_EXPORT tinytc_status_t tinytc_list_function_passes(size_t *names_size, + char const *const **names); + +/** + * @brief Compile tensor language to SPIR-V + * + * @param mod [out] pointer to the SPIR-V module created + * @param prg [inout] tensor program; modified as compiler passes are run + * @param info [in] core info object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, + const_tinytc_core_info_t info); + +/** + * @brief Compiler tensor language to SPIR-V and assemble + * + * @param bin [out] pointer to the binary object created + * @param prg [inout] tensor program; modified as compiler passes are run + * @param info [in] core info object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble( + tinytc_binary_t *bin, tinytc_prog_t prg, const_tinytc_core_info_t info); + +/** + * @brief Assemble SPIR-V module + * + * @param bin [out] pointer to the binary object created + * @param mod [in] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, + const_tinytc_spv_mod_t mod); + +/** + * @brief Create binary + * + * @param bin [out] pointer to binary object + * @param ctx [in] compiler context + * @param format [in] Bundle format (SPIR-V or Native) + * @param data_size [in] Size of data in bytes + * @param data [in][range(0, data_size)] Binary data; data is copied + * @param core_features [in][optional] requested core features; must be 0 (default) or a + * combination of tinytc_core_feature_flag_t + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, + tinytc_compiler_context_t ctx, + tinytc_bundle_format_t format, size_t data_size, + uint8_t const *data, + tinytc_core_feature_flags_t core_features); + +/** + * @brief Get context object from binary object + * + * The reference count of the context remains unchanged. + * + * @param bin [in] binary object + * @param ctx [out] pointer to context object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_get_compiler_context(const_tinytc_binary_t bin, + tinytc_compiler_context_t *ctx); + +/** + * @brief Get raw binary data + * + * @param bin [in] binary object + * @param format [out] binary format + * @param data_size [out] size of data + * @param data [out] data array; returned pointer is invalidated if the binary object is deleted + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, + tinytc_bundle_format_t *format, + size_t *data_size, uint8_t const **data); +/** + * @brief Get requested core features + * + * @param bin [in] binary object + * @param core_features [out] core features + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_get_core_features( + const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features); + +//////////////////////////// +////////// Recipe ////////// +//////////////////////////// + +/** + * @brief Returns a small batched GEMM recipe + * + * The program contains a kernel for @f$\beta=0@f$ called "gemm_beta0" and a kernel for + * @f$\beta\neq 0@f$ called "gemm". All matrix shapes and strides are known at compile-time. + * + * The signature of the generated kernels gemm and gemm_beta0 is (if A and B are not transposed) + * + * @code + * func @{name}(%alpha: {ty.alpha}, + * %A: memref<{ty.A}x{M}x{K}x?,strided<1,{ldA},{strideA}>>, + * %B: memref<{ty.B}x{K}x{N}x?,strided<1,{ldB},{strideB}>>, + * %beta: {ty.beta}, + * %C: memref<{ty.C}x{M}x{N}x?,strided<1,{ldC},{strideC}>>) + * @endcode + * + * meaning that its kernels need arguments in the following order: + * + * @code + * alpha, A_ptr, howmany, B_ptr, howmany, beta, C_ptr, howmany + * @endcode + * + * @param recipe [out] pointer to the recipe object created + * @param info [in] core info object + * @param number_ty [in] Number types of alpha, A, B, beta, C + * @param tA [in] Transpose A + * @param tB [in] Transpose B + * @param M [in] Number of rows of A, C + * @param N [in] Number of columns of B, C + * @param K [in] Number columns of A, number of rows of B + * @param ldA [in] Leading dimension of A + * @param strideA [in] Number of elements between A-matrices + * @param ldB [in] Leading dimension of B + * @param strideB [in] Number of elements between B-matrices + * @param ldC [in] Leading dimension of C + * @param strideC [in] Number of elements between C-matrices + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_create( + tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_type_t number_ty, + tinytc_transpose_t tA, tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, + int64_t strideA, int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC); + +/** + * @brief Set kernel arguments for small GEMM batched recipe + * + * @param handler [inout] Recipe handler object + * @param howmany [in] Group size + * @param alpha_size [in] Size of alpha argument + * @param alpha_value [in] Pointer to data used for alpha; data is copied + * @param A_type [in] Type of memory object used for A-matrix + * @param A_value [in] Memory object used for A-matrix + * @param B_type [in] Type of memory object used for B-matrix + * @param B_value [in] Memory object used for B-matrix + * @param beta_size [in] Size of beta argument + * @param beta_value [in] Pointer to data used for beta; data is copied + * @param C_type [in] Type of memory object used for C-matrix + * @param C_value [in] Memory object used for C-matrix + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( + tinytc_recipe_handler_t handler, int64_t howmany, size_t alpha_size, const void *alpha_value, + tinytc_mem_type_t A_type, const void *A_value, tinytc_mem_type_t B_type, const void *B_value, + size_t beta_size, const void *beta_value, tinytc_mem_type_t C_type, const void *C_value); + +/** + * @brief Returns a tall and skinny recipe + * + * The program contains a kernel for beta = 0 called "gemm_beta0" and a kernel for beta != 0 + * called "gemm". M (= number of rows of A, C) and strides are dynamic. + * + * The signature of the generated kernels gemm and gemm_beta0 is + * + * @code + * func @{name}(%alpha: {ty.alpha}, + * %A: memref<{ty.A}x?x{K},strided<1,?>>, + * %B: memref<{ty.B}x{K}x{N},strided<1,?>>, + * %beta: {ty.beta}, + * %C: memref<{ty.C}x?x{N},strided<1,?>>) + * @endcode + * + * meaning that its kernels need arguments in the following order: + * + * @code + * alpha, A_ptr, M, ldA, B_ptr, ldB, beta, C_ptr, M, ldC + * @endcode + * + * where ldA, ldB, ldC is the size of stride[1] of A, B, C, respectively. + * + * @param recipe [out] pointer to the recipe object created + * @param info [in] core info object + * @param number_ty [in] Number type of alpha, A, B, beta, C + * @param N [in] Number of columns of B, C + * @param K [in] Number columns of A, number of rows of B + * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have + * the parameter auto-selected + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create(tinytc_recipe_t *recipe, + const_tinytc_core_info_t info, + tinytc_type_t number_ty, + int64_t N, int64_t K, + int32_t M_block_size); + +/** + * @brief Returns a tall and skinny recipe with additional specialization constants + * + * Similar to tinytc_recipe_tall_and_skinny_create but with the additional specialization + * constants M, ldA, ldB, and ldC. + * The specialization constants may be either set to a fixed value or to TINYTC_DYNAMIC. + * Note that if a specialization constant is set to a fixed value then the parameter with the + * same name in tinytc_recipe_tall_and_skinny_set_args is ignored. + * + * Furthermore, the memory alignment may be passed with alignA, alignB, and alignC or set to 0 to + * use the default memory alignment (= size of scalar type). + * + * The generated kernels have the following signature: + * + * @code + * func @{name}(%alpha: {ty.alpha}, + * %A: memref<{ty.A}x{M}x{K},strided<1,{ldA}>>, + * %B: memref<{ty.B}x{K}x{N},strided<1,{ldB}>>, + * %beta: {ty.beta}, + * %C: memref<{ty.C}x{M}x{N},strided<1,{ldC}>>) + * @endcode + * + * @param recipe [out] pointer to the recipe object created + * @param info [in] core info object + * @param number_ty [in] Number type of alpha, A, B, beta, C + * @param M [in] Number of rows of A, C; can be TINYTC_DYNAMIC + * @param N [in] Number of columns of B, C + * @param K [in] Number columns of A, number of rows of B + * @param ldA [in] Leading dimension of A; can be TINYTC_DYNAMIC + * @param ldB [in] Leading dimension of B; can be TINYTC_DYNAMIC + * @param ldC [in] Leading dimension of C; can be TINYTC_DYNAMIC + * @param alignA [in] Memory alignment of A; can be 0 + * @param alignB [in] Memory alignment of B; can be 0 + * @param alignC [in] Memory alignment of C; can be 0 + * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have + * the parameter auto-selected + * + * @return + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( + tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_type_t number_ty, int64_t M, + int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t alignA, int32_t alignB, + int32_t alignC, int32_t M_block_size); + +/** + * @brief Suggest an M block size for tall and skinny recipe + * + * @param info [in] core info object + * @param M_block_size [out] pointer to block size + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size( + const_tinytc_core_info_t info, int32_t *M_block_size); + +/** + * @brief Set kernel arguments for tall and skinny GEMM recipe + * + * @param handler [inout] Recipe handler object + * @param M [in] Size of M-mode + * @param alpha_size [in] Size of alpha argument + * @param alpha_value [in] Pointer to data used for alpha; data is copied + * @param A_type [in] Type of memory object used for A-matrix + * @param A_value [in] Memory object used for A-matrix + * @param ldA [in] Leading dimension of A + * @param B_type [in] Type of memory object used for B-matrix + * @param B_value [in] Memory object used for B-matrix + * @param ldB [in] Leading dimension of B + * @param beta_size [in] Size of beta argument + * @param beta_value [in] Pointer to data used for beta; data is copied + * @param C_type [in] Type of memory object used for C-matrix + * @param C_value [in] Memory object used for C-matrix + * @param ldC [in] Leading dimension of C + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( + tinytc_recipe_handler_t handler, int64_t M, size_t alpha_size, const void *alpha_value, + tinytc_mem_type_t A_type, const void *A_value, int64_t ldA, tinytc_mem_type_t B_type, + const void *B_value, int64_t ldB, size_t beta_size, const void *beta_value, + tinytc_mem_type_t C_type, const void *C_value, int64_t ldC); + +/** + * @brief Get prog object + * + * The reference count of the prog remains unchanged. + * The user must call tinytc_prog_retain if the prog shall be used after the recipe was released. + * + * @param recipe [in] recipe object + * @param prg [out] pointer to prog object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recipe, + tinytc_prog_t *prg); + +/** + * @brief Get binary + * + * The reference count of the binary remains unchanged. + * The user must call tinytc_binary_retain if the binary shall be used after the recipe was + * released. + * + * @param recipe [in] recipe object + * @param bin [out] pointer to binary + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_binary(const_tinytc_recipe_t recipe, + tinytc_binary_t *bin); + +/** + * @brief Get recipe object + * + * The reference count of the recipe remains unchanged. + * The user must call tinytc_recipe_retain if the recipe shall be used after the recipe handler was + * released. + * + * @param handler [in] recipe handler object + * @param recipe [out] pointer to recipe object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t +tinytc_recipe_handler_get_recipe(const_tinytc_recipe_handler_t handler, tinytc_recipe_t *recipe); + +#ifdef __cplusplus +} +#endif + +#endif // CORE_20240409_H diff --git a/include/tinytc/core.hpp b/include/tinytc/core.hpp new file mode 100644 index 00000000..7c1076c1 --- /dev/null +++ b/include/tinytc/core.hpp @@ -0,0 +1,1110 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CORE_20240403_HPP +#define CORE_20240403_HPP + +#include "tinytc/tinytc.h" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// For bit_cast, memcpy for C++ < 2020 +#if __cplusplus >= 202002L +#include +#else +#include +#endif + +namespace tinytc { + +//////////////////////////// +/////////// Error ////////// +//////////////////////////// + +//! Builder exception enhanced with location +class builder_error : public std::exception { + public: + //! ctor; taking location and status code + inline builder_error(status code, location const &loc) : code_(code), loc_(loc) {} + //! Get status code + inline auto code() const noexcept { return code_; } + //! Get location + inline auto loc() const noexcept -> location const & { return loc_; } + //! Get explanatory string + inline char const *what() const noexcept override { return to_string(code_); } + + private: + status code_; + location loc_; +}; + +//! Throw exception for unsuccessful call to C-API +inline void CHECK_STATUS_LOC(tinytc_status_t code, location const &loc) { + if (code != tinytc_status_success) { + throw builder_error(status{std::underlying_type_t(code)}, loc); + } +} + +//////////////////////////// +////////// FP math ///////// +//////////////////////////// + +/** + * @brief IEEE754 floating point format parameters + * + * @tparam ExponentBits Number of exponent bits + * @tparam MantissaBits Number of mantissa bits + */ +template struct ieee754_format { + constexpr static uint32_t exponent_bits = ExponentBits; ///< Number of exponent bits + constexpr static uint32_t mantissa_bits = MantissaBits; ///< Number of mantissa bits + //! Total number of bits + constexpr static uint32_t num_bits = 1 + exponent_bits + mantissa_bits; + //! Bias + constexpr static uint32_t bias = (1 << (exponent_bits - 1)) - 1; + //! Max exponent when encoded with bias added + constexpr static uint32_t max_biased_exponent = (1 << exponent_bits) - 1; + //! Bit mask for sign bit + constexpr static uint32_t sign_mask = 1 << (num_bits - 1); + //! Bit mask for exponent bits + constexpr static uint32_t exponent_mask = max_biased_exponent << mantissa_bits; + //! Bit mask for exponent mantissa bits + constexpr static uint32_t mantissa_mask = (1 << mantissa_bits) - 1; + //! Number of bytes + constexpr static uint32_t num_bytes = 1 + (num_bits - 1) / 8; + //! Unsigned integer type large enough to store bit pattern + using bits_type = std::conditional_t< + num_bytes == 1, std::uint8_t, + std::conditional_t< + num_bytes == 2, std::uint16_t, + std::conditional_t>>>; +}; + +//! Floating point format for bf16 (bfloat16) +using bf16_format = ieee754_format<8, 7>; +//! Floating point format for f16 (half) +using f16_format = ieee754_format<5, 10>; +//! Floating point format for f32 (float) +using f32_format = ieee754_format<8, 23>; + +/** + * @brief Truncate high precision floating point number and return low precision floating point + * number + * + * @tparam F16f low precision floating point format + * @tparam F32f high precision floating point format + * @param x bit pattern of high precision number + * + * @return bit pattern of low precision number + */ +template +constexpr auto ieee754_truncate(typename F32f::bits_type x) -> F16f::bits_type { + using UI = F32f::bits_type; + using UITrunc = F16f::bits_type; + constexpr UI num_shift_bits = F32f::mantissa_bits - F16f::mantissa_bits; + auto const round_nearest_even_and_truncate = [](UI mantissa32) { + constexpr UI midpoint = (1 << num_shift_bits) / 2; + const UI bias = ((mantissa32 >> num_shift_bits) & 0x1) + (midpoint - 1); + return (mantissa32 + bias) >> num_shift_bits; + }; + + const UITrunc sign = (x & F32f::sign_mask) >> (F32f::num_bits - F16f::num_bits); + const UI exponent32 = (x & F32f::exponent_mask) >> F32f::mantissa_bits; + const UI mantissa32 = x & F32f::mantissa_mask; + + UITrunc exponent16 = 0; + UITrunc mantissa16 = 0; + if (exponent32 > F32f::bias + F16f::bias) { + exponent16 = F16f::max_biased_exponent; + // Map numbers except NaN to inf + if (exponent32 < F32f::max_biased_exponent) { + mantissa16 = 0; + } else { + // Need to ceil to make sure that NaN is not truncated to inf + mantissa16 = 1 + ((mantissa32 - 1) >> num_shift_bits); + } + } else if (F32f::bias == F16f::bias || exponent32 > F32f::bias - F16f::bias) { + // convert bias + // E_{32} = e + F32f::bias + // E_{16} = e + F16f::bias + // = E_{32} - F32f::bias + F16f::bias + // = E_{32} - (F32f::bias - F16f::bias) + exponent16 = exponent32 - (F32f::bias - F16f::bias); + mantissa16 = round_nearest_even_and_truncate(mantissa32); + } else if (exponent32 >= F32f::bias + 1 - F16f::bias - F16f::mantissa_bits) { + exponent16 = 0; + mantissa16 = round_nearest_even_and_truncate((mantissa32 | (1 << F32f::mantissa_bits)) >> + ((F32f::bias + 1 - F16f::bias) - exponent32)); + } + + exponent16 <<= F16f::mantissa_bits; + + // Need to add mantissa as it might overflow during rounding and then we need to increase the + // exponent by 1 + return (sign | exponent16) + mantissa16; +} + +/** + * @brief Extend low precision floating point number and return high precision floating point + * number + * + * @tparam F32f high precision floating point format + * @tparam F16f low precision floating point format + * @param x bit pattern of low precision number + * + * @return bit pattern of high precision number + */ +template +constexpr auto ieee754_extend(typename F16f::bits_type x) -> F32f::bits_type { + using UIExt = F32f::bits_type; + const UIExt sign = (x & F16f::sign_mask) << (F32f::num_bits - F16f::num_bits); + const UIExt exponent16 = (x & F16f::exponent_mask) >> F16f::mantissa_bits; + const UIExt mantissa16 = x & F16f::mantissa_mask; + + UIExt exponent32 = exponent16; + UIExt mantissa32 = mantissa16; + if (F32f::exponent_bits != F16f::exponent_bits) { + if (exponent16 == F16f::max_biased_exponent) { + // Inf and NaN + exponent32 = F32f::max_biased_exponent; + } else if (exponent16 != 0) { + // convert bias + // E_{16} = e + F16f::bias + // E_{32} = e + F32f::bias + // = E_{16} - F16f::bias + F32f::bias + // = E_{16} + (F32f::bias - F16f::bias) + exponent32 += F32f::bias - F16f::bias; + } + + // Subnormal f16 numbers must be represented as f32 normal numbers + if (exponent16 == 0 && mantissa16 != 0) { + UIExt shift_count = 0; + do { + mantissa32 <<= 1; + ++shift_count; + } while ((mantissa32 & (1 << F16f::mantissa_bits)) != (1 << F16f::mantissa_bits)); + mantissa32 &= F16f::mantissa_mask; + exponent32 = F32f::bias + 1 - F16f::bias - shift_count; + } + } + + // shift mantissa + mantissa32 <<= F32f::mantissa_bits - F16f::mantissa_bits; + + // shift exponent + exponent32 <<= F32f::mantissa_bits; + + return sign | exponent32 | mantissa32; +} + +/** + * @brief Low precision float type + * + * For all operations, low precision floats are converted single precision, the operation is done in + * single precision, and then the result is stored in the low precision type + * + * @tparam T storage type + * @tparam F16f low precision floating point format + */ +template class lp_float { + public: + using lp_format = F16f; + + constexpr lp_float() = default; + + constexpr lp_float(lp_float const &) = default; + constexpr lp_float(lp_float &&) = default; + constexpr lp_float &operator=(lp_float const &) = default; + constexpr lp_float &operator=(lp_float &&) = default; + +#if __cplusplus >= 202002L +#define TINYTC_LPFLOAT_CONSTEXPR constexpr + //! construct from float + constexpr lp_float(float const &val) + : data_{ieee754_truncate(std::bit_cast(val))} {} + //! assign float + constexpr auto operator=(float const &rhs) -> lp_float & { return *this = lp_float{rhs}; } + //! implicit conversion to float + constexpr operator float() const { + auto bits = ieee754_extend(data_); + return std::bit_cast(bits); + } +#else +#define TINYTC_LPFLOAT_CONSTEXPR + //! construct from float + lp_float(float const &val) { + f32_format::bits_type bits; + memcpy(&bits, &val, sizeof(f32_format::bits_type)); + data_ = ieee754_truncate(bits); + } + //! assign float + auto operator=(float const &rhs) -> lp_float & { return *this = lp_float{rhs}; } + //! implicit conversion to float + operator float() const { + auto bits = ieee754_extend(data_); + float number; + memcpy(&number, &bits, sizeof(f32_format::bits_type)); + return number; + } +#endif + + //! Get bit representation + TINYTC_LPFLOAT_CONSTEXPR auto bits() const -> T { return data_; } + //! Construct lp_float from bit representation + constexpr static auto from_bits(T const &val) -> lp_float { + auto r = lp_float{}; + r.data_ = val; + return r; + } + + //! add + TINYTC_LPFLOAT_CONSTEXPR auto operator+(lp_float const &rhs) const -> lp_float { + return operator float() + static_cast(rhs); + } + //! add to + TINYTC_LPFLOAT_CONSTEXPR auto operator+=(lp_float const &rhs) -> lp_float & { + return *this = *this + rhs; + } + //! subtract + TINYTC_LPFLOAT_CONSTEXPR auto operator-(lp_float const &rhs) const -> lp_float { + return operator float() - static_cast(rhs); + } + //! subtract from + TINYTC_LPFLOAT_CONSTEXPR auto operator-=(lp_float const &rhs) -> lp_float & { + return *this = *this - rhs; + } + //! multiply + TINYTC_LPFLOAT_CONSTEXPR auto operator*(lp_float const &rhs) const -> lp_float { + return operator float() * static_cast(rhs); + } + //! multiply with + TINYTC_LPFLOAT_CONSTEXPR auto operator*=(lp_float const &rhs) -> lp_float & { + return *this = *this * rhs; + } + //! divide + TINYTC_LPFLOAT_CONSTEXPR auto operator/(lp_float const &rhs) const -> lp_float { + return operator float() / static_cast(rhs); + } + //! divide with + TINYTC_LPFLOAT_CONSTEXPR auto operator/=(lp_float const &rhs) -> lp_float & { + return *this = *this / rhs; + } + //! unary minus + TINYTC_LPFLOAT_CONSTEXPR auto operator-() -> lp_float { return -operator float(); } + //! pre-increase by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator++() -> lp_float & { + return *this = operator float() + 1.0f; + } + //! post-increase by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator++(int) -> lp_float { + lp_float tmp = *this; + operator++(); + return tmp; + } + //! pre-decrease by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator--() -> lp_float & { + return *this = operator float() - 1.0f; + } + //! post-decrease by 1 + TINYTC_LPFLOAT_CONSTEXPR auto operator--(int) -> lp_float { + lp_float tmp = *this; + operator--(); + return tmp; + } + //! equal + TINYTC_LPFLOAT_CONSTEXPR auto operator==(lp_float const &rhs) const -> bool { + return operator float() == static_cast(rhs); + } + //! not equal + TINYTC_LPFLOAT_CONSTEXPR auto operator!=(lp_float const &rhs) const -> bool { + return operator float() != static_cast(rhs); + } + //! greater than + TINYTC_LPFLOAT_CONSTEXPR auto operator>(lp_float const &rhs) const -> bool { + return operator float() > static_cast(rhs); + } + //! greater than or equal + TINYTC_LPFLOAT_CONSTEXPR auto operator>=(lp_float const &rhs) const -> bool { + return operator float() >= static_cast(rhs); + } + //! less than + TINYTC_LPFLOAT_CONSTEXPR auto operator<(lp_float const &rhs) const -> bool { + return operator float() < static_cast(rhs); + } + //! less than or equal + TINYTC_LPFLOAT_CONSTEXPR auto operator<=(lp_float const &rhs) const -> bool { + return operator float() <= static_cast(rhs); + } + + private: + T data_; +}; + +/** + * @brief bf16 host emulation type + */ +using bfloat16 = lp_float; +/** + * @brief fp16 host emulation type + */ +using half = lp_float; + +//////////////////////////// +//////// Array view //////// +//////////////////////////// + +/** + * @brief Base implementation of array view + * + * @tparam T array element type + */ +template class array_view_base { + public: + using iterator = T *; + + /** + * @brief Empty array view + */ + array_view_base() = default; + + /** + * @brief Single element view + * + * @param single the single element + */ + array_view_base(T &single) : data_{&single}, size_{1} {} + + /** + * @brief ctor + * + * @param data base pointer + * @param size array size + */ + array_view_base(T *data, std::size_t size) : data_{data}, size_{size} {} + + /** + * @brief ctor + * + * @param begin begin pointer + * @param end end pointer (not included) + */ + array_view_base(T *begin, T *end) + : data_{begin}, size_{static_cast(end - begin)} {} + + //! Begin iterator + auto begin() const -> iterator { return data_; } + //! End iterator + auto end() const -> iterator { return data_ + size_; } + //! Returns true if view is empty + auto empty() const -> bool { return size_ == 0; } + //! Returns array size + auto size() const -> std::size_t { return size_; } + //! Access first element; must not call when array size is 0 + auto front() const -> T & { return data_[0]; } + //! Access last element; must not call when array size is 0 + auto back() const -> T & { return data_[size_ - 1]; } + //! Get data pointer + auto data() const -> T * { return data_; } + //! Access operator + auto operator[](std::size_t n) const -> T & { return data_[n]; } + //! Convert to vector + operator std::vector>() const { + return std::vector>(data_, data_ + size_); + } + //! Equals operator + auto operator==(array_view_base const &other) const -> bool { + bool eq = true; + for (std::size_t i = 0; i < size_; ++i) { + eq = eq && data_[i] == other.data_[i]; + } + return eq; + } + auto operator!=(array_view_base const &other) const -> bool { return !(*this == other); } + + private: + T *data_ = nullptr; + std::size_t size_ = 0; +}; + +/** + * @brief Stores an immutable view on an array (pointer + size) + * + * @tparam T array element type + */ +template class array_view : public array_view_base { + public: + using array_view_base::array_view_base; + + /** + * @brief Convert vector to array view + * + * @param vec standard vector + */ + array_view(std::vector const &vec) + : array_view_base{(!vec.empty() ? vec.data() : nullptr), vec.size()} {} + + /** + * @brief Convert std::array to array view + * + * @tparam N array size + * @param arr standard array + */ + template + array_view(std::array const &arr) : array_view_base{arr.data(), arr.size()} {} + + /** + * @brief Convert initializer list to array view (array_view must be rvalue) + * + * @param arr initializer list + */ + array_view(std::initializer_list const &arr) + : array_view_base{(arr.begin() != arr.end() ? arr.begin() : nullptr), arr.size()} { + } +}; + +template array_view(T const &) -> array_view; +template array_view(T const *, std::size_t) -> array_view; +template array_view(T const *, T const *) -> array_view; + +/** + * @brief Stores a mutable view on an array (pointer + size) + * + * @tparam T array element type + */ +template class mutable_array_view : public array_view_base { + public: + using array_view_base::array_view_base; + + /** + * @brief Convert vector to array view + * + * @param vec standard vector + */ + mutable_array_view(std::vector &vec) + : array_view_base{(!vec.empty() ? vec.data() : nullptr), vec.size()} {} + + /** + * @brief Convert std::array to array view + * + * @tparam N array size + * @param arr standard array + */ + template + mutable_array_view(std::array &arr) : array_view_base{arr.data(), arr.size()} {} +}; + +template mutable_array_view(T &) -> mutable_array_view; +template mutable_array_view(T *, std::size_t) -> mutable_array_view; +template mutable_array_view(T *, T *) -> mutable_array_view; + +//////////////////////////// +///////// Mem type ///////// +//////////////////////////// + +/** + * @brief Guess memory type of memory object + * + * @tparam T memory object type + */ +template struct auto_mem_type; + +/** + * @brief Check whether T maps to a scalar data type + * + * @tparam T type + */ +template +constexpr bool is_supported_scalar_type = std::is_same_v || // i8 + std::is_same_v || // i16 + std::is_same_v || // i32 + std::is_same_v || // i64 + std::is_same_v || // f32 + std::is_same_v || // f64 + std::is_same_v> || // c32 + std::is_same_v>; // c64 + +/** + * @brief True if T is either pointer to a support scalar type or a pointer to a pointer to a + * supported scalar type; void* is fine, too + * + * @tparam T type + */ +template +constexpr bool is_usm_pointer_type = + std::is_same_v || + (std::is_pointer_v && + (is_supported_scalar_type> || + is_supported_scalar_type>>)); + +/** + * @brief Specialize auto_mem_type for pointer to non-class types + * + * All pointers to scalars are assumed to be Unified Shared Memory pointers. + * (Automatic guessing for Shared Virtual Memory pointers not implemented.) + * + * @tparam T memory object type + */ +template struct auto_mem_type>> { + constexpr static mem_type value = mem_type::usm_pointer; ///< Pointer maps to USM pointer type +}; + +/** + * @brief Convenience wrapper for auto_mem_type + * + * @tparam T memory object type + */ +template inline constexpr auto auto_mem_type_v = auto_mem_type::value; + +//! Type-safe wrapper for memory objects +struct mem { + /** + * @brief ctor + * + * @tparam T pointer type or buffer type + * @param value USM / SVM pointer or cl_mem (cl_mem implicitly converts to void*) + * @param type memory object type + */ + template + inline mem(T const value, mem_type type = auto_mem_type_v) : value{value}, type{type} {} + + const void *value; ///< USM / SVM pointer or cl_mem (passed by value) + mem_type type; ///< Memory object type +}; + +//////////////////////////// +///// Compiler context ///// +//////////////////////////// + +/** + * @brief Add compiler to context + * + * @param ctx compiler context + * @param name File name + * @param text Source text + * + * @return Source id (should be set in position.source_id) + */ +inline auto add_source(tinytc_compiler_context_t ctx, char const *name, char const *text) + -> std::int32_t { + std::int32_t source_id; + CHECK_STATUS(tinytc_compiler_context_add_source(ctx, name, text, &source_id)); + return source_id; +} + +/** + * @brief Create compiler context + * + * @return Compiler context + */ +inline auto create_compiler_context() -> shared_handle { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_compiler_context_create(&ctx)); + return shared_handle{ctx}; +} + +/** + * @brief Set error reporter + * + * Error reporting function that is called whenever an error occurs in the parser or the + * builder. + * + * @param ctx compiler context + * @param reporter error reporting callback + * @param user_data pointer to user data that is passed to the callback + */ +inline void set_error_reporter(tinytc_compiler_context_t ctx, tinytc_error_reporter_t reporter, + void *user_data = nullptr) { + CHECK_STATUS(tinytc_compiler_context_set_error_reporter(ctx, reporter, user_data)); +} + +/** + * @brief Sets an optimization flag + * + * The state can be 0 (disabled), 1 (enabled), or -1 (use default according to optimization + * level). + * + * @param ctx compiler context + * @param flag optimization flag + * @param state flag state + */ +inline void set_optimization_flag(tinytc_compiler_context_t ctx, optflag flag, std::int32_t state) { + CHECK_STATUS(tinytc_compiler_context_set_optimization_flag( + ctx, static_cast(flag), state)); +} +/** + * @brief Set optimization level + * + * @param ctx compiler context + * @param level optimization level + */ +inline void set_optimization_level(tinytc_compiler_context_t ctx, std::int32_t level) { + CHECK_STATUS(tinytc_compiler_context_set_optimization_level(ctx, level)); +} +/** + * @brief Enhance error message with compiler context; useful when builder is used + * + * @param ctx compiler context + * @param loc Source location + * @param what Error description + */ +inline void report_error(tinytc_compiler_context_t ctx, location const &loc, char const *what) { + CHECK_STATUS(tinytc_compiler_context_report_error(ctx, &loc, what)); +} + +//////////////////////////// +/////////// Prog /////////// +//////////////////////////// + +/** + * @brief Dump program to stderr + * + * @param p program + */ +inline void dump(tinytc_prog_t p) { CHECK_STATUS(tinytc_prog_dump(p)); } +/** + * @brief Get context + * + * @param p program + * + * @return Compiler context + */ +inline auto get_compiler_context(const_tinytc_prog_t p) + -> shared_handle { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_prog_get_compiler_context(p, &ctx)); + return shared_handle{ctx, true}; +} +/** + * @brief Dump program to file + * + * @param p program + * @param filename Path to file + */ +inline void print_to_file(tinytc_prog_t p, char const *filename) { + CHECK_STATUS(tinytc_prog_print_to_file(p, filename)); +} +/** + * @brief Dump program to string + * + * @param p program + * + * @return C-string (unique handle) + */ +inline auto print_to_string(tinytc_prog_t p) -> unique_handle { + char *str; + CHECK_STATUS(tinytc_prog_print_to_string(p, &str)); + return unique_handle{str}; +} + +//////////////////////////// +/////// SPIR-V Module ////// +//////////////////////////// + +/** + * @brief Dump module to stderr + * + * @param mod SPIR-V module + */ +inline void dump(const_tinytc_spv_mod_t mod) { CHECK_STATUS(tinytc_spv_mod_dump(mod)); } +/** + * @brief Dump module to file + * + * @param mod SPIR-V module + * @param filename Path to file + */ +inline void print_to_file(const_tinytc_spv_mod_t mod, char const *filename) { + CHECK_STATUS(tinytc_spv_mod_print_to_file(mod, filename)); +} +/** + * @brief Dump module to string + * + * @param mod SPIR-V module + * + * @return C-string (unique handle) + */ +inline auto print_to_string(const_tinytc_spv_mod_t mod) -> unique_handle { + char *str; + CHECK_STATUS(tinytc_spv_mod_print_to_string(mod, &str)); + return unique_handle{str}; +} + +//////////////////////////// +//////// Device info /////// +//////////////////////////// + +/** + * @brief Get subgroup sizes + * + * @param info Core info + * + * @return Subgroup sizes + */ +inline auto get_subgroup_sizes(const_tinytc_core_info_t info) -> array_view { + std::size_t sgs_size = 0; + std::int32_t const *sgs = nullptr; + CHECK_STATUS(tinytc_core_info_get_subgroup_sizes(info, &sgs_size, &sgs)); + return array_view(sgs, sgs_size); +} + +/** + * @brief Get register space per subgroup in bytes + * + * @param info Core info + * + * @return Register space + */ +inline auto get_register_space(const_tinytc_core_info_t info) -> std::int32_t { + std::int32_t space; + CHECK_STATUS(tinytc_core_info_get_register_space(info, &space)); + return space; +} + +/** + * @brief Set core features + * + * @param info Core info + * + * @param flags set core features; must be 0 or a combination of tinytc_core_feature_flag_t + */ +inline void set_core_features(tinytc_core_info_t info, tinytc_core_feature_flags_t flags) { + CHECK_STATUS(tinytc_core_info_set_core_features(info, flags)); +} + +/** + * @brief Get core features + * + * @param info Core info + * + * @return Core features + */ +inline auto get_core_features(const_tinytc_core_info_t info) -> tinytc_core_feature_flags_t { + tinytc_core_feature_flags_t flags; + CHECK_STATUS(tinytc_core_info_get_core_features(info, &flags)); + return flags; +} + +/** + * @brief Set SPIR-V feature + * + * @param info Core info + * @param feature SPIR-V feature + * @param available true if feature is available and false otherwise + */ +inline void set_spirv_feature(tinytc_core_info_t info, spirv_feature feature, bool available) { + CHECK_STATUS(tinytc_core_info_set_spirv_feature( + info, static_cast(feature), available)); +} + +/** + * @brief Get SPIR-V feature + * + * @param info Core info + * @param feature SPIR-V feature + * + * @return true if feature is available and false otherwise + */ +inline auto have_spirv_feature(const_tinytc_core_info_t info, spirv_feature feature) -> bool { + tinytc_bool_t available; + CHECK_STATUS(tinytc_core_info_have_spirv_feature( + info, static_cast(feature), &available)); + return available; +} + +/** + * @brief Get default alignment + * + * @param info Core info + * + * @return alignment in bytes + */ +inline auto get_default_alignment(const_tinytc_core_info_t info) -> std::int32_t { + std::int32_t alignment; + CHECK_STATUS(tinytc_core_info_get_default_alignment(info, &alignment)); + return alignment; +} + +/** + * @brief Set default alignment + * + * @param info Core info + * + * @param alignment alignment in bytes + */ +inline void set_default_alignment(tinytc_core_info_t info, std::int32_t alignment) { + CHECK_STATUS(tinytc_core_info_set_default_alignment(info, alignment)); +} + +/** + * @brief Create core info for generic GPUs manually + * + * @param register_space Size of register file per subgroup in bytes + * @param max_work_group_size Maximum size of local work group + * @param sgs Subgrouip sizes + * + * @return Core info + */ +inline auto create_core_info_generic(std::int32_t register_space, std::int32_t max_work_group_size, + array_view sgs) + -> shared_handle { + tinytc_core_info_t info; + CHECK_STATUS(tinytc_core_info_generic_create(&info, register_space, max_work_group_size, + sgs.size(), sgs.data())); + return shared_handle{info}; +} + +/** + * @brief Get core info for Intel GPUs from lookup table + * + * @param arch IP version + * + * @return Core info + */ +inline auto create_core_info_intel_from_arch(intel_gpu_architecture arch) + -> shared_handle { + tinytc_core_info_t info; + CHECK_STATUS(tinytc_core_info_intel_create_from_arch( + &info, static_cast(arch))); + return shared_handle{info}; +} + +/** + * @brief Get core info for Intel GPUs from lookup table + * + * @param name architecture name + * + * @return Core info + */ +inline auto create_core_info_intel_from_name(char const *name) + -> shared_handle { + tinytc_core_info_t info; + CHECK_STATUS(tinytc_core_info_intel_create_from_name(&info, name)); + return shared_handle{info}; +} + +/** + * @brief Create core info for Intel GPUs manually + * + * @param ip_version IP version + * @param num_eus_per_subslice Number of EUs (XVEs) per subslice (XeCore) + * @param num_threads_per_eu Number of hardware threads per EU (XVE) + * @param sgs Subgrouip sizes + * + * @return Core info + */ +inline auto create_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, + std::int32_t num_threads_per_eu, array_view sgs) + -> shared_handle { + tinytc_core_info_t info; + CHECK_STATUS(tinytc_core_info_intel_create(&info, ip_version, num_eus_per_subslice, + num_threads_per_eu, sgs.size(), sgs.data())); + return shared_handle{info}; +} + +//////////////////////////// +////////// Parser ////////// +//////////////////////////// + +/** + * @brief Parse source text from file + * + * @param filename Filename + * @param ctx Compiler context + * + * @return Program + */ +inline auto parse_file(char const *filename, tinytc_compiler_context_t ctx = {}) + -> shared_handle { + tinytc_prog_t prg; + CHECK_STATUS(tinytc_parse_file(&prg, filename, ctx)); + return shared_handle{prg}; +} + +/** + * @brief Parse source text from stdin + * + * @param ctx Compiler context + * + * @return Program + */ +inline auto parse_stdin(tinytc_compiler_context_t ctx = {}) -> shared_handle { + tinytc_prog_t prg; + CHECK_STATUS(tinytc_parse_stdin(&prg, ctx)); + return shared_handle{prg}; +} +/** + * @brief Parse source text from string + * + * @param src Source text + * @param ctx Compiler context + * + * @return Program + */ +inline auto parse_string(std::string const &src, tinytc_compiler_context_t ctx = {}) + -> shared_handle { + tinytc_prog_t prg; + CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), ctx)); + return shared_handle{prg}; +} + +//////////////////////////// +///////// Compiler ///////// +//////////////////////////// + +/** + * @brief Get compiler context + * + * @param bin Binary + * + * @return Compiler context + */ +inline auto get_compiler_context(const_tinytc_binary_t bin) + -> shared_handle { + tinytc_compiler_context_t ctx; + CHECK_STATUS(tinytc_binary_get_compiler_context(bin, &ctx)); + return shared_handle{ctx, true}; +} +/** + * @brief Get core features + * + * @param bin Binary + * + * @return Core features + */ +inline auto get_core_features(const_tinytc_binary_t bin) -> tinytc_core_feature_flags_t { + tinytc_core_feature_flags_t cf; + CHECK_STATUS(tinytc_binary_get_core_features(bin, &cf)); + return cf; +} + +//! Container for raw data +struct raw_binary { + bundle_format format; ///< Bundle format + std::size_t data_size; ///< Size of binary data in bytes + std::uint8_t const *data; ///< Pointer to binary data +}; + +/** + * @brief Get raw data + * + * @param bin Binary + * + * @return Raw data + */ +inline auto get_raw(tinytc_binary_t bin) -> raw_binary { + raw_binary r; + tinytc_bundle_format_t f; + CHECK_STATUS(tinytc_binary_get_raw(bin, &f, &r.data_size, &r.data)); + r.format = bundle_format{std::underlying_type_t(f)}; + return r; +} + +/** + * @brief Create binary + * + * @param ctx Compiler context + * @param format Bundle format (SPIR-V or Native) + * @param data_size Size of data in bytes + * @param data Binary data; data is copied + * @param core_features requested core features; must be 0 (default) or a combination of + * tinytc_core_feature_flag_t + * + * @return Binary + */ +inline auto create_binary(tinytc_compiler_context_t ctx, bundle_format format, + std::size_t data_size, std::uint8_t const *data, + tinytc_core_feature_flags_t core_features) + -> shared_handle { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_binary_create(&bin, ctx, static_cast(format), + data_size, data, core_features)); + return shared_handle{bin}; +} + +/** + * @brief Run a function pass on every function of a program + * + * @param pass_name name of function pass; cf. list_function_passes + * @param prg tensor program; modified as compiler pass is run + * @param info core info object; might be nullptr if core info is not required for pass + */ +inline void run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info = {}) { + CHECK_STATUS(tinytc_run_function_pass(pass_name, prg, info)); +} + +/** + * @brief Get function pass names + * + * @param names_size Number of function pass names + * @param names Array of function pass names + */ +inline void list_function_passes(std::size_t &names_size, char const *const *&names) { + CHECK_STATUS(tinytc_list_function_passes(&names_size, &names)); +} + +/** + * @brief Convert tensor language to SPIR-V + * + * @param prg Program + * @param info Core info + * + * @return SPIR-V module + */ +inline auto compile_to_spirv(tinytc_prog_t prg, const_tinytc_core_info_t info) + -> shared_handle { + tinytc_spv_mod_t mod; + CHECK_STATUS(tinytc_prog_compile_to_spirv(&mod, prg, info)); + return shared_handle{mod}; +} + +/** + * @brief Compile program to SPIR-V and assemble + * + * @param prg Program + * @param info Core info + * + * @return Binary + */ +inline auto compile_to_spirv_and_assemble(tinytc_prog_t prg, const_tinytc_core_info_t info) + -> shared_handle { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info)); + return shared_handle{bin}; +} + +/** + * @brief Assemble SPIR-V module + * + * @param mod [in] SPIR-V module + * + * @return Binary + */ +inline auto spirv_assemble(tinytc_spv_mod_t mod) -> shared_handle { + tinytc_binary_t bin; + CHECK_STATUS(tinytc_spirv_assemble(&bin, mod)); + return shared_handle{bin}; +} + +} // namespace tinytc + +namespace std { +template struct hash> { + size_t operator()(tinytc::lp_float const &val) const noexcept { + using h = hash::lp_format::bits_type>; + return h{}(val.bits()); + } +}; +} // namespace std + +#endif // CORE_20240403_HPP diff --git a/include/tinytc/enums.anko b/include/tinytc/enums.anko new file mode 100644 index 00000000..cc3fb321 --- /dev/null +++ b/include/tinytc/enums.anko @@ -0,0 +1,233 @@ +; Copyright (C) 2025 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +enum @status "Status codes" { + doc_to_string + case %success => 0x0 "Success" + case %bad_alloc => 0x1 "Bad allocation" + case %invalid_arguments => 0x2 "Invalid arguments passed to function" + case %out_of_range => 0x3 "Out of range" + case %runtime_error => 0x4 "General runtime error" + case %internal_compiler_error => 0x5 "Internal compiler error" + case %unsupported_subgroup_size => 0x6 "Device does not support subgroup size" + case %unsupported_work_group_size => 0x7 "Device does not support work-group size" + case %compilation_error => 0x8 "Compilation error" + case %file_io_error => 0x9 "I/O error occured in file operation" + case %parse_error => 0xa "Parse error" + case %unavailable_extension => 0xb "Required vendor extension is unavailable" + case %unsupported_backend => 0xc "Unsupported backend" + case %invalid_kernel_arguments => 0xd "Invalid arguments passed to kernel" + case %unsupported_device => 0xe "Unsupported device" + case %invalid_core_info => 0xf "Invalid core info object (e.g. max work group size is 0 or " + "subgroup sizes vector is empty)" + case %unknown_pass_name => 0x10 "Unknown compiler pass name" + case %not_implemented => 0x11 "Not implemented" + case %compute_runtime_error => 0x12 "Error occured in compute runtime" + ; IR errors + case %ir_out_of_bounds => 0x100 "Argument is out of bounds" + case %ir_invalid_shape => 0x101 "Invalid shape" + case %ir_incompatible_shapes => 0x102 "Incompatible tensor shapes" + case %ir_shape_stride_mismatch => 0x103 "Dimension of shape and stride must match" + case %ir_number_mismatch => 0x104 "Number type mismatch" + case %ir_invalid_number_of_indices => 0x105 "Number of indices must match memref order " + "or must be 1 for group types" + case %ir_expected_boolean => 0x106 "Expected boolean type" + case %ir_expected_number => 0x107 "Expected number type" + case %ir_expected_int => 0x108 "Expected integer type" + case %ir_expected_float => 0x109 "Expected floating point type" + case %ir_expected_complex => 0x10a "Expected complex type" + case %ir_expected_i32 => 0x10b "Expected i32 type" + case %ir_expected_index => 0x10c "Expected index type" + case %ir_expected_coopmatrix => 0x10d "Expected coopmatrix type" + case %ir_expected_coopmatrix_or_number => 0x10e "Expected coopmatrix or number type" + case %ir_expected_coopmatrix_number_or_boolean => 0x10f "Expected coopmatrix type, number type, " + "or boolean type" + case %ir_expected_memref => 0x110 "Expected memref type" + case %ir_expected_memref_or_number => 0x111 "Expected memref or number type" + case %ir_expected_memref_or_group => 0x112 "Expected memref or group type" + case %ir_expected_memref_order_0 => 0x113 "Expected memref of order 0 (scalar)" + case %ir_expected_memref_order_1 => 0x114 "Expected memref of order 1 (vector)" + case %ir_expected_memref_order_2 => 0x115 "Expected memref of order 2 (matrix)" + case %ir_expected_memref_order_0_or_1 => 0x116 "Expected memref of order 0 or 1 (scalar or vector)" + case %ir_expected_memref_order_1_or_2 => 0x117 "Expected memref of order 1 or 2 (vector or matrix)" + case %ir_expected_memref_order_0_1_or_2 => 0x118 "Expected memref of order 0, 1, or 2 " + "(scalar, vector, or matrix)" + case %ir_unexpected_yield => 0x119 "Yield encountered in non-yielding region" + case %ir_yield_mismatch => 0x11a "Number of yielded values does not match number of " + "values yielded by region or the types are different" + case %ir_subview_mismatch => 0x11b "Number of dynamic offsets and sizes must match " + "number of dynamic operands" + case %ir_invalid_slice => 0x11c "Static offset and size must be non-negative or " + "dynamic (?)" + case %ir_expand_shape_order_too_small => 0x11d "Expand shape must have at least 2 entries" + case %ir_expand_shape_mismatch => 0x11e "Number of dynamic expand shape operands must equal " + "number of dynamic modes in static expand shape" + case %ir_collective_called_from_spmd => 0x11f "Collective instruction must not be called from " + "SPMD region" + case %ir_fp_unsupported => 0x120 "Floating point type unsupported by instruction" + case %ir_spmd_called_from_collective => 0x121 "SPMD instruction must not be called from collective " + "region" + case %ir_expected_local_address_space => 0x122 "Expected memref with local address space" + case %ir_expected_global_address_space => 0x123 "Expected memref with global address space" + case %ir_address_space_mismatch => 0x124 "Address space must match" + case %ir_invalid_offset => 0x125 "Offset must be non-negative or dynamic" + case %ir_int_unsupported => 0x126 "Instruction does not support int type" + case %ir_boolean_unsupported => 0x127 "Instruction does not support boolean type" + case %ir_complex_unsupported => 0x128 "Instruction does not support complex type" + case %ir_coopmatrix_unsupported => 0x129 "Instruction does not support coopmatrix type" + case %ir_forbidden_cast => 0x12a "Forbidden cast" + case %ir_invalid_beta => 0x12b "beta must be constant and 0 or 1 for atomic linear " + "algebra operations" + case %ir_init_return_mismatch => 0x12c "The number or types of the initial values does not " + "match the return type list" + case %ir_invalid_matrix_use => 0x12d "Operands have invalid matrix use" + case %ir_unsupported_coopmatrix_shape => 0x12e "Unsupported coopmatrix shape for the combination of " + "Number type, matrix use, and target architecture" + case %ir_forbidden_promotion => 0x12f "Number type promotion is forbidden" + case %ir_constant_mismatch => 0x130 "Type of constant does not match type or returned " + "value" + case %ir_insufficient_alignment => 0x131 "Pointer does not satisfy minimum alignment " + "requirements" + case %ir_must_have_yield => 0x132 "Last instruction of region that returns values must " + "be yield" + case %ir_yield_in_else_branch_missing => 0x133 "Else-branch must have yield instruction if " + "then-branch has yield instruction" + case %ir_from_to_mismatch => 0x134 "length(from) != length(to) in foreach" + case %ir_operand_type_must_match_return_type => 0x135 "Operand type must match return type" + case %ir_invalid_stride => 0x136 "Invalid stride" + case %ir_init_return_type_mismatch => 0x137 "Type of initializer does not match return type or " + "the number of return types is not equal the number " + "of initializers" + case %ir_value_still_has_uses => 0x138 "A value shall be erased that still has uses" + case %ir_expected_array_attribute => 0x139 "Expected array attribute" + case %ir_expected_boolean_attribute => 0x140 "Expected boolean attribute" + case %ir_expected_dictionary_attribute => 0x141 "Expected dictionary attribute" + case %ir_expected_integer_attribute => 0x142 "Expected integer attribute" + case %ir_expected_string_attribute => 0x143 "Expected string attribute" + case %ir_duplicate_key_in_dictionary => 0x144 "Duplicate key detected in list of named attributes" + case %ir_unexpected_array_attribute_size => 0x145 "Unexpected array size" + case %ir_expected_non_scalar_memref => 0x146 "Expected memref of dimension greater or equal than 1" + case %ir_complex_number_type_unsupported => 0x147 "Complex number type not supported" + ; SPIR-V errors + case %spirv_forbidden_forward_declaration => 0x1000 "Forward declaration of id is forbidden" + case %spirv_undefined_value => 0x1001 "Undefined SPIR-V value" + case %spirv_missing_dope_vector => 0x1002 "Dope vector missing" + case %spirv_unsupported_atomic_data_type => 0x1003 "Atomic data type unsupported by SPIR-V" + case %spirv_required_feature_unavailable => 0x1004 "A required SPIR-V feature is unavailable" + ; The unknown error comes last + case %unknown => 0x7fffffff "Unknown status code" +} + +enum @comp3 "Named components of 3d vector" { + case %x => 0 ".x" + case %y => 1 ".y" + case %z => 2 ".z" +} + +enum @reduce_mode "Reduce mode" { + case %row => 0 "Reduction over rows" + case %column => 1 "Reduction over columns" +} + +enum @transpose "Transpose" { + doc_to_string + case %N => 0 "n" + case %T => 1 "t" +} + +enum @address_space "Address space" { + case %global => 0x1 "Global memory" + case %local => 0x2 "Local memory, returned by alloca" +} + +enum @checked_flag +"@brief Checked flag + +Checks can be combined by bitwise or, that is, + +tinytc_checked_flag_both = tinytc_checked_flag_rows | tinytc_checked_flag_cols +tinytc_checked_flag_rows = tinytc_checked_flag_rows | tinytc_checked_flag_none" +{ + case %none => 0x0 "Perform no checks" + case %rows => 0x1 "Check for out-of-bound rows" + case %cols => 0x2 "Check for out-of-bound cols" + case %both => 0x3 "Check for out-of-bound rows and cols" +} + +enum @store_flag "Store flag" { + case %regular => 0 "Non-atomic store" + case %atomic => 1 "Atomic store" + case %atomic_add => 2 "Atomic fetch add" + case %atomic_max => 3 "Atomic fetch max" + case %atomic_min => 4 "Atomic fetch min" +} + +enum @matrix_use "Matrix use" { + doc_to_string + case %a => 0 "matrix_a" + case %b => 1 "matrix_b" + case %acc => 2 "matrix_acc" +} + +enum @spirv_feature "SPIR-V features" { + case %float16 => 0 "f16 support" + case %float64 => 1 "f64 support" + case %int64_atomics => 2 "i64 atomics support" + case %groups => 3 "work group collectives" + case %subgroup_dispatch => 4 "subgroup support" + case %atomic_float16_add_local => 5 "f16 atomic add on local pointer" + case %atomic_float16_add_global => 6 "f16 atomic add on global pointer" + case %atomic_float32_add_local => 7 "f32 atomic add on local pointer" + case %atomic_float32_add_global => 8 "f32 atomic add on global pointer" + case %atomic_float64_add_local => 9 "f64 atomic add on local pointer" + case %atomic_float64_add_global => 10 "f64 atomic add on global pointer" + case %atomic_float16_min_max_local => 11 "f16 atomic min/max on local pointer" + case %atomic_float16_min_max_global => 12 "f16 atomic min/max on global pointer" + case %atomic_float32_min_max_local => 13 "f32 atomic min/max on local pointer" + case %atomic_float32_min_max_global => 14 "f32 atomic min/max on global pointer" + case %atomic_float64_min_max_local => 15 "f64 atomic min/max on local pointer" + case %atomic_float64_min_max_global => 16 "f64 atomic minmax on global pointer" + case %bfloat16_conversion => 17 "bf16 -> f32 and f32 -> bf16 conversion" + case %subgroup_buffer_block_io => 18 "subgroup block read/write support" +} + +enum @core_feature_flag "Core features that may be optionally enabled" { + case %large_register_file => 0x1 "Request a large register file. " + "On PVC this doubles the number of registers per vector engine " + "but halves the number of available hardware threads. " + "When this feature is activated, the kernel is compiled with " + "the -ze-opt-large-register-file option." +} + +enum @intel_gpu_architecture +"@brief IP versions for Intel GPUs + +Note: IP versions are extracted from +* https://github.com/intel/compute-runtime/blob/4b5d5f235abf0ff67c9188f8096afd4da2e0574d/third_party/aot_config_headers/platforms.h +* https://github.com/intel/llvm/blob/56e9067ba69809fb6ea1fd4328456ca3a009f984/sycl/source/detail/device_info.hpp#L619" +{ + case %tgl => 0x03000000 "Tiger Lake" + case %pvc => 0x030f0000 "Ponte Vecchio" + case %bmg => 0x05004000 "Battlemage" +} + +enum @bundle_format "Target binary format" { + case %spirv => 0 "SPIR-V" + case %native => 1 "Native device binary" +} + +enum @optflag "Flags for optimizer" { + case %unsafe_fp_math => 0 "Unsafe floating point math (e.g. 0.0 * x => 0.0)" +} + +enum @mem_type "Memory object type" { + case %buffer => 0x0 "Buffer object (e.g. cl_mem)" + case %usm_pointer => 0x1 "Unified shared memory pointer" + case %svm_pointer => 0x2 "Shared virtual memory pointer" +} + +enum @support_level "Support level of a device" { + case %none => 0x0 "Device is unsupported (e.g. subgroups feature missing in OpenCL-C)" + case %basic => 0x1 "Device provides necessary features but is not well tested" + case %tuned => 0x2 "Device provides necessary features and is well tested" +} diff --git a/include/tinytc/instructions.anko b/include/tinytc/instructions.anko new file mode 100644 index 00000000..849a5d6f --- /dev/null +++ b/include/tinytc/instructions.anko @@ -0,0 +1,382 @@ +; Copyright (C) 2025 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +include "tinytc/enums.anko" + +inst @alloca "Alloca instruction" { + collective + prop %stack_ptr private => i64 + ret %result "memref type of allocated variable" +} + +inst @barrier "Barrier instruction" { + prop %fence_flags => "tinytc_address_spaces_t" "address space(s) of memory fence; " + "set to 0 for no fence" + cxx "auto has_fence(address_space as) -> bool;" +} + +inst @cast "Cast instruction" { + op %a "operand" + ret %result "target type" +} + +inst @constant "Constant instruction" { + skip_builder + prop %value => "constant_value_type" "constant value" + ret %result "type of constant" + cxx "auto is_zero() -> bool;" + cxx "auto is_identity() -> bool;" +} + +inst @cooperative_matrix_apply "Cooperative matrix apply instruction" { + spmd + op %a "matrix" + reg %body "instructions to apply to coopmatrix entries" + ret %result "result type" + cxx "inline auto row() -> tinytc_value & { return body().param(0); }" + cxx "inline auto col() -> tinytc_value & { return body().param(1); }" + cxx "inline auto val() -> tinytc_value & { return body().param(2); }" +} + +inst @cooperative_matrix_extract "Cooperative matrix extract instruction" { + spmd + prop %index => i64 "index" + op %mat "matrix" + ret %result "result type" +} + +inst @cooperative_matrix_insert "Cooperative matrix insert instruction" { + spmd + prop %index => i64 "index" + op %val "value to insert" + op %mat "matrix" + ret %result "result type" +} + +inst @cooperative_matrix_load "Cooperative matrix load instruction" { + spmd + prop %t => @transpose "transposed load" + prop %checked => @checked_flag "boundary check" + op %operand "matrix" + op %pos0 "row offset" + op %pos1 "column offset" + ret %result "result type" +} + +inst @cooperative_matrix_mul_add "Cooperative matrix mul add instruction" { + spmd + op %a "a matrix" + op %b "b matrix" + op %c "c matrix" + ret %result "result type" + cxx "auto is_c_zero() -> bool;" +} + +inst @cooperative_matrix_prefetch "Cooperative matrix prefetch instruction" { + spmd + prop %cache_level => i32 "cache level; 0 is closest to the core" + prop %rows => i32 "number of rows" + prop %cols => i32 "number of columns" + op %operand "matrix" + op %pos0 "row offset" + op %pos1 "column offset" +} + +inst @cooperative_matrix_reduce "Cooperative matrix reduce instruction" { + prop %mode => @reduce_mode "reduce mode" + op %a "matrix" + ret %result "result type" +} +inst @cooperative_matrix_reduce_add : @cooperative_matrix_reduce "Cooperative matrix reduce add instruction" { spmd } +inst @cooperative_matrix_reduce_max : @cooperative_matrix_reduce "Cooperative matrix reduce max instruction" { spmd } +inst @cooperative_matrix_reduce_min : @cooperative_matrix_reduce "Cooperative matrix reduce min instruction" { spmd } + +inst @cooperative_matrix_scale "Cooperative matrix scale instruction" { + spmd + op %a "scalar" + op %b "matrix" + ret %result "result type" +} + +inst @cooperative_matrix_store "Cooperative matrix store instruction" { + spmd + prop %t => @transpose "transposed store" + prop %checked => @checked_flag "boundary check" + prop %flag => @store_flag "store mode" + op %val "value to store" + op %operand "matrix" + op %pos0 "row offset" + op %pos1 "column offset" +} + + +inst @expand "Expand instruction" { + prop %expanded_mode => i64 "expanded mode" + prop* %static_expand_shape => i64 "static expand shape" + op %operand "tensor" + op* %expand_shape "dynamic expand shape" + ret %result "result type" +} + +inst @fuse "Fuse instruction" { + prop %from => i64 "first mode to fuse" + prop %to => i64 "last mode to fuse" + op %operand "tensor" + ret %result "result type" +} + +inst @if "If instruction"{ + op %condition "condition" + reg %then "then block" + reg %otherwise "else block" + ret* %results "return type array" + cxx "auto is_otherwise_empty() -> bool;" +} + +inst @lifetime_stop "Lifetime stop instruction" { + collective + op %object "stack object whose lifetime ends" +} + + +inst @load "Load instruction" { + op %operand "tensor or group of tensors" + op* %index_list "indices" + ret %result "result type" +} + +inst @parallel "Parallel instruction" { + collective + reg %body "parallel region" +} + +inst @size "Size instruction" { + prop %mode => i64 "mode for which size is extracted" + op %operand "tensor" + ret %result "result type" +} + +inst @subgroup_broadcast "Subgroup broadcast instruction" { + spmd + op %a "operand" + op %idx "subgroup local index" + ret %result "result index" +} + +inst @subview "Subview instruction" { + prop* %static_offsets => i64 "static offsets (need to add value to dynamic offsets " + "if static_offsets[i] == TINYTC_DYNAMIC)" + prop* %static_sizes => i64 "static sizes (need to add value to dynamic sizes " + "if static_sizes[i] == TINYTC_DYNAMIC)" + op %operand "operand" + op* %offsets "dynamic offsets" + op* %sizes "dynamic sizes" + ret %result "resulting memref type" +} + +inst @store "Store instruction" { + prop %flag => @store_flag "store flag" + op %val "value to store" + op %operand "operand" + op* %index_list "indices" +} + +inst @yield "Yield instruction" { + op* %yielded_vals "yielded values" +} + + +; Binary aritmetic + +inst @arith "Arithmetic instruction (binary)" { + op %a "left-hand operand" + op %b "right-hand operand" + ret %result "result type" + cxx "void setup_and_check(support_flags support);" +} +inst @add : @arith "Addition instruction" {} +inst @sub : @arith "Subtraction instruction" {} +inst @mul : @arith "Multiplication instruction" {} +inst @div : @arith "Division instruction" {} +inst @rem : @arith "Remainder instruction" {} +inst @max : @arith "Max instruction" {} +inst @min : @arith "Min instruction" {} +inst @shl : @arith "Shift-left instruction" {} +inst @shr : @arith "Shift-right instruction" {} +inst @and : @arith "Bitwise and instruction" {} +inst @or : @arith "Bitwise or instruction" {} +inst @xor : @arith "Bitwise xor instruction" {} + + +; Unary aritmetic + +inst @arith_unary "Arithmetic instruction (unary)" { + op %a "operand" + ret %result "result type" + cxx "void setup_and_check(support_flags support, bool component_type_match = false);" +} +inst @neg : @arith_unary "Negation instruction" {} +inst @not : @arith_unary "Bitwise not instruction" {} +inst @abs : @arith_unary "Absolute value instruction" {} +inst @conj : @arith_unary "Complex conjugate instruction" {} +inst @im : @arith_unary "Imaginary part instruction" {} +inst @re : @arith_unary "Real part instruction" {} + + +; BLAS with 2 tensor operands + +inst @blas_a2 "BLAS instruction with 2 tensor operands" { + prop %atomic => bool "atomic flag" + op %alpha "alpha scalar" + op %A "A tensor" + op %beta "beta scalar" + op %B "B tensor" +} + +inst @axpby : @blas_a2 "AXPBY instruction" { + collective + prop %tA => @transpose "transpose A" +} + +inst @cumsum : @blas_a2 "Cumsum instruction" { + collective + prop %mode => i64 "sum mode" +} + +inst @sum : @blas_a2 "Sum instruction" { + collective + prop %tA => @transpose "transpose A" +} + + +; BLAS with 3 tensor operands + +inst @blas_a3 "BLAS instruction with 3 tensor operands" { + prop %atomic => bool "atomic flag" + op %alpha "alpha scalar" + op %A "A tensor" + op %B "B tensor" + op %beta "beta scalar" + op %C "C tensor" +} + +inst @gemm : @blas_a3 "GEMM instruction" { + collective + prop %tA => @transpose "transpose A" + prop %tB => @transpose "transpose B" +} + +inst @gemv : @blas_a3 "GEMV instruction" { + collective + prop %tA => @transpose "transpose A" +} + +inst @ger : @blas_a3 "GER instruction" { + collective +} + +inst @hadamard : @blas_a3 "Hadamard instruction" { + collective +} + + +; Built-in functions + +inst @builtin "Builtin instruction" { + ret %result "result type" +} +inst @group_id : @builtin "Group ID instruction" { + prop %mode => @comp3 "mode" +} +inst @num_groups : @builtin "Number of groups instruction" { + prop %mode => @comp3 "mode" +} +inst @num_subgroups : @builtin "Number of subgroups instruction" { + prop %mode => @comp3 "mode" +} +inst @subgroup_size : @builtin "Subgroup size instruction" {} +inst @subgroup_id : @builtin "Subgroup id instruction" { + spmd + prop %mode => @comp3 "mode" +} +inst @subgroup_linear_id : @builtin "Subgroup linear id instruction" { + spmd +} +inst @subgroup_local_id : @builtin "Subgroup local id instruction" { + spmd +} + +; Comparison + +inst @compare "Comparison instruction" { + op %a "left-hand operand" + op %b "right-hand operand" + ret %result "result type" + cxx "void setup_and_check(support_flags support);" +} +inst @equal : @compare "Equal instruction" {} +inst @not_equal : @compare "Not equal instruction" {} +inst @greater_than : @compare "Greater instruction" {} +inst @greater_than_equal : @compare "Greater than or equal instruction" {} +inst @less_than : @compare "Less than instruction" {} +inst @less_than_equal : @compare "Less than or equal instruction" {} + + +; Loops + +inst @loop "Loop instruction" { + reg %body "loop body" +} + +inst @for : @loop "Create for loop" { + op %from "loop begin (inclusive)" + op %to "loop end (exclusive)" + op? %step "loop step" + op* %iter_init "array of initial values" + ret* %results "array of return types" + cxx "inline auto loop_var() -> tinytc_value & { return body().param(0); }" + cxx "inline auto iter_arg(std::int64_t no) -> tinytc_value & { return body().param(no + 1); }" +} + +inst @foreach : @loop "Create foreach loop" { + collective + op* %from "array of loop begin (inclusive)" + op* %to "array of loop end (exclusive)" + cxx "inline auto dim() -> std::int32_t { return get().num_operands() / 2; }" + cxx "inline auto loop_vars() { return body().params(); }" +} + +; Math + +inst @math_unary "Math instruction" { + op %a "argument" + ret %result "result type" + cxx "void setup_and_check(support_flags support);" +} +inst @cos : @math_unary "Cosine instruction" {} +inst @sin : @math_unary "Sine instruction" {} +inst @exp : @math_unary "Base-e exponential instruction" {} +inst @exp2 : @math_unary "Base-2 exponential instruction" {} +inst @native_cos : @math_unary "Cosine instruction (native)" {} +inst @native_sin : @math_unary "Sine instruction (native)" {} +inst @native_exp : @math_unary "Base-e exponential instruction (native)" {} +inst @native_exp2 : @math_unary "Base-2 exponential instruction (native)" {} + + +; Subgroup operation + +inst @subgroup_operation "Subgroup operation instruction" { + op %a "operand" + ret %result "result type" + cxx "void setup_and_check(support_flags support);" +} +inst @subgroup_exclusive_scan_add : @subgroup_operation "Subgroup exclusive scan add" { spmd } +inst @subgroup_exclusive_scan_max : @subgroup_operation "Subgroup exclusive scan max" { spmd } +inst @subgroup_exclusive_scan_min : @subgroup_operation "Subgroup exclusive scan min" { spmd } +inst @subgroup_inclusive_scan_add : @subgroup_operation "Subgroup inclusive scan add" { spmd } +inst @subgroup_inclusive_scan_max : @subgroup_operation "Subgroup inclusive scan max" { spmd } +inst @subgroup_inclusive_scan_min : @subgroup_operation "Subgroup inclusive scan min" { spmd } +inst @subgroup_reduce_add : @subgroup_operation "Subgroup reduce add" { spmd } +inst @subgroup_reduce_max : @subgroup_operation "Subgroup reduce max" { spmd } +inst @subgroup_reduce_min : @subgroup_operation "Subgroup reduce min" { spmd } + diff --git a/include/tinytc/tinytc.h b/include/tinytc/tinytc.h index fb70475c..8630981a 100644 --- a/include/tinytc/tinytc.h +++ b/include/tinytc/tinytc.h @@ -1,1577 +1,13 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#ifndef TINYTC_20240409_H -#define TINYTC_20240409_H +#ifndef TINYTC_20250704_H +#define TINYTC_20250704_H +#include "tinytc/builder.h" +#include "tinytc/core.h" #include "tinytc/export.h" #include "tinytc/types.h" #include "tinytc/version.h" -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -//////////////////////////// -/////////// Error ////////// -//////////////////////////// - -#define TINYTC_CHECK_STATUS(X) \ - do { \ - tinytc_status_t status = X; \ - if (status != tinytc_status_success) { \ - return status; \ - } \ - } while (0) - -/** - * @brief Translate status code to textual description - * - * @param status [in] status code - * - * @return String - */ -TINYTC_EXPORT char const *tinytc_error_string(tinytc_status_t status); - -//////////////////////////// -//////// Scalar type /////// -//////////////////////////// - -//! Convert scalar type to string -TINYTC_EXPORT char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty); -//! Size of scalar type in bytes -TINYTC_EXPORT size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty); - -//////////////////////////// -///////// Data type //////// -//////////////////////////// - -/** - * @brief Create scalar data type - * - * @param dt [out] pointer to the data type object created - * @param type [in] scalar type - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); - -/** - * @brief Create memref data type - * - * @param dt [out] pointer to the data type object created - * @param scalar_ty [in] element type - * @param shape_size [in] tensor order; number of elements in shape array, must be 0 if shape == - * nullptr - * @param shape [in][range(0, shape_size)] array of mode sizes - * @param stride_size [in][optional] number of elements in stride array; must be either 0 for - * automatic stride calculation or must match shape_size; must be 0 if stride == nullptr - * @param stride [in][optional][range(0, stride_size)] stride array - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, - tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, - uint32_t stride_size, const int64_t *stride, - const tinytc_location_t *loc); - -/** - * @brief Create group data type - * - * @param dt [out] pointer to the data type object created - * @param memref_ty [in] memref data type object - * @param offset [in][optional] offset parameter; pass 0 for default - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_type_create(tinytc_data_type_t *dt, - tinytc_data_type_t memref_ty, int64_t offset, - const tinytc_location_t *loc); - -/** - * @brief Release data type object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param dt [inout] data type object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_data_type_release(tinytc_data_type_t dt); - -/** - * @brief Increase reference count of data type object by 1 - * - * @param dt [inout] data type object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_data_type_retain(tinytc_data_type_t dt); - -//////////////////////////// -/////////// Value ////////// -//////////////////////////// - -/** - * @brief Create value - * - * @param vl [out] pointer to the value object created - * @param type [in] data type object - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_create(tinytc_value_t *vl, tinytc_data_type_t type, - const tinytc_location_t *loc); - -/** - * @brief Create floating point immediate value - * - * @param vl [out] pointer to the value object created - * @param imm [in] immediate value - * @param type [in] type of immediate value - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_float_imm_create(tinytc_value_t *vl, double imm, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); -/** - * @brief Create integer immediate value - * - * @param vl [out] pointer to the value object created - * @param imm [in] immediate value - * @param type [in] type of immediate value - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_int_imm_create(tinytc_value_t *vl, int64_t imm, - tinytc_scalar_type_t type, - const tinytc_location_t *loc); - -/** - * @brief Release value object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param vl [inout] value object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_release(tinytc_value_t vl); - -/** - * @brief Increase reference count of value object by 1 - * - * @param vl [inout] value object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_retain(tinytc_value_t vl); - -/** - * @brief Set name of value - * - * @param vl [inout] value object - * @param name [in] name; null-terminated string - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name); - -/** - * @brief Get name of value - * - * The returned pointer may be invalidated if the value or any node in the abstract syntax - * tree referencing the value is modified. - * - * @param vl [in] value object - * @param name [out] pointer to C string - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name); - -//////////////////////////// -/////// Instructions /////// -//////////////////////////// - -//! Convert arithmetic operation type to string -TINYTC_EXPORT char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op); -//! Convert arithmetic operation type to string (unary) -TINYTC_EXPORT char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op); -//! Convert cmp condition to string -TINYTC_EXPORT char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond); -//! Convert transpose to string -TINYTC_EXPORT char const *tinytc_transpose_to_string(tinytc_transpose_t t); - -/** - * @brief Create arithmetic instruction (binary) - * - * @code %value = arith. %a, %b : type(%a) ; type(%a) == type(%b) @endcode - * - * @param instr [out] pointer to the inst object created - * @param op [in] arithmetic operation type - * @param a [in] left-hand operand - * @param b [in] right-hand operand - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, - tinytc_value_t a, tinytc_value_t b, - const tinytc_location_t *loc); - -/** - * @brief Create arithmetic instruction (unary) - * - * @code %value = arith. %a : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param op [in] unary arithmetic operation type - * @param a [in] operand - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, - tinytc_arithmetic_unary_t op, - tinytc_value_t a, - const tinytc_location_t *loc); - -/** - * @brief Create cast instruction - * - * @code %value = cast %a, %b : type(%a) -> %to_ty @endcode - * - * @param instr [out] pointer to the inst object created - * @param a [in] operand - * @param to_ty [in] target type - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_scalar_type_t to_ty, - const tinytc_location_t *loc); - -/** - * @brief Create binary op instruction - * - * @code %value = cmp. %a, %b : type(%a) ; type(%a) == type(%b) @endcode - * - * @param instr [out] pointer to the inst object created - * @param cond [in] compare type - * @param a [in] left-hand operand - * @param b [in] right-hand operand - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, - tinytc_cmp_condition_t cond, tinytc_value_t a, - tinytc_value_t b, - const tinytc_location_t *loc); - -/** - * @brief Create alloca instruction - * - * @code %value = alloca -> %ty @endcode - * - * @param instr [out] pointer to the inst object created - * @param ty [in] type that is allocated - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, - const tinytc_location_t *loc); - -/** - * @brief Create axpby instruction - * - * @code - * axpby.. %alpha, %A, %beta, %B : type(%alpha), type(%A), type(%beta), type(%B) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param tA [in] operation applied on A - * @param atomic [in] true for atomic updates of B - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param beta [in] @f$\beta@f$ - * @param B [in] B - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_bool_t atomic, tinytc_value_t alpha, - tinytc_value_t A, tinytc_value_t beta, - tinytc_value_t B, - const tinytc_location_t *loc); - -/** - * @brief Create expand instruction - * - * @code %value = expand %a[%mode -> %expand_shape] : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param a [in] operand - * @param mode [in] expanded mode - * @param expand_shape_size [in] dimension of expand shape; must be at least 2 - * @param expand_shape [in][range(2, expand_shape_size)] expand shape array - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t mode, uint32_t expand_shape_size, - tinytc_value_t *expand_shape, - const tinytc_location_t *loc); - -/** - * @brief Create fuse instruction - * - * @code %value = fuse %a[%from, %to] : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param a [in] operand - * @param from [in] first mode to fuse - * @param to [in] last mode to fuse - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t from, int64_t to, - const tinytc_location_t *loc); - -/** - * @brief Create load instruction - * - * @code %value = load %a[%index_list] : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param a [in] operand - * @param index_list_size [in] number of indices - * @param index_list [in][range(0, index_list_size)] indices array; may be nullptr if - * index_list_size is 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t index_list_size, - tinytc_value_t *index_list, - const tinytc_location_t *loc); -/** - * @brief Create group_id instruction - * - * @code %value = group_id @endcode - * - * @param instr [out] pointer to the inst object created - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, - const tinytc_location_t *loc); - -/** - * @brief Create group_size instruction - * - * @code %value = group_size @endcode - * - * @param instr [out] pointer to the inst object created - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, - const tinytc_location_t *loc); - -/** - * @brief Create GEMM instruction - * - * @code - * gemm... %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param tA [in] operation applied on A - * @param tB [in] operation applied on B - * @param atomic [in] true for atomic updates of C - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param B [in] B - * @param beta [in] @f$\beta@f$ - * @param C [in] C - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_transpose_t tB, tinytc_bool_t atomic, - tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, - tinytc_value_t C, - const tinytc_location_t *loc); - -/** - * @brief Create GEMV instruction - * - * @code - * gemv.. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param tA [in] operation applied on A - * @param atomic [in] true for atomic updates of C - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param B [in] B - * @param beta [in] @f$\beta@f$ - * @param C [in] C - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_bool_t atomic, tinytc_value_t alpha, - tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, - const tinytc_location_t *loc); - -/** - * @brief Create GER instruction - * - * @code - * ger. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param atomic [in] true for atomic updates of C - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param B [in] B - * @param beta [in] @f$\beta@f$ - * @param C [in] C - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, - tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, - tinytc_value_t C, - const tinytc_location_t *loc); - -/** - * @brief Create Hadamard instruction - * - * @code - * hadamard. %alpha, %A, %B, %beta, %C - * : type(%alpha), type(%A), type(%B), type(%beta), type(%C) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param atomic [in] true for atomic updates of C - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param B [in] B - * @param beta [in] @f$\beta@f$ - * @param C [in] C - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_hadamard_inst_create( - tinytc_inst_t *instr, tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, const tinytc_location_t *loc); - -/** - * @brief Create size instruction - * - * @code %value = size %a[%mode] : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param a [in] operand - * @param mode [in] mode for that the size is queried - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - int64_t mode, const tinytc_location_t *loc); - -/** - * @brief Create subview instruction - * - * @code %value = subview %a[%offset1:%size1,...,%offsetN:%sizeN] : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param a [in] operand - * @param slice_list_size [in] number of slices - * @param offset_list [in][range(0, slice_list_size)] offset array; may be nullptr if - * slice_list_size is 0 - * @param size_list [in][range(0, slice_list_size)] size array; may be nullptr if slice_list_size - * is 0; size_list[i] may be nullptr if a single offset shall be passed instead of a range for the - * i-th mode - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t slice_list_size, - tinytc_value_t *offset_list, - tinytc_value_t *size_list, - const tinytc_location_t *loc); - -/** - * @brief Create store instruction - * - * @code store %val, %a[%index_list] : type(%a) @endcode - * - * @param instr [out] pointer to the inst object created - * @param val [in] value to store - * @param a [in] operand - * @param index_list_size [in] number of indices - * @param index_list [in][range(0, index_list_size)] indices array; may be nullptr if - * index_list_size is 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, - tinytc_value_t a, uint32_t index_list_size, - tinytc_value_t *index_list, - const tinytc_location_t *loc); - -/** - * @brief Create sum instruction - * - * @code - * sum.. %alpha, %A, %beta, %B : type(%alpha), type(%A), type(%beta), type(%B) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param tA [in] operation applied on A - * @param atomic [in] true for atomic updates of C - * @param alpha [in] @f$\alpha@f$ - * @param A [in] A - * @param beta [in] @f$\beta@f$ - * @param B [in] B - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_bool_t atomic, tinytc_value_t alpha, - tinytc_value_t A, tinytc_value_t beta, - tinytc_value_t B, - const tinytc_location_t *loc); - -/** - * @brief Create for loop - * - * @code - * for %loop_var = %from, %to, %step : type(%loop_var) { %body } - * ; type(%loop_var) == type(%from) - * ; type(%loop_var) == type(%to) - * ; type(%loop_var) == type(%step) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param loop_var [in] loop variable - * @param from [in] loop begion - * @param to [in] loop bound - * @param step [in][optional] loop step; can be nullptr - * @param body [in] loop body - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_value_t step, tinytc_region_t body, - const tinytc_location_t *loc); - -/** - * @brief Create foreach loop - * - * @code - * foreach %loop_var = %from, %to : type(%loop_var) { %body } - * ; type(%loop_var) == type(%from) - * ; type(%loop_var) == type(%to) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param loop_var [in] loop variable - * @param from [in] loop begion - * @param to [in] loop bound - * @param body [in] loop body - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, - tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_region_t body, - const tinytc_location_t *loc); - -/** - * @brief Create if condition - * - * @code - * if %condition { %then } else { %otherwise } - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param condition [in] condition - * @param then [in] region taken if condition is true - * @param otherwise [in][optional] region taken if condition is false; can be nullptr - * @param return_type_list_size [in] length of return type array - * @param return_type_list [in][range(0, return_type_list_size)] return type array; can be nullptr - * if return_type_list_size is 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, - tinytc_region_t then, tinytc_region_t otherwise, - uint32_t return_type_list_size, - tinytc_scalar_type_t *return_type_list, - const tinytc_location_t *loc); - -/** - * @brief Create yield instruction - * - * @code - * yield %identifier1, ..., %identifierN : type(%identifier1), ..., type(%identifierN) - * @endcode - * - * @param instr [out] pointer to the inst object created - * @param yield_list_size [in] length of yielded values list; must be at least 1 - * @param yield_list [in][range(1, yield_list_size)] yielded values array - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, - uint32_t yield_list_size, - tinytc_value_t *yield_list, - const tinytc_location_t *loc); - -/** - * @brief Release inst object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param instr [inout] inst object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_release(tinytc_inst_t instr); - -/** - * @brief Increase reference count of inst object by 1 - * - * @param instr [inout] inst object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_retain(tinytc_inst_t instr); - -/** - * @brief Get value produced by instruction - * - * @param instr [in] inst object - * @param result [out] result value; may be set to nullptr if instruction does not return a value - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, - tinytc_value_t *result); - -/** - * @brief Get values produced by instruction - * - * Function can be called with result_list_size = 0 and result_list = nullptr in order to obtain - * the number of results - * - * @param instr [in] inst object - * @param result_list_size [inout] number of results to fetch; is updated with the actual value - * @param result_list [out][range(0, result_list_size)] user-provided memory for storing result - * handles; at most result_list_size values are written; can be nullptr if result_list_size is 0 - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, - uint32_t *result_list_size, - tinytc_value_t *result_list); - -//////////////////////////// -////////// Region ////////// -//////////////////////////// - -/** - * @brief Create region - * - * @param reg [out] pointer to the region object created - * @param instruction_list_size [in] length of instruction array - * @param instruction_list [in][range(0, instruction_list_size)] instruction array; can be nullptr - * if instruction_list_size is 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_region_create(tinytc_region_t *reg, - uint32_t instruction_list_size, - tinytc_inst_t *instruction_list, - const tinytc_location_t *loc); -/** - * @brief Release region object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param reg [inout] region object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_region_release(tinytc_region_t reg); - -/** - * @brief Increase reference count of region object by 1 - * - * @param reg [inout] region object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_region_retain(tinytc_region_t reg); - -//////////////////////////// -/////////// Func /////////// -//////////////////////////// - -/** - * @brief Create function prototype - * - * @param fun [out] pointer to the func object created - * @param name [in] function name - * @param arg_list_size [in] length of argument array - * @param arg_list [in][range(0, arg_list_size)] argument array; can be nullptr if arg_list_size is - * 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, - tinytc_value_t *arg_list, - const tinytc_location_t *loc); - -/** - * @brief Create function - * - * @param fun [out] pointer to the func object created - * @param prototype [in] function prototype - * @param body [in] function body - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototype, - tinytc_region_t body, - const tinytc_location_t *loc); - -/** - * @brief Set work-group size - * - * @param fun [out] func object (must be the function definition, not the function prototype) - * @param x [in] number of rows in parallel grid; must be a multiple of the subgroup size - * @param y [in] number of columns in parallel grid - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, - int32_t y); -/** - * @brief Set subgroup size - * - * @param fun [out] func object (must be the function definition, not the function prototype) - * @param sgs [in] subgroup size; the supported values need to be queried from the compute device - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs); - -/** - * @brief Release function object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param fun [inout] function object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_func_release(tinytc_func_t fun); - -/** - * @brief Increase reference count of function object by 1 - * - * @param fun [inout] function object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_func_retain(tinytc_func_t fun); - -//////////////////////////// -/////////// Prog /////////// -//////////////////////////// - -/** - * @brief Create program - * - * @param prg [out] pointer to the prog object created - * @param fun_list_size [in] length of func array - * @param fun_list [in][range(0, fun_list_size)] func array; can be nullptr if fun_list_size is 0 - * @param loc [in][optional] Source code location; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size, - tinytc_func_t *fun_list, - const tinytc_location_t *loc); - -/** - * @brief Release program object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param prg [inout] program object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_release(tinytc_prog_t prg); - -/** - * @brief Increase reference count of program object by 1 - * - * @param prg [inout] program object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_retain(tinytc_prog_t prg); - -//////////////////////////// -// Visitors and transforms / -//////////////////////////// - -/** - * @brief Dump program to stderr - * - * @param prg [in] program object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_dump(const_tinytc_prog_t prg); - -/** - * @brief Print program to file - * - * @param prg [in] program object - * @param filename [in] filename - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, - char const *filename); - -/** - * @brief Print program to string - * - * The user is responsible to dispose the string with tinytc_string_destroy. - * - * @param prg [in] program object - * @param str [out] pointer to string - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_print_to_string(const_tinytc_prog_t prg, char **str); - -/** - * @brief Delete a (non-const) string returned from tinytc API - * - * @param str [in] string - */ -TINYTC_EXPORT void tinytc_string_destroy(char *str); - -//////////////////////////// -//////// Device info /////// -//////////////////////////// - -/** - * @brief Create core_info for a generic GPUs - * - * @param info [out] pointer to the core_info object created - * @param register_space [in] Size of register file per subgroup in bytes - * @param max_work_group_size [in] Maximum size of local work group - * @param sgs_size [in] Length of sgs array - * @param sgs [in] Allowed subgroup sizes - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_generic_create(tinytc_core_info_t *info, - int32_t register_space, - int32_t max_work_group_size, - uint32_t sgs_size, - int32_t const *sgs); - -/** - * @brief Look up core info for Intel GPU architecture - * - * @param info [out] pointer to the core_info object created - * @param arch [in] IP version - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create_from_arch( - tinytc_core_info_t *info, tinytc_intel_gpu_architecture_t arch); - -/** - * @brief Create core_info for Intel GPUs - * - * @param info [out] pointer to the core_info object created - * @param ip_version [in] IP version of architecture - * @param num_eus_per_subslice [in] Number of Execution Units (Xe Vector Engines) per subslice (Xe - * Core) - * @param num_threads_per_eu [in] Number of threads per Execution Unit (Xe Vector Engine) - * @param sgs_size [in] Length of sgs array - * @param sgs [in] Allowed subgroup sizes - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_intel_create(tinytc_core_info_t *info, - uint32_t ip_version, - int32_t num_eus_per_subslice, - int32_t num_threads_per_eu, - uint32_t sgs_size, int32_t const *sgs); - -/** - * @brief Returns available subgroup sizes - * - * @param info [in] core info object - * @param sgs_size [out] pointer to number of subgroup sizes - * @param sgs [out] pointer to subgroup size array; pointer is invalidated when core info is deleted - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_subgroup_sizes(const_tinytc_core_info_t info, - uint32_t *sgs_size, - int32_t const **sgs); - -/** - * @brief Returns register space per subgroup in bytes - * - * @param info [in] core info object - * @param space [out] pointer to register space - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_get_register_space(const_tinytc_core_info_t info, - int32_t *space); - -/** - * @brief Set core features - * - * @param info [in] core info object - * @param flags [in] set core features; must be 0 or a combination of tinytc_core_feature_flag_t - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_set_core_features(tinytc_core_info_t info, - tinytc_core_feature_flags_t flags); - -/** - * @brief Get core features - * - * @param info [in] core info object - * @param flags [out] pointer to core feature flags - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t -tinytc_core_info_get_core_features(tinytc_core_info_t info, tinytc_core_feature_flags_t *flags); - -/** - * @brief Release core info object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] core info object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_release(tinytc_core_info_t obj); - -/** - * @brief Increase reference count of core info object by 1 - * - * @param obj [inout] core info object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_core_info_retain(tinytc_core_info_t obj); - -//////////////////////////// -////////// Parser ////////// -//////////////////////////// - -/** - * @brief Parser tensor language source file and create prog - * - * @param prg [out] pointer to prog object created - * @param filename [in] path to source file - * @param ctx [inout][optional] source context object; stores error log; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, - tinytc_source_context_t ctx); - -/** - * @brief Parser tensor language source from stdin and create prog - * - * @param prg [out] pointer to prog object created - * @param ctx [inout][optional] source context object; stores error log; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t ctx); - -/** - * @brief Parser tensor language source from string and create prog - * - * @param prg [out] pointer to prog object created - * @param source_size [in] length of source string - * @param source [in] source string - * @param ctx [inout][optional] source context object; stores error log; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, - char const *source, tinytc_source_context_t ctx); -/** - * @brief Create source context - * - * The source context stores the tensor language source and enhaces error messages with - * source code context. - * - * @param ctx [out] pointer to the source context object created - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_create(tinytc_source_context_t *ctx); - -/** - * @brief Add source context - * - * Manually add a source file to the source context that can be referenced in a tinytc_location. - * Useful to enhance error messages when using the builder methods and classes. - * - * @param ctx [in] source context object - * @param name [in] source name - * @param text [in] source text - * @param source_id [out] pointer to source id - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_add_source(tinytc_source_context_t ctx, - char const *name, char const *text, - int32_t *source_id); - -/** - * @brief Get error log - * - * The string's memory is owned by source context. - * Note that the pointer may invalidated by any function call involving the source context object, - * so the string should be copied or printed right after a call to this function. - * - * @param ctx [in] source context object - * @param log [out] pointer to string - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_get_error_log(const_tinytc_source_context_t ctx, - char const **log); - -/** - * @brief Report an error and augment the error with source context - * - * @param ctx [in] source context object - * @param location [in] source location - * @param what [in] error description - * @param append [in] true: append to error log, false: clear error log - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_report_error(tinytc_source_context_t ctx, - const tinytc_location_t *location, - char const *what, - tinytc_bool_t append); - -/** - * @brief Release source context object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] source context object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_release(tinytc_source_context_t obj); - -/** - * @brief Increase reference count of source context object by 1 - * - * @param obj [inout] source context object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_context_retain(tinytc_source_context_t obj); - -//////////////////////////// -///////// Compiler ///////// -//////////////////////////// - -/** - * @brief Compile tensor language to OpenCL-C - * - * @param src [out] pointer to the source object created - * @param prg [inout] tensor program; modified as compiler passes are run - * @param info [in] core info object - * @param ctx [inout][optional] source context object to save extended error messages that are - * enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx); - -/** - * @brief Get source text - * - * @param src [in] source object - * @param length [out] pointer to code length - * @param code [out] code contains a pointer to the source text; the pointer is only valid as long - * as the source object is alive - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, - char const **code); - -/** - * @brief Get source location - * - * @param src [in] source object - * @param loc [out] pointer to location - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, - tinytc_location_t *loc); - -/** - * @brief Get core features - * - * @param src [in] source object - * @param core_features [out] pointer to core features - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_core_features( - const_tinytc_source_t src, tinytc_core_feature_flags_t *core_features); - -/** - * @brief Get required OpenCL extensions - * - * @param src [in] source object - * @param extensions_size [out] pointer to number of extensions - * @param extensions [out][range(0,extensions_size)] pointer to array of C-strings; array owned by - * source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t src, - uint32_t *extensions_size, - char const *const **extensions); - -/** - * @brief Create binary - * - * @param bin [out] pointer to binary object - * @param format [in] Bundle format (SPIR-V or Native) - * @param data_size [in] Size of data in bytes - * @param data [in][range(0, data_size)] Binary data; data is copied - * @param core_features [in][optional] requested core features; must be 0 (default) or a combination - * of tinytc_core_feature_flag_t - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, - tinytc_bundle_format_t format, size_t data_size, - uint8_t const *data, - tinytc_core_feature_flags_t core_features); - -/** - * @brief Get raw binary data - * - * @param bin [in] binary object - * @param format [out] binary format - * @param data_size [out] size of data - * @param data [out] data array; returned pointer is invalidated if the binary object is deleted - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, - tinytc_bundle_format_t *format, - size_t *data_size, uint8_t const **data); -/** - * @brief Get requested core features - * - * @param bin [in] binary object - * @param core_features [out] core features - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_binary_get_core_features( - const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features); - -/** - * @brief Release source object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_release(tinytc_source_t obj); - -/** - * @brief Increase reference count of source object by 1 - * - * @param obj [inout] source object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_source_retain(tinytc_source_t obj); - -/** - * @brief Release binary object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param bin [inout] binary object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_binary_release(tinytc_binary_t bin); - -/** - * @brief Increase reference count of binary object by 1 - * - * @param bin [inout] binary object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_binary_retain(tinytc_binary_t bin); - -//////////////////////////// -////////// Recipe ////////// -//////////////////////////// - -/** - * @brief Returns a small batched GEMM recipe - * - * The program contains a kernel for @f$\beta=0@f$ called "gemm_beta0" and a kernel for @f$\beta\neq - * 0@f$ called "gemm". All matrix shapes and strides are known at compile-time. - * - * The signature of the generated kernels gemm and gemm_beta0 is (if A and B are not transposed) - * - * @code - * func @{name}(%alpha: {ty.alpha}, - * %A: memref<{ty.A}x{M}x{K}x?,strided<1,{ldA},{strideA}>>, - * %B: memref<{ty.B}x{K}x{N}x?,strided<1,{ldB},{strideB}>>, - * %beta: {ty.beta}, - * %C: memref<{ty.C}x{M}x{N}x?,strided<1,{ldC},{strideC}>>) - * @endcode - * - * meaning that its kernels need arguments in the following order: - * - * @code - * alpha, A_ptr, howmany, B_ptr, howmany, beta, C_ptr, howmany - * @endcode - * - * @param recipe [out] pointer to the recipe object created - * @param info [in] core info object - * @param ty [in] Scalar types of alpha, A, B, beta, C - * @param tA [in] Transpose A - * @param tB [in] Transpose B - * @param M [in] Number of rows of A, C - * @param N [in] Number of columns of B, C - * @param K [in] Number columns of A, number of rows of B - * @param ldA [in] Leading dimension of A - * @param strideA [in] Number of elements between A-matrices - * @param ldB [in] Leading dimension of B - * @param strideB [in] Number of elements between B-matrices - * @param ldC [in] Leading dimension of C - * @param strideC [in] Number of elements between C-matrices - * @param ctx [inout][optional] source context object; saves error log; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_create( - tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, - tinytc_transpose_t tA, tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, - int64_t strideA, int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC, - tinytc_source_context_t ctx); - -/** - * @brief Set kernel arguments for small GEMM batched recipe - * - * @param handler [inout] Recipe handler object - * @param howmany [in] Group size - * @param alpha_size [in] Size of alpha argument - * @param alpha_value [in] Pointer to data used for alpha; data is copied - * @param A_type [in] Type of memory object used for A-matrix - * @param A_value [in] Memory object used for A-matrix - * @param B_type [in] Type of memory object used for B-matrix - * @param B_value [in] Memory object used for B-matrix - * @param beta_size [in] Size of beta argument - * @param beta_value [in] Pointer to data used for beta; data is copied - * @param C_type [in] Type of memory object used for C-matrix - * @param C_value [in] Memory object used for C-matrix - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( - tinytc_recipe_handler_t handler, int64_t howmany, size_t alpha_size, const void *alpha_value, - tinytc_mem_type_t A_type, const void *A_value, tinytc_mem_type_t B_type, const void *B_value, - size_t beta_size, const void *beta_value, tinytc_mem_type_t C_type, const void *C_value); - -/** - * @brief Returns a tall and skinny recipe - * - * The program contains a kernel for beta = 0 called "gemm_beta0" and a kernel for beta != 0 called - * "gemm". M (= number of rows of A, C) and strides are dynamic. - * - * The signature of the generated kernels gemm and gemm_beta0 is - * - * @code - * func @{name}(%alpha: {ty.alpha}, - * %A: memref<{ty.A}x?x{K},strided<1,?>>, - * %B: memref<{ty.B}x{K}x{N},strided<1,?>>, - * %beta: {ty.beta}, - * %C: memref<{ty.C}x?x{N},strided<1,?>>) - * @endcode - * - * meaning that its kernels need arguments in the following order: - * - * @code - * alpha, A_ptr, M, ldA, B_ptr, ldB, beta, C_ptr, M, ldC - * @endcode - * - * where ldA, ldB, ldC is the size of stride[1] of A, B, C, respectively. - * - * @param recipe [out] pointer to the recipe object created - * @param info [in] core info object - * @param ty [in] Scalar type of alpha, A, B, beta, C - * @param N [in] Number of columns of B, C - * @param K [in] Number columns of A, number of rows of B - * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the - * parameter auto-selected - * @param ctx [inout][optional] source context object; saves error log; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create( - tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t N, - int64_t K, int32_t M_block_size, tinytc_source_context_t ctx); - -/** - * @brief Returns a tall and skinny recipe with additional specialization constants - * - * Similar to tinytc_recipe_tall_and_skinny_create but with the additional specialization - * constants M, ldA, ldB, and ldC. - * The specializtion constants may be either set to a fixed value or to TINYTC_DYNAMIC. - * Note that if a specialization constant is set to a fixed value then the parameter with the same - * name in tinytc_recipe_tall_and_skinny_set_args is ignored. - * - * The generated kernels have the following signature: - * - * @code - * func @{name}(%alpha: {ty.alpha}, - * %A: memref<{ty.A}x{M}x{K},strided<1,{ldA}>>, - * %B: memref<{ty.B}x{K}x{N},strided<1,{ldB}>>, - * %beta: {ty.beta}, - * %C: memref<{ty.C}x{M}x{N},strided<1,{ldC}>>) - * @endcode - * - * @param recipe [out] pointer to the recipe object created - * @param info [in] core info object - * @param ty [in] Scalar type of alpha, A, B, beta, C - * @param M [in] Number of rows of A, C; can be TINYTC_DYNAMIC - * @param N [in] Number of columns of B, C - * @param K [in] Number columns of A, number of rows of B - * @param ldA [in] Leading dimension of A; can be TINYTC_DYNAMIC - * @param ldB [in] Leading dimension of B; can be TINYTC_DYNAMIC - * @param ldC [in] Leading dimension of C; can be TINYTC_DYNAMIC - * @param M_block_size [in][optional] Size of M block that each work group gets; pass 0 to have the - * parameter auto-selected - * @param ctx [inout][optional] source context object; saves error log; can be nullptr - * - * @return - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( - tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t M, - int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t M_block_size, - tinytc_source_context_t ctx); - -/** - * @brief Suggest an M block size for tall and skinny recipe - * - * @param info [in] core info object - * @param M_block_size [out] pointer to block size - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size( - const_tinytc_core_info_t info, int32_t *M_block_size); - -/** - * @brief Set kernel arguments for tall and skinny GEMM recipe - * - * @param handler [inout] Recipe handler object - * @param M [in] Size of M-mode - * @param alpha_size [in] Size of alpha argument - * @param alpha_value [in] Pointer to data used for alpha; data is copied - * @param A_type [in] Type of memory object used for A-matrix - * @param A_value [in] Memory object used for A-matrix - * @param ldA [in] Leading dimension of A - * @param B_type [in] Type of memory object used for B-matrix - * @param B_value [in] Memory object used for B-matrix - * @param ldB [in] Leading dimension of B - * @param beta_size [in] Size of beta argument - * @param beta_value [in] Pointer to data used for beta; data is copied - * @param C_type [in] Type of memory object used for C-matrix - * @param C_value [in] Memory object used for C-matrix - * @param ldC [in] Leading dimension of C - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( - tinytc_recipe_handler_t handler, int64_t M, size_t alpha_size, const void *alpha_value, - tinytc_mem_type_t A_type, const void *A_value, int64_t ldA, tinytc_mem_type_t B_type, - const void *B_value, int64_t ldB, size_t beta_size, const void *beta_value, - tinytc_mem_type_t C_type, const void *C_value, int64_t ldC); - -/** - * @brief Get prog object - * - * @param recipe [in] recipe object - * @param prg [out] pointer to prog object; reference count is increased so the user needs to call - * tinytc_prog_release to clean up - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recipe, - tinytc_prog_t *prg); - -/** - * @brief Get source object - * - * @param recipe [in] recipe object - * @param src [out] pointer to source object; reference count is increased so the user needs to call - * tinytc_source_release to clean up - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_get_source(const_tinytc_recipe_t recipe, - tinytc_source_t *src); - -/** - * @brief Release recipe object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] recipe object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_release(tinytc_recipe_t obj); - -/** - * @brief Increase reference count of recipe object by 1 - * - * @param obj [inout] recipe object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_retain(tinytc_recipe_t obj); - -/** - * @brief Get recipe object - * - * @param handler [in] recipe handler object - * @param recipe [out] pointer to recipe object; reference count is increased so the user needs to - * call tinytc_recipe_release to clean up - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t -tinytc_recipe_handler_get_recipe(const_tinytc_recipe_handler_t handler, tinytc_recipe_t *recipe); - -/** - * @brief Release recipe handler object - * - * Decreases reference count by 1, free memory if reference count is 0. - * - * @param obj [inout] recipe handler object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_handler_release(tinytc_recipe_handler_t obj); - -/** - * @brief Increase reference count of recipe handler object by 1 - * - * @param obj [inout] recipe handler object - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_recipe_handler_retain(tinytc_recipe_handler_t obj); - -#ifdef __cplusplus -} -#endif - -#endif // TINYTC_20240409_H +#endif // TINYTC_20250704_H diff --git a/include/tinytc/tinytc.hpp b/include/tinytc/tinytc.hpp index 99b19b41..85ffa43b 100644 --- a/include/tinytc/tinytc.hpp +++ b/include/tinytc/tinytc.hpp @@ -1,2257 +1,13 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#ifndef TINYTC_20240403_HPP -#define TINYTC_20240403_HPP +#ifndef TINYTC_20250704_HPP +#define TINYTC_20250704_HPP -#include "tinytc/tinytc.h" +#include "tinytc/builder.hpp" +#include "tinytc/core.hpp" +#include "tinytc/export.h" #include "tinytc/types.hpp" +#include "tinytc/version.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tinytc { - -//////////////////////////// -/////////// Error ////////// -//////////////////////////// - -//! Convert error code to string -inline char const *error_string(status code) { - return ::tinytc_error_string(static_cast<::tinytc_status_t>(code)); -} - -//! Throw exception for unsuccessful call to C-API -inline void CHECK_STATUS(tinytc_status_t code) { - if (code != tinytc_status_success) { - throw status{std::underlying_type_t(code)}; - } -} - -//! Builder exception enhanced with location -class builder_error : public std::exception { - public: - //! ctor; taking location and status code - inline builder_error(status code, location const &loc) : code_(code), loc_(loc) {} - //! Get status code - inline auto code() const noexcept { return code_; } - //! Get location - inline auto loc() const noexcept -> location const & { return loc_; } - //! Get explanatory string - inline char const *what() const noexcept override { return error_string(code_); } - - private: - status code_; - location loc_; -}; - -//! Throw exception for unsuccessful call to C-API -inline void CHECK_STATUS_LOC(tinytc_status_t code, location const &loc) { - if (code != tinytc_status_success) { - throw builder_error(status{std::underlying_type_t(code)}, loc); - } -} - -//////////////////////////// -//////// Scalar type /////// -//////////////////////////// - -//! Convert scalar type to string -inline char const *to_string(scalar_type ty) { - return ::tinytc_scalar_type_to_string(static_cast(ty)); -} -//! Size of scalar type in bytes -inline std::size_t size(scalar_type ty) { - return ::tinytc_scalar_type_size(static_cast(ty)); -} - -/** - * Returns the scalar type corresponding to C++ type T - * - * Specializations exist for bool, (u)int8_t, (u)int16_t, (u)int32_t, (u)int64_t, float, double. - * The scalar_type is stored in the static constexpr member "value". - */ -template struct to_scalar_type; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i1; ///< value -}; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i8; ///< value -}; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i16; ///< value -}; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i32; ///< value -}; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::i64; ///< value -}; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::f32; ///< value -}; -//! to_scalar_type specialization -template <> struct to_scalar_type { - static constexpr scalar_type value = scalar_type::f64; ///< value -}; -/** - * Convenience variable for to_scalar_type. - * - * Example: @code scalar_type ty = to_scalar_type_v; @endcode - */ -template inline constexpr scalar_type to_scalar_type_v = to_scalar_type::value; - -//////////////////////////// -// Shared / unique handle // -//////////////////////////// - -namespace internal { -//! Wraps retain / release calls for type T -template struct shared_handle_traits {}; - -//! Wraps destroy calls for type T -template struct unique_handle_traits {}; -} // namespace internal - -/** - * @brief Wraps a C handle in a reference-counted object - * - * @tparam T C handle type (handle type = pointer to opaque struct) - */ -template class shared_handle { - public: - //! Traits shortcut - using traits = internal::shared_handle_traits; - //! Typedef for native C handle - using native_type = T; - - //! Create empty (invalid) handle - shared_handle() : obj_{nullptr} {} - //! Create handle from C handle - explicit shared_handle(T obj, bool needs_retain = false) : obj_(obj) { - if (needs_retain) { - CHECK_STATUS(c_retain()); - } - } - //! Decrease reference count - ~shared_handle() { c_release(); } - //! Copy ctor - shared_handle(shared_handle const &other) : obj_(other.obj_) { CHECK_STATUS(c_retain()); } - //! Move ctor - shared_handle(shared_handle &&other) noexcept : obj_(other.obj_) { other.obj_ = nullptr; } - //! Copy operator - shared_handle &operator=(shared_handle const &other) { - if (obj_ != other.obj_) { - CHECK_STATUS(c_release()); - obj_ = other.obj_; - CHECK_STATUS(c_retain()); - } - return *this; - } - //! Move operator - shared_handle &operator=(shared_handle &&other) { - if (obj_ != other.obj_) { - CHECK_STATUS(c_release()); - obj_ = other.obj_; - other.obj_ = nullptr; - } - return *this; - } - - //! Dereference C handle and get reference to underlying type - auto operator*() const -> std::remove_pointer_t & { return *obj_; } - //! Convert handle to C handle - auto operator->() const -> T { return obj_; } - //! Returns C handle - auto get() const -> T { return obj_; } - //! Returns C handle and releases the ownership of the managed object - auto release() -> T { - auto tmp = obj_; - obj_ = nullptr; - return tmp; - } - - //! Check whether handle is non-empty (valid) - explicit operator bool() const noexcept { return obj_ != nullptr; } - - //! Check equality - bool operator==(shared_handle const &other) const { return obj_ == other.obj_; } - //! Check inequality - bool operator!=(shared_handle const &other) const { return !(*this == other); } - - protected: - //! Call retain in C-API if C handle is not NULL - auto c_retain() -> tinytc_status_t { - if (obj_ != nullptr) { - return traits::retain(obj_); - } - return tinytc_status_success; - } - //! Call release in C-API if C handle is not NULL - auto c_release() -> tinytc_status_t { - if (obj_ != nullptr) { - return traits::release(obj_); - } - return tinytc_status_success; - } - //! The C handle - T obj_; -}; - -/** - * @brief Wraps a C handle in a unique_ptr-alike object - * - * @tparam T C handle type (handle type = pointer to opaque struct) - */ -template class unique_handle { - public: - //! Traits shortcut - using traits = internal::unique_handle_traits; - //! Typedef for native C handle - using native_type = T; - - //! Create empty (invalid) handle - unique_handle() : obj_{nullptr} {} - //! Create handle from C handle - explicit unique_handle(T obj) : obj_(obj) {} - //! Destroy object - ~unique_handle() { - if (obj_) { - traits::destroy(obj_); - } - } - //! Copy ctor - unique_handle(unique_handle const &other) = delete; - //! Move ctor - unique_handle(unique_handle &&other) noexcept : obj_(other.obj_) { other.obj_ = nullptr; } - //! Copy operator - unique_handle &operator=(unique_handle const &other) = delete; - //! Move operator - unique_handle &operator=(unique_handle &&other) { - obj_ = other.obj_; - other.obj_ = nullptr; - return *this; - } - - //! Dereference C handle and get reference to underlying type - auto operator*() const -> std::remove_pointer_t & { return *obj_; } - //! Convert handle to C handle - auto operator->() const -> T { return obj_; } - //! Returns C handle - auto get() const -> T { return obj_; } - //! Returns C handle and releases the ownership of the managed object - auto release() -> T { - auto tmp = obj_; - obj_ = nullptr; - return tmp; - } - - //! Check whether handle is non-empty (valid) - explicit operator bool() const noexcept { return obj_ != nullptr; } - - //! Check equality - bool operator==(unique_handle const &other) const { return obj_ == other.obj_; } - //! Check inequality - bool operator!=(unique_handle const &other) const { return !(*this == other); } - - protected: - //! The C handle - T obj_; -}; - -//////////////////////////// -///////// Data type //////// -//////////////////////////// - -//! Check if mode i is dynamic ('?') -inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_data_type_t handle) -> tinytc_status_t { - return tinytc_data_type_retain(handle); - } - static auto release(tinytc_data_type_t handle) -> tinytc_status_t { - return tinytc_data_type_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_data_type_t -class data_type : public shared_handle { - public: - using shared_handle::shared_handle; -}; - -/** - * @brief Make a scalar data type - * - * Cf. \ref tinytc_scalar_type_create - * - * @param scalar_ty Scalar type - * @param loc Source code location - * - * @return Data type - */ -inline data_type make_scalar(scalar_type scalar_ty, location const &loc = {}) { - tinytc_data_type_t st; - CHECK_STATUS_LOC( - tinytc_scalar_type_create(&st, static_cast(scalar_ty), &loc), loc); - return data_type{st}; -} - -/** - * @brief Make a memref data type - * - * Cf. \ref tinytc_memref_type_create - * - * @param scalar_ty Element type - * @param shape Tensor shape - * @param stride Tensor stride - * @param loc Source code location - * - * @return Data type - */ -inline data_type make_memref(scalar_type scalar_ty, std::vector const &shape, - std::vector const &stride = {}, - location const &loc = {}) { - tinytc_data_type_t mt; - CHECK_STATUS_LOC(tinytc_memref_type_create(&mt, static_cast(scalar_ty), - shape.size(), shape.data(), stride.size(), - stride.data(), &loc), - loc); - return data_type{mt}; -} - -/** - * @brief Make a group data type - * - * @param memref_ty Memref data type - * @param offset Offset parameter - * @param loc Source code location - * - * @return Data type - */ -inline data_type make_group(data_type const &memref_ty, std::int64_t offset = 0, - location const &loc = {}) { - tinytc_data_type_t gt; - CHECK_STATUS_LOC(tinytc_group_type_create(>, memref_ty.get(), offset, &loc), loc); - return data_type{gt}; -} - -//////////////////////////// -/////////// Value ////////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_value_t handle) -> tinytc_status_t { - return tinytc_value_retain(handle); - } - static auto release(tinytc_value_t handle) -> tinytc_status_t { - return tinytc_value_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_value_t -class value : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get name - * - * @return Name - */ - inline auto get_name() const -> char const * { - char const *name; - CHECK_STATUS(tinytc_value_get_name(obj_, &name)); - return name; - } - /** - * @brief Set name - * - * @param name Name - */ - inline void name(std::string const &name) { - CHECK_STATUS(tinytc_value_set_name(obj_, name.c_str())); - } -}; - -namespace internal { -//! Is reinterpret_cast(&v) allowed, where v has type value -constexpr bool value_reinterpret_allowed = - std::is_standard_layout_v && sizeof(value) == sizeof(tinytc_value_t); -} // namespace internal - -/** - * @brief Make value - * - * @param ty Data type - * @param loc Source code location - * - * @return Value - */ -inline auto make_value(data_type const &ty, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_value_create(&val, ty.get(), &loc), loc); - return value{val}; -} - -/** - * @brief Make value - * - * @param scalar_ty Scalar type - * @param loc Source code location - * - * @return Value - */ -inline auto make_value(scalar_type scalar_ty, location const &loc = {}) -> value { - tinytc_value_t val; - auto ty = make_scalar(scalar_ty, loc); - CHECK_STATUS_LOC(tinytc_value_create(&val, ty.get(), &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is f32. - * - * @param imm Float value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(float imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_float_imm_create(&val, imm, tinytc_scalar_type_f32, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * @param imm Float value - * @param type Type of immediate value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(double imm, scalar_type type = scalar_type::f64, location const &loc = {}) - -> value { - tinytc_value_t val; - CHECK_STATUS_LOC( - tinytc_float_imm_create(&val, imm, static_cast(type), &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is i8. - * - * @param imm Int value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int8_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i8, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is i16. - * - * @param imm Int value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int16_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i16, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * Type is i32. - * - * @param imm Int value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int32_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_i32, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate value - * - * @param imm Int value - * @param type Type of immediate value - * @param loc Source code location - * - * @return Value - */ -inline auto make_imm(std::int64_t imm, scalar_type type = scalar_type::i64, - location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC( - tinytc_int_imm_create(&val, imm, static_cast(type), &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate index value - * - * @param imm index value - * @param loc Source code location - * - * @return Value - */ -inline auto make_index(std::int32_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_index, &loc), loc); - return value{val}; -} - -/** - * @brief Make immediate index value - * - * @param imm index value - * @param loc Source code location - * - * @return Value - */ -inline auto make_index(std::int64_t imm, location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, imm, tinytc_scalar_type_index, &loc), loc); - return value{val}; -} - -/** - * @brief Make dynamic ('?') - * - * @param loc Source code location - * - * @return Value - */ -inline auto make_dynamic(location const &loc = {}) -> value { - tinytc_value_t val; - CHECK_STATUS_LOC(tinytc_int_imm_create(&val, dynamic, tinytc_scalar_type_i64, &loc), loc); - return value{val}; -} - -//////////////////////////// -/////////// Inst /////////// -//////////////////////////// - -/** - * @brief Convert arithmetic operation type to string - * - * @param op Arithmetic operation type - * - * @return C-string - */ -inline char const *to_string(arithmetic op) { - return ::tinytc_arithmetic_to_string(static_cast<::tinytc_arithmetic_t>(op)); -} - -/** - * @brief Convert arithmetic operation type to string (unary) - * - * @param op Arithmetic operation type - * - * @return C-string - */ -inline char const *to_string(arithmetic_unary op) { - return ::tinytc_arithmetic_unary_to_string(static_cast<::tinytc_arithmetic_unary_t>(op)); -} - -/** - * @brief Convert cmp condition to string - * - * @param cond Condition - * - * @return C-string - */ -inline char const *to_string(cmp_condition cond) { - return ::tinytc_cmp_condition_to_string(static_cast<::tinytc_cmp_condition_t>(cond)); -} - -/** - * @brief Convert transpose to string - * - * @param t Transpose - * - * @return C-string - */ -inline char const *to_string(transpose t) { - return ::tinytc_transpose_to_string(static_cast(t)); -} - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_inst_t handle) -> tinytc_status_t { - return tinytc_inst_retain(handle); - } - static auto release(tinytc_inst_t handle) -> tinytc_status_t { - return tinytc_inst_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_inst_t -class inst : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get result value - * - * @return Value; may be empty - */ - inline auto get_value() const -> value { - tinytc_value_t result; - CHECK_STATUS(tinytc_inst_get_value(obj_, &result)); - return value{result}; - } - - /** - * @brief Get result values - * - * @return Vector of values - */ - inline auto get_values() const -> std::vector { - static_assert(internal::value_reinterpret_allowed); - std::uint32_t result_list_size = 0; - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, nullptr)); - auto values = std::vector(result_list_size); - tinytc_value_t *result_list = reinterpret_cast(values.data()); - CHECK_STATUS(tinytc_inst_get_values(obj_, &result_list_size, result_list)); - return values; - } -}; - -namespace internal { -//! Is reinterpret_cast(&i) allowed, where i has type inst -constexpr bool inst_reinterpret_allowed = - std::is_standard_layout_v && sizeof(inst) == sizeof(tinytc_inst_t); -} // namespace internal - -//////////////////////////// -////////// Region ////////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_region_t handle) -> tinytc_status_t { - return tinytc_region_retain(handle); - } - static auto release(tinytc_region_t handle) -> tinytc_status_t { - return tinytc_region_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_region_t -class region : public shared_handle { - public: - using shared_handle::shared_handle; -}; - -/** - * @brief Make region - * - * @param instructions Vector of instructions - * @param loc Source code location - * - * @return Region - */ -inline region make_region(std::vector &instructions, location const &loc = {}) { - tinytc_region_t reg; - static_assert(internal::inst_reinterpret_allowed); - if (instructions.size() > std::numeric_limits::max()) { - throw std::out_of_range("Instruction list too long"); - } - CHECK_STATUS_LOC(tinytc_region_create(®, instructions.size(), - reinterpret_cast(instructions.data()), - &loc), - loc); - return region{reg}; -} - -//////////////////////////// -/////// Instructions /////// -//////////////////////////// - -/** - * @brief Make arithmetic instruction (binary) - * - * @param op Arithmetic operation type - * @param a First operand - * @param b Second operand - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_arith(arithmetic op, value const &a, value const &b, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_arith_inst_create(&instr, static_cast(op), a.get(), - b.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make arithmetic instruction (unary) - * - * @param op Arithmetic operation type - * @param a Operand - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_arith(arithmetic_unary op, value const &a, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_arith_unary_inst_create( - &instr, static_cast(op), a.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make cast instruction - * - * @param a Operand - * @param to_ty Target type - * @param loc Source code lcoation - * - * @return Instruction - */ -inline inst make_cast(value const &a, scalar_type to_ty, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_cast_inst_create(&instr, a.get(), static_cast(to_ty), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make compare instruction - * - * @param cond Condition type - * @param a First operand - * @param b Second operand - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_cmp(cmp_condition cond, value const &a, value const &b, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_cmp_inst_create(&instr, static_cast(cond), - a.get(), b.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make alloca instruction - * - * @param ty Memref type of allocated variable - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_alloca(data_type const &ty, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_alloca_inst_create(&instr, ty.get(), &loc), loc); - return inst(instr); -} - -/** - * @brief Make AXPBY instruction - * - * @param tA Operation applied on A - * @param atomic true for atomic updates of B - * @param alpha @f$\alpha@f$ - * @param A A - * @param beta @f$\beta@f$ - * @param B B - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_axpby(transpose tA, bool atomic, value const &alpha, value const &A, - value const &beta, value const &B, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_axpby_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), beta.get(), B.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make expand instruction - * - * @param a Operand - * @param mode Expanded mode - * @param expand_shape New shape of mode - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_expand(value const &a, std::int64_t mode, std::vector const &expand_shape, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_inst_t instr; - auto len = expand_shape.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("expand shape too large"); - } - tinytc_value_t *eshape = - const_cast(reinterpret_cast(expand_shape.data())); - CHECK_STATUS_LOC(tinytc_expand_inst_create(&instr, a.get(), mode, len, eshape, &loc), loc); - return inst(instr); -} - -/** - * @brief Make fuse instruciton - * - * @param a Operand - * @param from First mode to fuse - * @param to Last mode to fuse - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_fuse(value const &a, std::int64_t from, std::int64_t to, - location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_fuse_inst_create(&instr, a.get(), from, to, &loc), loc); - return inst(instr); -} - -/** - * @brief Make load instruction - * - * @param a Operand - * @param index_list Vector of indices - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_load(value const &a, std::vector const &index_list, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_inst_t instr; - auto len = index_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("index list too long"); - } - tinytc_value_t *il = - const_cast(reinterpret_cast(index_list.data())); - CHECK_STATUS_LOC(tinytc_load_inst_create(&instr, a.get(), len, il, &loc), loc); - return inst(instr); -} - -/** - * @brief Make group id instruction - * - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_group_id(location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_id_inst_create(&instr, &loc), loc); - return inst(instr); -} - -/** - * @brief Make group size instruction - * - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_group_size(location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_group_size_inst_create(&instr, &loc), loc); - return inst(instr); -} - -/** - * @brief Make GEMM instruction - * - * @param tA Operation applied on A - * @param tB Operation applied on B - * @param atomic true for atomic updates of C - * @param alpha @f$\alpha@f$ - * @param A A - * @param B B - * @param beta @f$\beta@f$ - * @param C C - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_gemm(transpose tA, transpose tB, bool atomic, value const &alpha, value const &A, - value const &B, value const &beta, value const &C, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_gemm_inst_create(&instr, static_cast(tA), - static_cast(tB), atomic, - alpha.get(), A.get(), B.get(), beta.get(), C.get(), - &loc), - loc); - return inst(instr); -} - -/** - * @brief Make GEMV instruction - * - * @param tA Operation applied on A - * @param atomic true for atomic updates of C - * @param alpha @f$\alpha@f$ - * @param A A - * @param B B - * @param beta @f$\beta@f$ - * @param C C - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_gemv(transpose tA, bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_gemv_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), B.get(), beta.get(), C.get(), - &loc), - loc); - return inst(instr); -} - -/** - * @brief Make GER instruction - * - * @param atomic true for atomic updates of C - * @param alpha @f$\alpha@f$ - * @param A A - * @param B B - * @param beta @f$\beta@f$ - * @param C C - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_ger(bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_ger_inst_create(&instr, atomic, alpha.get(), A.get(), B.get(), - beta.get(), C.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make hadamard instruction - * - * @param atomic true for atomic updates of C - * @param alpha @f$\alpha@f$ - * @param A A - * @param B B - * @param beta @f$\beta@f$ - * @param C C - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_hadamard(bool atomic, value const &alpha, value const &A, value const &B, - value const &beta, value const &C, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_hadamard_inst_create(&instr, atomic, alpha.get(), A.get(), B.get(), - beta.get(), C.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make size instruction - * - * @param a Operand - * @param mode Mode - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_size(value const &a, std::int64_t mode, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_size_inst_create(&instr, a.get(), mode, &loc), loc); - return inst(instr); -} - -/** - * @brief Make subview instruction - * - * @param a Operand - * @param offset_list Vector of offsets - * @param size_list Vector of sizes; initialize with empty value if only offset is required - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_subview(value const &a, std::vector const &offset_list, - std::vector const &size_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_inst_t instr; - if (offset_list.size() != size_list.size()) { - throw std::invalid_argument("offset list must have the same length as the size list"); - } - auto len = offset_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("slice list too long"); - } - tinytc_value_t *ol = - const_cast(reinterpret_cast(offset_list.data())); - tinytc_value_t *sl = - const_cast(reinterpret_cast(size_list.data())); - CHECK_STATUS_LOC(tinytc_subview_inst_create(&instr, a.get(), len, ol, sl, &loc), loc); - return inst(instr); -} - -/** - * @brief Make store instruction - * - * @param val Value that is stored - * @param a Target memref - * @param index_list Vector of indices - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_store(value const &val, value const &a, std::vector const &index_list, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_inst_t instr; - auto len = index_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("index list too long"); - } - tinytc_value_t *il = - const_cast(reinterpret_cast(index_list.data())); - CHECK_STATUS_LOC(tinytc_store_inst_create(&instr, val.get(), a.get(), len, il, &loc), loc); - return inst(instr); -} - -/** - * @brief Make sum instruction - * - * @param tA Operation applied on A - * @param atomic true for atomic updates of B - * @param alpha @f$\alpha@f$ - * @param A A - * @param beta @f$\beta@f$ - * @param B B - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_sum(transpose tA, bool atomic, value const &alpha, value const &A, - value const &beta, value const &B, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_sum_inst_create(&instr, static_cast(tA), atomic, - alpha.get(), A.get(), beta.get(), B.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make for loop instruction - * - * @param loop_var Loop variable - * @param from Loop variable start - * @param to Loop variable bound - * @param step Loop variable step - * @param body Loop body - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_for(value const &loop_var, value const &from, value const &to, value const &step, - region const &body, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC(tinytc_for_inst_create(&instr, loop_var.get(), from.get(), to.get(), - step.get(), body.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make foreach loop instruction - * - * @param loop_var Loop variable - * @param from Loop variable start - * @param to Loop variable bound - * @param body Loop body - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_foreach(value const &loop_var, value const &from, value const &to, - region const &body, location const &loc = {}) { - tinytc_inst_t instr; - CHECK_STATUS_LOC( - tinytc_foreach_inst_create(&instr, loop_var.get(), from.get(), to.get(), body.get(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make if condition instruction - * - * @param condition Condition value (of type bool) - * @param then Then region - * @param otherwise Else region - * @param return_type_list Types of returned values - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_if(value const &condition, region const &then, region const &otherwise = region{}, - std::vector const &return_type_list = {}, - location const &loc = {}) { - tinytc_inst_t instr; - auto len = return_type_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("return type list too long"); - } - auto rl_vec = std::vector(); - rl_vec.resize(len); - for (auto const &rt : return_type_list) { - rl_vec.emplace_back(static_cast(rt)); - } - CHECK_STATUS_LOC(tinytc_if_inst_create(&instr, condition.get(), then.get(), otherwise.get(), - len, rl_vec.data(), &loc), - loc); - return inst(instr); -} - -/** - * @brief Make yield instruction - * - * @param yield_list Yielded values - * @param loc Source code location - * - * @return Instruction - */ -inline inst make_yield(std::vector const &yield_list, location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_inst_t instr; - auto len = yield_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("slice list too long"); - } - tinytc_value_t *yl = - const_cast(reinterpret_cast(yield_list.data())); - CHECK_STATUS_LOC(tinytc_yield_inst_create(&instr, len, yl, &loc), loc); - return inst(instr); -} - -//////////////////////////// -/////////// Func /////////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_func_t handle) -> tinytc_status_t { - return tinytc_func_retain(handle); - } - static auto release(tinytc_func_t handle) -> tinytc_status_t { - return tinytc_func_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_func_t -class func : public shared_handle { - public: - using shared_handle::shared_handle; -}; - -namespace internal { -//! Is reinterpret_cast(&f) allowed, where f has type func -constexpr bool func_reinterpret_allowed = - std::is_standard_layout_v && sizeof(func) == sizeof(tinytc_func_t); -} // namespace internal - -/** - * @brief Make function prototype - * - * @param name Function name - * @param arg_list Argument list - * @param loc Source code location - * - * @return Function - */ -inline func make_function_prototype(char const *name, std::vector &arg_list, - location const &loc = {}) { - static_assert(internal::value_reinterpret_allowed); - tinytc_func_t fun; - auto len = arg_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("argument list too long"); - } - tinytc_value_t *al = reinterpret_cast(arg_list.data()); - CHECK_STATUS_LOC(tinytc_function_prototype_create(&fun, name, len, al, &loc), loc); - return func(fun); -} - -/** - * @brief Make function - * - * @param prototype Function prototype - * @param body Function body - * @param loc Source code location - * - * @return Function - */ -inline func make_function(func const &prototype, region const &body, location const &loc = {}) { - tinytc_func_t fun; - CHECK_STATUS_LOC(tinytc_function_create(&fun, prototype.get(), body.get(), &loc), loc); - return func(fun); -} - -/** - * @brief Set work-group size (x,y) - * - * @param fun Function object; must have been created with "make_function" - * @param x x - * @param y y - */ -inline void set_work_group_size(func &fun, std::int32_t x, std::int32_t y) { - CHECK_STATUS(tinytc_function_set_work_group_size(fun.get(), x, y)); -} - -/** - * @brief Set subgroup size - * - * @param fun Function object; must have been created with "make_function" - * @param sgs Subgroup size - */ -inline void set_subgroup_size(func &fun, std::int32_t sgs) { - CHECK_STATUS(tinytc_function_set_subgroup_size(fun.get(), sgs)); -} - -//////////////////////////// -/////////// Prog /////////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_prog_t handle) -> tinytc_status_t { - return tinytc_prog_retain(handle); - } - static auto release(tinytc_prog_t handle) -> tinytc_status_t { - return tinytc_prog_release(handle); - } -}; -template <> struct unique_handle_traits { - static void destroy(char *obj) { tinytc_string_destroy(obj); } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_prog_t -class prog : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Dump program to stderr - */ - void dump() const { CHECK_STATUS(tinytc_prog_dump(obj_)); } - /** - * @brief Dump program to file - * - * @param filename Path to file - */ - void print_to_file(char const *filename) const { - CHECK_STATUS(tinytc_prog_print_to_file(obj_, filename)); - } - /** - * @brief Dump program to string - * - * @return C-string (unique handle) - */ - auto print_to_string() const -> unique_handle { - char *str; - CHECK_STATUS(tinytc_prog_print_to_string(obj_, &str)); - return unique_handle{str}; - } -}; - -/** - * @brief Make program - * - * @param fun_list Vector of functions - * @param loc Source code location - * - * @return Program - */ -inline prog make_program(std::vector &fun_list, location const &loc = {}) { - tinytc_prog_t prg; - static_assert(internal::func_reinterpret_allowed); - auto len = fun_list.size(); - if (len > std::numeric_limits::max()) { - throw std::out_of_range("function list too long"); - } - tinytc_func_t *fl = reinterpret_cast(fun_list.data()); - CHECK_STATUS_LOC(tinytc_program_create(&prg, len, fl, &loc), loc); - return prog{prg}; -} - -//////////////////////////// -////////// Builder ///////// -//////////////////////////// - -//! Builder for regions -class region_builder { - public: - /** - * @brief Returns built product - * - * @param loc Source code location - * - * @return Region - */ - inline auto get_product(location const &loc = {}) -> region { - return make_region(instructions_, loc); - } - - /** - * @brief Add instruction - * - * @param i Instruction - * @param name Result name - * - * @return Value returned by instruction; may be empty - */ - [[maybe_unused]] inline auto add(inst i, std::string const &name = "") -> value { - auto result = i.get_value(); - if (result && name.size() > 0) { - result.name(name); - } - instructions_.emplace_back(std::move(i)); - return result; - } - - /** - * @brief Add instruction that returns multiple values - * - * @param i Instruction - * @param name Result name - * - * @return Values returned by instruction - */ - [[maybe_unused]] inline auto add_multivalued(inst i, std::string const &name = "") - -> std::vector { - auto results = i.get_values(); - if (name.size() > 0) { - int counter = 0; - for (auto &result : results) { - result.name(name + std::to_string(counter++)); - } - } - instructions_.emplace_back(std::move(i)); - return results; - } - - /** - * @brief Build for-loop with functor f(region_builder&) -> void - * - * @tparam F Functor type - * @param loop_var_ty Type of loop variable - * @param from Loop variable start - * @param to Loop variable bound - * @param f Functor - * @param name Loop variable name - * @param loc Source code location - */ - template - void for_loop(scalar_type loop_var_ty, value const &from, value const &to, F &&f, - std::string const &name = "", location const &loc = {}) { - for_loop(std::move(loop_var_ty), std::move(from), std::move(to), value{nullptr}, - std::forward(f), name, loc); - } - /** - * @brief Build for-loop with functor f(region_builder&) -> void - * - * @tparam F Functor type - * @param loop_var_ty Type of loop variable - * @param from Loop variable start - * @param to Loop variable bound - * @param step Loop variable step - * @param f Functor - * @param name Loop variable name - * @param loc Source code location - */ - template - void for_loop(scalar_type loop_var_ty, value const &from, value const &to, value const &step, - F &&f, std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(loop_var_ty); - if (name.size() > 0) { - loop_var.name(name); - } - auto bb = region_builder{}; - f(bb); - add(::tinytc::make_for(std::move(loop_var), from, to, step, bb.get_product(), loc)); - } - /** - * @brief Build foreach-loop with functor f(region_builder&) -> void - * - * @tparam F Functor type - * @param loop_var_ty Type of loop variable - * @param from Loop variable start - * @param to Loop variable bound - * @param f functor - * @param name Loop variable name - * @param loc Source code location - */ - template - void foreach (data_type const &loop_var_ty, value const &from, value const &to, F && f, - std::string const &name = "", location const &loc = {}) { - auto loop_var = make_value(loop_var_ty); - if (name.size() > 0) { - loop_var.name(name); - } - auto bb = region_builder{}; - f(bb); - add(::tinytc::make_foreach(std::move(loop_var), from, to, bb.get_product(), loc)); - } - - /** - * @brief Build if with functor then(region_builder&) -> void - * - * @tparam F Functor type - * @param condition Condition value - * @param then Then region functor - * @param return_type_list Types of returned values - * @param loc Source code location - * - * @return Returned values - */ - template - auto if_condition(value const &condition, F &&then, - std::vector const &return_type_list = {}, - location const &loc = {}) -> std::vector { - auto bb = region_builder{}; - then(bb); - return add_multivalued(::tinytc::make_if(std::move(condition), bb.get_product(), region{}, - return_type_list, loc)); - } - /** - * @brief Build if/else with functors then(region_builder&) -> void and - * otherwise(region_builder&) -> void - * - * @tparam F "if" functor type - * @tparam G "else" functor type - * @param condition If condition - * @param then "if" functor - * @param otherwise "else" functor - * @param return_type_list List of types of returned values - * @param loc Source code location - * - * @return Returned values - */ - template - auto ifelse(value const &condition, F &&then, G &&otherwise, - std::vector const &return_type_list = {}, location const &loc = {}) - -> std::vector { - auto bb1 = region_builder{}; - then(bb1); - auto bb2 = region_builder{}; - otherwise(bb2); - return add_multivalued(::tinytc::make_if(std::move(condition), bb1.get_product(), - bb2.get_product(), return_type_list, loc)); - } - - private: - std::vector instructions_; -}; - -//! Builder for functions -class function_builder { - public: - /** - * @brief creates function \@name - * - * @param name Function name - */ - inline function_builder(std::string name) : name_(std::move(name)), body_{nullptr} {} - - /** - * @brief Returns built product - * - * @param loc Source code location - * - * @return Function - */ - inline func get_product(location const &loc = {}) { - auto proto = make_function_prototype(name_.c_str(), arguments_, loc); - auto fun = make_function(proto, body_); - if (x_ > 0 && y_ > 0) { - set_work_group_size(fun, x_, y_); - } - if (sgs_ > 0) { - set_subgroup_size(fun, sgs_); - } - return fun; - } - - /** - * @brief @code %name: %ty @endcode - * - * @param ty Argument type - * @param name Argument name - * @param loc Source code location - * - * @return Value - */ - inline value argument(data_type const &ty, std::string const &name = "", - location const &loc = {}) { - auto v = make_value(ty, loc); - if (name.size() > 0) { - v.name(name); - } - arguments_.emplace_back(std::move(v)); - return arguments_.back(); - } - - /** - * @brief @code work_group_size(%x, %y) @endcode - * - * @param x x - * @param y y - */ - inline void work_group_size(std::int32_t x, std::int32_t y) { - x_ = x; - y_ = y; - } - /** - * @brief @code subgroup_size(%subgroup_size) @endcode - * - * @param subgroup_size Subgroup size - */ - inline void subgroup_size(std::int32_t subgroup_size) { sgs_ = subgroup_size; } - - /** - * @brief Build function body with functor f(region_builder&) -> void - * - * @tparam F Functor type - * @param f Functor - * @param loc Source code location - */ - template void body(F &&f, location const &loc = {}) { - auto bb = region_builder{}; - f(bb); - body_ = bb.get_product(loc); - } - - private: - std::string name_; - region body_; - std::vector arguments_; - std::int32_t x_ = 0, y_ = 0, sgs_ = 0; -}; - -//! Builder for programs -class program_builder { - public: - /** - * @brief create function \@name with functor f(function_builder&) -> void - * - * @tparam F Functor type - * @param name Function name - * @param f Functor - * @param loc Source code location - */ - template void create(std::string name, F &&f, location const &loc = {}) { - auto fb = function_builder(std::move(name)); - f(fb); - add(fb.get_product(loc)); - } - /** - * @brief Add function - * - * @param f function - */ - inline void add(func f) { functions_.emplace_back(std::move(f)); } - /** - * @brief Returns built product - * - * @param loc Source code location - * - * @return Program - */ - inline prog get_product(location const &loc = {}) { return make_program(functions_, loc); } - - private: - std::vector functions_; -}; - -//////////////////////////// -//////// Device info /////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_core_info_t handle) -> tinytc_status_t { - return tinytc_core_info_retain(handle); - } - static auto release(tinytc_core_info_t handle) -> tinytc_status_t { - return tinytc_core_info_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_core_info_t -class core_info : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get subgroup sizes - * - * Cf. @ref tinytc_core_info_get_subgroup_sizes - * - * @param sgs_size Pointer to size of subgroup size array - * @param sgs Pointer ot subgroup size array - */ - void get_subgroup_sizes(std::uint32_t *sgs_size, std::int32_t const **sgs) { - CHECK_STATUS(tinytc_core_info_get_subgroup_sizes(obj_, sgs_size, sgs)); - } - - /** - * @brief Get register space per subgroup in bytes - * - * @return Register space - */ - auto get_register_space() -> std::int32_t { - std::int32_t space; - CHECK_STATUS(tinytc_core_info_get_register_space(obj_, &space)); - return space; - } - - /** - * @brief Set core features - * - * @param flags set core features; must be 0 or a combination of tinytc_core_feature_flag_t - */ - void set_core_features(tinytc_core_feature_flags_t flags) { - CHECK_STATUS(tinytc_core_info_set_core_features(obj_, flags)); - } - - /** - * @brief Get core features - * - * @return Core features - */ - auto get_core_features() const -> tinytc_core_feature_flags_t { - tinytc_core_feature_flags_t flags; - CHECK_STATUS(tinytc_core_info_get_core_features(obj_, &flags)); - return flags; - } -}; - -/** - * @brief Create core info for generic GPUs manually - * - * @param register_space Size of register file per subgroup in bytes - * @param max_work_group_size Maximum size of local work group - * @param sgs Subgrouip sizes - * - * @return Core info - */ -inline auto make_core_info_generic(std::int32_t register_space, std::int32_t max_work_group_size, - std::vector sgs) -> core_info { - tinytc_core_info_t info; - CHECK_STATUS(tinytc_core_info_generic_create(&info, register_space, max_work_group_size, - sgs.size(), sgs.data())); - return core_info{info}; -} - -/** - * @brief Get core info for Intel GPUs from lookup table - * - * @param arch IP version - * - * @return Core info - */ -inline auto make_core_info_intel_from_arch(intel_gpu_architecture arch) -> core_info { - tinytc_core_info_t info; - CHECK_STATUS(tinytc_core_info_intel_create_from_arch( - &info, static_cast(arch))); - return core_info{info}; -} - -/** - * @brief Create core info for Intel GPUs manually - * - * @param ip_version IP version - * @param num_eus_per_subslice Number of EUs (XVEs) per subslice (XeCore) - * @param num_threads_per_eu Number of hardware threads per EU (XVE) - * @param sgs Subgrouip sizes - * - * @return Core info - */ -inline auto make_core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, - std::int32_t num_threads_per_eu, std::vector sgs) - -> core_info { - tinytc_core_info_t info; - CHECK_STATUS(tinytc_core_info_intel_create(&info, ip_version, num_eus_per_subslice, - num_threads_per_eu, sgs.size(), sgs.data())); - return core_info{info}; -} - -//////////////////////////// -////////// Parser ////////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_source_context_t handle) -> tinytc_status_t { - return tinytc_source_context_retain(handle); - } - static auto release(tinytc_source_context_t handle) -> tinytc_status_t { - return tinytc_source_context_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_source_context_t -class source_context : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Add source to context - * - * @param name File name - * @param text Source text - * - * @return Source id (should be set in position.source_id) - */ - inline auto add_source(char const *name, char const *text) -> std::int32_t { - std::int32_t source_id; - CHECK_STATUS(tinytc_source_context_add_source(obj_, name, text, &source_id)); - return source_id; - } - /** - * @brief Get error log - * - * @return C-string that is valid as long as source_context is not modified; empty string if - * source_context is empty - */ - inline auto get_error_log() const noexcept -> char const * { - if (obj_) { - char const *log; - // No need to call CHECK_STATUS, as the only possible error code is - // tinytc_status_invalid_arguments but we only pass valid arguments - tinytc_source_context_get_error_log(obj_, &log); - return log; - } - return ""; - } - /** - * @brief Enhance error message with source context; useful when builder is used - * - * @param loc Source location - * @param what Error description - * @param append True: append to error log; false: clear error log - */ - inline void report_error(location const &loc, char const *what, bool append = true) { - CHECK_STATUS(tinytc_source_context_report_error(obj_, &loc, what, - static_cast(append))); - } -}; - -/** - * @brief Create source context - * - * @return Source context - */ -inline auto make_source_context() -> source_context { - tinytc_source_context_t ctx; - CHECK_STATUS(tinytc_source_context_create(&ctx)); - return source_context{ctx}; -} - -/** - * @brief Parse source text from file - * - * @param filename Filename - * @param source_ctx Source context for improved error reporting - * - * @return Program - */ -inline auto parse_file(char const *filename, source_context source_ctx = {}) -> prog { - tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_file(&prg, filename, source_ctx.get())); - return prog(prg); -} - -/** - * @brief Parse source text from stdin - * - * @param source_ctx Source context for improved error reporting - * - * @return Program - */ -inline auto parse_stdin(source_context source_ctx = {}) -> prog { - tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_stdin(&prg, source_ctx.get())); - return prog(prg); -} -/** - * @brief Parse source text from string - * - * @param src Source text - * @param source_ctx Source context for improved error reporting - * - * @return Porgram - */ -inline auto parse_string(std::string const &src, source_context source_ctx = {}) -> prog { - tinytc_prog_t prg; - CHECK_STATUS(tinytc_parse_string(&prg, src.size(), src.c_str(), source_ctx.get())); - return prog(prg); -} - -//////////////////////////// -///////// Compiler ///////// -//////////////////////////// - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_source_t handle) -> tinytc_status_t { - return tinytc_source_retain(handle); - } - static auto release(tinytc_source_t handle) -> tinytc_status_t { - return tinytc_source_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_source_t -class source : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get code - * - * @return Pointer to C-string that is bound to the lifetime of the source object - */ - inline auto get_code() const -> std::string_view { - char const *code = nullptr; - std::size_t length = 0; - CHECK_STATUS(tinytc_source_get_code(obj_, &length, &code)); - return std::string_view(code, length); - } - - /** - * @brief Get location - * - * @return Location - */ - inline auto get_location() const -> location { - location loc = {}; - CHECK_STATUS(tinytc_source_get_location(obj_, &loc)); - return loc; - } - - /** - * @brief Get OpenCL extension - * - * @param extensions_size Number of extensions - * @param extensions Array of extensions - */ - inline void get_extensions(std::uint32_t &extensions_size, - char const *const *&extensions) const { - CHECK_STATUS(tinytc_source_get_extensions(obj_, &extensions_size, &extensions)); - } -}; - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_binary_t handle) -> tinytc_status_t { - return tinytc_binary_retain(handle); - } - static auto release(tinytc_binary_t handle) -> tinytc_status_t { - return tinytc_binary_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_binary_t -class binary : public shared_handle { - public: - using shared_handle::shared_handle; - - //! Container for raw data - struct raw { - bundle_format format; ///< Bundle format - std::size_t data_size; ///< Size of binary data in bytes - std::uint8_t const *data; ///< Pointer to binary data - }; - - /** - * @brief Get raw data - * - * @return Raw data - */ - inline auto get_raw() -> raw { - raw r; - tinytc_bundle_format_t f; - CHECK_STATUS(tinytc_binary_get_raw(obj_, &f, &r.data_size, &r.data)); - r.format = bundle_format{std::underlying_type_t(f)}; - return r; - } - /** - * @brief Get core features - * - * @return Core features - */ - inline auto get_core_features() -> tinytc_core_feature_flags_t { - tinytc_core_feature_flags_t cf; - CHECK_STATUS(tinytc_binary_get_core_features(obj_, &cf)); - return cf; - } -}; - -/** - * @brief Make binary - * - * @param format Bundle format (SPIR-V or Native) - * @param data_size Size of data in bytes - * @param data Binary data; data is copied - * @param core_features requested core features; must be 0 (default) or a combination of - * tinytc_core_feature_flag_t - * - * @return Binary - */ -inline auto make_binary(bundle_format format, std::size_t data_size, std::uint8_t const *data, - tinytc_core_feature_flags_t core_features) -> binary { - tinytc_binary_t bin; - CHECK_STATUS(tinytc_binary_create(&bin, static_cast(format), data_size, - data, core_features)); - return binary{bin}; -} - -/** - * @brief Compile program to OpenCL-C - * - * @param prg Program - * @param info Core info - * @param ctx Source context for improved error reporting - * - * @return Source - */ -inline auto compile_to_opencl(prog prg, core_info const &info, source_context ctx = {}) -> source { - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, prg.get(), info.get(), ctx.get())); - return source{src}; -} - -//////////////////////////// -////////// Recipe ////////// -//////////////////////////// - -/** - * @brief Guess memory type of memory object - * - * @tparam T memory object type - */ -template struct auto_mem_type; - -/** - * @brief True if T is either pointer to a fundamental type or a pointer to a pointer to a - * fundamental type - * - * @tparam T type - */ -template -constexpr bool usm_pointer_type = - std::is_pointer_v && - (std::is_fundamental_v> || - std::is_fundamental_v>>); - -/** - * @brief Specialize auto_mem_type for pointer to non-class types - * - * All pointers to scalars are assumed to be Unified Shared Memory pointers. - * (Automatic guessing for Shared Virtual Memory pointers not implemented.) - * - * @tparam T memory object type - */ -template struct auto_mem_type>> { - constexpr static mem_type value = mem_type::usm_pointer; ///< Pointer maps to USM pointer type -}; - -/** - * @brief Convenience wrapper for auto_mem_type - * - * @tparam T memory object type - */ -template inline constexpr auto auto_mem_type_v = auto_mem_type::value; - -//! Type-safe wrapper for memory objects -struct mem { - /** - * @brief ctor - * - * @tparam T pointer type or buffer type - * @param value USM / SVM pointer or cl_mem (cl_mem implicitly converts to void*) - * @param type memory object type - */ - template - inline mem(T const value, mem_type type = auto_mem_type_v) : value{value}, type{type} {} - - const void *value; ///< USM / SVM pointer or cl_mem (passed by value) - mem_type type; ///< Memory object type -}; - -namespace internal { -template <> struct shared_handle_traits { - static auto retain(tinytc_recipe_t handle) -> tinytc_status_t { - return tinytc_recipe_retain(handle); - } - static auto release(tinytc_recipe_t handle) -> tinytc_status_t { - return tinytc_recipe_release(handle); - } -}; -template <> struct shared_handle_traits { - static auto retain(tinytc_recipe_handler_t handle) -> tinytc_status_t { - return tinytc_recipe_handler_retain(handle); - } - static auto release(tinytc_recipe_handler_t handle) -> tinytc_status_t { - return tinytc_recipe_handler_release(handle); - } -}; -} // namespace internal - -//! @brief Reference-counting wrapper for tinytc_recipe_t -class recipe : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get program - * - * @return Program - */ - auto get_prog() const -> prog { - tinytc_prog_t prg; - CHECK_STATUS(tinytc_recipe_get_prog(obj_, &prg)); - return prog{prg}; - } - - /** - * @brief Get source - * - * @return Source - */ - auto get_source() const -> source { - tinytc_source_t src; - CHECK_STATUS(tinytc_recipe_get_source(obj_, &src)); - return source{src}; - } -}; - -//! @brief Reference-counting wrapper for tinytc_recipe_handler_t -class recipe_handler : public shared_handle { - public: - using shared_handle::shared_handle; - - /** - * @brief Get recipe - * - * @return Recipe - */ - auto get_recipe() const -> recipe { - tinytc_recipe_t rec; - CHECK_STATUS(tinytc_recipe_handler_get_recipe(obj_, &rec)); - return recipe{rec}; - } -}; - -//! @brief Reference-counting wrapper for tinytc_recipe_t -class small_gemm_batched : public recipe { - public: - using recipe::recipe; - - /** - * @brief Set kernel arguments - * - * @tparam T Scalar type; must match scalar_type passed to constructor - * @param handler Recipe handler - * @param howmany Batch size - * @param alpha @f$\alpha@f$ - * @param A Memory object used for A-matrix - * @param B Memory object used for B-matrix - * @param beta @f$\beta@f$ - * @param C Memory object used for C-matrix - */ - template - static void set_args(recipe_handler &handler, std::int64_t howmany, T alpha, mem A, mem B, - T beta, mem C) { - CHECK_STATUS(tinytc_recipe_small_gemm_batched_set_args( - handler.get(), howmany, sizeof(alpha), &alpha, static_cast(A.type), - A.value, static_cast(B.type), B.value, sizeof(beta), &beta, - static_cast(C.type), C.value)); - } -}; - -/** - * @brief Make small GEMM batched recipe - * - * Cf. @ref tinytc_recipe_small_gemm_batched_create - * - * @param info Core info - * @param ty Scalar type of @f$\alpha@f$, A, B, @f$\beta@f$, C - * @param tA Operation applied on A - * @param tB Operation applied on B - * @param M Number of rows of A and C - * @param N Number of columns of B and C - * @param K Number of columns of A, number of rows of B - * @param ldA Leading dimension of an A matrix - * @param strideA Stride of A-matrices - * @param ldB Leading dimension of an B matrix - * @param strideB Stride of B-matrices - * @param ldC Leading dimension of an C matrix - * @param strideC Stride of C-matrices - * @param ctx Source context for improved error reporting - * - * @return Small GEMM batched recipe - */ -inline auto make_small_gemm_batched(core_info const &info, scalar_type ty, transpose tA, - transpose tB, std::int64_t M, std::int64_t N, std::int64_t K, - std::int64_t ldA, std::int64_t strideA, std::int64_t ldB, - std::int64_t strideB, std::int64_t ldC, std::int64_t strideC, - source_context ctx = {}) -> small_gemm_batched { - tinytc_recipe_t rec; - CHECK_STATUS(tinytc_recipe_small_gemm_batched_create( - &rec, info.get(), static_cast(ty), - static_cast(tA), static_cast(tB), M, N, K, ldA, - strideA, ldB, strideB, ldC, strideC, ctx.get())); - return small_gemm_batched{rec}; -} - -//! @brief Reference-counting wrapper for tinytc_recipe_t -class tall_and_skinny : public recipe { - public: - using recipe::recipe; - - /** - * @brief Set kernel arguments - * - * @tparam T Scalar type; must match scalar_type passed to constructor - * @param handler Recipe handler - * @param M Number of rows of A and C - * @param alpha @f$\alpha@f$ - * @param A Memory object used for A-matrix - * @param ldA Leading dimension of A - * @param B Memory object used for B-matrix - * @param ldB Leading dimension of B - * @param beta @f$\beta@f$ - * @param C Memory object used for C-matrix - * @param ldC Leading dimension of C - */ - template - static void set_args(recipe_handler &handler, std::int64_t M, T alpha, mem A, std::int64_t ldA, - mem B, std::int64_t ldB, T beta, mem C, std::int64_t ldC) { - CHECK_STATUS(tinytc_recipe_tall_and_skinny_set_args( - handler.get(), M, sizeof(alpha), &alpha, static_cast(A.type), - A.value, ldA, static_cast(B.type), B.value, ldB, sizeof(beta), &beta, - static_cast(C.type), C.value, ldC)); - } -}; - -/** - * @brief Make tall and skinny recipe - * - * Cf. @ref tinytc_recipe_tall_and_skinny_create - * - * @param info Core info - * @param ty Scalar type of @f$\alpha@f$, A, B, @f$\beta@f$, C - * @param N Number of columns of B and C - * @param K Number of columns of A, number of rows of B - * @param M_block_size Chunk size for M-mode - * @param ctx Source context for improved error reporting - * - * @return Tall and skinny recipe - */ -inline auto make_tall_and_skinny(core_info const &info, scalar_type ty, std::int64_t N, - std::int64_t K, std::int32_t M_block_size = 0, - source_context ctx = {}) -> tall_and_skinny { - tinytc_recipe_t rec; - CHECK_STATUS(tinytc_recipe_tall_and_skinny_create( - &rec, info.get(), static_cast(ty), N, K, M_block_size, ctx.get())); - return tall_and_skinny{rec}; -} - -/** - * @brief Make tall and skinny recipe with additional specialization constants - * - * Cf. @ref tinytc_recipe_tall_and_skinny_create_specialized - * - * @param info Core info - * @param ty Scalar type of @f$\alpha@f$, A, B, @f$\beta@f$, C - * @param M Number of rows of A and C; can be dynamic - * @param N Number of columns of B and C - * @param K Number of columns of A, number of rows of B - * @param ldA Leading dimension of A; can be dynamic - * @param ldB Leading dimension of B; can be dynamic - * @param ldC Leading dimension of C; can be dynamic - * @param M_block_size Chunk size for M-mode - * @param ctx Source context for improved error reporting - * - * @return Tall and skinny recipe - */ -inline auto make_tall_and_skinny_specialized(core_info const &info, scalar_type ty, std::int64_t M, - std::int64_t N, std::int64_t K, std::int64_t ldA, - std::int64_t ldB, std::int64_t ldC, - std::int32_t M_block_size = 0, source_context ctx = {}) - -> tall_and_skinny { - tinytc_recipe_t rec; - CHECK_STATUS(tinytc_recipe_tall_and_skinny_create_specialized( - &rec, info.get(), static_cast(ty), M, N, K, ldA, ldB, ldC, - M_block_size, ctx.get())); - return tall_and_skinny{rec}; -} - -} // namespace tinytc - -#endif // TINYTC_20240403_HPP +#endif // TINYTC_20250704_HPP diff --git a/include/tinytc/tinytc_cl.h b/include/tinytc/tinytc_cl.h index 148958b2..402be1d9 100644 --- a/include/tinytc/tinytc_cl.h +++ b/include/tinytc/tinytc_cl.h @@ -17,13 +17,11 @@ extern "C" { /////////// Error ////////// //////////////////////////// -TINYTC_EXPORT tinytc_status_t tinytc_cl_convert_status(cl_int status); - #define TINYTC_CL_CHECK_STATUS(X) \ do { \ cl_int stat = X; \ if (stat != CL_SUCCESS) { \ - return tinytc_cl_convert_status(stat); \ + return tinytc_status_compute_runtime_error; \ } \ } while (0) @@ -57,22 +55,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *inf ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bundle [out] pointer to the kernel bundle (cl_program) object created - * @param context [in] context handle - * @param device [in] device handle - * @param src [in] source text and extensions - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source( - cl_program *bundle, cl_context context, cl_device_id device, const_tinytc_source_t src, - tinytc_source_context_t source_ctx); - /** * @brief Compile tensor program * @@ -82,14 +64,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_source( * @param prg [inout] tensor program; modified as compiler passes are run * @param core_features [in][optional] requested core features; must be 0 (default) or a combination * of tinytc_core_feature_flag_t - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, tinytc_source_context_t source_ctx); + tinytc_core_feature_flags_t core_features); /** * @brief Create an OpenCL program from a tinytc binary @@ -98,14 +78,13 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( * @param context [in] context handle * @param device [in] device handle * @param bin [in] binary object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary( - cl_program *bundle, cl_context context, cl_device_id device, const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, + cl_context context, + cl_device_id device, + const_tinytc_binary_t bin); /** * @brief Get work group size for kernel @@ -121,11 +100,11 @@ TINYTC_EXPORT tinytc_status_t tinytc_cl_get_group_size(cl_kernel kernel, size_t /** * @brief Convert group size to opencl global range * - * @param howmany group size + * @param num_groups [in][range(0,3)] pointer to number of groups of size >= 3 * @param local_size [in][range(0,3)] pointer to local size array of size >= 3 * @param global_size [out][range(0,3)] pointer to global size array of size >= 3 */ -TINYTC_EXPORT void tinytc_cl_get_global_size(int64_t howmany, const size_t *local_size, +TINYTC_EXPORT void tinytc_cl_get_global_size(const size_t *num_groups, const size_t *local_size, size_t *global_size); //////////////////////////// @@ -139,16 +118,13 @@ TINYTC_EXPORT void tinytc_cl_get_global_size(int64_t howmany, const size_t *loca * @param context [in] context handle * @param device [in] device handle * @param recipe [in] recipe object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_cl_recipe_handler_create(tinytc_recipe_handler_t *handler, cl_context context, cl_device_id device, - tinytc_recipe_t recipe, - tinytc_source_context_t source_ctx); + tinytc_recipe_t recipe); /** * @brief Submit recipe to device diff --git a/include/tinytc/tinytc_cl.hpp b/include/tinytc/tinytc_cl.hpp index 93227d29..760781c4 100644 --- a/include/tinytc/tinytc_cl.hpp +++ b/include/tinytc/tinytc_cl.hpp @@ -20,7 +20,7 @@ namespace tinytc { //! Throw exception for unsuccessful call to C-API and convert result code to tinytc status inline void CL_CHECK_STATUS(cl_int stat) { if (stat != CL_SUCCESS) { - throw status{std::underlying_type_t(::tinytc_cl_convert_status(stat))}; + throw status::compute_runtime_error; } } @@ -48,10 +48,10 @@ inline auto get_support_level(cl_device_id device) -> support_level { * * @return core info */ -inline auto make_core_info(cl_device_id device) -> core_info { +inline auto create_core_info(cl_device_id device) -> shared_handle { tinytc_core_info_t info; CHECK_STATUS(::tinytc_cl_core_info_create(&info, device)); - return core_info{info}; + return shared_handle{info}; } //////////////////////////// @@ -59,42 +59,28 @@ inline auto make_core_info(cl_device_id device) -> core_info { //////////////////////////// namespace internal { +inline auto convert_status(cl_int stat) { + return stat != CL_SUCCESS ? tinytc_status_compute_runtime_error : tinytc_status_success; +} + template <> struct shared_handle_traits { static auto retain(cl_program handle) -> tinytc_status_t { - return ::tinytc_cl_convert_status(clRetainProgram(handle)); + return convert_status(clRetainProgram(handle)); } static auto release(cl_program handle) -> tinytc_status_t { - return ::tinytc_cl_convert_status(clReleaseProgram(handle)); + return convert_status(clReleaseProgram(handle)); } }; template <> struct shared_handle_traits { static auto retain(cl_kernel handle) -> tinytc_status_t { - return ::tinytc_cl_convert_status(clRetainKernel(handle)); + return convert_status(clRetainKernel(handle)); } static auto release(cl_kernel handle) -> tinytc_status_t { - return ::tinytc_cl_convert_status(clReleaseKernel(handle)); + return convert_status(clReleaseKernel(handle)); } }; } // namespace internal -/** - * @brief Make an OpenCL program from a tinytc source - * - * @param context Context - * @param device Device - * @param src Source - * @param source_ctx Source context for improved error reporting - * - * @return cl_program (shared handle) - */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, source const &src, - source_context source_ctx = {}) -> shared_handle { - cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_source(&obj, context, device, src.get(), - source_ctx.get())); - return shared_handle{obj}; -} - /** * @brief Make an OpenCL program from a tinytc program * @@ -103,16 +89,15 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, source c * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) -> shared_handle { +inline auto create_kernel_bundle(cl_context context, cl_device_id device, tinytc_prog_t prg, + tinytc_core_feature_flags_t core_features = 0) + -> shared_handle { cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_program(&obj, context, device, prg.get(), - core_features, source_ctx.get())); + CHECK_STATUS( + tinytc_cl_kernel_bundle_create_with_program(&obj, context, device, prg, core_features)); return shared_handle{obj}; } @@ -122,15 +107,13 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, prog prg * @param context Context * @param device Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return cl_program (shared handle) */ -inline auto make_kernel_bundle(cl_context context, cl_device_id device, binary const &bin, - source_context source_ctx = {}) -> shared_handle { +inline auto create_kernel_bundle(cl_context context, cl_device_id device, const_tinytc_binary_t bin) + -> shared_handle { cl_program obj; - CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_binary(&obj, context, device, bin.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_cl_kernel_bundle_create_with_binary(&obj, context, device, bin)); return shared_handle{obj}; } @@ -142,7 +125,7 @@ inline auto make_kernel_bundle(cl_context context, cl_device_id device, binary c * * @return cl_kernel (shared handle) */ -inline auto make_kernel(cl_program mod, char const *name) -> shared_handle { +inline auto create_kernel(cl_program mod, char const *name) -> shared_handle { cl_int err; cl_kernel obj = clCreateKernel(mod, name, &err); CL_CHECK_STATUS(err); @@ -165,15 +148,16 @@ inline auto get_group_size(cl_kernel kernel) -> std::array { /** * @brief Convert group size to opencl global range * - * @param howmany Group size + * @param num_groups Number of groups * @param local_size Work-group size * * @return Global size */ -inline auto get_global_size(std::int64_t howmany, std::array const &local_size) +inline auto get_global_size(std::array const &num_groups, + std::array const &local_size) -> std::array { auto global_size = std::array{}; - tinytc_cl_get_global_size(howmany, local_size.data(), global_size.data()); + tinytc_cl_get_global_size(num_groups.data(), local_size.data(), global_size.data()); return global_size; } @@ -184,10 +168,10 @@ inline auto get_global_size(std::int64_t howmany, std::array co namespace internal { template <> struct shared_handle_traits { static auto retain(cl_event handle) -> tinytc_status_t { - return ::tinytc_cl_convert_status(clRetainEvent(handle)); + return convert_status(clRetainEvent(handle)); } static auto release(cl_event handle) -> tinytc_status_t { - return ::tinytc_cl_convert_status(clReleaseEvent(handle)); + return convert_status(clReleaseEvent(handle)); } }; } // namespace internal @@ -200,41 +184,36 @@ template <> struct auto_mem_type { }; /** - * @brief Recipe handler for the OpenCL runtime + * @brief Submit recipe to queue + * + * @param handler Recipe handler + * @param queue Command queue + * @param num_wait_events Number of events to wait + * @param wait_events Array of num_wait_events events to wait on + * + * @return Event (cl_event wrapped in shared_handle -> cleans up automatically) */ -class opencl_recipe_handler : public recipe_handler { - public: - using recipe_handler::recipe_handler; - - /** - * @brief Submit recipe to queue - * - * @param queue Command queue - * @param num_wait_events Number of events to wait - * @param wait_events Array of num_wait_events events to wait on - * - * @return Event (cl_event wrapped in shared_handle -> cleans up automatically) - */ - inline auto submit(cl_command_queue queue, uint32_t num_wait_events = 0, - cl_event *wait_events = nullptr) -> shared_handle { - cl_event evt; - CHECK_STATUS( - tinytc_cl_recipe_handler_submit(obj_, queue, num_wait_events, wait_events, &evt)); - return shared_handle{evt}; - } - /** - * @brief Submit recipe to queue; does not return event - * - * @param queue Command queue - * @param num_wait_events Number of events to wait - * @param wait_events Array of num_wait_events events to wait on - */ - inline void submit_no_event(cl_command_queue queue, uint32_t num_wait_events = 0, - cl_event *wait_events = nullptr) { - CHECK_STATUS( - tinytc_cl_recipe_handler_submit(obj_, queue, num_wait_events, wait_events, NULL)); - } -}; +inline auto submit(tinytc_recipe_handler_t handler, cl_command_queue queue, + uint32_t num_wait_events = 0, cl_event *wait_events = nullptr) + -> shared_handle { + cl_event evt; + CHECK_STATUS( + tinytc_cl_recipe_handler_submit(handler, queue, num_wait_events, wait_events, &evt)); + return shared_handle{evt}; +} +/** + * @brief Submit recipe to queue; does not return event + * + * @param handler Recipe handler + * @param queue Command queue + * @param num_wait_events Number of events to wait + * @param wait_events Array of num_wait_events events to wait on + */ +inline void submit_no_event(tinytc_recipe_handler_t handler, cl_command_queue queue, + uint32_t num_wait_events = 0, cl_event *wait_events = nullptr) { + CHECK_STATUS( + tinytc_cl_recipe_handler_submit(handler, queue, num_wait_events, wait_events, NULL)); +} /** * @brief Make recipe handler @@ -242,16 +221,14 @@ class opencl_recipe_handler : public recipe_handler { * @param context Context * @param device Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return OpenCL recipe handler */ -inline auto make_recipe_handler(cl_context context, cl_device_id device, recipe const &rec, - source_context source_ctx = {}) -> opencl_recipe_handler { +inline auto create_recipe_handler(cl_context context, cl_device_id device, tinytc_recipe_t rec) + -> shared_handle { tinytc_recipe_handler_t handler; - CHECK_STATUS( - tinytc_cl_recipe_handler_create(&handler, context, device, rec.get(), source_ctx.get())); - return opencl_recipe_handler{handler}; + CHECK_STATUS(tinytc_cl_recipe_handler_create(&handler, context, device, rec)); + return shared_handle{handler}; } } // namespace tinytc diff --git a/include/tinytc/tinytc_sycl.hpp b/include/tinytc/tinytc_sycl.hpp index 52bc2b80..7bcebfb5 100644 --- a/include/tinytc/tinytc_sycl.hpp +++ b/include/tinytc/tinytc_sycl.hpp @@ -35,26 +35,12 @@ TINYTC_EXPORT auto get_support_level(sycl::device const &dev) -> support_level; * * @return core info */ -TINYTC_EXPORT auto make_core_info(sycl::device const &dev) -> core_info; +TINYTC_EXPORT auto create_core_info(sycl::device const &dev) -> shared_handle; //////////////////////////// ////////// Kernel ////////// //////////////////////////// -/** - * @brief Make SYCL kernel bundle from tinytc source - * - * @param ctx Context - * @param dev Device - * @param src Source - * @param source_ctx Source context for improved error reporting - * - * @return SYCL kernel bundle - */ -TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - source const &src, source_context source_ctx = {}) - -> sycl::kernel_bundle; - /** * @brief Make SYCL kernel bundle from tinytc program * @@ -63,13 +49,12 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ -TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) +TINYTC_EXPORT auto create_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, + tinytc_prog_t prg, + tinytc_core_feature_flags_t core_features = 0) -> sycl::kernel_bundle; /** @@ -78,12 +63,11 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * @param ctx Context * @param dev Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return SYCL kernel bundle */ -TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, - binary const &bin, source_context source_ctx = {}) +TINYTC_EXPORT auto create_kernel_bundle(sycl::context const &ctx, sycl::device const &dev, + const_tinytc_binary_t bin) -> sycl::kernel_bundle; /** @@ -94,8 +78,8 @@ TINYTC_EXPORT auto make_kernel_bundle(sycl::context const &ctx, sycl::device con * * @return SYCL kernel */ -TINYTC_EXPORT auto make_kernel(sycl::kernel_bundle const &bundle, - char const *name) -> sycl::kernel; +TINYTC_EXPORT auto create_kernel(sycl::kernel_bundle const &bundle, + char const *name) -> sycl::kernel; /** * @brief Get work-group size @@ -109,23 +93,29 @@ TINYTC_EXPORT auto get_group_size(sycl::kernel const &krnl) -> sycl::range<3u>; /** * @brief Convert group size to SYCL range * - * @param howmany Group size + * **Important:** num_groups is in SYCL ZYX order, meaning that the range should contain + * {num_groups_z, num_groups_y, num_groups_x}. + * + * @param num_groups Number of groups * @param local_size Work-group size * * @return Global size */ -TINYTC_EXPORT auto get_global_size(std::int64_t howmany, sycl::range<3u> const &local_size) - -> sycl::range<3u>; +TINYTC_EXPORT auto get_global_size(sycl::range<3u> const &num_groups, + sycl::range<3u> const &local_size) -> sycl::range<3u>; /** * @brief Get SYCL nd_range * + * **Important:** num_groups is in SYCL ZYX order, meaning that the range should contain + * {num_groups_z, num_groups_y, num_groups_x}. + * * @param krnl Kernel - * @param howmany Group size + * @param num_groups Number of groups * * @return ND range */ -TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, std::int64_t howmany) +TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, sycl::range<3u> const &num_groups) -> sycl::nd_range<3u>; //////////////////////////// @@ -133,45 +123,43 @@ TINYTC_EXPORT auto get_execution_range(sycl::kernel const &krnl, std::int64_t ho //////////////////////////// /** - * @brief Recipe handler for the SYCL runtime + * @brief Launch recipe with submit call + * + * @param handler recipe handler + * @param cgh Handler + */ +TINYTC_EXPORT void parallel_for(tinytc_recipe_handler_t handler, sycl::handler &cgh); +/** + * @brief Submit recipe to queue + * + * @param handler recipe handler + * @param q Queue + * + * @return Event + */ +TINYTC_EXPORT auto submit(tinytc_recipe_handler_t handler, sycl::queue q) -> sycl::event; +/** + * @brief Submit recipe to queue + * + * @param handler recipe handler + * @param q Queue + * @param dep_event Event to wait on + * + * @return Event + */ +TINYTC_EXPORT auto submit(tinytc_recipe_handler_t handler, sycl::queue q, + sycl::event const &dep_event) -> sycl::event; +/** + * @brief Submit recipe to queue + * + * @param handler recipe handler + * @param q Queue + * @param dep_events Events to wait on + * + * @return Event */ -class TINYTC_EXPORT sycl_recipe_handler : public recipe_handler { - public: - using recipe_handler::recipe_handler; - - /** - * @brief Launch recipe with submit call - * - * @param h Handler - */ - void parallel_for(sycl::handler &h); - /** - * @brief Submit recipe to queue - * - * @param q Queue - * - * @return Event - */ - auto submit(sycl::queue q) -> sycl::event; - /** - * @brief Submit recipe to queue - * - * @param q Queue - * @param dep_event Event to wait on - * - * @return Event - */ - auto submit(sycl::queue q, sycl::event const &dep_event) -> sycl::event; - /** - * @brief Submit recipe to queue - * - * @param q Queue - * @param dep_events Events to wait on - * - * @return Event - */ - auto submit(sycl::queue q, std::vector const &dep_events) -> sycl::event; -}; +TINYTC_EXPORT auto submit(tinytc_recipe_handler_t handler, sycl::queue q, + std::vector const &dep_events) -> sycl::event; /** * @brief Make recipe handler @@ -179,24 +167,22 @@ class TINYTC_EXPORT sycl_recipe_handler : public recipe_handler { * @param ctx Context * @param dev Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return SYCL recipe handler */ -TINYTC_EXPORT auto make_recipe_handler(sycl::context const &ctx, sycl::device const &dev, - recipe const &rec, source_context source_ctx = {}) - -> sycl_recipe_handler; +TINYTC_EXPORT auto create_recipe_handler(sycl::context const &ctx, sycl::device const &dev, + tinytc_recipe_t rec) + -> shared_handle; /** * @brief Make recipe handler * * @param q Queue * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return SYCL recipe handler */ -TINYTC_EXPORT auto make_recipe_handler(sycl::queue const &q, recipe const &rec, - source_context source_ctx = {}) -> sycl_recipe_handler; +TINYTC_EXPORT auto create_recipe_handler(sycl::queue const &q, tinytc_recipe_t rec) + -> shared_handle; } // namespace tinytc diff --git a/include/tinytc/tinytc_ze.h b/include/tinytc/tinytc_ze.h index d5b4d45a..b39d82ad 100644 --- a/include/tinytc/tinytc_ze.h +++ b/include/tinytc/tinytc_ze.h @@ -17,13 +17,11 @@ extern "C" { /////////// Error ////////// //////////////////////////// -TINYTC_EXPORT tinytc_status_t tinytc_ze_convert_status(ze_result_t result); - #define TINYTC_ZE_CHECK_STATUS(X) \ do { \ ze_result_t result = X; \ if (result != ZE_RESULT_SUCCESS) { \ - return tinytc_ze_convert_status(result); \ + return tinytc_status_compute_runtime_error; \ } \ } while (0) @@ -57,38 +55,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_core_info_create(tinytc_core_info_t *inf ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bin [out] pointer to the binary object created - * @param src [in] source text - * @param ip_version [in] IP version (pass tinytc_intel_gpu_architecture_t here) - * @param format [in] binary format (SPIR-V or native) - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_source_compile_to_binary( - tinytc_binary_t *bin, const_tinytc_source_t src, uint32_t ip_version, - tinytc_bundle_format_t format, tinytc_source_context_t source_ctx); - -/** - * @brief Compile OpenCL-C source to device binary - * - * @param bundle [out] pointer to the kernel bundle (ze_module_handle_t) object created - * @param context [in] context handle - * @param device [in] device handle - * @param src [in] source text and extensions - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr - * - * @return tinytc_status_success on success and error otherwise - */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_source( - ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_source_t src, tinytc_source_context_t source_ctx); - /** * @brief Compile tensor program * @@ -98,15 +64,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_source( * @param prg [inout] tensor program; modified as compiler passes are run * @param core_features [in][optional] requested core features; must be 0 (default) or a combination * of tinytc_core_feature_flag_t - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_program( ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - tinytc_prog_t prg, tinytc_core_feature_flags_t core_features, - tinytc_source_context_t source_ctx); + tinytc_prog_t prg, tinytc_core_feature_flags_t core_features); /** * @brief Create an OpenCL program from a tinytc binary @@ -115,14 +78,12 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_program( * @param context [in] context handle * @param device [in] device handle * @param bin [in] binary object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ -TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_bundle_create_with_binary( - ze_module_handle_t *bundle, ze_context_handle_t context, ze_device_handle_t device, - const_tinytc_binary_t bin, tinytc_source_context_t source_ctx); +TINYTC_EXPORT tinytc_status_t +tinytc_ze_kernel_bundle_create_with_binary(ze_module_handle_t *bundle, ze_context_handle_t context, + ze_device_handle_t device, const_tinytc_binary_t bin); /** * @brief Create a kernel and set group size @@ -149,15 +110,6 @@ TINYTC_EXPORT tinytc_status_t tinytc_ze_kernel_create(ze_kernel_handle_t *krnl, TINYTC_EXPORT tinytc_status_t tinytc_ze_get_group_size(ze_kernel_handle_t kernel, uint32_t *x, uint32_t *y, uint32_t *z); -/** - * @brief Convert group size to level zero group count - * - * @param howmany group size - * - * @return group count - */ -TINYTC_EXPORT ze_group_count_t tinytc_ze_get_group_count(int64_t howmany); - //////////////////////////// ////////// Recipe ////////// //////////////////////////// @@ -169,16 +121,13 @@ TINYTC_EXPORT ze_group_count_t tinytc_ze_get_group_count(int64_t howmany); * @param context [in] context handle * @param device [in] device handle * @param recipe [in] recipe object - * @param source_ctx [inout][optional] source context object to save extended error messages that - * are enhanced with source code context; can be nullptr * * @return tinytc_status_success on success and error otherwise */ TINYTC_EXPORT tinytc_status_t tinytc_ze_recipe_handler_create(tinytc_recipe_handler_t *handler, ze_context_handle_t context, ze_device_handle_t device, - tinytc_recipe_t recipe, - tinytc_source_context_t source_ctx); + tinytc_recipe_t recipe); /** * @brief Submit recipe to device diff --git a/include/tinytc/tinytc_ze.hpp b/include/tinytc/tinytc_ze.hpp index 57a2eac2..5a78689a 100644 --- a/include/tinytc/tinytc_ze.hpp +++ b/include/tinytc/tinytc_ze.hpp @@ -20,7 +20,7 @@ namespace tinytc { //! Throw exception for unsuccessful call to C-API and convert result code to tinytc status inline void ZE_CHECK_STATUS(ze_result_t result) { if (result != ZE_RESULT_SUCCESS) { - throw status{std::underlying_type_t(::tinytc_ze_convert_status(result))}; + throw status::compute_runtime_error; } } @@ -48,34 +48,16 @@ inline auto get_support_level(ze_device_handle_t device) -> support_level { * * @return core info */ -inline auto make_core_info(ze_device_handle_t device) -> core_info { +inline auto create_core_info(ze_device_handle_t device) -> shared_handle { tinytc_core_info_t info; CHECK_STATUS(::tinytc_ze_core_info_create(&info, device)); - return core_info{info}; + return shared_handle{info}; } //////////////////////////// ////////// Kernel ////////// //////////////////////////// -/** - * @brief Compile source to binary - * - * @param src Source object - * @param ip_version IP version (pass tinytc_intel_gpu_architecture_t here) - * @param format Bundle format (SPIR-V or Native) - * @param ctx Source context for improved error reporting - * - * @return Binary - */ -inline auto compile_to_binary(source const &src, std::uint32_t ip_version, bundle_format format, - source_context ctx = {}) -> binary { - tinytc_binary_t bin; - CHECK_STATUS(tinytc_ze_source_compile_to_binary( - &bin, src.get(), ip_version, static_cast(format), ctx.get())); - return binary{bin}; -} - namespace internal { template <> struct unique_handle_traits { static void destroy(ze_kernel_handle_t obj) { zeKernelDestroy(obj); } @@ -85,25 +67,6 @@ template <> struct unique_handle_traits { }; } // namespace internal -/** - * @brief Make a Level Zero module from a tinytc source - * - * @param context Context - * @param device Device - * @param src Source - * @param source_ctx Source context for improved error reporting - * - * @return Level Zero module (unique handle) - */ -inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - source const &src, source_context source_ctx = {}) - -> unique_handle { - ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_source(&obj, context, device, src.get(), - source_ctx.get())); - return unique_handle{obj}; -} - /** * @brief Make a Level Zero module from a tinytc program * @@ -112,17 +75,15 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * @param prg Program * @param core_features requested core features; must be 0 (default) or a combination of * tinytc_core_feature_flag_t - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ -inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, prog prg, - tinytc_core_feature_flags_t core_features = 0, - source_context source_ctx = {}) +inline auto create_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, + tinytc_prog_t prg, tinytc_core_feature_flags_t core_features = 0) -> unique_handle { ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_program(&obj, context, device, prg.get(), - core_features, source_ctx.get())); + CHECK_STATUS( + tinytc_ze_kernel_bundle_create_with_program(&obj, context, device, prg, core_features)); return unique_handle{obj}; } @@ -132,16 +93,13 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * @param context Context * @param device Device * @param bin Binary - * @param source_ctx Source context for improved error reporting * * @return Level Zero module (unique handle) */ -inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, - binary const &bin, source_context source_ctx = {}) - -> unique_handle { +inline auto create_kernel_bundle(ze_context_handle_t context, ze_device_handle_t device, + const_tinytc_binary_t bin) -> unique_handle { ze_module_handle_t obj; - CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_binary(&obj, context, device, bin.get(), - source_ctx.get())); + CHECK_STATUS(tinytc_ze_kernel_bundle_create_with_binary(&obj, context, device, bin)); return unique_handle{obj}; } @@ -153,7 +111,7 @@ inline auto make_kernel_bundle(ze_context_handle_t context, ze_device_handle_t d * * @return Level Zero kernel (unique handle) */ -inline auto make_kernel(ze_module_handle_t mod, char const *name) +inline auto create_kernel(ze_module_handle_t mod, char const *name) -> unique_handle { ze_kernel_handle_t obj; CHECK_STATUS(tinytc_ze_kernel_create(&obj, mod, name)); @@ -173,44 +131,27 @@ inline auto get_group_size(ze_kernel_handle_t kernel) -> std::array ze_group_count_t { - return tinytc_ze_get_group_count(howmany); -} - //////////////////////////// ////////// Recipe ////////// //////////////////////////// /** - * @brief Recipe handler for the Level Zero runtime + * @brief Append recipe to command list + * + * Cf. @ref tinytc_ze_recipe_handler_submit + * + * @param handler Recipe handler + * @param list Command list + * @param signal_event Event to be signalled on completetion + * @param num_wait_events Number of wait events to wait on + * @param wait_events Array of num_wait_events events to wait on */ -class level_zero_recipe_handler : public recipe_handler { - public: - using recipe_handler::recipe_handler; - - /** - * @brief Append recipe to command list - * - * Cf. @ref tinytc_ze_recipe_handler_submit - * - * @param list Command list - * @param signal_event Event to be signalled on completetion - * @param num_wait_events Number of wait events to wait on - * @param wait_events Array of num_wait_events events to wait on - */ - inline void submit(ze_command_list_handle_t list, ze_event_handle_t signal_event = nullptr, - uint32_t num_wait_events = 0, ze_event_handle_t *wait_events = nullptr) { - CHECK_STATUS(tinytc_ze_recipe_handler_submit(obj_, list, signal_event, num_wait_events, - wait_events)); - } -}; +inline void submit(tinytc_recipe_handler_t handler, ze_command_list_handle_t list, + ze_event_handle_t signal_event = nullptr, uint32_t num_wait_events = 0, + ze_event_handle_t *wait_events = nullptr) { + CHECK_STATUS( + tinytc_ze_recipe_handler_submit(handler, list, signal_event, num_wait_events, wait_events)); +} /** * @brief Make recipe handler @@ -218,17 +159,14 @@ class level_zero_recipe_handler : public recipe_handler { * @param context Context * @param device Device * @param rec Recipe - * @param source_ctx Source context for improved error reporting * * @return Level Zero recipe handler */ -inline auto make_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, - recipe const &rec, source_context source_ctx = {}) - -> level_zero_recipe_handler { +inline auto create_recipe_handler(ze_context_handle_t context, ze_device_handle_t device, + tinytc_recipe_t rec) -> shared_handle { tinytc_recipe_handler_t handler; - CHECK_STATUS( - tinytc_ze_recipe_handler_create(&handler, context, device, rec.get(), source_ctx.get())); - return level_zero_recipe_handler{handler}; + CHECK_STATUS(tinytc_ze_recipe_handler_create(&handler, context, device, rec)); + return shared_handle{handler}; } } // namespace tinytc diff --git a/include/tinytc/types.anko b/include/tinytc/types.anko new file mode 100644 index 00000000..fb796197 --- /dev/null +++ b/include/tinytc/types.anko @@ -0,0 +1,58 @@ +; Copyright (C) 2025 Intel Corporation +; SPDX-License-Identifier: BSD-3-Clause + +include "tinytc/enums.anko" + +type @boolean "Bool type" {} + +type @number "Number type" {} + +type @integer : @number "Integer type" {} +type @i8 : @integer "Signed 8 bit integer type" {} +type @i16 : @integer "Signed 16 bit integer type" {} +type @i32 : @integer "Signed 32 bit integer type" {} +type @i64 : @integer "Signed 64 bit integer type" {} +type @index : @integer "Integer type for indices" {} + +type @float : @number "Floating point type" {} +type @bf16 : @float "Brain floating point (16 bit)" {} +type @f16 : @float "Half precision floating point (16 bit)" {} +type @f32 : @float "Single precision floating point (32 bit)" {} +type @f64 : @float "Double precision floating point (64 bit)" {} + +type @complex : @number "Complex number type" {} +type @c32 : @complex "Single precision floating point (32 bit)" {} +type @c64 : @complex "Double precision floating point (64 bit)" {} + +type @coopmatrix "Coopmatrix type" { + prop %component_ty => type_t "component type" + prop %rows => i64 "number of rows" + prop %cols => i64 "number of columns" + prop %use => @matrix_use "matrix use" + cxx "inline auto shape(int mode) const -> std::int64_t { return mode == 1 ? cols() : rows(); }" +} + +type @group "Group type" { + prop %element_ty => type_t "element type" + prop %size => i64 "group size" + prop %offset => i64 "offset added on element access" +} + +type @memref "Memref type" { + prop %element_ty => type_t "element type" + prop* %shape => i64 "tensor shape" + prop* %stride => i64 "tensor stride" + prop %addrspace => @address_space "address space" + cxx "static auto canonical_stride(array_view shape) -> std::vector;" + cxx "inline auto dim() const -> std::int64_t { return shape_.size(); }" + cxx "inline auto shape(std::int64_t i) const -> std::int64_t { return shape_[i]; }" + cxx "inline auto stride(std::int64_t i) const -> std::int64_t { return stride_[i]; }" + cxx "auto is_dynamic_shape() const -> bool;" + cxx "auto is_dynamic_stride() const -> bool;" + cxx "auto is_dynamic() const -> bool;" + cxx "auto is_canonical_stride() const -> bool;" + cxx "auto element_alignment() const -> std::int32_t;" + cxx "auto size_in_bytes() const -> std::int64_t;" +} + +type @void "Void type" {} diff --git a/include/tinytc/types.h b/include/tinytc/types.h deleted file mode 100644 index a404daa1..00000000 --- a/include/tinytc/types.h +++ /dev/null @@ -1,435 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef TYPES_20240410_H -#define TYPES_20240410_H - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -//////////////////////////// -///////// Constants //////// -//////////////////////////// - -#define TINYTC_DYNAMIC INT64_MIN - -//////////////////////////// -/////// Enumerations /////// -//////////////////////////// - -/** - * @brief Status codes - */ -typedef enum { - tinytc_status_success = 0x0, ///< Success - tinytc_status_bad_alloc = 0x1, ///< Failure to allocate storage - tinytc_status_invalid_arguments = 0x2, ///< Function got invalid arguments - tinytc_status_out_of_range = 0x3, ///< Element access out of bounds - tinytc_status_runtime_error = 0x4, ///< Runtime error - tinytc_status_internal_compiler_error = 0x5, ///< Internal compiler error - tinytc_status_unsupported_subgroup_size = 0x6, ///< Device does not support subgroup size - tinytc_status_unsupported_work_group_size = 0x7, ///< Device does not support work-group size - tinytc_status_compilation_error = 0x8, ///< Compilation error - tinytc_status_file_io_error = 0x9, ///< Error during File I/O - tinytc_status_parse_error = 0xa, ///< Error during parsing - tinytc_status_unavailable_extension = 0xb, ///< Unavailable runtime extension - tinytc_status_unsupported_backend = 0xc, ///< Unsupported backend (SYCL runtime) - tinytc_status_invalid_kernel_arguments = 0xd, ///< Kernel got invalid arguments - tinytc_status_unsupported_device = 0xe, ///< Unsupported device - // IR errors - tinytc_status_ir_out_of_bounds = 0x100, ///< Out of bounds access - tinytc_status_ir_invalid_shape = 0x101, ///< Invalid tensor shape - tinytc_status_ir_incompatible_shapes = 0x102, ///< Tensor shape requirements not satisfied - tinytc_status_ir_shape_stride_mismatch = 0x103, ///< Mismatch of shape and stride - tinytc_status_ir_scalar_mismatch = 0x104, ///< Mismatch of scalar types - tinytc_status_ir_invalid_number_of_indices = 0x105, /// Invalid number of indices - tinytc_status_ir_expected_scalar = 0x106, ///< Expected a value of scalar type - tinytc_status_ir_expected_memref = 0x107, ///< Expected a value of memref type - tinytc_status_ir_expected_memref_or_scalar = 0x108, ///< Expected memref or scalar type - tinytc_status_ir_expected_memref_or_group = 0x109, ///< Expected a value of memref or group type - tinytc_status_ir_expected_vector_or_matrix = 0x10a, ///< Expected a vector or marix - tinytc_status_ir_unexpected_yield = 0x10b, ///< Unexpected yield instruction - tinytc_status_ir_yield_mismatch = 0x10c, ///< Wrong number of yielded values - tinytc_status_ir_multiple_dynamic_modes = 0x10d, ///< At most one mode must be dynamic - tinytc_status_ir_invalid_slice = 0x10e, ///< Invalid slice - tinytc_status_ir_expand_shape_order_too_small = 0x10f, ///< Expand shape too small - tinytc_status_ir_expand_shape_mismatch = 0x110, ///< Invalid expand shape - tinytc_status_ir_collective_called_from_spmd = 0x111, ///< Collective instruction from SPMD - tinytc_status_ir_fp_unsupported = 0x112, ///< Instruction does not support floating type - // Level zero errors - tinytc_status_ze_result_not_ready = 0x10000, ///< ZE_RESULT_NOT_READY - tinytc_status_ze_result_error_device_lost = 0x10001, ///< ZE_RESULT_ERROR_DEVICE_LOST - tinytc_status_ze_result_error_out_of_host_memory = - 0x10002, ///< ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY - tinytc_status_ze_result_error_out_of_device_memory = - 0x10003, ///< ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY - tinytc_status_ze_result_error_module_build_failure = - 0x10004, ///< ZE_RESULT_ERROR_MODULE_BUILD_FAILURE - tinytc_status_ze_result_error_module_link_failure = - 0x10005, ///< ZE_RESULT_ERROR_MODULE_LINK_FAILURE - tinytc_status_ze_result_error_device_requires_reset = - 0x10006, ///< ZE_RESULT_ERROR_DEVICE_REQUIRES_RESET - tinytc_status_ze_result_error_device_in_low_power_state = - 0x10007, ///< ZE_RESULT_ERROR_DEVICE_IN_LOW_POWER_STATE - tinytc_status_ze_result_exp_error_device_is_not_vertex = - 0x10008, ///< ZE_RESULT_EXP_ERROR_DEVICE_IS_NOT_VERTEX - tinytc_status_ze_result_exp_error_vertex_is_not_device = - 0x10009, ///< ZE_RESULT_EXP_ERROR_VERTEX_IS_NOT_DEVICE - tinytc_status_ze_result_exp_error_remote_device = - 0x1000A, ///< ZE_RESULT_EXP_ERROR_REMOTE_DEVICE - tinytc_status_ze_result_exp_error_operands_incompatible = - 0x1000B, ///< ZE_RESULT_EXP_ERROR_OPERANDS_INCOMPATIBLE - tinytc_status_ze_result_exp_rtas_build_retry = 0x1000C, ///< ZE_RESULT_EXP_RTAS_BUILD_RETRY - tinytc_status_ze_result_exp_rtas_build_deferred = - 0x1000D, ///< ZE_RESULT_EXP_RTAS_BUILD_DEFERRED - tinytc_status_ze_result_error_insufficient_permissions = - 0x1000E, ///< ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS - tinytc_status_ze_result_error_not_available = 0x1000F, ///< ZE_RESULT_ERROR_NOT_AVAILABLE - tinytc_status_ze_result_error_dependency_unavailable = - 0x10010, ///< ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE - tinytc_status_ze_result_warning_dropped_data = 0x10011, ///< ZE_RESULT_WARNING_DROPPED_DATA - tinytc_status_ze_result_error_uninitialized = 0x10012, ///< ZE_RESULT_ERROR_UNINITIALIZED - tinytc_status_ze_result_error_unsupported_version = - 0x10013, ///< ZE_RESULT_ERROR_UNSUPPORTED_VERSION - tinytc_status_ze_result_error_unsupported_feature = - 0x10014, ///< ZE_RESULT_ERROR_UNSUPPORTED_FEATURE - tinytc_status_ze_result_error_invalid_argument = 0x10015, ///< ZE_RESULT_ERROR_INVALID_ARGUMENT - tinytc_status_ze_result_error_invalid_null_handle = - 0x10016, ///< ZE_RESULT_ERROR_INVALID_NULL_HANDLE - tinytc_status_ze_result_error_handle_object_in_use = - 0x10017, ///< ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE - tinytc_status_ze_result_error_invalid_null_pointer = - 0x10018, ///< ZE_RESULT_ERROR_INVALID_NULL_POINTER - tinytc_status_ze_result_error_invalid_size = 0x10019, ///< ZE_RESULT_ERROR_INVALID_SIZE - tinytc_status_ze_result_error_unsupported_size = 0x1001A, ///< ZE_RESULT_ERROR_UNSUPPORTED_SIZE - tinytc_status_ze_result_error_unsupported_alignment = - 0x1001B, ///< ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT - tinytc_status_ze_result_error_invalid_synchronization_object = - 0x1001C, ///< ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT - tinytc_status_ze_result_error_invalid_enumeration = - 0x1001D, ///< ZE_RESULT_ERROR_INVALID_ENUMERATION - tinytc_status_ze_result_error_unsupported_enumeration = - 0x1001E, ///< ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION - tinytc_status_ze_result_error_unsupported_image_format = - 0x1001F, ///< ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT - tinytc_status_ze_result_error_invalid_native_binary = - 0x10020, ///< ZE_RESULT_ERROR_INVALID_NATIVE_BINARY - tinytc_status_ze_result_error_invalid_global_name = - 0x10021, ///< ZE_RESULT_ERROR_INVALID_GLOBAL_NAME - tinytc_status_ze_result_error_invalid_kernel_name = - 0x10022, ///< ZE_RESULT_ERROR_INVALID_KERNEL_NAME - tinytc_status_ze_result_error_invalid_function_name = - 0x10023, ///< ZE_RESULT_ERROR_INVALID_FUNCTION_NAME - tinytc_status_ze_result_error_invalid_group_size_dimension = - 0x10024, ///< ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION - tinytc_status_ze_result_error_invalid_global_width_dimension = - 0x10025, ///< ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION - tinytc_status_ze_result_error_invalid_kernel_argument_index = - 0x10026, ///< ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX - tinytc_status_ze_result_error_invalid_kernel_argument_size = - 0x10027, ///< ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE - tinytc_status_ze_result_error_invalid_kernel_attribute_value = - 0x10028, ///< ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE - tinytc_status_ze_result_error_invalid_module_unlinked = - 0x10029, ///< ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED - tinytc_status_ze_result_error_invalid_command_list_type = - 0x1002A, ///< ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE - tinytc_status_ze_result_error_overlapping_regions = - 0x1002B, ///< ZE_RESULT_ERROR_OVERLAPPING_REGIONS - tinytc_status_ze_result_warning_action_required = - 0x1002C, ///< ZE_RESULT_WARNING_ACTION_REQUIRED - tinytc_status_ze_result_error_unknown = 0x1002D, ///< ZE_RESULT_ERROR_UNKNOWN - // OpenCL errors - tinytc_status_cl_build_program_failure = 0x20000, ///< CL_BUILD_PROGRAM_FAILURE - tinytc_status_cl_compile_program_failure = 0x20001, ///< CL_COMPILE_PROGRAM_FAILURE - tinytc_status_cl_compiler_not_available = 0x20002, ///< CL_COMPILER_NOT_AVAILABLE - tinytc_status_cl_device_not_found = 0x20003, ///< CL_DEVICE_NOT_FOUND - tinytc_status_cl_device_not_available = 0x20004, ///< CL_DEVICE_NOT_AVAILABLE - tinytc_status_cl_device_partition_failed = 0x20005, ///< CL_DEVICE_PARTITION_FAILED - tinytc_status_cl_exec_status_error_for_events_in_wait_list = - 0x20006, ///< CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST - tinytc_status_cl_image_format_mismatch = 0x20007, ///< CL_IMAGE_FORMAT_MISMATCH - tinytc_status_cl_image_format_not_supported = 0x20008, ///< CL_IMAGE_FORMAT_NOT_SUPPORTED - tinytc_status_cl_invalid_arg_index = 0x20009, ///< CL_INVALID_ARG_INDEX - tinytc_status_cl_invalid_arg_size = 0x2000A, ///< CL_INVALID_ARG_SIZE - tinytc_status_cl_invalid_arg_value = 0x2000B, ///< CL_INVALID_ARG_VALUE - tinytc_status_cl_invalid_binary = 0x2000C, ///< CL_INVALID_BINARY - tinytc_status_cl_invalid_buffer_size = 0x2000D, ///< CL_INVALID_BUFFER_SIZE - tinytc_status_cl_invalid_build_options = 0x2000E, ///< CL_INVALID_BUILD_OPTIONS - tinytc_status_cl_invalid_command_queue = 0x2000F, ///< CL_INVALID_COMMAND_QUEUE - tinytc_status_cl_invalid_compiler_options = 0x20010, ///< CL_INVALID_COMPILER_OPTIONS - tinytc_status_cl_invalid_context = 0x20011, ///< CL_INVALID_CONTEXT - tinytc_status_cl_invalid_device = 0x20012, ///< CL_INVALID_DEVICE - tinytc_status_cl_invalid_device_partition_count = - 0x20013, ///< CL_INVALID_DEVICE_PARTITION_COUNT - tinytc_status_cl_invalid_device_queue = 0x20014, ///< CL_INVALID_DEVICE_QUEUE - tinytc_status_cl_invalid_device_type = 0x20015, ///< CL_INVALID_DEVICE_TYPE - tinytc_status_cl_invalid_event = 0x20016, ///< CL_INVALID_EVENT - tinytc_status_cl_invalid_event_wait_list = 0x20017, ///< CL_INVALID_EVENT_WAIT_LIST - tinytc_status_cl_invalid_global_offset = 0x20018, ///< CL_INVALID_GLOBAL_OFFSET - tinytc_status_cl_invalid_global_work_size = 0x20019, ///< CL_INVALID_GLOBAL_WORK_SIZE - tinytc_status_cl_invalid_host_ptr = 0x2001A, ///< CL_INVALID_HOST_PTR - tinytc_status_cl_invalid_image_descriptor = 0x2001B, ///< CL_INVALID_IMAGE_DESCRIPTOR - tinytc_status_cl_invalid_image_format_descriptor = - 0x2001C, ///< CL_INVALID_IMAGE_FORMAT_DESCRIPTOR - tinytc_status_cl_invalid_image_size = 0x2001D, ///< CL_INVALID_IMAGE_SIZE - tinytc_status_cl_invalid_kernel = 0x2001E, ///< CL_INVALID_KERNEL - tinytc_status_cl_invalid_kernel_args = 0x2001F, ///< CL_INVALID_KERNEL_ARGS - tinytc_status_cl_invalid_kernel_definition = 0x20020, ///< CL_INVALID_KERNEL_DEFINITION - tinytc_status_cl_invalid_kernel_name = 0x20021, ///< CL_INVALID_KERNEL_NAME - tinytc_status_cl_invalid_linker_options = 0x20022, ///< CL_INVALID_LINKER_OPTIONS - tinytc_status_cl_invalid_mem_object = 0x20023, ///< CL_INVALID_MEM_OBJECT - tinytc_status_cl_invalid_operation = 0x20024, ///< CL_INVALID_OPERATION - tinytc_status_cl_invalid_pipe_size = 0x20025, ///< CL_INVALID_PIPE_SIZE - tinytc_status_cl_invalid_platform = 0x20026, ///< CL_INVALID_PLATFORM - tinytc_status_cl_invalid_program = 0x20027, ///< CL_INVALID_PROGRAM - tinytc_status_cl_invalid_program_executable = 0x20028, ///< CL_INVALID_PROGRAM_EXECUTABLE - tinytc_status_cl_invalid_property = 0x20029, ///< CL_INVALID_PROPERTY - tinytc_status_cl_invalid_queue_properties = 0x2002A, ///< CL_INVALID_QUEUE_PROPERTIES - tinytc_status_cl_invalid_sampler = 0x2002B, ///< CL_INVALID_SAMPLER - tinytc_status_cl_invalid_spec_id = 0x2002C, ///< CL_INVALID_SPEC_ID - tinytc_status_cl_invalid_value = 0x2002D, ///< CL_INVALID_VALUE - tinytc_status_cl_invalid_work_dimension = 0x2002E, ///< CL_INVALID_WORK_DIMENSION - tinytc_status_cl_invalid_work_group_size = 0x2002F, ///< CL_INVALID_WORK_GROUP_SIZE - tinytc_status_cl_invalid_work_item_size = 0x20030, ///< CL_INVALID_WORK_ITEM_SIZE - tinytc_status_cl_kernel_arg_info_not_available = 0x20031, ///< CL_KERNEL_ARG_INFO_NOT_AVAILABLE - tinytc_status_cl_link_program_failure = 0x20032, ///< CL_LINK_PROGRAM_FAILURE - tinytc_status_cl_linker_not_available = 0x20033, ///< CL_LINKER_NOT_AVAILABLE - tinytc_status_cl_map_failure = 0x20034, ///< CL_MAP_FAILURE - tinytc_status_cl_mem_copy_overlap = 0x20035, ///< CL_MEM_COPY_OVERLAP - tinytc_status_cl_mem_object_allocation_failure = 0x20036, ///< CL_MEM_OBJECT_ALLOCATION_FAILURE - tinytc_status_cl_misaligned_sub_buffer_offset = 0x20037, ///< CL_MISALIGNED_SUB_BUFFER_OFFSET - tinytc_status_cl_out_of_host_memory = 0x20038, ///< CL_OUT_OF_HOST_MEMORY - tinytc_status_cl_out_of_resources = 0x20039, ///< CL_OUT_OF_RESOURCES - tinytc_status_cl_max_size_restriction_exceeded = 0x2003A, ///< CL_MAX_SIZE_RESTRICTION_EXCEEDED - tinytc_status_cl_profiling_info_not_available = 0x2003B, ///< CL_PROFILING_INFO_NOT_AVAILABLE - // The unknown error comes last - tinytc_status_unknown = 0x7fffffff ///< Unknown error occured -} tinytc_status_t; - -//! Scalar types -typedef enum { - tinytc_scalar_type_i1 = 0, ///< Signed 1 bit integer (boolean) - tinytc_scalar_type_i8 = 1, ///< Signed 8 bit integer - tinytc_scalar_type_i16 = 2, ///< Signed 16 bit integer - tinytc_scalar_type_i32 = 3, ///< Signed 32 bit integer - tinytc_scalar_type_i64 = 4, ///< Signed 64 bit integer - tinytc_scalar_type_index = 5, ///< Integer type for indices - tinytc_scalar_type_f32 = 6, ///< Single precision floating point (32 bit) - tinytc_scalar_type_f64 = 7 ///< Double precision floating point (64 bit) -} tinytc_scalar_type_t; - -//! Arithmetic operations -typedef enum { - tinytc_arithmetic_add = 0, ///< add - tinytc_arithmetic_sub = 1, ///< subtract - tinytc_arithmetic_mul = 2, ///< multiply - tinytc_arithmetic_div = 3, ///< divide - tinytc_arithmetic_rem = 4, ///< division remainder - tinytc_arithmetic_shl = 5, ///< left shift - tinytc_arithmetic_shr = 6, ///< arithmetic right shift - tinytc_arithmetic_and = 7, ///< bitwise and - tinytc_arithmetic_or = 8, ///< bitwise or - tinytc_arithmetic_xor = 9 ///< bitwise xor -} tinytc_arithmetic_t; - -//! Arithmetic operations (unary) -typedef enum { - tinytc_arithmetic_unary_neg = 0, ///< negation - tinytc_arithmetic_unary_not = 1 ///< bitwise not -} tinytc_arithmetic_unary_t; - -//! Compare operation -typedef enum { - tinytc_cmp_condition_eq = 0, ///< equals - tinytc_cmp_condition_ne = 1, ///< not equal - tinytc_cmp_condition_gt = 2, ///< greater than - tinytc_cmp_condition_ge = 3, ///< greather or equal than - tinytc_cmp_condition_lt = 4, ///< less than - tinytc_cmp_condition_le = 5 ///< less or equal than -} tinytc_cmp_condition_t; - -//! Transpose -typedef enum { - tinytc_transpose_N = 0, ///< No transpose - tinytc_transpose_T = 1 ///< Transpose -} tinytc_transpose_t; - -//! Core features that may be optionally enabled -typedef enum { - /** - * Request a large register file. - * On PVC this doubles the number of registers per vector engine - * but halves the number of available hardware threads. - * When this feature is activated, the kernel is compiled with - * the "-ze-opt-large-register-file" option. - */ - tinytc_core_feature_flag_large_register_file = 0x1 -} tinytc_core_feature_flag_t; - -//! Type for combination of core feature flags -typedef uint32_t tinytc_core_feature_flags_t; - -/** - * @brief IP versions for Intel GPUs - * - * Note: IP versions are extracted from - * * https://github.com/intel/compute-runtime/blob/4b5d5f235abf0ff67c9188f8096afd4da2e0574d/third_party/aot_config_headers/platforms.h - * * https://github.com/intel/llvm/blob/56e9067ba69809fb6ea1fd4328456ca3a009f984/sycl/source/detail/device_info.hpp#L619 - */ -typedef enum { - tinytc_intel_gpu_architecture_tgl = 0x03000000, ///< Tiger Lake - tinytc_intel_gpu_architecture_pvc = 0x030f0007 ///< Ponte Vecchio -} tinytc_intel_gpu_architecture_t; - -//! Target binary format -typedef enum { - tinytc_bundle_format_spirv = 0, ///< SPIR-V - tinytc_bundle_format_native = 1 ///< Native device binary -} tinytc_bundle_format_t; - -//! Memory object type -typedef enum { - tinytc_mem_type_buffer = 0x0, ///< Buffer object (e.g. cl_mem) - tinytc_mem_type_usm_pointer = 0x1, ///< Unified shared memory pointer - tinytc_mem_type_svm_pointer = 0x2, ///< Shared virtual memory pointer -} tinytc_mem_type_t; - -//! Support level of a device -typedef enum { - //! Device is unsupported (e.g. subgroups feature missing in OpenCL-C) - tinytc_support_level_none = 0x0, - //! Device provides necessary features but is not well tested - tinytc_support_level_basic = 0x1, - //! Device provides necessary features and is well tested - tinytc_support_level_tuned = 0x2 -} tinytc_support_level_t; - -//////////////////////////// -/////////// Types ////////// -//////////////////////////// - -//! @brief Bool type {0,1} -typedef uint8_t tinytc_bool_t; - -//! @struct tinytc_data_type -//! @brief Opaque struct for a data type -struct tinytc_data_type; -//! @brief data_type handle -typedef struct tinytc_data_type *tinytc_data_type_t; -//! @brief const data_type handle -typedef const struct tinytc_data_type *const_tinytc_data_type_t; - -//! @struct tinytc_value -//! @brief Opaque struct for a value -struct tinytc_value; -//! @brief value handle -typedef struct tinytc_value *tinytc_value_t; -//! @brief const value handle -typedef const struct tinytc_value *const_tinytc_value_t; - -//! @struct tinytc_inst -//! @brief Opaque struct for an instruction -struct tinytc_inst; -//! @brief inst handle -typedef struct tinytc_inst *tinytc_inst_t; -//! @brief const inst handle -typedef const struct tinytc_inst *const_tinytc_inst_t; - -//! @struct tinytc_region -//! @brief Opaque struct for a region -struct tinytc_region; -//! @brief region handle -typedef struct tinytc_region *tinytc_region_t; -//! @brief const region handle -typedef const struct tinytc_region *const_tinytc_region_t; - -//! @struct tinytc_func -//! @brief Opaque struct for a function -struct tinytc_func; -//! @brief func handle -typedef struct tinytc_func *tinytc_func_t; -//! @brief const func handle -typedef const struct tinytc_func *const_tinytc_func_t; - -//! @struct tinytc_prog -//! @brief Opaque struct for a program -struct tinytc_prog; -//! @brief prog handle -typedef struct tinytc_prog *tinytc_prog_t; -//! @brief const prog handle -typedef const struct tinytc_prog *const_tinytc_prog_t; - -//! @struct tinytc_core_info; -//! @brief Opaque struct for core information -struct tinytc_core_info; -//! @brief core_info handle -typedef struct tinytc_core_info *tinytc_core_info_t; -//! @brief const core_info handle -typedef const struct tinytc_core_info *const_tinytc_core_info_t; - -//! @struct tinytc_source; -//! @brief Opaque struct for source text -struct tinytc_source; -//! @brief source handle -typedef struct tinytc_source *tinytc_source_t; -//! @brief const source handle -typedef const struct tinytc_source *const_tinytc_source_t; - -//! @struct tintyc_source_context -//! @brief Opaque struct for source context -struct tinytc_source_context; -//! @brief source_context handle -typedef struct tinytc_source_context *tinytc_source_context_t; -//! @brief const source_context handle -typedef const struct tinytc_source_context *const_tinytc_source_context_t; - -//! @struct tinytc_binary; -//! @brief Opaque struct for a binary -struct tinytc_binary; -//! @brief binary handle -typedef struct tinytc_binary *tinytc_binary_t; -//! @brief const binary handle -typedef const struct tinytc_binary *const_tinytc_binary_t; - -//! @struct tinytc_recipe; -//! @brief Opaque struct for a recipe -struct tinytc_recipe; -//! @brief recipe handle -typedef struct tinytc_recipe *tinytc_recipe_t; -//! @brief const recipe handle -typedef const struct tinytc_recipe *const_tinytc_recipe_t; - -//! @struct tinytc_recipe_handler; -//! @brief Opaque struct for a recipe handler -struct tinytc_recipe_handler; -//! @brief recipe_handler handle -typedef struct tinytc_recipe_handler *tinytc_recipe_handler_t; -//! @brief const recipe_handler handle -typedef const struct tinytc_recipe_handler *const_tinytc_recipe_handler_t; - -//////////////////////////// -////////// Structs ///////// -//////////////////////////// - -//! @brief Source code position -typedef struct tinytc_position { - int32_t source_id; ///< Source file identifier; 0 is "unknown source" - int32_t line; ///< Line number; counting starts at 1 - int32_t column; ///< Column number; counting start at 1 -} tinytc_position_t; - -//! @brief Source code location -typedef struct tinytc_location { - tinytc_position_t begin; ///< Starting position - tinytc_position_t end; ///< End position -} tinytc_location_t; - -#ifdef __cplusplus -} -#endif - -#endif // TYPES_20240410_H diff --git a/include/tinytc/types.h.mochi b/include/tinytc/types.h.mochi new file mode 100644 index 00000000..502781d5 --- /dev/null +++ b/include/tinytc/types.h.mochi @@ -0,0 +1,335 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef TYPES_20240410_H +#define TYPES_20240410_H + +#include "tinytc/export.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +//////////////////////////// +///////// Constants //////// +//////////////////////////// + +#define TINYTC_DYNAMIC INT64_MIN + +//////////////////////////// +/////// Enumerations /////// +//////////////////////////// + +// もち enum_h "tinytc/enums.anko" + +//! Type for combination of address spaces +typedef uint32_t tinytc_address_spaces_t; + +//! Type for combination of core feature flags +typedef uint32_t tinytc_core_feature_flags_t; + +#define TINYTC_INTEL_GPU_ARCHITECTURE_SUB_VERSION_BITS 0xfff + +//////////////////////////// +/////////// Types ////////// +//////////////////////////// + +//! @brief Bool type {0,1} +typedef uint8_t tinytc_bool_t; + +//! @struct tinytc_attr +//! @brief Opaque struct for an attribute +struct tinytc_attr; // IWYU pragma: export +//! @brief attribute handle +typedef struct tinytc_attr *tinytc_attr_t; +//! @brief const attribute handle +typedef const struct tinytc_attr *const_tinytc_attr_t; + +//! @struct tinytc_type +//! @brief Opaque struct for a data type +struct tinytc_type; // IWYU pragma: export +//! @brief type handle +typedef struct tinytc_type *tinytc_type_t; +//! @brief const type handle +typedef const struct tinytc_type *const_tinytc_type_t; + +//! @struct tinytc_value +//! @brief Opaque struct for a value +struct tinytc_value; // IWYU pragma: export +//! @brief value handle +typedef struct tinytc_value *tinytc_value_t; +//! @brief const value handle +typedef const struct tinytc_value *const_tinytc_value_t; + +//! @struct tinytc_inst +//! @brief Opaque struct for an instruction +struct tinytc_inst; // IWYU pragma: export +//! @brief inst handle +typedef struct tinytc_inst *tinytc_inst_t; +//! @brief const inst handle +typedef const struct tinytc_inst *const_tinytc_inst_t; +/** + * @brief Delete inst object + * + * @param instr [inout] inst object + */ +TINYTC_EXPORT void tinytc_inst_destroy(tinytc_inst_t instr); + +//! @brief inst iterator handle +typedef struct tinytc_inst *tinytc_inst_iterator_t; + +//! @struct tinytc_region +//! @brief Opaque struct for a region +struct tinytc_region; // IWYU pragma: export +//! @brief region handle +typedef struct tinytc_region *tinytc_region_t; +//! @brief const region handle +typedef const struct tinytc_region *const_tinytc_region_t; + +//! @struct tinytc_func +//! @brief Opaque struct for a function +struct tinytc_func; // IWYU pragma: export +//! @brief func handle +typedef struct tinytc_func *tinytc_func_t; +//! @brief const func handle +typedef const struct tinytc_func *const_tinytc_func_t; +/** + * @brief Delete function object + * + * @param fun [inout] function object + */ +TINYTC_EXPORT void tinytc_func_destroy(tinytc_func_t fun); + +//! @struct tinytc_prog +//! @brief Opaque struct for a program +struct tinytc_prog; // IWYU pragma: export +//! @brief prog handle +typedef struct tinytc_prog *tinytc_prog_t; +//! @brief const prog handle +typedef const struct tinytc_prog *const_tinytc_prog_t; +/** + * @brief Release program object + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param prg [inout] program object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_release(tinytc_prog_t prg); +/** + * @brief Increase reference count of program object by 1 + * + * @param prg [inout] program object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_prog_retain(tinytc_prog_t prg); + +//! @struct tinytc_spv_mod +//! @brief Opaque struct for a SPIR-V module +struct tinytc_spv_mod; // IWYU pragma: export +//! @brief spv_mod handle +typedef struct tinytc_spv_mod *tinytc_spv_mod_t; +//! @brief const spv_mod handle +typedef const struct tinytc_spv_mod *const_tinytc_spv_mod_t; +/** + * @brief Release SPIR-V module + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param mod [inout] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_release(tinytc_spv_mod_t mod); +/** + * @brief Increase reference count of SPIR-V module by 1 + * + * @param mod [inout] SPIR-V module + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_spv_mod_retain(tinytc_spv_mod_t mod); + +//! @struct tinytc_core_info; +//! @brief Opaque struct for core information +struct tinytc_core_info; // IWYU pragma: export +//! @brief core_info handle +typedef struct tinytc_core_info *tinytc_core_info_t; +//! @brief const core_info handle +typedef const struct tinytc_core_info *const_tinytc_core_info_t; +/** + * @brief Release core info object + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param obj [inout] core info object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_release(tinytc_core_info_t obj); +/** + * @brief Increase reference count of core info object by 1 + * + * @param obj [inout] core info object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_core_info_retain(tinytc_core_info_t obj); + +//! @struct tintyc_compiler_context +//! @brief Opaque struct for compiler context +struct tinytc_compiler_context; // IWYU pragma: export +//! @brief compiler_context handle +typedef struct tinytc_compiler_context *tinytc_compiler_context_t; +//! @brief const compiler_context handle +typedef const struct tinytc_compiler_context *const_tinytc_compiler_context_t; +/** + * @brief Release context object + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param obj [inout] context object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_release(tinytc_compiler_context_t obj); +/** + * @brief Increase reference count of context object by 1 + * + * @param obj [inout] context object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_compiler_context_retain(tinytc_compiler_context_t obj); + +//! @struct tinytc_binary; +//! @brief Opaque struct for a binary +struct tinytc_binary; // IWYU pragma: export +//! @brief binary handle +typedef struct tinytc_binary *tinytc_binary_t; +//! @brief const binary handle +typedef const struct tinytc_binary *const_tinytc_binary_t; +/** + * @brief Release binary object + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param bin [inout] binary object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_release(tinytc_binary_t bin); +/** + * @brief Increase reference count of binary object by 1 + * + * @param bin [inout] binary object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_binary_retain(tinytc_binary_t bin); + +//! @struct tinytc_recipe; +//! @brief Opaque struct for a recipe +struct tinytc_recipe; // IWYU pragma: export +//! @brief recipe handle +typedef struct tinytc_recipe *tinytc_recipe_t; +//! @brief const recipe handle +typedef const struct tinytc_recipe *const_tinytc_recipe_t; +/** + * @brief Release recipe object + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param obj [inout] recipe object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_release(tinytc_recipe_t obj); +/** + * @brief Increase reference count of recipe object by 1 + * + * @param obj [inout] recipe object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_retain(tinytc_recipe_t obj); + +//! @struct tinytc_recipe_handler; +//! @brief Opaque struct for a recipe handler +struct tinytc_recipe_handler; // IWYU pragma: export +//! @brief recipe_handler handle +typedef struct tinytc_recipe_handler *tinytc_recipe_handler_t; +//! @brief const recipe_handler handle +typedef const struct tinytc_recipe_handler *const_tinytc_recipe_handler_t; +/** + * @brief Release recipe handler object + * + * Decreases reference count by 1, free memory if reference count is 0. + * + * @param obj [inout] recipe handler object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_handler_release(tinytc_recipe_handler_t obj); +/** + * @brief Increase reference count of recipe handler object by 1 + * + * @param obj [inout] recipe handler object + * + * @return tinytc_status_success on success and error otherwise + */ +TINYTC_EXPORT tinytc_status_t tinytc_recipe_handler_retain(tinytc_recipe_handler_t obj); + +/** + * @brief Delete a (non-const) string returned from tinytc API + * + * @param str [in] string + */ +TINYTC_EXPORT void tinytc_string_destroy(char *str); + +//////////////////////////// +////////// Structs ///////// +//////////////////////////// + +//! @brief Named attribute +typedef struct tinytc_named_attr { + tinytc_attr_t name; ///< Name stored as string attribute + tinytc_attr_t attr; ///< Attribute +} tinytc_named_attr_t; + +//! @brief Source code position +typedef struct tinytc_position { + int32_t source_id; ///< Source file identifier; 0 is "unknown source" + int32_t line; ///< Line number; counting starts at 1 + int32_t column; ///< Column number; counting start at 1 +} tinytc_position_t; + +//! @brief Source code location +typedef struct tinytc_location { + tinytc_position_t begin; ///< Starting position + tinytc_position_t end; ///< End position +} tinytc_location_t; + +//////////////////////////// +///////// Callbacks //////// +//////////////////////////// + +/** + * @brief Signature for error reporting callback + * + * @param what Error description + * @param location Source code location + * @param user_data user data that is passed on to callback + */ +typedef void (*tinytc_error_reporter_t)(char const *what, const tinytc_location_t *location, + void *user_data); + +#ifdef __cplusplus +} +#endif + +#endif // TYPES_20240410_H diff --git a/include/tinytc/types.hpp b/include/tinytc/types.hpp deleted file mode 100644 index 06514da5..00000000 --- a/include/tinytc/types.hpp +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef TYPES_20240410_HPP -#define TYPES_20240410_HPP - -#include "tinytc/types.h" - -#include - -namespace tinytc { - -//////////////////////////// -///////// Constants //////// -//////////////////////////// - -constexpr static std::int64_t dynamic = TINYTC_DYNAMIC; - -//////////////////////////// -/////// Enumerations /////// -//////////////////////////// - -/** - * @brief Cf. @ref tinytc_status_t - * - * A status is typically thrown as exception, hence one should wrap calls as following: - * - * @code{.cpp} - * try { - * ... - * } catch (tinytc::status const& st) { - * ... - * } - * @endcode - */ -enum class status { - success = tinytc_status_success, - bad_alloc = tinytc_status_bad_alloc, - invalid_arguments = tinytc_status_invalid_arguments, - out_of_range = tinytc_status_out_of_range, - runtime_error = tinytc_status_runtime_error, - internal_compiler_error = tinytc_status_internal_compiler_error, - unsupported_subgroup_size = tinytc_status_unsupported_subgroup_size, - unsupported_work_group_size = tinytc_status_unsupported_work_group_size, - compilation_error = tinytc_status_compilation_error, - file_io_error = tinytc_status_file_io_error, - parse_error = tinytc_status_parse_error, - unavailable_extension = tinytc_status_unavailable_extension, - unsupported_backend = tinytc_status_unsupported_backend, - invalid_kernel_arguments = tinytc_status_invalid_kernel_arguments, - unsupported_device = tinytc_status_unsupported_device, - // IR errors - ir_out_of_bounds = tinytc_status_ir_out_of_bounds, - ir_invalid_shape = tinytc_status_ir_invalid_shape, - ir_incompatible_shapes = tinytc_status_ir_incompatible_shapes, - ir_shape_stride_mismatch = tinytc_status_ir_shape_stride_mismatch, - ir_scalar_mismatch = tinytc_status_ir_scalar_mismatch, - ir_invalid_number_of_indices = tinytc_status_ir_invalid_number_of_indices, - ir_expected_scalar = tinytc_status_ir_expected_scalar, - ir_expected_memref = tinytc_status_ir_expected_memref, - ir_expected_memref_or_scalar = tinytc_status_ir_expected_memref_or_scalar, - ir_expected_memref_or_group = tinytc_status_ir_expected_memref_or_group, - ir_expected_vector_or_matrix = tinytc_status_ir_expected_vector_or_matrix, - ir_unexpected_yield = tinytc_status_ir_unexpected_yield, - ir_yield_mismatch = tinytc_status_ir_yield_mismatch, - ir_multiple_dynamic_modes = tinytc_status_ir_multiple_dynamic_modes, - ir_invalid_slice = tinytc_status_ir_invalid_slice, - ir_expand_shape_order_too_small = tinytc_status_ir_expand_shape_order_too_small, - ir_expand_shape_mismatch = tinytc_status_ir_expand_shape_mismatch, - ir_collective_called_from_spmd = tinytc_status_ir_collective_called_from_spmd, - ir_fp_unsupported = tinytc_status_ir_fp_unsupported, - // Level Zero errors - ze_result_not_ready = tinytc_status_ze_result_not_ready, - ze_result_error_device_lost = tinytc_status_ze_result_error_device_lost, - ze_result_error_out_of_host_memory = tinytc_status_ze_result_error_out_of_host_memory, - ze_result_error_out_of_device_memory = tinytc_status_ze_result_error_out_of_device_memory, - ze_result_error_module_build_failure = tinytc_status_ze_result_error_module_build_failure, - ze_result_error_module_link_failure = tinytc_status_ze_result_error_module_link_failure, - ze_result_error_device_requires_reset = tinytc_status_ze_result_error_device_requires_reset, - ze_result_error_device_in_low_power_state = - tinytc_status_ze_result_error_device_in_low_power_state, - ze_result_exp_error_device_is_not_vertex = - tinytc_status_ze_result_exp_error_device_is_not_vertex, - ze_result_exp_error_vertex_is_not_device = - tinytc_status_ze_result_exp_error_vertex_is_not_device, - ze_result_exp_error_remote_device = tinytc_status_ze_result_exp_error_remote_device, - ze_result_exp_error_operands_incompatible = - tinytc_status_ze_result_exp_error_operands_incompatible, - ze_result_exp_rtas_build_retry = tinytc_status_ze_result_exp_rtas_build_retry, - ze_result_exp_rtas_build_deferred = tinytc_status_ze_result_exp_rtas_build_deferred, - ze_result_error_insufficient_permissions = - tinytc_status_ze_result_error_insufficient_permissions, - ze_result_error_not_available = tinytc_status_ze_result_error_not_available, - ze_result_error_dependency_unavailable = tinytc_status_ze_result_error_dependency_unavailable, - ze_result_warning_dropped_data = tinytc_status_ze_result_warning_dropped_data, - ze_result_error_uninitialized = tinytc_status_ze_result_error_uninitialized, - ze_result_error_unsupported_version = tinytc_status_ze_result_error_unsupported_version, - ze_result_error_unsupported_feature = tinytc_status_ze_result_error_unsupported_feature, - ze_result_error_invalid_argument = tinytc_status_ze_result_error_invalid_argument, - ze_result_error_invalid_null_handle = tinytc_status_ze_result_error_invalid_null_handle, - ze_result_error_handle_object_in_use = tinytc_status_ze_result_error_handle_object_in_use, - ze_result_error_invalid_null_pointer = tinytc_status_ze_result_error_invalid_null_pointer, - ze_result_error_invalid_size = tinytc_status_ze_result_error_invalid_size, - ze_result_error_unsupported_size = tinytc_status_ze_result_error_unsupported_size, - ze_result_error_unsupported_alignment = tinytc_status_ze_result_error_unsupported_alignment, - ze_result_error_invalid_synchronization_object = - tinytc_status_ze_result_error_invalid_synchronization_object, - ze_result_error_invalid_enumeration = tinytc_status_ze_result_error_invalid_enumeration, - ze_result_error_unsupported_enumeration = tinytc_status_ze_result_error_unsupported_enumeration, - ze_result_error_unsupported_image_format = - tinytc_status_ze_result_error_unsupported_image_format, - ze_result_error_invalid_native_binary = tinytc_status_ze_result_error_invalid_native_binary, - ze_result_error_invalid_global_name = tinytc_status_ze_result_error_invalid_global_name, - ze_result_error_invalid_kernel_name = tinytc_status_ze_result_error_invalid_kernel_name, - ze_result_error_invalid_function_name = tinytc_status_ze_result_error_invalid_function_name, - ze_result_error_invalid_group_size_dimension = - tinytc_status_ze_result_error_invalid_group_size_dimension, - ze_result_error_invalid_global_width_dimension = - tinytc_status_ze_result_error_invalid_global_width_dimension, - ze_result_error_invalid_kernel_argument_index = - tinytc_status_ze_result_error_invalid_kernel_argument_index, - ze_result_error_invalid_kernel_argument_size = - tinytc_status_ze_result_error_invalid_kernel_argument_size, - ze_result_error_invalid_kernel_attribute_value = - tinytc_status_ze_result_error_invalid_kernel_attribute_value, - ze_result_error_invalid_module_unlinked = tinytc_status_ze_result_error_invalid_module_unlinked, - ze_result_error_invalid_command_list_type = - tinytc_status_ze_result_error_invalid_command_list_type, - ze_result_error_overlapping_regions = tinytc_status_ze_result_error_overlapping_regions, - ze_result_warning_action_required = tinytc_status_ze_result_warning_action_required, - ze_result_error_unknown = tinytc_status_ze_result_error_unknown, - // OpenCL errors - cl_build_program_failure = tinytc_status_cl_build_program_failure, - cl_compile_program_failure = tinytc_status_cl_compile_program_failure, - cl_compiler_not_available = tinytc_status_cl_compiler_not_available, - cl_device_not_found = tinytc_status_cl_device_not_found, - cl_device_not_available = tinytc_status_cl_device_not_available, - cl_device_partition_failed = tinytc_status_cl_device_partition_failed, - cl_exec_status_error_for_events_in_wait_list = - tinytc_status_cl_exec_status_error_for_events_in_wait_list, - cl_image_format_mismatch = tinytc_status_cl_image_format_mismatch, - cl_image_format_not_supported = tinytc_status_cl_image_format_not_supported, - cl_invalid_arg_index = tinytc_status_cl_invalid_arg_index, - cl_invalid_arg_size = tinytc_status_cl_invalid_arg_size, - cl_invalid_arg_value = tinytc_status_cl_invalid_arg_value, - cl_invalid_binary = tinytc_status_cl_invalid_binary, - cl_invalid_buffer_size = tinytc_status_cl_invalid_buffer_size, - cl_invalid_build_options = tinytc_status_cl_invalid_build_options, - cl_invalid_command_queue = tinytc_status_cl_invalid_command_queue, - cl_invalid_compiler_options = tinytc_status_cl_invalid_compiler_options, - cl_invalid_context = tinytc_status_cl_invalid_context, - cl_invalid_device = tinytc_status_cl_invalid_device, - cl_invalid_device_partition_count = tinytc_status_cl_invalid_device_partition_count, - cl_invalid_device_queue = tinytc_status_cl_invalid_device_queue, - cl_invalid_device_type = tinytc_status_cl_invalid_device_type, - cl_invalid_event = tinytc_status_cl_invalid_event, - cl_invalid_event_wait_list = tinytc_status_cl_invalid_event_wait_list, - cl_invalid_global_offset = tinytc_status_cl_invalid_global_offset, - cl_invalid_global_work_size = tinytc_status_cl_invalid_global_work_size, - cl_invalid_host_ptr = tinytc_status_cl_invalid_host_ptr, - cl_invalid_image_descriptor = tinytc_status_cl_invalid_image_descriptor, - cl_invalid_image_format_descriptor = tinytc_status_cl_invalid_image_format_descriptor, - cl_invalid_image_size = tinytc_status_cl_invalid_image_size, - cl_invalid_kernel = tinytc_status_cl_invalid_kernel, - cl_invalid_kernel_args = tinytc_status_cl_invalid_kernel_args, - cl_invalid_kernel_definition = tinytc_status_cl_invalid_kernel_definition, - cl_invalid_kernel_name = tinytc_status_cl_invalid_kernel_name, - cl_invalid_linker_options = tinytc_status_cl_invalid_linker_options, - cl_invalid_mem_object = tinytc_status_cl_invalid_mem_object, - cl_invalid_operation = tinytc_status_cl_invalid_operation, - cl_invalid_pipe_size = tinytc_status_cl_invalid_pipe_size, - cl_invalid_platform = tinytc_status_cl_invalid_platform, - cl_invalid_program = tinytc_status_cl_invalid_program, - cl_invalid_program_executable = tinytc_status_cl_invalid_program_executable, - cl_invalid_property = tinytc_status_cl_invalid_property, - cl_invalid_queue_properties = tinytc_status_cl_invalid_queue_properties, - cl_invalid_sampler = tinytc_status_cl_invalid_sampler, - cl_invalid_spec_id = tinytc_status_cl_invalid_spec_id, - cl_invalid_value = tinytc_status_cl_invalid_value, - cl_invalid_work_dimension = tinytc_status_cl_invalid_work_dimension, - cl_invalid_work_group_size = tinytc_status_cl_invalid_work_group_size, - cl_invalid_work_item_size = tinytc_status_cl_invalid_work_item_size, - cl_kernel_arg_info_not_available = tinytc_status_cl_kernel_arg_info_not_available, - cl_link_program_failure = tinytc_status_cl_link_program_failure, - cl_linker_not_available = tinytc_status_cl_linker_not_available, - cl_map_failure = tinytc_status_cl_map_failure, - cl_mem_copy_overlap = tinytc_status_cl_mem_copy_overlap, - cl_mem_object_allocation_failure = tinytc_status_cl_mem_object_allocation_failure, - cl_misaligned_sub_buffer_offset = tinytc_status_cl_misaligned_sub_buffer_offset, - cl_out_of_host_memory = tinytc_status_cl_out_of_host_memory, - cl_out_of_resources = tinytc_status_cl_out_of_resources, - cl_max_size_restriction_exceeded = tinytc_status_cl_max_size_restriction_exceeded, - cl_profiling_info_not_available = tinytc_status_cl_profiling_info_not_available, - // The unknown error comes last - unknown = tinytc_status_unknown -}; - -//! Scalar types -enum class scalar_type { - i1 = tinytc_scalar_type_i1, ///< Signed 1 bit integer (boolean) - i8 = tinytc_scalar_type_i8, ///< Signed 8 bit integer - i16 = tinytc_scalar_type_i16, ///< Signed 16 bit integer - i32 = tinytc_scalar_type_i32, ///< Signed 32 bit integer - i64 = tinytc_scalar_type_i64, ///< Signed 64 bit integer - index = tinytc_scalar_type_index, ///< Unsigned Integer type for indices - f32 = tinytc_scalar_type_f32, ///< Single precision floating point (32 bit) - f64 = tinytc_scalar_type_f64 ///< Double precision floating point (64 bit) -}; - -//! Arithmetic operations -enum class arithmetic { - add = tinytc_arithmetic_add, ///< add - sub = tinytc_arithmetic_sub, ///< subtract - mul = tinytc_arithmetic_mul, ///< multiply - div = tinytc_arithmetic_div, ///< divide - rem = tinytc_arithmetic_rem, ///< division remainder - shl = tinytc_arithmetic_shl, ///< left shift - shr = tinytc_arithmetic_shr, ///< arithmetic right shift - and_ = tinytc_arithmetic_and, ///< bitwise and - or_ = tinytc_arithmetic_or, ///< bitwise or - xor_ = tinytc_arithmetic_xor ///< bitwise xor -}; - -//! Arithmetic operations (unary) -enum class arithmetic_unary { - neg = tinytc_arithmetic_unary_neg, ///< negation - not_ = tinytc_arithmetic_unary_not ///< bitwise not -}; - -//! Compare operation -enum class cmp_condition { - eq = tinytc_cmp_condition_eq, ///< equals - ne = tinytc_cmp_condition_ne, ///< not equal - gt = tinytc_cmp_condition_gt, ///< greater than - ge = tinytc_cmp_condition_ge, ///< greather or equal than - lt = tinytc_cmp_condition_lt, ///< less than - le = tinytc_cmp_condition_le ///< less or equal than -}; -//! Transpose -enum class transpose { - N = tinytc_transpose_N, ///< no transpose - T = tinytc_transpose_T ///< transpose -}; - -//! @brief Cf. @ref tinytc_core_feature_flag_t -enum class core_feature_flag { large_register_file = tinytc_core_feature_flag_large_register_file }; - -//! @brief Cf. @ref tinytc_intel_gpu_architecture_t -enum class intel_gpu_architecture { - tgl = tinytc_intel_gpu_architecture_tgl, - pvc = tinytc_intel_gpu_architecture_pvc -}; - -//! Target binary format -enum class bundle_format { - spirv = tinytc_bundle_format_spirv, ///< SPIR-V - native = tinytc_bundle_format_native ///< Native device binary -}; - -//! Memory object type -enum class mem_type { - buffer = tinytc_mem_type_buffer, ///< Buffer object (e.g. cl_mem) - usm_pointer = tinytc_mem_type_usm_pointer, ///< Unified shared memory pointer - svm_pointer = tinytc_mem_type_svm_pointer, ///< Shared virtual memory pointer -}; - -//! Support level of a device -enum class support_level { - //! Device is unsupported (e.g. subgroups feature missing in OpenCL-C) - none = tinytc_support_level_none, - //! Device provides necessary features but is not well tested - basic = tinytc_support_level_basic, - //! Device provides necessary features and is well tested - tuned = tinytc_support_level_tuned -}; - -//////////////////////////// -/////// Type aliases /////// -//////////////////////////// - -//! @brief Alias for tinytc_position in namespace tinytc -using position = ::tinytc_position; -//! @brief Alias for tinytc_location in namespace tinytc -using location = ::tinytc_location; - -} // namespace tinytc - -#endif // TYPES_20240410_HPP diff --git a/include/tinytc/types.hpp.mochi b/include/tinytc/types.hpp.mochi new file mode 100644 index 00000000..28535e76 --- /dev/null +++ b/include/tinytc/types.hpp.mochi @@ -0,0 +1,348 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef TYPES_20240410_HPP +#define TYPES_20240410_HPP + +#include "tinytc/types.h" + +#include +#include + +namespace tinytc { + +//////////////////////////// +///////// Constants //////// +//////////////////////////// + +constexpr static std::int64_t dynamic = TINYTC_DYNAMIC; + +//! Check if mode i is dynamic ('?') +inline bool is_dynamic_value(std::int64_t i) { return i == dynamic; } + +//////////////////////////// +/////// Enumerations /////// +//////////////////////////// + +// もち enum_hpp "tinytc/enums.anko" + +//////////////////////////// +/////// Type aliases /////// +//////////////////////////// + +//! @brief Alias for tinytc_position in namespace tinytc +using position = ::tinytc_position; +//! @brief Alias for tinytc_location in namespace tinytc +using location = ::tinytc_location; + +//////////////////////////// +/////////// Error ////////// +//////////////////////////// + +//! Throw exception for unsuccessful call to C-API +inline void CHECK_STATUS(tinytc_status_t code) { + if (code != tinytc_status_success) { + throw status{std::underlying_type_t(code)}; + } +} + +//////////////////////////// +// Shared / unique handle // +//////////////////////////// + +namespace internal { +//! Wraps retain / release calls for type T +template struct shared_handle_traits {}; + +//! Wraps destroy calls for type T +template struct unique_handle_traits {}; +} // namespace internal + +/** + * @brief Wraps a C handle in a reference-counted object + * + * @tparam T C handle type (handle type = pointer to opaque struct) + */ +template class shared_handle { + public: + //! Traits shortcut + using traits = internal::shared_handle_traits; + //! Typedef for native C handle + using native_type = T; + + //! Create empty (invalid) handle + shared_handle() : obj_{nullptr} {} + //! Create handle from C handle + explicit shared_handle(T obj, bool needs_retain = false) : obj_(obj) { + if (needs_retain) { + CHECK_STATUS(c_retain()); + } + } + //! Decrease reference count + ~shared_handle() { c_release(); } + //! Copy ctor + shared_handle(shared_handle const &other) : obj_(other.obj_) { CHECK_STATUS(c_retain()); } + //! Move ctor + shared_handle(shared_handle &&other) noexcept : obj_(other.obj_) { other.obj_ = nullptr; } + //! Copy operator + shared_handle &operator=(shared_handle const &other) { + if (obj_ != other.obj_) { + CHECK_STATUS(c_release()); + obj_ = other.obj_; + CHECK_STATUS(c_retain()); + } + return *this; + } + //! Move operator + shared_handle &operator=(shared_handle &&other) { + if (obj_ != other.obj_) { + CHECK_STATUS(c_release()); + obj_ = other.obj_; + other.obj_ = nullptr; + } + return *this; + } + + //! Dereference C handle and get reference to underlying type + auto operator*() const -> std::remove_pointer_t & { return *obj_; } + //! Convert handle to C handle + auto operator->() const -> T { return obj_; } + //! Returns C handle + auto get() const -> T { return obj_; } + //! Returns C handle and releases the ownership of the managed object + auto release() -> T { + auto tmp = obj_; + obj_ = nullptr; + return tmp; + } + + //! Check whether handle is non-empty (valid) + explicit operator bool() const noexcept { return obj_ != nullptr; } + + //! Check equality + bool operator==(shared_handle const &other) const { return obj_ == other.obj_; } + //! Check inequality + bool operator!=(shared_handle const &other) const { return !(*this == other); } + + protected: + //! Call retain in C-API if C handle is not NULL + auto c_retain() -> tinytc_status_t { + if (obj_ != nullptr) { + return traits::retain(obj_); + } + return tinytc_status_success; + } + //! Call release in C-API if C handle is not NULL + auto c_release() -> tinytc_status_t { + if (obj_ != nullptr) { + return traits::release(obj_); + } + return tinytc_status_success; + } + //! The C handle + T obj_; +}; + +/** + * @brief Wraps a C handle in a unique_ptr-alike object + * + * @tparam T C handle type (handle type = pointer to opaque struct) + */ +template class unique_handle { + public: + //! Traits shortcut + using traits = internal::unique_handle_traits; + //! Typedef for native C handle + using native_type = T; + + //! Create empty (invalid) handle + unique_handle() : obj_{nullptr} {} + //! Create handle from C handle + explicit unique_handle(T obj) : obj_(obj) {} + //! Destroy object + ~unique_handle() { + if (obj_) { + traits::destroy(obj_); + } + } + //! Copy ctor + unique_handle(unique_handle const &other) = delete; + //! Move ctor + unique_handle(unique_handle &&other) noexcept : obj_(other.obj_) { other.obj_ = nullptr; } + //! Copy operator + unique_handle &operator=(unique_handle const &other) = delete; + //! Move operator + unique_handle &operator=(unique_handle &&other) { + obj_ = other.obj_; + other.obj_ = nullptr; + return *this; + } + + //! Dereference C handle and get reference to underlying type + auto operator*() const -> std::remove_pointer_t & { return *obj_; } + //! Convert handle to C handle + auto operator->() const -> T { return obj_; } + //! Returns C handle + auto get() const -> T { return obj_; } + //! Returns C handle and releases the ownership of the managed object + auto release() -> T { + auto tmp = obj_; + obj_ = nullptr; + return tmp; + } + + //! Check whether handle is non-empty (valid) + explicit operator bool() const noexcept { return obj_ != nullptr; } + + //! Check equality + bool operator==(unique_handle const &other) const { return obj_ == other.obj_; } + //! Check inequality + bool operator!=(unique_handle const &other) const { return !(*this == other); } + + protected: + //! The C handle + T obj_; +}; + +//////////////////////////// +/////////// Attr /////////// +//////////////////////////// + +class array_attr; // IWYU pragma: export +class boolean_attr; // IWYU pragma: export +class dictionary_attr; // IWYU pragma: export +class integer_attr; // IWYU pragma: export +class string_attr; // IWYU pragma: export + +//////////////////////////// +/////////// Type /////////// +//////////////////////////// + +// もち forward_hpp "tinytc/types.anko" + +//////////////////////////// +/////////// Inst /////////// +//////////////////////////// + +namespace internal { +template <> struct unique_handle_traits { + static void destroy(tinytc_inst_t handle) { return tinytc_inst_destroy(handle); } +}; +} // namespace internal + +// もち forward_hpp "tinytc/instructions.anko" + +//////////////////////////// +/////////// Func /////////// +//////////////////////////// + +namespace internal { +template <> struct unique_handle_traits { + static void destroy(tinytc_func_t handle) { return tinytc_func_destroy(handle); } +}; +} // namespace internal + +//////////////////////////// +/////////// Prog /////////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_prog_t handle) -> tinytc_status_t { + return tinytc_prog_retain(handle); + } + static auto release(tinytc_prog_t handle) -> tinytc_status_t { + return tinytc_prog_release(handle); + } +}; +template <> struct unique_handle_traits { + static void destroy(char *obj) { tinytc_string_destroy(obj); } +}; +} // namespace internal + +//////////////////////////// +///// Compiler context ///// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_compiler_context_t handle) -> tinytc_status_t { + return tinytc_compiler_context_retain(handle); + } + static auto release(tinytc_compiler_context_t handle) -> tinytc_status_t { + return tinytc_compiler_context_release(handle); + } +}; +} // namespace internal + +//////////////////////////// +/////// SPIR-V Module ////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_spv_mod_t handle) -> tinytc_status_t { + return tinytc_spv_mod_retain(handle); + } + static auto release(tinytc_spv_mod_t handle) -> tinytc_status_t { + return tinytc_spv_mod_release(handle); + } +}; +} // namespace internal + +//////////////////////////// +//////// Device info /////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_core_info_t handle) -> tinytc_status_t { + return tinytc_core_info_retain(handle); + } + static auto release(tinytc_core_info_t handle) -> tinytc_status_t { + return tinytc_core_info_release(handle); + } +}; +} // namespace internal + +//////////////////////////// +///////// Compiler ///////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_binary_t handle) -> tinytc_status_t { + return tinytc_binary_retain(handle); + } + static auto release(tinytc_binary_t handle) -> tinytc_status_t { + return tinytc_binary_release(handle); + } +}; +} // namespace internal + +//////////////////////////// +////////// Recipe ////////// +//////////////////////////// + +namespace internal { +template <> struct shared_handle_traits { + static auto retain(tinytc_recipe_t handle) -> tinytc_status_t { + return tinytc_recipe_retain(handle); + } + static auto release(tinytc_recipe_t handle) -> tinytc_status_t { + return tinytc_recipe_release(handle); + } +}; +template <> struct shared_handle_traits { + static auto retain(tinytc_recipe_handler_t handle) -> tinytc_status_t { + return tinytc_recipe_handler_retain(handle); + } + static auto release(tinytc_recipe_handler_t handle) -> tinytc_status_t { + return tinytc_recipe_handler_release(handle); + } +}; +} // namespace internal + +} // namespace tinytc + +#endif // TYPES_20240410_HPP diff --git a/include/tinytc/version.h.in b/include/tinytc/version.h.in index 5559caae..36cfcc36 100644 --- a/include/tinytc/version.h.in +++ b/include/tinytc/version.h.in @@ -5,12 +5,12 @@ #define VERSION_20240408_H // clang-format off -#define TINYTC_VERSION_MAJOR = @GIT_MAJOR_VERSION@; ///< Major version (X.x.x) -#define TINYTC_VERSION_MINOR = @GIT_MINOR_VERSION@; ///< Minor version (x.X.x) -#define TINYTC_VERSION_PATCH = @GIT_PATCH_VERSION@; ///< Patch version (x.x.X) -#define TINYTC_VERSION_HASH = "@GIT_COMMIT@"; ///< Git commit hash -#define TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE = @GIT_COMMITS_SINCE_RELEASE@; ///< Number of commits since last tag -#define TINYTC_VERSION_DESCRIPTION = "v@GIT_MAJOR_VERSION@.@GIT_MINOR_VERSION@.@GIT_PATCH_VERSION@-@GIT_COMMITS_SINCE_RELEASE@-@GIT_COMMIT@"; ///< Version string (vx.x.x-x-x) +#define TINYTC_VERSION_MAJOR @GIT_MAJOR_VERSION@ ///< Major version (X.x.x) +#define TINYTC_VERSION_MINOR @GIT_MINOR_VERSION@ ///< Minor version (x.X.x) +#define TINYTC_VERSION_PATCH @GIT_PATCH_VERSION@ ///< Patch version (x.x.X) +#define TINYTC_VERSION_HASH "@GIT_COMMIT@" ///< Git commit hash +#define TINYTC_VERSION_NUMBER_OF_COMMITS_SINCE_RELEASE @GIT_COMMITS_SINCE_RELEASE@ ///< Number of commits since last tag +#define TINYTC_VERSION_DESCRIPTION "v@GIT_MAJOR_VERSION@.@GIT_MINOR_VERSION@.@GIT_PATCH_VERSION@-@GIT_COMMITS_SINCE_RELEASE@-@GIT_COMMIT@" ///< Version string (vx.x.x-x-x) // clang-format on #endif // VERSION_20240408_H diff --git a/iwyu.imp b/iwyu.imp new file mode 100644 index 00000000..ed56ff03 --- /dev/null +++ b/iwyu.imp @@ -0,0 +1,4 @@ +[ + { "symbol": ["std::array", "private", "", "public"] }, + { "symbol": ["std::pair", "private", "", "public"] } +] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3d74f3db..11b4baf4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause include(CommonOptions) +include(GeneratedFiles) include(GenerateExportHeader) include(GitVersion) include(InstallLib) @@ -12,75 +13,119 @@ else () set(type static) endif () -find_package(clir 0.5.1 REQUIRED ${type}) -find_package(re2c REQUIRED) -find_package(BISON 3.8.2 REQUIRED) - set(SOURCES + analysis/aa_results.cpp + analysis/alias.cpp + analysis/cfg.cpp + analysis/gcd.cpp + analysis/stack.cpp binary.cpp codegen_tools.cpp compiler.cpp - data_type.cpp + compiler_context.cpp + compiler_context_cache.cpp + coopmatrix_layout.cpp device_info.cpp error.cpp - func.cpp - gemm_generator.cpp - inst.cpp + gemm_tools.cpp + half.cpp location.cpp - node/data_type_node.cpp - node/inst_node.cpp + matrix_ext_info.cpp + node/attr.cpp + node/func.cpp + node/inst.cpp + node/inst_view_impl.cpp + node/region.cpp + node/prog.cpp + node/value.cpp + number.cpp parser/parse_context.cpp parser.cpp - passes.cpp - precision_helper.cpp - prog.cpp + pass/check_ir.cpp + pass/clone.cpp + pass/constant_folding.cpp + pass/constant_propagation.cpp + pass/convert_to_spirv.cpp + pass/dead_code_elimination.cpp + pass/dump_cfg.cpp + pass/dump_def_use.cpp + pass/dump_gcd.cpp + pass/dump_ir.cpp + pass/insert_barrier.cpp + pass/insert_lifetime_stop.cpp + pass/lower_coopmatrix.cpp + pass/lower_foreach.cpp + pass/lower_linalg.cpp + pass/slot_tracker.cpp + pass/stack.cpp + pass/work_group_size.cpp recipe.cpp recipe/small_gemm_batched.cpp recipe/tall_and_skinny.cpp - region.cpp - required_extensions.cpp - scalar_type.cpp - source.cpp + spv/block2d_diy.cpp + spv/capex_util.cpp + spv/converter.cpp + spv/converter_aux.cpp + spv/coopmatrix_impl.cpp + spv/coopmatrix_impl_block.cpp + spv/coopmatrix_impl_dpas.cpp + spv/dope_vector.cpp + spv/inst_assembler.cpp + spv/matrix_walker.cpp + spv/module.cpp + spv/names.cpp + spv/opencl.std.cpp + spv/pass/assemble.cpp + spv/pass/assign_ids.cpp + spv/pass/dump_asm.cpp + spv/pass/capex.cpp + spv/uniquifier.cpp + support/temp_counter.cpp tiling.cpp - value.cpp - visitor/aa_results.cpp - visitor/alias_analysis.cpp - visitor/check_ir.cpp - visitor/dump_ir.cpp - visitor/equal.cpp - visitor/insert_barrier.cpp - visitor/lifetime_analysis.cpp - visitor/metadata.cpp - visitor/opencl_ast.cpp - visitor/slot_tracker.cpp - visitor/stack.cpp - visitor/work_group_size.cpp -) -set(RE2C_SOURCES - parser/lexer.re + support/walk.cpp ) -BISON_TARGET(parser parser/parser_impl.yy ${CMAKE_CURRENT_BINARY_DIR}/parser/parser_impl.cpp - DEFINES_FILE ${CMAKE_CURRENT_BINARY_DIR}/parser/parser_impl.hpp) + +add_library(tinytc-objects OBJECT ${SOURCES}) + +add_re2c_or_pregenerated_to_target(TARGET tinytc-objects SOURCES parser/lexer.re) +add_bison_or_pregenerated_to_target(TARGET tinytc-objects SOURCES parser/parser_impl.yy) +add_mochi_or_pregenerated_to_target(TARGET tinytc-objects + SOURCES + builder.cpp.mochi + enums.cpp.mochi + node/inst_kind.cpp.mochi + node/inst_view.cpp.mochi + node/inst_view.hpp.mochi + node/type.cpp.mochi + node/type.hpp.mochi + node/visit.hpp.mochi + ${PROJECT_SOURCE_DIR}/include/tinytc/builder.h.mochi + ${PROJECT_SOURCE_DIR}/include/tinytc/builder.hpp.mochi + ${PROJECT_SOURCE_DIR}/include/tinytc/types.h.mochi + ${PROJECT_SOURCE_DIR}/include/tinytc/types.hpp.mochi + DEPENDS + ${PROJECT_SOURCE_DIR}/include/tinytc/enums.anko + ${PROJECT_SOURCE_DIR}/include/tinytc/instructions.anko + ${PROJECT_SOURCE_DIR}/include/tinytc/types.anko + SEARCH_PATHS ${PROJECT_SOURCE_DIR}/include) + set(PUBLIC_HEADERS - tinytc.h - tinytc.hpp - types.h - types.hpp + ${PROJECT_BINARY_DIR}/include/tinytc/builder.h + ${PROJECT_BINARY_DIR}/include/tinytc/builder.hpp + ${PROJECT_SOURCE_DIR}/include/tinytc/core.h + ${PROJECT_SOURCE_DIR}/include/tinytc/core.hpp + ${PROJECT_SOURCE_DIR}/include/tinytc/tinytc.h + ${PROJECT_SOURCE_DIR}/include/tinytc/tinytc.hpp + ${PROJECT_BINARY_DIR}/include/tinytc/types.h + ${PROJECT_BINARY_DIR}/include/tinytc/types.hpp ) -list(TRANSFORM PUBLIC_HEADERS PREPEND "${PROJECT_SOURCE_DIR}/include/tinytc/") - -add_flag_if_available_to_source_files(CXX "${BISON_parser_OUTPUTS}" "-Wno-unused-but-set-variable") -add_library(tinytc-objects OBJECT ${SOURCES} ${BISON_parser_OUTPUTS}) -add_re2c_to_target(TARGET tinytc-objects SOURCES ${RE2C_SOURCES}) set_cxx_common_options(tinytc-objects) -target_link_libraries(tinytc-objects PUBLIC clir::clir) target_compile_definitions(tinytc-objects PUBLIC "$<$>:TINYTC_STATIC_DEFINE>") add_library(tinytc $) add_library(tinytc::tinytc ALIAS tinytc) -target_link_libraries(tinytc PRIVATE clir::clir) set_cxx_common_options(tinytc) # Generate export header @@ -98,6 +143,7 @@ configure_file(${tinytc_version_header_in} ${tinytc_version_header}) target_include_directories(tinytc-objects PRIVATE "$" "$" + "$" ) target_include_directories(tinytc-objects PUBLIC "$" @@ -113,7 +159,6 @@ target_sources(tinytc PUBLIC FILE_SET HEADERS # install - install_lib(tinytc tinytc) # subdirs @@ -126,3 +171,9 @@ endif() if(BUILD_LEVEL_ZERO) add_subdirectory(ze) endif() + +# cpack + +include(GeneratedFiles) +write_generated_files(tinytc-objects) + diff --git a/src/visitor/aa_results.cpp b/src/analysis/aa_results.cpp similarity index 58% rename from src/visitor/aa_results.cpp rename to src/analysis/aa_results.cpp index dbfafba3..161b46d4 100644 --- a/src/visitor/aa_results.cpp +++ b/src/analysis/aa_results.cpp @@ -1,25 +1,24 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "visitor/aa_results.hpp" -#include "node/value_node.hpp" +#include "analysis/aa_results.hpp" #include namespace tinytc { -aa_results::aa_results(std::unordered_map alias, - std::unordered_map allocs) +aa_results::aa_results(std::unordered_map alias, + std::unordered_map allocs) : alias_(std::move(alias)), allocs_(std::move(allocs)) {} -auto aa_results::root(value_node const &a) -> value_node const * { +auto aa_results::root(tinytc_value const &a) const -> const_tinytc_value_t { auto root = &a; - if (alias_.find(root) != alias_.end()) { - root = alias_[root]; + if (auto it = alias_.find(root); it != alias_.end()) { + root = it->second; } return root; } -bool aa_results::alias(value_node const &a, value_node const &b) { +bool aa_results::alias(tinytc_value const &a, tinytc_value const &b) const { auto ra = root(a); auto rb = root(b); if (ra == rb) { diff --git a/src/visitor/aa_results.hpp b/src/analysis/aa_results.hpp similarity index 85% rename from src/visitor/aa_results.hpp rename to src/analysis/aa_results.hpp index e65f2613..3c553e64 100644 --- a/src/visitor/aa_results.hpp +++ b/src/analysis/aa_results.hpp @@ -13,21 +13,19 @@ namespace tinytc { class aa_results { public: - aa_results() = default; - auto root(::tinytc_value const &a) -> ::tinytc_value const *; - bool alias(::tinytc_value const &a, ::tinytc_value const &b); - - private: struct allocation { std::int64_t start, stop; }; aa_results(std::unordered_map<::tinytc_value const *, ::tinytc_value const *> alias, std::unordered_map<::tinytc_value const *, allocation> allocs); + + auto root(::tinytc_value const &a) const -> ::tinytc_value const *; + bool alias(::tinytc_value const &a, ::tinytc_value const &b) const; + + private: std::unordered_map<::tinytc_value const *, ::tinytc_value const *> alias_; std::unordered_map<::tinytc_value const *, allocation> allocs_; - - friend class alias_analyser; }; } // namespace tinytc diff --git a/src/analysis/alias.cpp b/src/analysis/alias.cpp new file mode 100644 index 00000000..67385460 --- /dev/null +++ b/src/analysis/alias.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/alias.hpp" +#include "error.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "support/walk.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include +#include + +namespace tinytc { + +class alias_analysis_visitor { + public: + void operator()(inst_view); + void operator()(alloca_inst a); + void operator()(expand_inst e); + void operator()(fuse_inst f); + void operator()(subview_inst s); + + auto get_result() && -> aa_results { return aa_results(std::move(alias_), std::move(allocs_)); } + + private: + std::unordered_map allocs_; + std::unordered_map alias_; +}; + +void alias_analysis_visitor::operator()(inst_view) {} +void alias_analysis_visitor::operator()(alloca_inst a) { + if (a.stack_ptr() >= 0) { + auto t = dyn_cast(a.result().ty()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + allocs_[&a.result()] = + aa_results::allocation{a.stack_ptr(), a.stack_ptr() + t->size_in_bytes()}; + } +} +void alias_analysis_visitor::operator()(expand_inst e) { + const_tinytc_value_t source = &e.operand(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[&e.result()] = source; +} +void alias_analysis_visitor::operator()(fuse_inst f) { + const_tinytc_value_t source = &f.operand(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[&f.result()] = source; +} + +void alias_analysis_visitor::operator()(subview_inst s) { + const_tinytc_value_t source = &s.operand(); + while (alias_.find(source) != alias_.end()) { + source = alias_[source]; + } + alias_[&s.result()] = source; +} + +auto alias_analysis::run_on_function(tinytc_func &fn) -> aa_results { + auto visitor = alias_analysis_visitor{}; + + walk(fn, [&visitor](tinytc_inst &i) { visit(visitor, i); }); + + return std::move(visitor).get_result(); +} + +} // namespace tinytc diff --git a/src/analysis/alias.hpp b/src/analysis/alias.hpp new file mode 100644 index 00000000..7cc7017d --- /dev/null +++ b/src/analysis/alias.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ALIAS_20240912_HPP +#define ALIAS_20240912_HPP + +#include "analysis/aa_results.hpp" +#include "tinytc/types.h" + +namespace tinytc { + +class alias_analysis { + public: + auto run_on_function(tinytc_func &fn) -> aa_results; +}; + +} // namespace tinytc + +#endif // ALIAS_20240912_HPP diff --git a/src/analysis/cfg.cpp b/src/analysis/cfg.cpp new file mode 100644 index 00000000..2cbba24f --- /dev/null +++ b/src/analysis/cfg.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/cfg.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc { + +void control_flow_graph::insert_before(tinytc_inst_t before_inst, tinytc_inst_t new_inst) { + add_node(new_inst, adj_[before_inst].kind_max); + adj_[new_inst].pred = std::move(adj_[before_inst].pred); + add_edge(new_inst, before_inst); +} + +auto control_flow_graph::node_queue() const -> std::queue { + auto q = std::queue{}; + for (auto &[key, neighbors] : adj_) { + q.push(key); + } + return q; +} + +auto get_control_flow_graph(tinytc_region &topreg) -> control_flow_graph { + auto cfg = control_flow_graph{}; + + const auto add_region = + [&cfg](tinytc_region ®, region_kind kind_max, + auto &add_region_ref) -> std::pair> { + if (reg.empty()) { + return {}; + } + + auto pred_nodes = std::queue{}; + const auto visit_inst = [&](tinytc_inst_t node) { + bool empty_child_regions = true; + if (node->num_child_regions() > 0) { + for (auto &subreg : node->child_regions()) { + auto [substart, subexits] = + add_region_ref(subreg, std::max(kind_max, subreg.kind()), add_region_ref); + if (substart != nullptr && !subexits.empty()) { + empty_child_regions = false; + cfg.add_edge(node, substart); + if (isa(*node)) { + for (; !subexits.empty(); subexits.pop()) { + cfg.add_edge(subexits.front(), node); + } + pred_nodes.push(node); + } else { + for (; !subexits.empty(); subexits.pop()) { + pred_nodes.push(subexits.front()); + } + } + } + } + } + if (empty_child_regions) { + pred_nodes.push(node); + } + }; + + auto start = reg.begin().get(); + cfg.add_node(start, kind_max); + visit_inst(start); + + for (auto it = ++reg.begin(); it != reg.end(); ++it) { + tinytc_inst_t node = it.get(); + cfg.add_node(node, kind_max); + + for (; !pred_nodes.empty(); pred_nodes.pop()) { + cfg.add_edge(pred_nodes.front(), node); + } + + visit_inst(node); + } + + return std::make_pair(std::move(start), std::move(pred_nodes)); + }; + + add_region(topreg, topreg.kind(), add_region); + + return cfg; +} + +} // namespace tinytc diff --git a/src/analysis/cfg.hpp b/src/analysis/cfg.hpp new file mode 100644 index 00000000..f8590f49 --- /dev/null +++ b/src/analysis/cfg.hpp @@ -0,0 +1,61 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CFG_20240919_HPP +#define CFG_20240919_HPP + +#include "node/inst.hpp" +#include "node/region.hpp" +#include "tinytc/types.h" +#include "util/iterator.hpp" + +#include +#include +#include + +namespace tinytc { + +class control_flow_graph { + public: + inline void add_node(tinytc_inst_t a, region_kind kind_max) { + adj_[a] = adjacency_list{}; + adj_[a].kind_max = kind_max; + } + inline void add_edge(tinytc_inst_t a, tinytc_inst_t b) { + adj_[a].succ.push_back(b); + adj_[b].pred.push_back(a); + } + void insert_before(tinytc_inst_t before_inst, tinytc_inst_t new_inst); + + auto node_queue() const -> std::queue; + + inline auto kind_max(tinytc_inst_t a) -> region_kind { return adj_[a].kind_max; } + + inline auto pred_begin(tinytc_inst_t a) { return adj_[a].pred.begin(); } + inline auto pred_end(tinytc_inst_t a) { return adj_[a].pred.end(); } + inline auto predecessors(tinytc_inst_t a) + -> iterator_range_wrapper::iterator> { + return {pred_begin(a), pred_end(a)}; + } + + inline auto succ_begin(tinytc_inst_t a) { return adj_[a].succ.begin(); } + inline auto succ_end(tinytc_inst_t a) { return adj_[a].succ.end(); } + inline auto successors(tinytc_inst_t a) + -> iterator_range_wrapper::iterator> { + return {succ_begin(a), succ_end(a)}; + } + + private: + struct adjacency_list { + region_kind kind_max = region_kind::mixed; + std::vector pred; + std::vector succ; + }; + std::unordered_map adj_; +}; + +auto get_control_flow_graph(tinytc_region ®) -> control_flow_graph; + +} // namespace tinytc + +#endif // CFG_20240919_HPP diff --git a/src/analysis/gcd.cpp b/src/analysis/gcd.cpp new file mode 100644 index 00000000..0f63760a --- /dev/null +++ b/src/analysis/gcd.cpp @@ -0,0 +1,376 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/gcd.hpp" +#include "codegen_tools.hpp" +#include "error.hpp" +#include "node/attr.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "support/walk.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/iterator.hpp" +#include "util/overloaded.hpp" + +#include // IWYU pragma: keep +#include +#include +#include +#include + +namespace tinytc { + +memref_info::memref_info(std::int64_t offset_gcd, std::vector shape_gcd, + std::vector stride_gcd) + : offset_gcd_(offset_gcd), shape_gcd_(std::move(shape_gcd)), + stride_gcd_(std::move(stride_gcd)) {} + +auto gcd_analysis_result::get(::const_tinytc_value_t a) const -> std::int64_t { + const auto g = get_if(a); + return g ? *g : 1; +} +auto gcd_analysis_result::get(::tinytc_value const &a) const -> std::int64_t { return get(&a); } +auto gcd_analysis_result::get_if(::const_tinytc_value_t a) const -> std::optional { + if (auto it = gcd_.find(a); it != gcd_.end()) { + return it->second; + } + return std::nullopt; +} +auto gcd_analysis_result::get_if(::tinytc_value const &a) const -> std::optional { + return get_if(&a); +} +void gcd_analysis_result::set(::tinytc_value const &a, std::int64_t g) { gcd_[&a] = g; } + +auto gcd_analysis_result::get_memref_if(::const_tinytc_value_t a) const -> memref_info const * { + if (auto it = memref_info_.find(a); it != memref_info_.end()) { + return &it->second; + } + return nullptr; +} +auto gcd_analysis_result::get_memref_if(::tinytc_value const &a) const -> memref_info const * { + return get_memref_if(&a); +} +void gcd_analysis_result::set_memref(::tinytc_value const &a, memref_info g) { + memref_info_[&a] = std::move(g); +} + +class gcd_helper { + public: + inline gcd_helper(std::int32_t default_alignment) : default_alignment_{default_alignment} {} + + void operator()(inst_view in); + void operator()(alloca_inst in); + void operator()(arith_inst in); + void operator()(arith_unary_inst in); + void operator()(cast_inst in); + void operator()(constant_inst in); + void operator()(expand_inst in); + void operator()(for_inst in); + void operator()(fuse_inst in); + void operator()(load_inst in); + void operator()(size_inst in); + void operator()(subgroup_broadcast_inst in); + void operator()(subview_inst in); + + void set_from_attributes(tinytc_func &fn); + + auto get_result() && { return std::move(gcd_); } + + private: + std::int32_t default_alignment_; + gcd_analysis_result gcd_; +}; + +void gcd_helper::operator()(inst_view) {} +void gcd_helper::operator()(alloca_inst in) { + if (in.stack_ptr() >= 0) { + const auto rt = get_memref_type(in.result().ty()); + std::int32_t i = rt->element_alignment(); + while (i < default_alignment_) { + const auto i2 = 2 * i; + if (in.stack_ptr() % i2 != 0) { + break; + } + i = i2; + } + // alloca shape/stride must be static, therefore we can set shape_gcd/stride_gcd to + // shape/stride + auto rt_number_size = size(rt->element_ty()); + gcd_.set_memref(in.result(), memref_info(i / rt_number_size, rt->shape(), rt->stride())); + } +} +void gcd_helper::operator()(arith_inst in) { + auto compute_gcd = [&]() -> std::optional { + const auto ga = gcd_.get(in.a()); + const auto gb = gcd_.get(in.b()); + switch (in.get().type_id()) { + case IK::IK_add: + return std::gcd(ga, gb); + case IK::IK_mul: + return ga * gb; + case IK::IK_div: { + return ga % gb == 0 ? ga / gb : 1; + } + default: + break; + } + return std::nullopt; + }; + auto g = compute_gcd(); + if (g) { + gcd_.set(in.result(), *g); + } +} +void gcd_helper::operator()(arith_unary_inst in) { + auto compute_gcd = [&]() -> std::optional { + switch (in.get().type_id()) { + case IK::IK_abs: + case IK::IK_not: + return gcd_.get(in.a()); + default: + break; + } + return std::nullopt; + }; + auto g = compute_gcd(); + if (g) { + gcd_.set(in.result(), *g); + } +} +void gcd_helper::operator()(cast_inst in) { + auto g = gcd_.get_if(in.a()); + if (g) { + gcd_.set(in.result(), *g); + } +} +void gcd_helper::operator()(constant_inst in) { + if (std::holds_alternative(in.value())) { + gcd_.set(in.result(), std::abs(std::get(in.value()))); + } +} +void gcd_helper::operator()(expand_inst in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + const auto mt = get_memref_type(in.operand()); + const auto offset_gcd = mi->offset_gcd(); + auto shape_gcd = std::vector{}; + auto stride_gcd = std::vector{}; + + auto static_shape = in.static_expand_shape(); + auto dyn_shape = in.expand_shape(); + + shape_gcd.reserve(mt->dim() + static_shape.size() - 1); + stride_gcd.reserve(mt->dim() + static_shape.size() - 1); + + for (std::int64_t i = 0; i < in.expanded_mode(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + + auto get_shape = [&, j = std::size_t{0}](std::int64_t s) mutable { + if (is_dynamic_value(s)) { + return gcd_.get(dyn_shape[j++]); + } + return s; + }; + stride_gcd.push_back(mi->stride_gcd(in.expanded_mode())); + shape_gcd.push_back(get_shape(static_shape[0])); + for (std::size_t j = 1; j < static_shape.size(); ++j) { + stride_gcd.push_back(stride_gcd.back() * shape_gcd.back()); + shape_gcd.push_back(get_shape(static_shape[j])); + } + + for (std::int64_t i = in.expanded_mode() + 1; i < mt->dim(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + + gcd_.set_memref(in.result(), memref_info(offset_gcd, shape_gcd, stride_gcd)); + } +} +void gcd_helper::operator()(for_inst in) { + if (in.has_step()) { + auto g = std::gcd(gcd_.get(in.from()), gcd_.get(in.step())); + gcd_.set(in.loop_var(), g); + } +} +void gcd_helper::operator()(fuse_inst in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + const auto mt = get_memref_type(in.operand()); + const auto offset_gcd = mi->offset_gcd(); + auto shape_gcd = std::vector{}; + auto stride_gcd = std::vector{}; + + shape_gcd.reserve(mt->dim()); + stride_gcd.reserve(mt->dim()); + + std::int64_t i = 0; + for (; i < in.from(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + std::int64_t prod = mi->shape_gcd(i++); + for (; i <= in.to(); ++i) { + prod *= mi->shape_gcd(i); + } + shape_gcd.push_back(prod); + stride_gcd.push_back(mi->stride_gcd(in.from())); + for (i = in.to() + 1; i < mt->dim(); ++i) { + shape_gcd.push_back(mi->shape_gcd(i)); + stride_gcd.push_back(mi->stride_gcd(i)); + } + + gcd_.set_memref(in.result(), memref_info(offset_gcd, shape_gcd, stride_gcd)); + } +} +void gcd_helper::operator()(load_inst in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi && isa(*in.operand().ty())) { + gcd_.set_memref(in.result(), *mi); + } +} +void gcd_helper::operator()(size_inst in) { + const auto size = + visit(overloaded{[&](group_type &g) -> std::int64_t { + return !is_dynamic_value(g.size()) ? g.size() : 1; + }, + [&](memref_type &m) -> std::int64_t { + const auto s_i = m.shape(in.mode()); + if (is_dynamic_value(s_i)) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + return mi->shape_gcd(in.mode()); + } + return 1; + } + return s_i; + }, + [&](tinytc_type &) -> std::int64_t { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + }}, + *in.operand().ty()); + + gcd_.set(in.result(), size); +} +void gcd_helper::operator()(subgroup_broadcast_inst in) { + auto g = gcd_.get_if(in.a()); + if (g) { + gcd_.set(in.result(), *g); + } +} +void gcd_helper::operator()(subview_inst in) { + if (auto mi = gcd_.get_memref_if(in.operand()); mi) { + const auto mt = get_memref_type(in.operand()); + + auto shape_gcd = std::vector{}; + auto stride_gcd = std::vector{}; + + shape_gcd.reserve(mt->dim()); + stride_gcd.reserve(mt->dim()); + auto dyn_offsets = in.offsets(); + auto dyn_sizes = in.sizes(); + std::int64_t offset_gcd = mi->offset_gcd(); + for (std::int64_t i = 0, joffset = 0, jsize = 0; i < mt->dim(); ++i) { + const std::int64_t offset = in.static_offsets()[i]; + + auto const get_offset = [&]() -> std::int64_t { + if (is_dynamic_value(offset)) { + return gcd_.get(dyn_offsets[joffset++]); + } + return offset; + }; + offset_gcd = std::gcd(offset_gcd, get_offset() * mi->stride_gcd(i)); + + const std::int64_t size = in.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + auto const get_size = [&]() -> std::int64_t { + if (is_dynamic_value(size)) { + return gcd_.get(dyn_sizes[jsize++]); + } + return size; + }; + shape_gcd.emplace_back(get_size()); + stride_gcd.emplace_back(mi->stride_gcd(i)); + } + } + + gcd_.set_memref(in.result(), memref_info(offset_gcd, shape_gcd, stride_gcd)); + } +} + +void gcd_helper::set_from_attributes(tinytc_func &fn) { + auto known_memref_info = [&](memref_type *mr, tinytc_attr_t dict) -> memref_info { + const std::int64_t alignment = [&]() -> std::int64_t { + if (auto alignment_attr = get_attr(dict, "alignment"); alignment_attr) { + auto ia = dyn_cast(alignment_attr); + if (ia) { + return ia->value(); + } + throw status::ir_expected_integer_attribute; + } + return default_alignment_; + }(); + + auto shape_gcd = [&]() -> std::vector { + if (auto shape_attr = get_attr(dict, "shape_gcd"); shape_attr) { + return get_array_attr_as(shape_attr); + } + return std::vector{}; + }(); + auto const dim = static_cast(mr->dim()); + if (shape_gcd.size() > dim) { + shape_gcd.resize(dim); + } else if (shape_gcd.size() < dim) { + std::size_t i = shape_gcd.size(); + shape_gcd.resize(mr->dim()); + for (; i < shape_gcd.size(); ++i) { + const auto s = mr->shape(i); + shape_gcd[i] = !is_dynamic_value(s) ? s : 1; + } + } + + auto stride_gcd = [&]() -> std::vector { + if (auto stride_attr = get_attr(dict, "stride_gcd"); stride_attr) { + return get_array_attr_as(stride_attr); + } + return std::vector{}; + }(); + if (stride_gcd.size() > dim) { + stride_gcd.resize(dim); + } else if (stride_gcd.size() < dim) { + std::size_t i = stride_gcd.size(); + stride_gcd.resize(mr->dim()); + for (; i < stride_gcd.size(); ++i) { + const auto s = mr->stride(i); + stride_gcd[i] = !is_dynamic_value(s) ? s : 1; + } + } + + auto mr_number_size = size(mr->element_ty()); + return memref_info(alignment / mr_number_size, std::move(shape_gcd), std::move(stride_gcd)); + }; + for (std::size_t arg_no = 0; arg_no < fn.num_params(); ++arg_no) { + auto ty = fn.params()[arg_no].ty(); + if (auto g = dyn_cast(ty); g) { + ty = g->element_ty(); + } + if (auto mr = dyn_cast(ty); mr) { + gcd_.set_memref(fn.params()[arg_no], known_memref_info(mr, fn.param_attr(arg_no))); + } + } +} + +auto gcd_analysis::run_on_function(tinytc_func &fn) -> gcd_analysis_result { + auto visitor = gcd_helper{default_alignment_}; + visitor.set_from_attributes(fn); + + walk(fn, [&visitor](tinytc_inst &i) { visit(visitor, i); }); + + return std::move(visitor).get_result(); +} + +} // namespace tinytc diff --git a/src/analysis/gcd.hpp b/src/analysis/gcd.hpp new file mode 100644 index 00000000..f166a5cc --- /dev/null +++ b/src/analysis/gcd.hpp @@ -0,0 +1,109 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GCD_20241203_HPP +#define GCD_20241203_HPP + +#include "tinytc/types.h" + +#include +#include +#include +#include +#include + +namespace tinytc { + +class memref_info { + public: + memref_info() = default; + memref_info(std::int64_t offset_gcd, std::vector shape_gcd, + std::vector stride_gcd); + + inline auto offset_gcd() const { return offset_gcd_; } + inline auto shape_gcd_begin() const { return shape_gcd_.begin(); } + inline auto shape_gcd_end() const { return shape_gcd_.end(); } + inline auto shape_gcd() const -> std::vector const & { return shape_gcd_; } + inline auto shape_gcd(std::size_t i) const -> std::int64_t { return shape_gcd_[i]; } + inline auto stride_gcd_begin() const { return stride_gcd_.begin(); } + inline auto stride_gcd_end() const { return stride_gcd_.end(); } + inline auto stride_gcd() const -> std::vector const & { return stride_gcd_; } + inline auto stride_gcd(std::size_t i) const -> std::int64_t { return stride_gcd_[i]; } + + private: + std::int64_t offset_gcd_; + std::vector shape_gcd_, stride_gcd_; +}; + +class gcd_analysis_result { + public: + auto get(::const_tinytc_value_t a) const -> std::int64_t; + auto get(::tinytc_value const &a) const -> std::int64_t; + auto get_if(::const_tinytc_value_t a) const -> std::optional; + auto get_if(::tinytc_value const &a) const -> std::optional; + void set(::tinytc_value const &a, std::int64_t g); + + auto get_memref_if(::const_tinytc_value_t a) const -> memref_info const *; + auto get_memref_if(::tinytc_value const &a) const -> memref_info const *; + void set_memref(::tinytc_value const &a, memref_info g); + + private: + std::unordered_map<::tinytc_value const *, std::int64_t> gcd_; + std::unordered_map<::tinytc_value const *, memref_info> memref_info_; +}; + +/** + * In the "GCD-analysis" we want to infer at compile time which integer divide an SSA value is + * divisible. For example, for + * + * %0 = constant 32 : index + * %1 = arith.mul %0, %x : index + * + * we know that %1 is at least divisible by 2, 4, 8, 16, 32 without knowing anything about %x. + * + * For the analysis, let P(%x) be the set of known prime factors of a value %x. We define + * + * "%x = constant C": P(%x) := set of prime factors of |C| + * + * For multiplication and addition formulae we define + * "%z = arith.add %x, %y": P(%z) := P(%x) ∩ P(%y) + * "%z = arith.sub %x, %y": P(%z) := P(%x) ∩ P(%y) + * "%z = arith.mul %x, %y": P(%z) := P(%x) ∪ P(%y) + * "%z = arith.div %x, %y": P(%z) := P(%x) ∖ P(%y) if P(%y) ⊆ P(%x) else {1} + * + * If nothing is known about %x we let + * P(%x) := {1} + * + * For efficiency, we encode the set of prime factors by its product, that is, by the integer + * + * p(%x) := \prod_{f\in P(%x)} f + * + * We can update p without having to resort to P as following: + * + * "%x = constant C": p(%x) := C + * "%z = arith.add %x, %y": p(%z) := gcd(p(%x),p(%y)) + * "%z = arith.sub %x, %y": p(%z) := gcd(p(%x),p(%y)) + * "%z = arith.mul %x, %y": p(%z) := p(%x) * p(%y) + * "%z = arith.div %x, %y": p(%z) := p(%x) / p(%y) if p(%x) % p(%y) == 0 else 1 + * "%x = unknown": p(%x) := 1 + * + * where gcd is the greatest common divisor. + * + * One special case is when we have %zero = constant 0. By definition gcd(%zero) = 0. + * We then have automatically + * "%z = arith.add %zero, %x": p(%z) = p(%x) // Fine, 0 + x = x so we must have p(%z) = p(%x) + * "%z = arith.mul %zero, %x": p(%z) = 0 // Fine, 0 * x = 0 so we must have p(%z) = p(%zero) + */ +class gcd_analysis { + public: + inline gcd_analysis(std::int32_t default_alignment) : default_alignment_(default_alignment) {} + + auto run_on_function(tinytc_func &fn) -> gcd_analysis_result; + + private: + std::int32_t default_alignment_; +}; + +} // namespace tinytc + +#endif // GCD_20241203_HPP diff --git a/src/analysis/stack.cpp b/src/analysis/stack.cpp new file mode 100644 index 00000000..45d90a82 --- /dev/null +++ b/src/analysis/stack.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "analysis/stack.hpp" +#include "error.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "support/walk.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include + +namespace tinytc { + +auto stack_high_water_mark::run_on_function(tinytc_func &fn) -> std::int64_t { + std::int64_t high_water_mark = 0; + + walk(fn, [&high_water_mark](tinytc_inst &i) { + if (auto a = dyn_cast(&i); a) { + auto t = dyn_cast(a.result().ty()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + high_water_mark = std::max(high_water_mark, a.stack_ptr() + t->size_in_bytes()); + } + }); + + return high_water_mark; +} + +} // namespace tinytc diff --git a/src/analysis/stack.hpp b/src/analysis/stack.hpp new file mode 100644 index 00000000..3e840a76 --- /dev/null +++ b/src/analysis/stack.hpp @@ -0,0 +1,20 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef STACK_20241112_HPP +#define STACK_20241112_HPP + +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class stack_high_water_mark { + public: + auto run_on_function(tinytc_func &fn) -> std::int64_t; +}; + +} // namespace tinytc + +#endif // STACK_20241112_HPP diff --git a/src/binary.cpp b/src/binary.cpp index 360fc0ab..a6789c2b 100644 --- a/src/binary.cpp +++ b/src/binary.cpp @@ -3,9 +3,9 @@ #include "binary.hpp" #include "error.hpp" -#include "tinytc/tinytc.h" +#include "tinytc/core.h" #include "tinytc/types.h" -#include "util.hpp" +#include "util/casting.hpp" #include #include @@ -13,20 +13,24 @@ using namespace tinytc; -tinytc_binary::tinytc_binary(std::vector data, bundle_format format, +tinytc_binary::tinytc_binary(shared_handle ctx, + std::vector data, bundle_format format, tinytc_core_feature_flags_t core_features) - : data_(std::move(data)), format_(format), core_features_(core_features) {} + : ctx_(std::move(ctx)), data_(std::move(data)), format_(format), core_features_(core_features) { +} extern "C" { -tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, tinytc_bundle_format_t format, - size_t data_size, uint8_t const *data, +tinytc_status_t tinytc_binary_create(tinytc_binary_t *bin, tinytc_compiler_context_t ctx, + tinytc_bundle_format_t format, size_t data_size, + uint8_t const *data, tinytc_core_feature_flags_t core_features) { - if (bin == nullptr || data == nullptr) { + if (bin == nullptr || ctx == nullptr || data == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - *bin = std::make_unique(std::vector(data, data + data_size), + *bin = std::make_unique(shared_handle{ctx, true}, + std::vector(data, data + data_size), enum_cast(format), core_features) .release(); }); @@ -43,6 +47,14 @@ tinytc_status_t tinytc_binary_get_raw(const_tinytc_binary_t bin, tinytc_bundle_f return tinytc_status_success; } +tinytc_status_t tinytc_binary_get_compiler_context(const_tinytc_binary_t bin, + tinytc_compiler_context_t *ctx) { + if (bin == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = bin->context(); }); +} + tinytc_status_t tinytc_binary_get_core_features(const_tinytc_binary_t bin, tinytc_core_feature_flags_t *core_features) { if (bin == nullptr || core_features == nullptr) { diff --git a/src/binary.hpp b/src/binary.hpp index 3607c0db..fdd048ac 100644 --- a/src/binary.hpp +++ b/src/binary.hpp @@ -20,14 +20,20 @@ struct tinytc_binary : tinytc::reference_counted { /** * @brief Create binary * + * @param ctx Compiler context * @param data Binary data * @param format Binary format (SPIR-V or native device binary) * @param metadata_map Dictionary kernel name -> kernel metadata * @param core_features Required core features */ - tinytc_binary(std::vector data, tinytc::bundle_format format, + tinytc_binary(tinytc::shared_handle ctx, + std::vector data, tinytc::bundle_format format, tinytc_core_feature_flags_t core_features); + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::shared_handle { + return ctx_; + } //! Get raw data inline auto data() const noexcept -> std::uint8_t const * { return data_.data(); } //! Get size of raw data @@ -40,6 +46,7 @@ struct tinytc_binary : tinytc::reference_counted { } private: + tinytc::shared_handle ctx_; std::vector data_; tinytc::bundle_format format_; tinytc_core_feature_flags_t core_features_; diff --git a/src/builder.cpp.mochi b/src/builder.cpp.mochi new file mode 100644 index 00000000..e2a5748a --- /dev/null +++ b/src/builder.cpp.mochi @@ -0,0 +1,130 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "error.hpp" +#include "location.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "tinytc/builder.h" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/overloaded.hpp" + +#include +#include + +using namespace tinytc; + +extern "C" { +// もち api_builder_cpp "tinytc/instructions.anko" + +tinytc_status_t tinytc_constant_inst_create_boolean(tinytc_inst_t *instr, tinytc_bool_t value, + tinytc_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = constant_inst::create(value != 0, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_constant_inst_create_complex(tinytc_inst_t *instr, double value_re, + double value_im, tinytc_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *instr = + constant_inst::create(std::complex(value_re, value_im), ty, get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_constant_inst_create_float(tinytc_inst_t *instr, double value, + tinytc_type_t ty, const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = constant_inst::create(value, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_constant_inst_create_int(tinytc_inst_t *instr, int64_t value, + tinytc_type_t ty, const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *instr = constant_inst::create(value, ty, get_optional(loc)).release(); }); +} + +tinytc_status_t tinytc_constant_inst_create_one(tinytc_inst_t *instr, tinytc_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + + auto dispatch_ty = ty; + if (const auto *ct = dyn_cast(dispatch_ty); ct != nullptr) { + dispatch_ty = ct->component_ty(); + } + return exception_to_status_code([&] { + *instr = visit( + overloaded{ + [&](boolean_type &) -> tinytc_inst_t { + return constant_inst::create(true, ty, get_optional(loc)).release(); + }, + [&](integer_type &) -> tinytc_inst_t { + return constant_inst::create(std::int64_t{1}, ty, get_optional(loc)).release(); + }, + [&](float_type &) -> tinytc_inst_t { + return constant_inst::create(double{1}, ty, get_optional(loc)).release(); + }, + [&](complex_type &) -> tinytc_inst_t { + return constant_inst::create(std::complex{1}, ty, get_optional(loc)) + .release(); + }, + [&](tinytc_type &) -> tinytc_inst_t { + throw compilation_error(get_optional(loc), status::ir_expected_number); + }}, + *dispatch_ty); + }); +} + +tinytc_status_t tinytc_constant_inst_create_zero(tinytc_inst_t *instr, tinytc_type_t ty, + const tinytc_location_t *loc) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + + auto dispatch_ty = ty; + if (const auto *ct = dyn_cast(dispatch_ty); ct != nullptr) { + dispatch_ty = ct->component_ty(); + } + return exception_to_status_code([&] { + *instr = visit( + overloaded{ + [&](boolean_type &) -> tinytc_inst_t { + return constant_inst::create(false, ty, get_optional(loc)).release(); + }, + [&](integer_type &) -> tinytc_inst_t { + return constant_inst::create(std::int64_t{0}, ty, get_optional(loc)).release(); + }, + [&](float_type &) -> tinytc_inst_t { + return constant_inst::create(double{0}, ty, get_optional(loc)).release(); + }, + [&](complex_type &) -> tinytc_inst_t { + return constant_inst::create(std::complex{0}, ty, get_optional(loc)) + .release(); + }, + [&](tinytc_type &) -> tinytc_inst_t { + throw compilation_error(get_optional(loc), status::ir_expected_number); + }}, + *dispatch_ty); + }); +} +} diff --git a/src/cl/CMakeLists.txt b/src/cl/CMakeLists.txt index e0cf379d..1a2d4435 100644 --- a/src/cl/CMakeLists.txt +++ b/src/cl/CMakeLists.txt @@ -2,21 +2,17 @@ # SPDX-License-Identifier: BSD-3-Clause include(CommonOptions) +include(GeneratedFiles) include(GNUInstallDirs) include(InstallLib) find_package(OpenCL REQUIRED) -find_package(re2c REQUIRED) set(SOURCES device_info.cpp - error.cpp kernel.cpp recipe_handler.cpp ) -set(RE2C_SOURCES - device_info_helper.re -) set(PUBLIC_HEADERS tinytc_cl.h tinytc_cl.hpp @@ -24,13 +20,14 @@ set(PUBLIC_HEADERS list(TRANSFORM PUBLIC_HEADERS PREPEND "${PROJECT_SOURCE_DIR}/include/tinytc/") add_library(tinytc_cl-objects OBJECT ${SOURCES}) -add_re2c_to_target(TARGET tinytc_cl-objects SOURCES ${RE2C_SOURCES} FLAGS "--tags") set_cxx_common_options(tinytc_cl-objects) target_link_libraries(tinytc_cl-objects PUBLIC tinytc OpenCL::OpenCL) target_include_directories(tinytc_cl-objects PRIVATE "$" ) +add_re2c_or_pregenerated_to_target(TARGET tinytc_cl-objects SOURCES device_info_helper.re) + add_library(tinytc_cl $) add_library(tinytc::tinytc_cl ALIAS tinytc_cl) @@ -52,3 +49,8 @@ install(FILES "${PROJECT_SOURCE_DIR}/cmake/FindOpenCL.cmake" DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/tinytc_cl ) + +# cpack + +include(GeneratedFiles) +write_generated_files(tinytc_cl-objects) diff --git a/src/cl/argument_handler.hpp b/src/cl/argument_handler.hpp index 7cc8d5f2..fdb15fbd 100644 --- a/src/cl/argument_handler.hpp +++ b/src/cl/argument_handler.hpp @@ -19,13 +19,16 @@ class opencl_argument_handler { using clSetKernelArgMemPointerINTEL_t = cl_int (*)(cl_kernel kernel, cl_uint arg_index, const void *arg_value); //! ctor - inline opencl_argument_handler() : clSetKernelArgMemPointerINTEL_(nullptr) {} + inline opencl_argument_handler() = default; //! ctor; checks whether cl_intel_unified_shared_memory is available and gets //! clSetKernelArgMemPointerINTEL - inline opencl_argument_handler(cl_platform_id plat) - : clSetKernelArgMemPointerINTEL_( - (clSetKernelArgMemPointerINTEL_t)clGetExtensionFunctionAddressForPlatform( - plat, "clSetKernelArgMemPointerINTEL")) {} + inline opencl_argument_handler(cl_platform_id plat) { set_platform(plat); } + + inline void set_platform(cl_platform_id plat) { + clSetKernelArgMemPointerINTEL_ = + (clSetKernelArgMemPointerINTEL_t)clGetExtensionFunctionAddressForPlatform( + plat, "clSetKernelArgMemPointerINTEL"); + } /** * @brief Set single kernel argument @@ -48,8 +51,8 @@ class opencl_argument_handler { * @param mem Memory object */ inline void set_mem_arg(cl_kernel kernel, std::uint32_t arg_index, const void *value, - tinytc_mem_type_t type) const { - switch (type) { + tinytc_mem_type_t ty) const { + switch (ty) { case tinytc_mem_type_buffer: set_arg(kernel, arg_index, sizeof(value), &value); return; @@ -69,7 +72,7 @@ class opencl_argument_handler { } private: - clSetKernelArgMemPointerINTEL_t clSetKernelArgMemPointerINTEL_; + clSetKernelArgMemPointerINTEL_t clSetKernelArgMemPointerINTEL_ = nullptr; }; } // namespace tinytc diff --git a/src/cl/device_info.cpp b/src/cl/device_info.cpp index 76c7fe2b..ffa871f3 100644 --- a/src/cl/device_info.cpp +++ b/src/cl/device_info.cpp @@ -3,40 +3,112 @@ #include "../device_info.hpp" #include "device_info_helper.hpp" -#include "tinytc/tinytc.h" +#include "error.hpp" +#include "tinytc/core.h" #include "tinytc/tinytc_cl.h" #include "tinytc/types.h" +#include "tinytc/types.hpp" #include #include #include #include #include -#include #include +#ifndef CL_DEVICE_SINGLE_FP_ATOMIC_CAPABILITIES_EXT +#define CL_DEVICE_SINGLE_FP_ATOMIC_CAPABILITIES_EXT 0x4231 +#endif +#ifndef CL_DEVICE_DOUBLE_FP_ATOMIC_CAPABILITIES_EXT +#define CL_DEVICE_DOUBLE_FP_ATOMIC_CAPABILITIES_EXT 0x4232 +#endif +#ifndef CL_DEVICE_HALF_FP_ATOMIC_CAPABILITIES_EXT +#define CL_DEVICE_HALF_FP_ATOMIC_CAPABILITIES_EXT 0x4233 +#endif + +#ifndef CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT +#define CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT (1 << 1) +#endif +#ifndef CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT +#define CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT (1 << 2) +#endif +#ifndef CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT +#define CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT (1 << 17) +#endif +#ifndef CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT +#define CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT (1 << 18) +#endif + +namespace tinytc { +void set_spirv_features(tinytc_core_info_t info, cl_device_id device) { + auto const set_feature = [&info](tinytc_spirv_feature_t feature, bool available) { + CHECK_STATUS(tinytc_core_info_set_spirv_feature(info, feature, available)); + }; + + const auto ocl_exts = get_opencl_extensions(device); + const auto ocl_version = get_opencl_version(device); + const auto max_num_subgroups = device_info(device, CL_DEVICE_MAX_NUM_SUB_GROUPS); + const auto double_fp_config = + device_info(device, CL_DEVICE_DOUBLE_FP_CONFIG); + + set_feature(tinytc_spirv_feature_float16, ocl_exts & opencl_ext_cl_khr_fp16); + set_feature(tinytc_spirv_feature_float64, + double_fp_config != 0 || ocl_exts & opencl_ext_cl_khr_fp64); + set_feature(tinytc_spirv_feature_groups, ocl_version.major >= 2 && max_num_subgroups != 0); + set_feature(tinytc_spirv_feature_subgroup_dispatch, + ocl_version.major >= 2 && max_num_subgroups != 0); + set_feature(tinytc_spirv_feature_subgroup_buffer_block_io, + ocl_exts & opencl_ext_cl_intel_spirv_subgroups); + set_feature(tinytc_spirv_feature_int64_atomics, + ocl_exts & (opencl_ext_cl_khr_int64_base_atomics | + opencl_ext_cl_khr_int64_extended_atomics)); + if (ocl_exts & opencl_ext_cl_ext_float_atomics) { + auto f16_flags = + device_info(device, CL_DEVICE_HALF_FP_ATOMIC_CAPABILITIES_EXT); + auto f32_flags = + device_info(device, CL_DEVICE_SINGLE_FP_ATOMIC_CAPABILITIES_EXT); + auto f64_flags = + device_info(device, CL_DEVICE_DOUBLE_FP_ATOMIC_CAPABILITIES_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_add_local, + f16_flags & CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_add_local, + f32_flags & CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_add_local, + f64_flags & CL_DEVICE_LOCAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_add_global, + f16_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_add_global, + f32_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_add_global, + f64_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_ADD_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_min_max_local, + f16_flags & CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_min_max_local, + f32_flags & CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_min_max_local, + f64_flags & CL_DEVICE_LOCAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float16_min_max_global, + f16_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float32_min_max_global, + f32_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT); + set_feature(tinytc_spirv_feature_atomic_float64_min_max_global, + f64_flags & CL_DEVICE_GLOBAL_FP_ATOMIC_MIN_MAX_EXT); + } +} +} // namespace tinytc + extern "C" { tinytc_status_t tinytc_cl_get_support_level(cl_device_id device, tinytc_support_level_t *level) { if (level == nullptr) { return tinytc_status_invalid_arguments; } - std::size_t extensions_size; - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, nullptr, &extensions_size)); - std::string extensions; - extensions.resize(extensions_size); - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, extensions_size, extensions.data(), nullptr)); - - bool has_subgroup = tinytc::has_subgroup_extension(extensions.size(), extensions.c_str()); + auto ocl_exts = tinytc::get_opencl_extensions(device); + bool has_subgroup = + ocl_exts & (tinytc::opencl_ext_cl_intel_subgroups | tinytc::opencl_ext_cl_khr_subgroups); if (!has_subgroup) { - char version_str[32]; - std::size_t version_str_size; - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_VERSION, sizeof(version_str) - 1, - version_str, &version_str_size)); - auto version = tinytc::get_opencl_version(version_str_size, version_str); + auto version = tinytc::get_opencl_version(device); if (version.major >= 3) { std::size_t features_size; TINYTC_CL_CHECK_STATUS( @@ -63,7 +135,11 @@ tinytc_status_t tinytc_cl_get_support_level(cl_device_id device, tinytc_support_ cl_version ip_ver = 0; cl_int err = clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, sizeof(ip_ver), &ip_ver, nullptr); - if (err == CL_SUCCESS && ip_ver == tinytc_intel_gpu_architecture_pvc) { + const auto is_arch = [&ip_ver](auto arch) { + return arch <= ip_ver && ip_ver <= arch + TINYTC_INTEL_GPU_ARCHITECTURE_SUB_VERSION_BITS; + }; + if (err == CL_SUCCESS && (is_arch(tinytc_intel_gpu_architecture_pvc) || + is_arch(tinytc_intel_gpu_architecture_bmg))) { *level = tinytc_support_level_tuned; } @@ -75,16 +151,18 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i return tinytc_status_invalid_arguments; } - cl_uint vendor_id; + cl_uint vendor_id, mem_base_addr_align; TINYTC_CL_CHECK_STATUS( clGetDeviceInfo(device, CL_DEVICE_VENDOR_ID, sizeof(vendor_id), &vendor_id, nullptr)); if (vendor_id == 0x8086) { - cl_version ip_ver; - cl_uint num_eus_per_subslice, num_threads_per_eu; + cl_device_type device_type; std::size_t subgroup_sizes_size = 0; + TINYTC_CL_CHECK_STATUS( + clGetDeviceInfo(device, CL_DEVICE_TYPE, sizeof(device_type), &device_type, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_SUB_GROUP_SIZES_INTEL, 0, nullptr, &subgroup_sizes_size)); auto subgroup_sizes_long = @@ -95,19 +173,36 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i auto subgroup_sizes = std::vector(subgroup_sizes_long.begin(), subgroup_sizes_long.end()); - TINYTC_CL_CHECK_STATUS( - clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, sizeof(ip_ver), &ip_ver, nullptr)); - - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, - sizeof(num_eus_per_subslice), &num_eus_per_subslice, - nullptr)); - TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, - sizeof(num_threads_per_eu), &num_threads_per_eu, - nullptr)); - - TINYTC_CHECK_STATUS(tinytc_core_info_intel_create(info, ip_ver, num_eus_per_subslice, - num_threads_per_eu, subgroup_sizes.size(), - subgroup_sizes.data())); + if (device_type == CL_DEVICE_TYPE_GPU) { + cl_version ip_ver; + cl_uint num_eus_per_subslice, num_threads_per_eu; + + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_IP_VERSION_INTEL, + sizeof(ip_ver), &ip_ver, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, + sizeof(num_eus_per_subslice), + &num_eus_per_subslice, nullptr)); + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, + sizeof(num_threads_per_eu), &num_threads_per_eu, + nullptr)); + + TINYTC_CHECK_STATUS(tinytc_core_info_intel_create( + info, ip_ver, num_eus_per_subslice, num_threads_per_eu, subgroup_sizes.size(), + subgroup_sizes.data())); + } else if (device_type == CL_DEVICE_TYPE_CPU) { + // 32 zmm registers + // @todo: need to do something smarter here + std::uint32_t register_space = 32 * 64; + size_t max_work_group_size; + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, + sizeof(max_work_group_size), + &max_work_group_size, nullptr)); + TINYTC_CHECK_STATUS( + tinytc_core_info_generic_create(info, register_space, max_work_group_size, + subgroup_sizes.size(), subgroup_sizes.data())); + } else { + return tinytc_status_unsupported_device; + } } else if (vendor_id == 0x1002) { // 512 KB / 32 wavefronts // @todo: can this info be queried? @@ -127,7 +222,13 @@ tinytc_status_t tinytc_cl_core_info_create(tinytc_core_info_t *info, cl_device_i return tinytc_status_unsupported_device; } - return tinytc_status_success; + TINYTC_CL_CHECK_STATUS(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, + sizeof(mem_base_addr_align), &mem_base_addr_align, + nullptr)); + // mem_base_addr_align is in bits -> convert to bytes + TINYTC_CHECK_STATUS(tinytc_core_info_set_default_alignment(*info, mem_base_addr_align / 8)); + + return tinytc::exception_to_status_code_cl([&] { set_spirv_features(*info, device); }); } } diff --git a/src/cl/device_info_helper.hpp b/src/cl/device_info_helper.hpp index c2e265a0..5cbfb337 100644 --- a/src/cl/device_info_helper.hpp +++ b/src/cl/device_info_helper.hpp @@ -4,7 +4,13 @@ #ifndef CL_DEVICE_INFO_HELPER_20240503_HPP #define CL_DEVICE_INFO_HELPER_20240503_HPP +#include "tinytc/tinytc_cl.hpp" + +#include + #include +#include +#include namespace tinytc { @@ -13,9 +19,43 @@ struct opencl_version { int minor; }; -bool has_subgroup_extension(std::size_t str_length, const char *str); -bool has_additional_subgroup_extensions(std::size_t str_length, const char *str); +enum opencl_ext_t { + opencl_ext_cl_khr_fp16 = 0x1, + opencl_ext_cl_khr_fp64 = 0x2, + opencl_ext_cl_khr_subgroups = 0x4, + opencl_ext_cl_intel_subgroups = 0x8, + opencl_ext_cl_intel_required_subgroup_size = 0x10, + opencl_ext_cl_intel_subgroups_long = 0x20, + opencl_ext_cl_intel_subgroups_short = 0x40, + opencl_ext_cl_khr_int64_base_atomics = 0x80, + opencl_ext_cl_khr_int64_extended_atomics = 0x100, + opencl_ext_cl_ext_float_atomics = 0x200, + opencl_ext_cl_intel_spirv_subgroups = 0x400, +}; +//! Type for combination of core feature flags +using opencl_exts_t = std::uint32_t; + +auto get_opencl_extensions(std::size_t str_length, const char *str) -> opencl_exts_t; +auto get_opencl_extensions(cl_device_id device) -> opencl_exts_t; auto get_opencl_version(std::size_t str_length, const char *str) -> opencl_version; +auto get_opencl_version(cl_device_id device) -> opencl_version; + +template auto device_info(cl_device_id device, cl_device_info param_name) -> T { + T val = {}; + CL_CHECK_STATUS(clGetDeviceInfo(device, param_name, sizeof(T), &val, nullptr)); + return val; +} + +template <> +inline auto device_info(cl_device_id device, cl_device_info param_name) + -> std::string { + std::string str; + std::size_t str_len; + CL_CHECK_STATUS(clGetDeviceInfo(device, param_name, 0, nullptr, &str_len)); + str.resize(str_len); + CL_CHECK_STATUS(clGetDeviceInfo(device, param_name, str_len, str.data(), nullptr)); + return str; +} } // namespace tinytc diff --git a/src/cl/device_info_helper.re b/src/cl/device_info_helper.re index 8004f97a..a554e13e 100644 --- a/src/cl/device_info_helper.re +++ b/src/cl/device_info_helper.re @@ -3,50 +3,67 @@ #include "device_info_helper.hpp" +#include + namespace tinytc { -bool has_subgroup_extension(std::size_t str_length, const char *str) { +auto get_opencl_extensions(std::size_t str_length, const char *str) -> opencl_exts_t { const char *YYLIMIT = str + str_length; const char *YYCURSOR = str; const char *YYMARKER; + opencl_exts_t result = 0; for (;;) { /*!re2c re2c:yyfill:enable = 0; re2c:define:YYCTYPE = char; re2c:eof = 0; + re2c:tags = 1; whitespace = [ \t\v\r]+; - "cl_intel_subgroups" { return true; } - "cl_khr_subgroups" { return true; } - whitespace { continue; } - * { continue; } - $ { break; } + + "cl_khr_fp16" { result |= opencl_ext_cl_khr_fp16; continue; } + "cl_khr_fp64" { result |= opencl_ext_cl_khr_fp64; continue; } + "cl_khr_subgroups" { result |= opencl_ext_cl_khr_subgroups; continue; } + "cl_intel_subgroups" { result |= opencl_ext_cl_intel_subgroups; continue; } + "cl_intel_required_subgroup_size" { + result |= opencl_ext_cl_intel_required_subgroup_size; continue; + } + "cl_intel_subgroups_long" { + result |= opencl_ext_cl_intel_subgroups_long; continue; + } + "cl_intel_subgroups_short" { + result |= opencl_ext_cl_intel_subgroups_short; continue; + } + "cl_intel_spirv_subgroups" { + result |= opencl_ext_cl_intel_spirv_subgroups; continue; + } + "cl_khr_int64_base_atomics" { + result |= opencl_ext_cl_khr_int64_base_atomics; continue; + } + "cl_khr_int64_extended_atomics" { + result |= opencl_ext_cl_khr_int64_extended_atomics; continue; + } + "cl_ext_float_atomics" { + result |= opencl_ext_cl_ext_float_atomics; continue; + } + whitespace { continue; } + * { + // skip remaining characters until we find whitespace + while (!std::isspace(*YYCURSOR) && YYCURSOR < YYLIMIT) { + ++YYCURSOR; + } + continue; + } + $ { break; } */ } - return false; + return result; } -bool has_additional_subgroup_extensions(std::size_t str_length, const char *str) { - const char *YYLIMIT = str + str_length; - const char *YYCURSOR = str; - const char *YYMARKER; - bool has_reqd_subgroup_size = false, has_subgroups_long = false, has_subgroups_short = false; - for (;;) { - /*!re2c - re2c:yyfill:enable = 0; - re2c:define:YYCTYPE = char; - re2c:eof = 0; - - "cl_intel_required_subgroup_size" { has_reqd_subgroup_size = true; continue; } - "cl_intel_subgroups_long" { has_subgroups_long = true; continue; } - "cl_intel_subgroups_short" { has_subgroups_short = true; continue; } - whitespace { continue; } - * { continue; } - $ { break; } - */ - } - return has_reqd_subgroup_size && has_subgroups_long && has_subgroups_short; +auto get_opencl_extensions(cl_device_id device) -> opencl_exts_t { + std::string extensions = device_info(device, CL_DEVICE_EXTENSIONS); + return get_opencl_extensions(extensions.size(), extensions.c_str()); } auto get_opencl_version(std::size_t str_length, const char *str) -> opencl_version { @@ -81,4 +98,9 @@ ret: return {major, minor}; } +auto get_opencl_version(cl_device_id device) -> opencl_version { + std::string version = device_info(device, CL_DEVICE_VERSION); + return get_opencl_version(version.size(), version.c_str()); +} + } // namespace tinytc diff --git a/src/cl/error.cpp b/src/cl/error.cpp deleted file mode 100644 index 0843be8b..00000000 --- a/src/cl/error.cpp +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "tinytc/tinytc_cl.h" -#include "tinytc/types.h" - -#include -#include - -extern "C" { -tinytc_status_t tinytc_cl_convert_status(cl_int status) { - switch (status) { - case CL_SUCCESS: - return tinytc_status_success; - case CL_BUILD_PROGRAM_FAILURE: - return tinytc_status_cl_build_program_failure; - case CL_COMPILE_PROGRAM_FAILURE: - return tinytc_status_cl_compile_program_failure; - case CL_COMPILER_NOT_AVAILABLE: - return tinytc_status_cl_compiler_not_available; - case CL_DEVICE_NOT_FOUND: - return tinytc_status_cl_device_not_found; - case CL_DEVICE_NOT_AVAILABLE: - return tinytc_status_cl_device_not_available; - case CL_DEVICE_PARTITION_FAILED: - return tinytc_status_cl_device_partition_failed; - case CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST: - return tinytc_status_cl_exec_status_error_for_events_in_wait_list; - case CL_IMAGE_FORMAT_MISMATCH: - return tinytc_status_cl_image_format_mismatch; - case CL_IMAGE_FORMAT_NOT_SUPPORTED: - return tinytc_status_cl_image_format_not_supported; - case CL_INVALID_ARG_INDEX: - return tinytc_status_cl_invalid_arg_index; - case CL_INVALID_ARG_SIZE: - return tinytc_status_cl_invalid_arg_size; - case CL_INVALID_ARG_VALUE: - return tinytc_status_cl_invalid_arg_value; - case CL_INVALID_BINARY: - return tinytc_status_cl_invalid_binary; - case CL_INVALID_BUFFER_SIZE: - return tinytc_status_cl_invalid_buffer_size; - case CL_INVALID_BUILD_OPTIONS: - return tinytc_status_cl_invalid_build_options; - case CL_INVALID_COMMAND_QUEUE: - return tinytc_status_cl_invalid_command_queue; - case CL_INVALID_COMPILER_OPTIONS: - return tinytc_status_cl_invalid_compiler_options; - case CL_INVALID_CONTEXT: - return tinytc_status_cl_invalid_context; - case CL_INVALID_DEVICE: - return tinytc_status_cl_invalid_device; - case CL_INVALID_DEVICE_PARTITION_COUNT: - return tinytc_status_cl_invalid_device_partition_count; - case CL_INVALID_DEVICE_QUEUE: - return tinytc_status_cl_invalid_device_queue; - case CL_INVALID_DEVICE_TYPE: - return tinytc_status_cl_invalid_device_type; - case CL_INVALID_EVENT: - return tinytc_status_cl_invalid_event; - case CL_INVALID_EVENT_WAIT_LIST: - return tinytc_status_cl_invalid_event_wait_list; - case CL_INVALID_GLOBAL_OFFSET: - return tinytc_status_cl_invalid_global_offset; - case CL_INVALID_GLOBAL_WORK_SIZE: - return tinytc_status_cl_invalid_global_work_size; - case CL_INVALID_HOST_PTR: - return tinytc_status_cl_invalid_host_ptr; - case CL_INVALID_IMAGE_DESCRIPTOR: - return tinytc_status_cl_invalid_image_descriptor; - case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: - return tinytc_status_cl_invalid_image_format_descriptor; - case CL_INVALID_IMAGE_SIZE: - return tinytc_status_cl_invalid_image_size; - case CL_INVALID_KERNEL: - return tinytc_status_cl_invalid_kernel; - case CL_INVALID_KERNEL_ARGS: - return tinytc_status_cl_invalid_kernel_args; - case CL_INVALID_KERNEL_DEFINITION: - return tinytc_status_cl_invalid_kernel_definition; - case CL_INVALID_KERNEL_NAME: - return tinytc_status_cl_invalid_kernel_name; - case CL_INVALID_LINKER_OPTIONS: - return tinytc_status_cl_invalid_linker_options; - case CL_INVALID_MEM_OBJECT: - return tinytc_status_cl_invalid_mem_object; - case CL_INVALID_OPERATION: - return tinytc_status_cl_invalid_operation; - case CL_INVALID_PIPE_SIZE: - return tinytc_status_cl_invalid_pipe_size; - case CL_INVALID_PLATFORM: - return tinytc_status_cl_invalid_platform; - case CL_INVALID_PROGRAM: - return tinytc_status_cl_invalid_program; - case CL_INVALID_PROGRAM_EXECUTABLE: - return tinytc_status_cl_invalid_program_executable; - case CL_INVALID_PROPERTY: - return tinytc_status_cl_invalid_property; - case CL_INVALID_QUEUE_PROPERTIES: - return tinytc_status_cl_invalid_queue_properties; - case CL_INVALID_SAMPLER: - return tinytc_status_cl_invalid_sampler; - case CL_INVALID_SPEC_ID: - return tinytc_status_cl_invalid_spec_id; - case CL_INVALID_VALUE: - return tinytc_status_cl_invalid_value; - case CL_INVALID_WORK_DIMENSION: - return tinytc_status_cl_invalid_work_dimension; - case CL_INVALID_WORK_GROUP_SIZE: - return tinytc_status_cl_invalid_work_group_size; - case CL_INVALID_WORK_ITEM_SIZE: - return tinytc_status_cl_invalid_work_item_size; - case CL_KERNEL_ARG_INFO_NOT_AVAILABLE: - return tinytc_status_cl_kernel_arg_info_not_available; - case CL_LINK_PROGRAM_FAILURE: - return tinytc_status_cl_link_program_failure; - case CL_LINKER_NOT_AVAILABLE: - return tinytc_status_cl_linker_not_available; - case CL_MAP_FAILURE: - return tinytc_status_cl_map_failure; - case CL_MEM_COPY_OVERLAP: - return tinytc_status_cl_mem_copy_overlap; - case CL_MEM_OBJECT_ALLOCATION_FAILURE: - return tinytc_status_cl_mem_object_allocation_failure; - case CL_MISALIGNED_SUB_BUFFER_OFFSET: - return tinytc_status_cl_misaligned_sub_buffer_offset; - case CL_OUT_OF_HOST_MEMORY: - return tinytc_status_cl_out_of_host_memory; - case CL_OUT_OF_RESOURCES: - return tinytc_status_cl_out_of_resources; - case CL_MAX_SIZE_RESTRICTION_EXCEEDED: - return tinytc_status_cl_max_size_restriction_exceeded; - case CL_PROFILING_INFO_NOT_AVAILABLE: - return tinytc_status_cl_profiling_info_not_available; - } - return tinytc_status_unknown; -} -} diff --git a/src/cl/error.hpp b/src/cl/error.hpp index d1c31c9a..562b1d53 100644 --- a/src/cl/error.hpp +++ b/src/cl/error.hpp @@ -4,7 +4,7 @@ #ifndef CL_ERROR_20240423_HPP #define CL_ERROR_20240423_HPP -#include "tinytc/tinytc.hpp" +#include "tinytc/core.hpp" #include diff --git a/src/cl/kernel.cpp b/src/cl/kernel.cpp index 23814944..c098b666 100644 --- a/src/cl/kernel.cpp +++ b/src/cl/kernel.cpp @@ -2,76 +2,29 @@ // SPDX-License-Identifier: BSD-3-Clause #include "../compiler_options.hpp" -#include "tinytc/tinytc.h" +#include "tinytc/core.h" #include "tinytc/tinytc_cl.h" #include "tinytc/types.h" #include #include -#include #include #include -#include #include -#include +#include extern "C" { -tinytc_status_t tinytc_cl_kernel_bundle_create_with_source(cl_program *bundle, cl_context context, - cl_device_id device, - const_tinytc_source_t src, - tinytc_source_context_t source_ctx) { - if (bundle == nullptr || src == nullptr) { - return tinytc_status_invalid_arguments; - } - - size_t length = 0; - char const *code = nullptr; - tinytc_core_feature_flags_t core_features = 0; - TINYTC_CL_CHECK_STATUS(tinytc_source_get_code(src, &length, &code)); - TINYTC_CL_CHECK_STATUS(tinytc_source_get_core_features(src, &core_features)); - - cl_int err; - cl_program p = clCreateProgramWithSource(context, 1, &code, &length, &err); - TINYTC_CL_CHECK_STATUS(err); - - auto options = std::ostringstream{}; - for (auto const &opt : tinytc::default_compiler_options) { - options << opt << " "; - } - if (core_features & tinytc_core_feature_flag_large_register_file) { - options << tinytc::large_register_file_compiler_option_cl; - } - auto options_str = std::move(options).str(); - if (err = clBuildProgram(p, 1, &device, options_str.c_str(), nullptr, nullptr); - err != CL_SUCCESS) { - if (source_ctx) { - std::string log; - std::size_t log_size; - clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); - log.resize(log_size); - clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); - - tinytc_location_t loc = {}; - tinytc_source_get_location(src, &loc); - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); - } - clReleaseProgram(p); - TINYTC_CL_CHECK_STATUS(err); - } - *bundle = p; - return tinytc_status_success; -} - -tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( - cl_program *bundle, cl_context context, cl_device_id device, tinytc_prog_t prg, - tinytc_core_feature_flags_t core_features, tinytc_source_context_t source_ctx) { +tinytc_status_t +tinytc_cl_kernel_bundle_create_with_program(cl_program *bundle, cl_context context, + cl_device_id device, tinytc_prog_t prg, + tinytc_core_feature_flags_t core_features) { if (bundle == nullptr || prg == nullptr) { return tinytc_status_invalid_arguments; } tinytc_core_info_t info = nullptr; - tinytc_source_t src = nullptr; + tinytc_binary_t bin = nullptr; tinytc_status_t status = tinytc_status_success; if (status = tinytc_cl_core_info_create(&info, device); status != tinytc_status_success) { @@ -81,17 +34,16 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( status != tinytc_status_success) { goto err; } - if (status = tinytc_prog_compile_to_opencl(&src, prg, info, source_ctx); + if (status = tinytc_prog_compile_to_spirv_and_assemble(&bin, prg, info); status != tinytc_status_success) { goto err; } - if (status = - tinytc_cl_kernel_bundle_create_with_source(bundle, context, device, src, source_ctx); + if (status = tinytc_cl_kernel_bundle_create_with_binary(bundle, context, device, bin); status != tinytc_status_success) { goto err; } err: - tinytc_source_release(src); + tinytc_binary_release(bin); tinytc_core_info_release(info); return status; @@ -99,8 +51,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_program( tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, cl_context context, cl_device_id device, - const_tinytc_binary_t bin, - tinytc_source_context_t source_ctx) { + const_tinytc_binary_t bin) { if (bin == nullptr) { return tinytc_status_invalid_arguments; } @@ -121,12 +72,15 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, c tinytc_core_feature_flags_t core_features; TINYTC_CHECK_STATUS(tinytc_binary_get_core_features(bin, &core_features)); + tinytc_compiler_context_t ctx = nullptr; + TINYTC_CHECK_STATUS(tinytc_binary_get_compiler_context(bin, &ctx)); + char const *options = ""; if (core_features & tinytc_core_feature_flag_large_register_file) { options = tinytc::large_register_file_compiler_option_cl; } if (err = clBuildProgram(p, 1, &device, options, nullptr, nullptr); err != CL_SUCCESS) { - if (source_ctx) { + if (ctx) { std::string log; std::size_t log_size; clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); @@ -134,7 +88,7 @@ tinytc_status_t tinytc_cl_kernel_bundle_create_with_binary(cl_program *bundle, c clGetProgramBuildInfo(p, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); tinytc_location_t loc = {}; - tinytc_source_context_report_error(source_ctx, &loc, log.c_str(), true); + tinytc_compiler_context_report_error(ctx, &loc, log.c_str()); } clReleaseProgram(p); TINYTC_CL_CHECK_STATUS(err); @@ -147,19 +101,33 @@ tinytc_status_t tinytc_cl_get_group_size(cl_kernel kernel, size_t *local_size) { if (local_size == nullptr) { return tinytc_status_invalid_arguments; } + constexpr int short_dev_list = 4; cl_program p; cl_device_id d; + cl_uint num_devices; TINYTC_CL_CHECK_STATUS(clGetKernelInfo(kernel, CL_KERNEL_PROGRAM, sizeof(p), &p, nullptr)); - TINYTC_CL_CHECK_STATUS(clGetProgramInfo(p, CL_PROGRAM_DEVICES, sizeof(d), &d, nullptr)); - return tinytc_cl_convert_status( - clGetKernelWorkGroupInfo(kernel, d, CL_KERNEL_COMPILE_WORK_GROUP_SIZE, - 3 * sizeof(std::size_t), local_size, nullptr)); + TINYTC_CL_CHECK_STATUS( + clGetProgramInfo(p, CL_PROGRAM_NUM_DEVICES, sizeof(num_devices), &num_devices, nullptr)); + if (num_devices <= short_dev_list) { + cl_device_id dbuf[4]; + TINYTC_CL_CHECK_STATUS( + clGetProgramInfo(p, CL_PROGRAM_DEVICES, sizeof(dbuf), &dbuf, nullptr)); + d = dbuf[0]; + } else { + auto dbuf = std::vector(num_devices); + TINYTC_CL_CHECK_STATUS(clGetProgramInfo( + p, CL_PROGRAM_DEVICES, num_devices * sizeof(cl_device_id), dbuf.data(), nullptr)); + d = dbuf[0]; + } + TINYTC_CL_CHECK_STATUS(clGetKernelWorkGroupInfo(kernel, d, CL_KERNEL_COMPILE_WORK_GROUP_SIZE, + 3 * sizeof(std::size_t), local_size, nullptr)); + return tinytc_status_success; } -void tinytc_cl_get_global_size(int64_t howmany, const size_t *local_size, size_t *global_size) { +void tinytc_cl_get_global_size(const size_t *num_groups, const size_t *local_size, + size_t *global_size) { for (size_t i = 0; i < 3; ++i) { - global_size[i] = local_size[i]; + global_size[i] = num_groups[i] * local_size[i]; } - global_size[2] *= howmany; } } diff --git a/src/cl/recipe_handler.cpp b/src/cl/recipe_handler.cpp index e9f562f6..c1966b54 100644 --- a/src/cl/recipe_handler.cpp +++ b/src/cl/recipe_handler.cpp @@ -3,9 +3,8 @@ #include "recipe_handler.hpp" #include "../recipe.hpp" -#include "../reference_counted.hpp" #include "error.hpp" -#include "tinytc/tinytc.hpp" +#include "tinytc/builder.hpp" #include "tinytc/tinytc_cl.h" #include "tinytc/tinytc_cl.hpp" #include "tinytc/types.h" @@ -17,17 +16,17 @@ namespace tinytc { -cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, recipe rec, - source_context source_ctx) +cl_recipe_handler::cl_recipe_handler(cl_context context, cl_device_id device, + shared_handle rec) : ::tinytc_recipe_handler(std::move(rec)) { - module_ = make_kernel_bundle(context, device, get_recipe().get_source(), std::move(source_ctx)); + module_ = create_kernel_bundle(context, device, get_binary(get_recipe()).get()); auto const num_kernels = get_recipe()->num_kernels(); kernels_.reserve(num_kernels); local_size_.reserve(num_kernels); for (int num = 0; num < num_kernels; ++num) { - kernels_.emplace_back(make_kernel(module_.get(), get_recipe()->kernel_name(num))); + kernels_.emplace_back(create_kernel(module_.get(), get_recipe()->kernel_name(num))); local_size_.emplace_back(get_group_size(kernels_.back().get())); } @@ -46,13 +45,13 @@ void cl_recipe_handler::active_kernel(int kernel_num) { void cl_recipe_handler::arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) { arg_handler_.set_arg(kernel(), arg_index, arg_size, arg_value); } -void cl_recipe_handler::mem_arg(std::uint32_t arg_index, const void *value, - tinytc_mem_type_t type) { - arg_handler_.set_mem_arg(kernel(), arg_index, value, type); +void cl_recipe_handler::mem_arg(std::uint32_t arg_index, const void *value, tinytc_mem_type_t ty) { + arg_handler_.set_mem_arg(kernel(), arg_index, value, ty); } void cl_recipe_handler::howmany(std::int64_t num) { - global_size_ = get_global_size(num, local_size()); + global_size_ = get_global_size( + std::array{static_cast(num), 1u, 1u}, local_size()); } auto cl_recipe_handler::kernel() -> cl_kernel { return kernels_[active_kernel_].get(); } @@ -69,15 +68,14 @@ extern "C" { tinytc_status_t tinytc_cl_recipe_handler_create(tinytc_recipe_handler_t *handler, cl_context context, cl_device_id device, - tinytc_recipe_t rec, - tinytc_source_context_t source_ctx) { + tinytc_recipe_t rec) { if (handler == nullptr || rec == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code_cl([&] { - *handler = std::make_unique(context, device, recipe(rec, true), - source_context(source_ctx, true)) - .release(); + *handler = + std::make_unique(context, device, shared_handle{rec, true}) + .release(); }); } diff --git a/src/cl/recipe_handler.hpp b/src/cl/recipe_handler.hpp index 6c4ebea7..fc0679d4 100644 --- a/src/cl/recipe_handler.hpp +++ b/src/cl/recipe_handler.hpp @@ -6,8 +6,8 @@ #include "../recipe.hpp" #include "argument_handler.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include "tinytc/types.hpp" #include #include @@ -19,12 +19,11 @@ namespace tinytc { struct cl_recipe_handler : ::tinytc_recipe_handler { public: - cl_recipe_handler(cl_context context, cl_device_id device, recipe rec, - source_context source_ctx); + cl_recipe_handler(cl_context context, cl_device_id device, shared_handle rec); void active_kernel(int kernel_num) override; void arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) override; - void mem_arg(std::uint32_t arg_index, const void *value, tinytc_mem_type_t type) override; + void mem_arg(std::uint32_t arg_index, const void *value, tinytc_mem_type_t ty) override; void howmany(std::int64_t num) override; auto kernel() -> cl_kernel; diff --git a/src/codegen_tools.cpp b/src/codegen_tools.cpp index 1ff7acf7..27b8d554 100644 --- a/src/codegen_tools.cpp +++ b/src/codegen_tools.cpp @@ -2,231 +2,402 @@ // SPDX-License-Identifier: BSD-3-Clause #include "codegen_tools.hpp" -#include "scalar_type.hpp" -#include "util.hpp" +#include "compiler_context.hpp" +#include "error.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "pass/constant_folding.hpp" +#include "tinytc/types.h" +#include "util/casting.hpp" +#include "util/ilist_base.hpp" +#include "util/overloaded.hpp" -#include -#include -#include -#include -#include -#include - -#include +#include +#include +#include #include - -using namespace clir; +#include namespace tinytc { -expr vload_helper(short vec_size, expr offset, expr ptr) { - switch (vec_size) { - case 1: - return ptr[std::move(offset)]; - case 2: - return vload2(std::move(offset), std::move(ptr)); - case 3: - return vload3(std::move(offset), std::move(ptr)); - case 4: - return vload4(std::move(offset), std::move(ptr)); - case 8: - return vload8(std::move(offset), std::move(ptr)); - case 16: - return vload16(std::move(offset), std::move(ptr)); - default: - break; +auto get_core_config_and_tiling(tinytc_func const &fn, const_tinytc_core_info_t info) + -> std::pair { + const auto get_core_config = [&]() -> core_config { + try { + return info->get_core_config(fn.subgroup_size()); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } }; - return nullptr; + core_config core_cfg = get_core_config(); + const auto wgs = fn.work_group_size(); + local_tiling tiling = {wgs[0] / core_cfg.subgroup_size, wgs[1]}; + return {core_cfg, tiling}; } -void store_helper(block_builder &bb, bool is_atomic, expr dst, scalar_type ty, address_space as, - expr value, expr beta) { - if (is_atomic) { - atomic_store_helper(bb, std::move(dst), ty, as, std::move(value), std::move(beta)); - } else { - bb.assign(dereference(dst), std::move(value) + std::move(beta) * dereference(dst)); - } -} - -void atomic_store_helper(block_builder &bb, expr dst, scalar_type ty, address_space as, expr value, - expr beta) { - int mode = -1; - visit(overloaded{ - [&](clir::internal::int_imm &c) { - mode = c.value() == 0 ? 0 : (c.value() == 1 ? 1 : -1); - }, - [&](clir::internal::uint_imm &c) { - mode = c.value() == 0u ? 0 : (c.value() == 1u ? 1 : -1); - }, - [&](clir::internal::float_imm &c) { - mode = c.value() == 0.0 ? 0 : (c.value() == 1.0 ? 1 : -1); - }, - [&](auto &) {}, - }, - *beta); - auto pointer_ty = pointer_to(to_clir_atomic_ty(ty, as, type_qualifier::volatile_t)); - auto atomic_dst = cast(std::move(pointer_ty), dst); - if (mode == 0) { - bb.add(call_builtin(builtin_function::atomic_store_explicit, - {std::move(atomic_dst), std::move(value), memory_order::relaxed, - memory_scope::work_group})); - } else if (mode == 1) { - bb.add(call_builtin(builtin_function::atomic_fetch_add_explicit, - {std::move(atomic_dst), std::move(value), memory_order::relaxed, - memory_scope::work_group})); - } else { - auto expected = bb.declare_assign(to_clir_ty(ty), "expected", dereference(dst)); - auto desired = bb.declare(to_clir_ty(ty), "desired"); - auto cmpxchg = - call_builtin(builtin_function::atomic_compare_exchange_strong_explicit, - {std::move(atomic_dst), address_of(std::move(expected)), desired, - memory_order::relaxed, memory_order::relaxed, memory_scope::work_group}); - bb.add(while_loop_builder(std::move(cmpxchg), true) - .body([&](block_builder &bb) { bb.assign(desired, value + beta * expected); }) - .get_product()); - } -} - -void dispatch_constant_dynamic(expr e, std::function const &const_case, - std::function const &dyn_case) { - visit( - overloaded{ - [&](clir::internal::int_imm &c) { const_case(c.value()); }, - [&](clir::internal::uint_imm &c) { const_case(static_cast(c.value())); }, - [&](auto &) { dyn_case(std::move(e)); }, - }, - *e); -} - -void tile_loop_by_sgs(block_builder &bb, expr loop_trip_count, unsigned sgs, unsigned num_tiles, - var sg_id, sgs_loop_body_builder const &body) { - dispatch_constant_dynamic( - std::move(loop_trip_count), - [&](std::int64_t c) { - tile_loop_by_sgs_constant(bb, c, sgs, num_tiles, std::move(sg_id), body); - }, - [&](expr d) { - tile_loop_by_sgs_dynamic(bb, std::move(d), sgs, num_tiles, std::move(sg_id), body); +void tile_loop_by_sgs(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, int num_tiles, + tinytc_value_t sg_id, sgs_loop_body_builder const &body, + tinytc_attr_t for_attributes) { + auto ity = loop_trip_count->ty(); + auto bool_ty = boolean_type::get(ity->context()); + auto c_sgs = bb.create(sgs, ity); + auto c_sgs_tiles = bb.create(sgs * num_tiles, ity); + auto c0 = bb.create(0, ity); + auto c_tiles_1 = bb.create(num_tiles - 1, ity); + + auto blocks = instant_constant_fold_add(bb, create(loop_trip_count, c_sgs, ity)); + auto rem = instant_constant_fold_add(bb, create(loop_trip_count, c_sgs, ity)); + + auto sg_id_cast = instant_constant_fold_add(bb, create(sg_id, ity)); + auto is_blocks_gt_0 = + instant_constant_fold_add(bb, create(blocks, c0, bool_ty)); + bb.if_condition(is_blocks_gt_0, [&](region_builder &bb) { + auto block_start = instant_constant_fold_add(bb, create(c_sgs, sg_id_cast, ity)); + auto block_end = instant_constant_fold_add(bb, create(c_sgs, blocks, ity)); + bb.for_loop( + std::move(block_start), std::move(block_end), c_sgs_tiles, + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, false, c_sgs); }, + for_attributes); + }); + + auto condition0 = instant_constant_fold_add(bb, create(rem, c0, bool_ty)); + bb.if_condition(condition0, [&](region_builder &bb) { + auto condition1 = + instant_constant_fold_add(bb, create(sg_id_cast, c_tiles_1, bool_ty)); + bb.if_condition(condition1, [&](region_builder &bb) { + auto block = instant_constant_fold_add(bb, create(blocks, c_sgs, ity)); + body(bb, block, true, rem); }); + }); } -void tile_loop_by_sgs_constant(block_builder &bb, unsigned loop_trip_count, unsigned sgs, - unsigned num_tiles, var sg_id, sgs_loop_body_builder const &body) { - auto blocks = loop_trip_count / sgs; - auto rem = loop_trip_count % sgs; - - auto block = bb.declare(generic_uint(), "blck"); - if (blocks > 0) { - bb.add(for_loop_builder(assignment(block, sgs * sg_id), block < sgs * blocks, - add_into(block, sgs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, false, sgs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - } - if (rem > 0) { - bb.assign(block, blocks * sgs); - bb.add(if_selection_builder(sg_id == num_tiles - 1u) - .then([&](block_builder &bb) { body(bb, block, true, rem); }) - .get_product()); - } -} - -void tile_loop_by_sgs_dynamic(block_builder &bb, expr loop_trip_count, unsigned sgs, - unsigned num_tiles, var sg_id, sgs_loop_body_builder const &body) { - auto blocks = bb.declare_assign(generic_uint(), "blocks", loop_trip_count / sgs); - auto rem = bb.declare_assign(generic_uint(), "rem", std::move(loop_trip_count) % sgs); - - auto block = bb.declare(generic_uint(), "blck"); - bb.add(for_loop_builder(assignment(block, sgs * sg_id), block < sgs * blocks, - add_into(block, sgs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, false, sgs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - - bb.add(if_selection_builder(rem > 0) - .then([&](block_builder &bb) { - bb.assign(block, blocks * sgs); - bb.add(if_selection_builder(sg_id == num_tiles - 1u) - .then([&](block_builder &bb) { body(bb, block, true, rem); }) - .get_product()); - }) - .get_product()); -} - -unsigned tile_loop_uniformly_max_block_size(unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles) { - auto blocks = 1 + (loop_trip_count - 1) / block_size; - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - auto bs = loop_trip_count / blocks; - auto rem = loop_trip_count % blocks; - return rem > 0 ? bs + 1 : bs; -} - -void tile_loop_uniformly(block_builder &bb, expr loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, uniform_loop_body_builder const &body) { - dispatch_constant_dynamic( - std::move(loop_trip_count), - [&](std::int64_t c) { - tile_loop_uniformly_constant(bb, c, block_size, num_tiles, std::move(sg_id), body); - }, - [&](expr d) { - tile_loop_uniformly_dynamic(bb, std::move(d), block_size, num_tiles, std::move(sg_id), - body); - }); +void tile_loop_uniformly(region_builder &bb, tinytc_value_t loop_trip_count, int block_size, + int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder const &body, + tinytc_attr_t for_attributes) { + auto ity = loop_trip_count->ty(); + auto bool_ty = boolean_type::get(ity->context()); + auto c0 = bb.create(0, ity); + auto c1 = bb.create(1, ity); + auto c_tiles = bb.create(num_tiles, ity); + + // Here we compute + // blocks = ceil(loop_trip_count / block_size) = 1 + (loop_trip_count - 1) / block_size + // blocks = ceil(blocks / num_tiles) * num_tiles = (1 + (blocks - 1) / num_tiles) * + // num_tiles + auto c_block_size = bb.create(block_size, ity); + auto blocks0 = instant_constant_fold_add(bb, create(loop_trip_count, c1, ity)); + auto blocks1 = instant_constant_fold_add(bb, create(blocks0, c_block_size, ity)); + auto blocks2 = instant_constant_fold_add(bb, create(blocks1, c_tiles, ity)); + auto blocks3 = instant_constant_fold_add(bb, create(c1, blocks2, ity)); + auto blocks = instant_constant_fold_add(bb, create(blocks3, c_tiles, ity)); + + auto bs = instant_constant_fold_add(bb, create(loop_trip_count, blocks, ity)); + auto bs_1 = instant_constant_fold_add(bb, create(bs, c1, ity)); + auto rem = instant_constant_fold_add(bb, create(loop_trip_count, blocks, ity)); + + auto sg_id_cast = instant_constant_fold_add(bb, create(sg_id, ity)); + // The following if makes it easy to eliminate the remainder handler in optimization if rem + // == 0 is known at compile time. Without the if, we would need to prove that block_start_1 + // is non-negative to eliminate the for-loop. + auto is_rem_gt_0 = instant_constant_fold_add(bb, create(rem, c0, bool_ty)); + bb.if_condition(is_rem_gt_0, [&](region_builder &bb) { + auto block_start_1 = instant_constant_fold_add(bb, create(bs_1, sg_id_cast, ity)); + auto block_end_1 = instant_constant_fold_add(bb, create(bs_1, rem, ity)); + auto step_1 = instant_constant_fold_add(bb, create(bs_1, c_tiles, ity)); + bb.for_loop( + std::move(block_start_1), std::move(block_end_1), std::move(step_1), + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs_1); }, + for_attributes); + }); + + auto tmp0 = instant_constant_fold_add(bb, create(rem, c_tiles, ity)); + auto tmp1 = instant_constant_fold_add(bb, create(sg_id_cast, tmp0, ity)); + auto sg_id_1 = instant_constant_fold_add(bb, create(tmp1, c_tiles, ity)); + auto tmp2 = instant_constant_fold_add(bb, create(bs, sg_id_1, ity)); + auto tmp3 = instant_constant_fold_add(bb, create(bs_1, rem, ity)); + auto block_start = instant_constant_fold_add(bb, create(tmp3, tmp2, ity)); + auto step = instant_constant_fold_add(bb, create(bs, c_tiles, ity)); + bb.for_loop( + std::move(block_start), loop_trip_count, std::move(step), + [&](region_builder &bb, tinytc_value_t block) { body(bb, block, bs); }, for_attributes); } -void tile_loop_uniformly_constant(block_builder &bb, unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, - uniform_loop_body_builder const &body) { - // Find minimum number of blocks such that the block sizes are smaller or equal block_size - auto blocks = 1 + (loop_trip_count - 1) / block_size; - // Increase the number of blocks if such that the number of blocks is a multiple - // of the number of tiles - blocks = (1 + (blocks - 1) / num_tiles) * num_tiles; - auto bs = loop_trip_count / blocks; - auto bs_1 = bs + 1; - auto rem = loop_trip_count % blocks; - - auto block = bb.declare(generic_uint(), "blck"); - if (rem > 0) { - bb.add(for_loop_builder(assignment(block, bs_1 * sg_id), block < bs_1 * rem, - add_into(block, bs_1 * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs_1); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - } - - auto sg_id_1 = (std::move(sg_id) + rem % num_tiles) % num_tiles; - bb.add(for_loop_builder(assignment(block, bs_1 * rem + bs * std::move(sg_id_1)), - block < loop_trip_count, add_into(block, bs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); -} - -void tile_loop_uniformly_dynamic(block_builder &bb, expr loop_trip_count, unsigned block_size, - unsigned num_tiles, var sg_id, - uniform_loop_body_builder const &body) { - auto blocks = - bb.declare_assign(generic_uint(), "blocks", 1 + (loop_trip_count - 1) / block_size); - bb.assign(blocks, (1 + (blocks - 1) / num_tiles) * num_tiles); - auto bs = bb.declare_assign(generic_uint(), "bs", loop_trip_count / blocks); - auto bs_1 = bb.declare_assign(generic_uint(), "bs_1", bs + 1); - auto rem = bb.declare_assign(generic_uint(), "rem", loop_trip_count % blocks); - - auto block = bb.declare(generic_uint(), "blck"); - bb.add(for_loop_builder(assignment(block, bs_1 * sg_id), block < bs_1 * rem, - add_into(block, bs_1 * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs_1); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - - auto sg_id_1 = (std::move(sg_id) + rem % num_tiles) % num_tiles; - bb.add(for_loop_builder(assignment(block, bs_1 * rem + bs * std::move(sg_id_1)), - block < std::move(loop_trip_count), add_into(block, bs * num_tiles)) - .body([&](block_builder &bb) { body(bb, block, bs); }) - .attribute(opencl_unroll_hint(1)) - .get_product()); +auto promote_binop_operands(region_builder &bb, tinytc_type_t result_ty, tinytc_value_t a, + tinytc_value_t b, location const &loc) + -> std::pair { + if (a->ty() != result_ty || b->ty() != result_ty) { + if (!promotable(a->ty(), result_ty) || !promotable(b->ty(), result_ty)) { + throw compilation_error(loc, status::ir_forbidden_promotion); + } + if (a->ty() != result_ty) { + a = bb.create(a, result_ty, loc); + } + if (b->ty() != result_ty) { + b = bb.create(b, result_ty, loc); + } + } + return {a, b}; +} + +auto mixed_precision_coopmatrix_scale(region_builder &bb, tinytc_value_t a, tinytc_value_t b, + location const &loc) -> tinytc_value_t { + coopmatrix_type *bt = dyn_cast(b->ty()); + if (bt == nullptr) { + throw compilation_error(loc, status::ir_expected_coopmatrix); + } + const auto a_ty = a->ty(); + const auto b_ty = bt->component_ty(); + if (a_ty != b_ty) { + if (!promotable(a_ty, b_ty)) { + throw compilation_error(loc, status::ir_forbidden_promotion); + } + a = bb.create(a, b_ty, loc); + } + return bb.create(a, b, bt, loc); +} + +auto get_atomic_store_flag(tinytc_value_t beta) -> std::optional { + constant_inst beta_cst = dyn_cast(beta->defining_inst()); + if (beta_cst) { + if (beta_cst.is_zero()) { + return store_flag::atomic; + } else if (beta_cst.is_identity()) { + return store_flag::atomic_add; + } + } + return std::nullopt; +} +void blas_update(region_builder &bb, bool atomic, tinytc_value_t alpha, tinytc_value_t ab, + tinytc_value_t beta, tinytc_value_t C, array_view index_list, + location const &loc) { + memref_type *ct = dyn_cast(C->ty()); + if (ct == nullptr) { + throw compilation_error(loc, {C}, status::ir_expected_number); + } + auto alpha_ab = mixed_precision_arithmetic(bb, ct->element_ty(), alpha, ab, loc); + if (atomic) { + auto flag = get_atomic_store_flag(beta); + if (!flag) { + throw compilation_error(loc, status::ir_invalid_beta); + } + bb.create(*flag, alpha_ab, C, index_list, loc); + } else { + auto c = bb.create(C, index_list, ct->element_ty(), loc); + auto beta_c = mixed_precision_arithmetic(bb, ct->element_ty(), beta, c, loc); + auto alpha_ab_plus_beta_c = + mixed_precision_arithmetic(bb, ct->element_ty(), alpha_ab, beta_c, loc); + bb.create(store_flag::regular, alpha_ab_plus_beta_c, C, index_list, loc); + } +} + +auto instant_constant_fold_add(region_builder &bb, unique_handle &&i) + -> tinytc_value_t { + auto ctx = i->context(); + if (!ctx) { + throw compilation_error(i->loc(), status::internal_compiler_error); + } + + auto fold = visit(constant_folding{ctx->opt_flag(optflag::unsafe_fp_math)}, *i); + auto val = std::visit(overloaded{[](tinytc_value_t &v) -> tinytc_value_t { return v; }, + [&bb](unique_handle &j) -> tinytc_value_t { + if (j) { + return bb.add(std::move(j)); + } + return nullptr; + }}, + fold); + if (val) { + return val; + } + return bb.add(std::move(i)); +} + +auto get_bool_constant(tinytc_value_t val) -> std::optional { + if (auto i = val->defining_inst(); i) { + if (auto ci = dyn_cast(i); ci) { + if (std::holds_alternative(ci.value())) { + return std::get(ci.value()); + } + } + } + return std::nullopt; +} + +auto get_int_constant(const_tinytc_value_t val) -> std::optional { + if (auto i = val->defining_inst(); i) { + if (auto ci = dyn_cast(i); ci) { + if (std::holds_alternative(ci.value())) { + return std::get(ci.value()); + } + } + } + return std::nullopt; +} + +auto get_int_constant(tinytc_value const &val) -> std::optional { + return get_int_constant(&val); +} + +auto get_coopmatrix_type(tinytc_value const &v) -> coopmatrix_type * { + auto ct = dyn_cast(v.ty()); + if (!ct) { + throw compilation_error(v.loc(), status::ir_expected_coopmatrix); + } + return ct; +} +auto get_memref_type(tinytc_value const &v) -> memref_type * { + auto mt = dyn_cast(v.ty()); + if (!mt) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return mt; +} +auto get_yield(location const &loc, tinytc_region ®) -> yield_inst { + auto y = yield_inst(nullptr); + if (auto it = reg.end(); --it != reg.end()) { + y = dyn_cast(it.get()); + } + if (!y) { + throw compilation_error(loc, status::ir_must_have_yield); + } + return y; +} + +auto add_check(checked_flag flag, checked_flag new_flag) -> checked_flag { + return checked_flag{std::underlying_type_t(flag) | + std::underlying_type_t(new_flag)}; +} + +work_group_op::work_group_op(std::int32_t num_tiles, std::int32_t subgroup_size, tinytc_type_t ty) + : num_tiles_{num_tiles}, subgroup_size_{subgroup_size}, ty_{ty}, tmp_{nullptr} {} + +void work_group_op::setup(region_builder &bb, location const &loc) { + if (num_tiles_ > 1) { + auto tmp_ty = get(ty_, array_view{num_tiles_}, + array_view{}, address_space::local); + tmp_ = bb.create(tmp_ty, loc); + } +} + +void work_group_op::teardown(region_builder &bb) { + if (tmp_) { + bb.create(tmp_, location{}); + } +} + +auto work_group_reduce::make(region_builder &bb, tinytc_value_t a, location const &loc) + -> tinytc_value_t { + auto a_reduced = bb.create(a, ty_, loc); + + if (num_tiles_ > 1) { + auto ctx = a->context(); + auto bool_ty = get(ctx); + auto i32_ty = get(ctx); + auto index_ty = get(ctx); + + auto sgid = bb.create(i32_ty, loc); + auto sglid = bb.create(i32_ty, loc); + auto c_zero = bb.constant_zero(i32_ty, loc); + auto is_sglid_0 = bb.create(sglid, c_zero, bool_ty, loc); + bb.if_condition( + is_sglid_0, + [&](region_builder &bb) { + auto sgid_index = bb.create(sgid, index_ty, loc); + bb.create(store_flag::regular, a_reduced, tmp_, array_view{sgid_index}, + loc); + }, + loc); + bb.create(static_cast(address_space::local), loc); + + auto is_lid_0 = bb.create(sgid, c_zero, bool_ty, loc); + bb.if_condition( + is_lid_0, + [&](region_builder &bb) { + auto c_num_tiles = bb.create(num_tiles_, i32_ty, loc); + auto c_sgs = bb.create(subgroup_size_, i32_ty, loc); + auto c_init = bb.constant_zero(ty_, loc); + auto acc = + bb.for_loop(sglid, c_num_tiles, c_sgs, {c_init}, {ty_}, + [&](region_builder &bb, array_view args) { + auto lv_index = bb.create(args[0], index_ty, loc); + auto a_sg_reduced = + bb.create(tmp_, array_view{lv_index}, ty_, loc); + auto sum = bb.create(args[1], a_sg_reduced, ty_, loc); + bb.create(array_view{sum}, loc); + }); + a_reduced = bb.create(acc[0], ty_, loc); + return a_reduced; + }, + loc); + } + return a_reduced; +} + +auto work_group_inclusive_scan::make(region_builder &bb, tinytc_value_t a, bool compute_sum, + location const &loc) + -> std::pair { + auto a_scan = bb.create(a, ty_, loc); + + auto ctx = a->context(); + auto i32_ty = get(ctx); + + if (num_tiles_ > 1) { + auto bool_ty = get(ctx); + auto index_ty = get(ctx); + + auto sgid = bb.create(i32_ty, loc); + auto sglid = bb.create(i32_ty, loc); + + auto c_sgs_1 = bb.create(subgroup_size_ - 1, i32_ty, loc); + auto is_last_sglid = bb.create(sglid, c_sgs_1, bool_ty, loc); + bb.if_condition( + is_last_sglid, + [&](region_builder &bb) { + auto sgid_index = bb.create(sgid, index_ty, loc); + bb.create(store_flag::regular, a_scan, tmp_, array_view{sgid_index}, + loc); + }, + loc); + bb.create(static_cast(address_space::local), loc); + + auto c_zero = bb.constant_zero(i32_ty, loc); + a_scan = bb.for_loop(c_zero, sgid, nullptr, {a_scan}, {ty_}, + [&](region_builder &bb, array_view args) { + auto lv_index = bb.create(args[0], index_ty, loc); + auto prefix = + bb.create(tmp_, array_view{lv_index}, ty_, loc); + auto scan = bb.create(args[1], prefix, ty_, loc); + bb.create(array_view{scan}, loc); + })[0]; + + if (compute_sum) { + auto c_num_tiles_1 = bb.create(num_tiles_ - 1, i32_ty, loc); + auto c_num_tiles_1_index = bb.create(c_num_tiles_1, index_ty, loc); + auto is_last_sgid = bb.create(sgid, c_num_tiles_1, bool_ty, loc); + auto is_last_work_item = bb.create(is_last_sglid, is_last_sgid, bool_ty, loc); + bb.if_condition( + is_last_work_item, + [&](region_builder &bb) { + bb.create(store_flag::regular, a_scan, tmp_, + array_view{c_num_tiles_1_index}, loc); + }, + loc); + bb.create(static_cast(address_space::local), + loc); + auto sum = bb.create(tmp_, array_view{c_num_tiles_1_index}, ty_, loc); + return {a_scan, sum}; + } + } else if (compute_sum) { + auto c_sgs_1 = bb.create(subgroup_size_ - 1, i32_ty, loc); + auto sum = bb.create(a_scan, c_sgs_1, ty_, loc); + return {a_scan, sum}; + } + return {a_scan, nullptr}; } } // namespace tinytc diff --git a/src/codegen_tools.hpp b/src/codegen_tools.hpp index 4825b0a5..1da223fa 100644 --- a/src/codegen_tools.hpp +++ b/src/codegen_tools.hpp @@ -4,53 +4,107 @@ #ifndef CODEGEN_TOOLS_20240229_HPP #define CODEGEN_TOOLS_20240229_HPP +#include "device_info.hpp" +#include "node/inst_view.hpp" +#include "tiling.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" #include #include - -#include -#include -#include -#include +#include +#include +#include namespace tinytc { -clir::expr vload_helper(short vec_size, clir::expr offset, clir::expr ptr); - -void store_helper(clir::block_builder &bb, bool is_atomic, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, clir::expr beta); -void atomic_store_helper(clir::block_builder &bb, clir::expr dst, scalar_type ty, - clir::address_space as, clir::expr value, clir::expr beta); - -void dispatch_constant_dynamic(clir::expr e, std::function const &const_case, - std::function const &dyn_case); +auto get_core_config_and_tiling(tinytc_func const &fn, const_tinytc_core_info_t info) + -> std::pair; using sgs_loop_body_builder = - std::function; + std::function; using uniform_loop_body_builder = - std::function; - -void tile_loop_by_sgs(clir::block_builder &bb, clir::expr loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, sgs_loop_body_builder const &body); -void tile_loop_by_sgs_constant(clir::block_builder &bb, unsigned loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, - sgs_loop_body_builder const &body); -void tile_loop_by_sgs_dynamic(clir::block_builder &bb, clir::expr loop_trip_count, unsigned sgs, - unsigned num_tiles, clir::var sg_id, - sgs_loop_body_builder const &body); - -unsigned tile_loop_uniformly_max_block_size(unsigned loop_trip_count, unsigned block_size, - unsigned num_tiles); -void tile_loop_uniformly(clir::block_builder &bb, clir::expr loop_trip_count, unsigned block_size, - unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); -void tile_loop_uniformly_constant(clir::block_builder &bb, unsigned loop_trip_count, - unsigned block_size, unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); -void tile_loop_uniformly_dynamic(clir::block_builder &bb, clir::expr loop_trip_count, - unsigned block_size, unsigned num_tiles, clir::var sg_id, - uniform_loop_body_builder const &body); + std::function; + +void tile_loop_by_sgs(region_builder &bb, tinytc_value_t loop_trip_count, int sgs, int num_tiles, + tinytc_value_t sg_id, sgs_loop_body_builder const &body, + tinytc_attr_t for_attributes = nullptr); + +void tile_loop_uniformly(region_builder &bb, tinytc_value_t loop_trip_count, int block_size, + int num_tiles, tinytc_value_t sg_id, uniform_loop_body_builder const &body, + tinytc_attr_t for_attributes = nullptr); + +auto promote_binop_operands(region_builder &bb, tinytc_type_t result_ty, tinytc_value_t a, + tinytc_value_t b, location const &loc) + -> std::pair; +template +auto mixed_precision_arithmetic(region_builder &bb, tinytc_type_t result_ty, tinytc_value_t a, + tinytc_value_t b, location const &loc) -> tinytc_value_t { + auto [a_p, b_p] = promote_binop_operands(bb, result_ty, a, b, loc); + return bb.create(a_p, b_p, result_ty, loc); +} +auto mixed_precision_coopmatrix_scale(region_builder &bb, tinytc_value_t a, tinytc_value_t b, + location const &loc) -> tinytc_value_t; + +auto get_atomic_store_flag(tinytc_value_t beta) -> std::optional; +void blas_update(region_builder &bb, bool atomic, tinytc_value_t alpha, tinytc_value_t ab, + tinytc_value_t beta, tinytc_value_t C, array_view index_list, + location const &loc); + +auto instant_constant_fold_add(region_builder &bb, unique_handle &&i) + -> tinytc_value_t; +auto get_bool_constant(tinytc_value_t val) -> std::optional; +auto get_int_constant(const_tinytc_value_t val) -> std::optional; +auto get_int_constant(tinytc_value const &val) -> std::optional; +auto get_coopmatrix_type(tinytc_value const &v) -> coopmatrix_type *; +auto get_memref_type(tinytc_value const &v) -> memref_type *; +auto get_yield(location const &loc, tinytc_region ®) -> yield_inst; + +template auto get_int_constants(T &&val_range) -> std::vector { + auto result = std::vector{}; + result.reserve(val_range.size()); + for (auto &val : val_range) { + const auto cst = get_int_constant(val); + result.emplace_back(cst ? *cst : dynamic); + } + return result; +} + +auto add_check(checked_flag flag, checked_flag new_flag) -> checked_flag; + +class work_group_op { + public: + work_group_op(std::int32_t num_tiles, std::int32_t subgroup_size, tinytc_type_t ty); + + void setup(region_builder &bb, location const &loc); + void teardown(region_builder &bb); + + inline auto num_tiles() const -> std::int32_t { return num_tiles_; } + inline auto subgroup_size() const -> std::int32_t { return subgroup_size_; } + inline auto ty() const -> tinytc_type_t { return ty_; } + + protected: + std::int32_t num_tiles_, subgroup_size_; + tinytc_type_t ty_; + tinytc_value_t tmp_; +}; + +class work_group_reduce : public work_group_op { + public: + using work_group_op::work_group_op; + + auto make(region_builder &bb, tinytc_value_t a, location const &loc) -> tinytc_value_t; +}; + +class work_group_inclusive_scan : public work_group_op { + public: + using work_group_op::work_group_op; + + auto make(region_builder &bb, tinytc_value_t a, bool compute_sum, location const &loc) + -> std::pair; +}; } // namespace tinytc diff --git a/src/compiler.cpp b/src/compiler.cpp index 8493d738..d9c8014f 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -1,59 +1,164 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause -#include "device_info.hpp" +#include "compiler_context.hpp" #include "error.hpp" -#include "node/program_node.hpp" -#include "parser.hpp" +#include "node/prog.hpp" +// IWYU pragma: begin_keep +#include "pass/dump_cfg.hpp" +#include "pass/dump_def_use.hpp" +#include "pass/dump_gcd.hpp" +#include "pass/dump_ir.hpp" +// IWYU pragma: end_keep +#include "pass/check_ir.hpp" +#include "pass/constant_propagation.hpp" +#include "pass/convert_to_spirv.hpp" +#include "pass/dead_code_elimination.hpp" +#include "pass/insert_barrier.hpp" +#include "pass/insert_lifetime_stop.hpp" +#include "pass/lower_coopmatrix.hpp" +#include "pass/lower_foreach.hpp" +#include "pass/lower_linalg.hpp" +#include "pass/stack.hpp" +#include "pass/work_group_size.hpp" #include "passes.hpp" -#include "reference_counted.hpp" -#include "required_extensions.hpp" -#include "source.hpp" -#include "tinytc/tinytc.h" +#include "spv/pass/assemble.hpp" +#include "spv/pass/assign_ids.hpp" +#include "tinytc/core.h" #include "tinytc/types.h" +#include "tinytc/types.hpp" -#include -#include - -#include -#include +#include +#include // IWYU pragma: keep #include -#include using namespace tinytc; +namespace tinytc { + +template struct optflag_setter { + PassT &pass; + tinytc_compiler_context_t ctx; + + template void operator()(Flags &&...flags) { + (pass.set_opt_flag(flags, ctx->opt_flag(flags)), ...); + } +}; + +void apply_default_optimization_pipeline(tinytc_prog_t prg, const_tinytc_core_info_t info) { + auto ctx = prg->context(); + const auto opt_level = ctx->opt_level(); + + // passes + auto cpp = constant_propagation_pass{}; + optflag_setter{cpp, ctx}(tinytc::optflag::unsafe_fp_math); + + run_function_pass(check_ir_pass{}, *prg); + + if (opt_level >= 1) { + // We run constant propagation + dead code elimination early to capture dead allocas + // (later on they are maybe "in use" due to the lifetime_stop instruction) + run_function_pass(cpp, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } + + run_function_pass(insert_lifetime_stop_pass{}, *prg); + run_function_pass(set_stack_ptr_pass{}, *prg); + run_function_pass(insert_barrier_pass{}, *prg); + run_function_pass(work_group_size_pass{info}, *prg); + + run_function_pass(lower_linalg_pass{info}, *prg); + // Run set stack ptr again as lower linalg may introduce allocas for the duration of the + // linalg op. Lower linalg is expected to insert lifetime_stop instructions, after it is done + // so we do not need to run the lifetime stop pass again. + run_function_pass(set_stack_ptr_pass{}, *prg); + run_function_pass(lower_foreach_pass{info}, *prg); + if (opt_level >= 1) { + run_function_pass(cpp, *prg); + run_function_pass(dead_code_elimination_pass{}, *prg); + } + run_function_pass(lower_coopmatrix_pass{info}, *prg); + + run_function_pass(check_ir_pass{}, *prg); +} + +} // namespace tinytc + extern "C" { -tinytc_status_t tinytc_prog_compile_to_opencl(tinytc_source_t *src, tinytc_prog_t prg, - const_tinytc_core_info_t info, - tinytc_source_context_t ctx) { - if (src == nullptr || prg == nullptr || info == nullptr) { +tinytc_status_t tinytc_run_function_pass(char const *pass_name, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code( [&] { - // passes - check_ir(*prg); - insert_lifetime_stop_inst(*prg); - set_stack_ptrs(*prg); - insert_barriers(*prg); - set_work_group_size(*prg, *info); - // opencl - auto ast = generate_opencl_ast(*prg, *info); - clir::make_names_unique(ast); - - auto oss = std::ostringstream{}; - auto ext = required_extensions(ast); - for (auto const &e : ext) { - oss << "#pragma OPENCL EXTENSION " << e << " : enable" << std::endl; - } - - clir::generate_opencl(oss, std::move(ast)); - - *src = std::make_unique<::tinytc_source>(oss.str(), prg->loc(), std::move(ext), - info->core_features()) - .release(); +#define FUNCTION_PASS(NAME, CREATE_PASS, ...) \ + if (strcmp(NAME, pass_name) == 0) { \ + auto pass = CREATE_PASS; \ + optflag_setter{pass, prg->context()}(__VA_ARGS__); \ + return run_function_pass(std::move(pass), *prg); \ + } +#define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) \ + if (strcmp(NAME, pass_name) == 0) { \ + return run_function_pass(CREATE_PASS(info), *prg); \ + } +#include "passes.def" +#undef FUNCTION_PASS +#undef FUNCTION_PASS_WITH_INFO + throw status::unknown_pass_name; }, - ctx); + prg->context()); +} + +tinytc_status_t tinytc_list_function_passes(size_t *names_size, char const *const **names) { + if (names_size == nullptr || names == nullptr) { + return tinytc_status_invalid_arguments; + } +#define FUNCTION_PASS(NAME, CREATE_PASS, ...) NAME, +#define FUNCTION_PASS_WITH_INFO(NAME, CREATE_PASS) NAME, + static char const *const pass_names[] = { +#include "passes.def" + }; +#undef FUNCTION_PASS +#undef FUNCTION_PASS_WITH_INFO + *names_size = sizeof(pass_names) / sizeof(char const *); + *names = pass_names; + + return tinytc_status_success; +} + +tinytc_status_t tinytc_prog_compile_to_spirv(tinytc_spv_mod_t *mod, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (mod == nullptr || prg == nullptr || info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { + apply_default_optimization_pipeline(prg, info); + + *mod = convert_to_spirv_pass{info}.run_on_program(*prg).release(); + spv::id_assigner{}.run_on_module(**mod); + }, + prg->context()); +} + +tinytc_status_t tinytc_prog_compile_to_spirv_and_assemble(tinytc_binary_t *bin, tinytc_prog_t prg, + const_tinytc_core_info_t info) { + if (bin == nullptr || prg == nullptr || info == nullptr) { + return tinytc_status_invalid_arguments; + } + tinytc_spv_mod_t mod; + TINYTC_CHECK_STATUS(tinytc_prog_compile_to_spirv(&mod, prg, info)); + auto mod_ = shared_handle{mod}; // For clean-up + TINYTC_CHECK_STATUS(tinytc_spirv_assemble(bin, mod_.get())); + return tinytc_status_success; +} + +tinytc_status_t tinytc_spirv_assemble(tinytc_binary_t *bin, const_tinytc_spv_mod_t mod) { + if (bin == nullptr || mod == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *bin = spv::assembler{}.run_on_module(*mod).release(); }); } } diff --git a/src/compiler_context.cpp b/src/compiler_context.cpp new file mode 100644 index 00000000..07cf396d --- /dev/null +++ b/src/compiler_context.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" +#include "error.hpp" +#include "node/value.hpp" +#include "tinytc/core.h" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { +void default_error_reporter(char const *, const tinytc_location_t *, void *) {} +} // namespace tinytc + +using namespace tinytc; + +extern "C" { + +tinytc_compiler_context::tinytc_compiler_context() + : cache_{std::make_unique(this)} { + opt_flags_.fill(-1); +} + +auto tinytc_compiler_context::source_name(std::int32_t source_id) + -> std::pair { + if (has_source_id(source_id)) { + auto &si = sources_[source_id - 1]; + return {si.name.c_str(), si.name.size()}; + } + return {unavailable_source_name, sizeof(unavailable_source_name) / sizeof(char) - 1}; +} +auto tinytc_compiler_context::source_text(std::int32_t source_id) + -> std::pair { + if (has_source_id(source_id)) { + auto &si = sources_[source_id - 1]; + return {si.text.c_str(), si.text.size()}; + } + return {"", 0}; +} +void tinytc_compiler_context::report_error(location const &l, char const *what) { + report_error(l, {}, what); +} + +void tinytc_compiler_context::report_error(tinytc_location const &l, + array_view const &ref_values, + char const *what) { + auto [name, name_size] = source_name(l.begin.source_id); + auto [text, text_size] = source_text(l.begin.source_id); + auto err = report_error_with_context(text, text_size, name, l, what); + reporter_(err.c_str(), &l, user_data_); + for (auto &ref_value : ref_values) { + if (ref_value) { + auto err = report_error_with_context(text, text_size, name, ref_value->loc(), + "value defined here"); + reporter_(err.c_str(), &ref_value->loc(), user_data_); + } + } +} + +auto tinytc_compiler_context::opt_flag(tinytc_optflag_t flag) const -> bool { + const auto state = opt_flags_[flag]; + if (state >= 0) { + return state > 0; + } + const auto clamped_opt_level = std::min(2, std::max(0, opt_level_)); + return default_opt_flags[clamped_opt_level][flag]; +} + +tinytc_status_t tinytc_compiler_context_create(tinytc_compiler_context_t *ctx) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *ctx = std::make_unique().release(); }); +} + +tinytc_status_t tinytc_compiler_context_add_source(tinytc_compiler_context_t ctx, char const *name, + char const *text, int32_t *source_id) { + if (ctx == nullptr || name == nullptr || text == nullptr || source_id == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *source_id = ctx->add_source(name, text); }); +} + +tinytc_status_t tinytc_compiler_context_set_error_reporter(tinytc_compiler_context_t ctx, + tinytc_error_reporter_t reporter, + void *user_data) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->set_error_reporter(reporter, user_data); }); +} + +tinytc_status_t tinytc_compiler_context_set_optimization_flag(tinytc_compiler_context_t ctx, + tinytc_optflag_t flag, + int32_t state) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->opt_flag(flag, state); }); +} + +tinytc_status_t tinytc_compiler_context_set_optimization_level(tinytc_compiler_context_t ctx, + int32_t level) { + if (ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + ctx->opt_level(level); + return tinytc_status_success; +} + +tinytc_status_t tinytc_compiler_context_report_error(tinytc_compiler_context_t ctx, + const tinytc_location_t *location, + char const *what) { + if (ctx == nullptr || location == nullptr || what == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { ctx->report_error(*location, what); }); +} + +tinytc_status_t tinytc_compiler_context_release(tinytc_compiler_context_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + auto ref_count = obj->dec_ref(); + if (ref_count == 0) { + delete obj; + } + return tinytc_status_success; +} + +tinytc_status_t tinytc_compiler_context_retain(tinytc_compiler_context_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + obj->inc_ref(); + return tinytc_status_success; +} +} diff --git a/src/compiler_context.hpp b/src/compiler_context.hpp new file mode 100644 index 00000000..4665d5ce --- /dev/null +++ b/src/compiler_context.hpp @@ -0,0 +1,89 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COMPILER_CONTEXT_20240924_HPP +#define COMPILER_CONTEXT_20240924_HPP + +#include "compiler_context_cache.hpp" +#include "reference_counted.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { +enum class optflag; + +void default_error_reporter(char const *what, const tinytc_location_t *location, void *user_data); +} // namespace tinytc + +struct tinytc_compiler_context : tinytc::reference_counted { + public: + constexpr static const char unavailable_source_name[] = "Source name unavailable"; + constexpr static std::array, 3u> default_opt_flags = { + {{false}, {false}, {true}}}; + + tinytc_compiler_context(); + + inline auto cache() -> tinytc::compiler_context_cache * { return cache_.get(); } + + inline void set_error_reporter(tinytc_error_reporter_t reporter, void *user_data) { + reporter_ = reporter; + user_data_ = user_data; + } + + // source / error handling + inline auto add_source(std::string name, std::string text) -> std::int32_t { + sources_.emplace_back(source_input{std::move(name), std::move(text)}); + return static_cast(sources_.size()); + } + inline auto add_source(char const *name, char const *text) -> std::int32_t { + sources_.emplace_back(source_input{std::string(name), std::string(text)}); + return static_cast(sources_.size()); + } + auto source_name(std::int32_t source_id) -> std::pair; + auto source_text(std::int32_t source_id) -> std::pair; + void report_error(tinytc_location const &l, char const *what); + void report_error(tinytc_location const &l, + tinytc::array_view const &ref_values, char const *what); + void report_error(tinytc_location const &l, + tinytc::array_view const &ref_values); + + auto opt_flag(tinytc_optflag_t flag) const -> bool; + inline void opt_flag(tinytc_optflag_t flag, std::int32_t state) { opt_flags_[flag] = state; } + inline auto opt_flag(tinytc::optflag flag) const -> bool { + return opt_flag(static_cast(flag)); + } + inline void opt_flag(tinytc::optflag flag, std::int32_t state) { + opt_flag(static_cast(flag), state); + } + + inline auto opt_level() const noexcept -> std::int32_t { return opt_level_; } + inline void opt_level(std::int32_t level) noexcept { opt_level_ = level; } + + inline auto index_bit_width() const noexcept -> std::size_t { return 64; } + + private: + struct source_input { + std::string name, text; + }; + + inline bool has_source_id(std::int32_t source_id) const { + return source_id >= 1 && static_cast(source_id) <= sources_.size(); + } + + std::unique_ptr cache_; + tinytc_error_reporter_t reporter_ = &tinytc::default_error_reporter; + void *user_data_ = nullptr; + std::vector sources_; + std::array opt_flags_; + std::int32_t opt_level_ = 2; +}; + +#endif // COMPILER_CONTEXT_20240924_HPP diff --git a/src/compiler_context_cache.cpp b/src/compiler_context_cache.cpp new file mode 100644 index 00000000..09bde1f0 --- /dev/null +++ b/src/compiler_context_cache.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context_cache.hpp" +#include "node/attr.hpp" +#include "node/type.hpp" +#include "tinytc/types.hpp" + +namespace tinytc { + +compiler_context_cache::compiler_context_cache(tinytc_compiler_context_t ctx) { + bool_ty = std::unique_ptr(new boolean_type(ctx)); + void_ty = std::unique_ptr(new void_type(ctx)); + + i8_ty = std::unique_ptr(new i8_type(ctx)); + i16_ty = std::unique_ptr(new i16_type(ctx)); + i32_ty = std::unique_ptr(new i32_type(ctx)); + i64_ty = std::unique_ptr(new i64_type(ctx)); + index_ty = std::unique_ptr(new index_type(ctx)); + bf16_ty = std::unique_ptr(new bf16_type(ctx)); + f16_ty = std::unique_ptr(new f16_type(ctx)); + f32_ty = std::unique_ptr(new f32_type(ctx)); + f64_ty = std::unique_ptr(new f64_type(ctx)); + c32_ty = std::unique_ptr(new c32_type(ctx)); + c64_ty = std::unique_ptr(new c64_type(ctx)); + + false_attr = std::unique_ptr(new boolean_attr(ctx, false)); + true_attr = std::unique_ptr(new boolean_attr(ctx, true)); +} + +compiler_context_cache::~compiler_context_cache() {} + +} // namespace tinytc + diff --git a/src/compiler_context_cache.hpp b/src/compiler_context_cache.hpp new file mode 100644 index 00000000..5428257c --- /dev/null +++ b/src/compiler_context_cache.hpp @@ -0,0 +1,71 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COMPILER_CONTEXT_CACHE_20240925_HPP +#define COMPILER_CONTEXT_CACHE_20240925_HPP + +#include "tinytc/types.h" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include + +namespace std { +template <> class hash> { + public: + auto operator()(std::pair const &key) const -> std::size_t { + return tinytc::fnv1a_combine(key.first, key.second); + } +}; +} // namespace std + +namespace tinytc { + +template class unique_storage { + public: + ~unique_storage() { + for (auto &m : map_) { + delete m.second; + } + } + + template + auto get(std::uint64_t hash, EqualFun &&is_equal, MakeFun &&make) -> T { + auto range = map_.equal_range(hash); + for (auto it = range.first; it != range.second; ++it) { + if (is_equal(it->second)) { + return it->second; + } + } + + return map_.emplace(hash, make())->second; + } + + private: + std::unordered_multimap map_; +}; + +class compiler_context_cache { + public: + compiler_context_cache(tinytc_compiler_context_t ctx); + ~compiler_context_cache(); + + compiler_context_cache(compiler_context_cache const &) = delete; + compiler_context_cache &operator=(compiler_context_cache const &) = delete; + + std::unique_ptr void_ty, bool_ty; + std::unique_ptr i8_ty, i16_ty, i32_ty, i64_ty, index_ty, bf16_ty, f16_ty, f32_ty, + f64_ty, c32_ty, c64_ty; + unique_storage coopmatrix_tys, group_tys, memref_tys; + + unique_storage array_attrs, dictionary_attrs, integer_attrs, string_attrs; + std::unique_ptr false_attr, true_attr; +}; + +} // namespace tinytc + +#endif // COMPILER_CONTEXT_CACHE_20240925_HPP diff --git a/src/coopmatrix_layout.cpp b/src/coopmatrix_layout.cpp new file mode 100644 index 00000000..18ca76aa --- /dev/null +++ b/src/coopmatrix_layout.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "node/type.hpp" +#include "number.hpp" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { + +auto get_layout(core_config const &cfg, coopmatrix_type const *ct) -> coopmatrix_layout { + auto l = coopmatrix_layout{}; + l.sty = ct->component_ty(); + l.rows = std::min(ct->rows(), static_cast(cfg.subgroup_size)); + l.cols = (1 + (l.rows * ct->cols() - 1) / cfg.subgroup_size) * cfg.subgroup_size / l.rows; + l.blocks = ct->rows() / l.rows; + l.length = l.rows * l.cols * l.blocks / cfg.subgroup_size; + l.shape1 = ct->cols(); + l.blocks1 = 1; + auto sty_size = size(l.sty); + if (ct->use() == matrix_use::b && l.blocks > 1) { + const auto omega_b = std::max(1, static_cast(2 / sty_size)); + l.blocks1 = omega_b; + } + l.ops_per_chan = 1; + if (ct->use() == matrix_use::a) { + const auto omega = std::max(1, static_cast(4 / sty_size)); + if (l.cols % l.ops_per_chan == 0) { + l.ops_per_chan = omega; + } + } + + return l; +} + +} // namespace tinytc diff --git a/src/coopmatrix_layout.hpp b/src/coopmatrix_layout.hpp new file mode 100644 index 00000000..4bb74a87 --- /dev/null +++ b/src/coopmatrix_layout.hpp @@ -0,0 +1,52 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_LAYOUT_20250428_HPP +#define COOPMATRIX_LAYOUT_20250428_HPP + +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/fnv1a.hpp" + +#include +#include +#include + +namespace tinytc { +class core_config; + +struct coopmatrix_layout { + tinytc_type_t sty; + std::int64_t rows, cols, blocks, length, shape1, blocks1; + std::int32_t ops_per_chan; + + inline auto operator==(coopmatrix_layout const &other) const { + return sty == other.sty && rows == other.rows && cols == other.cols && + blocks == other.blocks && length == other.length && shape1 == other.shape1 && + blocks1 == other.blocks1 && ops_per_chan == other.ops_per_chan; + } + + inline auto component_no(std::int64_t block1, std::int64_t col, std::int64_t block2) const + -> std::int64_t { + return block1 + col * blocks1 + block2 * blocks1 * (length / blocks); + } + + inline auto component_no(std::int64_t col, std::int64_t block) const -> std::int64_t { + return component_no(block % blocks1, col, block / blocks1); + } +}; + +auto get_layout(core_config const &cfg, coopmatrix_type const *ct) -> coopmatrix_layout; + +} // namespace tinytc + +namespace std { +template <> struct hash { + inline auto operator()(tinytc::coopmatrix_layout const &key) const -> std::size_t { + return tinytc::fnv1a_combine(key.sty, key.rows, key.cols, key.blocks, key.length, + key.shape1, key.blocks1, key.ops_per_chan); + } +}; +} // namespace std + +#endif // COOPMATRIX_LAYOUT_20250428_HPP diff --git a/src/data_type.cpp b/src/data_type.cpp deleted file mode 100644 index cb9827b3..00000000 --- a/src/data_type.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "error.hpp" -#include "location.hpp" -#include "node/data_type_node.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" -#include "tinytc/types.hpp" -#include "util.hpp" - -#include -#include -#include -#include - -using namespace tinytc; - -extern "C" { -tinytc_status_t tinytc_scalar_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t type, - const tinytc_location_t *loc) { - if (dt == nullptr) { - return tinytc_status_invalid_arguments; - } - - return exception_to_status_code([&] { - *dt = std::make_unique(enum_cast(type), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_memref_type_create(tinytc_data_type_t *dt, tinytc_scalar_type_t scalar_ty, - uint32_t shape_size, const int64_t *shape, - uint32_t stride_size, const int64_t *stride, - const tinytc_location_t *loc) { - if (dt == nullptr) { - return tinytc_status_invalid_arguments; - } - - return exception_to_status_code([&] { - auto shape_vec = std::vector(shape, shape + shape_size); - auto stride_vec = std::vector(); - if (stride_size > 0) { - stride_vec.insert(stride_vec.end(), stride, stride + stride_size); - } - *dt = std::make_unique(enum_cast(scalar_ty), - std::move(shape_vec), std::move(stride_vec), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_group_type_create(tinytc_data_type_t *dt, tinytc_data_type_t memref_ty, - int64_t offset, const tinytc_location_t *loc) { - if (dt == nullptr) { - return tinytc_status_invalid_arguments; - } - - return exception_to_status_code([&] { - *dt = - std::make_unique(data_type(memref_ty, true), offset, get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_data_type_release(tinytc_data_type_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_data_type_retain(tinytc_data_type_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/device_info.cpp b/src/device_info.cpp index 5d239378..92ae86e3 100644 --- a/src/device_info.cpp +++ b/src/device_info.cpp @@ -3,11 +3,15 @@ #include "device_info.hpp" #include "error.hpp" -#include "tinytc/tinytc.h" +#include "tinytc/core.h" +#include "tinytc/core.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" #include +#include #include #include #include @@ -22,7 +26,7 @@ core_info_generic::core_info_generic(std::int32_t register_space, std::int32_t m : register_space_(register_space), max_work_group_size_(max_work_group_size), subgroup_sizes_(std::move(subgroup_sizes)) {} -auto core_info_generic::subgroup_sizes() const -> std::vector const & { +auto core_info_generic::subgroup_sizes() const -> array_view { return subgroup_sizes_; } auto core_info_generic::register_space() const -> std::int32_t { return register_space_; } @@ -36,8 +40,9 @@ auto core_info_generic::get_core_config(std::int32_t subgroup_size) const -> tin subgroup_sizes_.end()) { throw std::out_of_range("Requested subgroup size not available"); } - return core_config{subgroup_size, max_work_group_size_, register_space_, false}; + return core_config{subgroup_size, max_work_group_size_, register_space_, &matrix_}; } +auto core_info_generic::matrix() const -> matrix_ext_info const & { return matrix_; } core_info_intel::core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, std::int32_t num_threads_per_eu, @@ -48,43 +53,58 @@ core_info_intel::core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_ std::sort(subgroup_sizes_.begin(), subgroup_sizes_.end()); register_size_ = 32; - if (ip_version_ >= static_cast(intel_gpu_architecture::pvc)) { + if (is_arch(tinytc_intel_gpu_architecture_pvc)) { register_size_ = 64; + set_spirv_feature(spirv_feature::bfloat16_conversion, true); + + const auto block_info = matrix_ext_block_io_info{.base_address_alignment = 64, + .min_stride = 64, + .max_stride = (1 << 24) - 1, + .pos0_alignment = 4, + .stride_alignment = 8, + .width_alignment = 4}; + matrix_ = matrix_ext_info(16, block_info, pvc_matrix_ext_types); + } else if (is_arch(tinytc_intel_gpu_architecture_bmg)) { + register_size_ = 64; + set_spirv_feature(spirv_feature::bfloat16_conversion, true); + + const auto block_info = matrix_ext_block_io_info{.base_address_alignment = 64, + .min_stride = 64, + .max_stride = (1 << 24) - 1, + .pos0_alignment = 4, + .stride_alignment = 16, + .width_alignment = 4}; + matrix_ = matrix_ext_info(16, block_info, pvc_matrix_ext_types); } - num_registers_per_thread_ = num_reg_small_grf(); } auto core_info_intel::num_reg_small_grf() const -> std::int32_t { return 128; } auto core_info_intel::num_reg_large_grf() const -> std::int32_t { - return ip_version_ >= static_cast(intel_gpu_architecture::pvc) - ? 256 - : num_reg_small_grf(); + if (is_arch(tinytc_intel_gpu_architecture_pvc) || is_arch(tinytc_intel_gpu_architecture_bmg)) { + return 256; + } + return num_reg_small_grf(); } -auto core_info_intel::subgroup_sizes() const -> std::vector const & { - return subgroup_sizes_; +auto core_info_intel::num_reg() const -> std::int32_t { + return core_features_ & tinytc_core_feature_flag_large_register_file ? num_reg_large_grf() + : num_reg_small_grf(); } -auto core_info_intel::register_space() const -> std::int32_t { - return register_size_ * num_registers_per_thread_; -} +auto core_info_intel::subgroup_sizes() const -> array_view { return subgroup_sizes_; } + +auto core_info_intel::register_space() const -> std::int32_t { return register_size_ * num_reg(); } auto core_info_intel::core_features() const -> tinytc_core_feature_flags_t { return core_features_; } -void core_info_intel::core_features(tinytc_core_feature_flags_t flags) { - if (flags & tinytc_core_feature_flag_large_register_file) { - num_registers_per_thread_ = num_reg_large_grf(); - } else { - num_registers_per_thread_ = num_reg_small_grf(); - } -} +void core_info_intel::core_features(tinytc_core_feature_flags_t flags) { core_features_ = flags; } auto core_info_intel::max_work_group_size(std::int32_t subgroup_size) const -> std::int32_t { auto const num_threads_per_eu_due_to_register_use = - num_threads_per_eu_ * num_reg_small_grf() / num_registers_per_thread_; + num_threads_per_eu_ * num_reg_small_grf() / num_reg(); auto const num_threads_per_eu_due_to_subgroup_size = num_threads_per_eu_ * subgroup_sizes_.front() / subgroup_size; auto const num_threads_per_eu = @@ -107,19 +127,19 @@ auto core_info_intel::get_core_config(std::int32_t subgroup_size) const -> core_ throw std::out_of_range("Requested subgroup size not available"); } - bool block_read_write_supported = !(subgroup_size == 32 && register_size_ == 32); - return core_config{subgroup_size, max_work_group_size(subgroup_size), register_space(), - block_read_write_supported}; + &matrix_}; } +auto core_info_intel::matrix() const -> matrix_ext_info const & { return matrix_; } + } // namespace tinytc using namespace tinytc; extern "C" { tinytc_status_t tinytc_core_info_generic_create(tinytc_core_info_t *info, int32_t register_space, - int32_t max_work_group_size, uint32_t sgs_size, + int32_t max_work_group_size, size_t sgs_size, int32_t const *sgs) { if (info == nullptr || sgs == nullptr) { return tinytc_status_invalid_arguments; @@ -142,11 +162,64 @@ tinytc_status_t tinytc_core_info_intel_create_from_arch(tinytc_core_info_t *info *info = std::make_unique(static_cast(arch), 8, 7, std::vector{8, 16, 32}) .release(); + (*info)->set_spirv_feature(spirv_feature::float16, true); + (*info)->set_spirv_feature(spirv_feature::float64, false); + (*info)->set_spirv_feature(spirv_feature::int64_atomics, true); + (*info)->set_spirv_feature(spirv_feature::groups, true); + (*info)->set_spirv_feature(spirv_feature::subgroup_dispatch, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_local, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_global, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_local, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_global, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_local, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_global, false); + (*info)->set_spirv_feature(spirv_feature::bfloat16_conversion, false); + (*info)->set_spirv_feature(spirv_feature::subgroup_buffer_block_io, true); break; case tinytc_intel_gpu_architecture_pvc: + case tinytc_intel_gpu_architecture_bmg: *info = std::make_unique(static_cast(arch), 8, 8, std::vector{16, 32}) .release(); + (*info)->set_spirv_feature(spirv_feature::float16, true); + (*info)->set_spirv_feature(spirv_feature::float64, true); + (*info)->set_spirv_feature(spirv_feature::int64_atomics, true); + (*info)->set_spirv_feature(spirv_feature::groups, true); + (*info)->set_spirv_feature(spirv_feature::subgroup_dispatch, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_local, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float16_add_global, false); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_local, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float32_add_global, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_local, true); + (*info)->set_spirv_feature(spirv_feature::atomic_float64_add_global, true); + (*info)->set_spirv_feature(spirv_feature::bfloat16_conversion, true); + (*info)->set_spirv_feature(spirv_feature::subgroup_buffer_block_io, true); + break; + default: + *info = nullptr; + throw status::invalid_arguments; + } + }); +} + +tinytc_status_t tinytc_core_info_intel_create_from_name(tinytc_core_info_t *info, + char const *name) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + switch (fnv1a(name, std::strlen(name))) { + case "tgl"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_tgl)); + break; + case "pvc"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_pvc)); + break; + case "bmg"_fnv1a: + CHECK_STATUS( + tinytc_core_info_intel_create_from_arch(info, tinytc_intel_gpu_architecture_bmg)); break; default: *info = nullptr; @@ -157,7 +230,7 @@ tinytc_status_t tinytc_core_info_intel_create_from_arch(tinytc_core_info_t *info tinytc_status_t tinytc_core_info_intel_create(tinytc_core_info_t *info, uint32_t ip_version, int32_t num_eus_per_subslice, - int32_t num_threads_per_eu, uint32_t sgs_size, + int32_t num_threads_per_eu, size_t sgs_size, int32_t const *sgs) { if (info == nullptr || sgs == nullptr) { return tinytc_status_invalid_arguments; @@ -170,8 +243,8 @@ tinytc_status_t tinytc_core_info_intel_create(tinytc_core_info_t *info, uint32_t }); } -tinytc_status_t tinytc_core_info_get_subgroup_sizes(const_tinytc_core_info_t info, - uint32_t *sgs_size, int32_t const **sgs) { +tinytc_status_t tinytc_core_info_get_subgroup_sizes(const_tinytc_core_info_t info, size_t *sgs_size, + int32_t const **sgs) { if (info == nullptr || sgs_size == nullptr || sgs == nullptr) { return tinytc_status_invalid_arguments; @@ -203,7 +276,7 @@ tinytc_status_t tinytc_core_info_set_core_features(tinytc_core_info_t info, return exception_to_status_code([&] { info->core_features(flags); }); } -tinytc_status_t tinytc_core_info_get_core_features(tinytc_core_info_t info, +tinytc_status_t tinytc_core_info_get_core_features(const_tinytc_core_info_t info, tinytc_core_feature_flags_t *flags) { if (info == nullptr || flags == nullptr) { return tinytc_status_invalid_arguments; @@ -211,6 +284,42 @@ tinytc_status_t tinytc_core_info_get_core_features(tinytc_core_info_t info, return exception_to_status_code([&] { *flags = info->core_features(); }); } +tinytc_status_t tinytc_core_info_set_spirv_feature(tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t available) { + + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { info->set_spirv_feature(enum_cast(feature), available); }); +} + +tinytc_status_t tinytc_core_info_have_spirv_feature(const_tinytc_core_info_t info, + tinytc_spirv_feature_t feature, + tinytc_bool_t *available) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *available = info->have_spirv_feature(enum_cast(feature)); }); +} + +tinytc_status_t tinytc_core_info_get_default_alignment(const_tinytc_core_info_t info, + int32_t *alignment) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *alignment = info->alignment(); }); +} + +tinytc_status_t tinytc_core_info_set_default_alignment(tinytc_core_info_t info, int32_t alignment) { + if (info == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { info->alignment(alignment); }); +} + tinytc_status_t tinytc_core_info_release(tinytc_core_info_t obj) { if (obj == nullptr) { return tinytc_status_invalid_arguments; diff --git a/src/device_info.hpp b/src/device_info.hpp index a0e638ba..070f673e 100644 --- a/src/device_info.hpp +++ b/src/device_info.hpp @@ -4,21 +4,26 @@ #ifndef DEVICE_INFO_20240304_HPP #define DEVICE_INFO_20240304_HPP +#include "matrix_ext_info.hpp" #include "reference_counted.hpp" +#include "tinytc/core.hpp" #include "tinytc/types.h" +#include #include #include namespace tinytc { +enum class spirv_feature; + //! Core parameters for a specific choice of subgroup size and core feature flags class core_config { public: std::int32_t subgroup_size; ///< Smallest unit of execution std::int32_t max_work_group_size; ///< Maximum size of local work group in number of works items std::int32_t register_space; ///< Size of register file in bytes - bool block_read_write_supported; ///< True if block reads / block writes are suppported + matrix_ext_info const *matrix; }; } // namespace tinytc @@ -26,8 +31,7 @@ class core_config { struct tinytc_core_info : tinytc::reference_counted { //! empty dtor virtual ~tinytc_core_info(); - //! Returns available subgroup sizes - virtual auto subgroup_sizes() const -> std::vector const & = 0; + virtual auto subgroup_sizes() const -> tinytc::array_view = 0; //! Returns availabe register space per subgroup virtual auto register_space() const -> std::int32_t = 0; //! Get core features @@ -39,29 +43,52 @@ struct tinytc_core_info : tinytc::reference_counted { virtual auto minmax_work_group_size() const -> std::int32_t = 0; //! Return core config for specific subgroup size and number of registers per tile virtual auto get_core_config(std::int32_t subgroup_size) const -> tinytc::core_config = 0; + virtual void set_spirv_feature(tinytc::spirv_feature f, bool available) = 0; + virtual auto have_spirv_feature(tinytc::spirv_feature f) const -> bool = 0; + virtual auto matrix() const -> tinytc::matrix_ext_info const & = 0; + virtual auto alignment() const -> std::int32_t = 0; + virtual void alignment(std::int32_t alignment) = 0; }; namespace tinytc { -class core_info_generic : public ::tinytc_core_info { +class core_info_common : public ::tinytc_core_info { + public: + inline void set_spirv_feature(spirv_feature f, bool available) override { + spv_feature_[static_cast(f)] = available; + } + inline auto have_spirv_feature(spirv_feature f) const -> bool override { + return spv_feature_[static_cast(f)]; + } + inline auto alignment() const -> std::int32_t override { return alignment_; } + inline void alignment(std::int32_t alignment) override { alignment_ = alignment; } + + private: + std::array spv_feature_ = {}; + std::int32_t alignment_ = 128; +}; + +class core_info_generic : public core_info_common { public: core_info_generic(std::int32_t register_space, std::int32_t max_work_group_size, std::vector subgroup_sizes); - auto subgroup_sizes() const -> std::vector const & override; + auto subgroup_sizes() const -> array_view override; auto register_space() const -> std::int32_t override; auto core_features() const -> tinytc_core_feature_flags_t override; void core_features(tinytc_core_feature_flags_t flags) override; auto minmax_work_group_size() const -> std::int32_t override; auto get_core_config(std::int32_t subgroup_size) const -> tinytc::core_config override; + auto matrix() const -> matrix_ext_info const & override; private: std::int32_t register_space_; std::int32_t max_work_group_size_; std::vector subgroup_sizes_; + matrix_ext_info matrix_; }; //! Set of core configurations for Intel GPUs -class core_info_intel : public ::tinytc_core_info { +class core_info_intel : public core_info_common { public: /** * @brief ctor @@ -75,8 +102,7 @@ class core_info_intel : public ::tinytc_core_info { core_info_intel(std::uint32_t ip_version, std::int32_t num_eus_per_subslice, std::int32_t num_threads_per_eu, std::vector subgroup_sizes); - //! @copydoc ::tinytc_core_info::subgroup_sizes - auto subgroup_sizes() const -> std::vector const & override; + auto subgroup_sizes() const -> array_view override; //! @copydoc ::tinytc_core_info::register_space auto register_space() const -> std::int32_t override; //! @copydoc ::tinytc_core_info::core_features() @@ -87,10 +113,16 @@ class core_info_intel : public ::tinytc_core_info { auto minmax_work_group_size() const -> std::int32_t override; //! @copydoc ::tinytc_core_info::get_core_config auto get_core_config(std::int32_t subgroup_size) const -> core_config override; + auto matrix() const -> matrix_ext_info const & override; private: + inline auto is_arch(tinytc_intel_gpu_architecture_t arch) const -> bool { + return arch <= ip_version_ && + ip_version_ <= arch + TINYTC_INTEL_GPU_ARCHITECTURE_SUB_VERSION_BITS; + } auto num_reg_small_grf() const -> std::int32_t; auto num_reg_large_grf() const -> std::int32_t; + auto num_reg() const -> std::int32_t; auto max_work_group_size(std::int32_t subgroup_size) const -> std::int32_t; std::uint32_t ip_version_; @@ -98,8 +130,8 @@ class core_info_intel : public ::tinytc_core_info { std::int32_t num_threads_per_eu_; std::vector subgroup_sizes_; std::int32_t register_size_; - std::int32_t num_registers_per_thread_; tinytc_core_feature_flags_t core_features_; + matrix_ext_info matrix_; }; } // namespace tinytc diff --git a/src/enums.cpp.mochi b/src/enums.cpp.mochi new file mode 100644 index 00000000..18a2f56c --- /dev/null +++ b/src/enums.cpp.mochi @@ -0,0 +1,8 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "tinytc/types.h" + +extern "C" { +// もち enum_cpp "tinytc/enums.anko" +} diff --git a/src/error.cpp b/src/error.cpp index 2431eecc..e34f0a83 100644 --- a/src/error.cpp +++ b/src/error.cpp @@ -3,8 +3,8 @@ #include "error.hpp" #include "location.hpp" -#include "tinytc/tinytc.h" +#include #include #include #include @@ -12,12 +12,28 @@ namespace tinytc { compilation_error::compilation_error(location const &loc, status code, std::string extra_info) - : loc_(loc), code_(code), extra_info_(std::move(extra_info)) {} + : loc_(loc), ref_values_{}, num_ref_values_{0}, code_(code), + extra_info_(std::move(extra_info)) {} + +compilation_error::compilation_error(location const &loc, + array_view ref_values, status code, + std::string extra_info) + : loc_(loc), code_(code), extra_info_(std::move(extra_info)) { + num_ref_values_ = std::min(error_max_ref, ref_values.size()); + for (std::size_t i = 0; i < num_ref_values_; ++i) { + ref_values_[i] = ref_values[i]; + } +} auto report_error_with_context(char const *code, std::size_t code_len, std::string const &file_name, location const &l, std::string const &what) -> std::string { constexpr int additional_context_lines = 2; + auto oerr = std::ostringstream{}; + oerr << file_name << ":"; + print_range(oerr, l.begin, l.end); + oerr << ": " << what; + int cur_line = 1; const char *begin = code; const char *limit = begin + code_len; @@ -27,7 +43,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri } ++begin; } - auto oerr = std::ostringstream{}; + char const *end = begin; int start_col = -1; while (cur_line <= l.end.line && *end != '\0' && end <= limit) { @@ -39,7 +55,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri if (start_col < 0) { start_col = static_cast(end - begin); } - oerr << std::string(begin, end) << std::endl; + oerr << std::endl << std::string(begin, end) << std::endl; if (cur_line >= l.begin.line) { int col_begin = 0; int num_col = 0; @@ -63,7 +79,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri col_begin = l.begin.column > 1 ? l.begin.column - 1 : 0; num_col = l.end.column > l.begin.column ? l.end.column - l.begin.column : 1; } - oerr << std::string(col_begin, ' ') << std::string(num_col, '~') << std::endl; + oerr << std::string(col_begin, ' ') << std::string(num_col, '~'); } ++cur_line; start_col = -1; @@ -71,303 +87,7 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri } ++end; } - oerr << file_name << ":"; - print_range(oerr, l.begin, l.end); - oerr << ": " << what; - return oerr.str(); + return std::move(oerr).str(); } } // namespace tinytc - -extern "C" { -char const *tinytc_error_string(tinytc_status_t status) { - switch (status) { - case tinytc_status_success: - return "Success"; - case tinytc_status_bad_alloc: - return "Bad allocation"; - case tinytc_status_invalid_arguments: - return "Invalid arguments passed to function"; - case tinytc_status_out_of_range: - return "Out of range"; - case tinytc_status_runtime_error: - return "General runtime error"; - case tinytc_status_internal_compiler_error: - return "Internal compiler error"; - case tinytc_status_unsupported_subgroup_size: - return "Unsupported subgroup size"; - case tinytc_status_unsupported_work_group_size: - return "Work group size is larger than maximum work group size supported by device"; - case tinytc_status_compilation_error: - return "Compilation error"; - case tinytc_status_file_io_error: - return "I/O error occured in file operation"; - case tinytc_status_parse_error: - return "Parse error"; - case tinytc_status_unavailable_extension: - return "Required vendor extension is unavailable"; - case tinytc_status_unsupported_backend: - return "Unsupport backend"; - case tinytc_status_invalid_kernel_arguments: - return "Invalid arguments passed to kernel"; - case tinytc_status_unsupported_device: - return "Unsupported device"; - // IR - case tinytc_status_ir_out_of_bounds: - return "Argument is out of bounds"; - case tinytc_status_ir_invalid_shape: - return "Mode size must be non-negative"; - case tinytc_status_ir_incompatible_shapes: - return "Incompatible tensor shapes"; - case tinytc_status_ir_shape_stride_mismatch: - return "Dimension of shape and stride must match"; - case tinytc_status_ir_scalar_mismatch: - return "Scalar type mismatch"; - case tinytc_status_ir_invalid_number_of_indices: - return "Number of indices must match memref order or must be 1 for group types"; - case tinytc_status_ir_expected_scalar: - return "Expected scalar type"; - case tinytc_status_ir_expected_memref: - return "Expected memref type"; - case tinytc_status_ir_expected_memref_or_scalar: - return "Expected memref type or scalar type"; - case tinytc_status_ir_expected_memref_or_group: - return "Expected memref or group operand"; - case tinytc_status_ir_expected_vector_or_matrix: - return "Expected vector or matrix input"; - case tinytc_status_ir_unexpected_yield: - return "Yield encountered in non-yielding region"; - case tinytc_status_ir_yield_mismatch: - return "Number of yielded values does not match number of values yielded by region"; - case tinytc_status_ir_multiple_dynamic_modes: - return "At most one mode must be dynamic ('?')"; - case tinytc_status_ir_invalid_slice: - return "Offset must be non-negative and must not be '?'; size must be positive or '?'"; - case tinytc_status_ir_expand_shape_order_too_small: - return "Expand shape must have at least 2 entries"; - case tinytc_status_ir_expand_shape_mismatch: - return "Product of expand shape must equal mode size"; - case tinytc_status_ir_collective_called_from_spmd: - return "Collective instruction must not be called from SPMD region"; - case tinytc_status_ir_fp_unsupported: - return "Floating point type unsupported for instruction"; - // Level Zero - case tinytc_status_ze_result_not_ready: - return "ZE_RESULT_NOT_READY"; - case tinytc_status_ze_result_error_device_lost: - return "ZE_RESULT_ERROR_DEVICE_LOST"; - case tinytc_status_ze_result_error_out_of_host_memory: - return "ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY"; - case tinytc_status_ze_result_error_out_of_device_memory: - return "ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY"; - case tinytc_status_ze_result_error_module_build_failure: - return "ZE_RESULT_ERROR_MODULE_BUILD_FAILURE"; - case tinytc_status_ze_result_error_module_link_failure: - return "ZE_RESULT_ERROR_MODULE_LINK_FAILURE"; - case tinytc_status_ze_result_error_device_requires_reset: - return "ZE_RESULT_ERROR_DEVICE_REQUIRES_RESET"; - case tinytc_status_ze_result_error_device_in_low_power_state: - return "ZE_RESULT_ERROR_DEVICE_IN_LOW_POWER_STATE"; - case tinytc_status_ze_result_exp_error_device_is_not_vertex: - return "ZE_RESULT_EXP_ERROR_DEVICE_IS_NOT_VERTEX"; - case tinytc_status_ze_result_exp_error_vertex_is_not_device: - return "ZE_RESULT_EXP_ERROR_VERTEX_IS_NOT_DEVICE"; - case tinytc_status_ze_result_exp_error_remote_device: - return "ZE_RESULT_EXP_ERROR_REMOTE_DEVICE"; - case tinytc_status_ze_result_exp_error_operands_incompatible: - return "ZE_RESULT_EXP_ERROR_OPERANDS_INCOMPATIBLE"; - case tinytc_status_ze_result_exp_rtas_build_retry: - return "ZE_RESULT_EXP_RTAS_BUILD_RETRY"; - case tinytc_status_ze_result_exp_rtas_build_deferred: - return "ZE_RESULT_EXP_RTAS_BUILD_DEFERRED"; - case tinytc_status_ze_result_error_insufficient_permissions: - return "ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS"; - case tinytc_status_ze_result_error_not_available: - return "ZE_RESULT_ERROR_NOT_AVAILABLE"; - case tinytc_status_ze_result_error_dependency_unavailable: - return "ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE"; - case tinytc_status_ze_result_warning_dropped_data: - return "ZE_RESULT_WARNING_DROPPED_DATA"; - case tinytc_status_ze_result_error_uninitialized: - return "ZE_RESULT_ERROR_UNINITIALIZED"; - case tinytc_status_ze_result_error_unsupported_version: - return "ZE_RESULT_ERROR_UNSUPPORTED_VERSION"; - case tinytc_status_ze_result_error_unsupported_feature: - return "ZE_RESULT_ERROR_UNSUPPORTED_FEATURE"; - case tinytc_status_ze_result_error_invalid_argument: - return "ZE_RESULT_ERROR_INVALID_ARGUMENT"; - case tinytc_status_ze_result_error_invalid_null_handle: - return "ZE_RESULT_ERROR_INVALID_NULL_HANDLE"; - case tinytc_status_ze_result_error_handle_object_in_use: - return "ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE"; - case tinytc_status_ze_result_error_invalid_null_pointer: - return "ZE_RESULT_ERROR_INVALID_NULL_POINTER"; - case tinytc_status_ze_result_error_invalid_size: - return "ZE_RESULT_ERROR_INVALID_SIZE"; - case tinytc_status_ze_result_error_unsupported_size: - return "ZE_RESULT_ERROR_UNSUPPORTED_SIZE"; - case tinytc_status_ze_result_error_unsupported_alignment: - return "ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT"; - case tinytc_status_ze_result_error_invalid_synchronization_object: - return "ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT"; - case tinytc_status_ze_result_error_invalid_enumeration: - return "ZE_RESULT_ERROR_INVALID_ENUMERATION"; - case tinytc_status_ze_result_error_unsupported_enumeration: - return "ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION"; - case tinytc_status_ze_result_error_unsupported_image_format: - return "ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT"; - case tinytc_status_ze_result_error_invalid_native_binary: - return "ZE_RESULT_ERROR_INVALID_NATIVE_BINARY"; - case tinytc_status_ze_result_error_invalid_global_name: - return "ZE_RESULT_ERROR_INVALID_GLOBAL_NAME"; - case tinytc_status_ze_result_error_invalid_kernel_name: - return "ZE_RESULT_ERROR_INVALID_KERNEL_NAME"; - case tinytc_status_ze_result_error_invalid_function_name: - return "ZE_RESULT_ERROR_INVALID_FUNCTION_NAME"; - case tinytc_status_ze_result_error_invalid_group_size_dimension: - return "ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION"; - case tinytc_status_ze_result_error_invalid_global_width_dimension: - return "ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION"; - case tinytc_status_ze_result_error_invalid_kernel_argument_index: - return "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX"; - case tinytc_status_ze_result_error_invalid_kernel_argument_size: - return "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE"; - case tinytc_status_ze_result_error_invalid_kernel_attribute_value: - return "ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE"; - case tinytc_status_ze_result_error_invalid_module_unlinked: - return "ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED"; - case tinytc_status_ze_result_error_invalid_command_list_type: - return "ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE"; - case tinytc_status_ze_result_error_overlapping_regions: - return "ZE_RESULT_ERROR_OVERLAPPING_REGIONS"; - case tinytc_status_ze_result_warning_action_required: - return "ZE_RESULT_WARNING_ACTION_REQUIRED"; - case tinytc_status_ze_result_error_unknown: - return "ZE_RESULT_ERROR_UNKNOWN"; - // OpenCL - case tinytc_status_cl_build_program_failure: - return "CL_BUILD_PROGRAM_FAILURE"; - case tinytc_status_cl_compile_program_failure: - return "CL_COMPILE_PROGRAM_FAILURE"; - case tinytc_status_cl_compiler_not_available: - return "CL_COMPILER_NOT_AVAILABLE"; - case tinytc_status_cl_device_not_found: - return "CL_DEVICE_NOT_FOUND"; - case tinytc_status_cl_device_not_available: - return "CL_DEVICE_NOT_AVAILABLE"; - case tinytc_status_cl_device_partition_failed: - return "CL_DEVICE_PARTITION_FAILED"; - case tinytc_status_cl_exec_status_error_for_events_in_wait_list: - return "CL_EXEC_STATUS_ERROR_FOR_EVENTS_IN_WAIT_LIST"; - case tinytc_status_cl_image_format_mismatch: - return "CL_IMAGE_FORMAT_MISMATCH"; - case tinytc_status_cl_image_format_not_supported: - return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; - case tinytc_status_cl_invalid_arg_index: - return "CL_INVALID_ARG_INDEX"; - case tinytc_status_cl_invalid_arg_size: - return "CL_INVALID_ARG_SIZE"; - case tinytc_status_cl_invalid_arg_value: - return "CL_INVALID_ARG_VALUE"; - case tinytc_status_cl_invalid_binary: - return "CL_INVALID_BINARY"; - case tinytc_status_cl_invalid_buffer_size: - return "CL_INVALID_BUFFER_SIZE"; - case tinytc_status_cl_invalid_build_options: - return "CL_INVALID_BUILD_OPTIONS"; - case tinytc_status_cl_invalid_command_queue: - return "CL_INVALID_COMMAND_QUEUE"; - case tinytc_status_cl_invalid_compiler_options: - return "CL_INVALID_COMPILER_OPTIONS"; - case tinytc_status_cl_invalid_context: - return "CL_INVALID_CONTEXT"; - case tinytc_status_cl_invalid_device: - return "CL_INVALID_DEVICE"; - case tinytc_status_cl_invalid_device_partition_count: - return "CL_INVALID_DEVICE_PARTITION_COUNT"; - case tinytc_status_cl_invalid_device_queue: - return "CL_INVALID_DEVICE_QUEUE"; - case tinytc_status_cl_invalid_device_type: - return "CL_INVALID_DEVICE_TYPE"; - case tinytc_status_cl_invalid_event: - return "CL_INVALID_EVENT"; - case tinytc_status_cl_invalid_event_wait_list: - return "CL_INVALID_EVENT_WAIT_LIST"; - case tinytc_status_cl_invalid_global_offset: - return "CL_INVALID_GLOBAL_OFFSET"; - case tinytc_status_cl_invalid_global_work_size: - return "CL_INVALID_GLOBAL_WORK_SIZE"; - case tinytc_status_cl_invalid_host_ptr: - return "CL_INVALID_HOST_PTR"; - case tinytc_status_cl_invalid_image_descriptor: - return "CL_INVALID_IMAGE_DESCRIPTOR"; - case tinytc_status_cl_invalid_image_format_descriptor: - return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; - case tinytc_status_cl_invalid_image_size: - return "CL_INVALID_IMAGE_SIZE"; - case tinytc_status_cl_invalid_kernel: - return "CL_INVALID_KERNEL"; - case tinytc_status_cl_invalid_kernel_args: - return "CL_INVALID_KERNEL_ARGS"; - case tinytc_status_cl_invalid_kernel_definition: - return "CL_INVALID_KERNEL_DEFINITION"; - case tinytc_status_cl_invalid_kernel_name: - return "CL_INVALID_KERNEL_NAME"; - case tinytc_status_cl_invalid_linker_options: - return "CL_INVALID_LINKER_OPTIONS"; - case tinytc_status_cl_invalid_mem_object: - return "CL_INVALID_MEM_OBJECT"; - case tinytc_status_cl_invalid_operation: - return "CL_INVALID_OPERATION"; - case tinytc_status_cl_invalid_pipe_size: - return "CL_INVALID_PIPE_SIZE"; - case tinytc_status_cl_invalid_platform: - return "CL_INVALID_PLATFORM"; - case tinytc_status_cl_invalid_program: - return "CL_INVALID_PROGRAM"; - case tinytc_status_cl_invalid_program_executable: - return "CL_INVALID_PROGRAM_EXECUTABLE"; - case tinytc_status_cl_invalid_property: - return "CL_INVALID_PROPERTY"; - case tinytc_status_cl_invalid_queue_properties: - return "CL_INVALID_QUEUE_PROPERTIES"; - case tinytc_status_cl_invalid_sampler: - return "CL_INVALID_SAMPLER"; - case tinytc_status_cl_invalid_spec_id: - return "CL_INVALID_SPEC_ID"; - case tinytc_status_cl_invalid_value: - return "CL_INVALID_VALUE"; - case tinytc_status_cl_invalid_work_dimension: - return "CL_INVALID_WORK_DIMENSION"; - case tinytc_status_cl_invalid_work_group_size: - return "CL_INVALID_WORK_GROUP_SIZE"; - case tinytc_status_cl_invalid_work_item_size: - return "CL_INVALID_WORK_ITEM_SIZE"; - case tinytc_status_cl_kernel_arg_info_not_available: - return "CL_KERNEL_ARG_INFO_NOT_AVAILABLE"; - case tinytc_status_cl_link_program_failure: - return "CL_LINK_PROGRAM_FAILURE"; - case tinytc_status_cl_linker_not_available: - return "CL_LINKER_NOT_AVAILABLE"; - case tinytc_status_cl_map_failure: - return "CL_MAP_FAILURE"; - case tinytc_status_cl_mem_copy_overlap: - return "CL_MEM_COPY_OVERLAP"; - case tinytc_status_cl_mem_object_allocation_failure: - return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; - case tinytc_status_cl_misaligned_sub_buffer_offset: - return "CL_MISALIGNED_SUB_BUFFER_OFFSET"; - case tinytc_status_cl_out_of_host_memory: - return "CL_OUT_OF_HOST_MEMORY"; - case tinytc_status_cl_out_of_resources: - return "CL_OUT_OF_RESOURCES"; - case tinytc_status_cl_max_size_restriction_exceeded: - return "CL_MAX_SIZE_RESTRICTION_EXCEEDED"; - case tinytc_status_cl_profiling_info_not_available: - return "CL_PROFILING_INFO_NOT_AVAILABLE"; - case tinytc_status_unknown: - return "Unknown error"; - } - return "Unknown status code"; -} -} diff --git a/src/error.hpp b/src/error.hpp index 3b48d27c..1ac627f1 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -4,13 +4,13 @@ #ifndef ERROR_20240410_HPP #define ERROR_20240410_HPP -#include "parser.hpp" -#include "tinytc/tinytc.hpp" +#include "compiler_context.hpp" +#include "tinytc/core.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" +#include #include -#include #include #include #include @@ -24,19 +24,29 @@ auto report_error_with_context(char const *code, std::size_t code_len, std::stri //! Compilation error class compilation_error : public std::exception { public: + constexpr static std::size_t error_max_ref = 4; + //! ctor; taking location, status code, and expanatory string compilation_error(location const &loc, status code, std::string extra_info = {}); + compilation_error(location const &loc, array_view ref_values, status code, + std::string extra_info = {}); //! Get status code inline auto code() const noexcept { return code_; } //! Get location inline auto loc() const noexcept -> location const & { return loc_; } + inline auto ref_values() const noexcept -> array_view { + return array_view(ref_values_.data(), num_ref_values_); + } + inline auto num_ref_values() const noexcept -> std::size_t { return num_ref_values_; } //! Get explanatory string - inline char const *what() const noexcept override { return error_string(code_); } + inline char const *what() const noexcept override { return to_string(code_); } //! Get additional information inline auto extra_info() const -> std::string const & { return extra_info_; } private: location loc_; + std::array ref_values_; + std::size_t num_ref_values_; status code_; std::string extra_info_; }; @@ -47,7 +57,8 @@ class internal_compiler_error : public std::exception { }; template -auto exception_to_status_code(F &&f, tinytc_source_context_t context = nullptr) -> tinytc_status_t { +auto exception_to_status_code(F &&f, tinytc_compiler_context_t context = nullptr) + -> tinytc_status_t { try { f(); } catch (internal_compiler_error const &e) { @@ -64,8 +75,9 @@ auto exception_to_status_code(F &&f, tinytc_source_context_t context = nullptr) if (e.extra_info().size() > 0) { auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); + context->report_error(e.loc(), e.ref_values(), what.c_str()); } else { - context->report_error(e.loc(), e.what()); + context->report_error(e.loc(), e.ref_values(), e.what()); } } return static_cast(e.code()); diff --git a/src/func.cpp b/src/func.cpp deleted file mode 100644 index 53d5fcfa..00000000 --- a/src/func.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "error.hpp" -#include "location.hpp" -#include "node/function_node.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" - -#include -#include -#include -#include -#include - -using namespace tinytc; - -extern "C" { - -tinytc_status_t tinytc_function_prototype_create(tinytc_func_t *fun, char const *name, - uint32_t arg_list_size, tinytc_value_t *arg_list, - const tinytc_location_t *loc) { - if (fun == nullptr || (arg_list_size > 0 && arg_list == nullptr)) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto arg_vec = std::vector(); - arg_vec.reserve(arg_list_size); - for (uint32_t i = 0; i < arg_list_size; ++i) { - arg_vec.emplace_back(value(arg_list[i], true)); - } - *fun = std::make_unique(std::string(name), std::move(arg_vec), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_function_create(tinytc_func_t *fun, tinytc_func_t prototype, - tinytc_region_t body, const tinytc_location_t *loc) { - if (fun == nullptr || prototype == nullptr || body == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *fun = - std::make_unique(func{prototype, true}, region{body, true}, get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_function_set_work_group_size(tinytc_func_t fun, int32_t x, int32_t y) { - function *f = dynamic_cast(fun); - if (f == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { f->work_group_size({x, y}); }); -} - -tinytc_status_t tinytc_function_set_subgroup_size(tinytc_func_t fun, int32_t sgs) { - function *f = dynamic_cast(fun); - if (f == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { f->subgroup_size(sgs); }); -} - -tinytc_status_t tinytc_func_release(tinytc_func_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_func_retain(tinytc_func_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/gemm_generator.cpp b/src/gemm_generator.cpp deleted file mode 100644 index f101d219..00000000 --- a/src/gemm_generator.cpp +++ /dev/null @@ -1,509 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "gemm_generator.hpp" -#include "codegen_tools.hpp" -#include "device_info.hpp" -#include "precision_helper.hpp" -#include "scalar_type.hpp" -#include "tiling.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -using namespace clir; - -namespace tinytc { - -gemm_scalar_type::gemm_scalar_type(scalar_type ty) : alpha(ty), A(ty), B(ty), beta(ty), C(ty) {} -gemm_scalar_type::gemm_scalar_type(scalar_type alphaAB, scalar_type betaC) - : alpha(alphaAB), A(alphaAB), B(alphaAB), beta(betaC), C(betaC) {} -gemm_scalar_type::gemm_scalar_type(scalar_type alpha, scalar_type A, scalar_type B, - scalar_type beta, scalar_type C) - : alpha(alpha), A(A), B(B), beta(beta), C(C) {} - -std::string gemm_configuration::identifier(std::string_view prefix) const { - std::ostringstream oss; - auto const dyn_val = [&oss](std::int64_t v) { - if (v == dynamic) { - oss << "d"; - } else { - oss << v; - } - }; - auto const stride = [&oss, &dyn_val](char X, std::array const &s) { - oss << "_" << X << "stride"; - dyn_val(s[0]); - oss << "_"; - dyn_val(s[1]); - }; - oss << prefix << "_"; - if (atomic) { - oss << "atomic_"; - } - oss << to_string(ty.alpha) << to_string(ty.A) << to_string(ty.B) << to_string(ty.beta) - << to_string(ty.C) << "_A" << to_string(transA) << "_B" << to_string(transB) << "_M"; - dyn_val(M); - oss << "_N"; - dyn_val(N); - oss << "_K"; - dyn_val(K); - stride('A', A_stride); - stride('B', B_stride); - stride('C', C_stride); - auto const format_optional = [&](std::optional const &val) { - if (val) { - auto f = oss.flags(); - auto v = *val; - oss << std::hex << std::bit_cast(v); - oss.flags(f); - } else { - oss << "d"; - } - }; - oss << "_alpha"; - format_optional(alpha); - oss << "_beta"; - format_optional(beta); - return oss.str(); -} - -constexpr static int max_K_unrolling = 8; - -auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uint32_t sgs, - std::uint32_t register_space, - std::pair max_fill_fraction) - -> std::pair { - auto const arithmetic_intensity = [&sgs](std::uint32_t row_blocks, std::uint32_t cols) { - return (row_blocks * sgs * cols) / static_cast(row_blocks * sgs + cols); - }; - - auto const max_scalars = register_space * max_fill_fraction.first / - (max_fill_fraction.second * C_scalar_type_size_in_bytes); - - // The required number of scalars is given by - // row_blocks * sgs * (cols + max_K_unrolling) + cols * max_K_unrolling - auto const max_row_blocks = [&sgs, &max_scalars](std::uint32_t cols) { - return (max_scalars - cols * max_K_unrolling) / (sgs * (cols + max_K_unrolling)); - }; - auto const max_cols = [&sgs, &max_scalars](std::uint32_t row_blocks) { - return (max_scalars - row_blocks * sgs * max_K_unrolling) / - (row_blocks * sgs + max_K_unrolling); - }; - - double max_ai = 0.0; - std::uint32_t row_blocks = 1, cols = 1; - for (std::uint32_t r = 1; r <= max_row_blocks(1); ++r) { - for (std::uint32_t c = 1; c <= max_cols(r); ++c) { - auto const ai = arithmetic_intensity(r, c); - if (ai > max_ai) { - max_ai = ai; - row_blocks = r; - cols = c; - } - } - } - - return std::make_pair(row_blocks, cols); -} - -class generator { - public: - generator(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, address_space As, address_space Bs, address_space Cs) - : gemm_cfg(gemm_cfg), tiling(tiling), core_cfg(core_cfg), Aspace(As), Bspace(Bs), - Cspace(Cs) {} - void add_microkernel(block_builder &bb, bool is_remainder, expr M, expr N, var A, var B, var C, - expr C_offset, expr alpha, expr beta); - void add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, - expr beta); - void add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta); - ::clir::func function(std::string_view name); - - private: - gemm_configuration const gemm_cfg; - local_tiling const tiling; - core_config const core_cfg; - address_space Aspace, Bspace, Cspace; - unsigned row_blocks_in_register = 1; - unsigned cols_in_register = 1; - var c, m; - std::array MNK; - std::array A_stride, B_stride, C_stride; -}; - -void generator::add_microkernel(block_builder &bb, bool is_remainder, expr M, expr N, var A, var B, - var C, expr C_offset, expr alpha, expr beta) { - std::int64_t n_bs = 0; - bool is_N_constant = false; - dispatch_constant_dynamic( - N, - [&](std::int64_t n) { - n_bs = n; - is_N_constant = true; - }, - [&](expr) { - n_bs = static_cast(cols_in_register); - is_N_constant = false; - }); - std::int64_t const n_blocks = - 1 + (n_bs - 1) / static_cast(core_cfg.subgroup_size); - auto n = var("n"); - - auto my_row_blocks_in_register = row_blocks_in_register; - dispatch_constant_dynamic( - M, - [&](std::int64_t m) { - while (my_row_blocks_in_register > 1 && - m < static_cast(my_row_blocks_in_register) * - core_cfg.subgroup_size) { - --my_row_blocks_in_register; - } - }, - [&](expr) {}); - - auto const am = gemm_cfg.transA == transpose::T ? 1 : 0; - auto const ak = gemm_cfg.transA == transpose::T ? 0 : 1; - auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), "Ab", A); - auto const Aoffset = [&](unsigned m_block) { - return A_stride[am] * (m + m_block * core_cfg.subgroup_size); - }; - - auto const bn = gemm_cfg.transB == transpose::T ? 0 : 1; - auto const bk = gemm_cfg.transB == transpose::T ? 1 : 0; - auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), "Bb", B); - auto const Boffset = [&](int n_block) { - return B_stride[bn] * (m + n_block * core_cfg.subgroup_size); - }; - - auto const cmn = [&](unsigned m_block, expr n) { - return c[m_block + row_blocks_in_register * std::move(n)]; - }; - - bb.add(for_loop_builder(declaration_assignment(generic_short(), n, 0), n < n_bs, ++n) - .body([&](block_builder &bb) { - for (std::size_t m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - bb.assign(cmn(m_block, n), precision_helper{gemm_cfg.ty.C}.zero()); - } - }) - .attribute(opencl_unroll_hint(n_bs)) - .get_product()); - - auto const compute_c = [&](block_builder &bb, std::int64_t Kb, ::clir::expr K0, - ::clir::expr K1) { - auto kb = var("kb"); - bb.add( - for_loop_builder(declaration_assignment(generic_short(), kb, std::move(K0)), - kb < std::move(K1), add_into(kb, Kb)) - .body([&](block_builder &bb) { - auto at = precision_helper{gemm_cfg.ty.A}; - auto a = bb.declare(array_of(at.type(), my_row_blocks_in_register * Kb), "a"); - auto amk = [&](unsigned m_block, unsigned k) { - return a[m_block + my_row_blocks_in_register * k]; - }; - bool const map_b_to_vec_type = - gemm_cfg.B_stride[bk] == 1 && - (Kb == 2 || Kb == 3 || Kb == 4 || Kb == 8 || Kb == 16); - int k_load_block_size = map_b_to_vec_type ? Kb : 1; - auto bt = precision_helper{gemm_cfg.ty.B}; - auto b = map_b_to_vec_type - ? bb.declare(array_of(bt.type(Kb), n_blocks), "b") - : bb.declare(array_of(bt.type(), n_blocks * Kb), "b"); - auto const read_A = [&](block_builder &bb, unsigned m_block, unsigned k, - bool check) { - auto condition = m + m_block * core_cfg.subgroup_size < M; - auto rhs = Ab[Aoffset(m_block)]; - auto rhs_checked = - check ? ternary_conditional(std::move(condition), rhs, 0) : rhs; - bb.assign(amk(m_block, k), std::move(rhs_checked)); - }; - auto block_read_A = [&](block_builder &bb, unsigned m_block, unsigned k) { - bb.assign( - amk(m_block, k), - at.sub_group_block_read(Ab + m_block * core_cfg.subgroup_size, Aspace)); - }; - for (unsigned k = 0; k < Kb; ++k) { - for (unsigned m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - if (!is_remainder && core_cfg.block_read_write_supported && - gemm_cfg.A_stride[am] == 1) { - block_read_A(bb, m_block, k); - } else { - read_A(bb, m_block, k, is_remainder); - } - } - bb.add(add_into(Ab, A_stride[ak])); - } - - auto const read_B = [&](block_builder &bb, int k, int n_block, bool check) { - auto condition = m + n_block * core_cfg.subgroup_size < N; - if (map_b_to_vec_type) { - auto rhs = vload_helper(Kb, 0, Bb + Boffset(n_block)); - if (rhs) { - auto rhs_checked = - check ? ternary_conditional(condition, rhs, - init_vector(bt.type(Kb), {0})) - : rhs; - bb.assign(b[n_block], std::move(rhs_checked)); - } else { - throw std::logic_error("Vload for native type missing"); - } - } else { - auto rhs = Bb[Boffset(n_block)]; - auto rhs_checked = check ? ternary_conditional(condition, rhs, 0) : rhs; - bb.assign(b[k + n_block * Kb], std::move(rhs_checked)); - } - }; - int first_n_block_with_check = - n_bs < n_blocks * static_cast(core_cfg.subgroup_size) - ? n_blocks - 1 - : n_blocks; - if (!is_N_constant) { - first_n_block_with_check = 0; - } - for (int k = 0; k < Kb; k += k_load_block_size) { - for (int n_block = 0; n_block < first_n_block_with_check; ++n_block) { - read_B(bb, k, n_block, false); - } - for (int n_block = first_n_block_with_check; n_block < n_blocks; - ++n_block) { - read_B(bb, k, n_block, true); - } - bb.add(add_into(Bb, k_load_block_size * B_stride[bk])); - } - - const int nbb = 4; - for (unsigned m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - for (std::int64_t nb = 0; nb < n_bs; nb += nbb) { - for (int k = 0; k < Kb; ++k) { - for (std::int64_t n = 0; n < nbb; ++n) { - if (nb + n < n_bs) { - auto const n_block = (nb + n) / core_cfg.subgroup_size; - auto const n_offset = (nb + n) % core_cfg.subgroup_size; - auto my_a = amk(m_block, k); - auto bkn = map_b_to_vec_type ? b[n_block].s(k) - : b[k + n_block * Kb]; - auto my_b = sub_group_broadcast(std::move(bkn), n_offset); - auto my_c = cmn(m_block, nb + n); - if (gemm_cfg.ty.A == gemm_cfg.ty.B && - gemm_cfg.ty.B == gemm_cfg.ty.C) { - bb.assign(my_c, - fma(std::move(my_a), std::move(my_b), my_c)); - } else { - bb.add(add_into(std::move(my_c), - std::move(my_a) * std::move(my_b))); - } - } - } - } - } - } - }) - .attribute(opencl_unroll_hint(1)) - .get_product()); - }; - dispatch_constant_dynamic( - MNK[2], - [&](std::int64_t K) { - static_assert(max_K_unrolling % 2 == 0, "max_K_unrolling must be a multiple of 2"); - auto Kb = max_K_unrolling; - while (K < Kb && Kb > 1) { - Kb /= 2; - } - auto KmultipleKb = (K / Kb) * Kb; - compute_c(bb, Kb, 0, KmultipleKb); - if (K - KmultipleKb > 0) { - compute_c(bb, 1, KmultipleKb, K); - } - }, - [&](expr K) { - auto KmultipleKb = bb.declare_assign(generic_uint(), "KmultipleKb", - (K / max_K_unrolling) * max_K_unrolling); - compute_c(bb, max_K_unrolling, 0, KmultipleKb); - bb.add(if_selection_builder(K - KmultipleKb > 0) - .then([&](block_builder &bb) { compute_c(bb, 1, KmultipleKb, K); }) - .get_product()); - }); - auto write_C = [&](block_builder &bb) { - auto n_to = is_N_constant ? n_bs : min(N, cast(generic_uint(), n_bs)); - auto n_unroll = is_N_constant ? n_bs : 1; -#if 0 - // We can use block writes if - // 1. They are supported (no block writes for SIMD32 before PVC) - // 2. We are not in a remainder loop - // 3. Data is adjacent in memory - // 4. The address is 16 byte aligned - // 5. We are not writing atomically - bool const use_block_write = - core_cfg.block_read_write_supported && !is_remainder && gemm_cfg.C_stride[0] == 1 && - gemm_cfg.C_stride[1] * size(gemm_cfg.ty.C) % 16 == 0 && !gemm_cfg.atomic -#endif - - // Block writes are disabled for now; would need to track memref alignment in subview / load - // instruction AND would need to impose alignment requirement in calling convention - constexpr bool use_block_write = false; - auto Cb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), "Cb", - C + C_offset); - if (!use_block_write) { - bb.add(add_into(Cb, C_stride[0] * m)); - } - bb.add( - for_loop_builder(declaration_assignment(generic_short(), n, 0), n < std::move(n_to), - ++n) - .body([&](block_builder &bb) { - for (std::size_t m_block = 0; m_block < my_row_blocks_in_register; ++m_block) { - auto my_c = alpha * cmn(m_block, n); - auto C_offset_m = C_stride[0] * (m_block * core_cfg.subgroup_size); - if (use_block_write) { - bb.add(precision_helper{gemm_cfg.ty.C}.sub_group_block_write( - Cb + m_block * core_cfg.subgroup_size, - std::move(my_c) + beta * Cb[std::move(C_offset_m)], Cspace)); - } else { - auto const write_C_mn = [&](block_builder &bb) { - store_helper(bb, gemm_cfg.atomic, Cb + C_offset_m, gemm_cfg.ty.C, - Cspace, my_c, beta); - }; - if (is_remainder) { - bb.add( - if_selection_builder(m + m_block * core_cfg.subgroup_size < M) - .then(write_C_mn) - .get_product()); - } else { - write_C_mn(bb); - } - } - } - bb.add(add_into(Cb, cast(generic_uint(), C_stride[1]))); - }) - .attribute(opencl_unroll_hint(n_unroll)) - .get_product()); - }; - write_C(bb); -} - -void generator::add_mloop(block_builder &bb, expr N, var A, var B, var C, expr C_offset, expr alpha, - expr beta) { - auto sg_m = bb.declare_assign(generic_uint(), "sg_m", get_sub_group_id() % tiling.m_tiles()); - tile_loop_by_sgs( - bb, MNK[0], core_cfg.subgroup_size * row_blocks_in_register, tiling.m_tiles(), - std::move(sg_m), - [&](block_builder &bb, expr block, bool is_remainder, expr inner_trip_count) { - auto Astride_m = gemm_cfg.transA == transpose::T ? A_stride[1] : A_stride[0]; - auto Ab = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), - "Ab", A + std::move(Astride_m) * block); - add_microkernel(bb, is_remainder, std::move(inner_trip_count), N, std::move(Ab), B, C, - C_stride[0] * std::move(block) + C_offset, alpha, beta); - }); -} - -void generator::add_function_body(block_builder &bb, var A, var B, var C, expr alpha, expr beta) { - m = bb.declare_assign(generic_uint(), "m", get_sub_group_local_id()); - c = var("c"); - - auto [max_row_blocks, max_cols] = max_register_block_gemm( - size(gemm_cfg.ty.C), core_cfg.subgroup_size, core_cfg.register_space); - row_blocks_in_register = max_row_blocks; - cols_in_register = max_cols; - if (!is_dynamic_value(gemm_cfg.M)) { - auto const row_blocks_needed_to_cover_M = 1 + (gemm_cfg.M - 1) / core_cfg.subgroup_size; - if (row_blocks_needed_to_cover_M < max_row_blocks) { - row_blocks_in_register = row_blocks_needed_to_cover_M; - } else { - auto blocks = gemm_cfg.M / row_blocks_in_register; - auto sg_blocks = 1 + (blocks - 1) / tiling.m_tiles(); - while (sg_blocks < tiling.m_tiles() && row_blocks_in_register >= 2) { - row_blocks_in_register /= 2; - blocks = gemm_cfg.M / row_blocks_in_register; - sg_blocks = 1 + (blocks - 1) / tiling.m_tiles(); - } - } - } - if (!is_dynamic_value(gemm_cfg.N)) { - cols_in_register = - tile_loop_uniformly_max_block_size(gemm_cfg.N, cols_in_register, tiling.n_tiles()); - } - bb.declare( - array_of(precision_helper{gemm_cfg.ty.C}.type(), row_blocks_in_register * cols_in_register), - c); - - auto sg_n = bb.declare_assign(generic_uint(), "sg_n", get_sub_group_id() / tiling.m_tiles()); - tile_loop_uniformly( - bb, MNK[1], max_cols, tiling.n_tiles(), std::move(sg_n), - [&](block_builder &bb, expr block, expr inner_trip_count) { - auto Bstride_n = gemm_cfg.transB == transpose::T ? B_stride[0] : B_stride[1]; - auto Bb = bb.declare_assign(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), - "Bb", B + std::move(Bstride_n) * block); - add_mloop(bb, std::move(inner_trip_count), A, std::move(Bb), C, - C_stride[1] * std::move(block), alpha, beta); - }); -} - -::clir::func generator::function(std::string_view name) { - auto A = var("A"); - auto B = var("B"); - auto C = var("C"); - - auto fb = ::clir::function_builder{std::string(name)}; - auto const scalar = [&](precision_helper const &fph, std::optional const &val, - std::string const &prefix) -> expr { - auto v = var{prefix}; - fb.argument(fph.type(), v); - return val ? fph.constant(*val) : v; - }; - auto const shape = [&](std::int64_t shape, expr &target, std::string const &prefix) { - auto v = var{prefix}; - fb.argument(to_clir_ty(scalar_type::index), v); - target = is_dynamic_value(shape) ? expr{std::move(v)} : expr{shape}; - }; - auto const stride = [&](std::array const &stride, std::array &target, - std::string const &prefix) { - for (std::size_t i = 0; i < stride.size(); ++i) { - auto v = var{prefix}; - fb.argument(to_clir_ty(scalar_type::index), v); - target[i] = is_dynamic_value(stride[i]) ? expr{std::move(v)} : expr{stride[i]}; - } - }; - - shape(gemm_cfg.M, MNK[0], "M"); - shape(gemm_cfg.N, MNK[1], "N"); - shape(gemm_cfg.K, MNK[2], "K"); - expr alpha = scalar(precision_helper{gemm_cfg.ty.alpha}, gemm_cfg.alpha, "alpha"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.A}.type(Aspace)), A); - stride(gemm_cfg.A_stride, A_stride, "A_stride"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.B}.type(Bspace)), B); - stride(gemm_cfg.B_stride, B_stride, "B_stride"); - expr beta = scalar(precision_helper{gemm_cfg.ty.beta}, gemm_cfg.beta, "beta"); - fb.argument(pointer_to(precision_helper{gemm_cfg.ty.C}.type(Cspace)), C); - stride(gemm_cfg.C_stride, C_stride, "C_stride"); - - fb.body([&](block_builder &bb) { add_function_body(bb, A, B, C, alpha, beta); }); - - auto f = fb.get_product(); - make_names_unique(f); - unsafe_simplify(f); - - return f; -} - -::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, address_space As, - address_space Bs, address_space Cs) { - return generator{gemm_cfg, tiling, core_cfg, As, Bs, Cs}.function(name); -} - -} // namespace tinytc diff --git a/src/gemm_generator.hpp b/src/gemm_generator.hpp deleted file mode 100644 index 8646ef70..00000000 --- a/src/gemm_generator.hpp +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef GEMM_GENERATOR_20240314_HPP -#define GEMM_GENERATOR_20240314_HPP - -#include "device_info.hpp" -#include "tiling.hpp" -#include "tinytc/types.hpp" - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace tinytc { - -//! Struct to handle mixed precision GEMMs -struct gemm_scalar_type { - //! alpha, A, B, beta, C all have the same type - gemm_scalar_type(scalar_type ty); - //! alpha's, A's, and B's type is different from beta's and C's type - gemm_scalar_type(scalar_type alphaAB, scalar_type betaC); - //! All operands potentially have a different type - gemm_scalar_type(scalar_type alpha, scalar_type A, scalar_type B, scalar_type beta, - scalar_type C); - scalar_type alpha, ///< @f$\alpha@f$ type - A, ///< A element type - B, ///< B element type - beta, ///< @f$\beta@f$ type - C; ///< C element type -}; - -/** - * @brief GEMM configuration struct - * - * The interface supports the operation - * - * C = alpha * opA(A) * opB(B) + beta * C, - * - * where - * - * opA/B(X) = transA/B == T ? X^T : X - * - * C is an MxN matrix, A is a MxK matrix, and B is a KxN matrix. - * - * The address of a matrix is calculated as following. Let X be element of {A,B,C}, then - * - * X(i,j) = X[i * X_stride[0] + j * X_stride[1]] - * - * If the atomic flag is set, C is updated atomically, either using - * - * * beta = 0: atomic store - * * beta = 1: atomic fetch add - * * general beta: atomic compare exchange - */ -struct gemm_configuration { - gemm_scalar_type ty; ///< scalar types of alpha, A, B, beta, C - transpose transA; ///< Transposition of A - transpose transB; ///< Transposition of B - std::int64_t M; ///< M, can be set to dynamic - std::int64_t N; ///< N, can be set to dynamic - std::int64_t K; ///< K, can be set to dynamic - std::array A_stride; ///< stride of A, entries can be set to dynamic - std::array B_stride; ///< stride of B, entries can be set to dynamic - std::array C_stride; ///< stride of C, entries can be set to dynamic - std::optional alpha; ///< fixed alpha if set; dynamic alpha if std::nullopt - std::optional beta; ///< fixed beta if set; dynamic beta if std::nullopt - bool atomic = false; ///< update C atomically - - std::string identifier( - std::string_view prefix = "gemm") const; ///< convert configuration to identification string -}; - -/** - * @brief Generate GEMM - * - * @param gemm_cfg configuration - * @param tiling Size of 2D subgroup grid - * @param core_cfg Core configuration - * @param name Routine prefix - * @param As Memory space of A (global or local) - * @param Bs Memory space of B (global or local) - * @param Cs Memory space of C (global or local) - * - * @return OpenCL-C AST - */ -::clir::func generate_gemm(gemm_configuration const &gemm_cfg, local_tiling const &tiling, - core_config const &core_cfg, std::string_view name, - ::clir::address_space As = ::clir::address_space::global_t, - ::clir::address_space Bs = ::clir::address_space::global_t, - ::clir::address_space Cs = ::clir::address_space::global_t); - -/** - * @brief Calculate maximum register blocking size of GEMM - * - * @param C_scalar_type_size_in_bytes Size of scalar type of result matrix in bytes - * @param sgs Subgroup size - * @param register_space Size of register file per core in bytes - * @param max_fill_fraction Fraction of register file that shall be blocked at most - * - * @return {number of row-blocks (block size = subgroup size), number of columns} - */ -auto max_register_block_gemm(std::uint32_t C_scalar_type_size_in_bytes, std::uint32_t sgs, - std::uint32_t register_space, - std::pair max_fill_fraction = {1, 2}) - -> std::pair; - -} // namespace tinytc - -#endif // GEMM_GENERATOR_20240314_HPP diff --git a/src/gemm_tools.cpp b/src/gemm_tools.cpp new file mode 100644 index 00000000..5f1fe374 --- /dev/null +++ b/src/gemm_tools.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "gemm_tools.hpp" + +#include + +namespace tinytc { + +auto max_register_block_gemm(std::int32_t A_size, std::int32_t B_size, std::int32_t C_size, + std::int32_t subgroup_size, std::int32_t register_space, + std::int32_t C_blocks, + std::pair max_fill_fraction) + -> std::pair { + auto const arithmetic_intensity = [](std::int32_t rows, std::int32_t cols) { + return (rows * cols) / static_cast(rows + cols); + }; + + auto const max_bytes = register_space * max_fill_fraction.first / max_fill_fraction.second; + + constexpr std::int32_t max_K = standard_K_block_sizes.back(); + + // The required number of bytes is given by + // num_bytes = rows * (cols * C_blocks * C_size + max_K * A_size) + cols * max_K * B_size. + // Thus + // rows <= (max_bytes - cols * max_K * B_size) / (cols * C_blocks * C_size + max_K * A_size). + // Moreover, we require rows % subgroup_size = 0, so we set rows = k * subgroup_size and get + // k <= (max_bytes - cols * max_K * B_size) / + // (subgroup_size * (cols * C_blocks * C_size + max_K * A_size)). + auto const max_rows = [&](std::int32_t cols) { + const auto k = (max_bytes - cols * max_K * B_size) / + (subgroup_size * (cols * C_blocks * C_size + max_K * A_size)); + return k * subgroup_size; + }; + // Here, we have + // cols <= (max_bytes - rows * max_K * A_size) / (rows * C_blocks * C_size + max_K * B_size). + auto const max_cols = [&](std::int32_t rows) { + return (max_bytes - rows * max_K * A_size) / (rows * C_blocks * C_size + max_K * B_size); + }; + + double max_ai = 0.0; + std::int32_t rows = subgroup_size, cols = 1; + for (std::int32_t r = subgroup_size; r <= max_rows(1); r += subgroup_size) { + for (std::int32_t c = 1; c <= max_cols(r); ++c) { + auto const ai = arithmetic_intensity(r, c); + if (ai > max_ai) { + max_ai = ai; + rows = r; + cols = c; + } + } + } + + return std::make_pair(rows, cols); +} + +// We have block_size(k) = k * subgroup_size, where k is a positive integer, +// and num_blocks(k) = ceil(size / block_size(k)) +// We want to solve +// max_k block_size(k) s.t. +// block_size(k) <= max_block_size ; must not exceed max block size +// and num_blocks(k) % num_tiles == 0 ; no load imbalance +// and block_size(k) - size < sgs ; no excessive block size +// +// If the above optimization does not have a solution, the minimum block size (= subgroup size) is +// returned) +// +auto compute_m_block_size(std::int32_t subgroup_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t { + auto const block_size = [&subgroup_size](std::int32_t k) -> std::int32_t { + return k * subgroup_size; + }; + auto const num_blocks = [&block_size, &size](std::int32_t k) -> std::int64_t { + return 1 + (size - 1) / block_size(k); + }; + std::int32_t k = max_block_size / subgroup_size; + while (k > 1 && (num_blocks(k) % num_tiles != 0 || block_size(k) - size >= subgroup_size)) { + --k; + } + return k * subgroup_size; +} + +auto choose_block_size_multiple(std::int32_t min_block_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t { + auto const block_size = [&min_block_size](std::int32_t k) -> std::int32_t { + return k * min_block_size; + }; + auto const num_blocks = [&block_size, &size](std::int32_t k) -> std::int64_t { + return 1 + (size - 1) / block_size(k); + }; + std::int32_t k = 1; + while (2 * k * min_block_size < max_block_size) { + k *= 2; + } + while (k > 1 && (num_blocks(k) % num_tiles != 0 || block_size(k) - size >= min_block_size)) { + k /= 2; + } + return k; +} + +// Similar as compute_m_block_size for fixed sizes +// block_sizes array must be sorted in ascending order +auto choose_block_size(array_view block_sizes, std::int32_t num_tiles, + std::int64_t size) -> std::int32_t { + auto const num_blocks = [&size](std::int32_t block_size) -> std::int64_t { + return 1 + (size - 1) / block_size; + }; + std::size_t k = block_sizes.size() - 1; + while (k > 1 && (num_blocks(block_sizes[k]) % num_tiles != 0 || + block_sizes[k] - size >= block_sizes[0])) { + --k; + } + return block_sizes[k]; +} + +auto choose_k_block_size(array_view block_sizes, std::int64_t K) -> std::int32_t { + std::int64_t j = static_cast(block_sizes.size()) - 1; + for (; K < block_sizes[j] && j > 0; --j) { + } + return block_sizes[j]; +} + +} // namespace tinytc diff --git a/src/gemm_tools.hpp b/src/gemm_tools.hpp new file mode 100644 index 00000000..8d5de16f --- /dev/null +++ b/src/gemm_tools.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GEMM_TOOLS_20241022_HPP +#define GEMM_TOOLS_20241022_HPP + +#include "tinytc/core.hpp" + +#include +#include +#include + +namespace tinytc { + +constexpr static std::array standard_K_block_sizes = {1, 2, 4, 8}; + +/** + * @brief Calculate maximum register blocking size of GEMM + * + * @param A_size Size of scalar type of A matrix in bytes + * @param B_size Size of scalar type of B matrix in bytes + * @param C_size Size of scalar type of result matrix in bytes + * @param subgroup_size Subgroup size + * @param register_space Size of register file per core in bytes + * @param C_blocks Number of register blocks needed for C, usually 1, 2 for complex + * @param max_fill_fraction Fraction of register file that shall be blocked at most + * + * @return {number of rows, number of columns} + */ +auto max_register_block_gemm(std::int32_t A_size, std::int32_t B_size, std::int32_t C_size, + std::int32_t subgroup_size, std::int32_t register_space, + std::int32_t C_blocks = 1, + std::pair max_fill_fraction = {1, 2}) + -> std::pair; + +auto compute_m_block_size(std::int32_t subgroup_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t; +auto choose_block_size_multiple(std::int32_t min_block_size, std::int32_t max_block_size, + std::int32_t num_tiles, std::int64_t size) -> std::int32_t; +auto choose_block_size(array_view block_sizes, std::int32_t num_tiles, + std::int64_t size) -> std::int32_t; +auto choose_k_block_size(array_view block_sizes, std::int64_t K) -> std::int32_t; + +} // namespace tinytc + +#endif // GEMM_TOOLS_20241022_HPP diff --git a/src/half.cpp b/src/half.cpp new file mode 100644 index 00000000..edcf38a8 --- /dev/null +++ b/src/half.cpp @@ -0,0 +1,31 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "tinytc/core.h" +#include "tinytc/core.hpp" + +#include +#include + +using namespace tinytc; + +extern "C" { + +uint16_t tinytc_f32_to_f16_as_ui16(float x) { + return ieee754_truncate(std::bit_cast(x)); +} + +float tinytc_f16_as_ui16_to_f32(uint16_t x) { + const auto y = ieee754_extend(x); + return std::bit_cast(y); +} + +uint16_t tinytc_f32_to_bf16_as_ui16(float x) { + return ieee754_truncate(std::bit_cast(x)); +} + +float tinytc_bf16_as_ui16_to_f32(uint16_t x) { + const auto y = ieee754_extend(x); + return std::bit_cast(y); +} +} diff --git a/src/inst.cpp b/src/inst.cpp deleted file mode 100644 index 06f41460..00000000 --- a/src/inst.cpp +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "error.hpp" -#include "location.hpp" -#include "node/inst_node.hpp" -#include "slice.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" -#include "tinytc/types.hpp" -#include "util.hpp" - -#include -#include -#include -#include -#include -#include -#include - -using namespace tinytc; - -extern "C" { -char const *tinytc_arithmetic_to_string(tinytc_arithmetic_t op) { - switch (op) { - case tinytc_arithmetic_add: - return "add"; - case tinytc_arithmetic_sub: - return "sub"; - case tinytc_arithmetic_mul: - return "mul"; - case tinytc_arithmetic_div: - return "div"; - case tinytc_arithmetic_rem: - return "rem"; - case tinytc_arithmetic_shl: - return "shl"; - case tinytc_arithmetic_shr: - return "shr"; - case tinytc_arithmetic_and: - return "and"; - case tinytc_arithmetic_or: - return "or"; - case tinytc_arithmetic_xor: - return "xor"; - } - return "unknown"; -} - -char const *tinytc_arithmetic_unary_to_string(tinytc_arithmetic_unary_t op) { - switch (op) { - case tinytc_arithmetic_unary_neg: - return "neg"; - case tinytc_arithmetic_unary_not: - return "not"; - } - return "unknown"; -} - -char const *tinytc_cmp_condition_to_string(tinytc_cmp_condition_t cond) { - switch (cond) { - case tinytc_cmp_condition_eq: - return "eq"; - case tinytc_cmp_condition_ne: - return "ne"; - case tinytc_cmp_condition_gt: - return "gt"; - case tinytc_cmp_condition_ge: - return "ge"; - case tinytc_cmp_condition_lt: - return "lt"; - case tinytc_cmp_condition_le: - return "le"; - } - return "unknown"; -} - -char const *tinytc_transpose_to_string(tinytc_transpose_t t) { - switch (t) { - case tinytc_transpose_T: - return "t"; - case tinytc_transpose_N: - return "n"; - } - return "unknown"; -} - -tinytc_status_t tinytc_arith_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_t op, - tinytc_value_t a, tinytc_value_t b, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), value(a, true), - value(b, true), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_arith_unary_inst_create(tinytc_inst_t *instr, tinytc_arithmetic_unary_t op, - tinytc_value_t a, const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(op), value(a, true), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_cast_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - tinytc_scalar_type_t to_ty, const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), enum_cast(to_ty), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_cmp_inst_create(tinytc_inst_t *instr, tinytc_cmp_condition_t cond, - tinytc_value_t a, tinytc_value_t b, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(cond), value(a, true), - value(b, true), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_alloca_inst_create(tinytc_inst_t *instr, tinytc_data_type_t ty, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(data_type(ty, true), get_optional(loc)).release(); - }); -} - -tinytc_status_t tinytc_axpby_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_bool_t atomic, tinytc_value_t alpha, - tinytc_value_t A, tinytc_value_t beta, tinytc_value_t B, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(beta, true), value(B, true), - bool(atomic), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_expand_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, - uint32_t expand_shape_size, tinytc_value_t *expand_shape, - const tinytc_location_t *loc) { - if (instr == nullptr || expand_shape == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto eshape_vec = std::vector(); - eshape_vec.reserve(expand_shape_size); - for (uint32_t i = 0; i < expand_shape_size; ++i) { - eshape_vec.emplace_back(value(expand_shape[i], true)); - } - *instr = std::make_unique(value(a, true), mode, std::move(eshape_vec), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_fuse_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t from, - int64_t to, const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), from, to, get_optional(loc)).release(); - }); -} - -tinytc_status_t tinytc_load_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t index_list_size, tinytc_value_t *index_list, - const tinytc_location_t *loc) { - if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto il_vec = std::vector(); - il_vec.reserve(index_list_size); - for (uint32_t i = 0; i < index_list_size; ++i) { - il_vec.emplace_back(value(index_list[i], true)); - } - *instr = std::make_unique(value(a, true), std::move(il_vec), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_group_id_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); -} - -tinytc_status_t tinytc_group_size_inst_create(tinytc_inst_t *instr, const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *instr = std::make_unique(get_optional(loc)).release(); }); -} - -tinytc_status_t tinytc_gemm_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_transpose_t tB, tinytc_bool_t atomic, - tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), enum_cast(tB), - value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_gemv_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_bool_t atomic, tinytc_value_t alpha, - tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, - tinytc_value_t C, const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(B, true), value(beta, true), - value(C, true), bool(atomic), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_ger_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, - tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, - tinytc_value_t beta, tinytc_value_t C, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_hadamard_inst_create(tinytc_inst_t *instr, tinytc_bool_t atomic, - tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t B, tinytc_value_t beta, tinytc_value_t C, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(value(alpha, true), value(A, true), value(B, true), - value(beta, true), value(C, true), bool(atomic), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_size_inst_create(tinytc_inst_t *instr, tinytc_value_t a, int64_t mode, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(value(a, true), mode, get_optional(loc)).release(); - }); -} - -tinytc_status_t tinytc_subview_inst_create(tinytc_inst_t *instr, tinytc_value_t a, - uint32_t slice_list_size, tinytc_value_t *offset_list, - tinytc_value_t *size_list, - const tinytc_location_t *loc) { - if (instr == nullptr || - (slice_list_size > 0 && (offset_list == nullptr || size_list == nullptr))) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto slice_vec = std::vector(); - slice_vec.reserve(slice_list_size); - for (uint32_t i = 0; i < slice_list_size; ++i) { - slice_vec.emplace_back(value(offset_list[i], true), value(size_list[i], true)); - } - *instr = - std::make_unique(value(a, true), std::move(slice_vec), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_store_inst_create(tinytc_inst_t *instr, tinytc_value_t val, tinytc_value_t a, - uint32_t index_list_size, tinytc_value_t *index_list, - const tinytc_location_t *loc) { - if (instr == nullptr || (index_list_size > 0 && index_list == nullptr)) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto il_vec = std::vector(); - il_vec.reserve(index_list_size); - for (uint32_t i = 0; i < index_list_size; ++i) { - il_vec.emplace_back(value(index_list[i], true)); - } - *instr = std::make_unique(value(val, true), value(a, true), std::move(il_vec), - get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_sum_inst_create(tinytc_inst_t *instr, tinytc_transpose_t tA, - tinytc_bool_t atomic, tinytc_value_t alpha, tinytc_value_t A, - tinytc_value_t beta, tinytc_value_t B, - const tinytc_location_t *loc) { - if (instr == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = std::make_unique(enum_cast(tA), value(alpha, true), - value(A, true), value(beta, true), value(B, true), - bool(atomic), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_for_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, tinytc_value_t step, - tinytc_region_t body, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var == nullptr || from == nullptr || to == nullptr || - body == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = - std::make_unique(value(loop_var, true), value(from, true), value(to, true), - value(step, true), region(body, true), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_foreach_inst_create(tinytc_inst_t *instr, tinytc_value_t loop_var, - tinytc_value_t from, tinytc_value_t to, - tinytc_region_t body, const tinytc_location_t *loc) { - if (instr == nullptr || loop_var == nullptr || from == nullptr || to == nullptr || - body == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *instr = - std::make_unique(value(loop_var, true), value(from, true), - value(to, true), region(body, true), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_if_inst_create(tinytc_inst_t *instr, tinytc_value_t condition, - tinytc_region_t then, tinytc_region_t otherwise, - uint32_t return_type_list_size, - tinytc_scalar_type_t *return_type_list, - const tinytc_location_t *loc) { - if (instr == nullptr || condition == nullptr || then == nullptr || - (return_type_list_size > 0 && return_type_list == nullptr)) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto rt = std::vector(); - rt.reserve(return_type_list_size); - for (uint32_t i = 0; i < return_type_list_size; ++i) { - rt.emplace_back(enum_cast(return_type_list[i])); - } - *instr = - std::make_unique(value(condition, true), region(then, true), - region(otherwise, true), std::move(rt), get_optional(loc)) - .release(); - }); -} - -tinytc_status_t tinytc_yield_inst_create(tinytc_inst_t *instr, uint32_t yield_list_size, - tinytc_value_t *yield_list, const tinytc_location_t *loc) { - if (instr == nullptr || yield_list_size == 0 || yield_list == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto yl = std::vector(); - yl.reserve(yield_list_size); - for (uint32_t i = 0; i < yield_list_size; ++i) { - yl.emplace_back(value(yield_list[i], true)); - } - *instr = std::make_unique(std::move(yl), get_optional(loc)).release(); - }); -} - -tinytc_status_t tinytc_inst_release(tinytc_inst_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_inst_retain(tinytc_inst_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} - -tinytc_status_t tinytc_inst_get_value(const_tinytc_inst_t instr, tinytc_value_t *result) { - if (instr == nullptr || result == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *result = instr->result().release(); }); -} - -tinytc_status_t tinytc_inst_get_values(const_tinytc_inst_t instr, uint32_t *result_list_size, - tinytc_value_t *result_list) { - if (instr == nullptr || result_list_size == nullptr || - (*result_list_size > 0 && result_list == nullptr)) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto const num_results = instr->num_results(); - if (num_results > std::numeric_limits::max()) { - throw std::out_of_range("too many results"); - } - auto const num = static_cast(num_results); - if (*result_list_size > 0) { - auto results = instr->results(); - if (results.size() != num_results) { - throw internal_compiler_error(); - } - auto const limit = std::min(num, *result_list_size); - for (uint32_t i = 0; i < limit; ++i) { - result_list[i] = results[i].release(); - } - } - *result_list_size = num; - }); -} -} diff --git a/src/matrix_ext_info.cpp b/src/matrix_ext_info.cpp new file mode 100644 index 00000000..ac4c328d --- /dev/null +++ b/src/matrix_ext_info.cpp @@ -0,0 +1,169 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "matrix_ext_info.hpp" +#include "node/type.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +matrix_ext_type::matrix_ext_type(TK a, TK b, std::vector acc, std::vector mnk) + : a_{a}, b_{b}, acc_(std::move(acc)), mnk_(std::move(mnk)) {} + +auto matrix_ext_type::have_acc(TK acc) const -> bool { + return std::find(acc_.begin(), acc_.end(), acc) != acc_.end(); +} +auto matrix_ext_type::have_type(TK sty, std::int64_t rows, std::int64_t cols, matrix_use use) const + -> bool { + auto find_shape = [&rows, &cols](array_view mnks, std::int64_t gemm_mnk::*shape0, + std::int64_t gemm_mnk::*shape1) { + for (auto const &mnk : mnks) { + if (mnk.*shape0 == rows && mnk.*shape1 == cols) { + return true; + } + } + return false; + }; + switch (use) { + case matrix_use::a: + return a() == sty && find_shape(mnk(), &gemm_mnk::M, &gemm_mnk::K); + case matrix_use::b: + return b() == sty && find_shape(mnk(), &gemm_mnk::K, &gemm_mnk::N); + case matrix_use::acc: + return have_acc(sty) && find_shape(mnk(), &gemm_mnk::M, &gemm_mnk::N); + } + throw status::internal_compiler_error; +} + +template +auto block_sizes(std::vector const &mnks, Get get) -> std::vector { + auto bs = std::vector{}; + bs.reserve(mnks.size()); + for (auto &mnk : mnks) { + auto val = get(mnk); + if (val) { + bs.push_back(*val); + } + } + std::sort(bs.begin(), bs.end()); + auto end = std::unique(bs.begin(), bs.end()); + bs.erase(end, bs.end()); + return bs; +} + +auto matrix_ext_type::M_block_sizes() const -> std::vector { + return block_sizes(mnk_, + [](gemm_mnk const &mnk) -> std::optional { return mnk.M; }); +} +auto matrix_ext_type::N_block_sizes(std::int32_t M) const -> std::vector { + return block_sizes(mnk_, [&M](gemm_mnk const &mnk) -> std::optional { + return mnk.M == M ? std::make_optional(mnk.N) : std::nullopt; + }); +} +auto matrix_ext_type::K_block_sizes(std::int32_t M, std::int32_t N) const + -> std::vector { + return block_sizes(mnk_, [&M, N](gemm_mnk const &mnk) -> std::optional { + return mnk.M == M && mnk.N == N ? std::make_optional(mnk.K) : std::nullopt; + }); +} + +auto matrix_ext_info::get_precision(TK a, TK b, TK acc) const -> matrix_ext_type const * { + for (auto const &mat_type : mat_types_) { + if (mat_type.a() == a && mat_type.b() == b && mat_type.have_acc(acc)) { + return &mat_type; + } + } + return nullptr; +} + +auto matrix_ext_info::have_gemm(TK a, TK b, TK c, TK d, std::int64_t M, std::int64_t N, + std::int64_t K) const -> bool { + for (auto const &mat_type : mat_types_) { + if (mat_type.have_type(a, M, K, matrix_use::a) && + mat_type.have_type(b, K, N, matrix_use::b) && + mat_type.have_type(c, M, N, matrix_use::acc) && + mat_type.have_type(d, M, N, matrix_use::acc)) { + return true; + } + } + return false; +} + +auto matrix_ext_info::have_precision(TK a, TK b, TK acc) const -> bool { + return get_precision(a, b, acc) != nullptr; +} + +auto matrix_ext_info::have_type(TK sty, std::int64_t rows, std::int64_t cols, matrix_use use) const + -> bool { + for (auto const &mat_type : mat_types_) { + if (mat_type.have_type(sty, rows, cols, use)) { + return true; + } + } + return false; +} + +auto matrix_ext_info::have_type(const coopmatrix_type *ty) const -> bool { + return have_type(ty->component_ty()->type_id(), ty->rows(), ty->cols(), ty->use()); +} + +const std::array pvc_matrix_ext_types = {{{TK::TK_i8, + TK::TK_i8, + {TK::TK_i32}, + {{16, 8, 32}, + {32, 8, 32}, + {64, 8, 32}, + {16, 16, 32}, + {32, 16, 32}, + {64, 16, 32}, + {16, 32, 32}, + {32, 32, 32}, + {64, 32, 32}, + {16, 8, 64}, + {32, 8, 64}, + {64, 8, 64}, + {16, 16, 64}, + {32, 16, 64}, + {64, 16, 64}, + {16, 32, 64}, + {32, 32, 64}, + {64, 32, 64}}}, + {TK::TK_f16, + TK::TK_f16, + {TK::TK_f16, TK::TK_f32}, + {{16, 8, 16}, + {32, 8, 16}, + {16, 16, 16}, + {32, 16, 16}, + {16, 24, 16}, + {32, 24, 16}, + {16, 32, 16}, + {32, 32, 16}, + {16, 8, 32}, + {32, 8, 32}, + {16, 16, 32}, + {32, 16, 32}, + {16, 24, 32}, + {32, 24, 32}, + {16, 32, 32}, + {32, 32, 32}}}, + {TK::TK_bf16, + TK::TK_bf16, + {TK::TK_bf16, TK::TK_f32}, + {{16, 8, 16}, + {32, 8, 16}, + {16, 16, 16}, + {32, 16, 16}, + {16, 32, 16}, + {32, 32, 16}, + {16, 8, 32}, + {32, 8, 32}, + {16, 16, 32}, + {32, 16, 32}, + {16, 32, 32}, + {32, 32, 32}}}}}; + +} // namespace tinytc diff --git a/src/matrix_ext_info.hpp b/src/matrix_ext_info.hpp new file mode 100644 index 00000000..318dee8b --- /dev/null +++ b/src/matrix_ext_info.hpp @@ -0,0 +1,84 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MATRIX_EXT_INFO_20241204_HPP +#define MATRIX_EXT_INFO_20241204_HPP + +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +enum class TK; + +struct gemm_mnk { + std::int64_t M, N, K; +}; + +class matrix_ext_type { + public: + matrix_ext_type(TK a, TK b, std::vector acc, std::vector mnk); + + inline auto a() const -> TK { return a_; } + inline auto b() const -> TK { return b_; } + inline auto acc() const -> array_view { return acc_; } + inline auto mnk() const -> array_view { return mnk_; } + + auto M_block_sizes() const -> std::vector; + auto N_block_sizes(std::int32_t M) const -> std::vector; + auto K_block_sizes(std::int32_t M, std::int32_t N) const -> std::vector; + + auto have_acc(TK acc) const -> bool; + auto have_type(TK sty, std::int64_t rows, std::int64_t cols, matrix_use use) const -> bool; + + private: + TK a_, b_; + std::vector acc_; + std::vector mnk_; +}; + +struct matrix_ext_block_io_info { + std::int32_t base_address_alignment; + std::int32_t min_stride; + std::int32_t max_stride; + std::int32_t pos0_alignment; + std::int32_t stride_alignment; + std::int32_t width_alignment; +}; + +class matrix_ext_info { + public: + matrix_ext_info() = default; + inline matrix_ext_info(std::int32_t required_subgroup_size, matrix_ext_block_io_info block_io, + array_view mat_types) + : required_sgs_{required_subgroup_size}, block_io_{block_io}, + mat_types_(std::move(mat_types)) {} + + auto get_precision(TK a, TK b, TK acc) const -> matrix_ext_type const *; + auto have_gemm(TK a, TK b, TK c, TK d, std::int64_t M, std::int64_t N, std::int64_t K) const + -> bool; + auto have_precision(TK a, TK b, TK acc) const -> bool; + auto have_type(TK sty, std::int64_t rows, std::int64_t cols, matrix_use use) const -> bool; + auto have_type(const coopmatrix_type *ty) const -> bool; + + inline auto required_subgroup_size() const -> std::int32_t { return required_sgs_; } + inline auto block_io() const -> matrix_ext_block_io_info const & { return block_io_; } + + inline auto have_dpas() const { return mat_types_.size() > 0; } + + private: + std::int32_t required_sgs_; + matrix_ext_block_io_info block_io_; + array_view mat_types_; +}; + +extern const std::array pvc_matrix_ext_types; + +} // namespace tinytc + +#endif // MATRIX_EXT_INFO_20241204_HPP diff --git a/src/node/attr.cpp b/src/node/attr.cpp new file mode 100644 index 00000000..9cebc5f4 --- /dev/null +++ b/src/node/attr.cpp @@ -0,0 +1,229 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/attr.hpp" +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" +#include "error.hpp" +#include "support/fnv1a_array_view.hpp" // IWYU pragma: keep +#include "tinytc/builder.h" +#include "tinytc/core.h" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include + +using namespace tinytc; + +namespace tinytc { + +auto array_attr::get(tinytc_compiler_context_t ctx, array_view values) + -> tinytc_attr_t { + const auto hash = fnv1a_combine(values); + const auto is_equal = [&](tinytc_attr_t a) { + const auto aa = dyn_cast(a); + return aa && + std::equal(values.begin(), values.end(), aa->values().begin(), aa->values().end()); + }; + const auto make = [&]() { return new array_attr(ctx, values); }; + + auto &attrs = ctx->cache()->array_attrs; + return attrs.get(hash, std::move(is_equal), std::move(make)); +} + +array_attr::array_attr(tinytc_compiler_context_t ctx, std::vector values) + : tinytc_attr(AK::AK_array, ctx), values_{std::move(values)} {} + +auto boolean_attr::get(tinytc_compiler_context_t ctx, bool value) -> tinytc_attr_t { + auto cache = ctx->cache(); + return value ? cache->true_attr.get() : cache->false_attr.get(); +} + +boolean_attr::boolean_attr(tinytc_compiler_context_t ctx, bool value) + : tinytc_attr(AK::AK_boolean, ctx), value_{value} {} + +auto dictionary_attr::get(tinytc_compiler_context_t ctx, + array_view sorted_attrs) -> tinytc_attr_t { + const auto hash = [&] { + auto h = fnv1a0(); + for (auto const &na : sorted_attrs) { + h = fnv1a_step(h, na.name); + h = fnv1a_step(h, na.attr); + } + return h; + }; + const auto is_equal = [&](tinytc_attr_t a) { + const auto da = dyn_cast(a); + return da && std::equal(sorted_attrs.begin(), sorted_attrs.end(), da->begin(), da->end(), + [](tinytc_named_attr_t const &a, tinytc_named_attr_t const &b) { + return a.name == b.name && a.attr == b.attr; + }); + }; + const auto make = [&]() { return new dictionary_attr(ctx, sorted_attrs); }; + + auto &attrs = ctx->cache()->dictionary_attrs; + return attrs.get(hash(), std::move(is_equal), std::move(make)); +} + +auto dictionary_attr::get_name_string(tinytc_attr_t name) -> std::string_view { + auto stra = dyn_cast(name); + if (stra) { + return stra->str(); + } + throw status::ir_expected_string_attribute; +} + +void dictionary_attr::sort(mutable_array_view unsorted_attrs) { + if (unsorted_attrs.empty()) { + return; + } + std::sort(unsorted_attrs.begin(), unsorted_attrs.end(), + [](tinytc_named_attr_t const &a, tinytc_named_attr_t const &b) { + return get_name_string(a.name) < get_name_string(b.name); + }); + for (std::size_t i = 1; i < unsorted_attrs.size(); ++i) { + if (unsorted_attrs[i - 1].name == unsorted_attrs[i].name) { + throw status::ir_duplicate_key_in_dictionary; + } + } +} + +dictionary_attr::dictionary_attr(tinytc_compiler_context_t ctx, + std::vector sorted_attrs) + : tinytc_attr(AK::AK_dictionary, ctx), attrs_{std::move(sorted_attrs)} {} + +auto dictionary_attr::find(tinytc_attr_t name) -> tinytc_attr_t { + if (!attrs_.empty() && name) { + auto namestr = get_name_string(name); + std::size_t lb = 0; + std::size_t ub = attrs_.size(); + + do { + std::size_t mid = (lb + ub) / 2; + auto cmp = namestr.compare(get_name_string(attrs_[mid].name)); + if (cmp == 0) { + return attrs_[mid].attr; + } else if (cmp < 0) { + ub = mid; + } else { + lb = mid + 1; + } + } while (ub > lb); + } + return nullptr; +} +auto dictionary_attr::find(std::string_view name) -> tinytc_attr_t { + return find(string_attr::get(context(), name)); +} + +auto integer_attr::get(tinytc_compiler_context_t ctx, std::int64_t value) -> tinytc_attr_t { + const auto hash = fnv1a_combine(value); + const auto is_equal = [&](tinytc_attr_t a) { + const auto ia = dyn_cast(a); + return ia && value == ia->value(); + }; + const auto make = [&]() { return new integer_attr(ctx, value); }; + + auto &attrs = ctx->cache()->integer_attrs; + return attrs.get(hash, std::move(is_equal), std::move(make)); +} + +integer_attr::integer_attr(tinytc_compiler_context_t ctx, std::int64_t value) + : tinytc_attr(AK::AK_integer, ctx), value_{value} {} + +auto string_attr::get(tinytc_compiler_context_t ctx, std::string_view str) -> tinytc_attr_t { + const auto hash = fnv1a_combine(str); + const auto is_equal = [&](tinytc_attr_t a) { + const auto stra = dyn_cast(a); + return stra && str == stra->str(); + }; + const auto make = [&]() { return new string_attr(ctx, std::string{str}); }; + + auto &attrs = ctx->cache()->string_attrs; + return attrs.get(hash, std::move(is_equal), std::move(make)); +} + +string_attr::string_attr(tinytc_compiler_context_t ctx, std::string str) + : tinytc_attr(AK::AK_string, ctx), str_{std::move(str)} {} + +auto get_attr(tinytc_attr_t dict, tinytc_attr_t name) -> tinytc_attr_t { + if (auto da = dyn_cast(dict); da) { + return da->find(name); + } + return nullptr; +} +auto get_attr(tinytc_attr_t dict, std::string_view name) -> tinytc_attr_t { + if (auto da = dyn_cast(dict); da) { + return da->find(name); + } + return nullptr; +} + +} // namespace tinytc + +extern "C" { + +tinytc_status_t tinytc_array_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + size_t array_size, const tinytc_attr_t *array) { + if (attr == nullptr || ctx == nullptr || (array_size != 0 && array == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *attr = array_attr::get(ctx, array_view(array, array_size)); }); +} + +tinytc_status_t tinytc_boolean_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + tinytc_bool_t value) { + if (attr == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *attr = boolean_attr::get(ctx, value); }); +} + +tinytc_status_t tinytc_dictionary_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + size_t items_size, tinytc_named_attr_t *items) { + TINYTC_CHECK_STATUS(tinytc_dictionary_attr_sort(items_size, items)); + return tinytc_dictionary_attr_get_with_sorted(attr, ctx, items_size, items); +} + +tinytc_status_t tinytc_dictionary_attr_get_with_sorted(tinytc_attr_t *attr, + tinytc_compiler_context_t ctx, + size_t items_size, + const tinytc_named_attr_t *items) { + if (attr == nullptr || ctx == nullptr || (items_size != 0 && items == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *attr = dictionary_attr::get(ctx, array_view(items, items_size)); + }); +} + +tinytc_status_t tinytc_dictionary_attr_sort(size_t items_size, tinytc_named_attr_t *items) { + return exception_to_status_code( + [&] { dictionary_attr::sort(mutable_array_view(items, items_size)); }); +} + +tinytc_status_t tinytc_integer_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + int64_t value) { + if (attr == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *attr = integer_attr::get(ctx, value); }); +} + +tinytc_status_t tinytc_string_attr_get(tinytc_attr_t *attr, tinytc_compiler_context_t ctx, + size_t str_length, char const *str) { + if (attr == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *attr = string_attr::get(ctx, std::string_view(str, str_length)); }); +} +} diff --git a/src/node/attr.hpp b/src/node/attr.hpp new file mode 100644 index 00000000..02b642a3 --- /dev/null +++ b/src/node/attr.hpp @@ -0,0 +1,151 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ATTR_20250626_HPP +#define ATTR_20250626_HPP + +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { +enum class AK; +} // namespace tinytc + +struct tinytc_attr { + public: + inline tinytc_attr(tinytc::AK tid, tinytc_compiler_context_t ctx) : tid_(tid), ctx_(ctx) {} + virtual ~tinytc_attr() = default; + inline auto type_id() const -> tinytc::AK { return tid_; } + inline auto context() const -> tinytc_compiler_context_t { return ctx_; } + + private: + tinytc::AK tid_; + tinytc_compiler_context_t ctx_; +}; + +namespace tinytc { + +enum class AK { AK_array, AK_boolean, AK_dictionary, AK_integer, AK_string }; + +class array_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::AK_array; } + static auto get(tinytc_compiler_context_t ctx, array_view values) + -> tinytc_attr_t; + + inline auto begin() const -> std::vector::const_iterator { + return values_.begin(); + } + inline auto end() const -> std::vector::const_iterator { return values_.end(); } + inline auto size() const { return values_.size(); } + inline auto const &values() const { return values_; } + inline auto value(std::size_t i) const -> tinytc_attr_t { return values_[i]; } + + protected: + array_attr(tinytc_compiler_context_t ctx, std::vector values); + + private: + std::vector values_; +}; + +class boolean_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::AK_boolean; } + static auto get(tinytc_compiler_context_t ctx, bool value) -> tinytc_attr_t; + + inline auto value() const { return value_; } + + protected: + boolean_attr(tinytc_compiler_context_t ctx, bool value); + friend class compiler_context_cache; + + private: + bool value_; +}; + +class dictionary_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::AK_dictionary; } + static auto get(tinytc_compiler_context_t ctx, array_view sorted_attrs) + -> tinytc_attr_t; + static void sort(mutable_array_view unsorted_attrs); + + inline auto begin() const -> std::vector::const_iterator { + return attrs_.begin(); + } + inline auto end() const -> std::vector::const_iterator { + return attrs_.end(); + } + inline auto const &attrs() const { return attrs_; } + + auto find(tinytc_attr_t name) -> tinytc_attr_t; + auto find(std::string_view name) -> tinytc_attr_t; + + protected: + dictionary_attr(tinytc_compiler_context_t ctx, std::vector sorted_attrs); + + private: + static auto get_name_string(tinytc_attr_t name) -> std::string_view; + + std::vector attrs_; +}; + +class integer_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::AK_integer; } + static auto get(tinytc_compiler_context_t ctx, std::int64_t value) -> tinytc_attr_t; + + inline auto value() const -> std::int64_t { return value_; } + + protected: + integer_attr(tinytc_compiler_context_t ctx, std::int64_t value); + + private: + std::int64_t value_; +}; + +class string_attr : public tinytc_attr { + public: + inline static bool classof(tinytc_attr const &a) { return a.type_id() == AK::AK_string; } + static auto get(tinytc_compiler_context_t ctx, std::string_view str) -> tinytc_attr_t; + + inline auto str() const -> std::string_view { return str_; } + + protected: + string_attr(tinytc_compiler_context_t ctx, std::string str); + + private: + std::string str_; +}; + +auto get_attr(tinytc_attr_t dict, tinytc_attr_t name) -> tinytc_attr_t; +auto get_attr(tinytc_attr_t dict, std::string_view name) -> tinytc_attr_t; + +template auto get_array_attr_as(tinytc_attr_t a) -> std::vector { + auto aa = dyn_cast(a); + if (!aa) { + throw status::ir_expected_array_attribute; + } + auto result = std::vector{}; + result.reserve(aa->size()); + for (auto const &va : *aa) { + auto val = dyn_cast_or_throw(va, [&] { + return status::ir_expected_integer_attribute; + })->value(); + result.emplace_back(static_cast(val)); + } + return result; +} + +} // namespace tinytc + +#endif // ATTR_20250626_HPP diff --git a/src/node/data_type_node.cpp b/src/node/data_type_node.cpp deleted file mode 100644 index df570f75..00000000 --- a/src/node/data_type_node.cpp +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "node/data_type_node.hpp" -#include "error.hpp" -#include "tinytc/types.hpp" - -#include -#include - -namespace tinytc { - -memref_data_type::memref_data_type(scalar_type type, std::vector shape, - std::vector stride, location const &lc) - : element_ty_(std::move(type)), shape_(std::move(shape)), stride_(std::move(stride)) { - loc(lc); - for (auto const &s : shape_) { - if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(loc(), status::ir_invalid_shape); - } - } - if (stride_.empty()) { - stride_ = canonical_stride(); - } else { - for (auto const &s : stride_) { - if (s < 0 && !is_dynamic_value(s)) { - throw compilation_error(loc(), status::ir_invalid_shape); - } - } - } - if (stride_.size() != shape_.size()) { - throw compilation_error(loc(), status::ir_shape_stride_mismatch); - } -} - -auto memref_data_type::canonical_stride() const -> std::vector { - if (shape_.empty()) { - return {}; - } - auto stride = std::vector(shape_.size(), dynamic); - stride[0] = 1; - for (std::size_t i = 0; i < shape_.size() - 1 && !is_dynamic_value(shape_[i]); ++i) { - stride[i + 1] = stride[i] * shape_[i]; - } - return stride; -} - -} // namespace tinytc diff --git a/src/node/data_type_node.hpp b/src/node/data_type_node.hpp deleted file mode 100644 index 8321554c..00000000 --- a/src/node/data_type_node.hpp +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef DATA_TYPE_NODE_20230309_HPP -#define DATA_TYPE_NODE_20230309_HPP - -#include "reference_counted.hpp" -#include "scalar_type.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include -#include -#include - -#include -#include -#include -#include - -namespace tinytc { -using data_type_nodes = clir::virtual_type_list; -} - -struct tinytc_data_type : tinytc::reference_counted, tinytc::data_type_nodes { - public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - - private: - tinytc::location loc_; -}; - -namespace tinytc { - -using data_type_node = ::tinytc_data_type; - -class group_data_type : public clir::visitable { - public: - inline group_data_type(data_type ty, std::int64_t offset = 0, location const &lc = {}) - : ty_(std::move(ty)), offset_(offset) { - loc(lc); - } - - inline auto ty() const -> data_type const & { return ty_; } - inline auto offset() const -> std::int64_t { return offset_; } - - private: - data_type ty_; - std::int64_t offset_; -}; - -class void_data_type : public clir::visitable {}; - -class memref_data_type : public clir::visitable { - public: - memref_data_type(scalar_type type, std::vector shape, - std::vector stride = {}, location const &lc = {}); - - inline scalar_type element_ty() const { return element_ty_; } - inline clir::data_type clir_element_ty() const { return to_clir_ty(element_ty_, addrspace_); } - inline clir::data_type clir_atomic_element_ty() const { - return to_clir_atomic_ty(element_ty_, addrspace_); - } - inline std::int64_t dim() const { return shape_.size(); } - inline auto const &shape() const { return shape_; } - inline std::int64_t shape(std::int64_t i) const { return shape_[i]; } - inline auto const &stride() const { return stride_; } - inline std::int64_t stride(std::int64_t i) const { return stride_[i]; } - inline std::int64_t size_in_bytes() const { - return is_dynamic() ? dynamic : size(element_ty_) * stride_.back() * shape_.back(); - } - inline clir::address_space addrspace() const { return addrspace_; } - inline void addrspace(clir::address_space space) { addrspace_ = space; } - - inline bool is_dynamic_shape() const { - return std::any_of(shape_.begin(), shape_.end(), is_dynamic_value); - } - inline bool is_dynamic_stride() const { - return std::any_of(stride_.begin(), stride_.end(), is_dynamic_value); - } - inline bool is_dynamic() const { return is_dynamic_shape() || is_dynamic_stride(); } - inline bool is_canonical_stride() const { return stride_ == canonical_stride(); } - - private: - auto canonical_stride() const -> std::vector; - - scalar_type element_ty_; - std::vector shape_, stride_; - clir::address_space addrspace_ = clir::address_space::global_t; -}; - -class scalar_data_type : public clir::visitable { - public: - inline scalar_data_type(scalar_type type, location const &lc) : ty_(type) { loc(lc); } - - inline scalar_type ty() const { return ty_; } - inline clir::data_type clir_ty() const { return to_clir_ty(ty_); } - - private: - scalar_type ty_; -}; - -} // namespace tinytc - -#endif // DATA_TYPE_NODE_20230309_HPP diff --git a/src/node/func.cpp b/src/node/func.cpp new file mode 100644 index 00000000..11a4e980 --- /dev/null +++ b/src/node/func.cpp @@ -0,0 +1,116 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/func.hpp" +#include "error.hpp" +#include "location.hpp" +#include "node/attr.hpp" +#include "tinytc/builder.h" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" + +#include +#include +#include +#include +#include + +using namespace tinytc; + +tinytc_func::tinytc_func(std::string name, tinytc::array_view params, + tinytc_type_t ty, tinytc_location const &lc) + : name_(std::move(name)), ty_{ty}, loc_{lc} { + body_.kind(tinytc::region_kind::collective); + body_.loc(loc_); + body_.set_params(std::move(params)); +} + +void tinytc_func::param_attr(std::size_t param_no, tinytc_attr_t a) { + if (param_no >= num_params()) { + throw compilation_error(loc(), status::invalid_arguments); + } + if (param_attr_.size() != num_params()) { + param_attr_.resize(num_params(), nullptr); + } + param_attr_[param_no] = a; +} +auto tinytc_func::param_attr(std::size_t param_no) const -> tinytc_attr_t { + if (param_no >= num_params()) { + throw compilation_error(loc(), status::invalid_arguments); + } + if (param_attr_.empty()) { + return nullptr; + } + return param_attr_[param_no]; +} + +auto tinytc_func::subgroup_size() const -> std::int32_t { + if (auto sgs_attr = get_attr(attr_, "subgroup_size"); sgs_attr) { + auto sgs = dyn_cast_or_throw(sgs_attr, [&] { + return compilation_error(loc_, status::ir_expected_integer_attribute); + }); + return sgs->value(); + } + throw compilation_error(loc_, status::internal_compiler_error, "Subgroup size is missing"); +} + +auto tinytc_func::work_group_size() const -> std::array { + if (auto wgs_attr = get_attr(attr_, "work_group_size"); wgs_attr) { + auto wgs_array = dyn_cast_or_throw( + wgs_attr, [&] { return compilation_error(loc_, status::ir_expected_array_attribute); }); + if (wgs_array->size() != 2) { + throw compilation_error(loc_, status::ir_unexpected_array_attribute_size, + "Work group size attribute must have 2 entries"); + } + auto wgs = std::array{}; + for (std::size_t i = 0; i < 2; ++i) { + wgs[i] = dyn_cast_or_throw(wgs_array->value(i), [&] { + return compilation_error(loc_, status::ir_expected_integer_attribute); + })->value(); + } + return wgs; + } + throw compilation_error(loc_, status::internal_compiler_error, "Work group size is missing"); +} + +extern "C" { + +tinytc_status_t tinytc_func_create(tinytc_func_t *fun, size_t name_length, char const *name, + size_t num_params, const tinytc_type_t *param_type_list, + tinytc_type_t ty, const tinytc_location_t *loc) { + if (fun == nullptr || (num_params > 0 && param_type_list == nullptr) || ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + *fun = std::make_unique(std::string(name, name_length), + array_view(param_type_list, num_params), ty, + get_optional(loc)) + .release(); + }); +} + +tinytc_status_t tinytc_func_set_parameter_attr(tinytc_func_t fun, size_t arg_no, tinytc_attr_t a) { + if (fun == nullptr || arg_no >= fun->num_params()) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { fun->param_attr(arg_no, a); }); +} + +tinytc_status_t tinytc_func_set_attr(tinytc_func_t fun, tinytc_attr_t a) { + if (fun == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { fun->attr(a); }); +} + +tinytc_status_t tinytc_func_get_body(tinytc_func_t fun, tinytc_region_t *body) { + if (fun == nullptr || body == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *body = &fun->body(); }); +} + +void tinytc_func_destroy(tinytc_func_t obj) { delete obj; } +} diff --git a/src/node/func.hpp b/src/node/func.hpp new file mode 100644 index 00000000..0c256b1e --- /dev/null +++ b/src/node/func.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FUNC_20250626_HPP +#define FUNC_20250626_HPP + +#include "node/region.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" + +#include +#include +#include +#include +#include +#include + +struct tinytc_func final { + public: + tinytc_func(std::string name, tinytc::array_view params, tinytc_type_t ty, + tinytc_location const &lc = {}); + + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } + inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } + + inline auto ty() const noexcept -> tinytc_type_t { return ty_; } + + inline auto params() { return body_.params(); } + inline auto params() const { return body_.params(); } + inline auto num_params() const noexcept { return body_.num_params(); } + + inline auto name() const -> std::string_view { return name_; } + inline auto body() -> tinytc_region & { return body_; } + inline auto body() const -> tinytc_region const & { return body_; } + + inline void attr(tinytc_attr_t a) { attr_ = a; } + inline auto attr() const -> tinytc_attr_t { return attr_; } + + void param_attr(std::size_t param_no, tinytc_attr_t a); + auto param_attr(std::size_t param_no) const -> tinytc_attr_t; + + auto subgroup_size() const -> std::int32_t; + auto work_group_size() const -> std::array; + + private: + std::string name_; + tinytc_type_t ty_; + tinytc_region body_; + tinytc_location loc_; + tinytc_attr_t attr_ = nullptr; + std::vector param_attr_; +}; + +#endif // FUNC_20250626_HPP diff --git a/src/node/function_node.hpp b/src/node/function_node.hpp deleted file mode 100644 index aa964558..00000000 --- a/src/node/function_node.hpp +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef FUNCTION_NODE_20230310_HPP -#define FUNCTION_NODE_20230310_HPP - -#include "location.hpp" -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" - -#include - -#include -#include -#include -#include -#include - -namespace tinytc { -using function_nodes = clir::virtual_type_list; -} - -struct tinytc_func : tinytc::reference_counted, tinytc::function_nodes { - public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - - virtual auto name() const -> std::string_view = 0; - - private: - tinytc::location loc_; -}; - -namespace tinytc { - -using function_node = ::tinytc_func; - -class prototype : public clir::visitable { - public: - inline prototype(std::string name, std::vector args = {}, location const &lc = {}) - : name_(std::move(name)), args_(std::move(args)) { - loc(lc); - } - - inline auto name() const -> std::string_view override { return name_; } - inline auto args() const -> std::vector const & { return args_; } - - private: - std::string name_; - std::vector args_; -}; - -class function : public clir::visitable { - public: - inline function(func prototype, region body, location const &lc = {}) - : prototype_(std::move(prototype)), body_(std::move(body)), work_group_size_{0, 0}, - subgroup_size_{0} { - loc(lc); - } - - inline auto name() const -> std::string_view override { return prototype_->name(); } - - inline auto prototype() const -> func const & { return prototype_; } - inline auto body() const -> region const & { return body_; } - inline auto work_group_size() const -> std::array { return work_group_size_; } - - inline void work_group_size(std::array const &work_group_size) { - work_group_size_ = work_group_size; - } - inline auto subgroup_size() const -> std::int32_t { return subgroup_size_; } - inline void subgroup_size(std::int32_t subgroup_size) { subgroup_size_ = subgroup_size; } - - private: - func prototype_; - region body_; - std::array work_group_size_; - std::int32_t subgroup_size_; -}; - -} // namespace tinytc - -#endif // FUNCTION_NODE_20230310_HPP diff --git a/src/node/inst.cpp b/src/node/inst.cpp new file mode 100644 index 00000000..5c2abf72 --- /dev/null +++ b/src/node/inst.cpp @@ -0,0 +1,218 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/inst.hpp" +#include "error.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "tinytc/builder.h" +#include "tinytc/types.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tinytc; + +static_assert(alignof(tinytc_value) == alignof(tinytc_inst)); +static_assert(alignof(use) == alignof(tinytc_inst)); +static_assert(alignof(tinytc_region) == alignof(tinytc_inst)); +static_assert(alignof(tinytc_inst) <= alignof(std::max_align_t)); + +auto tinytc_inst::create(IK tid, inst_layout layout, tinytc_location const &lc) -> tinytc_inst_t { + std::size_t size = 0; + size += sizeof(tinytc_value) * layout.num_results; + size += sizeof(tinytc_inst); + size += sizeof(use) * layout.num_operands; + size += layout.sizeof_properties; + size += sizeof(tinytc_region) * layout.num_child_regions; + + struct del { + void operator()(std::uint8_t *p) const { std::free(p); } + }; + auto raw_mem = + std::unique_ptr(static_cast(std::malloc(size))); + if (raw_mem.get() == nullptr) { + throw status::bad_alloc; + } + + // initialize results + tinytc_value_t first_result = reinterpret_cast(raw_mem.get()); + tinytc_value_t last_result = first_result + layout.num_results; + for (; first_result != last_result; ++first_result) { + new (first_result) tinytc_value(); + } + + // initialize inst + tinytc_inst_t in = reinterpret_cast(last_result); + new (in) tinytc_inst(tid, layout, lc); + + // initialize uses + use *first_use = reinterpret_cast(in + 1); + use *last_use = first_use + layout.num_operands; + for (; first_use != last_use; ++first_use) { + new (first_use) use(in); + } + + // properties + std::uint8_t *first_prop = reinterpret_cast(last_use); + std::uint8_t *last_prop = first_prop + layout.sizeof_properties; + if (layout.sizeof_properties > 0) { + visit(overloaded{[&](auto view) { + std::construct_at(reinterpret_cast(first_prop)); + }}, + *in); + } + + // child regions + tinytc_region_t first_region = reinterpret_cast(last_prop); + tinytc_region_t last_region = first_region + layout.num_child_regions; + for (; first_region != last_region; ++first_region) { + new (first_region) tinytc_region(in); + } + + raw_mem.release(); + return in; +} + +void tinytc_inst::destroy(tinytc_inst_t in) { + void *raw_mem = reinterpret_cast(in) - in->layout_.num_results; + in->~tinytc_inst(); + std::free(raw_mem); +} + +tinytc_inst::~tinytc_inst() { + // child regions + for (tinytc_region_t r = child_region_ptr(0); r != child_region_ptr(layout_.num_child_regions); + ++r) { + std::destroy_at(r); + } + + // properties + if (layout_.sizeof_properties > 0) { + visit(overloaded{[&](auto view) { + std::destroy_at(static_cast(props())); + }}, + *this); + } + + // uses + for (use *u = use_ptr(0); u != use_ptr(layout_.num_operands); ++u) { + std::destroy_at(u); + } + + // results + for (tinytc_value_t r = result_ptr(0); r != result_ptr(layout_.num_results); --r) { + std::destroy_at(r); + } +} + +auto tinytc_inst::context() -> tinytc_compiler_context_t { + if (num_results() > 0) { + return result(0).context(); + } else if (num_operands() > 0) { + return op(0).context(); + } + return nullptr; +} + +void tinytc_inst::subs(tinytc_value_t old_value, tinytc_value_t new_value, bool recursive) { + for (auto u = use_ptr(0); u != use_ptr(layout_.num_operands); ++u) { + if (u->get() == old_value) { + u->set(new_value); + } + } + if (recursive) { + for (auto ® : child_regions()) { + for (auto &in : reg) { + in.subs(old_value, new_value, true); + } + } + } +} + +void tinytc_inst::op(std::size_t pos, tinytc_value_t val) { + if (val == nullptr) { + throw compilation_error(loc(), status::invalid_arguments); + } + *use_ptr(pos) = val; +} + +void tinytc_inst::result(std::size_t pos, tinytc_type_t ty) { + if (ty == nullptr) { + throw compilation_error(loc(), status::invalid_arguments); + } + *result_ptr(pos) = tinytc_value{ty, this, loc()}; +} + +extern "C" { + +void tinytc_inst_destroy(tinytc_inst_t obj) { tinytc_inst::destroy(obj); } + +tinytc_status_t tinytc_inst_get_parent_region(tinytc_inst_t instr, tinytc_region_t *parent) { + if (instr == nullptr || parent == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *parent = instr->parent(); }); +} + +tinytc_status_t tinytc_inst_get_values(tinytc_inst_t instr, size_t *result_list_size, + tinytc_value_t *result_list) { + if (instr == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const num_results = instr->num_results(); + if (num_results < 0) { + throw std::out_of_range("Number of results must not be negative"); + } + auto num = static_cast(num_results); + if (*result_list_size > 0) { + num = std::min(num, *result_list_size); + auto results = instr->result_begin(); + for (uint32_t i = 0; i < num; ++i) { + result_list[i] = &results[i]; + } + } + *result_list_size = num; + }); +} + +tinytc_status_t tinytc_inst_get_regions(tinytc_inst_t instr, size_t *result_list_size, + tinytc_region_t *result_list) { + if (instr == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const num_results = instr->num_child_regions(); + if (num_results < 0) { + throw std::out_of_range("Number of results must not be negative"); + } + auto num = static_cast(num_results); + if (*result_list_size > 0) { + auto results = instr->child_regions_begin(); + num = std::min(num, *result_list_size); + for (uint32_t i = 0; i < num; ++i) { + result_list[i] = &results[i]; + } + } + *result_list_size = num; + }); +} + +tinytc_status_t tinytc_inst_set_attr(tinytc_inst_t instr, tinytc_attr_t a) { + if (instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { instr->attr(a); }); +} +} diff --git a/src/node/inst.hpp b/src/node/inst.hpp new file mode 100644 index 00000000..a453d21a --- /dev/null +++ b/src/node/inst.hpp @@ -0,0 +1,134 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INST_20250626_HPP +#define INST_20250626_HPP + +#include "node/region.hpp" +#include "node/value.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +//! Instruction classification +enum class inst_execution_kind { + mixed, ///< mixed instruction on uniform or varying data + collective, ///< collective instruction on uniform data, distributed among work-items + spmd ///< SPMD instruction on varying data +}; + +using result_iterator = std::reverse_iterator; +using result_range = iterator_range_wrapper; + +using op_iterator = + tinytc::indirect_random_access_iterator; +using op_range = tinytc::iterator_range_wrapper; +static_assert(std::random_access_iterator); +static_assert(std::ranges::random_access_range); + +using region_range = iterator_range_wrapper; + +struct inst_layout { + std::int32_t num_results; + std::int32_t num_operands; + std::uint32_t sizeof_properties; + std::int32_t num_child_regions; +}; + +enum class IK; + +} // namespace tinytc + +struct alignas(8) tinytc_inst : tinytc::ilist_node_with_parent { + public: + static auto create(tinytc::IK tid, tinytc::inst_layout layout, tinytc_location const &lc) + -> tinytc_inst_t; + static void destroy(tinytc_inst_t in); + + auto context() -> tinytc_compiler_context_t; + inline auto type_id() const -> tinytc::IK { return tid_; } + + inline auto attr() const noexcept -> tinytc_attr_t { return attr_; } + inline void attr(tinytc_attr_t attr) noexcept { attr_ = attr; } + + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } + inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + + // Iterator over operands + inline auto op_begin() -> tinytc::op_iterator { return {use_ptr(0)}; } + inline auto op_end() -> tinytc::op_iterator { return {use_ptr(layout_.num_operands)}; } + inline auto operands() -> tinytc::op_range { return {op_begin(), op_end()}; } + inline auto op(std::size_t pos) -> tinytc_value & { return *use_ptr(pos)->get(); } + inline auto get_use(std::size_t pos) -> tinytc::use & { return *use_ptr(pos); } + inline auto num_operands() const -> std::int32_t { return layout_.num_operands; } + void op(std::size_t pos, tinytc_value_t val); + + void subs(tinytc_value_t old_value, tinytc_value_t new_value, bool recursive = true); + + // Iterator over results + inline auto result_begin() -> tinytc::result_iterator { + return tinytc::result_iterator(reinterpret_cast(this)); + } + inline auto result_end() -> tinytc::result_iterator { + return tinytc::result_iterator(reinterpret_cast(this) - + layout_.num_results); + } + inline auto results() -> tinytc::result_range { return {result_begin(), result_end()}; } + inline auto result(std::size_t pos) -> tinytc_value & { return *result_ptr(pos); } + inline auto num_results() const -> std::int32_t { return layout_.num_results; } + void result(std::size_t pos, tinytc_type_t ty); + + // Properties + inline auto props() -> void * { return use_ptr(layout_.num_operands); } + + // Iterator over regions + inline auto child_regions_begin() -> tinytc_region_t { return child_region_ptr(0); } + inline auto child_regions_end() -> tinytc_region_t { + return child_region_ptr(layout_.num_child_regions); + } + inline auto child_regions() -> tinytc::region_range { + return tinytc::region_range{child_regions_begin(), child_regions_end()}; + } + auto child_region(std::size_t pos) -> tinytc_region & { return *child_region_ptr(pos); } + auto num_child_regions() const -> std::int32_t { return layout_.num_child_regions; } + + auto kind() -> tinytc::inst_execution_kind; + + auto layout() const -> tinytc::inst_layout const & { return layout_; } + + private: + inline tinytc_inst(tinytc::IK tid, tinytc::inst_layout layout, tinytc_location const &lc) + : tid_(tid), layout_(layout), loc_{lc} {} + ~tinytc_inst(); + + tinytc_inst(tinytc_inst const &other) = delete; + tinytc_inst(tinytc_inst &&other) = delete; + tinytc_inst &operator=(tinytc_inst const &other) = delete; + tinytc_inst &operator=(tinytc_inst &&other) = delete; + + inline auto result_ptr(std::int32_t no) -> tinytc_value_t { + return reinterpret_cast(this) - ++no; + } + inline auto use_ptr(std::int32_t no) -> tinytc::use * { + return reinterpret_cast(this + 1) + no; + } + inline auto child_region_ptr(std::int32_t no) -> tinytc_region_t { + std::uint8_t *props_end = static_cast(props()) + layout_.sizeof_properties; + return reinterpret_cast(props_end) + no; + } + + tinytc::IK tid_; + tinytc::inst_layout layout_; + tinytc::location loc_; + tinytc_attr_t attr_ = nullptr; +}; + +#endif // INST_20250626_HPP diff --git a/src/node/inst_kind.cpp.mochi b/src/node/inst_kind.cpp.mochi new file mode 100644 index 00000000..11c1f2e1 --- /dev/null +++ b/src/node/inst_kind.cpp.mochi @@ -0,0 +1,10 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "error.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" + +using namespace tinytc; + +// もち inst_kind_cpp "tinytc/instructions.anko" diff --git a/src/node/inst_node.cpp b/src/node/inst_node.cpp deleted file mode 100644 index 0fa368ea..00000000 --- a/src/node/inst_node.cpp +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "node/inst_node.hpp" -#include "error.hpp" -#include "node/data_type_node.hpp" -#include "node/value_node.hpp" -#include "scalar_type.hpp" -#include "tinytc/types.hpp" -#include "util.hpp" - -#include -#include - -#include -#include -#include - -namespace tinytc { - -scalar_data_type *get_scalar_type(location const &loc, value &v) { - auto m = dynamic_cast(v->ty().get()); - if (m == nullptr) { - throw compilation_error(loc, status::ir_expected_scalar); - } - return m; -} - -memref_data_type *get_memref_type(location const &loc, value &v) { - auto m = dynamic_cast(v->ty().get()); - if (m == nullptr) { - throw compilation_error(loc, status::ir_expected_memref); - } - return m; -} - -blas_a2_inst::blas_a2_inst(value alpha, value A, value beta, value B, bool atomic) - : alpha_(std::move(alpha)), A_(std::move(A)), beta_(std::move(beta)), B_(std::move(B)), - atomic_(atomic) {} - -blas_a3_inst::blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic) - : alpha_(std::move(alpha)), A_(std::move(A)), B_(std::move(B)), beta_(std::move(beta)), - C_(std::move(C)), atomic_(atomic) {} - -loop_inst::loop_inst(value loop_var, value from, value to, region body, location const &lc) - : loop_inst(std::move(loop_var), std::move(from), std::move(to), {}, std::move(body), lc) {} - -loop_inst::loop_inst(value loop_var, value from, value to, value step, region body, - location const &lc) - : loop_var_(std::move(loop_var)), from_(std::move(from)), to_(std::move(to)), - step_(std::move(step)), body_(std::move(body)) { - loc(lc); - auto lvt = get_scalar_type(loc(), loop_var_); - auto fromt = get_scalar_type(loc(), from_); - auto tot = get_scalar_type(loc(), to_); - bool step_ok = true; - if (step_) { - auto stept = get_scalar_type(loc(), step_); - step_ok = lvt->ty() == stept->ty(); - } - - if (lvt->ty() != fromt->ty() || lvt->ty() != tot->ty() || !step_ok) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } -} - -alloca_inst::alloca_inst(data_type ty, location const &lc) - : result_{make_value(std::move(ty))}, stack_ptr_{-1} { - loc(lc); - auto memref = dynamic_cast(result_->ty().get()); - if (memref == nullptr) { - throw compilation_error(loc(), status::ir_expected_memref); - } - memref->addrspace(clir::address_space::local_t); -} - -axpby_inst::axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(beta), std::move(B), atomic), tA_(tA) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - - bool shape_equal = false; - if (tA_ == transpose::T && a->dim() == 2 && b->dim() == 2) { - shape_equal = a->shape()[1] == b->shape()[0] && a->shape()[0] == b->shape()[1]; - } else { - shape_equal = a->shape() == b->shape(); - } - - if (!shape_equal) { - throw compilation_error(loc(), status::ir_incompatible_shapes); - } - - if (b->dim() > 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix); - } -} - -arith_inst::arith_inst(arithmetic op, value a, value b, location const &lc) - : op_(op), a_(std::move(a)), b_(std::move(b)) { - loc(lc); - - auto at = get_scalar_type(loc(), a_); - auto bt = get_scalar_type(loc(), b_); - - if (at->ty() != bt->ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } - bool inst_supports_fp = false; - switch (op) { - case arithmetic::add: - case arithmetic::sub: - case arithmetic::mul: - case arithmetic::div: - case arithmetic::rem: - inst_supports_fp = true; - break; - case arithmetic::shl: - case arithmetic::shr: - case arithmetic::and_: - case arithmetic::or_: - case arithmetic::xor_: - inst_supports_fp = false; - break; - } - if (!inst_supports_fp && is_floating_type(at->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); - } - result_ = make_value(at->ty()); -} - -arith_unary_inst::arith_unary_inst(arithmetic_unary op, value a, location const &lc) - : op_(op), a_(std::move(a)) { - loc(lc); - - auto at = get_scalar_type(loc(), a_); - bool inst_supports_fp = false; - switch (op) { - case arithmetic_unary::neg: - inst_supports_fp = true; - break; - case arithmetic_unary::not_: - inst_supports_fp = false; - break; - } - if (!inst_supports_fp && is_floating_type(at->ty())) { - throw compilation_error(loc(), status::ir_fp_unsupported); - } - result_ = make_value(at->ty()); -} - -cast_inst::cast_inst(value a, scalar_type to_ty, location const &lc) - : a_(std::move(a)), result_{make_value(to_ty)} { - loc(lc); -} - -compare_inst::compare_inst(cmp_condition cond, value a, value b, location const &lc) - : cond_(cond), a_(std::move(a)), b_(std::move(b)), result_{make_value(scalar_type::i1)} { - loc(lc); - - auto at = get_scalar_type(loc(), a_); - auto bt = get_scalar_type(loc(), b_); - - if (at->ty() != bt->ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } -} - -gemm_inst::gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, - bool atomic, location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic), - tA_(tA), tB_(tB) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - - if (a->dim() != 2 || b->dim() != 2 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemm only supported for memref of order 2 (matrices)"); - } - - auto ak = tA_ == transpose::T ? 0 : 1; - auto bk = tB_ == transpose::T ? 1 : 0; - auto M = c->shape(0); - auto N = c->shape(1); - auto K = a->shape(ak); - if (a->shape(1 - ak) != M || b->shape(bk) != K || b->shape(1 - bk) != N) { - std::ostringstream oss; - oss << "Got "; - oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; - oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; - oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -gemv_inst::gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic), - tA_(tA) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - - if (a->dim() != 2 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "gemv only supports matrix-vector products"); - } - - auto ak = tA_ == transpose::T ? 0 : 1; - auto M = c->shape(0); - auto K = a->shape(ak); - if (a->shape(1 - ak) != M || b->shape(0) != K) { - std::ostringstream oss; - oss << "Got "; - oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -ger_inst::ger_inst(value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 2) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "ger requires two vectors as input and one matrix as output"); - } - - auto M = c->shape(0); - auto N = c->shape(1); - if (a->shape(0) != M || b->shape(0) != N) { - std::ostringstream oss; - oss << "Got "; - oss << "a=" << a->shape(0) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "C=" << c->shape(0) << "x" << c->shape(1); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -hadamard_inst::hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(B), std::move(beta), std::move(C), atomic) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - auto c = get_memref_type(loc(), C_); - - if (a->dim() != 1 || b->dim() != 1 || c->dim() != 1) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix, - "hadamard requires two vectors as input and one vector as output"); - } - - auto M = c->shape(0); - if (a->shape(0) != M || b->shape(0) != M) { - std::ostringstream oss; - oss << "Got "; - oss << "a=" << a->shape(0) << ", "; - oss << "b=" << b->shape(0) << ", "; - oss << "c=" << c->shape(0); - throw compilation_error(loc(), status::ir_incompatible_shapes, oss.str()); - } -} - -expand_inst::expand_inst(value op, std::int64_t mode, std::vector expand_shape, - location const &lc) - : op_(std::move(op)), mode_(mode), expand_shape_(std::move(expand_shape)) { - loc(lc); - - auto m = get_memref_type(loc(), op_); - bool const range_ok = 0 <= mode_ && mode_ < m->dim(); - if (!range_ok) { - throw compilation_error(loc(), status::ir_out_of_bounds); - } - - if (expand_shape_.size() < 2) { - throw compilation_error(loc(), status::ir_expand_shape_order_too_small); - } - - auto known_expand_shape = std::vector(); - known_expand_shape.reserve(expand_shape_.size()); - std::size_t dyn_count = 0, non_imm_count = 0; - for (auto &s : expand_shape_) { - visit(overloaded{[&](int_imm &i) { - if (is_dynamic_value(i.value())) { - known_expand_shape.push_back(dynamic); - ++dyn_count; - return; - } - if (i.value() < 0) { - throw compilation_error(loc(), status::ir_invalid_shape); - } - known_expand_shape.push_back(i.value()); - }, - [&](auto &) { - known_expand_shape.push_back(dynamic); - ++non_imm_count; - }}, - *s); - } - - if (dyn_count > 1) { - throw compilation_error(loc(), status::ir_multiple_dynamic_modes); - } - - auto size = m->shape(mode_); - if (!is_dynamic_value(size) && non_imm_count == 0) { - std::int64_t prod = 1; - std::int64_t dyn_mode = -1; - for (std::size_t i = 0; i < known_expand_shape.size(); ++i) { - auto const s = known_expand_shape[i]; - if (is_dynamic_value(s)) { - dyn_mode = i; - } else { - prod *= s; - } - } - if (dyn_mode >= 0) { - std::int64_t const s = size / prod; - known_expand_shape[dyn_mode] = s; - expand_shape_[dyn_mode] = make_imm(s); - prod *= s; - } - if (prod != size) { - throw compilation_error(loc(), status::ir_expand_shape_mismatch); - } - } - - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim() + known_expand_shape.size() - 1); - stride.reserve(m->dim() + known_expand_shape.size() - 1); - for (std::int64_t i = 0; i < mode_; ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); - } - - stride.push_back(m->stride(mode_)); - shape.push_back(known_expand_shape[0]); - for (std::size_t j = 1; j < known_expand_shape.size(); ++j) { - stride.push_back(is_dynamic_value(stride.back()) || is_dynamic_value(shape.back()) - ? dynamic - : stride.back() * shape.back()); - shape.push_back(known_expand_shape[j]); - } - for (std::int64_t i = mode_ + 1; i < m->dim(); ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); - } - auto r = std::make_unique(m->element_ty(), shape, stride); - - r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); -} - -fuse_inst::fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc) - : op_(std::move(op)), from_(from), to_(to) { - loc(lc); - auto m = get_memref_type(loc(), op_); - bool const range_ok = 0 <= from_ && from_ < to_ && to_ < m->dim(); - if (!range_ok) { - throw compilation_error(loc(), status::ir_out_of_bounds); - } - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - std::int64_t i = 0; - for (; i < from_; ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); - } - std::int64_t prod = 1; - for (; i <= to_; ++i) { - if (is_dynamic_value(m->shape(i))) { - prod = dynamic; - break; - } - prod *= m->shape(i); - } - shape.push_back(prod); - stride.push_back(m->stride(from_)); - for (i = to_ + 1; i < m->dim(); ++i) { - shape.push_back(m->shape(i)); - stride.push_back(m->stride(i)); - } - auto r = std::make_unique(m->element_ty(), shape, stride); - - r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); -} - -if_inst::if_inst(value condition, region then, region otherwise, - std::vector const &return_types, location const &lc) - : condition_(std::move(condition)), then_(std::move(then)), otherwise_(std::move(otherwise)) { - loc(lc); - for (auto &ty : return_types) { - results_.push_back(make_value(ty)); - } -} - -load_inst::load_inst(value op, std::vector index_list, location const &lc) - : op_(std::move(op)), index_list_(std::move(index_list)) { - loc(lc); - visit(overloaded{ - [&](group_data_type &g) { - if (static_cast(index_list_.size()) != 1) { - throw compilation_error(loc(), status::ir_invalid_number_of_indices); - } - result_ = make_value(g.ty()); - }, - [&](memref_data_type &m) { - if (m.dim() != static_cast(index_list_.size())) { - throw compilation_error(loc(), status::ir_invalid_number_of_indices); - } - result_ = make_value(m.element_ty()); - }, - [&](auto &) { throw compilation_error(loc(), status::ir_expected_memref_or_group); }}, - *op_->ty()); -} - -size_inst::size_inst(value op, std::int64_t mode, location const &lc) - : op_(std::move(op)), mode_(mode) { - loc(lc); - auto m = get_memref_type(loc(), op_); - bool const range_ok = 0 <= mode_ && mode_ < m->dim(); - if (!range_ok) { - throw compilation_error(loc(), status::ir_out_of_bounds); - } - - result_ = make_value(scalar_type::index); -} - -subview_inst::subview_inst(value op, std::vector slices, location const &lc) - : op_(std::move(op)), slices_(std::move(slices)) { - loc(lc); - auto m = get_memref_type(loc(), op_); - if (m->dim() != static_cast(slices_.size())) { - throw compilation_error(loc(), status::ir_invalid_number_of_indices); - } - - auto shape = std::vector{}; - auto stride = std::vector{}; - shape.reserve(m->dim()); - stride.reserve(m->dim()); - for (std::int64_t i = 0; i < m->dim(); ++i) { - auto &slice = slices_[i]; - visit(overloaded{[&](int_imm &i) { - if (i.value() < 0) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *slice.first); - if (slice.second) { // if size is given - visit(overloaded{[&](int_imm &i) { - if (i.value() < 1 && !is_dynamic_value(i.value())) { - throw compilation_error(loc(), status::ir_invalid_slice); - } - }, - [](auto &) {}}, - *slice.second); - auto size = visit(overloaded{[&](int_imm &offset, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return is_dynamic_value(m->shape(i)) - ? dynamic - : m->shape(i) - offset.value(); - } - return size.value(); - }, - [&](val &, int_imm &size) -> std::int64_t { - if (is_dynamic_value(size.value())) { - return dynamic; - } - return size.value(); - }, - [](auto &, auto &) -> std::int64_t { return dynamic; }}, - *slice.first, *slice.second); - shape.push_back(size); - stride.push_back(m->stride(i)); - } - } - auto r = std::make_unique(m->element_ty(), shape, stride); - - r->addrspace(m->addrspace()); - result_ = make_value(data_type(r.release())); -} - -store_inst::store_inst(value val, value op, std::vector index_list, location const &lc) - : val_(std::move(val)), op_(std::move(op)), index_list_(std::move(index_list)) { - loc(lc); - auto v = get_scalar_type(loc(), val_); - auto o = get_memref_type(loc(), op_); - - if (v->ty() != o->element_ty()) { - throw compilation_error(loc(), status::ir_scalar_mismatch); - } - - if (o->dim() != static_cast(index_list_.size())) { - throw compilation_error(loc(), status::ir_invalid_number_of_indices); - } -} - -sum_inst::sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic, - location const &lc) - : super(std::move(alpha), std::move(A), std::move(beta), std::move(B), atomic), tA_(tA) { - loc(lc); - auto a = get_memref_type(loc(), A_); - auto b = get_memref_type(loc(), B_); - - bool const size_ok = (a->dim() == 2 && b->dim() == 1) || (a->dim() == 1 && b->dim() == 0); - if (!size_ok) { - throw compilation_error(loc(), status::ir_expected_vector_or_matrix); - } - - if (a->dim() == 2) { - if (a->shape(tA_ == transpose::T ? 1 : 0) != b->shape(0)) { - throw compilation_error(loc(), status::ir_incompatible_shapes); - } - } -} - -} // namespace tinytc diff --git a/src/node/inst_node.hpp b/src/node/inst_node.hpp deleted file mode 100644 index 1dc54846..00000000 --- a/src/node/inst_node.hpp +++ /dev/null @@ -1,434 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef INST_NODE_20230327_HPP -#define INST_NODE_20230327_HPP - -#include "reference_counted.hpp" -#include "slice.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" - -#include - -#include -#include -#include -#include - -namespace tinytc { - -//! Instruction classification -enum class inst_kind { - replicated, ///< replicated instruction executed in every work-item - collective ///< collective instruction distributed among work-items -}; - -using inst_nodes = - clir::virtual_type_list; - -} // namespace tinytc - -struct tinytc_inst : tinytc::reference_counted, tinytc::inst_nodes { - public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - - virtual tinytc::value result() const = 0; - inline virtual auto results() const -> std::vector { - if (auto r = result(); r) { - return {std::move(r)}; - } - return {}; - } - inline virtual auto num_results() const -> std::size_t { return result() ? 1u : 0u; } - virtual tinytc::inst_kind kind() const = 0; - - private: - tinytc::location loc_; -}; - -namespace tinytc { - -using inst_node = ::tinytc_inst; - -class scalar_inst : public inst_node {}; - -class blas_a2_inst : public inst_node { - public: - blas_a2_inst(value alpha, value A, value beta, value B, bool atomic); - - inline bool atomic() const { return atomic_; } - inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return alpha_; } - inline auto A() const -> value const & { return A_; } - inline auto beta() const -> value const & { return beta_; } - inline auto B() const -> value const & { return B_; } - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } - - protected: - value alpha_, A_, beta_, B_; - bool atomic_; -}; - -class blas_a3_inst : public inst_node { - public: - blas_a3_inst(value alpha, value A, value B, value beta, value C, bool atomic); - - inline bool atomic() const { return atomic_; } - inline void atomic(bool a) { atomic_ = a; } - inline auto alpha() const -> value const & { return alpha_; } - inline auto A() const -> value const & { return A_; } - inline auto B() const -> value const & { return B_; } - inline auto beta() const -> value const & { return beta_; } - inline auto C() const -> value const & { return C_; } - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } - - protected: - value alpha_, A_, B_, beta_, C_; - bool atomic_; -}; - -class loop_inst : public inst_node { - public: - loop_inst(value loop_var, value from, value to, region body, location const &loc = {}); - loop_inst(value loop_var, value from, value to, value step, region body, - location const &loc = {}); - inline auto loop_var() const -> value const & { return loop_var_; } - inline auto from() const -> value const & { return from_; } - inline auto to() const -> value const & { return to_; } - inline auto step() const -> value const & { return step_; } - inline auto body() const -> region const & { return body_; } - inline value result() const override { return value{}; } - - private: - value loop_var_, from_, to_, step_; - region body_; -}; - -class alloca_inst : public clir::visitable { - public: - alloca_inst(data_type ty, location const &loc = {}); - - inline value result() const override { return result_; } - inline std::int64_t stack_ptr() const { return stack_ptr_; } - inline void stack_ptr(std::int64_t ptr) { stack_ptr_ = ptr; } - inline inst_kind kind() const override { return inst_kind::collective; } - - private: - value result_; - std::int64_t stack_ptr_; -}; - -class axpby_inst : public clir::visitable { - public: - using super = clir::visitable; - axpby_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, - location const &lc = {}); - - inline transpose tA() const { return tA_; } - - private: - transpose tA_; -}; - -class arith_inst : public clir::visitable { - public: - arith_inst(arithmetic op, value a, value b, location const &lc = {}); - - inline arithmetic op() const { return op_; } - inline auto a() const -> value const & { return a_; } - inline auto b() const -> value const & { return b_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - arithmetic op_; - value a_, b_, result_; -}; - -class arith_unary_inst : public clir::visitable { - public: - arith_unary_inst(arithmetic_unary op, value a, location const &lc = {}); - - inline arithmetic_unary op() const { return op_; } - inline auto a() const -> value const & { return a_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - arithmetic_unary op_; - value a_, result_; -}; - -class barrier_inst : public clir::visitable { - public: - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } -}; - -class cast_inst : public clir::visitable { - public: - cast_inst(value a, scalar_type to_ty, location const &lc = {}); - inline auto a() const -> value const & { return a_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value a_, result_; -}; - -class compare_inst : public clir::visitable { - public: - compare_inst(cmp_condition cond, value a, value b, location const &lc = {}); - - inline cmp_condition cond() const { return cond_; } - inline auto a() const -> value const & { return a_; } - inline auto b() const -> value const & { return b_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - cmp_condition cond_; - value a_, b_, result_; -}; - -class expand_inst : public clir::visitable { - public: - expand_inst(value op, std::int64_t mode, std::vector expand_shape, - location const &lc = {}); - - inline auto operand() const -> value const & { return op_; } - inline std::int64_t mode() const { return mode_; } - inline auto expand_shape() const -> std::vector const & { return expand_shape_; } - inline auto expand_shape(std::int64_t i) const -> value const & { return expand_shape_[i]; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value op_, result_; - std::int64_t mode_; - std::vector expand_shape_; -}; - -class fuse_inst : public clir::visitable { - public: - fuse_inst(value op, std::int64_t from, std::int64_t to, location const &lc = {}); - - inline auto operand() const -> value const & { return op_; } - inline std::int64_t from() const { return from_; } - inline std::int64_t to() const { return to_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value op_, result_; - std::int64_t from_, to_; -}; - -class load_inst : public clir::visitable { - public: - load_inst(value op, std::vector index_list, location const &lc = {}); - - inline auto operand() const -> value const & { return op_; } - inline auto index_list() const -> std::vector const & { return index_list_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value op_; - std::vector index_list_; - value result_; -}; - -class group_id_inst : public clir::visitable { - public: - inline group_id_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { - loc(lc); - } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value result_; -}; - -class group_size_inst : public clir::visitable { - public: - inline group_size_inst(location const &lc = {}) : result_{make_value(scalar_type::index)} { - loc(lc); - } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value result_; -}; - -class lifetime_stop_inst : public clir::visitable { - public: - inline lifetime_stop_inst(value obj) : obj_(std::move(obj)) {} - inline auto object() const -> value const & { return obj_; } - inline value result() const override { return value{}; } - inline inst_kind kind() const override { return inst_kind::collective; } - - private: - value obj_; -}; - -class gemm_inst : public clir::visitable { - public: - using super = clir::visitable; - gemm_inst(transpose tA, transpose tB, value alpha, value A, value B, value beta, value C, - bool atomic = false, location const &lc = {}); - - inline transpose tA() const { return tA_; } - inline transpose tB() const { return tB_; } - - private: - transpose tA_, tB_; -}; - -class gemv_inst : public clir::visitable { - public: - using super = clir::visitable; - gemv_inst(transpose tA, value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); - - inline transpose tA() const { return tA_; } - - private: - transpose tA_; -}; - -class ger_inst : public clir::visitable { - public: - using super = clir::visitable; - ger_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); -}; - -class for_inst : public clir::visitable { - public: - using super = clir::visitable; - using super::super; - inline inst_kind kind() const override { return inst_kind::replicated; } -}; - -class foreach_inst : public clir::visitable { - public: - using super = clir::visitable; - inline foreach_inst(value loop_var, value from, value to, region body, location const &loc = {}) - : super(std::move(loop_var), std::move(from), std::move(to), std::move(body), loc) {} - inline inst_kind kind() const override { return inst_kind::collective; } -}; - -class hadamard_inst : public clir::visitable { - public: - using super = clir::visitable; - hadamard_inst(value alpha, value A, value B, value beta, value C, bool atomic = false, - location const &lc = {}); -}; - -class if_inst : public clir::visitable { - public: - if_inst(value condition, region then, region otherwise = {}, - std::vector const &return_types = {}, location const &lc = {}); - inline auto condition() const -> value const & { return condition_; } - inline auto then() const -> region const & { return then_; } - inline auto otherwise() const -> region const & { return otherwise_; } - inline value result() const override { - return results_.size() > 0 ? results_.front() : value{}; - } - inline auto results() const -> std::vector override { return results_; } - inline auto num_results() const -> std::size_t override { return results_.size(); } - inline auto results_ref() -> std::vector & { return results_; } - inline auto results_ref() const -> std::vector const & { return results_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value condition_; - region then_, otherwise_; - std::vector results_; -}; - -class size_inst : public clir::visitable { - public: - size_inst(value op, std::int64_t mode, location const &lc = {}); - - inline auto operand() const -> value const & { return op_; } - inline std::int64_t mode() const { return mode_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value op_, result_; - std::int64_t mode_; -}; - -class subview_inst : public clir::visitable { - public: - subview_inst(value op, std::vector slices, location const &lc = {}); - - inline auto slices() const -> std::vector const & { return slices_; } - inline auto operand() const -> value const & { return op_; } - inline value result() const override { return result_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value op_; - std::vector slices_; - value result_; -}; - -class store_inst : public clir::visitable { - public: - store_inst(value val, value op, std::vector index_list, location const &lc = {}); - - inline auto val() const -> value const & { return val_; } - inline auto operand() const -> value const & { return op_; } - inline auto index_list() const -> std::vector const & { return index_list_; } - inline value result() const override { return {}; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - value val_, op_; - std::vector index_list_; -}; - -class sum_inst : public clir::visitable { - public: - using super = clir::visitable; - sum_inst(transpose tA, value alpha, value A, value beta, value B, bool atomic = false, - location const &lc = {}); - - inline transpose tA() const { return tA_; } - - private: - transpose tA_; -}; - -class yield_inst : public clir::visitable { - public: - inline yield_inst(std::vector vals, location const &lc = {}) : vals_(std::move(vals)) { - loc(lc); - } - inline value result() const override { return value{}; } - inline auto vals() const -> std::vector const & { return vals_; } - inline inst_kind kind() const override { return inst_kind::replicated; } - - private: - std::vector vals_; -}; - -} // namespace tinytc - -#endif // INST_NODE_20230327_HPP diff --git a/src/node/inst_view.cpp.mochi b/src/node/inst_view.cpp.mochi new file mode 100644 index 00000000..52bd7149 --- /dev/null +++ b/src/node/inst_view.cpp.mochi @@ -0,0 +1,30 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/inst_view.hpp" + +#include +#include +#include + +namespace tinytc { + +template void safe_increase(std::int32_t &value, T by) { + if (by < 0 || by > std::numeric_limits::max() - value) { + throw status::out_of_range; + } + value += by; +} + +template void safe_increase(std::int32_t &value, T by) { + using T_s = std::make_signed_t; + if (by > static_cast(std::numeric_limits::max()) || + static_cast(by) > std::numeric_limits::max() - value) { + throw status::out_of_range; + } + value += by; +} + +// もち inst_cpp "tinytc/instructions.anko" + +} // namespace tinytc diff --git a/src/node/inst_view.hpp.mochi b/src/node/inst_view.hpp.mochi new file mode 100644 index 00000000..d3960fc9 --- /dev/null +++ b/src/node/inst_view.hpp.mochi @@ -0,0 +1,77 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INST_VIEW_HPP_20250616 +#define INST_VIEW_HPP_20250616 + +#include "node/inst.hpp" +#include "node/region.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +using constant_value_type = std::variant>; + +enum support_flag { + supports_bool = 0x1, + supports_int = 0x2, + supports_float = 0x4, + supports_complex = 0x8 +}; +using support_flags = std::uint32_t; + +class inst_view { + public: + struct properties {}; + + inline inst_view(tinytc_inst_t in) : in_{in} {} + + inline auto get() -> tinytc_inst & { return *in_; } + inline auto loc() noexcept -> tinytc::location const & { return get().loc(); } + + inline explicit operator bool() const { return in_ != nullptr; } + + private: + tinytc_inst_t in_; +}; + +template +requires(std::is_base_of_v>) +auto isa(tinytc_inst &obj) -> bool { + return To::classof(obj); +} + +template +requires(std::is_base_of_v>) +auto dyn_cast(tinytc_inst_t obj) -> To { + if (obj != nullptr && isa(*obj)) { + return To(obj); + } + return To(nullptr); +} + +template +requires(std::is_base_of_v>) +auto dyn_cast_or_throw(tinytc_inst_t obj, F &&make_exception) -> To * { + if (auto c = dyn_cast(obj); c) { + return c; + } + throw make_exception(); +} + +// もち inst_hpp "tinytc/instructions.anko" + +} // namespace tinytc + +#endif // INST_VIEW_HPP_20250616 diff --git a/src/node/inst_view_impl.cpp b/src/node/inst_view_impl.cpp new file mode 100644 index 00000000..d6269494 --- /dev/null +++ b/src/node/inst_view_impl.cpp @@ -0,0 +1,1169 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "error.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/iterator.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +coopmatrix_type *get_coopmatrix_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); + if (m == nullptr) { + throw compilation_error(loc, {&v}, status::ir_expected_coopmatrix); + } + return m; +} + +number_type *get_scalar_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); + if (m == nullptr) { + throw compilation_error(loc, {&v}, status::ir_expected_number); + } + return m; +} + +memref_type *get_memref_type(location const &loc, tinytc_value const &v) { + auto m = dyn_cast(v.ty()); + if (m == nullptr) { + throw compilation_error(loc, {&v}, status::ir_expected_memref); + } + return m; +} + +void check_index_ty(location const &loc, tinytc_value const &v) { + if (!isa(*v.ty())) { + throw compilation_error(loc, {&v}, status::ir_expected_index); + } +} + +void check_memref_shape(memref_type *rt, std::int64_t ri, memref_type *ot, std::int64_t oi, + location const &loc) { + if (rt->shape(ri) != ot->shape(oi)) { + auto extra_info = std::ostringstream{} << "Size of mode " << ri + << " does not match operand mode " << oi << " [" + << rt->shape(ri) << "!=" << ot->shape(oi) << "]"; + throw compilation_error(loc, status::ir_invalid_shape, std::move(extra_info).str()); + } +} +void check_memref_stride(memref_type *rt, std::int64_t ri, memref_type *ot, std::int64_t oi, + location const &loc) { + if (!is_dynamic_value(rt->stride(ri)) && rt->stride(ri) != ot->stride(oi)) { + auto extra_info = std::ostringstream{} << "Stride of mode " << ri + << " does not match operand stride " << oi << " [" + << rt->stride(ri) << "!=" << ot->stride(oi) << "]"; + throw compilation_error(loc, status::ir_invalid_stride, std::move(extra_info).str()); + } +} + +void check_memref_mode(memref_type *rt, std::int64_t ri, memref_type *ot, std::int64_t oi, + location const &loc) { + check_memref_shape(rt, ri, ot, oi, loc); + check_memref_stride(rt, ri, ot, oi, loc); +} + +auto get_and_check_memref_type_addrspace(tinytc_value const &operand, tinytc_type_t ty, + location const &loc) + -> std::pair { + auto rt = dyn_cast(ty); + if (!rt) { + throw compilation_error(loc, status::ir_expected_memref); + } + auto ot = get_memref_type(loc, operand); + if (rt->element_ty() != ot->element_ty()) { + throw compilation_error(loc, {&operand}, status::ir_number_mismatch); + } + if (rt->addrspace() != ot->addrspace()) { + throw compilation_error(loc, {&operand}, status::ir_address_space_mismatch); + } + return {ot, rt}; +} + +void alloca_inst::setup_and_check() { + auto memref = dyn_cast(result().ty()); + if (memref == nullptr) { + throw compilation_error(loc(), status::ir_expected_memref); + } + if (memref->addrspace() != address_space::local) { + throw compilation_error(loc(), status::ir_expected_local_address_space); + } + + stack_ptr(-1); +} + +void barrier_inst::setup_and_check() {} + +auto barrier_inst::has_fence(address_space as) -> bool { + return (fence_flags() & static_cast(as)) > 0; +} + +void cast_inst::setup_and_check() { + auto to_ty = result().ty(); + + if (auto rt = dyn_cast(to_ty); rt) { + auto ct = dyn_cast(a().ty()); + if (!ct) { + throw compilation_error(loc(), {&a()}, status::ir_expected_coopmatrix); + } + if (ct->rows() != rt->rows() || ct->cols() != rt->cols()) { + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); + } + const bool use_matches = ct->use() == rt->use(); + const bool use_conversion_allowed = + ct->use() == matrix_use::acc && + (rt->use() == matrix_use::a || rt->use() == matrix_use::b); + if (!use_matches && !use_conversion_allowed) { + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); + } + if (!is_cast_allowed(ct->component_ty(), rt->component_ty())) { + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); + } + } else { + if (!isa(*to_ty)) { + throw compilation_error(loc(), status::ir_expected_number); + } + + if (!is_cast_allowed(a().ty(), to_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_forbidden_cast); + } + } +} + +void constant_inst::setup_and_check() { + auto ty = result().ty(); + + const auto type_ok = [](constant_value_type const &val, tinytc_type_t ty) { + return (isa(*ty) && std::holds_alternative(val)) || + (isa(*ty) && std::holds_alternative(val)) || + (isa(*ty) && std::holds_alternative>(val)); + }; + + if (auto bt = dyn_cast(ty); bt) { + if (!std::holds_alternative(value())) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else if (auto st = dyn_cast(ty); st) { + if (!type_ok(value(), st)) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else if (auto ct = dyn_cast(ty); ct) { + if (!type_ok(value(), ct->component_ty())) { + throw compilation_error(loc(), status::ir_constant_mismatch); + } + } else { + throw compilation_error(loc(), status::ir_expected_coopmatrix_number_or_boolean); + } +} + +auto constant_inst::is_zero() -> bool { + return std::visit([](auto const &v) { return v == decltype(v){0}; }, value()); +} +auto constant_inst::is_identity() -> bool { + return std::visit([](auto const &v) { return v == decltype(v){1}; }, value()); +} + +void cooperative_matrix_apply_inst::setup_and_check() { + auto ty = result().ty(); + + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } + + auto at = get_coopmatrix_type(loc(), a()); + + auto i32_ty = i32_type::get(at->context()); + body().loc(loc()); + body().kind(region_kind::spmd); + body().set_num_params(3); + body().set_param(0, i32_ty); + body().set_param(1, i32_ty); + body().set_param(2, at->component_ty()); +} + +void cooperative_matrix_extract_inst::setup_and_check() { + auto ty = result().ty(); + + auto matt = get_coopmatrix_type(loc(), mat()); + if (matt->component_ty() != ty) { + throw compilation_error(loc(), {&mat()}, status::ir_number_mismatch); + } +} + +void cooperative_matrix_insert_inst::setup_and_check() { + auto ty = result().ty(); + + if (mat().ty() != ty) { + throw compilation_error(loc(), {&mat()}, status::ir_operand_type_must_match_return_type); + } + + auto valt = get_scalar_type(loc(), val()); + auto matt = get_coopmatrix_type(loc(), mat()); + if (matt->component_ty() != valt) { + throw compilation_error(loc(), {&val(), &mat()}, status::ir_number_mismatch); + } +} + +void cooperative_matrix_load_inst::setup_and_check() { + auto rt = dyn_cast(result().ty()); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_coopmatrix); + } + + auto ot = get_memref_type(loc(), operand()); + if (ot->element_ty() != rt->component_ty()) { + throw compilation_error(loc(), {&operand()}, status::ir_number_mismatch); + } + if (ot->dim() != 2) { + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); + } + + check_index_ty(loc(), pos0()); + check_index_ty(loc(), pos1()); +} + +void cooperative_matrix_mul_add_inst::setup_and_check() { + auto rt = dyn_cast(result().ty()); + if (!rt) { + throw compilation_error(loc(), status::ir_expected_memref); + } + if (rt->use() != matrix_use::acc) { + throw compilation_error(loc(), status::ir_invalid_matrix_use); + } + + auto at = get_coopmatrix_type(loc(), a()); + auto bt = get_coopmatrix_type(loc(), b()); + auto ct = get_coopmatrix_type(loc(), c()); + if (at->use() != matrix_use::a) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_matrix_use); + } + if (bt->use() != matrix_use::b) { + throw compilation_error(loc(), {&b()}, status::ir_invalid_matrix_use); + } + if (ct->use() != matrix_use::acc) { + throw compilation_error(loc(), {&c()}, status::ir_invalid_matrix_use); + } + + auto M = rt->rows(); + auto N = rt->cols(); + auto K = at->cols(); + if (rt->rows() != M || rt->cols() != N || ct->rows() != M || ct->cols() != N || + at->rows() != M || bt->rows() != K || bt->cols() != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << at->rows() << "x" << at->cols() << ", "; + oss << "B=" << bt->rows() << "x" << bt->cols() << ", "; + oss << "C=" << ct->rows() << "x" << ct->cols() << ", "; + oss << "result=" << rt->rows() << "x" << rt->cols(); + throw compilation_error(loc(), {&a(), &b(), &c()}, status::ir_incompatible_shapes, + oss.str()); + } + + const auto AB_ty = promote(at->component_ty(), bt->component_ty()); + if (!AB_ty) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_forbidden_promotion); + } + if (!promotable(AB_ty, ct->component_ty())) { + throw compilation_error(loc(), {&a(), &b(), &c()}, status::ir_forbidden_promotion); + } + if (!is_cast_allowed(ct->component_ty(), rt->component_ty())) { + throw compilation_error(loc(), {&c()}, status::ir_forbidden_cast); + } +} + +auto cooperative_matrix_mul_add_inst::is_c_zero() -> bool { + if (auto c_def = c().defining_inst(); c_def) { + if (auto c_def_const = dyn_cast(c_def); c_def_const) { + return std::visit( + overloaded{[](bool) { return false; }, [](auto v) { return v == decltype(v){0}; }}, + c_def_const.value()); + } + } + return false; +} + +void cooperative_matrix_prefetch_inst::setup_and_check() { + auto ot = get_memref_type(loc(), operand()); + if (ot->dim() != 2) { + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); + } + if (rows() <= 0 || cols() <= 0) { + throw compilation_error(loc(), {}, status::ir_invalid_shape); + } + + check_index_ty(loc(), pos0()); + check_index_ty(loc(), pos1()); +} + +void cooperative_matrix_reduce_inst::setup_and_check() { + auto at = get_coopmatrix_type(loc(), a()); + auto rt = get_coopmatrix_type(loc(), result().ty()); + if (at->component_ty() != rt->component_ty()) { + throw compilation_error(loc(), {&a()}, status::ir_number_mismatch); + } + if (at->use() != rt->use()) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_matrix_use); + } + const int m = mode() == reduce_mode::column ? 0 : 1; + if (rt->shape(1 - m) != at->shape(1 - m) || rt->shape(m) != 1) { + throw compilation_error(loc(), {&a()}, status::ir_invalid_shape); + } +} +void cooperative_matrix_reduce_add_inst::setup_and_check() { + cooperative_matrix_reduce_inst::setup_and_check(); +} +void cooperative_matrix_reduce_max_inst::setup_and_check() { + cooperative_matrix_reduce_inst::setup_and_check(); +} +void cooperative_matrix_reduce_min_inst::setup_and_check() { + cooperative_matrix_reduce_inst::setup_and_check(); +} + +void cooperative_matrix_scale_inst::setup_and_check() { + auto ty = result().ty(); + + if (b().ty() != ty) { + throw compilation_error(loc(), {&b()}, status::ir_operand_type_must_match_return_type); + } + + auto bt = get_coopmatrix_type(loc(), b()); + + if (a().ty() != bt->component_ty()) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_number_mismatch); + } +} + +void cooperative_matrix_store_inst::setup_and_check() { + auto vt = get_coopmatrix_type(loc(), val()); + auto ot = get_memref_type(loc(), operand()); + if (vt->component_ty() != ot->element_ty()) { + throw compilation_error(loc(), {&val(), &operand()}, status::ir_number_mismatch); + } + if (ot->dim() != 2) { + throw compilation_error(loc(), {&operand()}, status::ir_expected_memref_order_2); + } + + check_index_ty(loc(), pos0()); + check_index_ty(loc(), pos1()); +} + +void expand_inst::setup_and_check() { + for (auto &es : expand_shape()) { + check_index_ty(loc(), es); + } + + auto ty = result().ty(); + + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + bool const range_ok = 0 <= expanded_mode() && expanded_mode() < ot->dim(); + if (!range_ok) { + throw compilation_error(loc(), {&operand()}, status::ir_out_of_bounds); + } + + if (static_expand_shape().size() < 2) { + throw compilation_error(loc(), status::ir_expand_shape_order_too_small); + } + if (std::count(static_expand_shape().begin(), static_expand_shape().end(), dynamic) != + expand_shape().size()) { + throw compilation_error(loc(), status::ir_expand_shape_mismatch); + } + + for (std::int64_t i = 0; i < expanded_mode(); ++i) { + check_memref_mode(rt, i, ot, i, loc()); + } + auto stride = ot->stride(expanded_mode()); + for (std::size_t i = 0; i < static_expand_shape().size(); ++i) { + const auto mode = expanded_mode() + i; + if (rt->shape(mode) != static_expand_shape()[i]) { + auto extra_info = std::ostringstream{} + << "Size of mode " << mode << " does not match static expand shape (" + << rt->shape(mode) << "!=" << static_expand_shape()[i] << ")"; + throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); + } + if (!is_dynamic_value(rt->stride(mode)) && rt->stride(mode) != stride) { + auto extra_info = std::ostringstream{} << "Stride of mode " << mode << " is invalid (" + << rt->stride(mode) << "!=" << stride << ")"; + throw compilation_error(loc(), status::ir_invalid_stride, std::move(extra_info).str()); + } + stride = is_dynamic_value(stride) || is_dynamic_value(rt->shape(mode)) + ? dynamic + : stride * rt->shape(mode); + } + for (std::int64_t i = expanded_mode() + 1; i < ot->dim(); ++i) { + check_memref_mode(rt, i + static_expand_shape().size() - 1, ot, i, loc()); + } +} + +void fuse_inst::setup_and_check() { + auto ty = result().ty(); + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + bool const range_ok = 0 <= from() && from() < to() && to() < ot->dim(); + if (!range_ok) { + throw compilation_error(loc(), status::ir_out_of_bounds); + } + + for (std::int64_t i = 0; i < from(); ++i) { + check_memref_mode(rt, i, ot, i, loc()); + } + + std::int64_t prod = 1; + for (std::int64_t i = from(); i <= to(); ++i) { + if (is_dynamic_value(ot->shape(i))) { + prod = dynamic; + break; + } + prod *= ot->shape(i); + } + if (rt->shape(from()) != prod) { + auto extra_info = std::ostringstream{} << "Size of mode " << from() + << " does not match shape product (" + << rt->shape(from()) << "!=" << prod << ")"; + throw compilation_error(loc(), status::ir_invalid_shape, std::move(extra_info).str()); + } + check_memref_stride(rt, from(), ot, from(), loc()); + + for (std::int64_t i = to() + 1; i < ot->dim(); ++i) { + check_memref_mode(rt, i - to() + from(), ot, i, loc()); + } +} + +void if_inst::setup_and_check() { + then().loc(loc()); + otherwise().loc(loc()); + if (!isa(*condition().ty())) { + throw compilation_error(loc(), {&condition()}, status::ir_expected_boolean); + } + + for (auto &r : results()) { + auto &ty = *r.ty(); + if (!isa(ty) && !isa(ty) && !isa(ty)) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_number_or_boolean); + } + } +} + +auto if_inst::is_otherwise_empty() -> bool { return otherwise().insts().empty(); } + +void lifetime_stop_inst::setup_and_check() {} + +void load_inst::setup_and_check() { + auto ty = result().ty(); + + visit( + overloaded{[&](group_type &g) { + if (g.element_ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } + if (static_cast(index_list().size()) != 1) { + throw compilation_error(loc(), status::ir_invalid_number_of_indices); + } + }, + [&](memref_type &m) { + if (m.element_ty() != ty) { + throw compilation_error(loc(), {&operand()}, + status::ir_operand_type_must_match_return_type); + } + if (m.dim() != static_cast(index_list().size())) { + throw compilation_error(loc(), status::ir_invalid_number_of_indices); + } + }, + [&](tinytc_type &) { + throw compilation_error(loc(), status::ir_expected_memref_or_group); + }}, + *operand().ty()); +} + +void parallel_inst::setup_and_check() { + body().kind(region_kind::spmd); + body().loc(loc()); +} + +void size_inst::setup_and_check() { + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_index); + } + + const bool range_ok = + visit(overloaded{[&](group_type &) -> bool { return 0 <= mode() && mode() < 1; }, + [&](memref_type &m) -> bool { return 0 <= mode() && mode() < m.dim(); }, + [&](tinytc_type &) -> bool { + throw compilation_error(loc(), status::ir_expected_memref_or_group); + }}, + *operand().ty()); + if (!range_ok) { + throw compilation_error(loc(), status::ir_out_of_bounds); + } +} + +void subgroup_broadcast_inst::setup_and_check() { + auto ty = result().ty(); + if (!isa(*ty)) { + throw compilation_error(loc(), status::ir_expected_number); + } + + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } + + if (!isa(*idx().ty())) { + throw compilation_error(loc(), {&idx()}, status::ir_expected_i32); + } +} + +void subview_inst::setup_and_check() { + for (auto &val : offsets()) { + check_index_ty(loc(), val); + } + for (auto &val : sizes()) { + check_index_ty(loc(), val); + } + + auto ty = result().ty(); + auto [ot, rt] = get_and_check_memref_type_addrspace(operand(), ty, loc()); + + if (ot->dim() != static_cast(static_offsets().size()) || + ot->dim() != static_cast(static_sizes().size())) { + throw compilation_error(loc(), status::ir_invalid_number_of_indices); + } + if (std::count(static_offsets().begin(), static_offsets().end(), dynamic) != offsets().size() || + std::count(static_sizes().begin(), static_sizes().end(), dynamic) != sizes().size()) { + throw compilation_error(loc(), status::ir_subview_mismatch); + } + + std::int64_t ri = 0; + for (std::int64_t i = 0; i < ot->dim(); ++i) { + auto offset = static_offsets()[i]; + auto size = static_sizes()[i]; + if ((offset < 0 && !is_dynamic_value(offset)) || (size < 0 && !is_dynamic_value(size))) { + throw compilation_error(loc(), status::ir_invalid_slice); + } + if (size > 0 || is_dynamic_value(size)) { + if (rt->shape(ri) != size) { + auto extra_info = std::ostringstream{} << "Size of mode " << ri + << " does not match slice size [" + << rt->shape(ri) << "!=" << size << "]"; + throw compilation_error(loc(), status::ir_invalid_shape, + std::move(extra_info).str()); + } + check_memref_stride(rt, ri, ot, i, loc()); + ++ri; + } + } +} + +void store_inst::setup_and_check() { + for (auto &val : index_list()) { + check_index_ty(loc(), val); + } + + auto o = get_memref_type(loc(), operand()); + + if (val().ty() != o->element_ty()) { + throw compilation_error(loc(), {&val(), &operand()}, status::ir_number_mismatch); + } + + if (o->dim() != static_cast(index_list().size())) { + throw compilation_error(loc(), {&operand()}, status::ir_invalid_number_of_indices); + } +} + +void yield_inst::setup_and_check() {} + +void arith_inst::setup_and_check() {} +void arith_inst::setup_and_check(support_flags support) { + auto ty = result().ty(); + + if (a().ty() != ty) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } + if (b().ty() != ty) { + throw compilation_error(loc(), {&b()}, status::ir_operand_type_must_match_return_type); + } + + if (isa(*ty)) { + if (!(support & supports_bool)) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + } else { + auto const check_scalar_ty = [&](tinytc_type_t ty) { + if (!(support & supports_float) && isa(*ty)) { + throw compilation_error(loc(), status::ir_fp_unsupported); + } + if (!(support & supports_complex) && isa(*ty)) { + throw compilation_error(loc(), status::ir_complex_unsupported); + } + }; + + if (auto ct = dyn_cast(ty); ct) { + check_scalar_ty(ct->component_ty()); + } else if (isa(*ty)) { + check_scalar_ty(ty); + } else { + throw compilation_error(loc(), status::ir_expected_coopmatrix_or_number); + } + } +} +void add_inst::setup_and_check() { + arith_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void sub_inst::setup_and_check() { + arith_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void mul_inst::setup_and_check() { + arith_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void div_inst::setup_and_check() { + arith_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void rem_inst::setup_and_check() { arith_inst::setup_and_check(supports_int | supports_float); } +void max_inst::setup_and_check() { arith_inst::setup_and_check(supports_int | supports_float); } +void min_inst::setup_and_check() { arith_inst::setup_and_check(supports_int | supports_float); } +void shl_inst::setup_and_check() { arith_inst::setup_and_check(supports_int); } +void shr_inst::setup_and_check() { arith_inst::setup_and_check(supports_int); } +void and_inst::setup_and_check() { arith_inst::setup_and_check(supports_bool | supports_int); } +void or_inst::setup_and_check() { arith_inst::setup_and_check(supports_bool | supports_int); } +void xor_inst::setup_and_check() { arith_inst::setup_and_check(supports_bool | supports_int); } + +void arith_unary_inst::setup_and_check() {} +void arith_unary_inst::setup_and_check(support_flags support, bool component_type_match) { + auto ty = result().ty(); + + if (isa(*ty)) { + if (!(support & supports_bool)) { + throw compilation_error(loc(), status::ir_boolean_unsupported); + } + } else { + auto const check_scalar_ty = [&](tinytc_type_t a_ty, tinytc_type_t r_ty) { + if (component_type_match) { + if (r_ty != component_type(a_ty)) { + throw compilation_error(loc(), {&a()}, + status::ir_operand_type_must_match_return_type); + } + } else { + if (a_ty != r_ty) { + throw compilation_error(loc(), {&a()}, + status::ir_operand_type_must_match_return_type); + } + } + if (!(support & supports_int) && isa(*a_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_int_unsupported); + } + if (!(support & supports_float) && isa(*a_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_fp_unsupported); + } + if (!(support & supports_complex) && isa(*a_ty)) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + }; + + auto ct = dyn_cast(a().ty()); + auto rt = dyn_cast(ty); + if (ct && rt) { + check_scalar_ty(ct->component_ty(), rt->component_ty()); + } else if (isa(*a().ty()) && isa(*ty)) { + check_scalar_ty(a().ty(), ty); + } else { + throw compilation_error(loc(), {&a()}, status::ir_expected_coopmatrix_or_number); + } + } +} +void abs_inst::setup_and_check() { + arith_unary_inst::setup_and_check(supports_int | supports_float | supports_complex, true); +} +void neg_inst::setup_and_check() { + arith_unary_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void not_inst::setup_and_check() { + arith_unary_inst::setup_and_check(supports_bool | supports_int); +} +void conj_inst::setup_and_check() { arith_unary_inst::setup_and_check(supports_complex); } +void im_inst::setup_and_check() { arith_unary_inst::setup_and_check(supports_complex, true); } +void re_inst::setup_and_check() { arith_unary_inst::setup_and_check(supports_complex, true); } + +void blas_a2_inst::setup_and_check() { + auto At = get_memref_type(loc(), A()); + auto Bt = get_memref_type(loc(), B()); + + if (!promotable(alpha().ty(), At->element_ty())) { + throw compilation_error(loc(), {&alpha(), &A()}, status::ir_forbidden_promotion); + } + if (!promotable(At->element_ty(), Bt->element_ty())) { + throw compilation_error(loc(), {&A(), &B()}, status::ir_forbidden_promotion); + } + if (!promotable(beta().ty(), Bt->element_ty())) { + throw compilation_error(loc(), {&beta(), &B()}, status::ir_forbidden_promotion); + } +} + +void axpby_inst::setup_and_check() { + blas_a2_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + + if (b->dim() < 0 || b->dim() > 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_0_1_or_2); + } + + bool shape_equal = false; + if (tA() == transpose::T && a->dim() == 2 && b->dim() == 2) { + shape_equal = a->shape()[1] == b->shape()[0] && a->shape()[0] == b->shape()[1]; + } else { + shape_equal = a->shape() == b->shape(); + } + + if (!shape_equal) { + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); + } +} + +void cumsum_inst::setup_and_check() { + blas_a2_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + + if (a->dim() < 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_non_scalar_memref); + } + if (mode() >= a->dim()) { + throw compilation_error(loc(), {&A()}, status::ir_out_of_bounds); + } + + bool shape_equal = a->dim() == b->dim(); + if (shape_equal) { + for (std::int64_t i = 0; i < a->dim(); ++i) { + shape_equal = shape_equal && a->shape(i) == b->shape(i); + } + } + + if (!shape_equal) { + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); + } +} + +void sum_inst::setup_and_check() { + blas_a2_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + + if (b->dim() == 1 && a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() == 0 && a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 0 && b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_0_or_1); + } + + if (a->dim() == 2) { + if (a->shape(tA() == transpose::T ? 1 : 0) != b->shape(0)) { + throw compilation_error(loc(), {&A(), &B()}, status::ir_incompatible_shapes); + } + } +} + +void blas_a3_inst::setup_and_check() { + auto At = get_memref_type(loc(), A()); + auto Bt = get_memref_type(loc(), B()); + auto Ct = get_memref_type(loc(), C()); + + const auto AB_ty = promote(At->element_ty(), Bt->element_ty()); + if (!AB_ty) { + throw compilation_error(loc(), {&A(), &B()}, status::ir_forbidden_promotion); + } + if (!promotable(alpha().ty(), AB_ty)) { + throw compilation_error(loc(), {&alpha(), &A(), &B()}, status::ir_forbidden_promotion); + } + if (!promotable(AB_ty, Ct->element_ty())) { + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_forbidden_promotion); + } + if (!promotable(beta().ty(), Ct->element_ty())) { + throw compilation_error(loc(), {&beta(), &C()}, status::ir_forbidden_promotion); + } +} + +void gemm_inst::setup_and_check() { + blas_a3_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() != 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_2); + } + if (c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_2); + } + + auto ak = tA() == transpose::T ? 0 : 1; + auto bk = tB() == transpose::T ? 1 : 0; + auto M = c->shape(0); + auto N = c->shape(1); + auto K = a->shape(ak); + if (a->shape(1 - ak) != M || b->shape(bk) != K || b->shape(1 - bk) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } +} + +void gemv_inst::setup_and_check() { + blas_a3_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_2); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 1) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_1); + } + + auto ak = tA() == transpose::T ? 0 : 1; + auto M = c->shape(0); + auto K = a->shape(ak); + if (a->shape(1 - ak) != M || b->shape(0) != K) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "c=" << c->shape(0); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } +} + +void ger_inst::setup_and_check() { + blas_a3_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 1) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1); + } + if (b->dim() != 1) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1); + } + if (c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_2); + } + + auto M = c->shape(0); + auto N = c->shape(1); + if (a->shape(0) != M || b->shape(0) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "a=" << a->shape(0) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } +} + +void hadamard_inst::setup_and_check() { + blas_a3_inst::setup_and_check(); + + auto a = get_memref_type(loc(), A()); + auto b = get_memref_type(loc(), B()); + auto c = get_memref_type(loc(), C()); + + if (a->dim() != 1 && a->dim() != 2) { + throw compilation_error(loc(), {&A()}, status::ir_expected_memref_order_1_or_2); + } + if (b->dim() != 1 && b->dim() != 2) { + throw compilation_error(loc(), {&B()}, status::ir_expected_memref_order_1_or_2); + } + if (c->dim() != 1 && c->dim() != 2) { + throw compilation_error(loc(), {&C()}, status::ir_expected_memref_order_1_or_2); + } + if (c->dim() != a->dim() || c->dim() != b->dim()) { + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes); + } + + auto M = c->shape(0); + if (c->dim() == 1) { + if (a->shape(0) != M || b->shape(0) != M) { + std::ostringstream oss; + oss << "Got "; + oss << "a=" << a->shape(0) << ", "; + oss << "b=" << b->shape(0) << ", "; + oss << "c=" << c->shape(0); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } + } else if (c->dim() == 2) { + auto N = c->shape(1); + if (a->shape(0) != M || a->shape(1) != N || b->shape(0) != M || b->shape(1) != N) { + std::ostringstream oss; + oss << "Got "; + oss << "A=" << a->shape(0) << "x" << a->shape(1) << ", "; + oss << "B=" << b->shape(0) << "x" << b->shape(1) << ", "; + oss << "C=" << c->shape(0) << "x" << c->shape(1); + throw compilation_error(loc(), {&A(), &B(), &C()}, status::ir_incompatible_shapes, + oss.str()); + } + } +} + +void builtin_inst::setup_and_check() {} + +void group_id_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_index); + } +} +void num_groups_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_index); + } +} +void num_subgroups_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_i32); + } +} +void subgroup_size_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_i32); + } +} +void subgroup_id_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_i32); + } +} +void subgroup_linear_id_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_i32); + } +} +void subgroup_local_id_inst::setup_and_check() { + builtin_inst::setup_and_check(); + if (!isa(*result().ty())) { + throw compilation_error(loc(), status::ir_expected_i32); + } +} + +void compare_inst::setup_and_check() {} +void compare_inst::setup_and_check(support_flags support) { + auto ty = result().ty(); + + if (!isa(*ty)) { + throw compilation_error(loc(), status::ir_expected_boolean); + } + + if (!isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_expected_number); + } + if (!(support & supports_complex) && isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + + if (a().ty() != b().ty()) { + throw compilation_error(loc(), {&a(), &b()}, status::ir_number_mismatch); + } +} +void equal_inst::setup_and_check() { + compare_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void not_equal_inst::setup_and_check() { + compare_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void greater_than_inst::setup_and_check() { + compare_inst::setup_and_check(supports_int | supports_float); +} +void greater_than_equal_inst::setup_and_check() { + compare_inst::setup_and_check(supports_int | supports_float); +} +void less_than_inst::setup_and_check() { + compare_inst::setup_and_check(supports_int | supports_float); +} +void less_than_equal_inst::setup_and_check() { + compare_inst::setup_and_check(supports_int | supports_float); +} + +void loop_inst::setup_and_check() {} + +void for_inst::setup_and_check() { + loop_inst::setup_and_check(); + + if (!isa(*from().ty())) { + throw compilation_error(loc(), {&from()}, status::ir_expected_int); + } + if (from().ty() != to().ty()) { + throw compilation_error(loc(), {&from(), &to()}, status::ir_number_mismatch); + } + if (has_step()) { + if (from().ty() != step().ty()) { + throw compilation_error(loc(), {&from(), &step()}, status::ir_number_mismatch); + } + } + + auto res = results(); + body().set_num_params(1 + res.size()); + body().set_param(0, from().ty()); + + auto init = iter_init(); + if (init.size() != res.size()) { + throw compilation_error(loc(), status::ir_init_return_type_mismatch); + } + for (std::int64_t i = 0; i < res.size(); ++i) { + auto ty = res[i].ty(); + if (init[i].ty() != ty) { + throw compilation_error(loc(), {&init[i]}, status::ir_init_return_type_mismatch); + } + if (!isa(*ty) && !isa(*ty) && !isa(*ty)) { + throw compilation_error(loc(), status::ir_expected_coopmatrix_number_or_boolean); + } + body().set_param(1 + i, ty); + } +} + +void foreach_inst::setup_and_check() { + loop_inst::setup_and_check(); + + auto from_ = from(); + auto to_ = to(); + if (from_.size() == 0 || from_.size() != to_.size()) { + throw compilation_error(loc(), status::ir_from_to_mismatch); + } + + auto num_lv = from_.size(); + body().kind(region_kind::spmd); + body().set_num_params(num_lv); + for (std::int64_t i = 0; i < num_lv; ++i) { + if (!isa(*from_[i].ty())) { + throw compilation_error(loc(), {&from_[i]}, status::ir_expected_int); + } + if (from_[i].ty() != to_[i].ty()) { + throw compilation_error(loc(), {&from_[i], &to_[i]}, status::ir_number_mismatch); + } + body().set_param(i, from_[i].ty()); + } +} + +void math_unary_inst::setup_and_check() {} +void math_unary_inst::setup_and_check(support_flags support) { + if (!isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_expected_number); + } + + if (!(support & supports_int) && isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_int_unsupported); + } else if (!(support & supports_float) && isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_fp_unsupported); + } else if (!(support & supports_complex) && isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + + if (a().ty() != result().ty()) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } +} +void cos_inst::setup_and_check() { math_unary_inst::setup_and_check(supports_float); } +void sin_inst::setup_and_check() { math_unary_inst::setup_and_check(supports_float); } +void exp_inst::setup_and_check() { + math_unary_inst::setup_and_check(supports_float | supports_complex); +} +void exp2_inst::setup_and_check() { + math_unary_inst::setup_and_check(supports_float | supports_complex); +} +void native_cos_inst::setup_and_check() { math_unary_inst::setup_and_check(supports_float); } +void native_sin_inst::setup_and_check() { math_unary_inst::setup_and_check(supports_float); } +void native_exp_inst::setup_and_check() { + math_unary_inst::setup_and_check(supports_float | supports_complex); +} +void native_exp2_inst::setup_and_check() { + math_unary_inst::setup_and_check(supports_float | supports_complex); +} + +void subgroup_operation_inst::setup_and_check() {} +void subgroup_operation_inst::setup_and_check(support_flags support) { + if (!isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_expected_number); + } + if (!(support & supports_complex) && isa(*a().ty())) { + throw compilation_error(loc(), {&a()}, status::ir_complex_unsupported); + } + + if (a().ty() != result().ty()) { + throw compilation_error(loc(), {&a()}, status::ir_operand_type_must_match_return_type); + } +} +void subgroup_exclusive_scan_add_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void subgroup_exclusive_scan_max_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float); +} +void subgroup_exclusive_scan_min_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float); +} +void subgroup_inclusive_scan_add_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void subgroup_inclusive_scan_max_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float); +} +void subgroup_inclusive_scan_min_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float); +} +void subgroup_reduce_add_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float | supports_complex); +} +void subgroup_reduce_max_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float); +} +void subgroup_reduce_min_inst::setup_and_check() { + subgroup_operation_inst::setup_and_check(supports_int | supports_float); +} + +} // namespace tinytc diff --git a/src/prog.cpp b/src/node/prog.cpp similarity index 55% rename from src/prog.cpp rename to src/node/prog.cpp index 7ee2c714..a1482a7a 100644 --- a/src/prog.cpp +++ b/src/node/prog.cpp @@ -1,16 +1,16 @@ // Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: BSD-3-Clause +#include "node/prog.hpp" #include "error.hpp" #include "location.hpp" -#include "node/program_node.hpp" +#include "pass/dump_ir.hpp" #include "passes.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" +#include "tinytc/builder.h" +#include "tinytc/core.h" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include #include #include #include @@ -18,27 +18,33 @@ #include #include #include -#include using namespace tinytc; +tinytc_prog::tinytc_prog(shared_handle ctx, tinytc_location const &lc) + : ctx_{std::move(ctx)} { + loc(lc); +} + extern "C" { -tinytc_status_t tinytc_program_create(tinytc_prog_t *prg, uint32_t fun_list_size, - tinytc_func_t *fun_list, const tinytc_location_t *loc) { - if (prg == nullptr || (fun_list_size > 0 && fun_list == nullptr)) { +tinytc_status_t tinytc_prog_create(tinytc_prog_t *prg, tinytc_compiler_context_t ctx, + const tinytc_location_t *loc) { + if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto fun_vec = std::vector(); - fun_vec.reserve(fun_list_size); - for (uint32_t i = 0; i < fun_list_size; ++i) { - fun_vec.emplace_back(func(fun_list[i], true)); - } - *prg = std::make_unique(std::move(fun_vec), get_optional(loc)).release(); + *prg = std::make_unique(shared_handle{ctx, true}, get_optional(loc)).release(); }); } +tinytc_status_t tinytc_prog_add_function(tinytc_prog_t prg, tinytc_func_t fun) { + if (prg == nullptr || fun == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { prg->push_back(tinytc::unique_handle(fun)); }); +} + tinytc_status_t tinytc_prog_release(tinytc_prog_t obj) { if (obj == nullptr) { return tinytc_status_invalid_arguments; @@ -58,14 +64,22 @@ tinytc_status_t tinytc_prog_retain(tinytc_prog_t obj) { return tinytc_status_success; } -tinytc_status_t tinytc_prog_dump(const_tinytc_prog_t prg) { +tinytc_status_t tinytc_prog_dump(tinytc_prog_t prg) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } - return exception_to_status_code([&] { dump_ir(std::cerr, *prg); }); + return exception_to_status_code([&] { run_function_pass(dump_ir_pass{std::cerr}, *prg); }); +} + +tinytc_status_t tinytc_prog_get_compiler_context(const_tinytc_prog_t prg, + tinytc_compiler_context_t *ctx) { + if (prg == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = prg->context(); }); } -tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *filename) { +tinytc_status_t tinytc_prog_print_to_file(tinytc_prog_t prg, char const *filename) { if (prg == nullptr || filename == nullptr) { return tinytc_status_invalid_arguments; } @@ -74,18 +88,18 @@ tinytc_status_t tinytc_prog_print_to_file(const_tinytc_prog_t prg, char const *f if (!stream.good()) { throw status::file_io_error; } - dump_ir(stream, *prg); + run_function_pass(dump_ir_pass{stream}, *prg); }); } -tinytc_status_t tinytc_prog_print_to_string(const_tinytc_prog_t prg, char **str) { +tinytc_status_t tinytc_prog_print_to_string(tinytc_prog_t prg, char **str) { if (prg == nullptr || str == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { auto const text = [&] { auto oss = std::ostringstream{}; - dump_ir(oss, *prg); + run_function_pass(dump_ir_pass{oss}, *prg); return std::move(oss).str(); }(); auto const length = text.size() + 1; // Need to include terminating null character diff --git a/src/node/prog.hpp b/src/node/prog.hpp new file mode 100644 index 00000000..2316ac04 --- /dev/null +++ b/src/node/prog.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef PROG_20250626_HPP +#define PROG_20250626_HPP + +#include "reference_counted.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/iterator.hpp" + +#include +#include + +struct tinytc_prog final : tinytc::reference_counted { + public: + using container_t = std::vector>; + + using iterator = tinytc::indirect_random_access_iterator; + using const_iterator = tinytc::indirect_random_access_iterator; + + tinytc_prog(tinytc::shared_handle ctx, + tinytc_location const &lc = {}); + + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::shared_handle { + return ctx_; + } + + inline auto loc() const noexcept -> tinytc_location const & { return loc_; } + inline void loc(tinytc_location const &loc) noexcept { loc_ = loc; } + + inline auto begin() -> iterator { return iterator{funcs_.begin()}; } + inline auto end() -> iterator { return iterator{funcs_.end()}; } + inline auto begin() const -> const_iterator { return const_iterator{funcs_.begin()}; } + inline auto end() const -> const_iterator { return const_iterator{funcs_.end()}; } + inline void push_back(tinytc::unique_handle &&fun) { + funcs_.push_back(std::move(fun)); + } + + private: + tinytc::shared_handle ctx_; + container_t funcs_; + tinytc_location loc_; +}; + +namespace tinytc { + +using program_node = ::tinytc_prog; + +} // namespace tinytc + +#endif // PROG_20250626_HPP diff --git a/src/node/program_node.hpp b/src/node/program_node.hpp deleted file mode 100644 index 372c41d7..00000000 --- a/src/node/program_node.hpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef PROGRAM_NODE_20240208_HPP -#define PROGRAM_NODE_20240208_HPP - -#include "location.hpp" -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" - -#include - -#include -#include - -namespace tinytc { -using program_nodes = clir::virtual_type_list; -} - -struct tinytc_prog : tinytc::reference_counted, tinytc::program_nodes { - public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - - private: - tinytc::location loc_; -}; - -namespace tinytc { - -using program_node = ::tinytc_prog; - -class program : public clir::visitable { - public: - inline program(std::vector decls, location const &lc = {}) : decls_(std::move(decls)) { - loc(lc); - } - inline auto declarations() -> std::vector & { return decls_; } - inline auto declarations() const -> std::vector const & { return decls_; } - - private: - std::vector decls_; -}; - -} // namespace tinytc - -#endif // PROGRAM_NODE_20240208_HPP diff --git a/src/node/region.cpp b/src/node/region.cpp new file mode 100644 index 00000000..e9998941 --- /dev/null +++ b/src/node/region.cpp @@ -0,0 +1,143 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/region.hpp" +#include "error.hpp" +#include "node/inst.hpp" +#include "node/inst.hpp" // IWYU pragma: keep +#include "tinytc/builder.h" +#include "tinytc/types.h" +#include "util/ilist.hpp" + +#include +#include +#include + +using namespace tinytc; + +namespace tinytc { + +auto ilist_callbacks::get_parent_region() -> tinytc_region * { + return reinterpret_cast(reinterpret_cast(this) - + tinytc_region::inst_list_offset()); +} + +void ilist_callbacks::node_added(tinytc_inst_t node) { + node->parent(get_parent_region()); +} +void ilist_callbacks::node_moved(tinytc_inst_t node) { + node->parent(get_parent_region()); +} +void ilist_callbacks::node_removed(tinytc_inst_t node) { tinytc_inst_destroy(node); } + +} // namespace tinytc + +tinytc_region::tinytc_region(tinytc_inst_t def_inst) + : def_inst_{def_inst}, kind_{region_kind::mixed} {} + +tinytc_region::~tinytc_region() {} + +void tinytc_region::loc(location const &loc) { + loc_ = loc; + for (auto ¶m : params_) { + param.loc(loc_); + } +} +void tinytc_region::defining_inst(tinytc_inst_t def_inst) { + def_inst_ = def_inst; + for (auto ¶m : params_) { + param.defining_inst(def_inst_); + } +} + +void tinytc_region::set_params(array_view param_types) { + params_.resize(param_types.size()); + for (std::size_t i = 0; i < param_types.size(); ++i) { + set_param(i, param_types[i]); + } +} + +void tinytc_region::set_num_params(std::size_t num_params) { params_.resize(num_params); } +void tinytc_region::set_param(std::size_t idx, tinytc_type_t param_type) { + params_[idx] = tinytc_value{param_type, def_inst_, loc_}; +} + +extern "C" { + +tinytc_status_t tinytc_region_append(tinytc_region_t reg, tinytc_inst_t instr) { + if (reg == nullptr || instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { reg->insts().push_back(instr); }); +} + +tinytc_status_t tinytc_region_begin(tinytc_region_t reg, tinytc_inst_iterator_t *iterator) { + if (reg == nullptr || iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *iterator = reg->insts().begin().get(); }); +} + +tinytc_status_t tinytc_region_end(tinytc_region_t reg, tinytc_inst_iterator_t *iterator) { + if (reg == nullptr || iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *iterator = reg->insts().end().get(); }); +} + +tinytc_status_t tinytc_region_erase(tinytc_region_t reg, tinytc_inst_iterator_t *iterator) { + if (reg == nullptr || iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *iterator = reg->insts().erase(*iterator).get(); }); +} + +tinytc_status_t tinytc_region_insert(tinytc_region_t reg, tinytc_inst_iterator_t *iterator, + tinytc_inst_t instr) { + if (reg == nullptr || iterator == nullptr || instr == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *iterator = reg->insts().insert(*iterator, instr).get(); }); +} + +tinytc_status_t tinytc_next_inst(tinytc_inst_iterator_t *iterator) { + if (iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *iterator = static_cast((*iterator)->next()); }); +} + +tinytc_status_t tinytc_prev_inst(tinytc_inst_iterator_t *iterator) { + if (iterator == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code( + [&] { *iterator = static_cast((*iterator)->prev()); }); +} + +tinytc_status_t tinytc_region_get_parameters(tinytc_region_t reg, size_t *result_list_size, + tinytc_value_t *result_list) { + + if (reg == nullptr || result_list_size == nullptr || + (*result_list_size > 0 && result_list == nullptr)) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const num_results = reg->num_params(); + if (num_results < 0) { + throw std::out_of_range("number of results must not be negative"); + } + auto num = static_cast(num_results); + if (*result_list_size > 0) { + auto results = reg->param_begin(); + num = std::min(num, *result_list_size); + for (size_t i = 0; i < num; ++i) { + result_list[i] = &results[i]; + } + } + *result_list_size = num; + }); +} +} diff --git a/src/node/region.hpp b/src/node/region.hpp new file mode 100644 index 00000000..0a5ad17c --- /dev/null +++ b/src/node/region.hpp @@ -0,0 +1,93 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef REGION_20250626_HPP +#define REGION_20250626_HPP + +#include "node/value.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/iterator.hpp" + +#include +#include +#include + +namespace tinytc { + +//! Instruction classification +enum class region_kind { mixed = 0x0, collective = 0x1, spmd = 0x2 }; + +template <> struct ilist_callbacks { + auto get_parent_region() -> tinytc_region *; + void node_added(tinytc_inst_t node); + void node_moved(tinytc_inst_t node); + void node_removed(tinytc_inst_t node); +}; + +} // namespace tinytc + +struct alignas(8) tinytc_region final { + public: + using iterator = tinytc::ilist::iterator; + using const_iterator = tinytc::ilist::const_iterator; + + tinytc_region(tinytc_inst_t def_inst = nullptr); + ~tinytc_region(); + + tinytc_region(tinytc_region const &) = delete; + tinytc_region(tinytc_region &&) = delete; + tinytc_region &operator=(tinytc_region const &) = delete; + tinytc_region &operator=(tinytc_region &&) = delete; + + inline auto kind() const noexcept -> tinytc::region_kind { return kind_; } + inline void kind(tinytc::region_kind kind) noexcept { kind_ = kind; } + + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } + void loc(tinytc::location const &loc); + + // Can be nullptr, e.g. if the region is the body of a function + inline auto defining_inst() const -> tinytc_inst_t { return def_inst_; } + void defining_inst(tinytc_inst_t def_inst); + + inline auto begin() -> iterator { return insts_.begin(); } + inline auto end() -> iterator { return insts_.end(); } + inline auto insts() -> tinytc::ilist & { return insts_; } + inline auto begin() const -> const_iterator { return insts_.cbegin(); } + inline auto end() const -> const_iterator { return insts_.cend(); } + inline auto insts() const -> tinytc::ilist const & { return insts_; } + inline auto empty() const -> bool { return insts_.empty(); } + + inline auto param_begin() { return params_.begin(); } + inline auto param_end() { return params_.end(); } + inline auto params() { return tinytc::iterator_range_wrapper{param_begin(), param_end()}; } + inline auto param_begin() const { return params_.begin(); } + inline auto param_end() const { return params_.end(); } + inline auto param(std::size_t pos) -> tinytc_value & { return params_[pos]; } + inline auto param(std::size_t pos) const -> tinytc_value const & { return params_[pos]; } + inline auto params() const { + return tinytc::iterator_range_wrapper{param_begin(), param_end()}; + } + inline auto num_params() const noexcept -> std::size_t { return params_.size(); } + void set_params(tinytc::array_view param_types); + void set_num_params(std::size_t num_params); + void set_param(std::size_t idx, tinytc_type_t param_type); + + private: + static auto inst_list_offset() -> std::size_t { + static_assert(std::is_standard_layout_v, "offsetof not guaranteed to work"); + return offsetof(tinytc_region, insts_); + } + friend struct tinytc::ilist_callbacks; + + tinytc_inst_t def_inst_; + tinytc::region_kind kind_; + tinytc::location loc_; + // params_ must come before insts_ such that the dtors are called in the correct order + std::vector params_; + tinytc::ilist insts_; +}; + +#endif // REGION_20250626_HPP diff --git a/src/node/region_node.hpp b/src/node/region_node.hpp deleted file mode 100644 index bcf1d983..00000000 --- a/src/node/region_node.hpp +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef REGION_NODE_20230908_HPP -#define REGION_NODE_20230908_HPP - -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" - -#include - -#include -#include - -namespace tinytc { -using region_nodes = clir::virtual_type_list; -} - -struct tinytc_region : tinytc::reference_counted, tinytc::region_nodes { - public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - - private: - tinytc::location loc_; -}; - -namespace tinytc { - -using region_node = ::tinytc_region; - -class rgn : public clir::visitable { - public: - inline rgn(std::vector insts = {}, location const &lc = {}) : insts_(std::move(insts)) { - loc(lc); - } - - inline auto insts() -> std::vector & { return insts_; } - inline auto insts() const -> std::vector const & { return insts_; } - inline void insts(std::vector insts) { insts_ = std::move(insts); } - - private: - std::vector insts_; -}; - -} // namespace tinytc - -#endif // REGION_NODE_20230908_HPP diff --git a/src/node/type.cpp.mochi b/src/node/type.cpp.mochi new file mode 100644 index 00000000..d0ce7984 --- /dev/null +++ b/src/node/type.cpp.mochi @@ -0,0 +1,250 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "compiler_context.hpp" +#include "compiler_context_cache.hpp" +#include "error.hpp" +#include "node/type.hpp" +#include "number.hpp" +#include "support/fnv1a_array_view.hpp" // IWYU pragma: keep +#include "tinytc/builder.h" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" +#include "util/math.hpp" + +#include +#include +#include +#include +#include + +using namespace tinytc; + +namespace tinytc { + +boolean_type::boolean_type(tinytc_compiler_context_t ctx) : tinytc_type(TK::TK_boolean, ctx) {} + +auto boolean_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->bool_ty.get(); +} + +auto coopmatrix_type::get(tinytc_type_t component_ty, std::int64_t rows, std::int64_t cols, + matrix_use use) -> tinytc_type_t { + const auto hash = fnv1a_combine(component_ty, rows, cols, use); + const auto is_equal = [&](tinytc_type_t ty) { + const auto ct = dyn_cast(ty); + return ct && component_ty == ct->component_ty() && rows == ct->rows() && + cols == ct->cols() && use == ct->use(); + }; + const auto make = [&]() { return new coopmatrix_type(component_ty, rows, cols, use); }; + + auto &tys = component_ty->context()->cache()->coopmatrix_tys; + return tys.get(hash, is_equal, make); +} + +coopmatrix_type::coopmatrix_type(tinytc_type_t component_ty, std::int64_t rows0, std::int64_t cols0, + matrix_use use) + : tinytc_type(TK::TK_coopmatrix, component_ty->context()), + component_ty_(std::move(component_ty)), rows_{rows0}, cols_{cols0}, use_(use) { + if (!isa(*component_ty_)) { + throw status::ir_expected_number; + } + if (rows() < 0 || is_dynamic_value(rows())) { + throw status::ir_invalid_shape; + } + if (!is_positive_power_of_two(rows())) { + throw status::ir_unsupported_coopmatrix_shape; + } + if (cols() < 0 || is_dynamic_value(cols())) { + throw status::ir_invalid_shape; + } +} + +auto group_type::get(tinytc_type_t memref_ty, std::int64_t size, std::int64_t offset) + -> tinytc_type_t { + const auto hash = fnv1a_combine(memref_ty, size, offset); + const auto is_equal = [&](tinytc_type_t ty) { + const auto gt = dyn_cast(ty); + return gt && memref_ty == gt->element_ty() && size == gt->size() && offset == gt->offset(); + }; + const auto make = [&]() { return new group_type(memref_ty, size, offset); }; + + auto &tys = memref_ty->context()->cache()->group_tys; + return tys.get(hash, std::move(is_equal), std::move(make)); +} + +group_type::group_type(tinytc_type_t element_ty, std::int64_t size, std::int64_t offset) + : tinytc_type(TK::TK_group, element_ty->context()), element_ty_(std::move(element_ty)), + size_(size), offset_(offset) { + if (!isa(*element_ty_)) { + throw status::ir_expected_memref; + } + if (size < 0 && !is_dynamic_value(size)) { + throw status::ir_invalid_shape; + } + if (offset < 0 && !is_dynamic_value(offset)) { + throw status::ir_invalid_offset; + } +} + +memref_type::memref_type(tinytc_type_t element_ty, array_view shape, + array_view stride, address_space addrspace) + : tinytc_type(TK::TK_memref, element_ty->context()), element_ty_(element_ty), shape_(shape), + stride_(stride), addrspace_(addrspace) { + if (!isa(*element_ty_)) { + throw status::ir_expected_number; + } + if (stride_.size() != shape_.size()) { + throw status::ir_shape_stride_mismatch; + } + for (auto const &s : shape_) { + if (s < 0 && !is_dynamic_value(s)) { + throw status::ir_invalid_shape; + } + } + for (auto const &s : stride_) { + if (s < 0 && !is_dynamic_value(s)) { + throw status::ir_invalid_shape; + } + } +} + +bool memref_type::is_dynamic_shape() const { + return std::any_of(shape_.begin(), shape_.end(), is_dynamic_value); +} +bool memref_type::is_dynamic_stride() const { + return std::any_of(stride_.begin(), stride_.end(), is_dynamic_value); +} +bool memref_type::is_dynamic() const { return is_dynamic_shape() || is_dynamic_stride(); } +bool memref_type::is_canonical_stride() const { return stride_ == canonical_stride(shape_); } + +auto memref_type::element_alignment() const -> std::int32_t { + return ::tinytc::alignment(element_ty()); +} +auto memref_type::size_in_bytes() const -> std::int64_t { + if (is_dynamic()) { + return dynamic; + } + std::size_t s = size(element_ty()); + if (dim() > 0) { + s *= stride_.back() * shape_.back(); + } + return s; +} + +auto memref_type::get(tinytc_type_t element_ty, array_view shape, + array_view stride, address_space addrspace) -> tinytc_type_t { + + auto stride_buffer = std::vector{}; + if (stride.empty()) { + stride_buffer = canonical_stride(shape); + stride = array_view{stride_buffer}; + } + + const auto hash = fnv1a_combine(element_ty, shape, stride, addrspace); + const auto is_equal = [&](tinytc_type_t ty) { + const auto mt = dyn_cast(ty); + return mt && element_ty == mt->element_ty() && addrspace == mt->addrspace() && + std::equal(shape.begin(), shape.end(), mt->shape().begin(), mt->shape().end()) && + std::equal(stride.begin(), stride.end(), mt->stride().begin(), mt->stride().end()); + }; + const auto make = [&]() { + if (!stride_buffer.empty()) { + return new memref_type(element_ty, shape, std::move(stride_buffer), addrspace); + } + return new memref_type(element_ty, shape, stride, addrspace); + }; + + auto &tys = element_ty->context()->cache()->memref_tys; + return tys.get(hash, std::move(is_equal), std::move(make)); +} + +auto memref_type::canonical_stride(array_view shape) -> std::vector { + if (shape.empty()) { + return {}; + } + auto stride = std::vector(shape.size(), dynamic); + stride[0] = 1; + for (std::size_t i = 0; i < shape.size() - 1 && !is_dynamic_value(shape[i]); ++i) { + stride[i + 1] = stride[i] * shape[i]; + } + return stride; +} + +number_type::number_type(TK tid, tinytc_compiler_context_t ctx) : tinytc_type(tid, ctx) {} + +integer_type::integer_type(TK tid, tinytc_compiler_context_t ctx) : number_type(tid, ctx) {} +i8_type::i8_type(tinytc_compiler_context_t ctx) : integer_type(TK::TK_i8, ctx) {} +i16_type::i16_type(tinytc_compiler_context_t ctx) : integer_type(TK::TK_i16, ctx) {} +i32_type::i32_type(tinytc_compiler_context_t ctx) : integer_type(TK::TK_i32, ctx) {} +i64_type::i64_type(tinytc_compiler_context_t ctx) : integer_type(TK::TK_i64, ctx) {} +index_type::index_type(tinytc_compiler_context_t ctx) : integer_type(TK::TK_index, ctx) {} + +float_type::float_type(TK tid, tinytc_compiler_context_t ctx) : number_type(tid, ctx) {} +bf16_type::bf16_type(tinytc_compiler_context_t ctx) : float_type(TK::TK_bf16, ctx) {} +f16_type::f16_type(tinytc_compiler_context_t ctx) : float_type(TK::TK_f16, ctx) {} +f32_type::f32_type(tinytc_compiler_context_t ctx) : float_type(TK::TK_f32, ctx) {} +f64_type::f64_type(tinytc_compiler_context_t ctx) : float_type(TK::TK_f64, ctx) {} + +complex_type::complex_type(TK tid, tinytc_compiler_context_t ctx) : number_type(tid, ctx) {} +c32_type::c32_type(tinytc_compiler_context_t ctx) : complex_type(TK::TK_c32, ctx) {} +c64_type::c64_type(tinytc_compiler_context_t ctx) : complex_type(TK::TK_c64, ctx) {} + +auto i8_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->i8_ty.get(); +} +auto i16_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->i16_ty.get(); +} +auto i32_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->i32_ty.get(); +} +auto i64_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->i64_ty.get(); +} +auto index_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->index_ty.get(); +} +auto bf16_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->bf16_ty.get(); +} +auto f16_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->f16_ty.get(); +} +auto f32_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->f32_ty.get(); +} +auto f64_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->f64_ty.get(); +} +auto c32_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->c32_ty.get(); +} +auto c64_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->c64_ty.get(); +} + +void_type::void_type(tinytc_compiler_context_t ctx) : tinytc_type(TK::TK_void, ctx) {} + +auto void_type::get(tinytc_compiler_context_t ctx) -> tinytc_type_t { + return ctx->cache()->void_ty.get(); +} + +// もち type_cpp "tinytc/types.anko" + +} // namespace tinytc + +extern "C" { +// もち api_builder_cpp "tinytc/types.anko" + +tinytc_status_t tinytc_type_get_compiler_context(const_tinytc_type_t ty, + tinytc_compiler_context_t *ctx) { + if (ty == nullptr || ctx == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ctx = ty->context(); }); +} +} diff --git a/src/node/type.hpp.mochi b/src/node/type.hpp.mochi new file mode 100644 index 00000000..483995e0 --- /dev/null +++ b/src/node/type.hpp.mochi @@ -0,0 +1,37 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef TYPE_20250626_HPP +#define TYPE_20250626_HPP + +#include "tinytc/core.hpp" +#include "tinytc/types.h" + +#include +#include + +namespace tinytc { +enum class address_space; +enum class matrix_use; +enum class TK; +} // namespace tinytc + +struct tinytc_type { + public: + inline tinytc_type(tinytc::TK tid, tinytc_compiler_context_t ctx) : tid_(tid), ctx_(ctx) {} + virtual ~tinytc_type() = default; + inline auto type_id() const -> tinytc::TK { return tid_; } + inline auto context() const -> tinytc_compiler_context_t { return ctx_; } + + private: + tinytc::TK tid_; + tinytc_compiler_context_t ctx_; +}; + +namespace tinytc { + +// もち type_hpp "tinytc/types.anko" + +} // namespace tinytc + +#endif // TYPE_20250626_HPP diff --git a/src/node/value.cpp b/src/node/value.cpp new file mode 100644 index 00000000..b3549537 --- /dev/null +++ b/src/node/value.cpp @@ -0,0 +1,136 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "node/value.hpp" +#include "error.hpp" +#include "tinytc/builder.h" +#include "tinytc/types.h" + +#include +#include + +using namespace tinytc; + +tinytc_value::tinytc_value(tinytc_type_t ty, tinytc_inst_t def_inst, location const &lc) + : ty_{std::move(ty)}, loc_{lc}, def_inst_{def_inst} {} + +tinytc_value::~tinytc_value() { + assert(!has_uses() && "Destructor called for value that still has uses"); +} + +auto tinytc_value::use_begin() -> use_iterator { return {first_use_}; } +auto tinytc_value::use_end() -> use_iterator { return {nullptr}; } +auto tinytc_value::uses() -> iterator_range_wrapper { + return {use_begin(), use_end()}; +} +auto tinytc_value::use_begin() const -> const_use_iterator { return {first_use_}; } +auto tinytc_value::use_end() const -> const_use_iterator { return {nullptr}; } +auto tinytc_value::uses() const -> iterator_range_wrapper { + return {use_begin(), use_end()}; +} +auto tinytc_value::has_uses() const -> bool { return first_use_ != nullptr; } + +namespace tinytc { + +use::use(tinytc_inst_t owner) : owner_{owner} {} + +use::~use() { + if (value_) { + remove_use_from_current_list(); + } +} + +use &use::operator=(tinytc_value_t val) { + set(val); + return *this; +} + +void use::set(tinytc_value_t value) { + if (value_) { + remove_use_from_current_list(); + } + value_ = value; + if (value_) { + add_use_to_list(&value_->first_use_); + } +} + +/* + * Let next = &A.n and we have + * + * ...A|.p|.n-->B|.p|.n-->C|.p|.n... + * ...----| ^-------| ^-------| + * + * After inserting T (T = this) we want the following new or adjusted pointers + * + * ...A|.p|.n==>T|.p|.n==>B|.p|.n-->C|.p|.n... + * ...---| ^======| ^======| ^------| + * + * We need to set + * next_ = T.n -> B = *next + * next_->prev_ = B.p -> &T.n = &next_ + * prev_ = T.p -> &A.n = next + * *next = A.n -> T = this + */ +void use::add_use_to_list(use **next) { + next_ = *next; + if (next_) { + next_->prev_ = &next_; + } + prev_ = next; + *next = this; +} + +/* + * We want to remove T (T = this): + * + * ...A|.p|.n-->T|.p|.n-->B|.p|.n-->C|.p|.n... + * ...---| ^------| ^------| ^------| + * + * After removing T we want the following adjusted pointers + * + * ...A|.p|.n==>B|.p|.n-->C|.p|.n... + * ...---| ^======| ^------| + * + * We need to set + * next_->prev_ = B.p -> &A.n = prev_ + * *prev_ = A.n -> B = next_ + */ +void use::remove_use_from_current_list() { + if (next_) { + next_->prev_ = prev_; + } + *prev_ = next_; +} + +} // namespace tinytc + +extern "C" { +tinytc_status_t tinytc_value_set_name(tinytc_value_t vl, char const *name) { + if (vl == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { vl->name(std::string(name)); }); +} + +tinytc_status_t tinytc_value_set_name_n(tinytc_value_t vl, size_t name_length, char const *name) { + if (vl == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { vl->name(std::string(name, name_length)); }); +} + +tinytc_status_t tinytc_value_get_name(const_tinytc_value_t vl, char const **name) { + if (vl == nullptr || name == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *name = vl->name(); }); +} + +tinytc_status_t tinytc_value_get_type(const_tinytc_value_t vl, tinytc_type_t *ty) { + if (vl == nullptr || ty == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { *ty = vl->ty(); }); +} +} diff --git a/src/node/value.hpp b/src/node/value.hpp new file mode 100644 index 00000000..b9ca4890 --- /dev/null +++ b/src/node/value.hpp @@ -0,0 +1,141 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef VALUE_20250626_HPP +#define VALUE_20250626_HPP + +#include "node/type.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { +class use; +class use_iterator; +class const_use_iterator; +}; // namespace tinytc + +struct alignas(8) tinytc_value final { + public: + tinytc_value(tinytc_type_t ty = nullptr, tinytc_inst_t def_inst_ = nullptr, + tinytc::location const &lc = {}); + ~tinytc_value(); + + tinytc_value(tinytc_value const &) = delete; + tinytc_value(tinytc_value &&) = default; + tinytc_value &operator=(tinytc_value const &) = delete; + tinytc_value &operator=(tinytc_value &&) = default; + + inline auto loc() const noexcept -> tinytc::location const & { return loc_; } + inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } + + inline auto ty() const -> tinytc_type_t { return ty_; } + + inline auto context() const -> tinytc_compiler_context_t { return ty_->context(); } + + inline auto name() const -> char const * { return name_.c_str(); } + inline void name(std::string name) { name_ = std::move(name); } + auto has_name() const -> bool { return !name_.empty(); } + + auto use_begin() -> tinytc::use_iterator; + auto use_end() -> tinytc::use_iterator; + auto uses() -> tinytc::iterator_range_wrapper; + auto use_begin() const -> tinytc::const_use_iterator; + auto use_end() const -> tinytc::const_use_iterator; + auto uses() const -> tinytc::iterator_range_wrapper; + auto has_uses() const -> bool; + + // Can be nullptr, e.g. if value is a region parameter + inline auto defining_inst() const -> tinytc_inst_t { return def_inst_; } + inline void defining_inst(tinytc_inst_t def_inst) { def_inst_ = def_inst; } + + private: + tinytc_type_t ty_; + tinytc::location loc_; + tinytc_inst_t def_inst_ = nullptr; + std::string name_; + + friend class tinytc::use; + tinytc::use *first_use_ = nullptr; +}; + +namespace tinytc { + +class alignas(8) use final { + public: + use() = default; + use(tinytc_inst_t owner); + ~use(); + + use(use &&other) = delete; + use(use const &other) = delete; + use &operator=(use &&other) = delete; + use &operator=(use const &other) = delete; + + use &operator=(tinytc_value_t val); + + inline auto get() -> tinytc_value_t { return value_; } + inline auto get() const -> const_tinytc_value_t { return value_; } + void set(tinytc_value_t value); + + inline auto owner() const -> tinytc_inst_t { return owner_; } + inline void owner(tinytc_inst_t owner) { owner_ = owner; } + + inline auto next() -> use * { return next_; } + inline auto next() const -> use const * { return next_; } + + private: + void add_use_to_list(use **next); + void remove_use_from_current_list(); + + tinytc_inst_t owner_ = nullptr; + tinytc_value_t value_ = nullptr; + use **prev_ = nullptr; + use *next_ = nullptr; +}; + +namespace detail { +template class use_iterator_base { + public: + using value_type = std::conditional_t; + using pointer = value_type *; + using reference = value_type &; + using difference_type = std::ptrdiff_t; + + use_iterator_base() : pos_{nullptr} {} + use_iterator_base(pointer pos) : pos_{std::move(pos)} {} + + auto operator*() const -> reference { return *pos_; } + auto operator->() const -> pointer { return pos_; } + auto operator++() -> Derived & { + pos_ = pos_->next(); + return *static_cast(this); + } + auto operator++(int) -> Derived { + auto old_pos = pos_; + pos_ = pos_->next(); + return Derived{old_pos}; + } + auto operator==(use_iterator_base const &other) const -> bool { return pos_ == other.pos_; } + auto operator!=(use_iterator_base const &other) const -> bool { return pos_ != other.pos_; } + + private: + pointer pos_; +}; +} // namespace detail + +class use_iterator : public detail::use_iterator_base {}; +class const_use_iterator : public detail::use_iterator_base {}; + +static_assert(std::forward_iterator); +static_assert(std::forward_iterator); + +} // namespace tinytc + +#endif // VALUE_20250626_HPP diff --git a/src/node/value_node.hpp b/src/node/value_node.hpp deleted file mode 100644 index 461bbf8d..00000000 --- a/src/node/value_node.hpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef VALUE_NODE_20230309_HPP -#define VALUE_NODE_20230309_HPP - -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" - -#include - -#include -#include -#include - -namespace tinytc { -using value_nodes = clir::virtual_type_list; -} - -struct tinytc_value : tinytc::reference_counted, tinytc::value_nodes { - public: - inline auto loc() const noexcept -> tinytc::location const & { return loc_; } - inline void loc(tinytc::location const &loc) noexcept { loc_ = loc; } - - virtual tinytc::data_type ty() const = 0; - virtual void ty(tinytc::data_type ty) = 0; - virtual auto name() const -> char const * = 0; - virtual void name(std::string name) = 0; - virtual auto has_name() const -> bool = 0; - - private: - tinytc::location loc_; -}; - -namespace tinytc { - -using value_node = ::tinytc_value; - -class float_imm : public clir::visitable { - public: - inline float_imm(double v, scalar_type ty = scalar_type::f64) - : ty_{make_scalar(ty)}, value_(v) {} - - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } - inline auto name() const -> char const * override { return ""; } - inline void name(std::string) override {} - auto has_name() const -> bool override { return false; } - - inline double value() const { return value_; } - - private: - data_type ty_; - double value_; -}; - -class int_imm : public clir::visitable { - public: - inline int_imm(std::int64_t v, scalar_type ty = scalar_type::i64) - : ty_{make_scalar(ty)}, value_(v) {} - - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } - inline auto name() const -> char const * override { return ""; } - inline void name(std::string) override {} - auto has_name() const -> bool override { return false; } - - inline std::int64_t value() const { return value_; } - - private: - data_type ty_; - std::int64_t value_; -}; - -class val : public clir::visitable { - public: - inline val(data_type ty) : ty_(std::move(ty)) {} - - inline data_type ty() const override { return ty_; } - inline void ty(data_type ty) override { ty_ = std::move(ty); } - inline auto name() const -> char const * override { return name_.c_str(); } - inline void name(std::string name) override { name_ = std::move(name); } - virtual auto has_name() const -> bool override { return !name_.empty(); } - - private: - data_type ty_; - std::string name_; -}; - -} // namespace tinytc - -#endif // VALUE_NODE_20230309_HPP diff --git a/src/node/visit.hpp.mochi b/src/node/visit.hpp.mochi new file mode 100644 index 00000000..db5c71c2 --- /dev/null +++ b/src/node/visit.hpp.mochi @@ -0,0 +1,42 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef VISIT_HPP_20250618_MOCHI +#define VISIT_HPP_20250618_MOCHI + +#include "node/attr.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" + +namespace tinytc { + +template auto visit(Visitor &&visitor, tinytc_attr &a) { + switch (a.type_id()) { + case AK::AK_array: { + return visitor(*static_cast(&a)); + } + case AK::AK_boolean: { + return visitor(*static_cast(&a)); + } + case AK::AK_dictionary: { + return visitor(*static_cast(&a)); + } + case AK::AK_integer: { + return visitor(*static_cast(&a)); + } + case AK::AK_string: { + return visitor(*static_cast(&a)); + } + default: + break; + } + throw status::internal_compiler_error; +} + +// もち visit_hpp "tinytc/types.anko" + +// もち visit_hpp "tinytc/instructions.anko" + +} // namespace tinytc + +#endif // VISIT_HPP_20250618_MOCHI diff --git a/src/number.cpp b/src/number.cpp new file mode 100644 index 00000000..e8ada48a --- /dev/null +++ b/src/number.cpp @@ -0,0 +1,103 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "number.hpp" +#include "compiler_context.hpp" +#include "error.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/overloaded.hpp" + +#include + +namespace tinytc { + +auto acc_type(tinytc_type_t ty) -> tinytc_type_t { + return visit( + overloaded{[](i8_type &ty) -> tinytc_type_t { return i32_type::get(ty.context()); }, + [](bf16_type &ty) -> tinytc_type_t { return f32_type::get(ty.context()); }, + [](f16_type &ty) -> tinytc_type_t { return f32_type::get(ty.context()); }, + [](tinytc_type &ty) -> tinytc_type_t { return &ty; }}, + *ty); +} + +auto component_count(tinytc_type_t ty) -> vector_size { + return isa(*ty) ? vector_size::v2 : vector_size::v1; +} +auto component_type(tinytc_type_t ty) -> tinytc_type_t { + if (isa(*ty)) { + return f32_type::get(ty->context()); + } else if (isa(*ty)) { + return f64_type::get(ty->context()); + } else if (isa(*ty)) { + return ty; + } + // only call component type for number types + throw status::ir_expected_number; +} + +auto promotable(tinytc_type_t a_ty, tinytc_type_t b_ty) -> bool { + if (a_ty == b_ty) { + return true; + } + const auto a_cc = static_cast(component_count(a_ty)); + const auto b_cc = static_cast(component_count(b_ty)); + const auto a_ct = component_type(a_ty); + const auto b_ct = component_type(b_ty); + return (isa(*a_ct) || !isa(*b_ct)) && + (size(a_ct) < size(b_ct) || a_ct == b_ct) && a_cc <= b_cc; +} + +auto promote(tinytc_type_t a_ty, tinytc_type_t b_ty) -> tinytc_type_t { + if (promotable(a_ty, b_ty)) { + return b_ty; + } else if (promotable(b_ty, a_ty)) { + return a_ty; + } + return nullptr; +} + +auto promote_or_throw(tinytc_type_t a_ty, tinytc_type_t b_ty, location const &loc) + -> tinytc_type_t { + if (auto res = promote(a_ty, b_ty); res) { + return res; + } + throw compilation_error(loc, status::ir_forbidden_promotion); +} + +auto alignment(tinytc_type_t ty, vector_size count) -> std::int32_t { + const std::int32_t scale = count == vector_size::v3 ? 4 : static_cast(count); + return scale * size(ty); +} + +auto is_cast_allowed(tinytc_type_t from_ty, tinytc_type_t to_ty) -> bool { + return isa(*from_ty) && isa(*to_ty) && + (!isa(*from_ty) || isa(*to_ty)); +} + +auto size(tinytc_type_t ty) -> std::size_t { + return visit(overloaded{[](i8_type &) -> std::size_t { return 1; }, // + [](i16_type &) -> std::size_t { return 2; }, // + [](i32_type &) -> std::size_t { return 4; }, // + [](i64_type &) -> std::size_t { return 8; }, // + [](index_type &ty) -> std::size_t { + return ty.context()->index_bit_width() / 8; + }, // + [](bf16_type &) -> std::size_t { return 2; }, // + [](f16_type &) -> std::size_t { return 2; }, // + [](f32_type &) -> std::size_t { return 4; }, // + [](f64_type &) -> std::size_t { return 8; }, // + [](c32_type &) -> std::size_t { return 8; }, // + [](c64_type &) -> std::size_t { return 16; }, // + [](tinytc_type &) -> std::size_t { + // only call size for number types + throw status::ir_expected_number; + }}, + *ty); +} + +} // namespace tinytc + diff --git a/src/number.hpp b/src/number.hpp new file mode 100644 index 00000000..428a3355 --- /dev/null +++ b/src/number.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef NUMBER_20250702_HPP +#define NUMBER_20250702_HPP + +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +enum class vector_size { v1 = 1, v2 = 2, v3 = 3, v4 = 4, v8 = 8, v16 = 16 }; + +auto acc_type(tinytc_type_t ty) -> tinytc_type_t; +auto component_count(tinytc_type_t ty) -> vector_size; +auto component_type(tinytc_type_t ty) -> tinytc_type_t; +auto promotable(tinytc_type_t a_ty, tinytc_type_t b_ty) -> bool; +auto promote(tinytc_type_t a_ty, tinytc_type_t b_ty) -> tinytc_type_t; +auto promote_or_throw(tinytc_type_t a_ty, tinytc_type_t b_ty, location const &loc) -> tinytc_type_t; +auto is_cast_allowed(tinytc_type_t from_ty, tinytc_type_t to_ty) -> bool; +auto alignment(tinytc_type_t ty, vector_size count = vector_size::v1) -> std::int32_t; +auto size(tinytc_type_t ty) -> std::size_t; +auto bit_width(tinytc_type_t ty) -> std::size_t; + +} // namespace tinytc + +#endif // NUMBER_20250702_HPP diff --git a/src/number_dispatch.hpp b/src/number_dispatch.hpp new file mode 100644 index 00000000..1bf7fe97 --- /dev/null +++ b/src/number_dispatch.hpp @@ -0,0 +1,103 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause +// +#ifndef NUMBER_DISPATCH_20250702_HPP +#define NUMBER_DISPATCH_20250702_HPP + +#include "compiler_context.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "util/casting.hpp" +#include "util/overloaded.hpp" + +#include + +namespace tinytc { + +template +auto dispatch_int_to_native(tinytc_type_t ty, F &&f, T &&...args) { + using RetT = + std::common_type_t(std::declval()...)), + decltype(f.template operator()(std::declval()...)), + decltype(f.template operator()(std::declval()...)), + decltype(f.template operator()(std::declval()...))>; + return visit( + overloaded{[&](i8_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](i16_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](i32_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](i64_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](index_type &ty) -> RetT { + const auto idx_width = ty.context()->index_bit_width(); + if (idx_width == 64) { + return f.template operator()(std::forward(args)...); + } else if (idx_width == 32) { + return f.template operator()(std::forward(args)...); + } + throw status::not_implemented; + }, + [](tinytc_type &) -> RetT { throw status::ir_expected_int; }}, + *ty); +} + +template +auto dispatch_float_to_native(tinytc_type_t ty, F &&f, T &&...args) { + using RetT = + std::common_type_t(std::declval()...)), + decltype(f.template operator()(std::declval()...)), + decltype(f.template operator()(std::declval()...)), + decltype(f.template operator()(std::declval()...))>; + return visit(overloaded{[&](bf16_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](f16_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](f32_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [&](f64_type &) -> RetT { + return f.template operator()(std::forward(args)...); + }, + [](tinytc_type &) -> RetT { throw status::ir_expected_float; }}, + *ty); +} + +template +auto dispatch_complex_to_native(tinytc_type_t ty, F &&f, T &&...args) { + using RetT = std::common_type_t< + decltype(f.template operator()>(std::declval()...)), + decltype(f.template operator()>(std::declval()...))>; + return visit( + overloaded{[&](c32_type &) -> RetT { + return f.template operator()>(std::forward(args)...); + }, + [&](c64_type &) -> RetT { + return f.template operator()>(std::forward(args)...); + }, + [](tinytc_type &) -> RetT { throw status::ir_expected_complex; }}, + *ty); +} + +template +auto dispatch_number_to_native(tinytc_type_t ty, F &&f, T &&...args) { + if (isa(*ty)) { + return dispatch_int_to_native(ty, std::forward(f), std::forward(args)...); + } else if (isa(*ty)) { + return dispatch_float_to_native(ty, std::forward(f), std::forward(args)...); + } else if (isa(*ty)) { + return dispatch_complex_to_native(ty, std::forward(f), std::forward(args)...); + } + throw status::ir_expected_number; +} + +} // namespace tinytc + +#endif // NUMBER_DISPATCH_20250702_HPP diff --git a/src/parser.cpp b/src/parser.cpp index 9abe3fb7..19386734 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -3,12 +3,13 @@ #include "parser.hpp" +#include "compiler_context.hpp" #include "error.hpp" -#include "location.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" #include "parser/parser_impl.hpp" -#include "tinytc/tinytc.h" +#include "tinytc/core.h" +#include "tinytc/core.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -17,69 +18,32 @@ #include #include #include -#include -#include +#include #include namespace tinytc { -auto parse(std::uint64_t size, char const *input) -> prog { - auto const initial_loc = location{position{0, 1, 1}, position{0, 1, 1}}; - auto lex = lexer(size, input, initial_loc); - auto ctx = parse_context{}; - auto p = parser(lex, ctx); - if (p() == 0) { - return ctx.program(); - } - return prog{}; -} -} // namespace tinytc - -using namespace tinytc; -extern "C" { - -tinytc_source_context::tinytc_source_context() {} - -auto tinytc_source_context::parse(std::string name, std::string text) -> prog { - sources_.emplace_back(source_input{std::move(name), std::move(text)}); - std::int32_t source_id = static_cast(sources_.size()); +auto parse(std::string name, std::string text, + shared_handle compiler_ctx) -> shared_handle { + std::int32_t source_id = compiler_ctx->add_source(std::move(name), std::move(text)); auto const initial_loc = location{position{source_id, 1, 1}, position{source_id, 1, 1}}; - auto const &input = sources_.back(); - auto lex = lexer(input.text.size(), input.text.c_str(), initial_loc); - auto ctx = parse_context{}; - auto p = parser(lex, ctx); + auto [ir, ir_size] = compiler_ctx->source_text(source_id); + auto lex = lexer(ir_size, ir, initial_loc); + auto parse_ctx = parse_context{std::move(compiler_ctx)}; + auto p = parser(lex, parse_ctx); if (p() == 0) { - return ctx.program(); - } - last_error_log_.clear(); - for (auto const &err : ctx.errors()) { - last_error_log_ = report_error_with_context(input.text.c_str(), input.text.size(), - input.name, err.first, err.second); + return parse_ctx.program(); } - return prog{}; + return {}; } -void tinytc_source_context::report_error(location const &l, char const *what, bool append) { - auto err = std::string{}; - if (l.begin.source_id >= 1 && static_cast(l.begin.source_id) <= sources_.size()) { - auto const &src = sources_[l.begin.source_id - 1]; - err = report_error_with_context(src.text.c_str(), src.text.size(), src.name, l, what); - } else { - err = (std::ostringstream{} << "\n" - << l << ": " << what) - .str(); - } - if (append) { - last_error_log_ += std::move(err); - } else { - last_error_log_ = std::move(err); - } -} +} // namespace tinytc + +using namespace tinytc; tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, - tinytc_source_context_t source_ctx) { + tinytc_compiler_context_t ctx) { if (prg == nullptr || filename == nullptr) { return tinytc_status_invalid_arguments; } @@ -89,9 +53,8 @@ tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, throw status::file_io_error; } auto ir = std::string(std::istreambuf_iterator{ir_stream}, {}); - - auto prog = source_ctx ? source_ctx->parse(std::string(filename), std::move(ir)) - : parse(ir.size(), ir.c_str()); + auto ctx_ = ctx ? shared_handle{ctx, true} : create_compiler_context(); + auto prog = parse(std::string(filename), std::move(ir), ctx_); if (!prog) { throw status::parse_error; } @@ -99,14 +62,14 @@ tinytc_status_t tinytc_parse_file(tinytc_prog_t *prg, char const *filename, }); } -tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t source_ctx) { +tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_compiler_context_t ctx) { if (prg == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { auto ir = std::string(std::istreambuf_iterator{std::cin}, {}); - auto prog = - source_ctx ? source_ctx->parse("", std::move(ir)) : parse(ir.size(), ir.c_str()); + auto ctx_ = ctx ? shared_handle{ctx, true} : create_compiler_context(); + auto prog = parse("", std::move(ir), ctx_); if (!prog) { throw status::parse_error; } @@ -115,71 +78,16 @@ tinytc_status_t tinytc_parse_stdin(tinytc_prog_t *prg, tinytc_source_context_t s } tinytc_status_t tinytc_parse_string(tinytc_prog_t *prg, size_t source_size, char const *source, - tinytc_source_context_t source_ctx) { + tinytc_compiler_context_t ctx) { if (prg == nullptr || source_size == 0 || source == nullptr) { return tinytc_status_invalid_arguments; } return exception_to_status_code([&] { - auto prog = source_ctx - ? source_ctx->parse("", std::string(source, source + source_size)) - : parse(source_size, source); + auto ctx_ = ctx ? shared_handle{ctx, true} : create_compiler_context(); + auto prog = parse("", std::string(source, source + source_size), ctx_); if (!prog) { throw status::parse_error; } *prg = prog.release(); }); } - -tinytc_status_t tinytc_source_context_create(tinytc_source_context_t *ctx) { - if (ctx == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code( - [&] { *ctx = std::make_unique().release(); }); -} - -tinytc_status_t tinytc_source_context_add_source(tinytc_source_context_t ctx, char const *name, - char const *text, int32_t *source_id) { - if (ctx == nullptr || name == nullptr || text == nullptr || source_id == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *source_id = ctx->add_source(name, text); }); -} - -tinytc_status_t tinytc_source_context_get_error_log(const_tinytc_source_context_t ctx, - char const **log) { - if (ctx == nullptr || log == nullptr) { - return tinytc_status_invalid_arguments; - } - *log = ctx->last_error_log().c_str(); // last_error_log and c_str are noexcept - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_context_report_error(tinytc_source_context_t ctx, - const tinytc_location_t *location, - char const *what, tinytc_bool_t append) { - if (ctx == nullptr || location == nullptr || what == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { ctx->report_error(*location, what, bool(append)); }); -} - -tinytc_status_t tinytc_source_context_release(tinytc_source_context_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_context_retain(tinytc_source_context_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/parser.hpp b/src/parser.hpp index 0d1cd1a9..02b1b841 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -4,48 +4,13 @@ #ifndef PARSER_20230614_HPP #define PARSER_20230614_HPP -#include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" +#include "tinytc/types.hpp" #include -#include -#include namespace tinytc { -auto parse(std::uint64_t size, char const *input) -> prog; +auto parse(std::uint64_t size, char const *input) -> shared_handle; } -/** - * @brief Source manager - * - * The source manager can parse tensor programs from files, stdin, or memory. - * Source code is stored in the manager such that error messages can be enhanced - * with code context. - */ -struct tinytc_source_context : tinytc::reference_counted { - public: - //! @brief ctor - tinytc_source_context(); - - auto parse(std::string name, std::string text) -> tinytc::prog; - - inline auto add_source(char const *name, char const *text) -> std::int32_t { - sources_.emplace_back(source_input{std::string(name), std::string(text)}); - return static_cast(sources_.size()); - } - - //! Annotate context to error message - void report_error(tinytc_location const &l, char const *what, bool append = false); - //! Return error log of last parse call - inline auto last_error_log() const noexcept -> std::string const & { return last_error_log_; } - - private: - struct source_input { - std::string name, text; - }; - std::vector sources_; - std::string last_error_log_; -}; - #endif // PARSER_20230614_HPP diff --git a/src/parser/lexer.hpp b/src/parser/lexer.hpp index b49ad4d4..ad30db76 100644 --- a/src/parser/lexer.hpp +++ b/src/parser/lexer.hpp @@ -26,8 +26,6 @@ class lexer { std::uint64_t lex_number(char const *s, char const *e); std::int64_t lex_integer_constant(char const *s, char const *e); double lex_floating_constant(char const *s, char const *e); - scalar_type lex_integer_type(char const *s, char const *e); - scalar_type lex_floating_type(char const *s, char const *e); char const *input_; std::size_t len_; diff --git a/src/parser/lexer.re b/src/parser/lexer.re index e76b2b87..000ae125 100644 --- a/src/parser/lexer.re +++ b/src/parser/lexer.re @@ -25,6 +25,7 @@ lex: char const *b = YYCURSOR; step(loc_); auto const adv_loc = [&]() { columns(loc_, YYCURSOR - b); }; + // clang-format off /*!re2c re2c:define:YYCTYPE = char; re2c:yyfill:enable = 0; @@ -35,13 +36,12 @@ lex: newline = "\r"? "\n"; whitespace = [ \t\v\r]+; - identifier = [0-9]+ | ([a-zA-Z] [a-zA-Z0-9_]*); - local_identifier = "%" identifier; - global_identifier = "@" identifier; - - integer_type = "i" ("1" | "8" | "16" | "32" | "64") | "index"; - floating_type = "f" ("32" | "64"); - + unnamed_identifier = [0-9]+; + named_identifier = [a-zA-Z] [a-zA-Z0-9_]*; + local_unnamed_identifier = "%" unnamed_identifier; + local_named_identifier = "%" named_identifier; + global_identifier = "@" (unnamed_identifier | named_identifier); + string = "\"" [^\"]* "\""; digit = [0-9]; hexdigit = [0-9a-fA-F]; integer_constant = [-+]? digit+; @@ -49,12 +49,17 @@ lex: mantissa_hex = (hexdigit* "." hexdigit+ | hexdigit+ "."); exponent = [eE] [-+]? digit+; exponent_hex = [pP] [-+]? digit+; - floating_constant = [-+]? (mantissa exponent? | digit+ exponent); + floating_constant = [-+]? (mantissa exponent? | digit+ exponent | "inf" | "nan"); floating_constant_hex = [-+]? "0x" (mantissa_hex exponent_hex? | hexdigit+ exponent_hex); // identifier - local_identifier { + local_unnamed_identifier { + adv_loc(); + std::int64_t id = lex_integer_constant(b + 1, YYCURSOR); + return parser::make_LOCAL_IDENTIFIER(std::move(id), loc_); + } + local_named_identifier { adv_loc(); auto id = std::string(b + 1, YYCURSOR); return parser::make_LOCAL_IDENTIFIER(std::move(id), loc_); @@ -81,17 +86,28 @@ lex: // keywords "func" { adv_loc(); return parser::make_FUNC(loc_); } - "work_group_size" { adv_loc(); return parser::make_WORK_GROUP_SIZE(loc_); } - "subgroup_size" { adv_loc(); return parser::make_SUBGROUP_SIZE(loc_); } - "->" { adv_loc(); return parser::make_RETURNS(loc_); } + "->" { adv_loc(); return parser::make_ARROW(loc_); } "?" { adv_loc(); return parser::make_DYNAMIC(loc_); } - ".n" { adv_loc(); return parser::make_NOTRANS(loc_); } - ".t" { adv_loc(); return parser::make_TRANS(loc_); } + ".n" { adv_loc(); return parser::make_TRANSPOSE(transpose::N, loc_); } + ".t" { adv_loc(); return parser::make_TRANSPOSE(transpose::T, loc_); } ".atomic" { adv_loc(); return parser::make_ATOMIC(loc_); } + ".atomic_add" { adv_loc(); return parser::make_ATOMIC_ADD(loc_); } + ".atomic_max" { adv_loc(); return parser::make_ATOMIC_MAX(loc_); } + ".atomic_min" { adv_loc(); return parser::make_ATOMIC_MIN(loc_); } + "init" { adv_loc(); return parser::make_INIT(loc_); } + "local" { adv_loc(); return parser::make_LOCAL(loc_); } + "global" { adv_loc(); return parser::make_GLOBAL(loc_); } + ".local" { adv_loc(); return parser::make_LOCAL_ATTR(loc_); } + ".global" { adv_loc(); return parser::make_GLOBAL_ATTR(loc_); } + ".x" { adv_loc(); return parser::make_COMP3(comp3::x, loc_); } + ".y" { adv_loc(); return parser::make_COMP3(comp3::y, loc_); } + ".z" { adv_loc(); return parser::make_COMP3(comp3::z, loc_); } + ".row" { adv_loc(); return parser::make_REDUCE_MODE(reduce_mode::row, loc_); } + ".column" { adv_loc(); return parser::make_REDUCE_MODE(reduce_mode::column, loc_); } // constants - "true" { adv_loc(); return parser::make_INTEGER_CONSTANT(1, loc_); } - "false" { adv_loc(); return parser::make_INTEGER_CONSTANT(0, loc_); } + "true" { adv_loc(); return parser::make_BOOLEAN_CONSTANT(true, loc_); } + "false" { adv_loc(); return parser::make_BOOLEAN_CONSTANT(false, loc_); } integer_constant { adv_loc(); auto i = lex_integer_constant(b, YYCURSOR); @@ -103,17 +119,26 @@ lex: return parser::make_FLOATING_CONSTANT(f, loc_); } - // types - integer_type { - adv_loc(); - auto t = lex_integer_type(b, YYCURSOR); - return parser::make_INTEGER_TYPE(t, loc_); - } - floating_type { - adv_loc(); - auto t = lex_floating_type(b, YYCURSOR); - return parser::make_FLOATING_TYPE(t, loc_); + // attributes + "attributes" { return parser::make_ATTRIBUTES(loc_); } + "alignment" | "shape_gcd" | "stride_gcd" | "unroll" | "work_group_size" { + adv_loc(); return parser::make_ATTR_NAME(std::string(b, YYCURSOR), loc_); } + + // types + "bool" { return parser::make_BOOLEAN(loc_); } + "i8" { adv_loc(); return parser::make_I8_TYPE(loc_); } + "i16" { adv_loc(); return parser::make_I16_TYPE(loc_); } + "i32" { adv_loc(); return parser::make_I32_TYPE(loc_); } + "i64" { adv_loc(); return parser::make_I64_TYPE(loc_); } + "index" { adv_loc(); return parser::make_INDEX_TYPE(loc_); } + "bf16" { adv_loc(); return parser::make_BF16_TYPE(loc_); } + "f16" { adv_loc(); return parser::make_F16_TYPE(loc_); } + "f32" { adv_loc(); return parser::make_F32_TYPE(loc_); } + "f64" { adv_loc(); return parser::make_F64_TYPE(loc_); } + "c32" { adv_loc(); return parser::make_C32_TYPE(loc_); } + "c64" { adv_loc(); return parser::make_C64_TYPE(loc_); } + "coopmatrix" { adv_loc(); return parser::make_COOPMATRIX(loc_); } "memref" { adv_loc(); return parser::make_MEMREF(loc_); } "group" { adv_loc(); return parser::make_GROUP(loc_); } @@ -121,60 +146,118 @@ lex: "offset" { adv_loc(); return parser::make_OFFSET(loc_); } "strided" { adv_loc(); return parser::make_STRIDED(loc_); } + // matrix use + "matrix_a" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::a, loc_); } + "matrix_b" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::b, loc_); } + "matrix_acc" { adv_loc(); return parser::make_MATRIX_USE(matrix_use::acc, loc_); } + + // checked flag + ".rows_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::rows, loc_); } + ".cols_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::cols, loc_); } + ".both_checked" { adv_loc(); return parser::make_CHECKED(checked_flag::both, loc_); } + // instructions - "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } - "arith" { adv_loc(); return parser::make_ARITH(loc_); } - "gemm" { adv_loc(); return parser::make_GEMM(loc_); } - "gemv" { adv_loc(); return parser::make_GEMV(loc_); } - "ger" { adv_loc(); return parser::make_GER(loc_); } - "hadamard" { adv_loc(); return parser::make_HADAMARD(loc_); } - "alloca" { adv_loc(); return parser::make_ALLOCA(loc_); } - "cast" { adv_loc(); return parser::make_CAST(loc_); } - "cmp" { adv_loc(); return parser::make_CMP(loc_); } - "expand" { adv_loc(); return parser::make_EXPAND(loc_); } - "fuse" { adv_loc(); return parser::make_FUSE(loc_); } - "load" { adv_loc(); return parser::make_LOAD(loc_); } - "group_id" { adv_loc(); return parser::make_GROUP_ID(loc_); } - "group_size" { adv_loc(); return parser::make_GROUP_SIZE(loc_); } - "for" { adv_loc(); return parser::make_FOR(loc_); } - "foreach" { adv_loc(); return parser::make_FOREACH(loc_); } - "if" { adv_loc(); return parser::make_IF(loc_); } - "else" { adv_loc(); return parser::make_ELSE(loc_); } - "size" { adv_loc(); return parser::make_SIZE(loc_); } - "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } - "store" { adv_loc(); return parser::make_STORE(loc_); } - "sum" { adv_loc(); return parser::make_SUM(loc_); } - "yield" { adv_loc(); return parser::make_YIELD(loc_); } + "axpby" { adv_loc(); return parser::make_AXPBY(loc_); } + "barrier" { adv_loc(); return parser::make_BARRIER(loc_); } + "cumsum" { adv_loc(); return parser::make_CUMSUM(loc_); } + "gemm" { adv_loc(); return parser::make_GEMM(loc_); } + "gemv" { adv_loc(); return parser::make_GEMV(loc_); } + "ger" { adv_loc(); return parser::make_GER(loc_); } + "hadamard" { adv_loc(); return parser::make_HADAMARD(loc_); } + "alloca" { adv_loc(); return parser::make_ALLOCA(loc_); } + "cast" { adv_loc(); return parser::make_CAST(loc_); } + "constant" { adv_loc(); return parser::make_CONSTANT(loc_); } + "cooperative_matrix_apply" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_APPLY(loc_); } + "cooperative_matrix_extract" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_EXTRACT(loc_); } + "cooperative_matrix_insert" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_INSERT(loc_); } + "cooperative_matrix_load" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_LOAD(loc_); } + "cooperative_matrix_mul_add" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_MUL_ADD(loc_); } + "cooperative_matrix_prefetch" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_PREFETCH(loc_); } + "cooperative_matrix_scale" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_SCALE(loc_); } + "cooperative_matrix_store" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_STORE(loc_); } + "expand" { adv_loc(); return parser::make_EXPAND(loc_); } + "fuse" { adv_loc(); return parser::make_FUSE(loc_); } + "load" { adv_loc(); return parser::make_LOAD(loc_); } + "for" { adv_loc(); return parser::make_FOR(loc_); } + "foreach" { adv_loc(); return parser::make_FOREACH(loc_); } + "if" { adv_loc(); return parser::make_IF(loc_); } + "parallel" { adv_loc(); return parser::make_PARALLEL(loc_); } + "else" { adv_loc(); return parser::make_ELSE(loc_); } + "size" { adv_loc(); return parser::make_SIZE(loc_); } + "subgroup_broadcast" { adv_loc(); return parser::make_SUBGROUP_BROADCAST(loc_); } + "subview" { adv_loc(); return parser::make_SUBVIEW(loc_); } + "store" { adv_loc(); return parser::make_STORE(loc_); } + "sum" { adv_loc(); return parser::make_SUM(loc_); } + "yield" { adv_loc(); return parser::make_YIELD(loc_); } // binary op - ".add" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::add, loc_); } - ".sub" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::sub, loc_); } - ".mul" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::mul, loc_); } - ".div" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::div, loc_); } - ".rem" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::rem, loc_); } - ".shl" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::shl, loc_); } - ".shr" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::shr, loc_); } - ".and" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::and_, loc_); } - ".or" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::or_, loc_); } - ".xor" { adv_loc(); return parser::make_ARITHMETIC(arithmetic::xor_, loc_); } + "add" { adv_loc(); return parser::make_ADD(loc_); } + "sub" { adv_loc(); return parser::make_SUB(loc_); } + "mul" { adv_loc(); return parser::make_MUL(loc_); } + "div" { adv_loc(); return parser::make_DIV(loc_); } + "rem" { adv_loc(); return parser::make_REM(loc_); } + "shl" { adv_loc(); return parser::make_SHL(loc_); } + "shr" { adv_loc(); return parser::make_SHR(loc_); } + "and" { adv_loc(); return parser::make_AND(loc_); } + "or" { adv_loc(); return parser::make_OR (loc_); } + "xor" { adv_loc(); return parser::make_XOR(loc_); } + "min" { adv_loc(); return parser::make_MIN(loc_); } + "max" { adv_loc(); return parser::make_MAX(loc_); } // unary op - ".neg" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::neg, loc_); } - ".not" { adv_loc(); return parser::make_ARITHMETIC_UNARY(arithmetic_unary::not_, loc_); } + "abs" { adv_loc(); return parser::make_ABS(loc_); } + "neg" { adv_loc(); return parser::make_NEG(loc_); } + "not" { adv_loc(); return parser::make_NOT(loc_); } + "conj" { adv_loc(); return parser::make_CONJ(loc_); } + "im" { adv_loc(); return parser::make_IM(loc_); } + "re" { adv_loc(); return parser::make_RE(loc_); } + + // builtin + "group_id" { adv_loc(); return parser::make_GROUP_ID(loc_); } + "num_groups" { adv_loc(); return parser::make_NUM_GROUPS(loc_); } + "num_subgroups" { adv_loc(); return parser::make_NUM_SUBGROUPS(loc_); } + "subgroup_size" { adv_loc(); return parser::make_SUBGROUP_SIZE(loc_); } + "subgroup_id" { adv_loc(); return parser::make_SUBGROUP_ID(loc_); } + "subgroup_linear_id" { adv_loc(); return parser::make_SUBGROUP_LINEAR_ID(loc_); } + "subgroup_local_id" { adv_loc(); return parser::make_SUBGROUP_LOCAL_ID(loc_); } // comparison condition - ".eq" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::eq, - loc_); } - ".ne" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ne, - loc_); } - ".gt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::gt, - loc_); } - ".ge" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::ge, - loc_); } - ".lt" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::lt, - loc_); } - ".le" { adv_loc(); return parser::make_CMP_CONDITION(cmp_condition::le, - loc_); } + "equal" { adv_loc(); return parser::make_EQUAL(loc_); } + "not_equal" { adv_loc(); return parser::make_NOT_EQUAL(loc_); } + "greater_than" { adv_loc(); return parser::make_GREATER_THAN(loc_); } + "greater_than_equal" { adv_loc(); return parser::make_GREATER_THAN_EQUAL(loc_); } + "less_than" { adv_loc(); return parser::make_LESS_THAN(loc_); } + "less_than_equal" { adv_loc(); return parser::make_LESS_THAN_EQUAL(loc_); } + + // math op + "cos" { adv_loc(); return parser::make_COS(loc_); } + "sin" { adv_loc(); return parser::make_SIN(loc_); } + "exp" { adv_loc(); return parser::make_EXP(loc_); } + "exp2" { adv_loc(); return parser::make_EXP2(loc_); } + "native_cos" { adv_loc(); return parser::make_NATIVE_COS(loc_); } + "native_sin" { adv_loc(); return parser::make_NATIVE_SIN(loc_); } + "native_exp" { adv_loc(); return parser::make_NATIVE_EXP(loc_); } + "native_exp2" { adv_loc(); return parser::make_NATIVE_EXP2(loc_); } + + // coopmatrix reduce + "cooperative_matrix_reduce_add" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE_ADD(loc_); } + "cooperative_matrix_reduce_max" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE_MAX(loc_); } + "cooperative_matrix_reduce_min" { adv_loc(); return parser::make_COOPERATIVE_MATRIX_REDUCE_MIN(loc_); } + + // subgroup op + "subgroup_exclusive_scan_add" { adv_loc(); return parser::make_SUBGROUP_EXCLUSIVE_SCAN_ADD(loc_); } + "subgroup_exclusive_scan_max" { adv_loc(); return parser::make_SUBGROUP_EXCLUSIVE_SCAN_MAX(loc_); } + "subgroup_exclusive_scan_min" { adv_loc(); return parser::make_SUBGROUP_EXCLUSIVE_SCAN_MIN(loc_); } + "subgroup_inclusive_scan_add" { adv_loc(); return parser::make_SUBGROUP_INCLUSIVE_SCAN_ADD(loc_); } + "subgroup_inclusive_scan_max" { adv_loc(); return parser::make_SUBGROUP_INCLUSIVE_SCAN_MAX(loc_); } + "subgroup_inclusive_scan_min" { adv_loc(); return parser::make_SUBGROUP_INCLUSIVE_SCAN_MIN(loc_); } + "subgroup_reduce_add" { adv_loc(); return parser::make_SUBGROUP_REDUCE_ADD(loc_); } + "subgroup_reduce_max" { adv_loc(); return parser::make_SUBGROUP_REDUCE_MAX(loc_); } + "subgroup_reduce_min" { adv_loc(); return parser::make_SUBGROUP_REDUCE_MIN(loc_); } + + // other strings + string { adv_loc(); return parser::make_STRING(std::string(b+1, YYCURSOR-1), loc_); } + whitespace { adv_loc(); goto lex; } comment { adv_loc(); goto lex; } @@ -186,6 +269,7 @@ lex: throw parser::syntax_error(loc_, "Unknown token"); } */ + // clang-format on } std::uint64_t lexer::lex_number(char const *s, char const *e) { @@ -235,35 +319,4 @@ double lexer::lex_floating_constant(char const *s, char const *e) { return d; } -scalar_type lexer::lex_integer_type(char const *s, char const *) { - char const *YYMARKER; - /*!re2c - re2c:yyfill:enable = 0; - re2c:define:YYCURSOR = s; - - "i1" { return scalar_type::i1; } - "i8" { return scalar_type::i8; } - "i16" { return scalar_type::i16; } - "i32" { return scalar_type::i32; } - "i64" { return scalar_type::i64; } - "index" { return scalar_type::index; } - $ { return {}; } - * { return {}; } - */ - return scalar_type{}; -} -scalar_type lexer::lex_floating_type(char const *s, char const *) { - char const *YYMARKER; - /*!re2c - re2c:yyfill:enable = 0; - re2c:define:YYCURSOR = s; - - "f32" { return scalar_type::f32; } - "f64" { return scalar_type::f64; } - $ { return {}; } - * { return {}; } - */ - return scalar_type{}; -} - } // namespace tinytc diff --git a/src/parser/parse_context.cpp b/src/parser/parse_context.cpp index 336c0226..f830ec4a 100644 --- a/src/parser/parse_context.cpp +++ b/src/parser/parse_context.cpp @@ -2,9 +2,9 @@ // SPDX-License-Identifier: BSD-3-Clause #include "parse_context.hpp" +#include "compiler_context.hpp" #include "location.hpp" -#include "node/function_node.hpp" -#include "node/value_node.hpp" +#include "node/value.hpp" #include "parser/parser_impl.hpp" #include @@ -12,48 +12,81 @@ namespace tinytc { -void parse_context::push_scope() { id_map_.push_back({}); } -void parse_context::pop_scope() { id_map_.pop_back(); } +parse_context::parse_context(shared_handle compiler_ctx) + : compiler_ctx_(std::move(compiler_ctx)) {} -void parse_context::val(std::string const &id, value val, location const &l) { - for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { - if (auto other = it->find(id); other != it->end()) { - auto oss = std::ostringstream{}; - oss << "Identifier %" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(l, oss.str()); - } - } - val->loc(l); - id_map_.back()[id] = std::move(val); +void parse_context::push_scope() { + unnamed_id_map_.push_back({}); + named_id_map_.push_back({}); } +void parse_context::pop_scope() { + named_id_map_.pop_back(); + unnamed_id_map_.pop_back(); +} + +void parse_context::push_region(tinytc_region_t r) { regions_.push(r); } +void parse_context::pop_region() { regions_.pop(); } +auto parse_context::top_region() -> tinytc_region_t { return regions_.top(); } +auto parse_context::has_regions() -> bool { return !regions_.empty(); } -value parse_context::val(std::string const &id, location const &l) { - for (auto it = id_map_.rbegin(); it != id_map_.rend(); ++it) { - if (auto j = it->find(id); j != it->end()) { - return j->second; - } +void parse_context::add_global_name(std::string const &name, location const &l) { + if (auto other = global_names_.find(name); other != global_names_.end()) { + auto oss = std::ostringstream{}; + oss << "Identifier @" << name << " was already used at " << other->second; + throw parser::syntax_error(l, std::move(oss).str()); } - throw parser::syntax_error(l, "Undefined identifier %" + id); + global_names_[name] = l; } -void parse_context::prototype(std::string const &id, func p) { - if (auto other = prototype_map_.find(id); other != prototype_map_.end()) { - auto oss = std::ostringstream{}; - oss << "Identifier @" << id << " was already used at " << other->second->loc(); - throw parser::syntax_error(p->loc(), oss.str()); +void parse_context::val(std::variant const &id, tinytc_value &val, + location const &l) { + const auto handle_val = + [&l, &val](KeyT const &id, + std::vector> &map) { + if (map.empty()) { + throw parser::syntax_error(l, "No active scope"); + } + for (auto it = map.rbegin(); it != map.rend(); ++it) { + if (auto other = it->find(id); other != it->end()) { + auto oss = std::ostringstream{}; + oss << "Identifier %" << id << " was already used at " << other->second->loc(); + throw parser::syntax_error(l, std::move(oss).str()); + } + } + val.loc(l); + map.back()[id] = &val; + }; + if (std::holds_alternative(id)) { + handle_val(std::get(id), unnamed_id_map_); + } else { + auto const &sid = std::get(id); + handle_val(sid, named_id_map_); + val.name(sid); } - prototype_map_[id] = std::move(p); } -func parse_context::prototype(std::string const &id, location const &l) { - if (auto j = prototype_map_.find(id); j != prototype_map_.end()) { - return j->second; +auto parse_context::val(std::variant const &id, location const &l) + -> tinytc_value_t { + const auto handle_val = + [&l](KeyT const &id, + std::vector> &map) { + for (auto it = map.rbegin(); it != map.rend(); ++it) { + if (auto j = it->find(id); j != it->end()) { + return j->second; + } + } + auto oss = std::ostringstream{}; + oss << "Undefined identifier %" << id; + throw parser::syntax_error(l, std::move(oss).str()); + }; + if (std::holds_alternative(id)) { + return handle_val(std::get(id), unnamed_id_map_); } - throw parser::syntax_error(l, "Undefined identifier @" + id); + return handle_val(std::get(id), named_id_map_); } -void parse_context::add_error(location const &loc, std::string const &what) { - errors_.emplace_back(std::make_pair(loc, what)); +void parse_context::report_error(location const &loc, std::string const &what) { + compiler_ctx_->report_error(loc, what.c_str()); } } // namespace tinytc diff --git a/src/parser/parse_context.hpp b/src/parser/parse_context.hpp index 832cc3a8..dad6f838 100644 --- a/src/parser/parse_context.hpp +++ b/src/parser/parse_context.hpp @@ -4,41 +4,51 @@ #ifndef PARSE_CONTEXT_20231221_HPP #define PARSE_CONTEXT_20231221_HPP -#include "tinytc/tinytc.hpp" +#include "tinytc/types.h" #include "tinytc/types.hpp" +#include +#include #include #include #include +#include #include namespace tinytc { class parse_context { public: - inline parse_context() { id_map_.push_back({}); } + parse_context(shared_handle compiler_ctx); inline auto program() { return program_; } - inline void program(prog p) { program_ = std::move(p); } + inline void program(shared_handle p) { program_ = std::move(p); } + + void val(std::variant const &id, tinytc_value &val, + location const &l); + auto val(std::variant const &id, location const &l) + -> tinytc_value_t; + + void report_error(location const &loc, std::string const &what); + + auto cctx() -> tinytc_compiler_context_t { return compiler_ctx_.get(); } void push_scope(); void pop_scope(); - void val(std::string const &id, value val, location const &l); - value val(std::string const &id, location const &l); - - void prototype(std::string const &id, func p); - func prototype(std::string const &id, location const &l); - void add_error(location const &loc, std::string const &what); + void push_region(tinytc_region_t r); + void pop_region(); + auto top_region() -> tinytc_region_t; + auto has_regions() -> bool; - inline auto errors() const -> std::vector> const & { - return errors_; - } + void add_global_name(std::string const &name, location const &l); private: - std::vector> id_map_; - std::unordered_map prototype_map_; - prog program_; - std::vector> errors_; + shared_handle compiler_ctx_; + std::vector> unnamed_id_map_; + std::vector> named_id_map_; + std::stack regions_; + std::unordered_map global_names_; + shared_handle program_; }; } // namespace tinytc diff --git a/src/parser/parser_impl.yy b/src/parser/parser_impl.yy index 07ff319e..80aea2a5 100644 --- a/src/parser/parser_impl.yy +++ b/src/parser/parser_impl.yy @@ -5,68 +5,80 @@ %language "c++" %code requires { - #include "node/function_node.hpp" - #include "slice.hpp" - #include "tinytc/tinytc.hpp" + #include "tinytc/types.h" #include "tinytc/types.hpp" - #include + #include #include #include + #include #include #include + #include + #include namespace tinytc { class parse_context; class lexer; + + using int_or_val = std::variant; + + using identifier = std::variant; + struct param_attrs { + identifier id; + location loc; + tinytc_attr_t dict; + }; } } %code { + #include "compiler_context.hpp" #include "error.hpp" - #include "node/data_type_node.hpp" - #include "node/inst_node.hpp" - #include "node/program_node.hpp" - #include "node/region_node.hpp" - #include "node/value_node.hpp" + #include "node/attr.hpp" + #include "node/func.hpp" + #include "node/inst.hpp" + #include "node/inst_view.hpp" + #include "node/region.hpp" + #include "node/value.hpp" #include "parser/lexer.hpp" #include "parser/parse_context.hpp" - #include "passes.hpp" - #include "util.hpp" - - #include + #include "tinytc/builder.hpp" + #include "tinytc/core.hpp" + #include "util/ilist.hpp" + #include "util/iterator.hpp" + #include "util/overloaded.hpp" - #include + #include #include - #include - #include - #include + #include + #include + #include #include + #include namespace tinytc { - void check_scalar_type(value & val, scalar_type const& sty, location & loc1, - location & loc2) { - clir::visit( - overloaded{[&](int_imm &i) { i.ty(make_scalar(sty)); }, - [&](float_imm &i) { i.ty(make_scalar(sty)); }, - [&](auto &) { - if (!val->ty() || !is_equal(*val->ty(), *make_scalar(sty))) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error( - loc, "Type of SSA value does not match operand type"); - } - }}, - *val); - } - void check_type(value & val, data_type & ty, location & loc1, - location & loc2) { - if (!val->ty() || !is_equal(*val->ty(), *ty)) { - auto loc = loc1; - loc.end = loc2.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); + void report_error(tinytc_compiler_context_t cctx, compilation_error const& e) { + if (e.extra_info().size() > 0) { + auto what = (std::ostringstream{} << e.what() << " (" << e.extra_info() << ')').str(); + cctx->report_error(e.loc(), e.ref_values(), what.c_str()); + } else { + cctx->report_error(e.loc(), e.ref_values(), e.what()); } - }; } + template + void yytry(parse_context& ctx, F&& f, location const& loc = {}) { + try { + f(); + } catch (compilation_error const &e) { + report_error(ctx.cctx(), e); + throw parser::syntax_error({}, ""); + } catch (status st) { + throw parser::syntax_error(loc, to_string(st)); + } catch (std::exception const &e) { + throw parser::syntax_error(loc, e.what()); + } + } + } // namespace tinytc } %header @@ -98,125 +110,199 @@ LSQBR "[" RSQBR "]" FUNC "func" - WORK_GROUP_SIZE "work_group_size" - SUBGROUP_SIZE "subgroup_size" - RETURNS "->" + ATTRIBUTES "attributes" + ARROW "->" DYNAMIC "?" - NOTRANS ".n" - TRANS ".t" ATOMIC ".atomic" + ATOMIC_ADD ".atomic_add" + ATOMIC_MAX ".atomic_max" + ATOMIC_MIN ".atomic_min" + INIT "init" + LOCAL "local" + GLOBAL "global" + LOCAL_ATTR ".local" + GLOBAL_ATTR ".global" + BOOLEAN "bool" + COOPMATRIX "coopmatrix" MEMREF "memref" GROUP "group" OFFSET "offset" STRIDED "strided" - AXPBY "axpby" - ARITH "arith" - GEMM "gemm" - GEMV "gemv" - GER "ger" - HADAMARD "hadamard" - ALLOCA "alloca" - CAST "cast" - CMP "cmp" - EXPAND "expand" - FUSE "fuse" - LOAD "load" - FOR "for" - FOREACH "foreach" - IF "if" - ELSE "else" - GROUP_ID "group_id" - GROUP_SIZE "group_size" - SIZE "size" - SUBVIEW "subview" - STORE "store" - SUM "sum" - YIELD "yield" -; -%token LOCAL_IDENTIFIER +; +%token + I8_TYPE "i8" + I16_TYPE "i16" + I32_TYPE "i32" + I64_TYPE "i64" + INDEX_TYPE "index" + BF16_TYPE "bf16" + F16_TYPE "f16" + F32_TYPE "f32" + F64_TYPE "f64" + C32_TYPE "c32" + C64_TYPE "64" +; + +%token + ALLOCA "alloca" + BARRIER "barrier" + CAST "cast" + CONSTANT "constant" + COOPERATIVE_MATRIX_APPLY "cooperative_matrix_apply" + COOPERATIVE_MATRIX_EXTRACT "cooperative_matrix_extract" + COOPERATIVE_MATRIX_INSERT "cooperative_matrix_insert" + COOPERATIVE_MATRIX_LOAD "cooperative_matrix_load" + COOPERATIVE_MATRIX_MUL_ADD "cooperative_matrix_mul_add" + COOPERATIVE_MATRIX_PREFETCH "cooperative_matrix_prefetch" + COOPERATIVE_MATRIX_REDUCE_ADD "cooperative_matrix_reduce_add" + COOPERATIVE_MATRIX_REDUCE_MAX "cooperative_matrix_reduce_max" + COOPERATIVE_MATRIX_REDUCE_MIN "cooperative_matrix_reduce_min" + COOPERATIVE_MATRIX_SCALE "cooperative_matrix_scale" + COOPERATIVE_MATRIX_STORE "cooperative_matrix_store" + EXPAND "expand" + FUSE "fuse" + LOAD "load" + IF "if" + ELSE "else" + PARALLEL "parallel" + SIZE "size" + SUBGROUP_BROADCAST "subgroup_broadcast" + SUBVIEW "subview" + STORE "store" + YIELD "yield" + ADD "add" + SUB "sub" + MUL "mul" + DIV "div" + REM "rem" + SHL "shl" + SHR "shr" + AND "and" + OR "or" + XOR "xor" + MIN "min" + MAX "max" + ABS "abs" + NEG "neg" + NOT "not" + CONJ "conj" + IM "im" + RE "re" + AXPBY "axpby" + CUMSUM "cumsum" + SUM "sum" + GEMM "gemm" + GEMV "gemv" + GER "ger" + HADAMARD "hadamard" + GROUP_ID "group_id" + NUM_GROUPS "num_groups" + NUM_SUBGROUPS "num_subgroups" + SUBGROUP_SIZE "subgroup_size" + SUBGROUP_ID "subgroup_id" + SUBGROUP_LINEAR_ID "subgroup_linear_id" + SUBGROUP_LOCAL_ID "subgroup_local_id" + EQUAL "equal" + NOT_EQUAL "not_equal" + GREATER_THAN "greater_than" + GREATER_THAN_EQUAL "greater_than_equal" + LESS_THAN "less_than" + LESS_THAN_EQUAL "less_than_equal" + FOR "for" + FOREACH "foreach" + COS "cos" + SIN "sin" + EXP "exp" + EXP2 "exp2" + NATIVE_COS "native_cos" + NATIVE_SIN "native_sin" + NATIVE_EXP "native_exp" + NATIVE_EXP2 "native_exp2" + SUBGROUP_EXCLUSIVE_SCAN_ADD "subgroup_exclusive_scan_add" + SUBGROUP_EXCLUSIVE_SCAN_MAX "subgroup_exclusive_scan_max" + SUBGROUP_EXCLUSIVE_SCAN_MIN "subgroup_exclusive_scan_min" + SUBGROUP_INCLUSIVE_SCAN_ADD "subgroup_inclusive_scan_add" + SUBGROUP_INCLUSIVE_SCAN_MAX "subgroup_inclusive_scan_max" + SUBGROUP_INCLUSIVE_SCAN_MIN "subgroup_inclusive_scan_min" + SUBGROUP_REDUCE_ADD "subgroup_reduce_add" + SUBGROUP_REDUCE_MAX "subgroup_reduce_max" + SUBGROUP_REDUCE_MIN "subgroup_reduce_min" +; +%token LOCAL_IDENTIFIER %token GLOBAL_IDENTIFIER +%token ATTR_NAME +%token STRING +%token BOOLEAN_CONSTANT %token INTEGER_CONSTANT %token FLOATING_CONSTANT -%token INTEGER_TYPE -%token FLOATING_TYPE -%token ARITHMETIC -%token ARITHMETIC_UNARY -%token CMP_CONDITION - -%nterm prog -%nterm > func_list -%nterm func -%nterm > arguments -%nterm <::tinytc::value> argument -%nterm >> attributes -%nterm > attribute -%nterm data_type -%nterm scalar_type -%nterm memref_type +%token COMP3 +%token REDUCE_MODE +%token MATRIX_USE +%token CHECKED +%token TRANSPOSE + +%nterm >> func_list +%nterm > func +%nterm ,std::vector>> parameters +%nterm > parameter +%nterm function_attributes +%nterm attribute +%nterm array_attribute +%nterm > attribute_list +%nterm dictionary_attribute +%nterm > named_attribute_list +%nterm named_attribute +%nterm attribute_name +%nterm optional_dictionary_attribute +%nterm data_type +%nterm boolean_type +%nterm scalar_type +%nterm coopmatrix_type +%nterm memref_type +%nterm optional_address_space %nterm > mode_list %nterm > optional_stride_list %nterm > stride_list %nterm constant_or_dynamic -%nterm group_type +%nterm group_type %nterm group_offset -%nterm memref_or_group_type -%nterm region -%nterm <::tinytc::value> var -%nterm > instructions -%nterm instruction -%nterm axpby_inst +%nterm var +%nterm > instruction %nterm atomic -%nterm <::tinytc::value> identifier_or_constant -%nterm > optional_identifier_or_constant_list -%nterm > identifier_or_constant_list -%nterm gemm_inst -%nterm gemv_inst -%nterm ger_inst -%nterm transpose -%nterm for_inst -%nterm <::tinytc::value> optional_step -%nterm foreach_inst -%nterm hadamard_inst -%nterm if_inst -%nterm > optional_returned_values -%nterm > optional_scalar_type_list -%nterm > scalar_type_list -%nterm else_region -%nterm sum_inst -%nterm yield_inst -%nterm for_loop_var_type -%nterm var_definition -%nterm > identifier_list -%nterm valued_inst -%nterm alloca_inst -%nterm arith_inst -%nterm arith_unary_inst -%nterm cast_inst -%nterm compare_inst -%nterm expand_inst -%nterm <::tinytc::value> constant_or_dynamic_or_identifier -%nterm > expand_shape -%nterm fuse_inst -%nterm load_inst -%nterm > optional_index_list -%nterm > index_list -%nterm <::tinytc::value> index_identifier_or_const -%nterm group_id_inst -%nterm group_size_inst -%nterm size_inst -%nterm store_inst -%nterm subview_inst -%nterm > optional_slice_list -%nterm > slice_list -%nterm slice -%nterm <::tinytc::value> slice_size +%nterm > optional_value_list +%nterm > value_list +%nterm optional_global_attr +%nterm optional_local_attr +%nterm transpose_opt +%nterm > transpose_opt2 +%nterm > for_inst +%nterm , std::vector, std::vector>> optional_loop_carried_values +%nterm , std::vector>> init_value_list +%nterm > init_value +%nterm optional_step +%nterm > if_inst +%nterm > optional_returned_values +%nterm > optional_return_type_list +%nterm > return_type_list +%nterm > identifier_list +%nterm > valued_inst +%nterm checked +%nterm integer_constant_or_identifier +%nterm > expand_shape +%nterm store_flag +%nterm >> optional_slice_list +%nterm >> slice_list +%nterm > slice +%nterm slice_size %% prog: func_list { - auto p = prog { std::make_unique(std::move($func_list), @prog).release() }; - ctx.program(p); - $$ = std::move(p); + auto p = create_prog(ctx.cctx(), @prog); + for (auto& f : $func_list) { + add_function(p.get(), std::move(f)); + } + ctx.program(std::move(p)); } ; @@ -225,106 +311,166 @@ func_list: | func_list func { $$ = std::move($1); $$.emplace_back(std::move($func)); } func: - FUNC { - ctx.push_scope(); - } GLOBAL_IDENTIFIER LPAREN arguments RPAREN attributes region { + FUNC GLOBAL_IDENTIFIER LPAREN parameters RPAREN + function_attributes >{ auto loc = @FUNC; loc.end = @RPAREN.end; - auto proto = func{ - std::make_unique($GLOBAL_IDENTIFIER, std::move($arguments), loc).release()}; - ctx.prototype($GLOBAL_IDENTIFIER, proto); - auto func_node = - std::make_unique(std::move(proto), std::move($region), @func).release(); - for (auto &attr : $attributes) { - attr(*func_node); - } - $func = func{func_node}; + yytry( + ctx, + [&] { + ctx.add_global_name($GLOBAL_IDENTIFIER, loc); + auto void_ty = get(ctx.cctx()); + auto func_node = create_func($GLOBAL_IDENTIFIER, $parameters.second, void_ty, loc); + func_node->attr($function_attributes); + ctx.push_scope(); + auto name_it = $parameters.first.begin(); + for (auto &p : func_node->params()) { + ctx.val(name_it->id, p, name_it->loc); + if (name_it->dict != 0) { + func_node->param_attr(name_it - $parameters.first.begin(), name_it->dict); + } + ++name_it; + } + ctx.push_region(&func_node->body()); + $$ = std::move(func_node); + }, + loc); + }[prototype] region { + ctx.pop_region(); ctx.pop_scope(); + $$ = std::move($prototype); } ; -arguments: +parameters: %empty {} - | argument { $$.emplace_back(std::move($argument)); } - | arguments COMMA argument { $$ = std::move($1); $$.emplace_back(std::move($argument)); } + | parameter { + $$.first.emplace_back(std::move($parameter.first)); + $$.second.emplace_back(std::move($parameter.second)); + } + | parameters COMMA parameter { + $$.first = std::move($1.first); + $$.second = std::move($1.second); + $$.first.emplace_back(std::move($parameter.first)); + $$.second.emplace_back(std::move($parameter.second)); + } ; -argument: - LOCAL_IDENTIFIER COLON data_type { - auto v = make_value(std::move($data_type)); - v.name($LOCAL_IDENTIFIER); - ctx.val($LOCAL_IDENTIFIER, v, @LOCAL_IDENTIFIER); - $$ = std::move(v); +parameter: + LOCAL_IDENTIFIER COLON data_type optional_dictionary_attribute[dict] { + $$ = std::make_pair(param_attrs{$LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER, $dict}, $data_type); } ; -attributes: +function_attributes: %empty {} - | attributes attribute { $$ = std::move($1); $$.emplace_back(std::move($attribute)); } + | ATTRIBUTES dictionary_attribute { $$ = $dictionary_attribute; } ; attribute: - WORK_GROUP_SIZE LPAREN INTEGER_CONSTANT[m] COMMA INTEGER_CONSTANT[n] RPAREN { - if ($m <= 0) { - throw parser::syntax_error(@m, "Must be a non-negative number"); - } - if ($n <= 0) { - throw parser::syntax_error(@n, "Must be a non-negative number"); - } - auto const wgs = std::array{static_cast($m), - static_cast($n)}; - $$ = [=](function &f) { f.work_group_size(wgs); }; + array_attribute { $$ = $array_attribute; } + | BOOLEAN_CONSTANT { $$ = boolean_attr::get(ctx.cctx(), $BOOLEAN_CONSTANT); } + | dictionary_attribute { $$ = $dictionary_attribute; } + | INTEGER_CONSTANT { $$ = integer_attr::get(ctx.cctx(), $INTEGER_CONSTANT); } + | STRING { $$ = string_attr::get(ctx.cctx(), $STRING); } +; + +array_attribute: + LSQBR RSQBR { $$ = array_attr::get(ctx.cctx(), {}); } + | LSQBR attribute_list RSQBR { $$ = array_attr::get(ctx.cctx(), $attribute_list); } + +attribute_list: + attribute { $$.push_back($attribute); } + | attribute_list COMMA attribute { $$ = std::move($1); $$.push_back($attribute); } +; + +dictionary_attribute: + LBRACE RBRACE { $$ = dictionary_attr::get(ctx.cctx(), {}); } + | LBRACE named_attribute_list RBRACE { + dictionary_attr::sort($named_attribute_list); + $$ = dictionary_attr::get(ctx.cctx(), $named_attribute_list); } - | SUBGROUP_SIZE LPAREN INTEGER_CONSTANT RPAREN { - if ($INTEGER_CONSTANT <= 0) { - throw parser::syntax_error(@INTEGER_CONSTANT, "Must be a non-negative number"); - } - auto const sgs = static_cast($INTEGER_CONSTANT); - $$ = [=](function &f) { f.subgroup_size(sgs); }; +; + +named_attribute_list: + named_attribute { $$.push_back($named_attribute); } + | named_attribute_list COMMA named_attribute { $$ = std::move($1); $$.push_back($named_attribute); } +; + +named_attribute: + attribute_name EQUALS attribute { + $$ = tinytc_named_attr_t{$attribute_name, $attribute}; } ; +attribute_name: + ATTR_NAME { $$ = string_attr::get(ctx.cctx(), $ATTR_NAME); } + | SUBGROUP_SIZE { $$ = string_attr::get(ctx.cctx(), "subgroup_size"); } + | STRING { $$ = string_attr::get(ctx.cctx(), $STRING); } + +optional_dictionary_attribute: + %empty { $$ = nullptr; } + | dictionary_attribute { $$ = $dictionary_attribute; } +; + data_type: - scalar_type { $$ = make_scalar($scalar_type); $$->loc(@scalar_type); } - | memref_type + boolean_type + | coopmatrix_type | group_type + | memref_type + | scalar_type +; + +boolean_type: + BOOLEAN { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @boolean_type); } ; scalar_type: - INTEGER_TYPE - | FLOATING_TYPE + I8_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | I16_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | I32_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | I64_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | INDEX_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | BF16_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | F16_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | F32_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | F64_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | C32_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } + | C64_TYPE { yytry(ctx, [&] { $$ = get(ctx.cctx()); }, @scalar_type); } +; + +coopmatrix_type: + COOPMATRIX LCHEV scalar_type TIMES INTEGER_CONSTANT[rows] TIMES INTEGER_CONSTANT[cols] COMMA MATRIX_USE RCHEV { + yytry( + ctx, [&] { $$ = get($scalar_type, $rows, $cols, $MATRIX_USE); }, + @coopmatrix_type); + } ; memref_type: - MEMREF LCHEV scalar_type mode_list RCHEV { - try { - $$ = data_type { - std::make_unique($scalar_type, std::move($mode_list), - std::vector{}, @memref_type) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + MEMREF LCHEV scalar_type mode_list optional_address_space RCHEV { + yytry( + ctx, + [&] { + auto empty = array_view{}; + $$ = get($scalar_type, $mode_list, empty, $optional_address_space); + }, + @memref_type); } - | MEMREF LCHEV scalar_type mode_list COMMA STRIDED LCHEV optional_stride_list RCHEV RCHEV { + | MEMREF LCHEV scalar_type mode_list COMMA STRIDED LCHEV optional_stride_list RCHEV optional_address_space RCHEV { if ($mode_list.size() != $optional_stride_list.size()) { auto loc = @scalar_type; loc.end = @optional_stride_list.end; throw syntax_error(loc, "Shape and stride list must have the same length"); } - try { - $$ = data_type { - std::make_unique($scalar_type, std::move($mode_list), - std::move($optional_stride_list), @memref_type) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + yytry( + ctx, + [&] { + $$ = get($scalar_type, $mode_list, $optional_stride_list, + $optional_address_space); + }, + @memref_type); } ; @@ -333,6 +479,12 @@ mode_list: | mode_list TIMES constant_or_dynamic { $$ = std::move($1); $$.push_back($constant_or_dynamic); } ; +optional_address_space: + %empty { $$ = address_space::global; } + | COMMA GLOBAL { $$ = address_space::global; } + | COMMA LOCAL { $$ = address_space::local; } +; + optional_stride_list: %empty {} | stride_list { $$ = std::move($1); } @@ -349,9 +501,10 @@ constant_or_dynamic: ; group_type: - GROUP LCHEV memref_type group_offset RCHEV { - $$ = make_group(std::move($memref_type), $group_offset); - $$->loc(@group_type); + GROUP LCHEV memref_type TIMES constant_or_dynamic[group_size] group_offset RCHEV { + yytry( + ctx, [&] { $$ = get(std::move($memref_type), $group_size, $group_offset); }, + @group_type); } ; @@ -360,65 +513,32 @@ group_offset: | COMMA OFFSET COLON constant_or_dynamic { $$ = $constant_or_dynamic; } ; -memref_or_group_type: - memref_type - | group_type +var: + LOCAL_IDENTIFIER { $$ = ctx.val($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER); } ; region: - LBRACE { - ctx.push_scope(); - } instructions RBRACE { - $$ = region{std::make_unique(std::move($instructions), @region).release()}; - ctx.pop_scope(); - } -; - -var: - LOCAL_IDENTIFIER { $$ = ctx.val($LOCAL_IDENTIFIER, @LOCAL_IDENTIFIER); } + LBRACE { ctx.push_scope(); } instructions { ctx.pop_scope(); } RBRACE {} ; instructions: %empty {} | instructions instruction { - $$ = std::move($1); - $$.emplace_back(std::move($instruction)); + if (!ctx.has_regions()) { + error(@instruction, "Internal error: missing region"); + YYERROR; + } + tinytc_region_t reg = ctx.top_region(); + reg->insts().push_back($instruction.release()); } ; instruction: - axpby_inst - | gemm_inst - | gemv_inst - | ger_inst - | for_inst - | foreach_inst - | hadamard_inst - | if_inst - | var_definition - | store_inst - | sum_inst - | yield_inst -; - -axpby_inst: - AXPBY transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($b, $mb, @b, @mb); - try { - $$ = inst { - std::make_unique($ta, std::move($alpha), std::move($a), - std::move($beta), std::move($b), $atomic, @axpby_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + AXPBY atomic transpose_opt[ta] var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] { + yytry(ctx, [&] { + $$ = axpby_inst::create($atomic, $ta, std::move($alpha), std::move($a), std::move($beta), + std::move($b), @instruction); + }); } ; @@ -427,176 +547,172 @@ atomic: | ATOMIC { $$ = true; } ; -identifier_or_constant: - var { $$ = $var; } - | INTEGER_CONSTANT { $$ = make_imm($INTEGER_CONSTANT); $$->loc(@INTEGER_CONSTANT); } - | FLOATING_CONSTANT { $$ = make_imm($FLOATING_CONSTANT); $$->loc(@FLOATING_CONSTANT); } -; - -optional_identifier_or_constant_list: +optional_value_list: %empty {} - | identifier_or_constant_list { $$ = std::move($1); } + | value_list { $$ = std::move($1); } +; -identifier_or_constant_list: - identifier_or_constant { $$.push_back(std::move($identifier_or_constant)); } - | identifier_or_constant_list COMMA identifier_or_constant { +value_list: + var { $$.push_back(std::move($var)); } + | value_list COMMA var { $$ = std::move($1); - $$.push_back(std::move($identifier_or_constant)); + $$.push_back(std::move($var)); } ; -gemm_inst: - GEMM transpose[ta] transpose[tb] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); - try { - $$ = inst { - std::make_unique($ta, $tb, std::move($alpha), std::move($a), - std::move($b), std::move($beta), std::move($c), $atomic, - @gemm_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +instruction: + BARRIER optional_global_attr optional_local_attr { + int32_t fence_flags = 0; + fence_flags |= $optional_global_attr; + fence_flags |= $optional_local_attr; + yytry(ctx, [&] { $$ = barrier_inst::create(fence_flags, @instruction); }); } ; -gemv_inst: - GEMV transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); - try { - $$ = inst { - std::make_unique($ta, std::move($alpha), std::move($a), std::move($b), - std::move($beta), std::move($c), $atomic, @gemv_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +optional_global_attr: + %empty { $$ = 0; } + | GLOBAL_ATTR { $$ = tinytc_address_space_global; } +; + +optional_local_attr: + %empty { $$ = 0; } + | LOCAL_ATTR { $$ = tinytc_address_space_local; } +; + +instruction: + CUMSUM atomic var[alpha] COMMA var[a] COMMA INTEGER_CONSTANT[mode] COMMA var[beta] COMMA var[b] { + yytry(ctx, [&] { + $$ = cumsum_inst::create($atomic, $mode, std::move($alpha), std::move($a), std::move($beta), + std::move($b), @instruction); + }); } ; -transpose: - NOTRANS { $$ = transpose::N; } - | TRANS { $$ = transpose::T; } +instruction: + GEMM atomic transpose_opt2[tr] var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { + yytry(ctx, [&] { + $$ = gemm_inst::create($atomic, $tr.first, $tr.second, std::move($alpha), std::move($a), + std::move($b), std::move($beta), std::move($c), @instruction); + }); + } ; -ger_inst: - GER atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); - try { - $$ = inst { - std::make_unique(std::move($alpha), std::move($a), std::move($b), - std::move($beta), std::move($c), $atomic, @ger_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +instruction: + GEMV atomic transpose_opt[ta] var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { + yytry(ctx, [&] { + $$ = gemv_inst::create($atomic, $ta, std::move($alpha), std::move($a), std::move($b), + std::move($beta), std::move($c), @instruction); + }); + } +; + +transpose_opt: + %empty { $$ = transpose::N; } + | TRANSPOSE { $$ = $TRANSPOSE; } +; + +transpose_opt2: + %empty { $$ = std::make_pair(transpose::N, transpose::N); } + | TRANSPOSE transpose_opt { $$ = std::make_pair($TRANSPOSE, $transpose_opt); } +; + +instruction: + GER atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { + yytry(ctx, [&] { + $$ = ger_inst::create($atomic, std::move($alpha), std::move($a), std::move($b), + std::move($beta), std::move($c), @instruction); + }); } ; +instruction: for_inst { $$ = std::move($1); } ; +valued_inst: for_inst { $$ = std::move($1); } ; for_inst: - FOR LOCAL_IDENTIFIER[loop_var] - EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] optional_step - for_loop_var_type { - check_scalar_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type($to, $for_loop_var_type, @to, @for_loop_var_type); - if ($optional_step) { - check_scalar_type($optional_step, $for_loop_var_type, @optional_step, @for_loop_var_type); - } - auto v = make_value($for_loop_var_type); - v.name($loop_var); - ctx.val($loop_var, std::move(v), @loop_var); - } region { - try { - $$ = inst { - std::make_unique(ctx.val($loop_var, @loop_var), $from, $to, - $optional_step, std::move($region), @for_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + FOR LOCAL_IDENTIFIER[loop_var] EQUALS var[from] COMMA var[to] optional_step optional_loop_carried_values[lcv] > { + yytry(ctx, [&] { + auto &[lcv_id, lcv_init, lcv_type] = $lcv; + location loc = @FOR; + loc.end = @lcv.end; + $$ = for_inst::create($from, $to, $optional_step, lcv_init, lcv_type, loc); + auto inode = for_inst($$.get()); + ctx.push_scope(); + auto &loop_var = inode.loop_var(); + ctx.val($loop_var, loop_var, @loop_var); + for (std::int32_t i = 0; i < inode.get().num_results(); ++i) { + ctx.val(lcv_id[i], inode.iter_arg(i), @lcv); + } + ctx.push_region(&inode.body()); + }); + }[loop_header] region optional_dictionary_attribute { + ctx.pop_region(); + ctx.pop_scope(); + $loop_header->attr($optional_dictionary_attribute); + $$ = std::move($loop_header); } ; optional_step: %empty { $$ = {}; } - | COMMA identifier_or_constant { $$ = $identifier_or_constant; } - -foreach_inst: - FOREACH LOCAL_IDENTIFIER[loop_var] - EQUALS identifier_or_constant[from] COMMA identifier_or_constant[to] for_loop_var_type { - check_scalar_type($from, $for_loop_var_type, @from, @for_loop_var_type); - check_scalar_type($to, $for_loop_var_type, @to, @for_loop_var_type); - auto v = make_value($for_loop_var_type); - v.name($loop_var); - ctx.val($loop_var, std::move(v), @loop_var); - } region { - try { - $$ = inst { - std::make_unique(ctx.val($loop_var, @loop_var), $from, $to, - std::move($region), @foreach_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + | COMMA var { $$ = $var; } +; + +optional_loop_carried_values: + %empty { $$ = {}; } + | INIT LPAREN init_value_list RPAREN ARROW LPAREN return_type_list RPAREN { + $$ = std::make_tuple(std::move($init_value_list.first), std::move($init_value_list.second), + std::move($return_type_list)); + } +; + +init_value_list: + init_value { + $$.first.emplace_back($init_value.first); + $$.second.emplace_back($init_value.second); + } + | init_value_list COMMA init_value { + $$ = std::move($1); + $$.first.emplace_back($init_value.first); + $$.second.emplace_back($init_value.second); } ; -for_loop_var_type: - %empty { $$ = scalar_type::index; } - | COLON INTEGER_TYPE { $$ = $INTEGER_TYPE; } +init_value: + LOCAL_IDENTIFIER EQUALS var { $$ = std::make_pair($LOCAL_IDENTIFIER, $var); } ; -var_definition: +instruction: + FOREACH LPAREN identifier_list[loop_var] RPAREN EQUALS + LPAREN value_list[from] RPAREN COMMA LPAREN value_list[to] RPAREN[header_end] >{ + yytry(ctx, [&] { + location loc = @FOREACH; + loc.end = @header_end.end; + $$ = foreach_inst::create($from, $to, loc); + auto inode = foreach_inst($$.get()); + ctx.push_scope(); + auto loop_vars = inode.loop_vars().begin(); + for (std::int64_t i = 0; i < inode.dim(); ++i) { + ctx.val($loop_var[i], loop_vars[i], @loop_var); + } + ctx.push_region(&inode.body()); + }); + }[loop_header] region { + ctx.pop_region(); + ctx.pop_scope(); + $$ = std::move($loop_header); + } +; + +instruction: identifier_list EQUALS valued_inst { $$ = std::move($valued_inst); - if ($identifier_list.size() == 1) { - if (!$$->result()) { - throw syntax_error(@identifier_list, "Instruction does not return value"); - } - $$->result()->name($identifier_list[0]); - ctx.val($identifier_list[0], $$->result(), @identifier_list); - } else { - auto results = $$->results(); - if (results.size() != $identifier_list.size()) { - throw syntax_error( - @identifier_list, - "Number of identifiers does not equal number of returned values"); - } - for (std::size_t i = 0; i < results.size(); ++i) { - results[i]->name($identifier_list[i]); - ctx.val($identifier_list[i], results[i], @identifier_list); - } + if (static_cast($identifier_list.size()) != $$->num_results()) { + throw syntax_error( + @identifier_list, + "Number of identifiers does not equal number of returned values"); + } + auto results = $$->result_begin(); + for (std::int64_t i = 0; i < $$->num_results(); ++i) { + ctx.val($identifier_list[i], results[i], @identifier_list); } } ; @@ -607,344 +723,399 @@ identifier_list: ; -hadamard_inst: - HADAMARD atomic - identifier_or_constant[alpha] COMMA var[a] COMMA var[b] COMMA identifier_or_constant[beta] COMMA var[c] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA memref_type[mb] COMMA scalar_type[fbeta] - COMMA memref_type[mc] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_type($b, $mb, @b, @mb); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($c, $mc, @c, @mc); - try { - $$ = inst { - std::make_unique(std::move($alpha), std::move($a), std::move($b), - std::move($beta), std::move($c), $atomic, - @hadamard_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +instruction: + HADAMARD atomic var[alpha] COMMA var[a] COMMA var[b] COMMA var[beta] COMMA var[c] { + yytry(ctx, [&] { + $$ = hadamard_inst::create($atomic, std::move($alpha), std::move($a), std::move($b), + std::move($beta), std::move($c), @instruction); + }); } ; -sum_inst: - SUM transpose[ta] atomic - identifier_or_constant[alpha] COMMA var[a] COMMA identifier_or_constant[beta] COMMA var[b] - COLON scalar_type[falpha] COMMA memref_type[ma] COMMA scalar_type[fbeta] COMMA memref_type[mb] { - check_scalar_type($alpha, $falpha, @alpha, @falpha); - check_type($a, $ma, @a, @ma); - check_scalar_type($beta, $fbeta, @beta, @fbeta); - check_type($b, $mb, @b, @mb); - try { - $$ = inst { - std::make_unique($ta, std::move($alpha), std::move($a), std::move($beta), - std::move($b), $atomic, @sum_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +instruction: + SUM atomic transpose_opt[ta] var[alpha] COMMA var[a] COMMA var[beta] COMMA var[b] { + yytry(ctx, [&] { + $$ = sum_inst::create($atomic, $ta, std::move($alpha), std::move($a), std::move($beta), + std::move($b), @instruction); + }); } ; -yield_inst: - YIELD optional_identifier_or_constant_list[vals] COLON optional_scalar_type_list[tys] { - if ($vals.size() != $tys.size()) { - location loc = @vals; - loc.end = @tys.end; - throw syntax_error(loc, "Identifier and scalar type list must have the same length"); - } - for (std::size_t i = 0; i < $vals.size(); ++i) { - check_scalar_type($vals[i], $tys[i], @vals, @tys); - } - $$ = inst{std::make_unique(std::move($vals)).release()}; +instruction: + YIELD LPAREN optional_value_list[vals] RPAREN { + yytry(ctx, [&] { $$ = yield_inst::create(std::move($vals), @instruction); }); } ; valued_inst: - alloca_inst - | arith_inst - | arith_unary_inst - | cast_inst - | compare_inst - | expand_inst - | fuse_inst - | if_inst - | load_inst - | group_id_inst - | group_size_inst - | size_inst - | subview_inst -; - -alloca_inst: - ALLOCA RETURNS memref_type { - try { - $$ = inst { - std::make_unique(std::move($memref_type), @alloca_inst).release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + ALLOCA optional_dictionary_attribute[dict] COLON memref_type { + yytry(ctx, [&] { + $$ = alloca_inst::create(std::move($memref_type), @valued_inst); + $$->attr($dict); + }); } ; -arith_inst: - ARITH ARITHMETIC identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - check_scalar_type($b, $ty, @b, @ty); - try { - $$ = inst { - std::make_unique($ARITHMETIC, std::move($a), std::move($b), @arith_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: ADD var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = add_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: SUB var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = sub_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: MUL var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = mul_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: DIV var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = div_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: REM var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = rem_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: SHL var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = shl_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: SHR var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = shr_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: AND var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = and_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: OR var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = or_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: XOR var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = xor_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: MIN var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = min_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: MAX var[a] COMMA var[b] COLON data_type[ty] { yytry(ctx, [&] { $$ = max_inst::create($a, $b, $ty, @valued_inst); }); }; + +valued_inst: ABS var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = abs_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: NEG var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = neg_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: NOT var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = not_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: CONJ var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = conj_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: IM var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = im_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: RE var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = re_inst::create($a, $ty, @valued_inst); }); }; + +valued_inst: GROUP_ID COMP3 COLON data_type[ty] { yytry(ctx, [&] { $$ = group_id_inst::create($COMP3, $ty, @valued_inst); }); }; +valued_inst: NUM_GROUPS COMP3 COLON data_type[ty] { yytry(ctx, [&] { $$ = num_groups_inst::create($COMP3, $ty, @valued_inst); }); }; +valued_inst: NUM_SUBGROUPS COMP3 COLON data_type[ty] { yytry(ctx, [&] { $$ = num_subgroups_inst::create($COMP3, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_SIZE COLON data_type[ty] { yytry(ctx, [&] { $$ = subgroup_size_inst::create($ty, @valued_inst); }); }; +valued_inst: SUBGROUP_ID COMP3 COLON data_type[ty] { yytry(ctx, [&] { $$ = subgroup_id_inst::create($COMP3, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_LINEAR_ID COLON data_type[ty] { yytry(ctx, [&] { $$ = subgroup_linear_id_inst::create($ty, @valued_inst); }); }; +valued_inst: SUBGROUP_LOCAL_ID COLON data_type[ty] { yytry(ctx, [&] { $$ = subgroup_local_id_inst::create($ty, @valued_inst); }); }; + +valued_inst: + CAST var[a] COLON data_type[to] { + yytry(ctx, [&] { $$ = cast_inst::create(std::move($a), $to, @valued_inst); }); } ; -arith_unary_inst: - ARITH ARITHMETIC_UNARY identifier_or_constant[a] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - try { - $$ = inst { - std::make_unique($ARITHMETIC_UNARY, std::move($a), - @arith_unary_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: EQUAL var[a] COMMA var[b] COLON boolean_type[ty] { yytry(ctx, [&] { $$ = equal_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: NOT_EQUAL var[a] COMMA var[b] COLON boolean_type[ty] { yytry(ctx, [&] { $$ = not_equal_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: GREATER_THAN var[a] COMMA var[b] COLON boolean_type[ty] { yytry(ctx, [&] { $$ = greater_than_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: GREATER_THAN_EQUAL var[a] COMMA var[b] COLON boolean_type[ty] { yytry(ctx, [&] { $$ = greater_than_equal_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: LESS_THAN var[a] COMMA var[b] COLON boolean_type[ty] { yytry(ctx, [&] { $$ = less_than_inst::create($a, $b, $ty, @valued_inst); }); }; +valued_inst: LESS_THAN_EQUAL var[a] COMMA var[b] COLON boolean_type[ty] { yytry(ctx, [&] { $$ = less_than_equal_inst::create($a, $b, $ty, @valued_inst); }); }; + +valued_inst: + CONSTANT LSQBR FLOATING_CONSTANT[re] COMMA FLOATING_CONSTANT[im] RSQBR COLON data_type { + yytry(ctx, [&] { + $$ = constant_inst::create(std::complex{$re, $im}, $data_type, @valued_inst); + }); + } + | CONSTANT FLOATING_CONSTANT COLON data_type { + yytry(ctx, [&] { + $$ = constant_inst::create($FLOATING_CONSTANT, $data_type, @valued_inst); + }); + } + | CONSTANT INTEGER_CONSTANT COLON data_type { + yytry(ctx, [&] { + $$ = constant_inst::create($INTEGER_CONSTANT, $data_type, @valued_inst); + }); + } + | CONSTANT BOOLEAN_CONSTANT COLON data_type { + yytry(ctx, [&] { + $$ = constant_inst::create($BOOLEAN_CONSTANT, $data_type, @valued_inst); + }); } ; +valued_inst: + COOPERATIVE_MATRIX_APPLY + LPAREN LOCAL_IDENTIFIER[row] COMMA LOCAL_IDENTIFIER[col] COMMA LOCAL_IDENTIFIER[val] RPAREN + EQUALS var ARROW data_type[result_ty] > { + yytry(ctx, [&] { + location loc = @COOPERATIVE_MATRIX_APPLY; + loc.end = @result_ty.end; + $$ = cooperative_matrix_apply_inst::create($var, $result_ty, loc); + auto inode = cooperative_matrix_apply_inst($$.get()); + ctx.push_scope(); + auto &row = inode.row(); + ctx.val($row, row, @row); + auto &col = inode.col(); + ctx.val($col, col, @col); + auto &val = inode.val(); + ctx.val($val, val, @val); + ctx.push_region(&inode.body()); + }); + }[apply_header] region { + ctx.pop_region(); + ctx.pop_scope(); + $$ = std::move($apply_header); + } +; -cast_inst: - CAST identifier_or_constant[a] COLON scalar_type[from] RETURNS scalar_type[to] { - check_scalar_type($a, $from, @a, @from); - try { - $$ = inst { std::make_unique(std::move($a), $to, @cast_inst).release() }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: + COOPERATIVE_MATRIX_EXTRACT var[mat] LSQBR INTEGER_CONSTANT[index] RSQBR COLON data_type[ty] { + yytry(ctx, [&] { + $$ = cooperative_matrix_extract_inst::create($index, std::move($mat), std::move($ty), + @valued_inst); + }); } ; -compare_inst: - CMP CMP_CONDITION identifier_or_constant[a] COMMA identifier_or_constant[b] COLON scalar_type[ty] { - check_scalar_type($a, $ty, @a, @ty); - check_scalar_type($b, $ty, @b, @ty); - try { - $$ = inst { - std::make_unique($CMP_CONDITION, std::move($a), std::move($b), - @compare_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: + COOPERATIVE_MATRIX_INSERT var[val] COMMA var[mat] LSQBR INTEGER_CONSTANT[index] RSQBR COLON data_type[ty] { + yytry(ctx, [&] { + $$ = cooperative_matrix_insert_inst::create($index, std::move($val), std::move($mat), + std::move($ty), @valued_inst); + }); } ; -expand_inst: - EXPAND var LSQBR INTEGER_CONSTANT[mode] RETURNS expand_shape RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - try { - $$ = inst { - std::make_unique(std::move($var), $mode, std::move($expand_shape), - @expand_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: + COOPERATIVE_MATRIX_LOAD transpose_opt[ta] checked var[op] LSQBR var[p0] COMMA var[p1] RSQBR COLON data_type[result_ty] { + yytry(ctx, [&] { + $$ = cooperative_matrix_load_inst::create($ta, $checked, std::move($op), std::move($p0), + std::move($p1), std::move($result_ty), @valued_inst); + }); } ; -expand_shape: - constant_or_dynamic_or_identifier[a] TIMES constant_or_dynamic_or_identifier[b] { - $$ = std::vector{$a, $b}; +checked: + %empty { $$ = checked_flag::none; } + | CHECKED { $$ = $CHECKED; } +; + +valued_inst: + COOPERATIVE_MATRIX_MUL_ADD var[a] COMMA var[b] COMMA var[c] COLON data_type[to_ty] { + yytry(ctx, [&] { + $$ = cooperative_matrix_mul_add_inst::create(std::move($a), std::move($b), std::move($c), + std::move($to_ty), @valued_inst); + }); } - | expand_shape TIMES constant_or_dynamic_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } ; -constant_or_dynamic_or_identifier: - var { - check_scalar_type($var, scalar_type::index, @var, @var); - $$ = $var; +instruction: + COOPERATIVE_MATRIX_PREFETCH INTEGER_CONSTANT[cache_level] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR COMMA INTEGER_CONSTANT[rows] COMMA INTEGER_CONSTANT[cols] { + yytry(ctx, [&] { + $$ = cooperative_matrix_prefetch_inst::create($cache_level, $rows, $cols, std::move($op), + std::move($p0), std::move($p1), @instruction); + }); } - | INTEGER_CONSTANT { $$ = make_index($INTEGER_CONSTANT); $$->loc(@INTEGER_CONSTANT); } - | DYNAMIC { $$ = make_dynamic(); $$->loc(@DYNAMIC); } ; -fuse_inst: - FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - try { - $$ = inst { - std::make_unique(std::move($var), $from, $to, @fuse_inst).release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } + +valued_inst: COOPERATIVE_MATRIX_REDUCE_ADD REDUCE_MODE var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = cooperative_matrix_reduce_add_inst::create($REDUCE_MODE, $a, $ty, @valued_inst); }); }; +valued_inst: COOPERATIVE_MATRIX_REDUCE_MAX REDUCE_MODE var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = cooperative_matrix_reduce_max_inst::create($REDUCE_MODE, $a, $ty, @valued_inst); }); }; +valued_inst: COOPERATIVE_MATRIX_REDUCE_MIN REDUCE_MODE var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = cooperative_matrix_reduce_min_inst::create($REDUCE_MODE, $a, $ty, @valued_inst); }); }; + +valued_inst: + COOPERATIVE_MATRIX_SCALE var[a] COMMA var[b] COLON data_type[ty] { + yytry(ctx, [&] { + $$ = cooperative_matrix_scale_inst::create(std::move($a), std::move($b), std::move($ty), + @valued_inst); + }); } ; -load_inst: - LOAD var LSQBR optional_index_list RSQBR COLON memref_or_group_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_or_group_type)) { - auto loc = @var; - loc.end = @memref_or_group_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - try { - $$ = inst { - std::make_unique(std::move($var), std::move($optional_index_list), - @load_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +instruction: + COOPERATIVE_MATRIX_STORE transpose_opt[ta] checked store_flag var[val] COMMA var[op] LSQBR var[p0] COMMA var[p1] RSQBR { + yytry(ctx, [&] { + $$ = cooperative_matrix_store_inst::create($ta, $checked, $store_flag, std::move($val), + std::move($op), std::move($p0), std::move($p1), + @instruction); + }); } ; -optional_index_list: - %empty {} - | index_list { $$ = std::move($1); } +valued_inst: + EXPAND var LSQBR INTEGER_CONSTANT[expanded_mode] ARROW expand_shape RSQBR COLON memref_type[ty] { + yytry(ctx, [&] { + auto static_shape = std::vector{}; + static_shape.reserve($expand_shape.size()); + auto dynamic_shape = std::vector{}; + dynamic_shape.reserve($expand_shape.size()); + for (auto &s : $expand_shape) { + std::visit(overloaded{ + [&](std::int64_t i) { static_shape.push_back(i); }, + [&](tinytc_value_t v) { + static_shape.push_back(dynamic); + dynamic_shape.push_back(v); + }, + }, + s); + } + $$ = expand_inst::create($expanded_mode, std::move(static_shape), std::move($var), + std::move(dynamic_shape), $ty, @valued_inst); + }); + } ; -index_list: - index_identifier_or_const { $$.push_back($index_identifier_or_const); } - | index_list COMMA index_identifier_or_const { $$ = std::move($1); $$.push_back($index_identifier_or_const); } +expand_shape: + integer_constant_or_identifier[a] TIMES integer_constant_or_identifier[b] { + $$ = std::vector{$a, $b}; + } + | expand_shape TIMES integer_constant_or_identifier[a] { $$ = std::move($1); $$.push_back($a); } ; -index_identifier_or_const: +integer_constant_or_identifier: var { - check_scalar_type($var, scalar_type::index, @var, @var); $$ = $var; } | INTEGER_CONSTANT { - $$ = make_index($INTEGER_CONSTANT); - $$->loc(@INTEGER_CONSTANT); + $$ = $INTEGER_CONSTANT; } ; -store_inst: - STORE var[a] COMMA var[b] LSQBR optional_index_list RSQBR COLON memref_type { - if (!$b->ty() || !is_equal(*$b->ty(), *$memref_type)) { - auto loc = @b; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - try { - $$ = inst { - std::make_unique(std::move($a), std::move($b), - std::move($optional_index_list), @store_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: + FUSE var LSQBR INTEGER_CONSTANT[from] COMMA INTEGER_CONSTANT[to] RSQBR COLON memref_type[ty] { + yytry(ctx, [&] { + $$ = fuse_inst::create($from, $to, std::move($var), $ty, @valued_inst); + }); + } +; + +valued_inst: + LOAD var LSQBR optional_value_list RSQBR COLON data_type { + yytry(ctx, [&] { + $$ = load_inst::create(std::move($var), std::move($optional_value_list), std::move($data_type), + @valued_inst); + }); } ; -group_id_inst: - GROUP_ID { $$ = inst{std::make_unique().release()}; } +instruction: + STORE store_flag var[a] COMMA var[b] LSQBR optional_value_list RSQBR { + yytry(ctx, [&] { + $$ = store_inst::create($store_flag, std::move($a), std::move($b), + std::move($optional_value_list), @instruction); + }); + } ; -group_size_inst: - GROUP_SIZE { $$ = inst{std::make_unique().release()}; } +store_flag: + %empty { $$ = store_flag::regular; } + | ATOMIC { $$ = store_flag::atomic; } + | ATOMIC_ADD { $$ = store_flag::atomic_add; } + | ATOMIC_MAX { $$ = store_flag::atomic_max; } + | ATOMIC_MIN { $$ = store_flag::atomic_min; } ; +instruction: if_inst { $$ = std::move($1); } ; +valued_inst: if_inst { $$ = std::move($1); } ; if_inst: - IF identifier_or_constant[condition] optional_returned_values region else_region { - check_scalar_type($condition, scalar_type::i1, @condition, @condition); - $$ = inst{std::make_unique(std::move($condition), std::move($region), - std::move($else_region), - std::move($optional_returned_values)) - .release()}; - $$->loc(@if_inst); + IF var[condition] optional_returned_values >{ + yytry(ctx, [&] { + auto loc = @IF; + loc.end = @optional_returned_values.end; + $$ = if_inst::create(std::move($condition), std::move($optional_returned_values), loc); + auto inode = if_inst($$.get()); + ctx.push_region(&inode.then()); + }); + }[header] region { + ctx.pop_region(); + auto inode = if_inst($header.get()); + ctx.push_region(&inode.otherwise()); + } else_region { + ctx.pop_region(); + $$ = std::move($header); } ; else_region: - %empty { $$ = {}; } - | ELSE region{ $$ = std::move($region); } + %empty {} + | ELSE region {} ; optional_returned_values: %empty { $$ = {}; } - | RETURNS LPAREN optional_scalar_type_list[tys] RPAREN { $$ = std::move($tys); } + | ARROW LPAREN optional_return_type_list[tys] RPAREN { $$ = std::move($tys); } ; -optional_scalar_type_list: +optional_return_type_list: %empty {} - | scalar_type_list { $$ = std::move($1); } + | return_type_list { $$ = std::move($1); } ; -scalar_type_list: - scalar_type { $$.push_back($scalar_type); } - | scalar_type_list COMMA scalar_type { $$ = std::move($1); $$.push_back($scalar_type); } +return_type_list: + data_type { $$.push_back($data_type); } + | return_type_list COMMA data_type { + $$ = std::move($1); $$.push_back($data_type); + } ; +valued_inst: COS var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = cos_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SIN var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = sin_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: EXP var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = exp_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: EXP2 var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = exp2_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: NATIVE_COS var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = native_cos_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: NATIVE_SIN var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = native_sin_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: NATIVE_EXP var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = native_exp_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: NATIVE_EXP2 var[a] COLON data_type[ty] { yytry(ctx, [&] { $$ = native_exp2_inst::create($a, $ty, @valued_inst); }); }; -size_inst: - SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - try { - $$ = inst { std::make_unique(std::move($var), $mode, @size_inst).release() }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +instruction: + PARALLEL >{ + yytry(ctx, [&] { + $$ = parallel_inst::create(@PARALLEL); + auto inode = parallel_inst($$.get()); + ctx.push_region(&inode.body()); + }); + }[header] region { + ctx.pop_region(); + $$ = std::move($header); } ; -subview_inst: - SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type { - if (!$var->ty() || !is_equal(*$var->ty(), *$memref_type)) { - auto loc = @var; - loc.end = @memref_type.end; - throw parser::syntax_error(loc, "Type of SSA value does not match operand type"); - } - try { - $$ = inst { - std::make_unique(std::move($var), std::move($optional_slice_list), - @subview_inst) - .release() - }; - } catch (compilation_error const &e) { - error(e.loc(), e.what()); - YYERROR; - } +valued_inst: + SIZE var LSQBR INTEGER_CONSTANT[mode] RSQBR COLON scalar_type { + yytry(ctx, [&] { + $$ = size_inst::create($mode, std::move($var), $scalar_type, @valued_inst); + }); + } +; + +valued_inst: + SUBGROUP_BROADCAST var[a] COMMA var[idx] COLON scalar_type { + yytry(ctx, [&] { + $$ = + subgroup_broadcast_inst::create(std::move($a), std::move($idx), $scalar_type, @valued_inst); + }); + } +; + +valued_inst: SUBGROUP_EXCLUSIVE_SCAN_ADD var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_exclusive_scan_add_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_EXCLUSIVE_SCAN_MAX var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_exclusive_scan_max_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_EXCLUSIVE_SCAN_MIN var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_exclusive_scan_min_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_INCLUSIVE_SCAN_ADD var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_inclusive_scan_add_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_INCLUSIVE_SCAN_MAX var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_inclusive_scan_max_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_INCLUSIVE_SCAN_MIN var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_inclusive_scan_min_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_REDUCE_ADD var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_reduce_add_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_REDUCE_MAX var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_reduce_max_inst::create($a, $ty, @valued_inst); }); }; +valued_inst: SUBGROUP_REDUCE_MIN var[a] COLON scalar_type[ty] { yytry(ctx, [&] { $$ = subgroup_reduce_min_inst::create($a, $ty, @valued_inst); }); }; + +valued_inst: + SUBVIEW var LSQBR optional_slice_list RSQBR COLON memref_type[ty] { + yytry(ctx, [&] { + auto static_offsets = std::vector{}; + auto static_sizes = std::vector{}; + auto offsets = std::vector{}; + auto sizes = std::vector{}; + static_offsets.reserve($optional_slice_list.size()); + static_sizes.reserve($optional_slice_list.size()); + offsets.reserve($optional_slice_list.size()); + sizes.reserve($optional_slice_list.size()); + for (auto &s : $optional_slice_list) { + std::visit(overloaded{ + [&](std::int64_t i) { static_offsets.push_back(i); }, + [&](tinytc_value_t v) { + static_offsets.push_back(dynamic); + offsets.push_back(v); + }, + }, + s.first); + std::visit(overloaded{ + [&](std::int64_t i) { static_sizes.push_back(i); }, + [&](tinytc_value_t v) { + static_sizes.push_back(dynamic); + sizes.push_back(v); + }, + }, + s.second); + } + $$ = subview_inst::create(std::move(static_offsets), std::move(static_sizes), std::move($var), + std::move(offsets), std::move(sizes), std::move($ty), @valued_inst); + }); } ; @@ -954,25 +1125,30 @@ optional_slice_list: ; slice_list: - slice { $$.push_back($slice); } - | slice_list COMMA slice { $$ = std::move($1); $$.push_back($slice); } + slice { + $$.emplace_back(std::move($slice)); + } + | slice_list COMMA slice { + $$ = std::move($1); + $$.emplace_back(std::move($slice)); + } ; slice: - COLON { $$ = slice(make_index(0), make_dynamic()); } - | index_identifier_or_const slice_size { $$ = slice(std::move($1), std::move($2)); } + integer_constant_or_identifier slice_size { $$ = std::make_pair(std::move($1), std::move($2)); } ; slice_size: %empty { $$ = {}; } - | COLON index_identifier_or_const { $$ = $2; } - | COLON DYNAMIC { $$ = make_dynamic(); } + | COLON integer_constant_or_identifier { $$ = $2; } ; %% namespace tinytc { void parser::error(location_type const& l, std::string const& m) { - ctx.add_error(l, m); + if (m.size() > 0 || l.begin.line > 0) { + ctx.report_error(l, m); + } } } diff --git a/src/pass/check_ir.cpp b/src/pass/check_ir.cpp new file mode 100644 index 00000000..b15af27c --- /dev/null +++ b/src/pass/check_ir.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/check_ir.hpp" +#include "error.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "support/walk.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include + +namespace tinytc { + +void check_ir_pass::check_yield(tinytc_region ®, tinytc_inst &in, status yield_missing_status) { + auto last_inst = --reg.end(); + if (last_inst == reg.end()) { + throw compilation_error(reg.loc(), yield_missing_status); + } + auto yield = dyn_cast(last_inst.get()); + if (!yield) { + throw compilation_error(reg.loc(), yield_missing_status); + } + if (yield.get().num_operands() != in.num_results()) { + throw compilation_error(yield.loc(), status::ir_yield_mismatch); + } + for (std::int64_t i = 0; i < in.num_results(); ++i) { + if (yield.get().op(i).ty() != in.result(i).ty()) { + throw compilation_error(yield.loc(), {&yield.get().op(i)}, status::ir_yield_mismatch); + } + } +} + +void check_ir_pass::operator()(inst_view) {} +void check_ir_pass::operator()(for_inst in) { + if (in.get().num_results() > 0) { + check_yield(in.body(), in.get()); + } +} +void check_ir_pass::operator()(if_inst in) { + if (in.get().num_results() > 0) { + check_yield(in.then(), in.get()); + check_yield(in.otherwise(), in.get(), status::ir_yield_in_else_branch_missing); + } +} + +void check_ir_pass::run_on_function(tinytc_func &fn) { + auto inside_spmd_region = std::stack{}; + inside_spmd_region.push(false); + walk(fn, [&](tinytc_inst &i, walk_stage const &stage) { + const bool child_region_is_spmd_region = + i.num_child_regions() > 0 && i.child_region(0).kind() == region_kind::spmd; + + if (stage.is_before_all_regions()) { + if (i.kind() == inst_execution_kind::collective && inside_spmd_region.top()) { + throw compilation_error(i.loc(), status::ir_collective_called_from_spmd); + } else if (i.kind() == inst_execution_kind::spmd && !inside_spmd_region.top()) { + throw compilation_error(i.loc(), status::ir_spmd_called_from_collective); + } + + if (child_region_is_spmd_region) { + inside_spmd_region.push(true); + } + } + + if (child_region_is_spmd_region && stage.is_after_all_regions()) { + inside_spmd_region.pop(); + } + + visit(*this, i); + }); +} + +} // namespace tinytc diff --git a/src/pass/check_ir.hpp b/src/pass/check_ir.hpp new file mode 100644 index 00000000..1ba63c00 --- /dev/null +++ b/src/pass/check_ir.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CHECK_IR_20240222_HPP +#define CHECK_IR_20240222_HPP + +#include "node/inst_view.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +namespace tinytc { + +class check_ir_pass { + public: + void operator()(inst_view in); + void operator()(for_inst in); + void operator()(if_inst in); + + void run_on_function(tinytc_func &fn); + + private: + void check_yield(tinytc_region ®, tinytc_inst &in, + status yield_missing_status = status::ir_must_have_yield); +}; + +} // namespace tinytc + +#endif // CHECK_IR_20240222_HPP diff --git a/src/pass/clone.cpp b/src/pass/clone.cpp new file mode 100644 index 00000000..ded98994 --- /dev/null +++ b/src/pass/clone.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/clone.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/visit.hpp" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include + +namespace tinytc { + +void inst_cloner::reset_subs() { subs_map_.clear(); } +void inst_cloner::set_subs(tinytc_value_t in_val, tinytc_value_t out_val) { + subs_map_[in_val] = out_val; +} +auto inst_cloner::subs(tinytc_value_t val) -> tinytc_value_t { + if (auto it = subs_map_.find(val); it != subs_map_.end()) { + return it->second; + } + return val; +} + +auto inst_cloner::clone_instruction(tinytc_inst &in) -> unique_handle { + auto cloned = visit( + [&](auto view) { + auto tid = view.get().type_id(); + auto layout = view.get().layout(); + auto lc = view.get().loc(); + auto clone = unique_handle{tinytc_inst::create(tid, layout, lc)}; + for (std::int32_t ret_no = 0; ret_no < layout.num_results; ++ret_no) { + clone->result(ret_no) = + tinytc_value{view.get().result(ret_no).ty(), clone.get(), lc}; + } + for (std::int32_t op_no = 0; op_no < layout.num_operands; ++op_no) { + clone->op(op_no, subs(&view.get().op(op_no))); + } + + auto clone_view = decltype(view)(clone.get()); + clone_view.props() = view.props(); + clone_view.setup_and_check(); + + return clone; + }, + in); + + for (auto res_orig = in.result_begin(), res_cloned = cloned->result_begin(); + res_orig != in.result_end() && res_cloned != cloned->result_end(); + ++res_orig, ++res_cloned) { + set_subs(&(*res_orig), &(*res_cloned)); + } + for (auto reg_orig = in.child_regions_begin(), reg_cloned = cloned->child_regions_begin(); + reg_orig != in.child_regions_end() && reg_cloned != cloned->child_regions_end(); + ++reg_orig, ++reg_cloned) { + for (auto p_orig = reg_orig->param_begin(), p_cloned = reg_cloned->param_begin(); + p_orig != reg_orig->param_end() && p_cloned != reg_cloned->param_end(); + ++p_orig, ++p_cloned) { + set_subs(&(*p_orig), &(*p_cloned)); + } + clone_region(*reg_orig, *reg_cloned); + } + return cloned; +} + +void inst_cloner::clone_region(tinytc_region &source, tinytc_region &target) { + for (auto &in_orig : source.insts()) { + target.insts().push_back(clone_instruction(in_orig).release()); + } +} + +} // namespace tinytc diff --git a/src/pass/clone.hpp b/src/pass/clone.hpp new file mode 100644 index 00000000..0dd941a5 --- /dev/null +++ b/src/pass/clone.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CLONE_20241118_HPP +#define CLONE_20241118_HPP + +#include "node/value.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc { + +class inst_cloner { + public: + void reset_subs(); + void set_subs(tinytc_value_t in_val, tinytc_value_t out_val); + auto subs(tinytc_value_t val) -> tinytc_value_t; + + auto clone_instruction(tinytc_inst &in) -> unique_handle; + void clone_region(tinytc_region &source, tinytc_region &target); + + private: + template auto subs_value_range(T &&range) { + auto vec = std::vector(); + vec.reserve(range.size()); + for (auto &r : range) { + vec.emplace_back(subs(&r)); + } + return vec; + } + + std::unordered_map subs_map_; +}; + +} // namespace tinytc + +#endif // CLONE_20241118_HPP diff --git a/src/pass/constant_folding.cpp b/src/pass/constant_folding.cpp new file mode 100644 index 00000000..60d25361 --- /dev/null +++ b/src/pass/constant_folding.cpp @@ -0,0 +1,262 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/constant_folding.hpp" +#include "error.hpp" +#include "node/inst.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "number_dispatch.hpp" +#include "util/casting.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +template class op_dispatcher { + private: + tinytc_type_t dispatch_ty; + F computer; + + public: + op_dispatcher(tinytc_type_t ty, F &&f) : dispatch_ty{ty}, computer{std::forward(f)} {} + + template + requires(std::is_same_v && ...) + auto operator()(T &&...ops) -> fold_result { + return dispatch_int_to_native(dispatch_ty, [&]() { + return computer.template operator()(std::forward(ops)...); + }); + } + template + requires(std::is_same_v && ...) + auto operator()(T &&...ops) -> fold_result { + return dispatch_float_to_native(dispatch_ty, [&]() { + return computer.template operator()(std::forward(ops)...); + }); + } + template + requires(std::is_same_v> && ...) + auto operator()(T &&...ops) -> fold_result { + return dispatch_complex_to_native(dispatch_ty, [&]() { + return computer.template operator()(std::forward(ops)...); + }); + } + template auto operator()(T &&...) -> fold_result { + throw compilation_error(computer.loc, status::ir_number_mismatch); + } +}; + +constant_folding::constant_folding(bool unsafe_fp_math) : unsafe_fp_math_(unsafe_fp_math) {} + +auto constant_folding::get_memref_type(tinytc_value const &v) const -> const memref_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + +auto constant_folding::operator()(inst_view) -> fold_result { return tinytc_value_t{}; } + +auto constant_folding::operator()(arith_inst in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + constant_inst b_const = dyn_cast(op_b.defining_inst()); + + if (isa(*op_a.ty())) { + if ((a_const && !std::holds_alternative(a_const.value())) || + (b_const && !std::holds_alternative(b_const.value()))) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + if (a_const && b_const) { + return compute_binary_op{in.get().type_id(), op_a.ty(), in.loc()}( + std::get(a_const.value()), std::get(b_const.value())); + } else if (a_const) { + return compute_binop_identities{unsafe_fp_math_, in.get().type_id(), op_b, true, + in.loc()}(std::get(a_const.value())); + } else if (b_const) { + return compute_binop_identities{unsafe_fp_math_, in.get().type_id(), op_a, false, + in.loc()}(std::get(b_const.value())); + } + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_number_or_boolean); + } + at = dyn_cast(ct->component_ty()); + } + + if (a_const && b_const) { + auto computer = compute_binary_op{in.get().type_id(), op_a.ty(), in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), a_const.value(), b_const.value()); + } else if (a_const) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.get().type_id(), op_b, true, in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), a_const.value()); + } else if (b_const) { + auto computer = + compute_binop_identities{unsafe_fp_math_, in.get().type_id(), op_a, false, in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), b_const.value()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(arith_unary_inst in) -> fold_result { + auto &op_a = in.a(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + if (!a_const) { + return tinytc_value_t{}; + } + + if (isa(*op_a.ty())) { + if (!std::holds_alternative(a_const.value())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + return compute_unary_op{in.get().type_id(), op_a.ty(), + in.loc()}(std::get(a_const.value())); + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + // Arithmetic on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(op_a.ty()); + if (ct == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_coopmatrix_or_number); + } + at = dyn_cast(ct->component_ty()); + } + + auto computer = compute_unary_op{in.get().type_id(), op_a.ty(), in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), a_const.value()); +} + +auto constant_folding::operator()(cast_inst in) -> fold_result { + auto &op_a = in.a(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + if (!a_const) { + return tinytc_value_t{}; + } + + auto rt = dyn_cast(in.result().ty()); + if (rt == nullptr) { + // Cast on coopmatrix is component-wise and if a coopmatrix is constant, then all + // elements have the same value. Thus, constant folding on coopmatrix types is identical to + // constant folding on scalar types. + auto ct = dyn_cast(in.result().ty()); + if (ct == nullptr) { + throw compilation_error(in.result().loc(), status::ir_expected_coopmatrix_or_number); + } + rt = dyn_cast(ct->component_ty()); + } + + return std::visit( + overloaded{[&](auto A) -> fold_result { return compute_cast(rt, A, in.loc()); }}, + a_const.value()); +} + +auto constant_folding::operator()(compare_inst in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + constant_inst b_const = dyn_cast(op_b.defining_inst()); + if (!a_const || !b_const) { + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_number); + } + + auto computer = compute_compare{in.get().type_id(), in.result().ty(), in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), a_const.value(), b_const.value()); +} + +auto constant_folding::operator()(cooperative_matrix_scale_inst in) -> fold_result { + auto &op_a = in.a(); + auto &op_b = in.b(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + throw compilation_error(op_a.loc(), status::ir_expected_number); + } + + if (a_const) { + auto computer = compute_binop_identities{unsafe_fp_math_, IK::IK_mul, op_b, true, in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), a_const.value()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(math_unary_inst in) -> fold_result { + auto &op_a = in.a(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + if (!a_const) { + return tinytc_value_t{}; + } + + auto at = dyn_cast(op_a.ty()); + if (at == nullptr) { + return tinytc_value_t{}; + } + + auto computer = compute_math_unary_op{in.get().type_id(), op_a.ty(), in.loc()}; + auto dispatcher = op_dispatcher{at, std::move(computer)}; + return std::visit(std::move(dispatcher), a_const.value()); +} + +auto constant_folding::operator()(size_inst in) -> fold_result { + auto mode_size = + visit(overloaded{[&](group_type const &g) -> std::int64_t { return g.size(); }, + [&](memref_type const &m) -> std::int64_t { return m.shape(in.mode()); }, + [&](auto const &) -> std::int64_t { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + }}, + *in.operand().ty()); + + if (!is_dynamic_value(mode_size)) { + return create(mode_size, index_type::get(in.operand().context()), in.loc()); + } + return tinytc_value_t{}; +} + +auto constant_folding::operator()(subgroup_broadcast_inst in) -> fold_result { + auto &op_a = in.a(); + + constant_inst a_const = dyn_cast(op_a.defining_inst()); + if (a_const) { + return &a_const.result(); + } + return tinytc_value_t{}; +} + +} // namespace tinytc diff --git a/src/pass/constant_folding.hpp b/src/pass/constant_folding.hpp new file mode 100644 index 00000000..9b7d9c82 --- /dev/null +++ b/src/pass/constant_folding.hpp @@ -0,0 +1,539 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONSTANT_FOLDING_HELPER_20241011_HPP +#define CONSTANT_FOLDING_HELPER_20241011_HPP + +#include "error.hpp" +#include "node/inst_view.hpp" +#include "node/value.hpp" +#include "number.hpp" +#include "number_dispatch.hpp" +#include "support/fp_util.hpp" // IWYU pragma: keep +#include "tinytc/builder.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +class number_type; +template class lp_float; + +using fold_result = std::variant>; + +class constant_folding { + public: + constant_folding(bool unsafe_fp_math); + + auto operator()(inst_view) -> fold_result; + auto operator()(arith_inst) -> fold_result; + auto operator()(arith_unary_inst) -> fold_result; + auto operator()(cooperative_matrix_scale_inst) -> fold_result; + auto operator()(cast_inst) -> fold_result; + auto operator()(compare_inst) -> fold_result; + auto operator()(math_unary_inst) -> fold_result; + auto operator()(size_inst in) -> fold_result; + auto operator()(subgroup_broadcast_inst in) -> fold_result; + + private: + auto get_memref_type(tinytc_value const &v) const -> const memref_type *; + + bool unsafe_fp_math_; +}; + +struct compute_unary_op { + IK operation; + tinytc_type_t ty; + location const &loc; + + auto operator()(bool a) -> fold_result { + bool val = false; + switch (operation) { + case IK::IK_not: + val = !a; + break; + default: + throw compilation_error(loc, status::ir_boolean_unsupported); + } + return create(val, ty, loc); + } + + template + requires(std::is_integral_v && !std::is_same_v) + auto operator()(T a) -> fold_result { + T val = 0; + switch (operation) { + case IK::IK_abs: + val = a < 0 ? -a : a; + break; + case IK::IK_neg: + val = -a; + break; + case IK::IK_not: + val = ~a; + break; + default: + throw compilation_error(loc, status::ir_int_unsupported); + } + return create(val, ty, loc); + } + + template + requires(is_floating_point_or_lp_float_v) + auto operator()(T a) -> fold_result { + T val = 0; + switch (operation) { + case IK::IK_abs: + val = a < T{0} ? -a : a; + break; + case IK::IK_neg: + val = -a; + break; + default: + throw compilation_error(loc, status::ir_fp_unsupported); + } + return create(val, ty, loc); + } + + template + requires(is_complex_v) + auto operator()(U const &A) -> fold_result { + const auto neg_conj = [&](T const &a) -> unique_handle { + T val = {}; + switch (operation) { + case IK::IK_neg: + val = -a; + break; + case IK::IK_conj: + val = std::conj(a); + break; + default: + return {}; + } + return create(val, ty, loc); + }; + const auto abs_im_re = [&](T const &a) -> unique_handle { + typename T::value_type val = {}; + switch (operation) { + case IK::IK_abs: + val = std::abs(a); + break; + case IK::IK_im: + val = std::imag(a); + break; + case IK::IK_re: + val = std::real(a); + break; + default: + return {}; + } + auto cst_ty = component_type(ty); + return create(val, cst_ty, loc); + }; + + const auto a = static_cast(A); + auto result = neg_conj(a); + if (result) { + return result; + } + result = abs_im_re(a); + if (result) { + return result; + } + throw compilation_error(loc, status::ir_complex_unsupported); + } +}; + +struct compute_binary_op { + IK operation; + tinytc_type_t ty; + location const &loc; + + auto operator()(bool a, bool b) -> fold_result { + bool val = false; + switch (operation) { + case IK::IK_and: + val = a && b; + break; + case IK::IK_or: + val = a || b; + break; + case IK::IK_xor: + val = a != b; + break; + default: + throw compilation_error(loc, status::ir_boolean_unsupported); + } + return create(val, ty, loc); + } + + template + requires(std::is_integral_v && !std::is_same_v) + auto operator()(T a, T b) -> fold_result { + T val = 0; + switch (operation) { + case IK::IK_add: + val = a + b; + break; + case IK::IK_sub: + val = a - b; + break; + case IK::IK_mul: + val = a * b; + break; + case IK::IK_div: + val = a / b; + break; + case IK::IK_rem: + val = a % b; + break; + case IK::IK_shl: + val = a << b; + break; + case IK::IK_shr: + val = a >> b; + break; + case IK::IK_and: + val = a & b; + break; + case IK::IK_or: + val = a | b; + break; + case IK::IK_xor: + val = a ^ b; + break; + case IK::IK_min: + val = std::min(a, b); + break; + case IK::IK_max: + val = std::max(a, b); + break; + default: + throw compilation_error(loc, status::internal_compiler_error); + } + return create(val, ty, loc); + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A, U const &B) -> fold_result { + const auto a = static_cast(A); + const auto b = static_cast(B); + T val = {}; + switch (operation) { + case IK::IK_add: + val = a + b; + break; + case IK::IK_sub: + val = a - b; + break; + case IK::IK_mul: + val = a * b; + break; + case IK::IK_div: + val = a / b; + break; + case IK::IK_rem: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::fmod(a, b); + } + break; + case IK::IK_min: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::min(a, b); + } + break; + case IK::IK_max: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } else { + val = std::max(a, b); + } + break; + default: + if constexpr (is_complex_v) { + throw compilation_error(loc, status::ir_complex_unsupported); + } + throw compilation_error(loc, status::ir_fp_unsupported); + break; + } + return create(val, ty, loc); + } +}; + +struct compute_binop_identities { + bool unsafe_fp_math; + IK operation; + tinytc_value &operand; + bool is_second_operand; + location const &loc; + + auto operator()(bool a) -> fold_result { + switch (operation) { + case IK::IK_and: + if (!a) { + return create(false, operand.ty(), loc); + } + break; + case IK::IK_or: + case IK::IK_xor: + if (!a) { + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } + + template + requires(std::is_integral_v && !std::is_same_v) + auto operator()(T a) -> fold_result { + switch (operation) { + case IK::IK_add: + if (a == T{0}) { // operand + 0 or 0 + operand + return &operand; + } + break; + case IK::IK_sub: + if (a == T{0} && !is_second_operand) { // operand - 0 + return &operand; + } + break; + case IK::IK_mul: + if (a == T{0}) { // operand * 0 or 0 * operand + return create(T{0}, operand.ty(), loc); + } else if (a == T{1}) { // operand * 1 or 1 * operand + return &operand; + } + break; + case IK::IK_div: + if (a == T{1} && !is_second_operand) { // operand / 1 + return &operand; + } + break; + case IK::IK_rem: + if (a == T{1} && !is_second_operand) { // operand % 1 + return create(T{0}, operand.ty(), loc); + } + break; + case IK::IK_shl: + case IK::IK_shr: + if (a == T{0}) { + if (is_second_operand) { // 0 << operand + return create(T{0}, operand.ty(), loc); + } else { // operand << 0 + return &operand; + } + } + break; + case IK::IK_and: + if (a == T{0}) { + return create(T{0}, operand.ty(), loc); + } + break; + case IK::IK_or: + case IK::IK_xor: + if (a == T{0}) { + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } + + template + requires(!std::is_integral_v) + auto operator()(U const &A) -> fold_result { + const auto a = static_cast(A); + switch (operation) { + case IK::IK_add: + if (a == T{0}) { // operand + 0 or 0 + operand + return &operand; + } + break; + case IK::IK_sub: + if (a == T{0} && !is_second_operand) { // operand - 0 + return &operand; + } + break; + case IK::IK_mul: + if (unsafe_fp_math && a == T{0}) { // operand * 0 or 0 * operand + return create(T{0}, operand.ty(), loc); + } else if (a == T{1}) { // operand * 1 or 1 * operand + return &operand; + } + break; + case IK::IK_div: + if (a == T{1} && !is_second_operand) { // operand / 1 + return &operand; + } + break; + default: + break; + } + return tinytc_value_t{}; + } +}; + +struct compute_compare { + IK cond; + tinytc_type_t ty; + location const &loc; + + template + requires(std::is_integral_v || is_floating_point_or_lp_float_v) + auto operator()(T a, T b) -> fold_result { + bool val = false; + switch (cond) { + case IK::IK_equal: + val = (a == b); + break; + case IK::IK_not_equal: + val = (a != b); + break; + case IK::IK_greater_than: + val = (a > b); + break; + case IK::IK_greater_than_equal: + val = (a >= b); + break; + case IK::IK_less_than: + val = (a < b); + break; + case IK::IK_less_than_equal: + val = (a <= b); + break; + default: + throw compilation_error(loc, status::internal_compiler_error); + }; + return create(val, ty, loc); + } + + template + auto operator()(std::complex const &A, std::complex const &B) -> fold_result { + const auto a = static_cast(A); + const auto b = static_cast(B); + bool val = false; + switch (cond) { + case IK::IK_equal: + val = (a == b); + break; + case IK::IK_not_equal: + val = (a != b); + break; + default: + throw compilation_error(loc, status::ir_complex_unsupported); + break; + }; + return create(val, ty, loc); + } +}; + +template struct value_cast_impl { + auto operator()(U const &u) { return static_cast(u); } +}; + +template struct value_cast_impl, U> { + auto operator()(U const &u) { return std::complex{static_cast(u), static_cast(0)}; } +}; + +template struct value_cast_impl, std::complex> { + auto operator()(std::complex const &u) { return static_cast>(u); } +}; + +template struct value_cast_impl { + auto operator()(U const &u) { return u != U{}; } +}; + +template struct value_cast_impl> { + auto operator()(std::complex const &) -> bool { throw status::ir_forbidden_cast; } +}; + +template struct value_cast_impl> { + auto operator()(std::complex const &) -> T { throw status::ir_forbidden_cast; } +}; + +template auto value_cast(U const &u) { return value_cast_impl{}(u); } + +template +auto compute_cast(number_type *to_ty, T A, location const &loc) -> fold_result { + return dispatch_number_to_native( + to_ty, [&]() { return create(value_cast(A), to_ty, loc); }); +}; + +struct compute_math_unary_op { + IK operation; + tinytc_type_t ty; + location const &loc; + + template + requires(std::is_integral_v) + auto operator()(T) -> fold_result { + throw compilation_error(loc, status::ir_int_unsupported); + } + + template + requires(is_floating_point_or_lp_float_v) + auto operator()(T a) -> fold_result { + T val = {}; + switch (operation) { + case IK::IK_cos: + case IK::IK_native_cos: + val = std::cos(a); + break; + case IK::IK_sin: + case IK::IK_native_sin: + val = std::sin(a); + break; + case IK::IK_exp: + case IK::IK_native_exp: + val = std::exp(a); + break; + case IK::IK_exp2: + case IK::IK_native_exp2: + val = std::exp2(a); + break; + default: + throw compilation_error(loc, status::ir_fp_unsupported); + } + return create(val, ty, loc); + } + + template + requires(is_complex_v) + auto operator()(U const &a) -> fold_result { + T val = {}; + switch (operation) { + case IK::IK_exp: + case IK::IK_native_exp: + val = std::exp(a); + break; + case IK::IK_exp2: + case IK::IK_native_exp2: + val = std::pow(T{std::complex{2.0, 0.0}}, a); + break; + default: + throw compilation_error(loc, status::ir_complex_unsupported); + } + return create(val, ty, loc); + } +}; + +} // namespace tinytc + +#endif // CONSTANT_FOLDING_HELPER_20241011_HPP diff --git a/src/pass/constant_propagation.cpp b/src/pass/constant_propagation.cpp new file mode 100644 index 00000000..eeeab845 --- /dev/null +++ b/src/pass/constant_propagation.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/constant_propagation.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "pass/constant_folding.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/overloaded.hpp" + +#include +#include + +namespace tinytc { + +void constant_propagation_pass::run_on_function(tinytc_func &fn) { run_on_region(fn.body()); } + +void constant_propagation_pass::run_on_region(tinytc_region ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + for (auto &subreg : it->child_regions()) { + run_on_region(subreg); + } + + const auto update_uses = [&it](tinytc_value_t with) { + if (it->num_results() != 1) { + throw status::internal_compiler_error; + } + auto r = it->result_begin(); + auto u = r->use_begin(); + while (r->has_uses()) { + u->set(with); + u = r->use_begin(); + } + if (r->has_uses()) { + throw status::internal_compiler_error; + } + }; + + fold_result fr = visit(constant_folding{unsafe_fp_math_}, *it); + std::visit(overloaded{[&](tinytc_value_t val) { + if (val) { + update_uses(val); + } + }, + [&](unique_handle &new_constant) { + if (new_constant) { + if (new_constant->num_results() != 1) { + throw status::internal_compiler_error; + } + update_uses(&*new_constant->result_begin()); + // insert new constant + it = reg.insts().insert(it, new_constant.release()); + // skip over constant + ++it; + } + }}, + fr); + } +} + +void constant_propagation_pass::set_opt_flag(tinytc::optflag flag, bool enabled) { + if (flag == tinytc::optflag::unsafe_fp_math) { + unsafe_fp_math_ = enabled; + } +} + +} // namespace tinytc diff --git a/src/pass/constant_propagation.hpp b/src/pass/constant_propagation.hpp new file mode 100644 index 00000000..9bda83f3 --- /dev/null +++ b/src/pass/constant_propagation.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONSTANT_PROPAGATION_20240807_HPP +#define CONSTANT_PROPAGATION_20240807_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +enum class optflag; + +class constant_propagation_pass { + public: + void run_on_function(::tinytc_func &fn); + void run_on_region(::tinytc_region ®); + + void set_opt_flag(tinytc::optflag flag, bool enabled); + + private: + bool unsafe_fp_math_ = false; +}; + +} // namespace tinytc + +#endif // CONSTANT_PROPAGATION_20240807_HPP diff --git a/src/pass/convert_to_spirv.cpp b/src/pass/convert_to_spirv.cpp new file mode 100644 index 00000000..f125e89c --- /dev/null +++ b/src/pass/convert_to_spirv.cpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/convert_to_spirv.hpp" +#include "spv/converter.hpp" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { + +convert_to_spirv_pass::convert_to_spirv_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw status::invalid_arguments; + } +} + +auto convert_to_spirv_pass::run_on_program(tinytc_prog &p) -> shared_handle { + return spv::convert_prog_to_spirv(p, *info_); +} + +} // namespace tinytc diff --git a/src/pass/convert_to_spirv.hpp b/src/pass/convert_to_spirv.hpp new file mode 100644 index 00000000..f819a943 --- /dev/null +++ b/src/pass/convert_to_spirv.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERT_TO_SPIRV_20241029_HPP +#define CONVERT_TO_SPIRV_20241029_HPP + +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +namespace tinytc { + +class convert_to_spirv_pass { + public: + convert_to_spirv_pass(::tinytc_core_info const *info); + + auto run_on_program(tinytc_prog &p) -> shared_handle; + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // CONVERT_TO_SPIRV_20241029_HPP diff --git a/src/pass/dead_code_elimination.cpp b/src/pass/dead_code_elimination.cpp new file mode 100644 index 00000000..2123bd66 --- /dev/null +++ b/src/pass/dead_code_elimination.cpp @@ -0,0 +1,86 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dead_code_elimination.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include + +namespace tinytc { + +class dead_code_analysis { + public: + auto operator()(inst_view in) -> bool; + auto operator()(if_inst in) -> bool; + auto operator()(for_inst in) -> bool; +}; + +auto dead_code_analysis::operator()(inst_view in) -> bool { + /* Instruction have side effects if either of the following is true + * + * - More than one child region (if, for, foreach, parallel, ...) + * - Instruction does not have results (barrier, GEMM, GER, ...) + * + */ + const bool has_side_effects = in.get().num_child_regions() > 0 || in.get().num_results() == 0; + + bool any_result_has_uses = false; + for (auto &res : in.get().results()) { + any_result_has_uses = any_result_has_uses || res.has_uses(); + } + + return !has_side_effects && !any_result_has_uses; +} + +auto dead_code_analysis::operator()(if_inst in) -> bool { + constant_inst cond_const = dyn_cast(in.condition().defining_inst()); + if (in.get().num_results() == 0 && cond_const) { + // If-instruction is dead if condition is constant and false + return std::holds_alternative(cond_const.value()) && + std::get(cond_const.value()) == false; + } + + return false; +} + +auto dead_code_analysis::operator()(for_inst in) -> bool { + constant_inst from_const = dyn_cast(in.from().defining_inst()); + constant_inst to_const = dyn_cast(in.to().defining_inst()); + if (in.get().num_results() == 0 && from_const && to_const) { + // For-instruction is dead if from >= to + return std::holds_alternative(from_const.value()) && + std::holds_alternative(to_const.value()) && + std::get(from_const.value()) >= + std::get(to_const.value()); + } + return false; +} + +void dead_code_elimination_pass::run_on_function(tinytc_func &fn) { run_on_region(fn.body()); } + +void dead_code_elimination_pass::run_on_region(tinytc_region ®) { + auto prev_it = reg.end(); + while (prev_it != reg.begin()) { + auto it = --prev_it; + auto is_dead = visit(dead_code_analysis{}, *it); + if (is_dead) { + prev_it = reg.insts().erase(it); + } else { + for (auto &subreg : it->child_regions()) { + run_on_region(subreg); + } + } + } +} + +} // namespace tinytc diff --git a/src/pass/dead_code_elimination.hpp b/src/pass/dead_code_elimination.hpp new file mode 100644 index 00000000..e56d27f8 --- /dev/null +++ b/src/pass/dead_code_elimination.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DEAD_CODE_ELIMINATION_20241007_HPP +#define DEAD_CODE_ELIMINATION_20241007_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class dead_code_elimination_pass { + public: + void run_on_function(::tinytc_func &fn); + void run_on_region(::tinytc_region ®); +}; + +} // namespace tinytc + +#endif // DEAD_CODE_ELIMINATION_20241007_HPP diff --git a/src/pass/dump_cfg.cpp b/src/pass/dump_cfg.cpp new file mode 100644 index 00000000..4fb271e1 --- /dev/null +++ b/src/pass/dump_cfg.cpp @@ -0,0 +1,43 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_cfg.hpp" +#include "analysis/cfg.hpp" +#include "node/func.hpp" +#include "pass/dump_ir.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_cfg_pass::dump_cfg_pass(std::ostream &os) : os_(&os) {} + +void dump_cfg_pass::run_on_function(tinytc_func &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + + *os_ << "digraph " << fn.name() << " {" << std::endl; + + auto cfg = get_control_flow_graph(fn.body()); + auto q = cfg.node_queue(); + for (; !q.empty(); q.pop()) { + auto &node = q.front(); + + *os_ << reinterpret_cast(node) << " [label=\""; + dump_ir.run_on_instruction(*node); + *os_ << "\"]" << std::endl; + + for (auto &neigh : cfg.successors(node)) { + *os_ << reinterpret_cast(node) << " -> " + << reinterpret_cast(neigh) << std::endl; + } + } + + *os_ << "}" << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_cfg.hpp b/src/pass/dump_cfg.hpp new file mode 100644 index 00000000..6646942b --- /dev/null +++ b/src/pass/dump_cfg.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_BACKWARD_CFG_20240919_HPP +#define DUMP_BACKWARD_CFG_20240919_HPP + +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class dump_cfg_pass { + public: + dump_cfg_pass(std::ostream &os); + + void run_on_function(tinytc_func &fn); + + private: + std::ostream *os_; +}; + +} // namespace tinytc + +#endif // DUMP_BACKWARD_CFG_20240919_HPP diff --git a/src/pass/dump_def_use.cpp b/src/pass/dump_def_use.cpp new file mode 100644 index 00000000..61573078 --- /dev/null +++ b/src/pass/dump_def_use.cpp @@ -0,0 +1,57 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_def_use.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "pass/dump_ir.hpp" +#include "support/walk.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_def_use_pass::dump_def_use_pass(std::ostream &os) : os_(&os) {} + +void dump_def_use_pass::run_on_function(tinytc_func &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + dump_ir.init_slot_tracker(fn); + + *os_ << "Def-use in @" << fn.name() << std::endl; + walk(fn, [&](tinytc_inst &i) { + if (i.num_results() > 0 || i.num_child_regions() > 0) { + *os_ << "> "; + visit(dump_ir, i); + *os_ << std::endl; + auto const def_use = [&](tinytc_value const &v) { + *os_ << " def "; + dump_ir.dump_val(v); + *os_ << std::endl; + for (auto &u : v.uses()) { + *os_ << " > "; + visit(dump_ir, *u.owner()); + *os_ << std::endl; + } + }; + for (auto &res : i.results()) { + def_use(res); + } + for (auto ® : i.child_regions()) { + for (auto &p : reg.params()) { + def_use(p); + } + } + } + }); + *os_ << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_def_use.hpp b/src/pass/dump_def_use.hpp new file mode 100644 index 00000000..1d8405ab --- /dev/null +++ b/src/pass/dump_def_use.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_DEF_USE_20241002_HPP +#define DUMP_DEF_USE_20241002_HPP + +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class dump_def_use_pass { + public: + dump_def_use_pass(std::ostream &os); + + void run_on_function(tinytc_func &fn); + + private: + std::ostream *os_; +}; + +} // namespace tinytc + +#endif // DUMP_DEF_USE_20241002_HPP diff --git a/src/pass/dump_gcd.cpp b/src/pass/dump_gcd.cpp new file mode 100644 index 00000000..87cb6511 --- /dev/null +++ b/src/pass/dump_gcd.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_gcd.hpp" +#include "analysis/gcd.hpp" +#include "device_info.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/visit.hpp" +#include "pass/dump_ir.hpp" +#include "support/walk.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_gcd_pass::dump_gcd_pass(std::ostream &os, ::tinytc_core_info const *info) + : os_(&os), info_{info} {} + +void dump_gcd_pass::run_on_function(tinytc_func &fn) { + auto dump_ir = dump_ir_pass(*os_, 0); + dump_ir.init_slot_tracker(fn); + auto gcd = gcd_analysis{info_->alignment()}.run_on_function(fn); + + auto const dump_range = [&](auto begin, auto end) { + *os_ << "["; + for (auto it = begin; it != end; ++it) { + if (it != begin) { + *os_ << ","; + } + *os_ << *it; + } + *os_ << "]"; + }; + auto const dump_gcd = [&](tinytc_value const &v) { + auto g = gcd.get_if(v); + if (g) { + *os_ << " gcd("; + dump_ir.dump_val(v); + *os_ << ") = " << *g << std::endl; + } + auto mi = gcd.get_memref_if(v); + if (mi) { + *os_ << " offset_gcd("; + dump_ir.dump_val(v); + *os_ << ") = " << mi->offset_gcd() << std::endl; + *os_ << " shape_gcd("; + dump_ir.dump_val(v); + *os_ << ") = "; + dump_range(mi->shape_gcd_begin(), mi->shape_gcd_end()); + *os_ << std::endl << " stride_gcd("; + dump_ir.dump_val(v); + *os_ << ") = "; + dump_range(mi->stride_gcd_begin(), mi->stride_gcd_end()); + *os_ << std::endl; + } + }; + + *os_ << "GCD in @" << fn.name() << std::endl; + for (auto &p : fn.params()) { + dump_gcd(p); + } + walk(fn, [&](tinytc_inst &i) { + if (i.num_results() > 0 || i.num_child_regions() > 0) { + *os_ << "> "; + visit(dump_ir, i); + *os_ << std::endl; + for (auto &res : i.results()) { + dump_gcd(res); + } + for (auto ® : i.child_regions()) { + for (auto &p : reg.params()) { + dump_gcd(p); + } + } + } + }); + *os_ << std::endl; +} + +} // namespace tinytc diff --git a/src/pass/dump_gcd.hpp b/src/pass/dump_gcd.hpp new file mode 100644 index 00000000..ae386216 --- /dev/null +++ b/src/pass/dump_gcd.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_GCD_20241203_HPP +#define DUMP_GCD_20241203_HPP + +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class dump_gcd_pass { + public: + dump_gcd_pass(std::ostream &os, ::tinytc_core_info const *info); + + void run_on_function(tinytc_func &fn); + + private: + std::ostream *os_; + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // DUMP_GCD_20241203_HPP diff --git a/src/pass/dump_ir.cpp b/src/pass/dump_ir.cpp new file mode 100644 index 00000000..df24ebad --- /dev/null +++ b/src/pass/dump_ir.cpp @@ -0,0 +1,741 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/dump_ir.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/fnv1a.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +dump_ir_pass::dump_ir_pass(std::ostream &os, int level_limit) : os_(&os), lvl_limit_(level_limit) {} + +/* Attribute nodes */ +void dump_ir_pass::operator()(array_attr const &a) { + *os_ << "["; + do_with_infix(a.begin(), a.end(), [&](auto const &a) { visit(*this, *a); }); + *os_ << "]"; +} +void dump_ir_pass::operator()(boolean_attr const &a) { *os_ << (a.value() ? "true" : "false"); } +void dump_ir_pass::operator()(dictionary_attr const &a) { + auto const is_keyword = [](std::string_view str) { + switch (fnv1a(str)) { + case "alignment"_fnv1a: + case "shape_gcd"_fnv1a: + case "stride_gcd"_fnv1a: + case "subgroup_size"_fnv1a: + case "unroll"_fnv1a: + case "work_group_size"_fnv1a: + return true; + default: + return false; + } + }; + auto const dump_name = [&](tinytc_attr_t a) { + if (auto s = dyn_cast(a); s) { + if (is_keyword(s->str())) { + *os_ << s->str(); + } else { + this->operator()(*s); + } + } else { + throw status::ir_expected_string_attribute; + } + }; + *os_ << "{"; + do_with_infix( + a.begin(), a.end(), + [&](auto const &a) { + dump_name(a.name); + *os_ << "="; + visit(*this, *a.attr); + }, + ", "); + *os_ << "}"; +} +void dump_ir_pass::operator()(integer_attr const &a) { *os_ << a.value(); } +void dump_ir_pass::operator()(string_attr const &a) { *os_ << "\"" << a.str() << "\""; } + +/* Data type nodes */ +void dump_ir_pass::operator()(void_type const &) { *os_ << "void"; } +void dump_ir_pass::operator()(boolean_type const &) { *os_ << "bool"; } +void dump_ir_pass::operator()(coopmatrix_type const &ct) { + *os_ << "coopmatrix<"; + visit(*this, *ct.component_ty()); + *os_ << "x" << ct.rows() << "x" << ct.cols() << "," << to_string(ct.use()) << ">"; +} +void dump_ir_pass::operator()(group_type const &g) { + auto const val = [&](std::int64_t v) -> std::ostream & { + if (is_dynamic_value(v)) { + return *os_ << "?"; + } + return *os_ << v; + }; + *os_ << "group<"; + visit(*this, *g.element_ty()); + *os_ << "x"; + val(g.size()); + if (g.offset() != 0) { + *os_ << ", offset: "; + val(g.offset()); + } + *os_ << ">"; +} +void dump_ir_pass::operator()(memref_type const &d) { + auto const val = [&](std::int64_t v) -> std::ostream & { + if (is_dynamic_value(v)) { + return *os_ << "?"; + } + return *os_ << v; + }; + *os_ << "memref<"; + visit(*this, *d.element_ty()); + for (auto const &s : d.shape()) { + *os_ << "x"; + val(s); + } + if (!d.is_canonical_stride()) { + *os_ << ",strided<"; + do_with_infix(d.stride().begin(), d.stride().end(), [&](auto const &a) { val(a); }); + *os_ << ">"; + } + if (d.addrspace() != address_space::global) { + *os_ << "," << to_string(d.addrspace()); + } + *os_ << ">"; +} +void dump_ir_pass::operator()(number_type const &t) { *os_ << to_string(t.type_id()); } + +/* Value nodes */ +void dump_ir_pass::dump_val(tinytc_value const &v) { + *os_ << "%" << v.name(); + auto const slot = tracker_.get_slot(v); + if (slot >= 0) { + *os_ << slot; + } +} + +/* Inst nodes */ +void dump_ir_pass::dump_blas_a2(blas_a2_inst g) { + *os_ << ' '; + dump_val(g.alpha()); + *os_ << ", "; + dump_val(g.A()); + *os_ << ", "; + dump_val(g.beta()); + *os_ << ", "; + dump_val(g.B()); +} + +void dump_ir_pass::dump_blas_a3(blas_a3_inst g) { + *os_ << ' '; + dump_val(g.alpha()); + *os_ << ", "; + dump_val(g.A()); + *os_ << ", "; + dump_val(g.B()); + *os_ << ", "; + dump_val(g.beta()); + *os_ << ", "; + dump_val(g.C()); +} + +void dump_ir_pass::operator()(alloca_inst a) { + dump_val(a.result()); + *os_ << " = alloca : "; + visit(*this, *a.result().ty()); +} + +void dump_ir_pass::operator()(axpby_inst a) { + *os_ << "axpby"; + if (a.atomic()) { + *os_ << ".atomic"; + } + if (a.tA() != transpose::N) { + *os_ << "." << to_string(a.tA()); + } + dump_blas_a2(static_cast(a)); +} + +void dump_ir_pass::operator()(arith_inst a) { + dump_val(a.result()); + *os_ << " = " << to_string(a.get().type_id()) << " "; + dump_val(a.a()); + *os_ << ", "; + dump_val(a.b()); + *os_ << " : "; + visit(*this, *a.result().ty()); +} + +void dump_ir_pass::operator()(arith_unary_inst a) { + dump_val(a.result()); + *os_ << " = " << to_string(a.get().type_id()) << " "; + dump_val(a.a()); + *os_ << " : "; + visit(*this, *a.result().ty()); +} + +void dump_ir_pass::operator()(barrier_inst b) { + *os_ << "barrier"; + if (b.has_fence(address_space::global)) { + *os_ << ".global"; + } + if (b.has_fence(address_space::local)) { + *os_ << ".local"; + } +} + +void dump_ir_pass::operator()(cast_inst c) { + dump_val(c.result()); + *os_ << " = cast "; + dump_val(c.a()); + *os_ << " : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(compare_inst a) { + dump_val(a.result()); + *os_ << " = " << to_string(a.get().type_id()) << " "; + dump_val(a.a()); + *os_ << ", "; + dump_val(a.b()); + *os_ << " : "; + visit(*this, *a.result().ty()); +} + +void dump_ir_pass::operator()(constant_inst c) { + dump_val(c.result()); + *os_ << " = constant "; + std::visit(overloaded{ + [&](bool b) { *os_ << (b ? "true" : "false"); }, + [&](std::int64_t i) { + if (is_dynamic_value(i)) { + *os_ << "?"; + } else { + *os_ << i; + } + }, + [&](double d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << d; + os_->flags(flags); + }, + [&](std::complex d) { + auto flags = os_->flags(); + *os_ << std::hexfloat << "[" << d.real() << "," << d.imag() << "]"; + os_->flags(flags); + }, + }, + c.value()); + *os_ << " : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_apply_inst c) { + dump_val(c.result()); + *os_ << " = cooperative_matrix_apply ("; + dump_val(c.row()); + *os_ << ","; + dump_val(c.col()); + *os_ << ","; + dump_val(c.val()); + *os_ << ") in "; + dump_val(c.a()); + *os_ << " -> "; + visit(*this, *c.result().ty()); + dump_region(c.body()); +} + +void dump_ir_pass::operator()(cooperative_matrix_extract_inst c) { + dump_val(c.result()); + *os_ << " = cooperative_matrix_extract "; + dump_val(c.mat()); + *os_ << "[" << c.index() << "] : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_insert_inst c) { + dump_val(c.result()); + *os_ << " = cooperative_matrix_insert "; + dump_val(c.val()); + *os_ << ", "; + dump_val(c.mat()); + *os_ << "[" << c.index() << "] : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_load_inst c) { + dump_val(c.result()); + *os_ << " = cooperative_matrix_load"; + if (c.t() != transpose::N) { + *os_ << "." << to_string(c.t()); + } + if (c.checked() != checked_flag::none) { + *os_ << "." << to_string(c.checked()); + } + *os_ << " "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "] : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_mul_add_inst c) { + dump_val(c.result()); + *os_ << " = cooperative_matrix_mul_add "; + dump_val(c.a()); + *os_ << ", "; + dump_val(c.b()); + *os_ << ", "; + dump_val(c.c()); + *os_ << " : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_prefetch_inst c) { + *os_ << "cooperative_matrix_prefetch "; + *os_ << c.cache_level(); + *os_ << ", "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "], "; + *os_ << c.rows(); + *os_ << ", "; + *os_ << c.cols(); +} + +void dump_ir_pass::operator()(cooperative_matrix_reduce_inst c) { + dump_val(c.result()); + *os_ << " = "; + *os_ << to_string(c.get().type_id()) << "." << to_string(c.mode()) << " "; + dump_val(c.a()); + *os_ << " : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_scale_inst c) { + dump_val(c.result()); + *os_ << " = cooperative_matrix_scale "; + dump_val(c.a()); + *os_ << ", "; + dump_val(c.b()); + *os_ << " : "; + visit(*this, *c.result().ty()); +} + +void dump_ir_pass::operator()(cooperative_matrix_store_inst c) { + *os_ << "cooperative_matrix_store"; + if (c.t() != transpose::N) { + *os_ << "." << to_string(c.t()); + } + if (c.checked() != checked_flag::none) { + *os_ << "." << to_string(c.checked()); + } + if (c.flag() != store_flag::regular) { + *os_ << '.' << to_string(c.flag()); + } + *os_ << " "; + dump_val(c.val()); + *os_ << ", "; + dump_val(c.operand()); + *os_ << "["; + dump_val(c.pos0()); + *os_ << ","; + dump_val(c.pos1()); + *os_ << "]"; +} + +void dump_ir_pass::operator()(cumsum_inst in) { + *os_ << "cumsum"; + if (in.atomic()) { + *os_ << ".atomic"; + } + *os_ << ' '; + dump_val(in.alpha()); + *os_ << ", "; + dump_val(in.A()); + *os_ << ", " << in.mode() << ", "; + dump_val(in.beta()); + *os_ << ", "; + dump_val(in.B()); +} + +void dump_ir_pass::operator()(expand_inst e) { + dump_val(e.result()); + *os_ << " = expand "; + dump_val(e.operand()); + *os_ << "[" << e.expanded_mode() << "->"; + auto const &ses = e.static_expand_shape(); + auto es = e.expand_shape(); + for (std::size_t i = 0, j = 0; i < ses.size(); ++i) { + if (i != 0) { + *os_ << " x "; + } + if (is_dynamic_value(ses[i])) { + dump_val(es[j++]); + } else { + *os_ << ses[i]; + } + } + *os_ << "] : "; + visit(*this, *e.result().ty()); +} + +void dump_ir_pass::operator()(fuse_inst f) { + dump_val(f.result()); + *os_ << " = fuse "; + dump_val(f.operand()); + *os_ << "[" << f.from() << "," << f.to() << "]"; + *os_ << " : "; + visit(*this, *f.result().ty()); +} + +void dump_ir_pass::operator()(load_inst e) { + dump_val(e.result()); + *os_ << " = load "; + dump_val(e.operand()); + *os_ << "["; + do_with_infix(e.index_list().begin(), e.index_list().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << "] : "; + visit(*this, *e.result().ty()); +} + +void dump_ir_pass::operator()(lifetime_stop_inst l) { + *os_ << "lifetime_stop "; + dump_val(l.object()); +} + +void dump_ir_pass::operator()(gemm_inst g) { + *os_ << "gemm"; + if (g.atomic()) { + *os_ << ".atomic"; + } + if (g.tA() != transpose::N || g.tB() != transpose::N) { + *os_ << "." << to_string(g.tA()); + } + if (g.tB() != transpose::N) { + *os_ << "." << to_string(g.tB()); + } + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(gemv_inst g) { + *os_ << "gemv"; + if (g.atomic()) { + *os_ << ".atomic"; + } + if (g.tA() != transpose::N) { + *os_ << "." << to_string(g.tA()); + } + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(ger_inst g) { + *os_ << "ger"; + if (g.atomic()) { + *os_ << ".atomic"; + } + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(for_inst in) { + auto results = in.results(); + if (results.size() > 0) { + do_with_infix(results.begin(), results.end(), [this](auto const &i) { dump_val(i); }); + *os_ << " = "; + } + *os_ << "for "; + dump_val(in.loop_var()); + *os_ << "="; + dump_val(in.from()); + *os_ << ","; + dump_val(in.to()); + if (in.has_step()) { + *os_ << ","; + dump_val(in.step()); + } + if (results.size() > 0) { + auto iter_init = in.iter_init(); + *os_ << " init("; + for (std::int64_t i = 0; i < results.size(); ++i) { + if (i != 0) { + *os_ << ","; + } + dump_val(in.iter_arg(i)); + *os_ << "="; + dump_val(iter_init[i]); + } + *os_ << ") -> ("; + do_with_infix(results.begin(), results.end(), + [this](auto const &i) { visit(*this, *i.ty()); }); + *os_ << ")"; + } + *os_ << " "; + dump_region(in.body()); + if (in.get().attr()) { + *os_ << " "; + visit(*this, *in.get().attr()); + } +} + +void dump_ir_pass::operator()(foreach_inst in) { + *os_ << "foreach ("; + do_with_infix(in.loop_vars().begin(), in.loop_vars().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << ")=("; + do_with_infix(in.from().begin(), in.from().end(), [this](auto const &i) { dump_val(i); }); + *os_ << "),("; + do_with_infix(in.to().begin(), in.to().end(), [this](auto const &i) { dump_val(i); }); + *os_ << ") "; + dump_region(in.body()); +} + +void dump_ir_pass::operator()(hadamard_inst g) { + *os_ << "hadamard"; + if (g.atomic()) { + *os_ << ".atomic"; + } + dump_blas_a3(static_cast(g)); +} + +void dump_ir_pass::operator()(if_inst in) { + auto results = in.results(); + if (results.size() > 0) { + do_with_infix(results.begin(), results.end(), [this](auto const &i) { dump_val(i); }); + *os_ << " = "; + } + *os_ << "if "; + dump_val(in.condition()); + *os_ << " "; + if (results.size() > 0) { + *os_ << "-> ("; + do_with_infix(results.begin(), results.end(), + [this](auto const &i) { visit(*this, *i.ty()); }); + *os_ << ") "; + } + dump_region(in.then()); + if (!in.is_otherwise_empty()) { + *os_ << " else "; + dump_region(in.otherwise()); + } +} + +void dump_ir_pass::operator()(math_unary_inst in) { + dump_val(in.result()); + *os_ << " = " << to_string(in.get().type_id()) << " "; + dump_val(in.a()); + *os_ << " : "; + visit(*this, *in.result().ty()); +} + +void dump_ir_pass::operator()(parallel_inst p) { + *os_ << "parallel "; + dump_region(p.body()); +} + +void dump_ir_pass::operator()(size_inst s) { + dump_val(s.result()); + *os_ << " = size "; + dump_val(s.operand()); + *os_ << "[" << s.mode() << "]"; + *os_ << " : "; + visit(*this, *s.result().ty()); +} + +void dump_ir_pass::operator()(subgroup_broadcast_inst in) { + dump_val(in.result()); + *os_ << " = subgroup_broadcast "; + dump_val(in.a()); + *os_ << ", "; + dump_val(in.idx()); + *os_ << " : "; + visit(*this, *in.result().ty()); +} + +void dump_ir_pass::operator()(subgroup_operation_inst in) { + dump_val(in.result()); + *os_ << " = " << to_string(in.get().type_id()) << " "; + dump_val(in.a()); + *os_ << " : "; + visit(*this, *in.result().ty()); +} + +void dump_ir_pass::operator()(subview_inst s) { + dump_val(s.result()); + *os_ << " = subview "; + dump_val(s.operand()); + *os_ << "["; + auto dyn_offsets = s.offsets(); + auto dyn_sizes = s.sizes(); + for (std::size_t i = 0, joffset = 0, jsize = 0; i < s.static_offsets().size(); ++i) { + if (i != 0) { + *os_ << ","; + } + auto offset = s.static_offsets()[i]; + if (is_dynamic_value(offset)) { + dump_val(dyn_offsets[joffset++]); + } else { + *os_ << offset; + } + auto size = s.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + *os_ << ":"; + if (is_dynamic_value(size)) { + dump_val(dyn_sizes[jsize++]); + } else { + *os_ << size; + } + } + } + *os_ << "] : "; + visit(*this, *s.result().ty()); +} + +void dump_ir_pass::operator()(store_inst e) { + *os_ << "store"; + if (e.flag() != store_flag::regular) { + *os_ << '.' << to_string(e.flag()); + } + *os_ << ' '; + dump_val(e.val()); + *os_ << ", "; + dump_val(e.operand()); + *os_ << "["; + do_with_infix(e.index_list().begin(), e.index_list().end(), + [this](auto const &i) { dump_val(i); }); + *os_ << "]"; +} + +void dump_ir_pass::operator()(sum_inst a) { + *os_ << "sum"; + if (a.atomic()) { + *os_ << ".atomic"; + } + if (a.tA() != transpose::N) { + *os_ << "." << to_string(a.tA()); + } + dump_blas_a2(static_cast(a)); +} + +void dump_ir_pass::operator()(yield_inst y) { + *os_ << "yield ("; + auto vals = y.yielded_vals(); + if (vals.size() > 0) { + do_with_infix(vals.begin(), vals.end(), [this](auto const &i) { dump_val(i); }, ", "); + } + *os_ << ")"; +} + +void dump_ir_pass::operator()(group_id_inst in) { + dump_val(in.result()); + *os_ << " = group_id." << to_string(in.mode()) << " : "; + visit(*this, *in.result().ty()); +} +void dump_ir_pass::operator()(num_groups_inst in) { + dump_val(in.result()); + *os_ << " = num_groups." << to_string(in.mode()) << " : "; + visit(*this, *in.result().ty()); +} +void dump_ir_pass::operator()(num_subgroups_inst in) { + dump_val(in.result()); + *os_ << " = num_subgroups." << to_string(in.mode()) << " : "; + visit(*this, *in.result().ty()); +} +void dump_ir_pass::operator()(subgroup_size_inst in) { + dump_val(in.result()); + *os_ << " = subgroup_size : "; + visit(*this, *in.result().ty()); +} +void dump_ir_pass::operator()(subgroup_id_inst in) { + dump_val(in.result()); + *os_ << " = subgroup_id." << to_string(in.mode()) << " : "; + visit(*this, *in.result().ty()); +} +void dump_ir_pass::operator()(subgroup_linear_id_inst in) { + dump_val(in.result()); + *os_ << " = subgroup_linear_id : "; + visit(*this, *in.result().ty()); +} +void dump_ir_pass::operator()(subgroup_local_id_inst in) { + dump_val(in.result()); + *os_ << " = subgroup_local_id : "; + visit(*this, *in.result().ty()); +} + +void dump_ir_pass::dump_region(tinytc_region ®) { + if (lvl_ < lvl_limit_) { + *os_ << "{" << std::endl; + ++lvl_; + auto ind = indent(); + for (auto &i : reg) { + *os_ << ind; + visit(*this, i); + *os_ << std::endl; + } + --lvl_; + *os_ << indent() << "}"; + } else { + *os_ << "{...}"; + } +} + +void dump_ir_pass::run_on_function(tinytc_func &fn) { + init_slot_tracker(fn); + + *os_ << "func @" << fn.name() << "("; + std::string infix = ",\n "; + infix += std::string(fn.name().size(), ' '); + do_with_infix_enumerated( + fn.params().begin(), fn.params().end(), + [this, &fn](std::int32_t arg_no, auto const &a) { + dump_val(a); + *os_ << ": "; + visit(*this, *a.ty()); + if (auto pa = fn.param_attr(arg_no); pa) { + *os_ << " "; + visit(*this, *pa); + } + }, + infix); + *os_ << ")"; + if (fn.attr()) { + *os_ << " attributes"; + visit(*this, *fn.attr()); + } + *os_ << " "; + dump_region(fn.body()); + *os_ << std::endl; +} + +void dump_ir_pass::run_on_region(tinytc_region ®) { dump_region(reg); } +void dump_ir_pass::run_on_instruction(tinytc_inst &in) { visit(*this, in); } + +void dump_ir_pass::init_slot_tracker(tinytc_func &fn) { + tracker_ = slot_tracker{}; + tracker_.run_on_function(fn); +} + +} // namespace tinytc diff --git a/src/pass/dump_ir.hpp b/src/pass/dump_ir.hpp new file mode 100644 index 00000000..1e30f514 --- /dev/null +++ b/src/pass/dump_ir.hpp @@ -0,0 +1,128 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_IR_20230330_HPP +#define DUMP_IR_20230330_HPP + +#include "node/attr.hpp" // IWYU pragma: keep +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "pass/slot_tracker.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include + +namespace tinytc { + +class dump_ir_pass { + public: + dump_ir_pass(std::ostream &os, int level_limit = std::numeric_limits::max()); + + /* Attribute nodes */ + void operator()(array_attr const &a); + void operator()(boolean_attr const &a); + void operator()(dictionary_attr const &a); + void operator()(integer_attr const &a); + void operator()(string_attr const &a); + + /* Data type nodes */ + void operator()(void_type const &); + void operator()(boolean_type const &); + void operator()(coopmatrix_type const &ct); + void operator()(group_type const &g); + void operator()(memref_type const &m); + void operator()(number_type const &s); + + /* Inst nodes */ + void operator()(alloca_inst a); + void operator()(axpby_inst a); + void operator()(arith_inst a); + void operator()(arith_unary_inst a); + void operator()(barrier_inst b); + void operator()(cast_inst c); + void operator()(compare_inst c); + void operator()(constant_inst c); + void operator()(cooperative_matrix_apply_inst c); + void operator()(cooperative_matrix_extract_inst c); + void operator()(cooperative_matrix_insert_inst c); + void operator()(cooperative_matrix_load_inst c); + void operator()(cooperative_matrix_mul_add_inst c); + void operator()(cooperative_matrix_prefetch_inst c); + void operator()(cooperative_matrix_reduce_inst c); + void operator()(cooperative_matrix_scale_inst c); + void operator()(cooperative_matrix_store_inst c); + void operator()(cumsum_inst a); + void operator()(expand_inst e); + void operator()(fuse_inst f); + void operator()(load_inst e); + void operator()(lifetime_stop_inst l); + void operator()(gemm_inst g); + void operator()(gemv_inst g); + void operator()(ger_inst g); + void operator()(for_inst p); + void operator()(foreach_inst p); + void operator()(hadamard_inst g); + void operator()(if_inst in); + void operator()(math_unary_inst in); + void operator()(parallel_inst p); + void operator()(size_inst s); + void operator()(subgroup_broadcast_inst in); + void operator()(subgroup_operation_inst in); + void operator()(subview_inst s); + void operator()(store_inst s); + void operator()(sum_inst s); + void operator()(yield_inst y); + void operator()(group_id_inst in); + void operator()(num_groups_inst in); + void operator()(num_subgroups_inst in); + void operator()(subgroup_size_inst in); + void operator()(subgroup_id_inst in); + void operator()(subgroup_linear_id_inst in); + void operator()(subgroup_local_id_inst in); + + void run_on_function(tinytc_func &fn); + void run_on_region(tinytc_region ®); + void run_on_instruction(tinytc_inst &in); + + void dump_val(tinytc_value const &v); + void init_slot_tracker(tinytc_func &fn); + + private: + void dump_region(tinytc_region ®); + void dump_blas_a2(blas_a2_inst g); + void dump_blas_a3(blas_a3_inst g); + + template + void do_with_infix(Iterator begin, Iterator end, Action a, std::string const &infix = ",") { + for (auto it = begin; it != end; ++it) { + if (it != begin) { + *os_ << infix; + } + a(*it); + } + } + template + void do_with_infix_enumerated(Iterator begin, Iterator end, Action a, + std::string const &infix = ",") { + for (auto it = begin; it != end; ++it) { + if (it != begin) { + *os_ << infix; + } + a(it - begin, *it); + } + } + + inline auto indent() { return std::string(2 * lvl_, ' '); } + std::ostream *os_; + int lvl_limit_; + int lvl_ = 0; + + slot_tracker tracker_; +}; + +} // namespace tinytc + +#endif // DUMP_IR_20230330_HPP diff --git a/src/pass/insert_barrier.cpp b/src/pass/insert_barrier.cpp new file mode 100644 index 00000000..546be3b9 --- /dev/null +++ b/src/pass/insert_barrier.cpp @@ -0,0 +1,228 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/insert_barrier.hpp" +#include "analysis/aa_results.hpp" +#include "analysis/alias.hpp" +#include "analysis/cfg.hpp" +#include "error.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { + +auto intersects(std::unordered_set<::tinytc_value const *> const &a, + std::unordered_set<::tinytc_value const *> const &b, aa_results const &aa) { + for (auto &av : a) { + for (auto &bv : b) { + if (aa.alias(*av, *bv)) { + return true; + } + } + } + return false; +} + +void insert_barrier_pass::reads_writes::clear() { + for (auto &rd : reads) { + rd.clear(); + } + for (auto &wr : writes) { + wr.clear(); + } +} + +void insert_barrier_pass::reads_writes::clear(address_space as) { + const auto space = address_space_to_index(as); + reads[space].clear(); + writes[space].clear(); +} + +void insert_barrier_pass::reads_writes::merge(reads_writes const &other) { + for (std::size_t i = 0; i < reads.size(); ++i) { + reads[i].insert(other.reads[i].begin(), other.reads[i].end()); + } + for (std::size_t i = 0; i < writes.size(); ++i) { + writes[i].insert(other.writes[i].begin(), other.writes[i].end()); + } +} + +void insert_barrier_pass::reads_writes::merge(reads_writes &&other) { + for (std::size_t i = 0; i < reads.size(); ++i) { + reads[i].merge(std::move(other.reads[i])); + } + for (std::size_t i = 0; i < writes.size(); ++i) { + writes[i].merge(std::move(other.writes[i])); + } +} + +void insert_barrier_pass::reads_writes::merge(address_space as, reads_writes const &other) { + const auto space = address_space_to_index(as); + reads[space].insert(other.reads[space].begin(), other.reads[space].end()); + writes[space].insert(other.writes[space].begin(), other.writes[space].end()); +} + +void insert_barrier_pass::reads_writes::emplace_read(address_space as, ::tinytc_value const *val) { + const auto space = address_space_to_index(as); + reads[space].emplace(val); +} +void insert_barrier_pass::reads_writes::emplace_write(address_space as, ::tinytc_value const *val) { + const auto space = address_space_to_index(as); + writes[space].emplace(val); +} +auto insert_barrier_pass::reads_writes::read_cardinal(address_space as) const -> std::size_t { + const auto space = address_space_to_index(as); + return reads[space].size(); +} +auto insert_barrier_pass::reads_writes::write_cardinal(address_space as) const -> std::size_t { + const auto space = address_space_to_index(as); + return writes[space].size(); +} + +bool insert_barrier_pass::reads_writes::raw(address_space as, reads_writes const &rw, + aa_results const &aa) const { + const auto space = address_space_to_index(as); + return intersects(rw.reads[space], writes[space], aa); +} +bool insert_barrier_pass::reads_writes::war(address_space as, reads_writes const &rw, + aa_results const &aa) const { + const auto space = address_space_to_index(as); + return intersects(rw.writes[space], reads[space], aa); +} +bool insert_barrier_pass::reads_writes::waw(address_space as, reads_writes const &rw, + aa_results const &aa) const { + const auto space = address_space_to_index(as); + return intersects(rw.writes[space], writes[space], aa); +} +bool insert_barrier_pass::reads_writes::raw_war_or_waw(address_space as, reads_writes const &rw, + aa_results const &aa) const { + return raw(as, rw, aa) || war(as, rw, aa) || waw(as, rw, aa); +} + +auto insert_barrier_pass::reads_writes::address_space_to_index(address_space as) -> std::size_t { + for (std::size_t i = 0; i < address_spaces.size(); ++i) { + if (as == address_spaces[i]) { + return i; + } + } + throw internal_compiler_error{}; +} + +void insert_barrier_pass::run_on_region(tinytc_region ®, aa_results const &aa) { + // irw = reads and writes invisible to other threads + auto irw_in = std::unordered_map{}; + auto irw_out = std::unordered_map{}; + + auto const get_rw = [](tinytc_inst &in) -> reads_writes { + auto rw = reads_writes{}; + auto const emplace_read = [&rw](tinytc_value const &v) { + if (auto *m = dyn_cast(v.ty()); m) { + rw.emplace_read(m->addrspace(), &v); + } + }; + auto const emplace_write = [&rw](tinytc_value const &v) { + if (auto *m = dyn_cast(v.ty()); m) { + rw.emplace_write(m->addrspace(), &v); + } + }; + + visit(overloaded{[&](blas_a2_inst in) { + emplace_read(in.A()); + emplace_write(in.B()); + }, + [&](blas_a3_inst in) { + emplace_read(in.A()); + emplace_read(in.B()); + emplace_write(in.C()); + }, + [&](load_inst in) { emplace_read(in.operand()); }, + [&](store_inst in) { emplace_write(in.operand()); }, [](inst_view) {}}, + in); + return rw; + }; + + auto const get_cardinal = [](reads_writes const &rw) { + return std::array{ + rw.read_cardinal(address_space::global), rw.read_cardinal(address_space::local), + rw.write_cardinal(address_space::global), rw.write_cardinal(address_space::local)}; + }; + + auto cfg = get_control_flow_graph(reg); + auto q = cfg.node_queue(); + while (!q.empty()) { + auto n = q.front(); + q.pop(); + + const bool insert_barriers = cfg.kind_max(n) < region_kind::spmd; + + auto &in = irw_in[n]; + auto &out = irw_out[n]; + + in.clear(); + for (auto &p : cfg.predecessors(n)) { + in.merge(irw_out[p]); + } + + auto out_size_before_update = get_cardinal(out); + + if (auto barrier = dyn_cast(n); insert_barriers && barrier) { + for (auto &as : reads_writes::address_spaces) { + if (!barrier.has_fence(as)) { + out.merge(as, in); + } + } + } else { + out = get_rw(*n); + + std::int32_t fence_flags = 0; + for (auto &as : reads_writes::address_spaces) { + if (insert_barriers && in.raw_war_or_waw(as, out, aa)) { + fence_flags |= static_cast(as); + } else { + out.merge(as, in); + } + } + if (fence_flags != 0) { + tinytc_region *subreg = n->parent(); + auto new_barrier = + subreg->insts() + .insert(n->iterator(), barrier_inst::create(fence_flags, {}).release()) + .get(); + // update cfg + cfg.insert_before(n, new_barrier); + q.push(new_barrier); + } + } + + // out has changed, need to enqueue successors again + if (out_size_before_update != get_cardinal(out)) { + for (auto &s : cfg.successors(n)) { + q.push(s); + } + } + } +} + +/* Function nodes */ +void insert_barrier_pass::run_on_function(tinytc_func &fn) { + auto aa = alias_analysis{}.run_on_function(fn); + run_on_region(fn.body(), aa); +} + +} // namespace tinytc diff --git a/src/pass/insert_barrier.hpp b/src/pass/insert_barrier.hpp new file mode 100644 index 00000000..c56dab46 --- /dev/null +++ b/src/pass/insert_barrier.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INSERT_BARRIER_20230310_HPP +#define INSERT_BARRIER_20230310_HPP + +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include + +namespace tinytc { + +class aa_results; + +class insert_barrier_pass { + public: + void run_on_function(tinytc_func &fn); + + private: + class reads_writes { + public: + constexpr static std::array address_spaces = {address_space::global, + address_space::local}; + + void clear(); + void clear(address_space as); + void merge(reads_writes const &other); + void merge(reads_writes &&other); + void merge(address_space as, reads_writes const &other); + void emplace_read(address_space as, ::tinytc_value const *val); + void emplace_write(address_space as, ::tinytc_value const *val); + auto read_cardinal(address_space as) const -> std::size_t; + auto write_cardinal(address_space as) const -> std::size_t; + + bool raw(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool war(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool waw(address_space as, reads_writes const &rw, aa_results const &aa) const; + bool raw_war_or_waw(address_space as, reads_writes const &rw, aa_results const &aa) const; + + private: + static auto address_space_to_index(address_space as) -> std::size_t; + + std::array, address_spaces.size()> reads, writes; + }; + + void run_on_region(tinytc_region ®, aa_results const &aa); +}; + +} // namespace tinytc + +#endif // INSERT_BARRIER_20230310_HPP diff --git a/src/pass/insert_lifetime_stop.cpp b/src/pass/insert_lifetime_stop.cpp new file mode 100644 index 00000000..33c5e20e --- /dev/null +++ b/src/pass/insert_lifetime_stop.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/insert_lifetime_stop.hpp" +#include "analysis/aa_results.hpp" +#include "analysis/alias.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" + +#include +#include + +namespace tinytc { + +auto insert_lifetime_stop_pass::run_on_region(tinytc_region ®, aa_results const &aa) + -> std::unordered_set { + if (reg.empty()) { + return {}; + } + + auto allocas = std::vector{}; + for (auto &i : reg) { + if (auto alloca = dyn_cast(&i); alloca) { + allocas.emplace_back(&alloca.result()); + } + } + + auto rgn_ops = std::unordered_set{}; + auto prev_it = reg.end(); + while (prev_it != reg.begin()) { + auto &i = *(--prev_it); + for (auto &subreg : i.child_regions()) { + rgn_ops.merge(run_on_region(subreg, aa)); + } + for (auto &v : i.operands()) { + if (isa(*v.ty())) { + rgn_ops.insert(aa.root(v)); + } + } + for (auto &v : i.results()) { + if (isa(*v.ty())) { + rgn_ops.insert(aa.root(v)); + } + } + + auto alloca_it = allocas.begin(); + while (alloca_it != allocas.end()) { + if (rgn_ops.contains(*alloca_it)) { + prev_it = reg.insts().insert_after( + prev_it, lifetime_stop_inst::create(*alloca_it, {}).release()); + --prev_it; + alloca_it = allocas.erase(alloca_it); + } else { + ++alloca_it; + } + } + } + return rgn_ops; +} + +void insert_lifetime_stop_pass::run_on_function(tinytc_func &fn) { + auto aa = alias_analysis{}.run_on_function(fn); + run_on_region(fn.body(), aa); +} + +} // namespace tinytc diff --git a/src/pass/insert_lifetime_stop.hpp b/src/pass/insert_lifetime_stop.hpp new file mode 100644 index 00000000..51c88682 --- /dev/null +++ b/src/pass/insert_lifetime_stop.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INSERT_LIFETIME_STOP_20240912_HPP +#define INSERT_LIFETIME_STOP_20240912_HPP + +#include "node/value.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc { + +class aa_results; + +class insert_lifetime_stop_pass { + public: + void run_on_function(tinytc_func &fn); + + private: + auto run_on_region(tinytc_region ®, aa_results const &aa) + -> std::unordered_set<::tinytc_value const *>; +}; + +} // namespace tinytc + +#endif // INSERT_LIFETIME_STOP_20240912_HPP diff --git a/src/pass/lower_coopmatrix.cpp b/src/pass/lower_coopmatrix.cpp new file mode 100644 index 00000000..50bb707d --- /dev/null +++ b/src/pass/lower_coopmatrix.cpp @@ -0,0 +1,195 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_coopmatrix.hpp" +#include "codegen_tools.hpp" +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "pass/clone.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc { + +class coopmatrix_code_generator { + public: + coopmatrix_code_generator(core_config core_cfg, tinytc_region ®) + : core_cfg_{std::move(core_cfg)}, bb_{®} {} + // Returns true if instruction was replaced + bool operator()(inst_view in); + bool operator()(cooperative_matrix_apply_inst in); + + void run_on_region(tinytc_region ®); + + private: + auto needs_coopmatrix_vector_impl(tinytc_inst &in); + + core_config core_cfg_; + region_builder bb_; +}; + +bool coopmatrix_code_generator::operator()(inst_view) { return false; } + +bool coopmatrix_code_generator::operator()(cooperative_matrix_apply_inst in) { + if (in.body().empty()) { + throw compilation_error(in.loc(), status::ir_must_have_yield); + } + + auto bool_ty = boolean_type::get(in.get().context()); + auto i32_ty = i32_type::get(in.get().context()); + + auto cloner = inst_cloner{}; + + auto ct = get_coopmatrix_type(in.a()); + auto cl = get_layout(core_cfg_, ct); + + auto p = bb_.create(i32_ty, in.loc()); + auto i = p; + auto j0 = tinytc_value_t{nullptr}; + if (cl.rows < core_cfg_.subgroup_size) { + auto cI = bb_.create(cl.rows, i32_ty, in.loc()); + i = bb_.create(p, cI, i32_ty, in.loc()); + j0 = bb_.create(p, cI, i32_ty, in.loc()); + } + const auto col_inc_factor = core_cfg_.subgroup_size / cl.rows; + + auto copy = &in.a(); + for (std::int64_t v = 0; v < cl.length; ++v) { + const auto k1 = v % cl.blocks1; + const auto u = v / cl.blocks1 % cl.cols; + const auto k2 = v / (cl.blocks1 * cl.cols); + + auto row = i; + const auto block_offset = k1 * cl.rows + k2 * cl.rows * cl.blocks1; + if (block_offset) { + auto cblock_offset = bb_.create(block_offset, i32_ty, in.loc()); + row = bb_.create(i, cblock_offset, i32_ty, in.loc()); + } + auto j1 = bb_.create(u * col_inc_factor, i32_ty, in.loc()); + auto col = j0 ? bb_.create(j0, j1, i32_ty, in.loc()) : j1; + auto val = + bb_.create(v, &in.a(), ct->component_ty(), in.loc()); + + cloner.set_subs(&in.row(), row); + cloner.set_subs(&in.col(), col); + cloner.set_subs(&in.val(), val); + + auto modified_val = tinytc_value_t{}; + if ((u + 1) * col_inc_factor > cl.shape1) { + auto cshape1 = bb_.create(cl.shape1, i32_ty, in.loc()); + auto cond = bb_.create(col, cshape1, bool_ty, in.loc()); + modified_val = bb_.ifelse( + cond, + [&](region_builder &bb) { + cloner.clone_region(in.body(), *bb.get_region()); + }, + [&](region_builder &bb) { + auto c0 = bb.constant_zero(ct->component_ty(), in.loc()); + bb.create(array_view{c0}); + }, + {ct->component_ty()}, in.loc()) + .front(); + } else { + cloner.clone_region(in.body(), *bb_.get_region()); + + auto last_inst = --bb_.get_region()->end(); + if (last_inst != bb_.get_region()->end() && isa(*last_inst)) { + auto vals = dyn_cast(last_inst.get()).yielded_vals(); + if (vals.size() != 1) { + throw compilation_error(in.loc(), status::ir_yield_mismatch); + } + modified_val = &vals[0]; + bb_.get_region()->insts().erase(last_inst); + } else { + throw compilation_error(in.loc(), status::ir_must_have_yield); + } + } + copy = bb_.create(v, modified_val, copy, in.result().ty(), + in.loc()); + } + for (auto &r : in.get().results()) { + auto u = r.use_begin(); + while (r.has_uses()) { + u->set(copy); + u = r.use_begin(); + } + } + return true; +} + +void coopmatrix_code_generator::run_on_region(tinytc_region ®) { + // Move all instructions to a temporary ilist. + // We later move the instructions back, except those that are lowered remain in old_ilist + // and are cleaned up at the end of the function. + auto old_ilist = std::move(reg.insts()); + + auto old_reg = bb_.get_region(); + bb_ = region_builder{®}; + + auto it = old_ilist.begin(); + while (it != old_ilist.end()) { + bool replaced = visit(*this, *it); + if (!replaced) { + auto instr = it.get(); + it = old_ilist.unlink(it); + reg.insts().push_back(instr); + for (auto &subreg : instr->child_regions()) { + run_on_region(subreg); + } + } else { + ++it; + } + } + it = old_ilist.end(); + while (it != old_ilist.begin()) { + --it; + for (auto &result : it->results()) { + if (result.has_uses()) { + throw compilation_error(result.loc(), status::ir_value_still_has_uses); + } + } + it = old_ilist.erase(it); + } + + bb_ = region_builder{old_reg}; +} + +lower_coopmatrix_pass::lower_coopmatrix_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_coopmatrix_pass::run_on_function(tinytc_func &fn) { + auto const subgroup_size = fn.subgroup_size(); + core_config core_cfg = {}; + try { + core_cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + + auto gen = coopmatrix_code_generator{core_cfg, fn.body()}; + gen.run_on_region(fn.body()); +} + +} // namespace tinytc diff --git a/src/pass/lower_coopmatrix.hpp b/src/pass/lower_coopmatrix.hpp new file mode 100644 index 00000000..c1190912 --- /dev/null +++ b/src/pass/lower_coopmatrix.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_COOPMATRIX_20241206_HPP +#define LOWER_COOPMATRIX_20241206_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_coopmatrix_pass { + public: + lower_coopmatrix_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_COOPMATRIX_20241206_HPP diff --git a/src/pass/lower_foreach.cpp b/src/pass/lower_foreach.cpp new file mode 100644 index 00000000..cfffd3c3 --- /dev/null +++ b/src/pass/lower_foreach.cpp @@ -0,0 +1,171 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_foreach.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "pass/clone.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +template +void make_loop0(region_builder &bb, tinytc_value_t from, tinytc_value_t to, tinytc_value_t sg_id, + int sgs, int num_tiles, F &&make_body, location const &loc) { + auto ity = from->ty(); + auto ctx = sg_id->context(); + auto bool_ty = get(ctx); + auto i32_ty = get(ctx); + auto sg_lid_i32 = bb.create(i32_ty, loc); + auto sg_lid = bb.create(sg_lid_i32, ity, loc); + auto size = instant_constant_fold_add(bb, create(to, from, ity, loc)); + auto work_item_offset = bb.create(from, sg_lid, ity, loc); + tile_loop_by_sgs( + bb, size, sgs, num_tiles, sg_id, + [&](region_builder &bb, tinytc_value_t block, bool is_remainder, + tinytc_value_t trip_count) { + auto loop_var0 = bb.create(block, work_item_offset, ity, loc); + if (is_remainder) { + auto cond = bb.create(sg_lid, trip_count, bool_ty, loc); + bb.if_condition(cond, [&](region_builder &bb) { make_body(bb, loop_var0); }, loc); + } else { + make_body(bb, loop_var0); + } + }); +} + +class foreach_generator { + public: + foreach_generator(local_tiling tiling, core_config core_cfg) + : tiling_{std::move(tiling)}, core_cfg_{std::move(core_cfg)} {} + auto operator()(inst_view) -> unique_handle { return {}; } + auto operator()(foreach_inst in) -> unique_handle; + + private: + local_tiling tiling_ = {}; + core_config core_cfg_ = {}; +}; + +auto foreach_generator::operator()(foreach_inst in) -> unique_handle { + const int block_size0 = core_cfg_.subgroup_size; + + auto parallel = create(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto i32_ty = i32_type::get(in.get().context()); + + auto cloner = inst_cloner{}; + auto loop_vars = in.loop_vars().begin(); + auto from = in.from().begin(); + auto to = in.to().begin(); + auto ity = (*from).ty(); + + if (in.dim() > 1) { + auto const make_inner_loop_nest = [&](region_builder &bb, tinytc_value_t from1, + tinytc_value_t to1) { + tinytc_region_t current_region = bb.get_region(); + for (std::int64_t i = in.dim() - 1; i > 1; --i) { + auto for_i = + create(&from[i], &to[i], nullptr, array_view{}, + array_view{}, in.loc()); + auto for_i_view = for_inst(for_i.get()); + cloner.set_subs(&loop_vars[i], &for_i_view.loop_var()); + tinytc_region_t next_region = &for_i_view.body(); + current_region->insts().push_back(for_i.release()); + current_region = next_region; + } + region_builder{current_region}.for_loop( + from1, to1, + [&](region_builder &bb, tinytc_value_t loop_var1) { + cloner.set_subs(&loop_vars[1], loop_var1); + cloner.clone_region(in.body(), *bb.get_region()); + }, + nullptr, in.loc()); + }; + + auto sg_id0 = bb.create(comp3::x, i32_ty, in.loc()); + auto sg_id1 = bb.create(comp3::y, i32_ty, in.loc()); + + auto size1 = bb.create(&to[1], &from[1], ity, in.loc()); + tile_loop_uniformly( + bb, size1, core_cfg_.subgroup_size, tiling_.n_tiles(), sg_id1, + [&](region_builder &bb, tinytc_value_t block, tinytc_value_t trip_count1) { + auto from1 = bb.create(&from[1], block, ity, in.loc()); + auto to1 = bb.create(from1, trip_count1, ity, in.loc()); + make_loop0( + bb, &from[0], &to[0], sg_id0, block_size0, tiling_.m_tiles(), + [&](region_builder &bb, tinytc_value_t loop_var0) { + cloner.set_subs(&loop_vars[0], loop_var0); + make_inner_loop_nest(bb, from1, to1); + }, + in.loc()); + }); + } else if (in.dim() == 1) { + auto sg_id = bb.create(i32_ty, in.loc()); + make_loop0( + bb, &from[0], &to[0], sg_id, block_size0, tiling_.m_tiles() * tiling_.n_tiles(), + [&](region_builder &bb, tinytc_value_t loop_var0) { + cloner.set_subs(&loop_vars[0], loop_var0); + cloner.clone_region(in.body(), *bb.get_region()); + }, + in.loc()); + } + + return parallel; +} + +lower_foreach_pass::lower_foreach_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_foreach_pass::run_on_function(tinytc_func &fn) { + auto const subgroup_size = fn.subgroup_size(); + core_config core_cfg = {}; + try { + core_cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + auto const work_group_size = fn.work_group_size(); + local_tiling tiling = {}; + tiling[0] = work_group_size[0] / subgroup_size; + tiling[1] = work_group_size[1]; + + walk(fn, [&](tinytc_region ®) { + for (auto it = reg.begin(); it != reg.end(); ++it) { + auto lowered_inst = visit(foreach_generator{tiling, core_cfg}, *it); + if (lowered_inst) { + it = reg.insts().erase(it); + it = reg.insts().insert(it, lowered_inst.release()); + } + } + }); +} + +} // namespace tinytc diff --git a/src/pass/lower_foreach.hpp b/src/pass/lower_foreach.hpp new file mode 100644 index 00000000..0486dd03 --- /dev/null +++ b/src/pass/lower_foreach.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_FOREACH_20241118_HPP +#define LOWER_FOREACH_20241118_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_foreach_pass { + public: + lower_foreach_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_FOREACH_20241118_HPP diff --git a/src/pass/lower_linalg.cpp b/src/pass/lower_linalg.cpp new file mode 100644 index 00000000..07de2731 --- /dev/null +++ b/src/pass/lower_linalg.cpp @@ -0,0 +1,820 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/lower_linalg.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "gemm_tools.hpp" +#include "matrix_ext_info.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +void gemm_microkernel(region_builder &bb, transpose tA, transpose tB, bool atomic, + tinytc_value_t alpha, tinytc_value_t A, tinytc_value_t B, tinytc_value_t beta, + tinytc_value_t C, tinytc_value_t K, tinytc_value_t m_block, + std::int32_t m_block_size, std::int32_t num_m_blocks, bool m_check, + tinytc_value_t n_block, std::int32_t n_block_size, std::int32_t num_n_blocks, + bool n_check, array_view K_block_sizes, tinytc_type_t a_ty, + tinytc_type_t b_ty, tinytc_type_t c_ty, tinytc_attr_t for_attributes, + location const &loc) { + auto ctx = m_block->context(); + auto bool_ty = boolean_type::get(ctx); + auto index_ty = index_type::get(ctx); + + const auto check_a = m_check ? checked_flag::rows : checked_flag::none; + const auto check_b = n_check ? checked_flag::cols : checked_flag::none; + const auto check_c = [&] { + if (m_check && n_check) { + return checked_flag::both; + } else if (m_check) { + return checked_flag::rows; + } else if (n_check) { + return checked_flag::cols; + } + return checked_flag::none; + }(); + auto c_m_block_size = bb.create(m_block_size, index_ty, loc); + auto c_n_block_size = bb.create(n_block_size, index_ty, loc); + + const auto c_acc_ty = [&c_ty, &loc]() { + if (!isa(*c_ty)) { + throw compilation_error(loc, status::ir_expected_number); + } + return acc_type(c_ty); + }(); + + auto coopmatrix_c_ty = get(c_ty, m_block_size, n_block_size, matrix_use::acc); + auto coopmatrix_c_acc_ty = + get(c_acc_ty, m_block_size, n_block_size, matrix_use::acc); + auto const compute_c_step = [&](region_builder &bb, std::int32_t k_block_size, tinytc_value_t k, + array_view const &c_acc, + array_view const &c_acc_tys, + bool check_k = false) { + tinytc_value_t pos_a[2] = {m_block, k}; + int amode = 0; + if (tA == transpose::T) { + std::swap(pos_a[0], pos_a[1]); + amode = 1 - amode; + } + auto coopmatrix_a_ty = + get(a_ty, m_block_size, k_block_size, matrix_use::a); + const auto my_check_a = check_k ? add_check(check_a, checked_flag::cols) : check_a; + auto a = std::vector{}; + a.reserve(num_m_blocks); + for (std::int32_t i = 0; i < num_m_blocks; ++i) { + a.emplace_back(bb.create(tA, my_check_a, A, pos_a[0], + pos_a[1], coopmatrix_a_ty)); + if (i + 1 < num_m_blocks) { + pos_a[amode] = bb.create(pos_a[amode], c_m_block_size, index_ty, loc); + } + } + + tinytc_value_t pos_b[2] = {k, n_block}; + int bmode = 1; + if (tB == transpose::T) { + std::swap(pos_b[0], pos_b[1]); + bmode = 1 - bmode; + } + auto coopmatrix_b_ty = + get(b_ty, k_block_size, n_block_size, matrix_use::b); + const auto my_check_b = check_k ? add_check(check_b, checked_flag::rows) : check_b; + auto b = std::vector{}; + b.reserve(num_n_blocks); + for (std::int32_t i = 0; i < num_n_blocks; ++i) { + b.emplace_back(bb.create(tB, my_check_b, B, pos_b[0], + pos_b[1], coopmatrix_b_ty)); + if (i + 1 < num_n_blocks) { + pos_b[bmode] = bb.create(pos_b[bmode], c_n_block_size, index_ty, loc); + } + } + + auto c_next = std::vector{}; + c_next.reserve(num_m_blocks * num_n_blocks); + for (std::int32_t n = 0; n < num_n_blocks; ++n) { + for (std::int32_t m = 0; m < num_m_blocks; ++m) { + c_next.emplace_back(bb.create( + a[m], b[n], c_acc[m + n * num_m_blocks], c_acc_tys[m + n * num_m_blocks], loc)); + } + } + return c_next; + }; + auto const compute_c = [&](region_builder &bb, std::int32_t k_block_size, tinytc_value_t K0, + tinytc_value_t K1, std::vector const &c_acc, + std::vector const &c_acc_tys, + bool check_k = false) -> std::vector { + auto c_step = bb.create(k_block_size, index_ty, loc); + auto return_values = bb.for_loop( + K0, K1, c_step, c_acc, c_acc_tys, + [&](region_builder &bb, array_view p) { + const auto k = p[0]; + auto c_acc_iter = array_view(p.begin() + 1, p.end()); + auto c_next = compute_c_step(bb, k_block_size, k, c_acc_iter, c_acc_tys, check_k); + bb.create(c_next, loc); + }, + for_attributes); + return return_values; + }; + + auto c_acc = std::vector{}; + c_acc.reserve(num_m_blocks * num_n_blocks); + for (std::int32_t i = 0; i < num_m_blocks * num_n_blocks; ++i) { + c_acc.emplace_back(bb.constant_zero(coopmatrix_c_acc_ty, loc)); + } + auto c_acc_tys = std::vector(c_acc.size()); + for (auto &ty : c_acc_tys) { + ty = coopmatrix_c_acc_ty; + } + auto c_tys = std::vector(c_acc.size()); + for (auto &ty : c_tys) { + ty = coopmatrix_c_ty; + } + + auto k_block_size = K_block_sizes.back(); + + const auto const_K = get_int_constant(K); + if (const_K) { + k_block_size = choose_k_block_size(K_block_sizes, *const_K); + } + + auto c_zero = bb.constant_zero(index_ty, loc); + auto c_k_block_size = bb.create(k_block_size, index_ty, loc); + auto tmp = instant_constant_fold_add(bb, create(K, c_k_block_size, index_ty, loc)); + auto K0 = instant_constant_fold_add(bb, create(tmp, c_k_block_size, index_ty, loc)); + auto needs_remainder = + instant_constant_fold_add(bb, create(K0, K, bool_ty, loc)); + auto r = get_bool_constant(needs_remainder); + if (r) { + if (*r != 0) { + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc, c_acc_tys); + const auto K_block_size = K_block_sizes.front(); + c_acc = compute_c(bb, K_block_size, K0, K, c_acc, c_acc_tys, K_block_size > 1); + } else { + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc, c_acc_tys); + } + } else { + c_acc = compute_c(bb, k_block_size, c_zero, K0, c_acc, c_acc_tys); + auto remainder = bb.ifelse( + needs_remainder, + [&](region_builder &bb) { + const auto K_block_size = K_block_sizes.front(); + auto c_next = + compute_c(bb, K_block_size, K0, K, c_acc, c_acc_tys, K_block_size > 1); + bb.create(c_next, loc); + }, + [&](region_builder &bb) { bb.create(c_acc, loc); }, c_acc_tys, loc); + c_acc = std::move(remainder); + } + + for (auto &a : c_acc) { + a = mixed_precision_coopmatrix_scale(bb, alpha, a, loc); + } + + const bool needs_final_cast = coopmatrix_c_ty != coopmatrix_c_acc_ty; + if (atomic) { + auto flag = get_atomic_store_flag(beta); + if (!flag) { + throw compilation_error(loc, status::ir_invalid_beta); + } + for (std::int32_t n = 0; n < num_n_blocks; ++n) { + auto pos1_offset = bb.create(n * n_block_size, index_ty, loc); + auto pos1 = bb.create(n_block, pos1_offset, index_ty, loc); + for (std::int32_t m = 0; m < num_m_blocks; ++m) { + auto pos0_offset = bb.create(m * m_block_size, index_ty, loc); + auto pos0 = bb.create(m_block, pos0_offset, index_ty, loc); + auto alpha_ab_mn = c_acc[m + n * num_m_blocks]; + if (needs_final_cast) { + alpha_ab_mn = bb.create(alpha_ab_mn, coopmatrix_c_ty, loc); + } + bb.create(transpose::N, check_c, *flag, alpha_ab_mn, + C, pos0, pos1, loc); + } + } + } else { + for (std::int32_t n = 0; n < num_n_blocks; ++n) { + auto pos1_offset = bb.create(n * n_block_size, index_ty, loc); + auto pos1 = bb.create(n_block, pos1_offset, index_ty, loc); + for (std::int32_t m = 0; m < num_m_blocks; ++m) { + auto pos0_offset = bb.create(m * m_block_size, index_ty, loc); + auto pos0 = bb.create(m_block, pos0_offset, index_ty, loc); + auto c_load = bb.create(transpose::N, check_c, C, + pos0, pos1, coopmatrix_c_ty); + auto &alpha_ab_mn = c_acc[m + n * num_m_blocks]; + auto alpha_ab_plus_beta_c = [&] { + if (needs_final_cast) { + auto c_load_acc = bb.create(c_load, coopmatrix_c_acc_ty, loc); + auto beta_c = mixed_precision_coopmatrix_scale(bb, beta, c_load_acc, loc); + auto alpha_ab_plus_beta_c = + bb.create(alpha_ab_mn, beta_c, alpha_ab_mn->ty(), loc); + return bb.create(alpha_ab_plus_beta_c, coopmatrix_c_ty, loc); + } else { + auto beta_c = mixed_precision_coopmatrix_scale(bb, beta, c_load, loc); + return bb.create(alpha_ab_mn, beta_c, alpha_ab_mn->ty(), loc); + } + }(); + bb.create(transpose::N, check_c, store_flag::regular, + alpha_ab_plus_beta_c, C, pos0, pos1, loc); + } + } + } +} + +class linalg_generator { + public: + linalg_generator(local_tiling const &tiling, core_config const &core_cfg, tinytc_region ®, + tinytc_inst_iterator_t ip) + : tiling_{tiling}, core_cfg_{core_cfg}, bb_{®, ip} {} + inline void operator()(inst_view in) { + throw compilation_error(in.loc(), status::not_implemented); + } + void operator()(axpby_inst in); + void operator()(cumsum_inst in); + void operator()(gemm_inst in); + void operator()(gemv_inst in); + void operator()(ger_inst in); + void operator()(hadamard_inst in); + void operator()(sum_inst in); + + inline auto insertion_point() -> tinytc_inst_iterator_t { return bb_.get_insertion_point(); } + + private: + auto get_memref_type(tinytc_value const &v) const -> const memref_type *; + + local_tiling const &tiling_; + core_config const &core_cfg_; + region_builder bb_; +}; + +auto linalg_generator::get_memref_type(tinytc_value const &v) const -> const memref_type * { + auto t = dyn_cast(v.ty()); + if (t == nullptr) { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + return t; +} + +void linalg_generator::operator()(axpby_inst in) { + auto ctx = in.alpha().context(); + auto bool_ty = get(ctx); + auto index_ty = get(ctx); + + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + if (bt->dim() == 0) { + auto parallel = create(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto i32_ty = get(ctx); + auto sg_id = bb.create(i32_ty, in.loc()); + auto sg_lid = bb.create(i32_ty, in.loc()); + auto c0 = bb.create(0, i32_ty); + auto cond0 = bb.create(sg_id, c0, bool_ty, in.loc()); + auto cond1 = bb.create(sg_lid, c0, bool_ty, in.loc()); + auto cond = bb.create(cond0, cond1, cond0->ty()); + bb.if_condition(cond, [&](region_builder &bb) { + auto a = bb.create(&in.A(), array_view{}, at->element_ty(), + in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {}, in.loc()); + }); + + bb_.add(std::move(parallel)); + } else if (bt->dim() == 1) { + auto c0 = bb_.constant_zero(index_ty, in.loc()); + auto c_shape0 = + instant_constant_fold_add(bb_, create(0, &in.B(), index_ty, in.loc())); + bb_.foreach_loop( + {c0}, {c_shape0}, + [&](region_builder &bb, auto loop_vars) { + auto a = bb.create(&in.A(), array_view{loop_vars[0]}, at->element_ty(), + in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), {loop_vars[0]}, + in.loc()); + }, + in.loc()); + } else if (bt->dim() == 2) { + auto c0 = bb_.constant_zero(index_ty, in.loc()); + auto c_shape0 = + instant_constant_fold_add(bb_, create(0, &in.B(), index_ty, in.loc())); + auto c_shape1 = + instant_constant_fold_add(bb_, create(1, &in.B(), index_ty, in.loc())); + bb_.foreach_loop( + {c0, c0}, {c_shape0, c_shape1}, + [&](region_builder &bb, auto loop_vars) { + auto a_idx = std::array{loop_vars[0], loop_vars[1]}; + if (in.tA() == transpose::T) { + std::swap(a_idx[0], a_idx[1]); + } + auto a = bb.create(&in.A(), a_idx, at->element_ty(), in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), a, &in.beta(), &in.B(), + {loop_vars[0], loop_vars[1]}, in.loc()); + }, + in.loc()); + } +} + +void linalg_generator::operator()(cumsum_inst in) { + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + + const auto num_tiles = tiling_.m_tiles() * tiling_.n_tiles(); + auto ctx = in.alpha().context(); + auto bool_ty = get(ctx); + auto i32_ty = get(ctx); + auto index_ty = get(ctx); + const auto &loc = in.loc(); + + auto const scan_loop_1d = [&](region_builder &bb, work_group_inclusive_scan &scan, + tinytc_value_t a_sub, tinytc_value_t b_sub) { + auto c_sgs = bb.create(scan.subgroup_size(), i32_ty, loc); + auto sglid = bb.create(i32_ty, loc); + auto from_index = [&]() -> tinytc_value_t { + if (scan.num_tiles() > 1) { + auto sgid = bb.create(i32_ty, loc); + auto from0 = bb.create(sgid, c_sgs, i32_ty, loc); + auto from1 = bb.create(from0, sglid, i32_ty, loc); + return bb.create(from1, index_ty, loc); + } else { + return bb.create(sglid, index_ty, loc); + } + }(); + + auto c_step = + bb.create(scan.subgroup_size() * scan.num_tiles(), index_ty, loc); + + auto c_1 = bb.constant_one(index_ty, loc); + auto shape0 = instant_constant_fold_add(bb, create(0, a_sub, index_ty, loc)); + auto tr0 = instant_constant_fold_add(bb, create(shape0, c_1, index_ty, loc)); + auto tr1 = instant_constant_fold_add(bb, create(tr0, c_step, index_ty, loc)); + auto tr2 = instant_constant_fold_add(bb, create(tr1, c_1, index_ty, loc)); + auto trip_count = + instant_constant_fold_add(bb, create(tr2, c_step, index_ty, loc)); + + auto c_init = bb.constant_zero(bt->element_ty(), loc); + auto a_scan = bb.for_loop( + from_index, trip_count, c_step, {c_init}, {bt->element_ty()}, + [&](region_builder &bb, array_view args) { + auto is_in_bounds = bb.create(args[0], shape0, bool_ty, loc); + auto a = bb.ifelse( + is_in_bounds, + [&](region_builder &bb) { + auto a = + bb.create(a_sub, array_view{args[0]}, at->element_ty(), loc); + if (at->element_ty() != bt->element_ty()) { + a = bb.create(a, bt->element_ty(), loc); + } + bb.create(array_view{a}, loc); + }, + [&](region_builder &bb) { bb.create(array_view{c_init}, loc); }, + {bt->element_ty()}, loc); + auto [a_scan, next_prefix] = scan.make(bb, a[0], true, loc); + a_scan = bb.create(args[1], a_scan, bt->element_ty(), loc); + next_prefix = bb.create(args[1], next_prefix, bt->element_ty(), loc); + bb.if_condition( + is_in_bounds, + [&](region_builder &bb) { + blas_update(bb, in.atomic(), &in.alpha(), a_scan, &in.beta(), b_sub, + {args[0]}, loc); + }, + loc); + bb.create(array_view{next_prefix}, loc); + }); + }; + + if (bt->dim() == 1) { + auto parallel = create(loc); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto scan = work_group_inclusive_scan(num_tiles, core_cfg_.subgroup_size, bt->element_ty()); + scan.setup(bb_, loc); + + scan_loop_1d(bb, scan, &in.A(), &in.B()); + + bb_.add(std::move(parallel)); + scan.teardown(bb_); + } else if (bt->dim() >= 2 && in.mode() == 0) { + auto scan = work_group_inclusive_scan(1, core_cfg_.subgroup_size, bt->element_ty()); + scan.setup(bb_, loc); + + auto parallel = create(loc); + + auto c_zero = bb_.constant_zero(index_ty, loc); + tinytc_region_t parent_region = ¶llel->child_region(0); + auto offsets = std::vector(bt->dim() - 1, nullptr); + for (std::int64_t i = bt->dim() - 1; i > 1; --i) { + auto bb = region_builder{parent_region}; + auto shape_i = bb.create(i, &in.B(), index_ty, loc); + auto for_i = create(c_zero, shape_i, nullptr, array_view{}, + array_view{}, loc); + auto for_i_view = for_inst(for_i.get()); + offsets[i - 1] = &for_i_view.body().param(0); + parent_region = &for_i_view.body(); + bb.add(std::move(for_i)); + } + + auto bb = region_builder{parent_region}; + auto sgid = bb.create(i32_ty, loc); + auto sgid_index = bb.create(sgid, index_ty, loc); + + auto shape0 = bb.create(0, &in.B(), index_ty, loc); + auto shape1 = bb.create(1, &in.B(), index_ty, loc); + auto c_num_tiles = bb.create(num_tiles, index_ty, loc); + bb.for_loop( + sgid_index, shape1, c_num_tiles, + [&](region_builder &bb, array_view args) { + auto static_offset = std::vector(bt->dim(), dynamic); + auto static_size = std::vector(bt->dim(), 0); + static_offset[0] = 0; + static_size[0] = dynamic; + auto a_sub_ty = get(at->element_ty(), array_view{dynamic}, + array_view{at->stride(0)}, at->addrspace()); + auto b_sub_ty = get(bt->element_ty(), array_view{dynamic}, + array_view{bt->stride(0)}, bt->addrspace()); + offsets[0] = args[0]; + auto a_sub = bb.create(static_offset, static_size, &in.A(), offsets, + array_view{shape0}, a_sub_ty, loc); + auto b_sub = bb.create(static_offset, static_size, &in.B(), offsets, + array_view{shape0}, b_sub_ty, loc); + scan_loop_1d(bb, scan, a_sub, b_sub); + }); + + bb_.add(std::move(parallel)); + scan.teardown(bb_); + } else if (bt->dim() >= 2) { + auto c_zero = bb_.constant_zero(index_ty, loc); + auto lb = std::vector(bt->dim() - 1, c_zero); + auto ub = std::vector{}; + ub.reserve(bt->dim() - 1); + for (std::int64_t i = 0; i < bt->dim(); ++i) { + if (i != in.mode()) { + ub.emplace_back( + instant_constant_fold_add(bb_, create(i, &in.B(), index_ty, loc))); + } + } + + auto J = bb_.create(in.mode(), &in.B(), index_ty, loc); + bb_.foreach_loop( + lb, ub, + [&](region_builder &bb, auto loop_vars) { + auto static_offset = std::vector(bt->dim(), dynamic); + auto static_size = std::vector(bt->dim(), 0); + static_offset[in.mode()] = 0; + static_size[in.mode()] = dynamic; + auto a_sub_ty = + get(at->element_ty(), array_view{dynamic}, + array_view{at->stride(in.mode())}, at->addrspace()); + auto a_sub = bb.create(static_offset, static_size, &in.A(), loop_vars, + array_view{J}, a_sub_ty, loc); + auto b_sub_ty = + get(bt->element_ty(), array_view{dynamic}, + array_view{bt->stride(in.mode())}, bt->addrspace()); + auto b_sub = bb.create(static_offset, static_size, &in.B(), loop_vars, + array_view{J}, b_sub_ty, loc); + + auto c_init = bb.constant_zero(bt->element_ty()); + auto acc = bb.for_loop(c_zero, J, {}, {c_init}, {bt->element_ty()}, + [&](region_builder &bb, array_view args) { + auto a = bb.create(a_sub, array_view{args[0]}, + at->element_ty(), loc); + auto prefix = mixed_precision_arithmetic( + bb, bt->element_ty(), args[1], a, loc); + blas_update(bb, in.atomic(), &in.alpha(), prefix, + &in.beta(), b_sub, {args[0]}, loc); + bb.create(array_view{prefix}, loc); + }); + }, + loc); + } +} + +void linalg_generator::operator()(gemm_inst in) { + auto parallel = create(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + + auto ctx = in.alpha().context(); + auto i32_ty = get(ctx); + auto index_ty = get(ctx); + + auto sg_m = bb.create(comp3::x, i32_ty, in.loc()); + auto sg_n = bb.create(comp3::y, i32_ty, in.loc()); + + auto [max_rows, max_cols] = max_register_block_gemm( + size(at->element_ty()), size(bt->element_ty()), size(acc_type(ct->element_ty())), + core_cfg_.subgroup_size, core_cfg_.register_space, + isa(*ct->element_ty()) ? 2 : 1); + + auto c_shape0 = + instant_constant_fold_add(bb, create(0, &in.C(), index_ty, in.loc())); + auto c_shape1 = + instant_constant_fold_add(bb, create(1, &in.C(), index_ty, in.loc())); + auto K = instant_constant_fold_add( + bb, create(in.tA() == transpose::T ? 0 : 1, &in.A(), index_ty, in.loc())); + + auto const_shape0 = get_int_constant(c_shape0); + auto const_shape1 = get_int_constant(c_shape1); + + const auto [block_size0, num_blocks0, block_size1, num_blocks1, do_tile_uniformly, + K_block_sizes] = + [&]() -> std::tuple> { + if (auto ext_type = core_cfg_.matrix->get_precision(at->element_ty()->type_id(), + bt->element_ty()->type_id(), + ct->element_ty()->type_id()); + ext_type) { + const auto M_bs = ext_type->M_block_sizes(); + // @todo Think about what do if we have multiple sizes for M + const auto block_size0 = M_bs.back(); + const auto shape0 = const_shape0 ? *const_shape0 : max_rows; + const auto num_blocks0 = + choose_block_size_multiple(block_size0, max_rows, tiling_.m_tiles(), shape0); + + // @todo Think about what do for multiple N sizes + const auto N_bs = ext_type->N_block_sizes(block_size0); + const auto block_size1 = N_bs.back(); + const auto shape1 = const_shape1 ? *const_shape1 : max_cols; + const auto num_blocks1 = + choose_block_size_multiple(block_size1, max_cols, tiling_.n_tiles(), shape1); + const auto K_bs = ext_type->K_block_sizes(block_size0, block_size1); + + return std::make_tuple(block_size0, num_blocks0, block_size1, num_blocks1, false, K_bs); + } + + const auto block_size0 = core_cfg_.subgroup_size; + const auto shape0 = const_shape0 ? *const_shape0 : max_rows; + const auto num_blocks0 = + choose_block_size_multiple(block_size0, max_rows, tiling_.m_tiles(), shape0); + const auto block_size1 = max_cols; + const auto num_blocks1 = 1; + + return std::make_tuple(block_size0, num_blocks0, block_size1, num_blocks1, + const_shape1.has_value(), + std::vector(standard_K_block_sizes.begin(), + standard_K_block_sizes.end())); + }(); + + if (do_tile_uniformly) { + tile_loop_uniformly( + bb, c_shape1, block_size1 * num_blocks1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, tinytc_value_t n_block, tinytc_value_t trip_count) { + auto const_trip_count = get_int_constant(trip_count); + if (!const_trip_count) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + tile_loop_by_sgs( + bb, c_shape0, block_size0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, tinytc_value_t m_block, bool m_check, tinytc_value_t) { + gemm_microkernel(bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), &in.A(), + &in.B(), &in.beta(), &in.C(), K, m_block, block_size0, + num_blocks0, m_check, n_block, *const_trip_count, + num_blocks1, false, K_block_sizes, at->element_ty(), + bt->element_ty(), ct->element_ty(), nullptr, in.loc()); + }); + }); + } else { + auto no_unroll = get_dictionary_attr_with_sorted( + ctx, + tinytc_named_attr_t{get(ctx, "unroll"), get(ctx, false)}); + tile_loop_by_sgs( + bb, c_shape1, block_size1 * num_blocks1, tiling_.n_tiles(), sg_n, + [&](region_builder &bb, tinytc_value_t n_block, bool n_check, tinytc_value_t) { + tile_loop_by_sgs( + bb, c_shape0, block_size0 * num_blocks0, tiling_.m_tiles(), sg_m, + [&](region_builder &bb, tinytc_value_t m_block, bool m_check, tinytc_value_t) { + gemm_microkernel(bb, in.tA(), in.tB(), in.atomic(), &in.alpha(), &in.A(), + &in.B(), &in.beta(), &in.C(), K, m_block, block_size0, + num_blocks0, m_check, n_block, block_size1, num_blocks1, + n_check, K_block_sizes, at->element_ty(), bt->element_ty(), + ct->element_ty(), no_unroll, in.loc()); + }, + no_unroll); + }, + no_unroll); + } + + bb_.add(std::move(parallel)); +} + +void linalg_generator::operator()(gemv_inst in) { + auto index_ty = index_type::get(in.alpha().context()); + auto c0 = bb_.constant_zero(index_ty, in.loc()); + auto c_shape0 = + instant_constant_fold_add(bb_, create(0, &in.C(), index_ty, in.loc())); + auto ct = get_memref_type(in.C()); + bb_.foreach_loop( + {c0}, {c_shape0}, + [&](region_builder &bb, auto loop_vars) { + auto c_init = bb.constant_zero(ct->element_ty()); + auto K = + bb.create(in.tA() == transpose::T ? 0 : 1, &in.A(), index_ty, in.loc()); + auto c_acc = bb.for_loop( + c0, K, {}, {c_init}, {ct->element_ty()}, + [&](region_builder &bb, array_view p) { + auto a_idx = std::array{loop_vars[0], p[0]}; + if (in.tA() == transpose::T) { + std::swap(a_idx[0], a_idx[1]); + } + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto a = bb.create(&in.A(), a_idx, at->element_ty(), in.loc()); + auto b = + bb.create(&in.B(), array_view{p[0]}, bt->element_ty(), in.loc()); + auto ab = + mixed_precision_arithmetic(bb, ct->element_ty(), a, b, in.loc()); + auto ab_c = mixed_precision_arithmetic(bb, ct->element_ty(), p[1], ab, + in.loc()); + bb.create(array_view{ab_c}, in.loc()); + }); + blas_update(bb, in.atomic(), &in.alpha(), c_acc[0], &in.beta(), &in.C(), {loop_vars[0]}, + in.loc()); + }, + in.loc()); +} + +void linalg_generator::operator()(ger_inst in) { + auto index_ty = index_type::get(in.alpha().context()); + auto c0 = bb_.constant_zero(index_ty, in.loc()); + auto c_shape0 = + instant_constant_fold_add(bb_, create(0, &in.C(), index_ty, in.loc())); + auto c_shape1 = + instant_constant_fold_add(bb_, create(1, &in.C(), index_ty, in.loc())); + bb_.foreach_loop( + {c0, c0}, {c_shape0, c_shape1}, + [&](region_builder &bb, auto loop_vars) { + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + auto a = + bb.create(&in.A(), array_view{loop_vars[0]}, at->element_ty(), in.loc()); + auto b = + bb.create(&in.B(), array_view{loop_vars[1]}, bt->element_ty(), in.loc()); + auto ab = mixed_precision_arithmetic(bb, ct->element_ty(), a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), + {loop_vars[0], loop_vars[1]}, in.loc()); + }, + in.loc()); +} + +void linalg_generator::operator()(hadamard_inst in) { + auto index_ty = index_type::get(in.alpha().context()); + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + auto ct = get_memref_type(in.C()); + + auto lb = std::vector(ct->dim()); + auto ub = std::vector(ct->dim()); + + auto c0 = bb_.constant_zero(index_ty, in.loc()); + for (std::int64_t i = 0; i < ct->dim(); ++i) { + lb[i] = c0; + ub[i] = instant_constant_fold_add(bb_, create(i, &in.C(), index_ty, in.loc())); + } + + bb_.foreach_loop( + lb, ub, + [&](region_builder &bb, auto loop_vars) { + auto a = bb.create(&in.A(), loop_vars, at->element_ty(), in.loc()); + auto b = bb.create(&in.B(), loop_vars, bt->element_ty(), in.loc()); + auto ab = mixed_precision_arithmetic(bb, ct->element_ty(), a, b, in.loc()); + blas_update(bb, in.atomic(), &in.alpha(), ab, &in.beta(), &in.C(), loop_vars, in.loc()); + }, + in.loc()); +} + +void linalg_generator::operator()(sum_inst in) { + auto at = get_memref_type(in.A()); + auto bt = get_memref_type(in.B()); + + auto ctx = in.alpha().context(); + auto bool_ty = get(ctx); + auto i32_ty = get(ctx); + auto index_ty = get(ctx); + + if (bt->dim() == 0) { + const auto num_tiles = tiling_.m_tiles() * tiling_.n_tiles(); + auto reducer = work_group_reduce(num_tiles, core_cfg_.subgroup_size, bt->element_ty()); + reducer.setup(bb_, in.loc()); + + auto parallel = create(in.loc()); + tinytc_region_t body = ¶llel->child_region(0); + auto bb = region_builder{body}; + + auto c_sgs = bb.create(core_cfg_.subgroup_size, i32_ty, in.loc()); + auto sgid = bb.create(i32_ty, in.loc()); + auto m = bb.create(i32_ty, in.loc()); + auto from0 = bb.create(sgid, c_sgs, i32_ty, in.loc()); + auto from1 = bb.create(from0, m, i32_ty, in.loc()); + auto from_index = bb.create(from1, index_ty, in.loc()); + + auto c_trip_count = + instant_constant_fold_add(bb, create(0, &in.A(), index_ty, in.loc())); + auto c_step = + bb.create(core_cfg_.subgroup_size * num_tiles, index_ty, in.loc()); + auto c_init = bb.constant_zero(bt->element_ty(), in.loc()); + + auto acc = bb.for_loop(from_index, c_trip_count, c_step, {c_init}, {bt->element_ty()}, + [&](region_builder &bb, array_view args) { + auto a = bb.create(&in.A(), array_view{args[0]}, + at->element_ty(), in.loc()); + auto sum = mixed_precision_arithmetic( + bb, bt->element_ty(), args[1], a, in.loc()); + bb.create(array_view{sum}, in.loc()); + }); + auto acc_reduced = reducer.make(bb, acc[0], in.loc()); + + auto c_zero = bb.constant_zero(i32_ty, in.loc()); + auto is_first_work_item = bb.create(from1, c_zero, bool_ty, in.loc()); + bb.if_condition( + is_first_work_item, + [&](region_builder &bb) { + blas_update(bb, in.atomic(), &in.alpha(), acc_reduced, &in.beta(), &in.B(), {}, + in.loc()); + }, + in.loc()); + + bb_.add(std::move(parallel)); + reducer.teardown(bb_); + } else if (bt->dim() == 1) { + auto c0 = bb_.constant_zero(index_ty, in.loc()); + auto c_shape0 = + instant_constant_fold_add(bb_, create(0, &in.B(), index_ty, in.loc())); + bb_.foreach_loop( + array_view{c0}, array_view{c_shape0}, + [&](region_builder &bb, auto loop_vars) { + auto K = bb.create(in.tA() == transpose::T ? 0 : 1, &in.A(), index_ty, + in.loc()); + auto c_init = bb.constant_zero(bt->element_ty()); + auto acc = bb.for_loop( + c0, K, {}, {c_init}, {bt->element_ty()}, + [&](region_builder &bb, array_view args) { + auto index_list = std::array{loop_vars[0], args[0]}; + if (in.tA() == transpose::T) { + std::swap(index_list[0], index_list[1]); + } + auto a = + bb.create(&in.A(), index_list, at->element_ty(), in.loc()); + auto sum = mixed_precision_arithmetic(bb, bt->element_ty(), + args[1], a, in.loc()); + bb.create(array_view{sum}, in.loc()); + }); + blas_update(bb, in.atomic(), &in.alpha(), acc[0], &in.beta(), &in.B(), + {loop_vars[0]}, in.loc()); + }, + in.loc()); + } +} + +lower_linalg_pass::lower_linalg_pass(::tinytc_core_info const *info) : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void lower_linalg_pass::run_on_function(tinytc_func &fn) { + auto [core_cfg, tiling] = get_core_config_and_tiling(fn, info_); + + walk(fn, [&](tinytc_region ®) { + auto it = reg.begin(); + while (it != reg.end()) { + if (isa(*it) || isa(*it)) { + auto gen = linalg_generator{tiling, core_cfg, reg, it.get()}; + visit(gen, *it); + it = reg.insts().erase(gen.insertion_point()); + } else { + ++it; + } + } + }); +} + +} // namespace tinytc diff --git a/src/pass/lower_linalg.hpp b/src/pass/lower_linalg.hpp new file mode 100644 index 00000000..e16b92f5 --- /dev/null +++ b/src/pass/lower_linalg.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LOWER_LINALG_20240801_HPP +#define LOWER_LINALG_20240801_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class lower_linalg_pass { + public: + lower_linalg_pass(::tinytc_core_info const *info); + + void run_on_function(::tinytc_func &fn); + + private: + ::tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // LOWER_LINALG_20240801_HPP diff --git a/src/pass/slot_tracker.cpp b/src/pass/slot_tracker.cpp new file mode 100644 index 00000000..146f0521 --- /dev/null +++ b/src/pass/slot_tracker.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/slot_tracker.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/region.hpp" +#include "node/value.hpp" +#include "support/walk.hpp" +#include "util/iterator.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +void slot_tracker::set_slot(tinytc_value const &v) { + if (!v.has_name()) { + slot_map_[&v] = slot_++; + } +} + +void slot_tracker::run_on_function(tinytc_func &fn) { + slot_ = 0; + for (auto const &arg : fn.params()) { + set_slot(arg); + } + walk(fn, [this](tinytc_inst &i) { + for (auto const ® : i.child_regions()) { + for (auto const &p : reg.params()) { + set_slot(p); + } + } + for (auto const &result : i.results()) { + set_slot(result); + } + }); +} + +auto slot_tracker::get_slot(tinytc_value const &v) -> std::int64_t { + auto it = slot_map_.find(&v); + return it != slot_map_.end() ? it->second : -1; +} + +} // namespace tinytc diff --git a/src/pass/slot_tracker.hpp b/src/pass/slot_tracker.hpp new file mode 100644 index 00000000..9fc2e7b0 --- /dev/null +++ b/src/pass/slot_tracker.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef SLOT_TRACKER_20240418_HPP +#define SLOT_TRACKER_20240418_HPP + +#include "tinytc/types.h" + +#include +#include + +namespace tinytc { + +class slot_tracker { + public: + void run_on_function(tinytc_func &fn); + + auto get_slot(tinytc_value const &v) -> std::int64_t; + + private: + void set_slot(tinytc_value const &v); + + std::int64_t slot_ = 0; + std::unordered_map slot_map_; +}; + +} // namespace tinytc + +#endif // SLOT_TRACKER_20240418_HPP diff --git a/src/pass/stack.cpp b/src/pass/stack.cpp new file mode 100644 index 00000000..75b10794 --- /dev/null +++ b/src/pass/stack.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/stack.hpp" +#include "error.hpp" +#include "node/attr.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "support/walk.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +void set_stack_ptr_pass::run_on_function(tinytc_func &fn) { + struct allocation { + tinytc_value_t value; + std::int64_t start, stop; + }; + std::list allocs; + + walk(fn, [&allocs](tinytc_inst &i) { + visit(overloaded{ + [&allocs](alloca_inst a) { + auto t = dyn_cast(a.result().ty()); + if (t == nullptr) { + throw compilation_error(a.loc(), status::ir_expected_memref); + } + const auto alignment = [&]() -> std::int32_t { + if (auto aa = get_attr(a.get().attr(), "alignment"); aa) { + auto val = dyn_cast_or_throw(aa, [&] { + return status::ir_expected_integer_attribute; + })->value(); + return val; + } + return t->element_alignment(); + }(); + auto size = t->size_in_bytes(); + std::int64_t stack_ptr = 0; + auto it = allocs.begin(); + for (; it != allocs.end(); ++it) { + if (it->start - stack_ptr >= size) { + break; + } + stack_ptr = (1 + (it->stop - 1) / alignment) * alignment; + } + allocs.insert(it, allocation{&a.result(), stack_ptr, stack_ptr + size}); + a.stack_ptr(stack_ptr); + }, + [&allocs](lifetime_stop_inst s) { + int num = 0; + auto &v = s.object(); + for (auto it = allocs.begin(); it != allocs.end();) { + if (it->value == &v) { + it = allocs.erase(it); + ++num; + } else { + ++it; + } + } + if (num != 1) { + throw compilation_error( + s.loc(), status::internal_compiler_error, + "Incorrect lifetime_stop: value not found in list of allocations"); + } + }, + [](inst_view) {}}, + i); + }); +} + +} // namespace tinytc diff --git a/src/pass/stack.hpp b/src/pass/stack.hpp new file mode 100644 index 00000000..466ff212 --- /dev/null +++ b/src/pass/stack.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef STACK_20230413_HPP +#define STACK_20230413_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class set_stack_ptr_pass { + public: + void run_on_function(tinytc_func &fn); +}; + +} // namespace tinytc + +#endif // STACK_20230413_HPP diff --git a/src/pass/work_group_size.cpp b/src/pass/work_group_size.cpp new file mode 100644 index 00000000..7aeac8a0 --- /dev/null +++ b/src/pass/work_group_size.cpp @@ -0,0 +1,126 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "pass/work_group_size.hpp" +#include "codegen_tools.hpp" +#include "device_info.hpp" +#include "error.hpp" +#include "node/attr.hpp" +#include "node/func.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "support/walk.hpp" +#include "tiling.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/types.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { + +auto get_shapes(tinytc_func &fn) -> std::vector { + auto shape_set = std::unordered_set{}; + + walk(fn, [&shape_set](tinytc_inst &i) { + visit(overloaded{ + [&](blas_a2_inst in) { + auto a = get_memref_type(in.A())->element_ty(); + auto b = get_memref_type(in.B()); + if (b->dim() == 1) { + shape_set.insert(blas_shape{a, a, b->element_ty(), {b->shape(0), 0}}); + } else if (b->dim() >= 2) { + shape_set.insert( + blas_shape{a, a, b->element_ty(), {b->shape(0), b->shape(1)}}); + } + }, + [&](blas_a3_inst in) { + auto a = get_memref_type(in.A())->element_ty(); + auto b = get_memref_type(in.B())->element_ty(); + auto c = get_memref_type(in.C()); + if (c->dim() == 1) { + shape_set.insert(blas_shape{a, b, c->element_ty(), {c->shape(0), 0}}); + } else if (c->dim() >= 2) { + shape_set.insert(blas_shape{a, + b, + c->element_ty(), + {c->shape(0), c->shape(1)}, + isa(in.get())}); + } + }, + [](inst_view) {}}, + i); + }); + + return std::vector(shape_set.begin(), shape_set.end()); +} + +work_group_size_pass::work_group_size_pass(::tinytc_core_info const *info) + : info_(std::move(info)) { + if (info_ == nullptr) { + throw std::invalid_argument("info must not be nullptr"); + } +} + +void work_group_size_pass::run_on_function(tinytc_func &fn) { + auto sgs_attr = get_attr(fn.attr(), "subgroup_size"); + auto wgs_attr = get_attr(fn.attr(), "work_group_size"); + + if (sgs_attr && wgs_attr) { + return; + } + + const auto shapes = get_shapes(fn); + + auto ctx = fn.ty()->context(); + const auto subgroup_size = [&] { + if (!sgs_attr) { + auto sgs = suggest_subgroup_size(shapes, *info_); + sgs_attr = get(ctx, sgs); + return sgs; + } else { + return fn.subgroup_size(); + } + }(); + + core_config cfg = {}; + try { + cfg = info_->get_core_config(subgroup_size); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + + const auto work_group_size = [&] { + if (!wgs_attr) { + auto tiling = suggest_local_tiling(shapes, cfg); + auto wgs = std::array{tiling[0] * subgroup_size, tiling[1]}; + wgs_attr = get( + ctx, array_view{get(ctx, wgs[0]), get(ctx, wgs[1])}); + return wgs; + } else { + return fn.work_group_size(); + } + }(); + + if (work_group_size[0] % subgroup_size != 0) { + throw compilation_error(fn.loc(), status::unsupported_work_group_size, + "First work-group size mode must be divisible by subgroup size"); + } + if (work_group_size[0] * work_group_size[1] > cfg.max_work_group_size) { + throw compilation_error(fn.loc(), status::unsupported_work_group_size); + } + + fn.attr(get_dictionary_attr_with_sorted( + ctx, {tinytc_named_attr_t{get(ctx, "subgroup_size"), sgs_attr}, + tinytc_named_attr_t{get(ctx, "work_group_size"), wgs_attr}})); +} + +} // namespace tinytc diff --git a/src/pass/work_group_size.hpp b/src/pass/work_group_size.hpp new file mode 100644 index 00000000..2b3bfeee --- /dev/null +++ b/src/pass/work_group_size.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef WORK_GROUP_SIZE_20240311_HPP +#define WORK_GROUP_SIZE_20240311_HPP + +#include "tinytc/types.h" + +namespace tinytc { + +class work_group_size_pass { + public: + work_group_size_pass(tinytc_core_info const *info); + + void run_on_function(tinytc_func &fn); + + private: + tinytc_core_info const *info_; +}; + +} // namespace tinytc + +#endif // WORK_GROUP_SIZE_20240311_HPP diff --git a/src/passes.cpp b/src/passes.cpp deleted file mode 100644 index ef27bad0..00000000 --- a/src/passes.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "passes.hpp" -#include "device_info.hpp" -#include "kernel_metadata.hpp" -#include "node/data_type_node.hpp" -#include "node/function_node.hpp" -#include "node/program_node.hpp" -#include "visitor/check_ir.hpp" -#include "visitor/dump_ir.hpp" -#include "visitor/equal.hpp" -#include "visitor/insert_barrier.hpp" -#include "visitor/lifetime_analysis.hpp" -#include "visitor/metadata.hpp" -#include "visitor/opencl_ast.hpp" -#include "visitor/stack.hpp" -#include "visitor/work_group_size.hpp" - -#include - -using clir::visit; - -namespace tinytc { - -void check_ir(tinytc_prog const &p) { return visit(ir_checker{}, p); } - -void dump_ir(std::ostream &os, tinytc_func const &f) { visit(ir_dumper{os}, f); } -void dump_ir(std::ostream &os, tinytc_prog const &p) { visit(ir_dumper{os}, p); } - -clir::prog generate_opencl_ast(tinytc_prog const &p, ::tinytc_core_info const &info) { - return visit(opencl_ast{&info}, p); -} - -auto get_metadata(tinytc_prog const &p) -> std::unordered_map { - auto v = metadata{}; - visit(v, p); - return v.get_result(); -} - -void insert_barriers(tinytc_func &f) { visit(insert_barrier{}, f); } -void insert_barriers(tinytc_prog &p) { visit(insert_barrier{}, p); } - -void insert_lifetime_stop_inst(tinytc_func &f) { visit(lifetime_inserter{}, f); } -void insert_lifetime_stop_inst(tinytc_prog &p) { visit(lifetime_inserter{}, p); } - -bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b) { return visit(equal{}, a, b); } - -void set_stack_ptrs(tinytc_func &f) { visit(stack_ptr{}, f); } -void set_stack_ptrs(tinytc_prog &p) { visit(stack_ptr{}, p); } - -void set_work_group_size(tinytc_func &f, ::tinytc_core_info const &info) { - visit(work_group_size{&info}, f); -} -void set_work_group_size(tinytc_prog &p, ::tinytc_core_info const &info) { - visit(work_group_size{&info}, p); -} - -} // namespace tinytc - diff --git a/src/passes.def b/src/passes.def new file mode 100644 index 00000000..b1317da8 --- /dev/null +++ b/src/passes.def @@ -0,0 +1,17 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +FUNCTION_PASS("check-ir", check_ir_pass{}) +FUNCTION_PASS("constant-propagation", constant_propagation_pass{}, tinytc::optflag::unsafe_fp_math) +FUNCTION_PASS("dead-code-elimination", dead_code_elimination_pass{}) +FUNCTION_PASS("dump-control-flow-graph", dump_cfg_pass{std::cout}) +FUNCTION_PASS("dump-def-use", dump_def_use_pass{std::cout}) +FUNCTION_PASS("dump-ir", dump_ir_pass{std::cout}) +FUNCTION_PASS("insert-barrier", insert_barrier_pass{}) +FUNCTION_PASS("insert-lifetime-stop", insert_lifetime_stop_pass{}) +FUNCTION_PASS("set-stack-ptr", set_stack_ptr_pass{}) +FUNCTION_PASS_WITH_INFO("dump-gcd", [](tinytc_core_info const* info) { return dump_gcd_pass(std::cout, info); }) +FUNCTION_PASS_WITH_INFO("lower-coopmatrix", [](tinytc_core_info const* info) { return lower_coopmatrix_pass{info}; }) +FUNCTION_PASS_WITH_INFO("lower-foreach", [](tinytc_core_info const* info) { return lower_foreach_pass{info}; }) +FUNCTION_PASS_WITH_INFO("lower-linalg", [](tinytc_core_info const* info) { return lower_linalg_pass{info}; }) +FUNCTION_PASS_WITH_INFO("work-group-size", [](tinytc_core_info const* info) { return work_group_size_pass{info}; }) diff --git a/src/passes.hpp b/src/passes.hpp index 19515e18..9f297d8b 100644 --- a/src/passes.hpp +++ b/src/passes.hpp @@ -4,44 +4,22 @@ #ifndef PASSES_20240314_HPP #define PASSES_20240314_HPP -#include "kernel_metadata.hpp" +#include "node/prog.hpp" #include "tinytc/types.h" -#include -#include -#include -#include - namespace tinytc { -//! Check whether some IR rules are respected -void check_ir(tinytc_prog const &p); -//! Dump IR to ostream -void dump_ir(std::ostream &os, tinytc_func const &f); -//! Dump IR to ostream -void dump_ir(std::ostream &os, tinytc_prog const &p); -//! Generate OpenCL AST -clir::prog generate_opencl_ast(tinytc_prog const &p, tinytc_core_info const &info); -//! Get kernel metadata -auto get_metadata(tinytc_prog const &p) -> std::unordered_map; -//! Insert barriers where necessary -void insert_barriers(tinytc_func &f); -//! Insert barriers where necessary -void insert_barriers(tinytc_prog &p); -//! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_func &f); -//! Insert lifetime stop instructions for set_stack_ptrs pass -void insert_lifetime_stop_inst(tinytc_prog &p); -//! Check whether data types a and b are equal -bool is_equal(tinytc_data_type const &a, tinytc_data_type const &b); -//! Manage temporary memory requested by alloca -void set_stack_ptrs(tinytc_func &f); -//! Manage temporary memory requested by alloca -void set_stack_ptrs(tinytc_prog &p); -//! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_func &f, tinytc_core_info const &info); -//! Choose work group and subgroup size heuristically if not given explicitly -void set_work_group_size(tinytc_prog &p, tinytc_core_info const &info); +template void run_function_pass(FunctionPass &&pass, tinytc_prog &p) { + for (auto &fun : p) { + pass.run_on_function(fun); + } +} + +template void run_function_pass(FunctionPass &&pass, tinytc_prog const &p) { + for (auto const &fun : p) { + pass.run_on_function(fun); + } +} } // namespace tinytc diff --git a/src/precision_helper.cpp b/src/precision_helper.cpp deleted file mode 100644 index e5498a03..00000000 --- a/src/precision_helper.cpp +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "precision_helper.hpp" -#include "scalar_type.hpp" -#include "tinytc/tinytc.hpp" - -#include -#include -#include -#include - -#include - -using clir::address_space; -using clir::as_char; -using clir::as_double; -using clir::as_float; -using clir::as_int; -using clir::as_long; -using clir::as_short; -using clir::as_uchar; -using clir::as_uint; -using clir::as_ulong; -using clir::builtin_type; -using clir::cast; -using clir::expr; -using clir::get_sub_group_local_id; -using clir::intel_sub_group_block_read_ui; -using clir::intel_sub_group_block_read_ul; -using clir::intel_sub_group_block_read_us; -using clir::intel_sub_group_block_write_ui; -using clir::intel_sub_group_block_write_ul; -using clir::intel_sub_group_block_write_us; -using clir::pointer_to; - -namespace tinytc { - -precision_helper::precision_helper(scalar_type ty) : ty_(ty) {} -builtin_type precision_helper::base_type() const { return to_clir_builtin_ty(ty_); } -builtin_type precision_helper::block_rw_base_type() const { - auto bt = base_type(); - switch (bt) { - case builtin_type::short_t: - return builtin_type::ushort_t; - case builtin_type::int_t: - case builtin_type::float_t: - return builtin_type::uint_t; - case builtin_type::long_t: - case builtin_type::double_t: - return builtin_type::ulong_t; - default: - break; - } - return bt; -} -expr precision_helper::as_type(builtin_type ty, expr e) const { - switch (ty) { - case builtin_type::char_t: - return as_char(std::move(e)); - case builtin_type::uchar_t: - return as_uchar(std::move(e)); - case builtin_type::short_t: - return as_short(std::move(e)); - case builtin_type::ushort_t: - return as_ushort(std::move(e)); - case builtin_type::int_t: - return as_int(std::move(e)); - case builtin_type::uint_t: - return as_uint(std::move(e)); - case builtin_type::long_t: - return as_long(std::move(e)); - case builtin_type::ulong_t: - return as_ulong(std::move(e)); - case builtin_type::float_t: - return as_float(std::move(e)); - case builtin_type::double_t: - return as_double(std::move(e)); - default: - break; - } - return e; -} -short precision_helper::bits() const { return size(ty_) * 8; } -clir::data_type precision_helper::type(address_space as) const { - return clir::data_type(base_type(), as); -} -clir::data_type precision_helper::type(short size, address_space as) const { - return clir::data_type(base_type(), size, as); -} -// TODO: Think of something for integer constants -expr precision_helper::constant(double value) const { return expr(value, bits()); } -expr precision_helper::zero() const { return constant(0.0); } - -expr precision_helper::sub_group_block_read(expr address, address_space as) const { - auto const make_read = [](builtin_type bt, expr address) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_read_us(std::move(address)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_read_ui(std::move(address)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_read_ul(std::move(address)); - default: - break; - } - return address[get_sub_group_local_id()]; - }; - auto bt = block_rw_base_type(); - address = cast(pointer_to(clir::data_type(bt, as)), std::move(address)); - auto inst = make_read(bt, std::move(address)); - if (bt != base_type()) { - return as_type(base_type(), std::move(inst)); - } - return inst; -} -expr precision_helper::sub_group_block_write(expr address, expr data, address_space as) const { - auto const make_write = [](builtin_type bt, expr address, expr data) -> expr { - switch (bt) { - case builtin_type::short_t: - case builtin_type::ushort_t: - return intel_sub_group_block_write_us(std::move(address), std::move(data)); - case builtin_type::int_t: - case builtin_type::uint_t: - case builtin_type::float_t: - return intel_sub_group_block_write_ui(std::move(address), std::move(data)); - case builtin_type::long_t: - case builtin_type::ulong_t: - case builtin_type::double_t: - return intel_sub_group_block_write_ul(std::move(address), std::move(data)); - default: - break; - } - return address[get_sub_group_local_id()] = std::move(data); - }; - auto bt = block_rw_base_type(); - address = cast(pointer_to(clir::data_type(bt, as)), std::move(address)); - if (bt != base_type()) { - data = as_type(bt, std::move(data)); - } - return make_write(bt, std::move(address), std::move(data)); -} - -} // namespace tinytc diff --git a/src/precision_helper.hpp b/src/precision_helper.hpp deleted file mode 100644 index 32445697..00000000 --- a/src/precision_helper.hpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef PRECISION_HELPER_20230214_HPP -#define PRECISION_HELPER_20230214_HPP - -#include "tinytc/types.hpp" - -#include "clir/builtin_type.hpp" -#include "clir/data_type.hpp" -#include "clir/expr.hpp" - -namespace tinytc { - -class precision_helper { - public: - precision_helper(scalar_type ty); - clir::builtin_type base_type() const; - clir::builtin_type block_rw_base_type() const; - clir::expr as_type(clir::builtin_type ty, clir::expr e) const; - short bits() const; - clir::data_type type(clir::address_space as = clir::address_space::generic_t) const; - clir::data_type type(short size, clir::address_space as = clir::address_space::generic_t) const; - clir::expr constant(double value) const; - clir::expr zero() const; - clir::expr sub_group_block_read(clir::expr address, - clir::address_space as = clir::address_space::generic_t) const; - clir::expr sub_group_block_write(clir::expr address, clir::expr data, - clir::address_space as = clir::address_space::generic_t) const; - - private: - scalar_type ty_; -}; - -} // namespace tinytc - -#endif // PRECISION_HELPER_20230214_HPP diff --git a/src/recipe.cpp b/src/recipe.cpp index 75c9f18b..b5d1df3b 100644 --- a/src/recipe.cpp +++ b/src/recipe.cpp @@ -3,39 +3,24 @@ #include "recipe.hpp" #include "error.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/types.hpp" +#include "number_dispatch.hpp" +#include "tinytc/core.h" #include +#include #include namespace tinytc { -template bool is_argument_zero(std::size_t arg_size, const void *arg_value) { - T v; - memcpy(&v, arg_value, std::min(sizeof(v), arg_size)); - return v == T(0); -} -auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_value) -> bool { - switch (type) { - case scalar_type::index: - return is_argument_zero(arg_size, arg_value); - case scalar_type::i8: - return is_argument_zero(arg_size, arg_value); - case scalar_type::i16: - return is_argument_zero(arg_size, arg_value); - case scalar_type::i32: - return is_argument_zero(arg_size, arg_value); - case scalar_type::i64: - return is_argument_zero(arg_size, arg_value); - case scalar_type::f32: - return is_argument_zero(arg_size, arg_value); - case scalar_type::f64: - return is_argument_zero(arg_size, arg_value); - case scalar_type::i1: - break; - }; - throw status::invalid_arguments; +auto is_argument_zero(tinytc_type_t ty, std::size_t arg_size, const void *arg_value) -> bool { + return dispatch_number_to_native( + ty, + [](std::size_t arg_size, const void *arg_value) { + T v; + memcpy(&v, arg_value, std::min(sizeof(v), arg_size)); + return v == T(0); + }, + arg_size, arg_value); } } // namespace tinytc @@ -46,16 +31,14 @@ tinytc_status_t tinytc_recipe_get_prog(const_tinytc_recipe_t recipe, tinytc_prog if (recipe == nullptr || prg == nullptr) { return tinytc_status_invalid_arguments; } - return tinytc::exception_to_status_code( - [&] { *prg = tinytc::prog(recipe->get_program()).release(); }); + return tinytc::exception_to_status_code([&] { *prg = recipe->get_program(); }); } -tinytc_status_t tinytc_recipe_get_source(const_tinytc_recipe_t recipe, tinytc_source_t *src) { - if (recipe == nullptr || src == nullptr) { +tinytc_status_t tinytc_recipe_get_binary(const_tinytc_recipe_t recipe, tinytc_binary_t *bin) { + if (recipe == nullptr || bin == nullptr) { return tinytc_status_invalid_arguments; } - return tinytc::exception_to_status_code( - [&] { *src = tinytc::source(recipe->get_source()).release(); }); + return tinytc::exception_to_status_code([&] { *bin = recipe->get_binary(); }); } tinytc_status_t tinytc_recipe_release(tinytc_recipe_t obj) { @@ -82,8 +65,7 @@ tinytc_status_t tinytc_recipe_handler_get_recipe(const_tinytc_recipe_handler_t h if (handler == nullptr || recipe == nullptr) { return tinytc_status_invalid_arguments; } - return tinytc::exception_to_status_code( - [&] { *recipe = tinytc::recipe(handler->get_recipe()).release(); }); + return tinytc::exception_to_status_code([&] { *recipe = handler->get_recipe(); }); } tinytc_status_t tinytc_recipe_handler_release(tinytc_recipe_handler_t obj) { diff --git a/src/recipe.hpp b/src/recipe.hpp index 879245a0..0a055617 100644 --- a/src/recipe.hpp +++ b/src/recipe.hpp @@ -5,7 +5,6 @@ #define RECIPE_20240419_HPP #include "reference_counted.hpp" -#include "tinytc/tinytc.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" @@ -14,40 +13,42 @@ #include namespace tinytc { -auto is_argument_zero(scalar_type type, std::size_t arg_size, const void *arg_value) -> bool; +auto is_argument_zero(tinytc_type_t ty, std::size_t arg_size, const void *arg_value) -> bool; } // namespace tinytc struct tinytc_recipe : tinytc::reference_counted { public: - inline tinytc_recipe(tinytc::prog prg, tinytc::source src) - : prg_(std::move(prg)), src_(std::move(src)) {} + inline tinytc_recipe(tinytc::shared_handle prg, + tinytc::shared_handle bin) + : prg_(std::move(prg)), bin_(std::move(bin)) {} virtual ~tinytc_recipe() = default; - inline auto get_program() const -> tinytc::prog const & { return prg_; } - inline auto get_source() const -> tinytc::source const & { return src_; } + inline auto get_program() const -> tinytc_prog_t { return prg_.get(); } + inline auto get_binary() const -> tinytc_binary_t { return bin_.get(); } virtual auto num_kernels() const -> int = 0; virtual auto kernel_name(int kernel_num) const -> char const * = 0; private: - tinytc::prog prg_; - tinytc::source src_; + tinytc::shared_handle prg_; + tinytc::shared_handle bin_; }; struct tinytc_recipe_handler : tinytc::reference_counted { public: - inline tinytc_recipe_handler(tinytc::recipe recipe) : recipe_(std::move(recipe)) {} + inline tinytc_recipe_handler(tinytc::shared_handle recipe) + : recipe_(std::move(recipe)) {} virtual ~tinytc_recipe_handler() = default; - inline auto get_recipe() const -> tinytc::recipe const & { return recipe_; } + inline auto get_recipe() const -> tinytc_recipe_t { return recipe_.get(); } virtual void active_kernel(int kernel_num) = 0; virtual void arg(std::uint32_t arg_index, std::size_t arg_size, const void *arg_value) = 0; - virtual void mem_arg(std::uint32_t arg_index, const void *value, tinytc_mem_type_t type) = 0; + virtual void mem_arg(std::uint32_t arg_index, const void *value, tinytc_mem_type_t ty) = 0; virtual void howmany(std::int64_t num) = 0; private: - tinytc::recipe recipe_; + tinytc::shared_handle recipe_; }; #endif // RECIPE_20240419_HPP diff --git a/src/recipe/small_gemm_batched.cpp b/src/recipe/small_gemm_batched.cpp index 7ace3283..683a2db1 100644 --- a/src/recipe/small_gemm_batched.cpp +++ b/src/recipe/small_gemm_batched.cpp @@ -3,21 +3,22 @@ #include "small_gemm_batched.hpp" #include "error.hpp" -#include "parser.hpp" +#include "node/type.hpp" +#include "number.hpp" #include "recipe.hpp" -#include "reference_counted.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" +#include "tinytc/builder.hpp" +#include "tinytc/core.h" +#include "tinytc/core.hpp" #include "tinytc/types.h" #include "tinytc/types.hpp" -#include "util.hpp" +#include "util/casting.hpp" +#include #include #include #include #include #include -#include namespace tinytc { @@ -32,8 +33,10 @@ auto small_gemm_batched_kernel_name(small_gemm_batched_kernel k) -> char const * } throw status::invalid_arguments; } -small_gemm_batched_recipe::small_gemm_batched_recipe(prog prg, source src, scalar_type ty) - : ::tinytc_recipe(std::move(prg), std::move(src)), ty_(ty) {} +small_gemm_batched_recipe::small_gemm_batched_recipe(shared_handle prg, + shared_handle bin, + tinytc_type_t ty) + : ::tinytc_recipe(std::move(prg), std::move(bin)), ty_(ty) {} auto small_gemm_batched_recipe::num_kernels() const -> int { return static_cast(small_gemm_batched_kernel::num_kernels); } @@ -46,24 +49,21 @@ auto small_gemm_batched_recipe::kernel_name(int kernel_num) const -> char const using namespace tinytc; extern "C" { -tinytc_status_t -tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_core_info_t info, - tinytc_scalar_type_t ty, tinytc_transpose_t tA, - tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, - int64_t ldA, int64_t strideA, int64_t ldB, int64_t strideB, - int64_t ldC, int64_t strideC, tinytc_source_context_t ctx) { - if (recipe == nullptr || info == nullptr || M == TINYTC_DYNAMIC || N == TINYTC_DYNAMIC || - K == TINYTC_DYNAMIC || ldA == TINYTC_DYNAMIC || strideA == TINYTC_DYNAMIC || - ldB == TINYTC_DYNAMIC || strideB == TINYTC_DYNAMIC || ldC == TINYTC_DYNAMIC || - strideC == TINYTC_DYNAMIC) { +tinytc_status_t tinytc_recipe_small_gemm_batched_create( + tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_type_t ty, tinytc_transpose_t tA, + tinytc_transpose_t tB, int64_t M, int64_t N, int64_t K, int64_t ldA, int64_t strideA, + int64_t ldB, int64_t strideB, int64_t ldC, int64_t strideC) { + if (recipe == nullptr || info == nullptr || ty == nullptr || M == TINYTC_DYNAMIC || + N == TINYTC_DYNAMIC || K == TINYTC_DYNAMIC || ldA == TINYTC_DYNAMIC || + strideA == TINYTC_DYNAMIC || ldB == TINYTC_DYNAMIC || strideB == TINYTC_DYNAMIC || + ldC == TINYTC_DYNAMIC || strideC == TINYTC_DYNAMIC) { return tinytc_status_invalid_arguments; } + auto ctx = ty->context(); std::int32_t source_id = 0; - if (ctx) { - TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "small gemm batched recipe", "", &source_id)); - } + TINYTC_CHECK_STATUS( + tinytc_compiler_context_add_source(ctx, "recipe/small_gemm_batched.cpp", "", &source_id)); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -74,58 +74,77 @@ tinytc_recipe_small_gemm_batched_create(tinytc_recipe_t *recipe, const_tinytc_co ++l.end.column; return l; }; - - auto const selA = [&](std::int64_t N1, std::int64_t N2) { - return tA == tinytc_transpose_T ? N2 : N1; - }; - auto const selB = [&](std::int64_t N1, std::int64_t N2) { - return tB == tinytc_transpose_T ? N2 : N1; + auto const make_static_sizes = [](transpose t, std::int64_t A, std::int64_t B) { + auto s = std::array{A, B, 0}; + if (t == transpose::T) { + std::swap(s[0], s[1]); + } + return s; }; + return exception_to_status_code( [&] { - auto const ty_ = enum_cast(ty); + auto const index_ty = get(ctx); + auto const void_ty = get(ctx); auto const tA_ = enum_cast(tA); auto const tB_ = enum_cast(tB); - auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = fb.argument(make_memref(ty_, {selA(M, K), selA(K, M), dynamic}, - {1, ldA, strideA}, my_loc()), - "A", my_loc()); - auto B = fb.argument(make_memref(ty_, {selB(K, N), selB(N, K), dynamic}, - {1, ldB, strideB}, my_loc()), - "B", my_loc()); - auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N, dynamic}, {1, ldC, strideC}, my_loc()), - "C", my_loc()); - - auto beta = is_beta_nonzero ? std::move(beta_arg) : make_imm(0.0, ty_, my_loc()); - fb.body( - [&](region_builder &bb) { - auto gid = bb.add(make_group_id(my_loc())); - auto offsets = std::vector{make_index(0, my_loc()), - make_index(0, my_loc()), gid}; - auto size = std::vector{make_dynamic(my_loc()), - make_dynamic(my_loc()), value{}}; - auto a = bb.add(make_subview(A, offsets, size, my_loc())); - auto b = bb.add(make_subview(B, offsets, size, my_loc())); - auto c = bb.add(make_subview(C, offsets, size, my_loc())); - bb.add(make_gemm(tA_, tB_, false, alpha, std::move(a), std::move(b), beta, - std::move(c), my_loc())); - }, - my_loc()); + auto const kernel = [&](char const *name, bool is_beta_nonzero) { + auto const static_offsets = std::array{0, 0, dynamic}; + auto const A_static_sizes = make_static_sizes(tA_, M, K); + auto const B_static_sizes = make_static_sizes(tB_, K, N); + auto const C_static_sizes = make_static_sizes(transpose::N, M, N); + + const auto A_shape = std::array{A_static_sizes[0], A_static_sizes[1], dynamic}; + const auto B_shape = std::array{B_static_sizes[0], B_static_sizes[1], dynamic}; + const auto C_shape = std::array{M, N, dynamic}; + const auto A_stride = std::array{std::int64_t{1}, ldA, strideA}; + const auto B_stride = std::array{std::int64_t{1}, ldB, strideB}; + const auto C_stride = std::array{std::int64_t{1}, ldC, strideC}; + auto A_ty = get(ty, A_shape, A_stride, address_space::global); + auto B_ty = get(ty, B_shape, B_stride, address_space::global); + auto C_ty = get(ty, C_shape, C_stride, address_space::global); + auto f = create_func(name, {ty, A_ty, B_ty, ty, C_ty}, void_ty, my_loc()); + auto fn_body = get_body(f.get()); + auto params = std::array{}; + get_parameters(fn_body, params); + set_name(params[0], "alpha"); + set_name(params[1], "A"); + set_name(params[2], "B"); + set_name(params[3], "beta"); + set_name(params[4], "C"); + + auto bb = region_builder{fn_body}; + + auto gid = bb.create(comp3::x, index_ty, my_loc()); + auto at = get(ty, array_view(A_static_sizes.data(), 2), + array_view{A_stride.data(), 2}, address_space::global); + auto bt = get(ty, array_view(B_static_sizes.data(), 2), + array_view{B_stride.data(), 2}, address_space::global); + auto ct = get(ty, array_view(C_static_sizes.data(), 2), + array_view{C_stride.data(), 2}, address_space::global); + auto empty = array_view{}; + auto a = bb.create(static_offsets, A_static_sizes, params[1], + array_view{gid}, empty, at, my_loc()); + auto b = bb.create(static_offsets, B_static_sizes, params[2], + array_view{gid}, empty, bt, my_loc()); + auto c = bb.create(static_offsets, C_static_sizes, params[4], + array_view{gid}, empty, ct, my_loc()); + auto beta = is_beta_nonzero ? params[3] : bb.constant_zero(ty, my_loc()); + bb.create(false, tA_, tB_, params[0], std::move(a), std::move(b), beta, + std::move(c), my_loc()); + + return f; }; - auto pb = program_builder{}; - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - auto p = pb.get_product(my_loc()); - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); - *recipe = std::make_unique(std::move(p), source(src), ty_) + auto p = create_prog(ctx, my_loc()); + add_function( + p.get(), + kernel(small_gemm_batched_kernel_name(small_gemm_batched_kernel::gemm), true)); + add_function(p.get(), kernel(small_gemm_batched_kernel_name( + small_gemm_batched_kernel::gemm_beta0), + false)); + auto bin = compile_to_spirv_and_assemble(p.get(), info); + *recipe = std::make_unique(std::move(p), std::move(bin), ty) .release(); }, ctx); @@ -138,7 +157,7 @@ tinytc_status_t tinytc_recipe_small_gemm_batched_set_args( if (handler == nullptr) { return tinytc_status_invalid_arguments; } - auto recipe = dynamic_cast(handler->get_recipe().get()); + auto recipe = dynamic_cast(handler->get_recipe()); if (recipe == nullptr) { return tinytc_status_invalid_arguments; } diff --git a/src/recipe/small_gemm_batched.hpp b/src/recipe/small_gemm_batched.hpp index 0ffdc2da..e24a98d5 100644 --- a/src/recipe/small_gemm_batched.hpp +++ b/src/recipe/small_gemm_batched.hpp @@ -5,24 +5,26 @@ #define SMALL_GEMM_BATCHED_20240419_HPP #include "../recipe.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" +#include "tinytc/types.h" namespace tinytc { +template class shared_handle; + enum class small_gemm_batched_kernel : int { gemm = 0, gemm_beta0 = 1, num_kernels = 2 }; auto small_gemm_batched_kernel_name(small_gemm_batched_kernel k) -> char const *; struct small_gemm_batched_recipe : ::tinytc_recipe { public: - small_gemm_batched_recipe(prog prg, source src, scalar_type ty); + small_gemm_batched_recipe(shared_handle prg, shared_handle bin, + tinytc_type_t ty); auto num_kernels() const -> int override; auto kernel_name(int kernel_num) const -> char const * override; - inline auto ty() const -> scalar_type { return ty_; } + inline auto ty() const -> tinytc_type_t { return ty_; } private: - scalar_type ty_; + tinytc_type_t ty_; }; } // namespace tinytc diff --git a/src/recipe/tall_and_skinny.cpp b/src/recipe/tall_and_skinny.cpp index 90b2505f..170e7c86 100644 --- a/src/recipe/tall_and_skinny.cpp +++ b/src/recipe/tall_and_skinny.cpp @@ -4,13 +4,15 @@ #include "tall_and_skinny.hpp" #include "device_info.hpp" #include "error.hpp" -#include "parser.hpp" +#include "node/type.hpp" +#include "number.hpp" #include "recipe.hpp" -#include "reference_counted.hpp" #include "tiling.hpp" -#include "tinytc/tinytc.h" +#include "tinytc/builder.hpp" +#include "tinytc/core.h" +#include "tinytc/core.hpp" #include "tinytc/types.h" -#include "util.hpp" +#include "tinytc/types.hpp" #include #include @@ -19,7 +21,6 @@ #include #include #include -#include namespace tinytc { @@ -34,10 +35,11 @@ auto tall_and_skinny_kernel_name(tall_and_skinny_kernel k) -> char const * { } throw status::invalid_arguments; } -tall_and_skinny_recipe::tall_and_skinny_recipe(prog prg, source src, scalar_type ty, std::int64_t M, - std::int64_t ldA, std::int64_t ldB, std::int64_t ldC, - std::int32_t M_block_size) - : ::tinytc_recipe(std::move(prg), std::move(src)), ty_(ty), M_dyn_(is_dynamic_value(M)), +tall_and_skinny_recipe::tall_and_skinny_recipe(shared_handle prg, + shared_handle bin, tinytc_type_t ty, + std::int64_t M, std::int64_t ldA, std::int64_t ldB, + std::int64_t ldC, std::int32_t M_block_size) + : ::tinytc_recipe(std::move(prg), std::move(bin)), ty_(ty), M_dyn_(is_dynamic_value(M)), ldA_dyn_(is_dynamic_value(ldA)), ldB_dyn_(is_dynamic_value(ldB)), ldC_dyn_(is_dynamic_value(ldC)), M_block_size_(M_block_size) {} auto tall_and_skinny_recipe::num_kernels() const -> int { @@ -54,27 +56,26 @@ using namespace tinytc; extern "C" { tinytc_status_t tinytc_recipe_tall_and_skinny_create(tinytc_recipe_t *recipe, const_tinytc_core_info_t info, - tinytc_scalar_type_t ty, int64_t N, int64_t K, - int32_t M_block_size, - tinytc_source_context_t ctx) { + tinytc_type_t ty, int64_t N, int64_t K, + int32_t M_block_size) { return tinytc_recipe_tall_and_skinny_create_specialized(recipe, info, ty, TINYTC_DYNAMIC, N, K, TINYTC_DYNAMIC, TINYTC_DYNAMIC, - TINYTC_DYNAMIC, M_block_size, ctx); + TINYTC_DYNAMIC, 0, 0, 0, M_block_size); } tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( - tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_scalar_type_t ty, int64_t M, - int64_t N, int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t M_block_size, - tinytc_source_context_t ctx) { - if (recipe == nullptr || info == nullptr || N == TINYTC_DYNAMIC || K == TINYTC_DYNAMIC) { + tinytc_recipe_t *recipe, const_tinytc_core_info_t info, tinytc_type_t ty, int64_t M, int64_t N, + int64_t K, int64_t ldA, int64_t ldB, int64_t ldC, int32_t alignA, int32_t alignB, + int32_t alignC, int32_t M_block_size) { + if (recipe == nullptr || info == nullptr || ty == nullptr || N == TINYTC_DYNAMIC || + K == TINYTC_DYNAMIC) { return tinytc_status_invalid_arguments; } + auto ctx = ty->context(); std::int32_t source_id = 0; - if (ctx) { - TINYTC_CHECK_STATUS( - tinytc_source_context_add_source(ctx, "tall and skinny recipe", "", &source_id)); - } + TINYTC_CHECK_STATUS( + tinytc_compiler_context_add_source(ctx, "recipe/tall_and_skinny.cpp", "", &source_id)); auto const my_loc = [&](std::source_location const loc = std::source_location::current()) { auto l = location{}; @@ -86,16 +87,18 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( return l; }; - if (M_block_size == 0u) { + if (M_block_size == 0) { TINYTC_CHECK_STATUS(tinytc_recipe_tall_and_skinny_suggest_block_size(info, &M_block_size)); } return exception_to_status_code( [&] { - auto const ty_ = enum_cast(ty); + auto const bool_ty = get(ctx); + auto const void_ty = get(ctx); + auto const index_ty = get(ctx); - auto const shapes = std::vector{blas_shape{ty_, {M_block_size, N}}}; - auto [sgs, tiling] = suggest_subgroup_size_and_tiling(shapes, *info); + auto const bshape = blas_shape{ty, ty, ty, {M_block_size, N}, true}; + auto [sgs, tiling] = suggest_subgroup_size_and_tiling(array_view(bshape), *info); // We want to avoid working on too many columns in parallel as there is a high // chance to trash the cache due to the large pitch @@ -103,66 +106,109 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_create_specialized( tiling[1] /= 2; } - auto const body = [&](region_builder &bb, value &alpha, value &A, value &B, value &beta, - value &C) { - auto const gemm = [&](region_builder &bb, std::vector const &offsets, - value const &block_size) { - auto a = bb.add( - make_subview(A, offsets, {block_size, make_index(K, my_loc())}, my_loc())); - auto c = bb.add( - make_subview(C, offsets, {block_size, make_index(N, my_loc())}, my_loc())); - bb.add(make_gemm(transpose::N, transpose::N, false, alpha, a, B, beta, c, - my_loc())); - }; + auto const A_stride = std::array{1, ldA}; + auto const B_stride = std::array{1, ldB}; + auto const C_stride = std::array{1, ldC}; - auto const block_size_imm = make_index(M_block_size, my_loc()); - auto gid = bb.add(make_group_id(my_loc())); - auto m = bb.add( - make_arith(arithmetic::mul, gid, make_index(M_block_size, my_loc()), my_loc())); - auto const offsets = std::vector{m, make_index(0, my_loc())}; + auto const body = [&](region_builder &bb, tinytc_value_t alpha, tinytc_value_t A, + tinytc_value_t B, bool is_beta_nonzero, tinytc_value_t beta_arg, + tinytc_value_t C) { + auto c_M_block_size = bb.create(M_block_size, index_ty, my_loc()); + auto gid = bb.create(comp3::x, index_ty, my_loc()); + auto m = bb.create(gid, c_M_block_size, get_type(gid), my_loc()); + auto beta = is_beta_nonzero ? beta_arg : bb.constant_zero(ty, my_loc()); + + auto const static_offsets = std::array{dynamic, 0}; + auto const offsets = array_view(m); + + auto const static_gemm = [&](region_builder &bb) { + auto const A_static_sizes = std::array{M_block_size, K}; + auto const C_static_sizes = std::array{M_block_size, N}; + auto at = get(ty, A_static_sizes, A_stride, address_space::global); + auto ct = get(ty, C_static_sizes, C_stride, address_space::global); + auto a = bb.create(static_offsets, A_static_sizes, A, offsets, + array_view{}, at, my_loc()); + auto c = bb.create(static_offsets, C_static_sizes, C, offsets, + array_view{}, ct, my_loc()); + bb.create(false, transpose::N, transpose::N, alpha, a, B, beta, c, + my_loc()); + }; + auto const dynamic_gemm = [&](region_builder &bb, tinytc_value_t dyn_block_size) { + auto const A_static_sizes = std::array{dynamic, K}; + auto const C_static_sizes = std::array{dynamic, N}; + auto const sizes = array_view(dyn_block_size); + auto at = get(ty, A_static_sizes, A_stride, address_space::global); + auto ct = get(ty, C_static_sizes, C_stride, address_space::global); + auto a = bb.create(static_offsets, A_static_sizes, A, offsets, + sizes, at, my_loc()); + auto c = bb.create(static_offsets, C_static_sizes, C, offsets, + sizes, ct, my_loc()); + bb.create(false, transpose::N, transpose::N, alpha, a, B, beta, c, + my_loc()); + }; if (!is_dynamic_value(M) && M % M_block_size == 0) { - gemm(bb, offsets, block_size_imm); + static_gemm(bb); } else { - auto M_val = - is_dynamic_value(M) ? bb.add(make_size(C, 0, my_loc())) : make_index(M); - auto M_val_sub_m = bb.add(make_arith(arithmetic::sub, M_val, m, my_loc())); - auto cond = bb.add(make_cmp(cmp_condition::lt, M_val_sub_m, - make_index(M_block_size, my_loc()), my_loc())); - auto const dynamic_imm = make_dynamic(my_loc()); + + auto M_val = bb.create(0, C, index_ty, my_loc()); + auto M_val_sub_m = bb.create(M_val, m, get_type(m), my_loc()); + auto cond = + bb.create(M_val_sub_m, c_M_block_size, bool_ty, my_loc()); bb.ifelse( - cond, [&](region_builder &bb) { gemm(bb, offsets, dynamic_imm); }, - [&](region_builder &bb) { gemm(bb, offsets, block_size_imm); }, {}, - my_loc()); + cond, [&](region_builder &bb) { dynamic_gemm(bb, M_val_sub_m); }, + [&](region_builder &bb) { static_gemm(bb); }, {}, my_loc()); } }; - auto const kernel = [&](function_builder &fb, bool is_beta_nonzero) { - auto alpha = fb.argument(make_scalar(ty_, my_loc()), "alpha", my_loc()); - auto A = fb.argument(make_memref(ty_, {M, K}, {1, ldA}, my_loc()), "A", my_loc()); - auto B = fb.argument(make_memref(ty_, {K, N}, {1, ldB}, my_loc()), "B", my_loc()); - auto beta_arg = fb.argument(make_scalar(ty_, my_loc()), "beta", my_loc()); - auto C = fb.argument(make_memref(ty_, {M, N}, {1, ldC}, my_loc()), "C", my_loc()); - fb.subgroup_size(sgs); + auto const kernel = [&](char const *name, bool is_beta_nonzero) { + auto A_shape = std::array{M, K}; + auto B_shape = std::array{K, N}; + auto C_shape = std::array{M, N}; + auto A_ty = get(ty, A_shape, A_stride, address_space::global); + auto B_ty = get(ty, B_shape, B_stride, address_space::global); + auto C_ty = get(ty, C_shape, C_stride, address_space::global); + auto f = create_func(name, {ty, A_ty, B_ty, ty, C_ty}, void_ty, my_loc()); + + auto alignments = std::array, 3u>{ + {{1, alignA}, {2, alignB}, {4, alignC}}}; + auto align_attr = tinytc_named_attr_t{get(ctx, "align"), nullptr}; + for (auto &[param_no, alignment] : alignments) { + if (alignment > 0) { + align_attr.attr = get(ctx, alignment); + set_parameter_attr(f.get(), param_no, + get_dictionary_attr_with_sorted(ctx, align_attr)); + } + } + + auto fn_body = get_body(f.get()); + auto params = std::array{}; + get_parameters(fn_body, params); + set_name(params[0], "alpha"); + set_name(params[1], "A"); + set_name(params[2], "B"); + set_name(params[3], "beta"); + set_name(params[4], "C"); auto const wgs = tiling.work_group_size(sgs); - fb.work_group_size(wgs[0], wgs[1]); + auto const wgs_attr = tinytc_named_attr_t{ + get(ctx, "work_group_size"), + get(ctx, array_view{get(ctx, wgs[0]), + get(ctx, wgs[1])})}; + set_attr(f.get(), get_dictionary_attr_with_sorted(ctx, wgs_attr)); - auto beta = is_beta_nonzero ? beta_arg : make_imm(0.0, ty_, my_loc()); - fb.body([&](region_builder &bb) { body(bb, alpha, A, B, beta, C); }, my_loc()); + auto bb = region_builder{fn_body}; + body(bb, params[0], params[1], params[2], is_beta_nonzero, params[3], params[4]); + return f; }; - auto pb = program_builder{}; - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), - [&](function_builder &fb) { kernel(fb, true); }, my_loc()); - pb.create( - tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), - [&](function_builder &fb) { kernel(fb, false); }, my_loc()); - - auto p = pb.get_product(my_loc()); - tinytc_source_t src; - CHECK_STATUS(tinytc_prog_compile_to_opencl(&src, p.get(), info, ctx)); - *recipe = std::make_unique(std::move(p), source(src), ty_, M, + auto p = create_prog(ctx, my_loc()); + add_function(p.get(), + kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm), true)); + add_function( + p.get(), + kernel(tall_and_skinny_kernel_name(tall_and_skinny_kernel::gemm_beta0), false)); + auto bin = compile_to_spirv_and_assemble(p.get(), info); + *recipe = std::make_unique(std::move(p), std::move(bin), ty, M, ldA, ldB, ldC, M_block_size) .release(); }, @@ -174,9 +220,12 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_suggest_block_size(const_tinytc_co if (info == nullptr || M_block_size == nullptr) { return tinytc_status_invalid_arguments; } - - return tinytc::exception_to_status_code( - [&] { *M_block_size = std::min(128, info->minmax_work_group_size()); }); + return tinytc::exception_to_status_code([&] { + if (info->minmax_work_group_size() <= 0) { + throw tinytc::status::invalid_core_info; + } + *M_block_size = std::min(128, info->minmax_work_group_size()); + }); } tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( @@ -187,7 +236,7 @@ tinytc_status_t tinytc_recipe_tall_and_skinny_set_args( if (handler == nullptr) { return tinytc_status_invalid_arguments; } - auto recipe = dynamic_cast(handler->get_recipe().get()); + auto recipe = dynamic_cast(handler->get_recipe()); if (recipe == nullptr) { return tinytc_status_invalid_arguments; } diff --git a/src/recipe/tall_and_skinny.hpp b/src/recipe/tall_and_skinny.hpp index dd2aaca4..ef099c00 100644 --- a/src/recipe/tall_and_skinny.hpp +++ b/src/recipe/tall_and_skinny.hpp @@ -5,24 +5,26 @@ #define TALL_AND_SKINNY_20240422_HPP #include "../recipe.hpp" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.hpp" +#include "tinytc/types.h" #include namespace tinytc { +template class shared_handle; + enum class tall_and_skinny_kernel : int { gemm = 0, gemm_beta0 = 1, num_kernels = 2 }; auto tall_and_skinny_kernel_name(tall_and_skinny_kernel k) -> char const *; struct tall_and_skinny_recipe : ::tinytc_recipe { public: - tall_and_skinny_recipe(prog prg, source src, scalar_type ty, std::int64_t M, std::int64_t ldA, - std::int64_t ldB, std::int64_t ldC, std::int32_t M_block_size); + tall_and_skinny_recipe(shared_handle prg, shared_handle bin, + tinytc_type_t ty, std::int64_t M, std::int64_t ldA, std::int64_t ldB, + std::int64_t ldC, std::int32_t M_block_size); auto num_kernels() const -> int override; auto kernel_name(int kernel_num) const -> char const * override; - inline auto ty() const -> scalar_type { return ty_; } + inline auto ty() const -> tinytc_type_t { return ty_; } inline auto M_block_size() const -> std::int32_t { return M_block_size_; } inline auto is_M_dynamic() const -> bool { return M_dyn_; } @@ -31,7 +33,7 @@ struct tall_and_skinny_recipe : ::tinytc_recipe { inline auto is_ldC_dynamic() const -> bool { return ldC_dyn_; } private: - scalar_type ty_; + tinytc_type_t ty_; bool M_dyn_, ldA_dyn_, ldB_dyn_, ldC_dyn_; std::int32_t M_block_size_; }; diff --git a/src/region.cpp b/src/region.cpp deleted file mode 100644 index ea9a507d..00000000 --- a/src/region.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "error.hpp" -#include "location.hpp" -#include "node/region_node.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/tinytc.hpp" -#include "tinytc/types.h" - -#include -#include -#include -#include - -using namespace tinytc; - -extern "C" { - -tinytc_status_t tinytc_region_create(tinytc_region_t *reg, uint32_t instruction_list_size, - tinytc_inst_t *instruction_list, - const tinytc_location_t *loc) { - if (reg == nullptr || (instruction_list_size > 0 && instruction_list == nullptr)) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - auto inst_vec = std::vector(); - inst_vec.reserve(instruction_list_size); - for (uint32_t i = 0; i < instruction_list_size; ++i) { - inst_vec.emplace_back(inst(instruction_list[i], true)); - } - *reg = std::make_unique(std::move(inst_vec), get_optional(loc)).release(); - }); -} - -tinytc_status_t tinytc_region_release(tinytc_region_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_region_retain(tinytc_region_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/required_extensions.cpp b/src/required_extensions.cpp deleted file mode 100644 index ea4d3edc..00000000 --- a/src/required_extensions.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "required_extensions.hpp" - -#include -#include - -#include - -namespace tinytc { - -auto ext_list(std::vector const &ext) -> std::vector { - auto result = std::vector{}; - result.reserve(ext.size() + 1); - for (auto const &e : ext) { - result.emplace_back(clir::to_string(e)); - } - result.emplace_back("cl_khr_fp64"); - return result; -} - -auto required_extensions(clir::func f) -> std::vector { - return ext_list(clir::get_required_extensions(std::move(f))); -} -auto required_extensions(clir::prog p) -> std::vector { - return ext_list(clir::get_required_extensions(std::move(p))); -} - -} // namespace tinytc diff --git a/src/required_extensions.hpp b/src/required_extensions.hpp deleted file mode 100644 index e4c5a23a..00000000 --- a/src/required_extensions.hpp +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef REQUIRED_EXTENSIONS_20240416_HPP -#define REQUIRED_EXTENSIONS_20240416_HPP - -#include - -#include -#include - -namespace tinytc { - -auto required_extensions(clir::func f) -> std::vector; -auto required_extensions(clir::prog p) -> std::vector; - -} // namespace tinytc - -#endif // REQUIRED_EXTENSIONS_20240416_HPP diff --git a/src/scalar_type.cpp b/src/scalar_type.cpp deleted file mode 100644 index af4ef4d1..00000000 --- a/src/scalar_type.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "scalar_type.hpp" -#include "tinytc/tinytc.h" -#include "tinytc/types.h" -#include "tinytc/types.hpp" - -#include - -namespace tinytc { - -bool is_floating_type(scalar_type ty) { - switch (ty) { - case scalar_type::f32: - case scalar_type::f64: - return true; - default: - break; - } - return false; -} - -clir::builtin_type to_clir_builtin_ty(scalar_type ty) { - switch (ty) { - case scalar_type::i1: - return clir::builtin_type::bool_t; - case scalar_type::i8: - return clir::builtin_type::char_t; - case scalar_type::i16: - return clir::builtin_type::short_t; - case scalar_type::i32: - return clir::builtin_type::int_t; - case scalar_type::i64: - return clir::builtin_type::long_t; - case scalar_type::index: - return clir::builtin_type::long_t; - case scalar_type::f32: - return clir::builtin_type::float_t; - case scalar_type::f64: - return clir::builtin_type::double_t; - } - return clir::builtin_type::void_t; -} - -clir::data_type to_clir_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return clir::data_type(to_clir_builtin_ty(ty), as, q); -} - -clir::builtin_type to_clir_atomic_builtin_ty(scalar_type ty) { - switch (ty) { - case scalar_type::i32: - return clir::builtin_type::atomic_int_t; - case scalar_type::i64: - return clir::builtin_type::atomic_long_t; - case scalar_type::index: - return clir::builtin_type::atomic_long_t; - case scalar_type::f32: - return clir::builtin_type::atomic_float_t; - case scalar_type::f64: - return clir::builtin_type::atomic_double_t; - default: - break; - } - return clir::builtin_type::void_t; -} - -clir::data_type to_clir_atomic_ty(scalar_type ty, clir::address_space as, clir::type_qualifier q) { - return clir::data_type(to_clir_atomic_builtin_ty(ty), as, q); -} - -} // namespace tinytc - -char const *tinytc_scalar_type_to_string(tinytc_scalar_type_t ty) { - switch (ty) { - case tinytc_scalar_type_i1: - return "i1"; - case tinytc_scalar_type_i8: - return "i8"; - case tinytc_scalar_type_i16: - return "i16"; - case tinytc_scalar_type_i32: - return "i32"; - case tinytc_scalar_type_i64: - return "i64"; - case tinytc_scalar_type_index: - return "index"; - case tinytc_scalar_type_f32: - return "f32"; - case tinytc_scalar_type_f64: - return "f64"; - } - return "unknown"; -} -size_t tinytc_scalar_type_size(tinytc_scalar_type_t ty) { - switch (ty) { - case tinytc_scalar_type_i1: - case tinytc_scalar_type_i8: - return 1; - case tinytc_scalar_type_i16: - return 2; - case tinytc_scalar_type_i32: - case tinytc_scalar_type_f32: - return 4; - case tinytc_scalar_type_i64: - case tinytc_scalar_type_index: - case tinytc_scalar_type_f64: - return 8; - } - return 0; -} diff --git a/src/scalar_type.hpp b/src/scalar_type.hpp deleted file mode 100644 index ff0428c1..00000000 --- a/src/scalar_type.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause -// -#ifndef SCALAR_TYPE_20240411_HPP -#define SCALAR_TYPE_20240411_HPP - -#include "tinytc/types.hpp" - -#include -#include - -namespace tinytc { - -bool is_floating_type(scalar_type ty); -clir::builtin_type to_clir_builtin_ty(scalar_type ty); -clir::builtin_type to_clir_atomic_builtin_ty(scalar_type ty); -clir::data_type to_clir_ty(scalar_type ty, clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); -clir::data_type to_clir_atomic_ty(scalar_type ty, - clir::address_space as = clir::address_space::generic_t, - clir::type_qualifier q = clir::type_qualifier::none); - -} // namespace tinytc - -#endif // SCALAR_TYPE_20240411_HPP diff --git a/src/slice.hpp b/src/slice.hpp deleted file mode 100644 index d545242e..00000000 --- a/src/slice.hpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SLICE_20240412_HPP -#define SLICE_20240412_HPP - -#include "tinytc/tinytc.hpp" - -#include - -namespace tinytc { - -//! Slice storing offset:size -class slice : public std::pair { - public: - //! ctor - inline slice(value offset = {}, value size = {}) - : std::pair{std::move(offset), std::move(size)} {} -}; - -} // namespace tinytc - -#endif // SLICE_20240412_HPP diff --git a/src/source.cpp b/src/source.cpp deleted file mode 100644 index 18038a0d..00000000 --- a/src/source.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#include "source.hpp" -#include "error.hpp" -#include "tinytc/tinytc.h" - -#include - -using namespace tinytc; - -extern "C" { -tinytc_status_t tinytc_source_get_code(const_tinytc_source_t src, size_t *length, - char const **code) { - if (src == nullptr || length == nullptr || code == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *length = src->size(); - *code = src->code(); - }); -} - -tinytc_status_t tinytc_source_get_location(const_tinytc_source_t src, tinytc_location_t *loc) { - if (src == nullptr || loc == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *loc = src->code_loc(); }); -} - -tinytc_status_t tinytc_source_get_core_features(const_tinytc_source_t src, - tinytc_core_feature_flags_t *core_features) { - if (src == nullptr || core_features == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { *core_features = src->core_features(); }); -} - -tinytc_status_t tinytc_source_get_extensions(const_tinytc_source_t src, uint32_t *extensions_size, - char const *const **extensions) { - if (src == nullptr || extensions_size == nullptr || extensions == nullptr) { - return tinytc_status_invalid_arguments; - } - return exception_to_status_code([&] { - *extensions_size = src->required_extensions().size(); - *extensions = src->required_extensions().data(); - }); -} - -tinytc_status_t tinytc_source_release(tinytc_source_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - auto ref_count = obj->dec_ref(); - if (ref_count == 0) { - delete obj; - } - return tinytc_status_success; -} - -tinytc_status_t tinytc_source_retain(tinytc_source_t obj) { - if (obj == nullptr) { - return tinytc_status_invalid_arguments; - } - obj->inc_ref(); - return tinytc_status_success; -} -} diff --git a/src/source.hpp b/src/source.hpp deleted file mode 100644 index d07017e7..00000000 --- a/src/source.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: BSD-3-Clause - -#ifndef SOURCE_20240412_HPP -#define SOURCE_20240412_HPP - -#include "reference_counted.hpp" -#include "tinytc/types.h" - -#include -#include -#include -#include - -struct tinytc_source : tinytc::reference_counted { - public: - inline tinytc_source(std::string code, tinytc_location const &code_loc, - std::vector required_extensions, - tinytc_core_feature_flags_t core_features) - : code_(std::move(code)), code_loc_(code_loc), - required_extensions_(std::move(required_extensions)), core_features_(core_features) {} - - inline auto code() const -> char const * { return code_.c_str(); } - inline auto code_loc() const -> tinytc_location const & { return code_loc_; } - inline auto size() const -> std::size_t { return code_.size(); } - inline auto const &required_extensions() const { return required_extensions_; } - inline auto core_features() const noexcept -> tinytc_core_feature_flags_t { - return core_features_; - } - - private: - std::string code_; - tinytc_location code_loc_; - std::vector required_extensions_; - tinytc_core_feature_flags_t core_features_; -}; - -#endif // SOURCE_20240412_HPP diff --git a/src/spv/block2d_diy.cpp b/src/spv/block2d_diy.cpp new file mode 100644 index 00000000..c09c0229 --- /dev/null +++ b/src/spv/block2d_diy.cpp @@ -0,0 +1,320 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/block2d_diy.hpp" +#include "compiler_context.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "spv/xe_constants.hpp" +#include "support/temp_counter.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/math.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto block_config::block_size_in_bytes() const -> std::int32_t { + return element_size * array_length * rows * cols; +} +auto block_config::block_size_in_num_grf() const -> std::int32_t { + return block_size_in_bytes() / xe::grf_size; +} +auto block_config::byte_offset(std::int32_t row, std::int32_t col, std::int32_t array_idx, + std::int32_t col_block, std::int32_t row_block) const + -> std::int32_t { + std::int32_t offset = 0; + if (transpose) { + offset = col_block; + offset = row_block + offset * row_blocks; + } else { + offset = row_block; + offset = col_block + offset * col_blocks; + } + offset = array_idx + offset * array_length; + offset = col + offset * cols; + offset = row + offset * rows; + return offset * element_size; +} +auto block_config::origin(std::int32_t row, std::int32_t col, std::int32_t array_idx, + std::int32_t col_block, std::int32_t row_block) const + -> std::array { + const auto offset = byte_offset(row, col, array_idx, col_block, row_block); + return region_origin(element_size, offset); +} +auto block_config::total_rows() const -> std::int32_t { return array_length * rows * row_blocks; } + +auto lsc_data_size(std::int32_t element_size) -> std::int32_t { + if (!is_positive_power_of_two(element_size) || element_size > 8) { + throw status::internal_compiler_error; + } + return ilog2(element_size); +} +auto region_origin(std::int32_t element_size, std::int32_t byte_offset) + -> std::array { + return {byte_offset / xe::grf_size, byte_offset % xe::grf_size / element_size}; +} +auto visa_type(tinytc_type_t ty) -> char const * { + return visit( + overloaded{[&](i8_type &) { return "b"; }, // + [&](i16_type &) { return "w"; }, // + [&](i32_type &) { return "d"; }, // + [&](i64_type &) { return "q"; }, // + [&](index_type &ty) { + const auto idx_width = ty.context()->index_bit_width(); + if (idx_width == 64) { + return "q"; + } else if (idx_width == 32) { + return "d"; + } + throw status::not_implemented; + }, + [&](bf16_type &) { return "bf"; }, // + [&](f16_type &) { return "hf"; }, // + [&](f32_type &) { return "f"; }, // + [&](f64_type &) { return "df"; }, // + [](tinytc_type &) -> char const * { throw status::internal_compiler_error; }}, + *ty); +} + +/** + * This routine generates transpose code for 8x8 matrices of d32 type. Multiple 8x8 matrices may be + * packed side-by-side. E.g. for num_8x8_blocks=2 the GRF layout, say starting at reg r93, + * is expected to be + * + * r93: a_11 ... a_18 b_11 ... b_18 + * ...: ... ... ... ... + * r100: a_81 ... a_88 b_81 ... b_88 + * + * The transposition is done in-place and we should get + * + * r93: a_11 ... a_81 b_11 ... b_81 + * ...: ... ... ... ... + * r100: a_18 ... a_88 b_18 ... b_88 + * + * This routine can also be called for 16x8 half types or 32x8 i8 types. + * Then, the routine generates transpose + VNNI transform. + * + */ +auto make_d32_transpose8x8(std::ostream &oasm, char const *matrix, std::size_t offset, + temp_counter &make_tmp, std::int32_t num_8x8_blocks = 1) { + constexpr std::int32_t element_size = 4; + const std::int32_t stride = 8 * element_size * num_8x8_blocks; + const std::int32_t num_elements = 8 * stride / element_size; + + auto dst_d = make_tmp("dst_d"); + auto dst_q = make_tmp("dst_q"); + oasm << ".decl " << dst_d << " v_type=G type=d num_elts=" << num_elements + << " align=wordx32 alias=<" << matrix << "," << offset << ">\n"; + oasm << ".decl " << dst_q << " v_type=G type=q num_elts=" << num_elements / 2 + << " align=wordx32 alias=<" << matrix << "," << offset << ">\n"; + + // 2x2 transpose + const std::int32_t exec_size = 4 * num_8x8_blocks; + for (std::int32_t r = 0; r < 4; ++r) { + auto ttmp = make_tmp("ttmp_d"); + oasm << ".decl " << ttmp << " v_type=G type=d num_elts=" << exec_size << " align=wordx32\n"; + auto [R1, C1] = region_origin(element_size, 2 * r * stride + element_size); + auto [R2, C2] = region_origin(element_size, 2 * r * stride + stride); + oasm << "mov (M1," << exec_size << ") " << ttmp << "(0,0)<1> " << dst_d << "(" << R1 << "," + << C1 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_d << "(" << R1 << "," << C1 << ")<2> " + << dst_d << "(" << R2 << "," << C2 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_d << "(" << R2 << "," << C2 << ")<2> " + << ttmp << "(0,0)<1;1,0>\n"; + } + // 4x4 transpose + for (std::int32_t r = 0; r < 4; r += 2) { + auto ttmp = make_tmp("ttmp_q"); + oasm << ".decl " << ttmp << " v_type=G type=q num_elts=" << exec_size << " align=wordx32\n"; + auto [R1, C1] = region_origin(2 * element_size, 2 * r * stride + 2 * element_size); + auto [R2, C2] = region_origin(2 * element_size, 2 * (r + 1) * stride); + oasm << "mov (M1," << exec_size << ") " << ttmp << "(0,0)<1> " << dst_q << "(" << R1 << "," + << C1 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_q << "(" << R1 << "," << C1 << ")<2> " + << dst_q << "(" << R2 << "," << C2 << ")<2;1,0>\n"; + oasm << "mov (M1," << exec_size << ") " << dst_q << "(" << R2 << "," << C2 << ")<2> " + << ttmp << "(0,0)<1;1,0>\n"; + } + // 8x8 transpose + for (std::int32_t r = 0; r < 4; ++r) { + auto ttmp = make_tmp("ttmp_d"); + auto [R1, C1] = region_origin(element_size, r * stride + 4 * element_size); + auto [R2, C2] = region_origin(element_size, (r + 4) * stride); + oasm << ".decl " << ttmp << " v_type=G type=d num_elts=" << exec_size << " align=wordx32\n"; + oasm << "mov (M1," << exec_size << ") " << ttmp << "(0,0)<1> " << dst_d << "(" << R1 << "," + << C1 << ")<8;4,1>\n"; + for (std::int32_t b = 0; b < 8 * num_8x8_blocks; b += 8) { + oasm << "mov (M1,4) " << dst_d << "(" << R1 << "," << C1 + b << ")<1> " << dst_d << "(" + << R2 << "," << C2 + b << ")<1;1,0>\n"; + oasm << "mov (M1,4) " << dst_d << "(" << R2 << "," << C2 + b << ")<1> " << ttmp << "(0," + << b / 2 << ")<1;1,0>\n"; + } + } +} + +struct block2d_native_helper { + block2d_native_helper(std::ostream &oasm, block_config const &cfg, temp_counter &make_tmp, + std::int32_t first_address_operand); + void header(); + template void walk(F &&io); + + inline auto temp(std::int32_t m, std::int32_t n) -> std::string const & { + return temps[n + m * cfg.col_blocks]; + } + + std::ostream &oasm; + block_config const &cfg; + std::vector temps; + std::string tempq; + std::int32_t first_; +}; + +block2d_native_helper::block2d_native_helper(std::ostream &oasm, block_config const &cfg, + temp_counter &make_tmp, + std::int32_t first_address_operand) + : oasm(oasm), cfg(cfg), temps(cfg.row_blocks * cfg.col_blocks), tempq{make_tmp("tempq")}, + first_{first_address_operand} { + std::generate(temps.begin(), temps.end(), [&]() { return make_tmp("temp"); }); +} + +void block2d_native_helper::header() { + const std::uint32_t block_size = + ((cfg.array_length - 1) << 16) | ((cfg.cols - 1) << 8) | (cfg.rows - 1); + const auto &tmp0 = temp(0, 0); + oasm << ".decl " << tmp0 << " v_type=G type=ud num_elts=8 align=wordx32\n" + << ".decl " << tempq << " v_type=G type=uq num_elts=4 align=wordx32 alias=<" << tmp0 + << ",0>\n" + << "mov (M1,1) " << tempq << "(0,0)<1> $" << first_ << "(0,0)<0;1,0>\n" + << "add (M1,1) " << tmp0 << "(0,2)<1> $" << first_ + 1 << "(0,0)<0;1,0> -1:d\n" + << "add (M1,1) " << tmp0 << "(0,3)<1> $" << first_ + 2 << "(0,0)<0;1,0> -1:d\n" + << "add (M1,1) " << tmp0 << "(0,4)<1> $" << first_ + 3 << "(0,0)<0;1,0> -1:d\n"; + if (cfg.pos0_shr) { + oasm << "shr (M1,1) " << tmp0 << "(0,5)<1> $" << first_ + 4 << "(0,0)<0;1,0> " + << cfg.pos0_shr << ":d\n"; + } else { + oasm << "mov (M1,1) " << tmp0 << "(0,5)<1> $" << first_ + 4 << "(0,0)<0;1,0>\n"; + } + oasm << "mov (M1,1) " << tmp0 << "(0,6)<1> $" << first_ + 5 << "(0,0)<0;1,0>\n" + << "mov (M1,1) " << tmp0 << "(0,7)<1> 0x" << std::hex << block_size << ":ud\n"; + for (std::int32_t m = 0; m < cfg.row_blocks; ++m) { + for (std::int32_t n = 0; n < cfg.col_blocks; ++n) { + const auto &tmp = temp(m, n); + if (m > 0 || n > 0) { + oasm << ".decl " << tmp << " v_type=G type=ud num_elts=8 align=wordx32\n"; + oasm << "mov (M1,8) " << tmp << "(0,0)<1> " << tmp0 << "(0,0)<1;1,0>\n"; + } + if (m > 0) { + oasm << "add (M1,1) " << tmp << "(0,5)<1> " << tmp << "(0,5)<0;1,0> 0x" + << m * cfg.rows * cfg.array_length << ":ud\n"; + } + if (n > 0) { + oasm << "add (M1,1) " << tmp << "(0,6)<1> " << tmp << "(0,6)<0;1,0> 0x" + << n * cfg.cols << ":ud\n"; + } + } + } +} + +template void block2d_native_helper::walk(F &&io) { + for (std::int32_t m = 0; m < cfg.row_blocks; ++m) { + for (std::int32_t n = 0; n < cfg.col_blocks; ++n) { + io(m, n); + } + } +} + +auto load_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string { + const std::uint32_t num_dst = std::min(31, cfg.block_size_in_num_grf()); + const std::uint32_t desc = [&] { + const std::uint32_t data_size = lsc_data_size(cfg.element_size); + std::uint32_t d = 3; + if (cfg.vnni) { + d |= 1 << 7; + } + if (cfg.transpose && !cfg.vnni) { + d |= 1 << 15; + } + d |= data_size << 9; + d |= num_dst << 20; + d |= 1 << 25; + return d; + }(); + + auto oasm = std::ostringstream{}; + auto h = block2d_native_helper(oasm, cfg, make_tmp, 1); + + oasm << "{\n"; + h.header(); + h.walk([&](std::int32_t m, std::int32_t n) { + const auto offset = cfg.byte_offset(0, 0, 0, n, m); + oasm << std::dec << "raw_sends.15.1.0." << num_dst << " (M1, 1) 0x0:ud 0x" << std::hex + << desc << ":ud " << h.temp(m, n) << ".0 %null.0 $0." << std::dec << offset << "\n"; + + if (cfg.vnni && cfg.transpose) { + for (std::int32_t array_idx = 0; array_idx < cfg.array_length; ++array_idx) { + make_d32_transpose8x8(oasm, "$0", cfg.byte_offset(0, 0, array_idx, n, m), make_tmp); + } + } + }); + oasm << "}\n"; + + return std::move(oasm).str(); +} +auto prefetch_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string { + const std::uint32_t desc = [&] { + const std::uint32_t data_size = lsc_data_size(cfg.element_size); + const std::uint32_t cache_control = cfg.cache_level == 1 ? 2 : 4; + std::uint32_t d = 3; + d |= data_size << 9; + d |= cache_control << 17; + d |= 1 << 25; + return d; + }(); + + auto oasm = std::ostringstream{}; + auto h = block2d_native_helper(oasm, cfg, make_tmp, 0); + + oasm << "{\n"; + h.header(); + h.walk([&](std::int32_t m, std::int32_t n) { + oasm << std::dec << "raw_sends.15.1.0.0 (M1, 1) 0x0:ud 0x" << std::hex << desc << ":ud " + << h.temp(m, n) << ".0 %null.0 %null.0\n"; + }); + oasm << "}\n"; + + return std::move(oasm).str(); +} +auto store_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string { + const std::uint32_t num_src1 = std::min(31, cfg.block_size_in_num_grf()); + const std::uint32_t desc = [&] { + const std::uint32_t data_size = lsc_data_size(cfg.element_size); + std::uint32_t d = 7; + d |= data_size << 9; + d |= 1 << 25; + return d; + }(); + + auto oasm = std::ostringstream{}; + auto h = block2d_native_helper(oasm, cfg, make_tmp, 1); + + oasm << "{\n"; + h.header(); + h.walk([&](std::int32_t m, std::int32_t n) { + const auto offset = cfg.byte_offset(0, 0, 0, n, m); + oasm << "raw_sends.15.1." << num_src1 << ".0 (M1, 1) 0x0:ud 0x" << std::hex << desc + << ":ud " << h.temp(m, n) << ".0 $0." << std::dec << offset << " %null.0\n"; + }); + oasm << "}\n"; + + return std::move(oasm).str(); +} + +} // namespace tinytc::spv diff --git a/src/spv/block2d_diy.hpp b/src/spv/block2d_diy.hpp new file mode 100644 index 00000000..1dce4cb1 --- /dev/null +++ b/src/spv/block2d_diy.hpp @@ -0,0 +1,52 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef BLOCK2D_DIY_20250219_HPP +#define BLOCK2D_DIY_20250219_HPP + +#include "tinytc/types.h" + +#include +#include +#include + +namespace tinytc { +class temp_counter; +} // namespace tinytc + +namespace tinytc::spv { + +struct block_config { + tinytc_type_t sty; + std::int32_t element_size; + std::int32_t array_length; + std::int32_t rows; + std::int32_t cols; + std::int32_t row_blocks; + std::int32_t col_blocks; + bool transpose; + bool vnni; + std::int32_t pos0_shr; // Number of bits to shift pos0 to the right (= divide by 2^pos0_shr) + std::int32_t cache_level; + + auto block_size_in_bytes() const -> std::int32_t; + auto block_size_in_num_grf() const -> std::int32_t; + auto byte_offset(std::int32_t row, std::int32_t col, std::int32_t array_idx, + std::int32_t col_block, std::int32_t row_block) const -> std::int32_t; + auto origin(std::int32_t row, std::int32_t col, std::int32_t array_idx, std::int32_t col_block, + std::int32_t row_block) const -> std::array; + auto total_rows() const -> std::int32_t; +}; + +auto lsc_data_size(std::int32_t element_size) -> std::int32_t; +auto region_origin(std::int32_t element_size, std::int32_t byte_offset) + -> std::array; +auto visa_type(tinytc_type_t ty) -> char const *; + +auto load_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string; +auto prefetch_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string; +auto store_block2d_native(block_config const &cfg, temp_counter &make_tmp) -> std::string; + +} // namespace tinytc::spv + +#endif // BLOCK2D_DIY_20250219_HPP diff --git a/src/spv/capex_util.cpp b/src/spv/capex_util.cpp new file mode 100644 index 00000000..0cd1fb98 --- /dev/null +++ b/src/spv/capex_util.cpp @@ -0,0 +1,686 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "capex_util.hpp" +#include "enums.hpp" + +namespace tinytc::spv { + +auto capabilities(ExecutionModel e) -> array_view { + switch (e) { + case ExecutionModel::Vertex: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::TessellationControl: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionModel::TessellationEvaluation: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionModel::Geometry: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionModel::Fragment: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::GLCompute: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionModel::Kernel: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionModel::TaskNV: { + constexpr static Capability values[] = {Capability::MeshShadingNV}; + return {values, 1}; + } + case ExecutionModel::MeshNV: { + constexpr static Capability values[] = {Capability::MeshShadingNV}; + return {values, 1}; + } + case ExecutionModel::RayGenerationKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::IntersectionKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::AnyHitKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::ClosestHitKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::MissKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::CallableKHR: { + constexpr static Capability values[] = {Capability::RayTracingNV, + Capability::RayTracingKHR}; + return {values, 2}; + } + case ExecutionModel::TaskEXT: { + constexpr static Capability values[] = {Capability::MeshShadingEXT}; + return {values, 1}; + } + case ExecutionModel::MeshEXT: { + constexpr static Capability values[] = {Capability::MeshShadingEXT}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(AddressingModel e) -> array_view { + switch (e) { + case AddressingModel::Physical32: { + constexpr static Capability values[] = {Capability::Addresses}; + return {values, 1}; + } + case AddressingModel::Physical64: { + constexpr static Capability values[] = {Capability::Addresses}; + return {values, 1}; + } + case AddressingModel::PhysicalStorageBuffer64: { + constexpr static Capability values[] = {Capability::PhysicalStorageBufferAddresses}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(MemoryModel e) -> array_view { + switch (e) { + case MemoryModel::Simple: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case MemoryModel::GLSL450: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case MemoryModel::OpenCL: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case MemoryModel::Vulkan: { + constexpr static Capability values[] = {Capability::VulkanMemoryModel}; + return {values, 1}; + } + default: + return {}; + } +} +auto capabilities(ExecutionMode e) -> array_view { + switch (e) { + case ExecutionMode::Invocations: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::SpacingEqual: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::SpacingFractionalEven: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::SpacingFractionalOdd: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::VertexOrderCw: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::VertexOrderCcw: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::PixelCenterInteger: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::OriginUpperLeft: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::OriginLowerLeft: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::EarlyFragmentTests: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::PointMode: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::Xfb: { + constexpr static Capability values[] = {Capability::TransformFeedback}; + return {values, 1}; + } + case ExecutionMode::DepthReplacing: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthGreater: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthLess: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::DepthUnchanged: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::LocalSizeHint: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::InputPoints: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::InputLines: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::InputLinesAdjacency: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::Triangles: { + constexpr static Capability values[] = {Capability::Geometry, Capability::Tessellation}; + return {values, 2}; + } + case ExecutionMode::InputTrianglesAdjacency: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::Quads: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::Isolines: { + constexpr static Capability values[] = {Capability::Tessellation}; + return {values, 1}; + } + case ExecutionMode::OutputVertices: { + constexpr static Capability values[] = {Capability::Geometry, Capability::Tessellation, + Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 4}; + } + case ExecutionMode::OutputPoints: { + constexpr static Capability values[] = {Capability::Geometry, Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 3}; + } + case ExecutionMode::OutputLineStrip: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::OutputTriangleStrip: { + constexpr static Capability values[] = {Capability::Geometry}; + return {values, 1}; + } + case ExecutionMode::VecTypeHint: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::ContractionOff: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::Initializer: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::Finalizer: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::SubgroupSize: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::SubgroupsPerWorkgroup: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::SubgroupsPerWorkgroupId: { + constexpr static Capability values[] = {Capability::SubgroupDispatch}; + return {values, 1}; + } + case ExecutionMode::LocalSizeHintId: { + constexpr static Capability values[] = {Capability::Kernel}; + return {values, 1}; + } + case ExecutionMode::NonCoherentColorAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageColorReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::NonCoherentDepthAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageDepthReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::NonCoherentStencilAttachmentReadEXT: { + constexpr static Capability values[] = {Capability::TileImageStencilReadAccessEXT}; + return {values, 1}; + } + case ExecutionMode::SubgroupUniformControlFlowKHR: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::PostDepthCoverage: { + constexpr static Capability values[] = {Capability::SampleMaskPostDepthCoverage}; + return {values, 1}; + } + case ExecutionMode::DenormPreserve: { + constexpr static Capability values[] = {Capability::DenormPreserve}; + return {values, 1}; + } + case ExecutionMode::DenormFlushToZero: { + constexpr static Capability values[] = {Capability::DenormFlushToZero}; + return {values, 1}; + } + case ExecutionMode::SignedZeroInfNanPreserve: { + constexpr static Capability values[] = {Capability::SignedZeroInfNanPreserve}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTE: { + constexpr static Capability values[] = {Capability::RoundingModeRTE}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTZ: { + constexpr static Capability values[] = {Capability::RoundingModeRTZ}; + return {values, 1}; + } + case ExecutionMode::NonCoherentTileAttachmentReadQCOM: { + constexpr static Capability values[] = {Capability::TileShadingQCOM}; + return {values, 1}; + } + case ExecutionMode::TileShadingRateQCOM: { + constexpr static Capability values[] = {Capability::TileShadingQCOM}; + return {values, 1}; + } + case ExecutionMode::EarlyAndLateFragmentTestsAMD: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::StencilRefReplacingEXT: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::CoalescingAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::IsApiEntryAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::MaxNodeRecursionAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::StaticNumWorkgroupsAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::ShaderIndexAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::MaxNumWorkgroupsAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefGreaterFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefLessFrontAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefGreaterBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::StencilRefLessBackAMD: { + constexpr static Capability values[] = {Capability::StencilExportEXT}; + return {values, 1}; + } + case ExecutionMode::QuadDerivativesKHR: { + constexpr static Capability values[] = {Capability::QuadControlKHR}; + return {values, 1}; + } + case ExecutionMode::RequireFullQuadsKHR: { + constexpr static Capability values[] = {Capability::QuadControlKHR}; + return {values, 1}; + } + case ExecutionMode::SharesInputWithAMDX: { + constexpr static Capability values[] = {Capability::ShaderEnqueueAMDX}; + return {values, 1}; + } + case ExecutionMode::OutputLinesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::OutputPrimitivesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupQuadsKHR: { + constexpr static Capability values[] = {Capability::ComputeDerivativeGroupQuadsKHR}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupLinearKHR: { + constexpr static Capability values[] = {Capability::ComputeDerivativeGroupLinearKHR}; + return {values, 2}; + } + case ExecutionMode::OutputTrianglesEXT: { + constexpr static Capability values[] = {Capability::MeshShadingNV, + Capability::MeshShadingEXT}; + return {values, 2}; + } + case ExecutionMode::PixelInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderPixelInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::PixelInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderPixelInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderSampleInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderSampleInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockOrderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderShadingRateInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockUnorderedEXT: { + constexpr static Capability values[] = {Capability::FragmentShaderShadingRateInterlockEXT}; + return {values, 1}; + } + case ExecutionMode::SharedLocalMemorySizeINTEL: { + constexpr static Capability values[] = {Capability::VectorComputeINTEL}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTPINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTNINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::FloatingPointModeALTINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::FloatingPointModeIEEEINTEL: { + constexpr static Capability values[] = {Capability::RoundToInfinityINTEL}; + return {values, 1}; + } + case ExecutionMode::MaxWorkgroupSizeINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::MaxWorkDimINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::NoGlobalOffsetINTEL: { + constexpr static Capability values[] = {Capability::KernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::NumSIMDWorkitemsINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::SchedulerTargetFmaxMhzINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximallyReconvergesKHR: { + constexpr static Capability values[] = {Capability::Shader}; + return {values, 1}; + } + case ExecutionMode::FPFastMathDefault: { + constexpr static Capability values[] = {Capability::FloatControls2}; + return {values, 1}; + } + case ExecutionMode::StreamingInterfaceINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesINTEL}; + return {values, 1}; + } + case ExecutionMode::RegisterMapInterfaceINTEL: { + constexpr static Capability values[] = {Capability::FPGAKernelAttributesv2INTEL}; + return {values, 1}; + } + case ExecutionMode::NamedBarrierCountINTEL: { + constexpr static Capability values[] = {Capability::VectorComputeINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximumRegistersINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + case ExecutionMode::MaximumRegistersIdINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + case ExecutionMode::NamedMaximumRegistersINTEL: { + constexpr static Capability values[] = {Capability::RegisterLimitsINTEL}; + return {values, 1}; + } + default: + return {}; + } +} +auto extensions(ExecutionModel e) -> array_view { + switch (e) { + default: + return {}; + } +} +auto extensions(AddressingModel e) -> array_view { + switch (e) { + case AddressingModel::PhysicalStorageBuffer64: { + constexpr static char const *values[] = {"SPV_EXT_physical_storage_buffer", + "SPV_KHR_physical_storage_buffer"}; + return {values, 2}; + } + default: + return {}; + } +} +auto extensions(MemoryModel e) -> array_view { + switch (e) { + case MemoryModel::Vulkan: { + constexpr static char const *values[] = {"SPV_KHR_vulkan_memory_model"}; + return {values, 1}; + } + default: + return {}; + } +} +auto extensions(ExecutionMode e) -> array_view { + switch (e) { + case ExecutionMode::SubgroupUniformControlFlowKHR: { + constexpr static char const *values[] = {"SPV_KHR_subgroup_uniform_control_flow"}; + return {values, 1}; + } + case ExecutionMode::PostDepthCoverage: { + constexpr static char const *values[] = {"SPV_KHR_post_depth_coverage"}; + return {values, 1}; + } + case ExecutionMode::DenormPreserve: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::DenormFlushToZero: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::SignedZeroInfNanPreserve: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTE: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::RoundingModeRTZ: { + constexpr static char const *values[] = {"SPV_KHR_float_controls"}; + return {values, 1}; + } + case ExecutionMode::EarlyAndLateFragmentTestsAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests"}; + return {values, 1}; + } + case ExecutionMode::StencilRefReplacingEXT: { + constexpr static char const *values[] = {"SPV_EXT_shader_stencil_export"}; + return {values, 1}; + } + case ExecutionMode::StencilRefUnchangedFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefGreaterFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefLessFrontAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefUnchangedBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefGreaterBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::StencilRefLessBackAMD: { + constexpr static char const *values[] = {"SPV_AMD_shader_early_and_late_fragment_tests", + "SPV_EXT_shader_stencil_export"}; + return {values, 2}; + } + case ExecutionMode::OutputLinesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::OutputPrimitivesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupQuadsKHR: { + constexpr static char const *values[] = {"SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives"}; + return {values, 2}; + } + case ExecutionMode::DerivativeGroupLinearKHR: { + constexpr static char const *values[] = {"SPV_NV_compute_shader_derivatives", + "SPV_KHR_compute_shader_derivatives"}; + return {values, 2}; + } + case ExecutionMode::OutputTrianglesEXT: { + constexpr static char const *values[] = {"SPV_NV_mesh_shader", "SPV_EXT_mesh_shader"}; + return {values, 2}; + } + case ExecutionMode::PixelInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::PixelInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::SampleInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockOrderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::ShadingRateInterlockUnorderedEXT: { + constexpr static char const *values[] = {"SPV_EXT_fragment_shader_interlock"}; + return {values, 1}; + } + case ExecutionMode::MaxWorkgroupSizeINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::MaxWorkDimINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::NoGlobalOffsetINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::NumSIMDWorkitemsINTEL: { + constexpr static char const *values[] = {"SPV_INTEL_kernel_attributes"}; + return {values, 1}; + } + case ExecutionMode::MaximallyReconvergesKHR: { + constexpr static char const *values[] = {"SPV_KHR_maximal_reconvergence"}; + return {values, 1}; + } + default: + return {}; + } +} + +} // namespace tinytc::spv diff --git a/src/spv/capex_util.hpp b/src/spv/capex_util.hpp new file mode 100644 index 00000000..536fb5c9 --- /dev/null +++ b/src/spv/capex_util.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_CAPEX_UTIL_20250630_HPP +#define GENERATED_CAPEX_UTIL_20250630_HPP + +#include "tinytc/core.hpp" + +namespace tinytc::spv { + +enum class Capability; +enum class AddressingModel; +enum class ExecutionMode; +enum class ExecutionModel; +enum class MemoryModel; +auto capabilities(ExecutionModel op) -> array_view; +auto capabilities(AddressingModel op) -> array_view; +auto capabilities(MemoryModel op) -> array_view; +auto capabilities(ExecutionMode op) -> array_view; +auto extensions(ExecutionModel op) -> array_view; +auto extensions(AddressingModel op) -> array_view; +auto extensions(MemoryModel op) -> array_view; +auto extensions(ExecutionMode op) -> array_view; + +} // namespace tinytc::spv + +#endif // GENERATED_CAPEX_UTIL_20250630_HPP diff --git a/src/spv/converter.cpp b/src/spv/converter.cpp new file mode 100644 index 00000000..c9e6be35 --- /dev/null +++ b/src/spv/converter.cpp @@ -0,0 +1,942 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/converter.hpp" +#include "analysis/gcd.hpp" +#include "analysis/stack.hpp" +#include "codegen_tools.hpp" +#include "compiler_context.hpp" +#include "converter_aux.hpp" +#include "error.hpp" +#include "matrix_ext_info.hpp" +#include "node/attr.hpp" +#include "node/func.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/prog.hpp" +#include "node/region.hpp" +#include "node/type.hpp" +#include "node/value.hpp" +#include "node/visit.hpp" +#include "spv/coopmatrix_impl_block.hpp" +#include "spv/coopmatrix_impl_dpas.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/pass/capex.hpp" +#include "spv/uniquifier.hpp" +#include "spv/visit.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/iterator.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto convert_prog_to_spirv(tinytc_prog &p, tinytc_core_info const &info) + -> shared_handle { + auto m = shared_handle{ + std::make_unique(p.share_context(), info.core_features()).release()}; + + auto conv = inst_converter{*m, info}; + + if (m->context()->index_bit_width() == 64) { + m->add_to(section::memory_model, AddressingModel::Physical64, + MemoryModel::OpenCL); + } else { + m->add_to(section::memory_model, AddressingModel::Physical32, + MemoryModel::OpenCL); + } + + for (auto &fn : p) { + conv.run_on_function(fn); + } + + // Add missing capabilites and extensions + auto cx = capex{conv.unique()}; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto &i : m->insts(enum_cast

(s))) { + visit(cx, i); + } + } + + for (int i = 0; i < TINYTC_ENUM_NUM_SPIRV_FEATURE; ++i) { + const auto feature = enum_cast(i); + if (cx.requires_feature(feature) && !info.have_spirv_feature(feature)) { + throw compilation_error(p.loc(), status::spirv_required_feature_unavailable, + to_string(feature)); + } + } + + return m; +} + +inst_converter::inst_converter(tinytc_spv_mod &m, tinytc_core_info const &info) + : mod_(&m), info_(&info), unique_(m) {} + +auto inst_converter::get_dope_vector(tinytc_value const &v) -> dope_vector * { + if (auto it = dope_vec_.find(&v); it != dope_vec_.end()) { + return &it->second; + } + return nullptr; +} + +auto inst_converter::declare(tinytc_value const &v, spv_inst *in) { vals_[&v] = in; } +auto inst_converter::val(tinytc_value const &v) -> spv_inst * { + if (auto it = vals_.find(&v); it != vals_.end()) { + return it->second; + } + throw compilation_error(v.loc(), status::spirv_undefined_value); +} + +auto inst_converter::spv_ty(tinytc_type_t ty) -> spv_inst * { + if (auto ct = dyn_cast(ty); ct) { + return matrix_impl().spv_ty(ct); + } + return get_spv_ty_non_coopmatrix(unique_, ty); +} + +auto inst_converter::make_dope_vector(tinytc_value const &v) -> dope_vector * { + if (dope_vec_.contains(&v)) { + throw compilation_error(v.loc(), status::internal_compiler_error); + } + + auto spv_index_ty = get_spv_index_ty(unique_, v.context()); + return ::tinytc::visit( + overloaded{[&](memref_type const &mr) -> dope_vector * { + return &(dope_vec_[&v] = dope_vector{spv_index_ty, mr.shape(), mr.stride()}); + }, + [&](group_type const &g) -> dope_vector * { + if (auto mt = dyn_cast(g.element_ty()); mt) { + auto pointer_ty = get_spv_pointer_index_ty(unique_, g.context()); + return &(dope_vec_[&v] = dope_vector{ + pointer_ty, mt->shape(), mt->stride(), spv_index_ty, + g.size(), spv_index_ty, g.offset()}); + } else { + throw compilation_error(v.loc(), status::ir_expected_memref); + } + }, + [](auto const &) -> dope_vector * { return nullptr; }}, + *v.ty()); +} + +auto inst_converter::matrix_impl() -> coopmatrix_impl & { + if (matrix_impl_) { + return *matrix_impl_; + } + throw status::internal_compiler_error; +} + +void inst_converter::operator()(inst_view in) { + // @todo + throw compilation_error(in.loc(), status::not_implemented); +} + +void inst_converter::operator()(alloca_inst in) { + if (in.stack_ptr() < 0) { + throw compilation_error(in.loc(), status::internal_compiler_error, + "Invalid stack_ptr in alloca. Did you run set_stack_ptrs?"); + } + if (!stack_) { + throw compilation_error(in.loc(), status::internal_compiler_error, + "Stack required but not allocated"); + } + + auto mt = get_memref_type(in.result()); + if (in.stack_ptr() % mt->element_alignment() != 0) { + throw compilation_error(in.loc(), status::ir_insufficient_alignment); + } + + auto stack_element_ty = unique_.int_ty(8); + auto stack_ptr_ty = unique_.pointer_ty(StorageClass::Workgroup, stack_element_ty, 1); + auto stack_ptr = mod_->add( + stack_ptr_ty, stack_, std::vector{unique_.constant(in.stack_ptr())}); + + auto memref_ptr_ty = get_spv_ty(unique_, mt); + declare(in.result(), mod_->add(memref_ptr_ty, stack_ptr)); + + // alloca only accepts fixed-size memrefs => dope vector is constant + auto rdv = make_dope_vector(in.result()); + for (std::int64_t i = 0; i < mt->dim(); ++i) { + rdv->shape(i, unique_.constant(mt->shape(i))); + } + for (std::int64_t i = 0; i < mt->dim(); ++i) { + rdv->stride(i, unique_.constant(mt->stride(i))); + } +} + +void inst_converter::operator()(arith_inst in) { + if (isa(*in.result().ty())) { + auto av = val(in.a()); + auto bv = val(in.b()); + declare(in.result(), matrix_impl().arith(in, av, bv)); + } else { + auto av = val(in.a()); + auto bv = val(in.b()); + auto ty = in.result().ty(); + auto ik = in.get().type_id(); + declare(in.result(), make_binary_op(unique_, ty, ik, av, bv, in.loc())); + } +} + +void inst_converter::operator()(arith_unary_inst in) { + if (isa(*in.a().ty())) { + auto av = val(in.a()); + declare(in.result(), matrix_impl().arith_unary(in, av)); + } else { + auto av = val(in.a()); + auto ty = in.a().ty(); + auto ik = in.get().type_id(); + declare(in.result(), make_unary_op(unique_, ty, ik, av, in.loc())); + } +} + +void inst_converter::operator()(barrier_inst in) { + std::int32_t fence = 0; + if (in.has_fence(address_space::global)) { + fence = fence | static_cast(MemorySemantics::CrossWorkgroupMemory) | + static_cast(MemorySemantics::SequentiallyConsistent); + } + if (in.has_fence(address_space::local)) { + fence = fence | static_cast(MemorySemantics::WorkgroupMemory) | + static_cast(MemorySemantics::SequentiallyConsistent); + } + auto scope = unique_.constant(static_cast(Scope::Workgroup)); + auto memory_semantics = unique_.constant(fence); + mod_->add(scope, scope, memory_semantics); +} + +void inst_converter::operator()(cast_inst in) { + if (auto ct = dyn_cast(in.result().ty()); ct) { + declare(in.result(), matrix_impl().cast(in, val(in.a()))); + } else { + auto av = val(in.a()); + auto to_ty = in.result().ty(); + auto a_ty = in.a().ty(); + declare(in.result(), make_cast(unique_, to_ty, a_ty, av, in.loc())); + } +} + +void inst_converter::operator()(compare_inst in) { + auto av = val(in.a()); + auto bv = val(in.b()); + auto tid = in.get().type_id(); + auto a_ty = in.a().ty(); + declare(in.result(), make_compare_op(unique_, a_ty, tid, av, bv, in.loc())); +} + +void inst_converter::operator()(constant_inst in) { + if (auto ct = dyn_cast(in.result().ty()); ct) { + declare(in.result(), matrix_impl().constant(in)); + } else { + auto ty = in.result().ty(); + declare(in.result(), make_constant(unique_, ty, in.value())); + } +} + +void inst_converter::operator()(cooperative_matrix_extract_inst in) { + declare(in.result(), matrix_impl().extract(in, val(in.mat()))); +} +void inst_converter::operator()(cooperative_matrix_insert_inst in) { + declare(in.result(), matrix_impl().insert(in, val(in.val()), val(in.mat()))); +} + +void inst_converter::operator()(cooperative_matrix_load_inst in) { + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + declare(in.result(), + matrix_impl().load(in, *odv, val(in.operand()), val(in.pos0()), val(in.pos1()))); +} + +void inst_converter::operator()(cooperative_matrix_mul_add_inst in) { + declare(in.result(), matrix_impl().mul_add(in, val(in.a()), val(in.b()), val(in.c()))); +} +void inst_converter::operator()(cooperative_matrix_prefetch_inst in) { + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + matrix_impl().prefetch(in, *odv, val(in.operand()), val(in.pos0()), val(in.pos1())); +} +void inst_converter::operator()(cooperative_matrix_reduce_inst in) { + declare(in.result(), matrix_impl().reduce(in, val(in.a()))); +} +void inst_converter::operator()(cooperative_matrix_scale_inst in) { + declare(in.result(), matrix_impl().scale(in, val(in.a()), val(in.b()))); +} +void inst_converter::operator()(cooperative_matrix_store_inst in) { + auto odv = get_dope_vector(in.operand()); + if (!odv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + matrix_impl().store(in, *odv, val(in.val()), val(in.operand()), val(in.pos0()), val(in.pos1())); +} + +void inst_converter::operator()(expand_inst in) { + auto spv_index_ty = get_spv_index_ty(unique_, in.operand().context()); + + auto shape = std::vector{}; + auto stride = std::vector{}; + auto const make_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + auto static_shape = in.static_expand_shape(); + auto dyn_shape = in.expand_shape(); + + shape.reserve(mt->dim() + static_shape.size() - 1); + stride.reserve(mt->dim() + static_shape.size() - 1); + + for (std::int64_t i = 0; i < in.expanded_mode(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + + auto get_shape = [&, j = std::size_t{0}](std::int64_t s) mutable { + if (is_dynamic_value(s)) { + return val(dyn_shape[j++]); + } + return unique_.constant(s); + }; + stride.push_back(dv->stride(in.expanded_mode())); + shape.push_back(get_shape(static_shape[0])); + for (std::size_t j = 1; j < static_shape.size(); ++j) { + stride.push_back(mod_->add(spv_index_ty, stride.back(), shape.back())); + shape.push_back(get_shape(static_shape[j])); + } + + for (std::int64_t i = in.expanded_mode() + 1; i < mt->dim(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + }; + make_shape_stride(); + declare(in.result(), val(in.operand())); + + auto rdv = make_dope_vector(in.result()); + + if (shape.size() != static_cast(rdv->dim()) || + stride.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride[i]); + } +} + +void inst_converter::operator()(for_inst in) { + auto header_label_op = std::make_unique(); + auto body_label_op = std::make_unique(); + auto continue_label_op = std::make_unique(); + auto merge_label_op = std::make_unique(); + auto header_label = header_label_op.get(); + auto body_label = body_label_op.get(); + auto continue_label = continue_label_op.get(); + auto merge_label = merge_label_op.get(); + + auto entry_label = get_last_label(*mod_); + if (!entry_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + + mod_->add(header_label); + + // Header block + auto spv_bool_ty = unique_.bool_ty(); + auto spv_loop_var_ty = spv_ty(in.loop_var().ty()); + mod_->insts(section::function).push_back(header_label_op.release()); + // nullptr needs to be replaced by the loop var update once it is defined + auto loop_var_phi = mod_->add( + spv_loop_var_ty, std::vector{PairIdRefIdRef{val(in.from()), entry_label}, + PairIdRefIdRef{nullptr, continue_label}}); + declare(in.loop_var(), loop_var_phi); + auto const make_iter_arg_phi = [&]() -> std::vector { + auto phis = std::vector{}; + auto iter_init = in.iter_init(); + phis.reserve(iter_init.size()); + for (std::int64_t i = 0; i < iter_init.size(); ++i) { + auto ty = spv_ty(in.iter_arg(i).ty()); + phis.emplace_back(mod_->add( + ty, std::vector{PairIdRefIdRef{val(iter_init[i]), entry_label}, + PairIdRefIdRef{nullptr, continue_label}})); + declare(in.iter_arg(i), phis.back()); + } + return phis; + }; + auto iter_arg_phis = make_iter_arg_phi(); + + auto condition = mod_->add(spv_bool_ty, loop_var_phi, val(in.to())); + auto loop_control = [&]() -> std::pair> { + auto unroll = get_attr(in.get().attr(), "unroll"); + if (unroll) { + auto ba = dyn_cast(unroll); + if (ba) { + return {ba->value() ? LoopControl::Unroll : LoopControl::DontUnroll, std::nullopt}; + } + auto ia = dyn_cast(unroll); + if (ia) { + return {LoopControl::PartialCount, ia->value()}; + } + throw status::ir_expected_boolean_attribute; + } + return {LoopControl::None, std::nullopt}; + }(); + mod_->add(merge_label, continue_label, loop_control.first, loop_control.second); + mod_->add(condition, body_label, merge_label, + std::vector{}); + + // Body block + mod_->insts(section::function).push_back(body_label_op.release()); + + auto results = in.results(); + auto yielded_for = run_on_region_with_yield(in.body(), results.size()); + // Update phis with yielded values + for (std::int64_t i = 0; i < results.size(); ++i) { + iter_arg_phis[i]->op0().back().first = yielded_for[i]; + } + + mod_->add(continue_label); + + // Continue block + mod_->insts(section::function).push_back(continue_label_op.release()); + auto step = [&]() -> spv_inst * { + if (in.has_step()) { + return val(in.step()); + } + return make_constant(unique_, in.loop_var().ty(), std::int64_t{1}); + }(); + auto loop_var_update = mod_->add(spv_loop_var_ty, loop_var_phi, step); + loop_var_phi->op0().back().first = loop_var_update; + mod_->add(header_label); + + // Merge block + mod_->insts(section::function).push_back(merge_label_op.release()); + + auto const set_results = [&] { + for (std::int64_t i = 0; i < results.size(); ++i) { + declare(results[i], val(in.iter_arg(i))); + } + }; + set_results(); +} + +void inst_converter::operator()(fuse_inst in) { + auto spv_index_ty = get_spv_index_ty(unique_, in.operand().context()); + + auto shape = std::vector{}; + auto stride = std::vector{}; + auto const make_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + shape.reserve(mt->dim()); + stride.reserve(mt->dim()); + std::int64_t i = 0; + for (; i < in.from(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + spv_inst *prod = dv->shape(i++); + for (; i <= in.to(); ++i) { + prod = mod_->add(spv_index_ty, prod, dv->shape(i)); + } + shape.push_back(prod); + stride.push_back(dv->stride(in.from())); + for (i = in.to() + 1; i < mt->dim(); ++i) { + shape.push_back(dv->shape(i)); + stride.push_back(dv->stride(i)); + } + }; + make_shape_stride(); + declare(in.result(), val(in.operand())); + + auto rdv = make_dope_vector(in.result()); + + if (shape.size() != static_cast(rdv->dim()) || + stride.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride[i]); + } +} + +void inst_converter::operator()(if_inst in) { + auto then_label = std::make_unique(); + auto otherwise_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto conditionv = val(in.condition()); + mod_->add(merge_label.get(), SelectionControl::None); + mod_->add(conditionv, then_label.get(), otherwise_label.get(), + std::vector{}); + mod_->insts(section::function).push_back(then_label.release()); + auto results = in.results(); + auto yielded_then = run_on_region_with_yield(in.then(), results.size()); + mod_->add(merge_label.get()); + auto then_last_label = get_last_label(*mod_); + if (!then_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + mod_->insts(section::function).push_back(otherwise_label.release()); + auto yielded_otherwise = run_on_region_with_yield(in.otherwise(), results.size()); + mod_->add(merge_label.get()); + auto otherwise_last_label = get_last_label(*mod_); + if (!otherwise_last_label) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + + mod_->insts(section::function).push_back(merge_label.release()); + + std::int64_t val_no = 0; + for (std::int64_t i = 0; i < results.size(); ++i) { + auto ty = spv_ty(results[i].ty()); + auto phi_inst = mod_->add( + ty, std::vector{ + PairIdRefIdRef{yielded_then[val_no], then_last_label}, + PairIdRefIdRef{yielded_otherwise[val_no], otherwise_last_label}}); + ++val_no; + declare(results[i], phi_inst); + } +} + +void inst_converter::operator()(lifetime_stop_inst) {} + +void inst_converter::operator()(load_inst in) { + auto spv_index_ty = get_spv_index_ty(unique_, in.operand().context()); + auto spv_pointer_index_ty = get_spv_pointer_index_ty(unique_, in.operand().context()); + auto spv_pointer_ty = spv_ty(in.operand().ty()); + auto spv_result_ty = spv_ty(in.result().ty()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + if (auto group_ty = dyn_cast(in.operand().ty()); group_ty) { + auto offset = mod_->add(spv_index_ty, dv->offset(), val(in.index_list()[0])); + auto pointer = mod_->add(spv_pointer_ty, val(in.operand()), + offset, std::vector{}); + declare(in.result(), mod_->add(spv_result_ty, pointer)); + auto rdv = make_dope_vector(in.result()); + + auto const make_dope_par = [&](std::int64_t static_s, spv_inst *s) -> spv_inst * { + if (is_dynamic_value(static_s)) { + auto pointer = mod_->add(spv_pointer_index_ty, s, offset, + std::vector{}); + return mod_->add(spv_index_ty, pointer); + } + return s; + }; + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, make_dope_par(dv->static_shape(i), dv->shape(i))); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, make_dope_par(dv->static_stride(i), dv->stride(i))); + } + } else if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { + const auto pointer = [&](spv_inst *additional_offset0 = nullptr) -> spv_inst * { + if (memref_ty->dim() == 0) { + return val(in.operand()); + } + + auto idx0 = val(in.index_list()[0]); + spv_inst *offset = memref_ty->stride(0) != 1 + ? mod_->add(spv_index_ty, idx0, dv->stride(0)) + : idx0; + for (std::int64_t i = 1; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + if (additional_offset0) { + offset = mod_->add(spv_index_ty, offset, additional_offset0); + } + return mod_->add(spv_pointer_ty, val(in.operand()), offset, + std::vector{}); + }; + declare(in.result(), mod_->add(spv_result_ty, pointer())); + } else { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + } +} + +void inst_converter::operator()(math_unary_inst in) { + auto av = val(in.a()); + auto ty = in.result().ty(); + auto ik = in.get().type_id(); + declare(in.result(), make_math_unary_op(unique_, ty, ik, av, in.loc())); +} + +void inst_converter::operator()(parallel_inst in) { run_on_region(in.body()); } + +void inst_converter::operator()(size_inst in) { + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + const auto shape = ::tinytc::visit( + overloaded{[&](group_type const &) -> spv_inst * { return dv->size(); }, + [&](memref_type const &) -> spv_inst * { return dv->shape(in.mode()); }, + [&](auto const &) -> spv_inst * { + throw compilation_error(in.loc(), status::ir_expected_memref_or_group); + }}, + *in.operand().ty()); + declare(in.result(), shape); +} + +void inst_converter::operator()(subgroup_broadcast_inst in) { + auto broadcast_scope = unique_.constant(static_cast(Scope::Subgroup)); + auto ty = spv_ty(in.result().ty()); + auto av = val(in.a()); + auto idxv = val(in.idx()); + declare(in.result(), mod_->add(ty, broadcast_scope, av, idxv)); +} + +void inst_converter::operator()(subgroup_operation_inst in) { + auto a_ty = in.a().ty(); + auto ik = in.get().type_id(); + declare(in.result(), make_subgroup_op(unique_, a_ty, ik, val(in.a()), in.loc())); +} + +void inst_converter::operator()(store_inst in) { + auto spv_index_ty = get_spv_index_ty(unique_, in.operand().context()); + auto spv_pointer_ty = spv_ty(in.operand().ty()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + if (auto memref_ty = dyn_cast(in.operand().ty()); memref_ty) { + const auto pointer = [&]() -> spv_inst * { + if (memref_ty->dim() == 0) { + return val(in.operand()); + } + + auto idx0 = val(in.index_list()[0]); + auto offset = memref_ty->stride(0) != 1 + ? mod_->add(spv_index_ty, idx0, dv->stride(0)) + : idx0; + for (std::int64_t i = 1; i < memref_ty->dim(); ++i) { + auto tmp = mod_->add(spv_index_ty, val(in.index_list()[i]), dv->stride(i)); + offset = mod_->add(spv_index_ty, offset, tmp); + } + + return mod_->add(spv_pointer_ty, val(in.operand()), offset, + std::vector{}); + }; + + auto val_ty = memref_ty->element_ty(); + make_store(unique_, in.flag(), val_ty, memref_ty->addrspace(), pointer(), val(in.val()), + in.loc()); + } else { + throw compilation_error(in.loc(), status::ir_expected_memref); + } +} + +void inst_converter::operator()(subview_inst in) { + auto spv_index_ty = get_spv_index_ty(unique_, in.operand().context()); + auto spv_result_ty = spv_ty(in.result().ty()); + + auto shape_out = std::vector{}; + auto stride_out = std::vector{}; + auto const make_offset_and_shape_stride = [&] { + auto mt = get_memref_type(in.operand()); + auto dv = get_dope_vector(in.operand()); + if (!dv) { + throw compilation_error(in.loc(), status::spirv_missing_dope_vector); + } + + shape_out.reserve(mt->dim()); + stride_out.reserve(mt->dim()); + auto dyn_offsets = in.offsets(); + auto dyn_sizes = in.sizes(); + auto offset_acc = unique_.null_constant(spv_index_ty); + for (std::int64_t i = 0, joffset = 0, jsize = 0; i < mt->dim(); ++i) { + const std::int64_t offset = in.static_offsets()[i]; + + auto const offset_inst = [&]() -> spv_inst * { + if (is_dynamic_value(offset)) { + return val(dyn_offsets[joffset++]); + } + return unique_.constant(offset); + }; + auto tmp = mod_->add(spv_index_ty, offset_inst(), dv->stride(i)); + offset_acc = mod_->add(spv_index_ty, offset_acc, tmp); + + const std::int64_t size = in.static_sizes()[i]; + if (size > 0 || is_dynamic_value(size)) { + auto const size_inst = [&]() -> spv_inst * { + if (is_dynamic_value(size)) { + return val(dyn_sizes[jsize++]); + } + return unique_.constant(size); + }; + shape_out.emplace_back(size_inst()); + stride_out.emplace_back(dv->stride(i)); + } + } + return offset_acc; + }; + + auto offset = make_offset_and_shape_stride(); + declare(in.result(), mod_->add(spv_result_ty, val(in.operand()), + offset, std::vector{})); + + auto rdv = make_dope_vector(in.result()); + + if (shape_out.size() != static_cast(rdv->dim()) || + stride_out.size() != static_cast(rdv->dim())) { + throw compilation_error(in.loc(), status::internal_compiler_error); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->shape(i, shape_out[i]); + } + for (std::int64_t i = 0; i < rdv->dim(); ++i) { + rdv->stride(i, stride_out[i]); + } +} + +void inst_converter::operator()(yield_inst in) { + if (yielded_vals_.empty()) { + throw compilation_error(in.loc(), status::ir_unexpected_yield); + } + + auto &top = yielded_vals_.top(); + if (static_cast(top.size()) != in.yielded_vals().size()) { + throw compilation_error(in.loc(), status::ir_yield_mismatch); + } + + std::int64_t i = 0; + for (auto &op : in.yielded_vals()) { + top[i++] = val(op); + } +} + +void inst_converter::operator()(group_id_inst in) { + auto gid = unique_.load_builtin(BuiltIn::WorkgroupId); + const std::int32_t mode = + static_cast(in.mode()) - static_cast(comp3::x); + auto rty = spv_ty(in.result().ty()); + declare(in.result(), + mod_->add(rty, gid, std::vector{mode})); +} +void inst_converter::operator()(num_groups_inst in) { + auto ng = unique_.load_builtin(BuiltIn::NumWorkgroups); + const std::int32_t mode = + static_cast(in.mode()) - static_cast(comp3::x); + auto rty = spv_ty(in.result().ty()); + declare(in.result(), mod_->add(rty, ng, std::vector{mode})); +} +void inst_converter::operator()(num_subgroups_inst in) { + auto make_constant = [&](comp3 c) -> std::int32_t { + switch (c) { + case comp3::x: + return tiling_.m_tiles(); + case comp3::y: + return tiling_.n_tiles(); + default: + break; + } + return 1; + }; + auto cst = make_constant(in.mode()); + declare(in.result(), unique_.constant(cst)); +} +void inst_converter::operator()(subgroup_size_inst in) { + declare(in.result(), unique_.load_builtin(BuiltIn::SubgroupSize)); +} +void inst_converter::operator()(subgroup_id_inst in) { + auto make_value = [&](comp3 c) -> spv_inst * { + auto rty = spv_ty(in.result().ty()); + const auto m_tiles = unique_.constant(tiling_.m_tiles()); + const auto sgid = unique_.load_builtin(BuiltIn::SubgroupId); + switch (c) { + case comp3::x: + return mod_->add(rty, sgid, m_tiles); + case comp3::y: + return mod_->add(rty, sgid, m_tiles); + default: + break; + } + return unique_.constant(std::int32_t{0}); + }; + declare(in.result(), make_value(in.mode())); +} +void inst_converter::operator()(subgroup_linear_id_inst in) { + declare(in.result(), unique_.load_builtin(BuiltIn::SubgroupId)); +} +void inst_converter::operator()(subgroup_local_id_inst in) { + declare(in.result(), unique_.load_builtin(BuiltIn::SubgroupLocalInvocationId)); +} + +void inst_converter::run_on_region(tinytc_region ®) { + for (auto &i : reg) { + visit(*this, i); + } +} + +auto inst_converter::run_on_region_with_yield(tinytc_region ®, std::int64_t num_results) + -> std::vector { + yielded_vals_.push(std::vector(num_results, nullptr)); + run_on_region(reg); + auto yielded_vals = std::move(yielded_vals_.top()); + if (static_cast(yielded_vals.size()) != num_results || + std::any_of(yielded_vals.begin(), yielded_vals.end(), + [](spv_inst *in) { return in == nullptr; })) { + throw compilation_error(reg.loc(), status::ir_yield_mismatch); + } + yielded_vals_.pop(); + return yielded_vals; +} + +void inst_converter::run_on_function(tinytc_func &fn) { + try { + core_cfg_ = info_->get_core_config(fn.subgroup_size()); + } catch (std::out_of_range const &e) { + throw compilation_error(fn.loc(), status::unsupported_subgroup_size); + } + + auto vars_used_by_function = std::vector{}; + + // Stack + auto const make_stack = [&] { + const auto high_water_mark = stack_high_water_mark{}.run_on_function(fn); + if (high_water_mark > 0) { + auto stack_element_ty = unique_.int_ty(8); + auto stack_array_ty = unique_.array_ty(stack_element_ty, high_water_mark); + auto stack_ptr_ty = unique_.pointer_ty(StorageClass::Workgroup, stack_array_ty, 0); + stack_ = mod_->add_to(section::type_const_var, stack_ptr_ty, + StorageClass::Workgroup); + const std::int32_t alignment = info_->alignment(); + mod_->add_to(section::decoration, stack_, Decoration::Alignment, + DecorationAttr{alignment}); + vars_used_by_function.emplace_back(stack_); + } else { + stack_ = nullptr; + } + }; + make_stack(); + + // Function type + auto fun_ty = unique_.function_ty(unique_.void_ty(), [&] { + auto params = std::vector{}; + params.reserve(fn.num_params()); + for (auto const &p : fn.params()) { + params.emplace_back(spv_ty(p.ty())); + auto dv = make_dope_vector(p); + if (dv) { + for (std::int64_t i = 0; i < dv->num_dynamic(); ++i) { + params.emplace_back(dv->ty()); + } + if (is_dynamic_value(dv->static_size())) { + params.emplace_back(dv->size_ty()); + } + if (is_dynamic_value(dv->static_offset())) { + params.emplace_back(dv->offset_ty()); + } + } + } + return params; + }()); + + // Function + auto const subgroup_size = fn.subgroup_size(); + auto const work_group_size = fn.work_group_size(); + tiling_[0] = work_group_size[0] / subgroup_size; + tiling_[1] = work_group_size[1]; + + matrix_impl_ = [&]() -> std::unique_ptr { + const auto gcd = gcd_analysis{info_->alignment()}.run_on_function(fn); + if (info_->matrix().have_dpas()) { + return std::make_unique(unique_, core_cfg_, std::move(gcd)); + } else if (info_->have_spirv_feature(spirv_feature::subgroup_buffer_block_io)) { + return std::make_unique(unique_, core_cfg_, std::move(gcd)); + } + return std::make_unique(unique_, core_cfg_, std::move(gcd)); + }(); + + auto void_ty = unique_.void_ty(); + auto fun = mod_->add(void_ty, FunctionControl::None, fun_ty); + for (auto const &p : fn.params()) { + declare(p, mod_->add(spv_ty(p.ty()))); + auto dv = get_dope_vector(p); + if (dv) { + auto const make_dope_par = [&](spv_inst *ty, std::int64_t s) { + return is_dynamic_value(s) ? mod_->add(ty) + : unique_.constant(s); + }; + for (std::int64_t i = 0; i < dv->dim(); ++i) { + dv->shape(i, make_dope_par(dv->ty(), dv->static_shape(i))); + } + for (std::int64_t i = 0; i < dv->dim(); ++i) { + dv->stride(i, make_dope_par(dv->ty(), dv->static_stride(i))); + } + if (dv->size_ty()) { + dv->size(make_dope_par(dv->size_ty(), dv->static_size())); + } + if (dv->offset_ty()) { + dv->offset(make_dope_par(dv->offset_ty(), dv->static_offset())); + } + } + } + + mod_->add(); + auto func_begin = --mod_->insts(section::function).end(); + run_on_region(fn.body()); + + auto func_end = mod_->insts(section::function).end(); + for (auto it = func_begin; it != func_end; ++it) { + if (auto ld = dyn_cast(it.get()); ld && isa(*ld->op0())) { + if (std::find(vars_used_by_function.begin(), vars_used_by_function.end(), ld->op0()) == + vars_used_by_function.end()) { + vars_used_by_function.push_back(ld->op0()); + } + } + } + + mod_->add(); + mod_->add(); + + tiling_ = {}; + + // Entry point + mod_->add_to(section::entry_point, ExecutionModel::Kernel, fun, + std::string{fn.name()}, std::move(vars_used_by_function)); + + // Execution mode + mod_->add_to( + section::execution_mode, fun, ExecutionMode::LocalSize, + ExecutionModeAttr{std::array{work_group_size[0], work_group_size[1], 1}}); + mod_->add_to(section::execution_mode, fun, ExecutionMode::SubgroupSize, + ExecutionModeAttr{subgroup_size}); +} + +} // namespace tinytc::spv + diff --git a/src/spv/converter.hpp b/src/spv/converter.hpp new file mode 100644 index 00000000..b08ab660 --- /dev/null +++ b/src/spv/converter.hpp @@ -0,0 +1,101 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERTER_20241111_HPP +#define CONVERTER_20241111_HPP + +#include "device_info.hpp" +#include "node/inst_view.hpp" +#include "spv/coopmatrix_impl.hpp" +#include "spv/defs.hpp" +#include "spv/dope_vector.hpp" +#include "spv/uniquifier.hpp" +#include "tiling.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto convert_prog_to_spirv(tinytc_prog &p, tinytc_core_info const &info) + -> shared_handle; + +class inst_converter { + public: + inst_converter(tinytc_spv_mod &m, tinytc_core_info const &info); + + // Instruction nodes + void operator()(inst_view in); + void operator()(alloca_inst in); + void operator()(arith_inst in); + void operator()(arith_unary_inst in); + void operator()(barrier_inst in); + void operator()(cast_inst in); + void operator()(compare_inst in); + void operator()(constant_inst in); + void operator()(cooperative_matrix_extract_inst in); + void operator()(cooperative_matrix_insert_inst in); + void operator()(cooperative_matrix_load_inst in); + void operator()(cooperative_matrix_mul_add_inst in); + void operator()(cooperative_matrix_prefetch_inst in); + void operator()(cooperative_matrix_reduce_inst in); + void operator()(cooperative_matrix_scale_inst in); + void operator()(cooperative_matrix_store_inst in); + void operator()(expand_inst in); + void operator()(for_inst in); + void operator()(fuse_inst in); + void operator()(if_inst in); + void operator()(lifetime_stop_inst in); + void operator()(load_inst in); + void operator()(math_unary_inst in); + void operator()(parallel_inst in); + void operator()(size_inst in); + void operator()(subgroup_broadcast_inst in); + void operator()(subgroup_operation_inst in); + void operator()(store_inst in); + void operator()(subview_inst in); + void operator()(yield_inst in); + void operator()(group_id_inst in); + void operator()(num_groups_inst in); + void operator()(num_subgroups_inst in); + void operator()(subgroup_size_inst in); + void operator()(subgroup_id_inst in); + void operator()(subgroup_linear_id_inst in); + void operator()(subgroup_local_id_inst in); + + void run_on_region(tinytc_region ®); + auto run_on_region_with_yield(tinytc_region ®, std::int64_t num_results) + -> std::vector; + void run_on_function(tinytc_func &fn); + + inline auto unique() -> uniquifier & { return unique_; } + + private: + auto get_dope_vector(tinytc_value const &v) -> dope_vector *; + auto declare(tinytc_value const &v, spv_inst *in); + auto val(tinytc_value const &v) -> spv_inst *; + auto spv_ty(tinytc_type_t ty) -> spv_inst *; + auto make_dope_vector(tinytc_value const &v) -> dope_vector *; + auto matrix_impl() -> coopmatrix_impl &; + auto make_matrix_impl() -> std::unique_ptr; + + tinytc_spv_mod_t mod_; + tinytc_core_info const *info_; + uniquifier unique_; + std::unique_ptr matrix_impl_ = nullptr; + std::unordered_map dope_vec_; + std::unordered_map vals_; + std::stack> yielded_vals_; + spv_inst *stack_ = nullptr; + core_config core_cfg_ = {}; + local_tiling tiling_ = {}; +}; + +} // namespace tinytc::spv + +#endif // CONVERTER_20241111_HPP diff --git a/src/spv/converter_aux.cpp b/src/spv/converter_aux.cpp new file mode 100644 index 00000000..3bf7953a --- /dev/null +++ b/src/spv/converter_aux.cpp @@ -0,0 +1,878 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/converter_aux.hpp" +#include "compiler_context.hpp" +#include "error.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "number_dispatch.hpp" +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto get_spv_index_ty(uniquifier &unique, tinytc_compiler_context_t ctx) -> spv_inst * { + return unique.int_ty(ctx->index_bit_width()); +} +auto get_spv_ty(uniquifier &unique, memref_type const *ty) -> spv_inst * { + const auto storage_cls = address_space_to_storage_class(ty->addrspace()); + auto pointee_ty = get_spv_ty_non_coopmatrix(unique, ty->element_ty()); + const auto align = alignment(ty->element_ty()); + return unique.pointer_ty(storage_cls, pointee_ty, align); +} +auto get_spv_pointer_index_ty(uniquifier &unique, tinytc_compiler_context_t ctx, + address_space addrspace) -> spv_inst * { + auto index_ty = index_type::get(ctx); + auto memref_ty = + memref_type::get(index_ty, array_view{dynamic}, array_view{std::int64_t{1}}, addrspace); + return get_spv_ty(unique, dyn_cast(memref_ty)); +} + +auto get_spv_ty_non_coopmatrix(uniquifier &unique, tinytc_type_t ty) -> spv_inst * { + return visit(overloaded{[&](boolean_type &) -> spv_inst * { return unique.bool_ty(); }, + [&](i8_type &) -> spv_inst * { return unique.int_ty(8); }, + [&](i16_type &) -> spv_inst * { return unique.int_ty(16); }, + [&](i32_type &) -> spv_inst * { return unique.int_ty(32); }, + [&](i64_type &) -> spv_inst * { return unique.int_ty(64); }, + [&](index_type &ty) -> spv_inst * { + return get_spv_index_ty(unique, ty.context()); + }, + [&](bf16_type &) -> spv_inst * { return unique.int_ty(16); }, + [&](f16_type &) -> spv_inst * { return unique.float_ty(16); }, + [&](f32_type &) -> spv_inst * { return unique.float_ty(32); }, + [&](f64_type &) -> spv_inst * { return unique.float_ty(64); }, + [&](c32_type &) -> spv_inst * { + return unique.vec_ty(unique.float_ty(32), vector_size::v2); + }, + [&](c64_type &) -> spv_inst * { + return unique.vec_ty(unique.float_ty(64), vector_size::v2); + }, + [&](group_type &ty) -> spv_inst * { + auto pointee_ty = + get_spv_ty_non_coopmatrix(unique, ty.element_ty()); + return unique.pointer_ty(StorageClass::CrossWorkgroup, pointee_ty, + ty.context()->index_bit_width() / 8); + }, + [&](memref_type &ty) -> spv_inst * { return get_spv_ty(unique, &ty); }, + [&](void_type &) -> spv_inst * { return unique.void_ty(); }, + [](tinytc_type &) -> spv_inst * { + // Coopmatrix handled by matrix impl class + throw status::not_implemented; + }}, + *ty); +} + +auto get_last_label(tinytc_spv_mod &mod) -> spv_inst * { + auto &insts = mod.insts(section::function); + auto it = insts.end(); + while (it != insts.begin()) { + auto in = (--it).get(); + if (isa(*in)) { + return in; + } + } + return nullptr; +} + +auto make_binary_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_boolean = [&](IK op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (op) { + case IK::IK_and: + return mod.add(ty, a, b); + case IK::IK_or: + return mod.add(ty, a, b); + case IK::IK_xor: + return mod.add(ty, a, b); + default: + break; + } + throw compilation_error(loc, status::ir_boolean_unsupported); + }; + auto const make_int = [&](IK op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (op) { + case IK::IK_add: + return mod.add(ty, a, b); + case IK::IK_sub: + return mod.add(ty, a, b); + case IK::IK_mul: + return mod.add(ty, a, b); + case IK::IK_div: + return mod.add(ty, a, b); + case IK::IK_rem: + return mod.add(ty, a, b); + case IK::IK_shl: + return mod.add(ty, a, b); + case IK::IK_shr: + return mod.add(ty, a, b); + case IK::IK_and: + return mod.add(ty, a, b); + case IK::IK_or: + return mod.add(ty, a, b); + case IK::IK_xor: + return mod.add(ty, a, b); + case IK::IK_min: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::s_min), + std::vector{a, b}); + case IK::IK_max: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::s_max), + std::vector{a, b}); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_float = [&](IK op, spv_inst *ty, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (op) { + case IK::IK_add: + return mod.add(ty, a, b); + case IK::IK_sub: + return mod.add(ty, a, b); + case IK::IK_mul: + return mod.add(ty, a, b); + case IK::IK_div: + return mod.add(ty, a, b); + case IK::IK_rem: + return mod.add(ty, a, b); + case IK::IK_min: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fmin), + std::vector{a, b}); + case IK::IK_max: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fmax), + std::vector{a, b}); + default: + break; + } + throw compilation_error(loc, status::ir_fp_unsupported); + }; + auto const make_complex = [&](IK op, spv_inst *ty, spv_inst *float_ty, spv_inst *a, + spv_inst *b) -> spv_inst * { + switch (op) { + case IK::IK_add: + return mod.add(ty, a, b); + case IK::IK_sub: + return mod.add(ty, a, b); + case IK::IK_mul: { + return make_complex_mul(unique, ty, a, b); + } + case IK::IK_div: { + auto a_times_conj_b = make_complex_mul(unique, ty, a, b, true); + + auto b_squared = mod.add(ty, b, b); + auto b_squared_0 = + mod.add(float_ty, b_squared, std::vector{0}); + auto b_squared_1 = + mod.add(float_ty, b_squared, std::vector{1}); + spv_inst *b_abs = mod.add(float_ty, b_squared_0, b_squared_1); + auto dummy = mod.add(ty); + b_abs = mod.add(ty, b_abs, dummy, std::vector{0}); + b_abs = mod.add(ty, b_abs, dummy, std::vector{0, 0}); + return mod.add(ty, a_times_conj_b, b_abs); + } + default: + break; + } + throw compilation_error(loc, status::ir_complex_unsupported); + }; + auto ty = get_spv_ty_non_coopmatrix(unique, operand_ty); + auto binop = + visit(overloaded{[&](boolean_type &) -> spv_inst * { return make_boolean(op, ty, a, b); }, + [&](integer_type &) -> spv_inst * { return make_int(op, ty, a, b); }, + [&](bf16_type &) -> spv_inst * { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + auto bf = mod.add(float_ty, b); + auto af_op_bf = make_float(op, float_ty, af, bf); + return mod.add(ty, af_op_bf); + }, + [&](float_type &) -> spv_inst * { return make_float(op, ty, a, b); }, + [&](complex_type &) -> spv_inst * { + auto float_ty = + get_spv_ty_non_coopmatrix(unique, component_type(operand_ty)); + return make_complex(op, ty, float_ty, a, b); + }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *operand_ty); + if (!binop) { + throw compilation_error(loc, status::internal_compiler_error); + } + return binop; +} + +auto make_binary_op_mixed_precision(uniquifier &unique, tinytc_type_t result_ty, IK op, + tinytc_type_t a_ty, spv_inst *a, tinytc_type_t b_ty, + spv_inst *b, location const &loc) -> spv_inst * { + if (!promotable(a_ty, result_ty) || !promotable(b_ty, result_ty)) { + throw compilation_error(loc, status::ir_forbidden_promotion); + } + if (a_ty != result_ty) { + a = make_cast(unique, result_ty, a_ty, a, loc); + } + if (b_ty != result_ty) { + b = make_cast(unique, result_ty, b_ty, b, loc); + } + return make_binary_op(unique, result_ty, op, a, b, loc); +} + +auto make_cast(uniquifier &unique, tinytc_type_t to_ty, tinytc_type_t a_ty, spv_inst *a, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const cast_from_int = [&](tinytc_type_t to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + auto op = visit( + overloaded{ + [&](integer_type &) -> spv_inst * { return mod.add(spv_to_ty, a); }, + [&](bf16_type &) -> spv_inst * { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + return mod.add(spv_to_ty, af); + }, + [&](float_type &) -> spv_inst * { return mod.add(spv_to_ty, a); }, + [&](complex_type &) -> spv_inst * { + auto cty = component_type(to_ty); + auto spv_float_ty = get_spv_ty_non_coopmatrix(unique, cty); + auto re = mod.add(spv_float_ty, a); + return mod.add(spv_to_ty, re, + unique.null_constant(spv_to_ty), + std::vector{0}); + }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *to_ty); + if (!op) { + throw compilation_error(loc, status::ir_forbidden_cast); + } + return op; + }; + auto const cast_from_float = [&](tinytc_type_t to_ty, spv_inst *spv_to_ty, tinytc_type_t a_ty, + spv_inst *a) -> spv_inst * { + auto op = visit( + overloaded{ + [&](integer_type &) -> spv_inst * { return mod.add(spv_to_ty, a); }, + [&](bf16_type &) -> spv_inst * { + return mod.add(spv_to_ty, a); + }, + [&](float_type &) -> spv_inst * { return mod.add(spv_to_ty, a); }, + [&](complex_type &) -> spv_inst * { + auto re = a; + auto cty = component_type(to_ty); + if (cty != a_ty) { + auto spv_float_ty = get_spv_ty_non_coopmatrix(unique, cty); + re = mod.add(spv_float_ty, a); + } + // If the line below is change, adjust make_complex_mul as well + return mod.add(spv_to_ty, re, + unique.null_constant(spv_to_ty), + std::vector{0}); + }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *to_ty); + if (!op) { + throw compilation_error(loc, status::ir_forbidden_cast); + } + return op; + }; + auto const cast_from_complex = [&](tinytc_type_t to_ty, spv_inst *spv_to_ty, + spv_inst *a) -> spv_inst * { + if (isa(*to_ty)) { + return mod.add(spv_to_ty, a); + } + throw compilation_error(loc, status::ir_forbidden_cast); + }; + + auto spv_to_ty = get_spv_ty_non_coopmatrix(unique, to_ty); + if (a_ty == to_ty) { + return mod.add(spv_to_ty, a); + } + + auto castop = + visit(overloaded{[&](integer_type &) { return cast_from_int(to_ty, spv_to_ty, a); }, + [&](bf16_type &) { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + return cast_from_float(to_ty, spv_to_ty, + f32_type::get(to_ty->context()), af); + }, + [&](float_type &) { return cast_from_float(to_ty, spv_to_ty, a_ty, a); }, + [&](complex_type &) { return cast_from_complex(to_ty, spv_to_ty, a); }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *a_ty); + if (!castop) { + throw compilation_error(loc, status::internal_compiler_error); + } + return castop; +} + +auto make_complex_mul(uniquifier &unique, spv_inst *ty, spv_inst *a, spv_inst *b, bool conj_b) + -> spv_inst * { + auto &mod = unique.mod(); + const auto is_imag_zero = [&](spv_inst *a) -> bool { + // We capture the case here if "a" stems from a non-complex -> complex cast + if (auto ci = dyn_cast(a); ci) { + return ci->type() == ty && ci->op1() == unique.null_constant(ty) && + ci->op2().size() == 1 && ci->op2()[0] == 0; + } + return false; + }; + + if (is_imag_zero(a)) { + a = mod.add(ty, a, a, std::vector{0, 0}); + return mod.add(ty, a, b); + } else if (is_imag_zero(b)) { + b = mod.add(ty, b, b, std::vector{0, 0}); + return mod.add(ty, a, b); + } + + auto neg_a = mod.add(ty, a); + auto a_times_i = + conj_b ? mod.add(ty, a, neg_a, std::vector{1, 2}) + : mod.add(ty, neg_a, a, std::vector{1, 2}); + auto b_1 = mod.add(ty, b, b, std::vector{1, 1}); + auto b_1_a_times_i = mod.add(ty, b_1, a_times_i); + auto b_0 = mod.add(ty, b, b, std::vector{0, 0}); + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fma), + std::vector{a, b_0, b_1_a_times_i}); +} + +auto make_compare_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto bool_ty = unique.bool_ty(); + auto const compare_int = [&](IK cond, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (cond) { + case IK::IK_equal: + return mod.add(bool_ty, a, b); + case IK::IK_not_equal: + return mod.add(bool_ty, a, b); + case IK::IK_greater_than: + return mod.add(bool_ty, a, b); + case IK::IK_greater_than_equal: + return mod.add(bool_ty, a, b); + case IK::IK_less_than: + return mod.add(bool_ty, a, b); + case IK::IK_less_than_equal: + return mod.add(bool_ty, a, b); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const compare_float = [&](IK cond, spv_inst *a, spv_inst *b) -> spv_inst * { + switch (cond) { + case IK::IK_equal: + return mod.add(bool_ty, a, b); + case IK::IK_not_equal: + return mod.add(bool_ty, a, b); + case IK::IK_greater_than: + return mod.add(bool_ty, a, b); + case IK::IK_greater_than_equal: + return mod.add(bool_ty, a, b); + case IK::IK_less_than: + return mod.add(bool_ty, a, b); + case IK::IK_less_than_equal: + return mod.add(bool_ty, a, b); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const compare_complex = [&](IK cond, spv_inst *a, spv_inst *b) -> spv_inst * { + auto bool2_ty = unique.vec_ty(bool_ty, vector_size::v2); + switch (cond) { + case IK::IK_equal: { + auto components_equal = mod.add(bool2_ty, a, b); + return mod.add(bool_ty, components_equal); + } + case IK::IK_not_equal: { + auto components_not_equal = mod.add(bool2_ty, a, b); + return mod.add(bool_ty, components_not_equal); + } + default: + throw compilation_error(loc, status::ir_complex_unsupported); + } + }; + auto cmpop = visit(overloaded{[&](integer_type &) { return compare_int(op, a, b); }, + [&](bf16_type &) { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + auto bf = mod.add(float_ty, b); + return compare_float(op, af, bf); + }, + [&](float_type &) { return compare_float(op, a, b); }, + [&](complex_type &) { return compare_complex(op, a, b); }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *operand_ty); + if (!cmpop) { + throw compilation_error(loc, status::internal_compiler_error); + } + return cmpop; +} + +auto make_constant(uniquifier &unique, tinytc_type_t ty, constant_value_type const &val) + -> spv_inst * { + const auto visitor = overloaded{ + [&](bool b) -> spv_inst * { + if (!isa(*ty)) { + throw status::ir_expected_boolean; + } + return unique.bool_constant(b); + }, + [&](std::int64_t i) -> spv_inst * { + return dispatch_int_to_native( + ty, [&]() { return unique.constant(static_cast(i)); }); + }, + [&](double d) -> spv_inst * { + return dispatch_float_to_native(ty, [&]() { + if constexpr (std::is_same_v) { + return unique.constant(bfloat16{static_cast(d)}.bits()); + } else if constexpr (std::is_same_v) { + return unique.constant(half{static_cast(d)}); + } else { + return unique.constant(static_cast(d)); + } + }); + }, + [&](std::complex c) -> spv_inst * { + return dispatch_complex_to_native(ty, [&]() { + auto cst = static_cast(c); + auto c_re = unique.constant(cst.real()); + auto c_im = unique.constant(cst.imag()); + auto cst_ty = get_spv_ty_non_coopmatrix(unique, ty); + return unique.mod().add_to( + section::type_const_var, cst_ty, std::vector{c_re, c_im}); + }); + }, + }; + auto cst = std::visit(visitor, val); + if (!cst) { + throw status::internal_compiler_error; + } + return cst; +} + +void make_conditional_execution(uniquifier &unique, spv_inst *condition, + std::function then) { + auto then_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto &mod = unique.mod(); + mod.add(merge_label.get(), SelectionControl::None); + mod.add(condition, then_label.get(), merge_label.get(), + std::vector{}); + mod.insts(section::function).push_back(then_label.release()); + then(mod); + mod.add(merge_label.get()); + mod.insts(section::function).push_back(merge_label.release()); +} + +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + spv_inst *otherwise, location const &loc) -> spv_inst * { + auto then_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto &mod = unique.mod(); + auto init_last_label = get_last_label(mod); + if (!init_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod.add(merge_label.get(), SelectionControl::None); + mod.add(condition, then_label.get(), merge_label.get(), + std::vector{}); + mod.insts(section::function).push_back(then_label.release()); + spv_inst *yielded_then = then(mod); + mod.add(merge_label.get()); + auto then_last_label = get_last_label(mod); + if (!then_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod.insts(section::function).push_back(merge_label.release()); + return mod.add(return_ty, + std::vector{PairIdRefIdRef{yielded_then, then_last_label}, + PairIdRefIdRef{otherwise, init_last_label}}); +} + +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + std::function otherwise, + location const &loc) -> spv_inst * { + auto then_label = std::make_unique(); + auto otherwise_label = std::make_unique(); + auto merge_label = std::make_unique(); + + auto &mod = unique.mod(); + mod.add(merge_label.get(), SelectionControl::None); + mod.add(condition, then_label.get(), otherwise_label.get(), + std::vector{}); + mod.insts(section::function).push_back(then_label.release()); + spv_inst *yielded_then = then(mod); + mod.add(merge_label.get()); + auto then_last_label = get_last_label(mod); + if (!then_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + mod.insts(section::function).push_back(otherwise_label.release()); + spv_inst *yielded_otherwise = otherwise(mod); + mod.add(merge_label.get()); + auto otherwise_last_label = get_last_label(mod); + if (!otherwise_last_label) { + throw compilation_error(loc, status::internal_compiler_error); + } + + mod.insts(section::function).push_back(merge_label.release()); + return mod.add(return_ty, std::vector{ + PairIdRefIdRef{yielded_then, then_last_label}, + PairIdRefIdRef{yielded_otherwise, otherwise_last_label}}); +} + +auto make_math_unary_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_float = [&](IK op, spv_inst *ty, spv_inst *a) -> spv_inst * { + auto const make_ext_inst = [&](OpenCLEntrypoint ep) { + return mod.add(ty, unique.opencl_ext(), static_cast(ep), + std::vector{a}); + }; + switch (op) { + case IK::IK_cos: + return make_ext_inst(OpenCLEntrypoint::cos); + case IK::IK_sin: + return make_ext_inst(OpenCLEntrypoint::sin); + case IK::IK_exp: + return make_ext_inst(OpenCLEntrypoint::exp); + case IK::IK_exp2: + return make_ext_inst(OpenCLEntrypoint::exp2); + case IK::IK_native_cos: + return make_ext_inst(OpenCLEntrypoint::native_cos); + case IK::IK_native_sin: + return make_ext_inst(OpenCLEntrypoint::native_sin); + case IK::IK_native_exp: + return make_ext_inst(OpenCLEntrypoint::native_exp); + case IK::IK_native_exp2: + return make_ext_inst(OpenCLEntrypoint::native_exp2); + default: + throw compilation_error(loc, status::internal_compiler_error); + } + }; + auto const make_complex = [&](IK op, tinytc_type_t operand_ty, spv_inst *ty, + spv_inst *a) -> spv_inst * { + auto cty = component_type(operand_ty); + auto float_ty = get_spv_ty_non_coopmatrix(unique, cty); + auto const make_complex_exp = [&](auto exp_ep, auto cos_ep, auto sin_ep, + spv_inst *im_scale = nullptr) { + auto a0 = mod.add(float_ty, a, std::vector{0}); + spv_inst *a1 = mod.add(float_ty, a, std::vector{1}); + if (im_scale) { + a1 = mod.add(float_ty, a1, im_scale); + } + auto e = mod.add(float_ty, unique.opencl_ext(), + static_cast(exp_ep), std::vector{a0}); + auto c = mod.add(float_ty, unique.opencl_ext(), + static_cast(cos_ep), std::vector{a1}); + auto s = mod.add(float_ty, unique.opencl_ext(), + static_cast(sin_ep), std::vector{a1}); + auto r = mod.add(float_ty, e, c); + auto i = mod.add(float_ty, e, s); + auto dummy = mod.add(ty); + auto result = mod.add(ty, r, dummy, std::vector{0}); + return mod.add(ty, i, result, std::vector{1}); + }; + switch (op) { + case IK::IK_exp: + return make_complex_exp(OpenCLEntrypoint::exp, OpenCLEntrypoint::cos, + OpenCLEntrypoint::sin); + case IK::IK_exp2: + return make_complex_exp(OpenCLEntrypoint::exp2, OpenCLEntrypoint::cos, + OpenCLEntrypoint::sin, unique.constant(std::log(T{2}))); + case IK::IK_native_exp: + return make_complex_exp(OpenCLEntrypoint::native_exp, OpenCLEntrypoint::native_cos, + OpenCLEntrypoint::native_sin); + case IK::IK_native_exp2: + return make_complex_exp(OpenCLEntrypoint::native_exp2, OpenCLEntrypoint::native_cos, + OpenCLEntrypoint::native_sin, unique.constant(std::log(T{2}))); + default: + throw compilation_error(loc, status::internal_compiler_error); + } + }; + + auto ty = get_spv_ty_non_coopmatrix(unique, operand_ty); + auto unop = + visit(overloaded{[&](bf16_type &) -> spv_inst * { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + auto op_af = make_float(op, float_ty, af); + return mod.add(ty, op_af); + }, + [&](float_type &) -> spv_inst * { return make_float(op, ty, a); }, + [&](c32_type &) -> spv_inst * { + return make_complex.template operator()(op, operand_ty, ty, a); + }, + [&](c64_type &) -> spv_inst * { + return make_complex.template operator()(op, operand_ty, ty, a); + }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *operand_ty); + if (!unop) { + throw compilation_error(loc, status::internal_compiler_error); + } + return unop; +} + +void make_store(uniquifier &unique, store_flag flag, tinytc_type_t val_ty, address_space as, + spv_inst *pointer, spv_inst *value, location const &loc) { + auto &mod = unique.mod(); + auto const split_re_im = [&]() -> std::array, 2u> { + auto component_sty = component_type(val_ty); + auto float_ty = get_spv_ty_non_coopmatrix(unique, component_sty); + const auto storage_cls = address_space_to_storage_class(as); + auto pointer_ty = unique.pointer_ty(storage_cls, float_ty, alignment(component_sty)); + auto c0 = unique.constant(std::int32_t{0}); + auto c1 = unique.constant(std::int32_t{1}); + auto re_ptr = mod.add(pointer_ty, pointer, std::vector{c0}); + auto im_ptr = mod.add(pointer_ty, pointer, std::vector{c1}); + auto re_val = mod.add(float_ty, value, std::vector{0}); + auto im_val = mod.add(float_ty, value, std::vector{1}); + return {{{re_ptr, re_val}, {im_ptr, im_val}}}; + }; + auto const make_atomic_something = [&]() { + auto result_ty = get_spv_ty_non_coopmatrix(unique, val_ty); + auto scope = unique.constant(static_cast(Scope::Workgroup)); + auto semantics = unique.constant(static_cast(MemorySemantics::Relaxed)); + if (isa(*val_ty)) { + auto re_im = split_re_im(); + auto component_sty = component_type(val_ty); + auto float_ty = get_spv_ty_non_coopmatrix(unique, component_sty); + mod.add(float_ty, re_im[0][0], scope, semantics, re_im[0][1]); + mod.add(float_ty, re_im[1][0], scope, semantics, re_im[1][1]); + } else if (isa(*val_ty)) { + mod.add(result_ty, pointer, scope, semantics, value); + } else if (isa(*val_ty)) { + mod.add(result_ty, pointer, scope, semantics, value); + } else { + throw compilation_error(loc, status::spirv_unsupported_atomic_data_type); + } + }; + + if (flag != store_flag::regular && + (isa(*val_ty) || isa(*val_ty) || isa(*val_ty))) { + throw compilation_error(loc, status::spirv_unsupported_atomic_data_type); + } + + switch (flag) { + case store_flag::regular: + mod.add(pointer, value); + break; + case store_flag::atomic: { + auto scope = unique.constant(static_cast(Scope::Workgroup)); + auto semantics = unique.constant(static_cast(MemorySemantics::Relaxed)); + if (isa(*val_ty)) { + auto re_im = split_re_im(); + mod.add(re_im[0][0], scope, semantics, re_im[0][1]); + mod.add(re_im[1][0], scope, semantics, re_im[1][1]); + } else { + mod.add(pointer, scope, semantics, value); + } + break; + } + case store_flag::atomic_add: + make_atomic_something.template operator()(); + break; + case store_flag::atomic_max: + make_atomic_something.template operator()(); + break; + case store_flag::atomic_min: + make_atomic_something.template operator()(); + break; + } +} + +auto make_unary_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_boolean = [&](IK op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case IK::IK_not: + return mod.add(ty, a); + default: + break; + } + throw compilation_error(loc, status::ir_boolean_unsupported); + }; + auto const make_int = [&](IK op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case IK::IK_abs: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::s_abs), + std::vector{a}); + case IK::IK_neg: + return mod.add(ty, a); + case IK::IK_not: + return mod.add(ty, a); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_float = [&](IK op, spv_inst *ty, spv_inst *a) -> spv_inst * { + switch (op) { + case IK::IK_abs: + return mod.add(ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::fabs), + std::vector{a}); + case IK::IK_neg: + return mod.add(ty, a); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + auto const make_complex = [&](IK op, spv_inst *ty, spv_inst *float_ty, + spv_inst *a) -> spv_inst * { + switch (op) { + case IK::IK_abs: { + auto a2 = mod.add(ty, a, a); + auto a2_0 = mod.add(float_ty, a2, std::vector{0}); + auto a2_1 = mod.add(float_ty, a2, std::vector{1}); + auto a2_0p1 = mod.add(float_ty, a2_0, a2_1); + return mod.add(float_ty, unique.opencl_ext(), + static_cast(OpenCLEntrypoint::sqrt), + std::vector{a2_0p1}); + } + case IK::IK_neg: + return mod.add(ty, a); + case IK::IK_conj: { + auto a_im = mod.add(float_ty, a, std::vector{1}); + auto neg_a_im = mod.add(float_ty, a_im); + return mod.add(ty, neg_a_im, a, std::vector{1}); + } + case IK::IK_im: + return mod.add(float_ty, a, std::vector{1}); + case IK::IK_re: + return mod.add(float_ty, a, std::vector{0}); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); + }; + + auto ty = get_spv_ty_non_coopmatrix(unique, operand_ty); + auto unop = + visit(overloaded{[&](boolean_type &) -> spv_inst * { return make_boolean(op, ty, a); }, + [&](integer_type &) -> spv_inst * { return make_int(op, ty, a); }, + [&](bf16_type &) -> spv_inst * { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + auto op_af = make_float(op, float_ty, af); + return mod.add(ty, op_af); + }, + [&](float_type &) -> spv_inst * { return make_float(op, ty, a); }, + [&](complex_type &) -> spv_inst * { + auto float_ty = + get_spv_ty_non_coopmatrix(unique, component_type(operand_ty)); + return make_complex(op, ty, float_ty, a); + }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *operand_ty); + if (!unop) { + throw compilation_error(loc, status::internal_compiler_error); + } + return unop; +} + +auto make_subgroup_op(uniquifier &unique, tinytc_type_t op_ty, IK op, spv_inst *a, + location const &loc) -> spv_inst * { + auto &mod = unique.mod(); + auto const make_impl = [&](tinytc_type_t op_ty, GroupOperation group_op, + spv_inst *ty, spv_inst *a) -> spv_inst * { + auto scope = unique.constant(static_cast(Scope::Subgroup)); + auto unop = visit(overloaded{[&](integer_type &) -> spv_inst * { + return mod.add(ty, scope, group_op, a); + }, + [&](bf16_type &) -> spv_inst * { + auto float_ty = unique.float_ty(32); + auto af = mod.add(float_ty, a); + auto op_af = + mod.add(ty, scope, group_op, af); + return mod.add(ty, op_af); + }, + [&](float_type &) -> spv_inst * { + return mod.add(ty, scope, group_op, a); + }, + [&](complex_type &) -> spv_inst * { + return mod.add(ty, scope, group_op, a); + }, + [](tinytc_type &) -> spv_inst * { return nullptr; }}, + *op_ty); + if (!unop) { + throw compilation_error(loc, status::internal_compiler_error); + } + return unop; + }; + auto ty = get_spv_ty_non_coopmatrix(unique, op_ty); + struct add_ops { + using i = OpGroupIAdd; + using f = OpGroupFAdd; + }; + struct max_ops { + using i = OpGroupSMax; + using f = OpGroupFMax; + }; + struct min_ops { + using i = OpGroupSMin; + using f = OpGroupFMin; + }; + switch (op) { + case IK::IK_subgroup_exclusive_scan_add: + return make_impl.template operator()(op_ty, GroupOperation::ExclusiveScan, ty, a); + case IK::IK_subgroup_exclusive_scan_max: + return make_impl.template operator()(op_ty, GroupOperation::ExclusiveScan, ty, a); + case IK::IK_subgroup_exclusive_scan_min: + return make_impl.template operator()(op_ty, GroupOperation::ExclusiveScan, ty, a); + case IK::IK_subgroup_inclusive_scan_add: + return make_impl.template operator()(op_ty, GroupOperation::InclusiveScan, ty, a); + case IK::IK_subgroup_inclusive_scan_max: + return make_impl.template operator()(op_ty, GroupOperation::InclusiveScan, ty, a); + case IK::IK_subgroup_inclusive_scan_min: + return make_impl.template operator()(op_ty, GroupOperation::InclusiveScan, ty, a); + case IK::IK_subgroup_reduce_add: + return make_impl.template operator()(op_ty, GroupOperation::Reduce, ty, a); + case IK::IK_subgroup_reduce_max: + return make_impl.template operator()(op_ty, GroupOperation::Reduce, ty, a); + case IK::IK_subgroup_reduce_min: + return make_impl.template operator()(op_ty, GroupOperation::Reduce, ty, a); + default: + break; + } + throw compilation_error(loc, status::internal_compiler_error); +} + +} // namespace tinytc::spv + diff --git a/src/spv/converter_aux.hpp b/src/spv/converter_aux.hpp new file mode 100644 index 00000000..4c860211 --- /dev/null +++ b/src/spv/converter_aux.hpp @@ -0,0 +1,58 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CONVERTER_AUX_20250416_HPP +#define CONVERTER_AUX_20250416_HPP + +#include "node/inst_view.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +#include + +namespace tinytc::spv { + +class spv_inst; +class uniquifier; + +auto get_spv_index_ty(uniquifier &unique, tinytc_compiler_context_t ctx) -> spv_inst *; +auto get_spv_ty(uniquifier &unique, memref_type const *ty) -> spv_inst *; +auto get_spv_pointer_index_ty(uniquifier &unique, tinytc_compiler_context_t ctx, + address_space addrspace = address_space::global) -> spv_inst *; +auto get_spv_ty_non_coopmatrix(uniquifier &unique, tinytc_type_t ty) -> spv_inst *; + +auto get_last_label(tinytc_spv_mod &mod) -> spv_inst *; +auto make_binary_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst *; +auto make_binary_op_mixed_precision(uniquifier &unique, tinytc_type_t result_ty, IK op, + tinytc_type_t a_ty, spv_inst *a, tinytc_type_t b_ty, + spv_inst *b, location const &loc) -> spv_inst *; +auto make_cast(uniquifier &unique, tinytc_type_t to_ty, tinytc_type_t a_ty, spv_inst *a, + location const &loc) -> spv_inst *; +auto make_complex_mul(uniquifier &unique, spv_inst *ty, spv_inst *a, spv_inst *b, + bool conj_b = false) -> spv_inst *; +auto make_compare_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, spv_inst *b, + location const &loc) -> spv_inst *; +auto make_constant(uniquifier &unique, tinytc_type_t ty, constant_value_type const &val) + -> spv_inst *; +void make_conditional_execution(uniquifier &unique, spv_inst *condition, + std::function then); +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + spv_inst *otherwise, location const &loc) -> spv_inst *; +auto make_conditional_execution(uniquifier &unique, spv_inst *return_ty, spv_inst *condition, + std::function then, + std::function otherwise, + location const &loc) -> spv_inst *; +auto make_math_unary_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, + location const &loc) -> spv_inst *; +void make_store(uniquifier &unique, store_flag flag, tinytc_type_t val_ty, address_space as, + spv_inst *pointer, spv_inst *value, location const &loc); +auto make_unary_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, + location const &loc) -> spv_inst *; +auto make_subgroup_op(uniquifier &unique, tinytc_type_t operand_ty, IK op, spv_inst *a, + location const &loc) -> spv_inst *; + +} // namespace tinytc::spv + +#endif // CONVERTER_AUX_20250416_HPP diff --git a/src/spv/coopmatrix_impl.cpp b/src/spv/coopmatrix_impl.cpp new file mode 100644 index 00000000..d8722147 --- /dev/null +++ b/src/spv/coopmatrix_impl.cpp @@ -0,0 +1,706 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/coopmatrix_impl.hpp" +#include "codegen_tools.hpp" +#include "converter_aux.hpp" +#include "error.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "number.hpp" +#include "spv/dope_vector.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/matrix_walker.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +coopmatrix_impl::coopmatrix_impl(uniquifier &unique, core_config const &cfg, gcd_analysis_result g) + : unique_{&unique}, cfg_{cfg}, gcd_{std::move(g)} {} + +auto coopmatrix_impl::extract(cooperative_matrix_extract_inst in, spv_inst *mat) -> spv_inst * { + auto matt = get_coopmatrix_type(in.mat()); + auto matl = get_layout(cfg(), matt); + const auto idx = in.index(); + if (idx < 0 || idx >= matl.length) { + throw compilation_error(in.loc(), status::ir_out_of_bounds); + } + return extract(matl, mat, idx); +} +auto coopmatrix_impl::insert(cooperative_matrix_insert_inst in, spv_inst *val, spv_inst *mat) + -> spv_inst * { + auto matt = get_coopmatrix_type(in.mat()); + auto matl = get_layout(cfg(), matt); + const auto idx = in.index(); + if (idx < 0 || idx >= matl.length) { + throw compilation_error(in.loc(), status::ir_out_of_bounds); + } + return insert(matl, val, mat, idx); +} + +auto coopmatrix_impl::load(cooperative_matrix_load_inst in, dope_vector const &odv, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) -> spv_inst * { + auto ot = get_memref_type(in.operand()); + auto rt = get_coopmatrix_type(in.result()); + auto pointer_ty = get_spv_ty(*unique_, ot); + + auto layout = get_layout(cfg(), rt); + auto matrix_ty = spv_ty(layout); + auto interface_ty = spv_interface_ty(layout); + + auto shape = std::array{odv.shape(0), odv.shape(1)}; + auto stride = std::array{odv.stride(0), odv.stride(1)}; + if (in.t() == transpose::T) { + std::swap(pos0, pos1); + std::swap(shape[0], shape[1]); + std::swap(stride[0], stride[1]); + } + + auto walker = matrix_walker(*unique_, cfg().subgroup_size, layout, pos0, pos1, shape[0], + shape[1], stride[0], stride[1], in.checked()); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(matrix_ty); + + const auto ld = [&](tinytc_spv_mod &mod) -> spv_inst * { + auto pointer = mod.add(pointer_ty, operand, walker.offset(), + std::vector{}); + return mod.add(interface_ty, pointer); + }; + const auto ld_chk = [&](tinytc_spv_mod &) { + return make_conditional_execution(*unique_, interface_ty, walker.col_ok(), ld, + unique_->null_constant(interface_ty), in.loc()); + }; + auto const ld_block = [&](tinytc_spv_mod &mod) { + spv_inst *block_result = result; + for (std::int64_t u = 0; u < layout.length / layout.blocks; ++u) { + spv_inst *val = walker.needs_mask() || walker.cols_checked() ? ld_chk(mod) : ld(mod); + block_result = insert(layout, val, block_result, walker.component_no()); + + if (u < layout.cols - 1) { + walker.advance_column(); + } + } + return block_result; + }; + auto const ld_block_chk = [&](tinytc_spv_mod &) { + auto const ld_block_zero = [&](tinytc_spv_mod &) { + spv_inst *block_result = result; + for (std::int64_t u = 0; u < layout.length / layout.blocks; ++u) { + block_result = insert(layout, unique_->null_constant(interface_ty), block_result, + walker.component_no(u)); + } + return block_result; + }; + return make_conditional_execution(*unique_, matrix_ty, walker.row_ok(), ld_block, + ld_block_zero, in.loc()); + }; + + for (std::int64_t w = 0; w < layout.blocks; ++w) { + result = walker.rows_checked() ? ld_block_chk(mod) : ld_block(mod); + + if (w < layout.blocks - 1) { + walker.advance_block(); + } + } + return result; +} + +void coopmatrix_impl::store(cooperative_matrix_store_inst in, dope_vector const &odv, spv_inst *val, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) { + auto ot = get_memref_type(in.operand()); + auto vt = get_coopmatrix_type(in.val()); + + auto layout = get_layout(cfg(), vt); + + auto shape = std::array{odv.shape(0), odv.shape(1)}; + auto stride = std::array{odv.stride(0), odv.stride(1)}; + if (in.t() == transpose::T) { + std::swap(pos0, pos1); + std::swap(shape[0], shape[1]); + std::swap(stride[0], stride[1]); + } + + auto walker = matrix_walker(*unique_, cfg().subgroup_size, layout, pos0, pos1, shape[0], + shape[1], stride[0], stride[1], in.checked()); + + const std::int32_t cols_per_store = [&]() -> std::int32_t { + std::int32_t cols_per_store = 1; + const std::int32_t max_cols_per_store = 16; + const bool sty_ok = !isa(*layout.sty); + const bool transpose_ok = in.t() == transpose::T; + const bool checked_ok = + in.checked() != checked_flag::cols && in.checked() != checked_flag::both; + const bool layout_ok = layout.blocks1 == 1 && layout.rows >= cfg().subgroup_size; + if (sty_ok && transpose_ok && checked_ok && layout_ok) { + const std::int32_t num_cols = layout.length / layout.blocks; + while (2 * cols_per_store <= max_cols_per_store && + num_cols % (2 * cols_per_store) == 0) { + cols_per_store *= 2; + } + } + return cols_per_store; + }(); + spv_inst *io_ty = get_spv_ty_non_coopmatrix(unique(), layout.sty); + spv_inst *io_vec_ty = cols_per_store > 1 ? unique().vec_ty(io_ty, cols_per_store) : io_ty; + const auto pointer_ty = [&] { + const auto storage_cls = address_space_to_storage_class(ot->addrspace()); + const auto align = ot->element_alignment(); + return unique_->pointer_ty(storage_cls, io_vec_ty, align); + }(); + + auto &mod = unique_->mod(); + const auto st = [&](tinytc_spv_mod &mod) { + auto pointer = mod.add(pointer_ty, operand, walker.offset(), + std::vector{}); + spv_inst *val_ij = nullptr; + if (cols_per_store > 1) { + val_ij = mod.add(io_vec_ty); + for (std::int32_t c = 0; c < cols_per_store; ++c) { + const auto comp_no = layout.component_no(walker.col_no() + c, walker.block_no()); + spv_inst *v = extract(layout, val, comp_no); + val_ij = mod.add(io_vec_ty, v, val_ij, + std::vector{c}); + } + } else { + val_ij = extract(layout, val, walker.component_no()); + } + + make_store(*unique_, in.flag(), layout.sty, ot->addrspace(), pointer, val_ij, in.loc()); + }; + auto const st_block = [&](tinytc_spv_mod &mod) { + for (std::int64_t u = 0; u < layout.length / layout.blocks; u += cols_per_store) { + if (walker.needs_mask() || walker.cols_checked()) { + make_conditional_execution(*unique_, walker.col_ok(), st); + } else { + st(mod); + } + + if (u < layout.cols - 1) { + for (std::int32_t c = 0; c < cols_per_store; ++c) { + walker.advance_column(); + } + } + } + }; + + for (std::int64_t w = 0; w < layout.blocks; ++w) { + if (walker.rows_checked()) { + make_conditional_execution(*unique_, walker.row_ok(), st_block); + + } else { + st_block(mod); + } + + if (w < layout.blocks - 1) { + walker.advance_block(); + } + } +} + +auto coopmatrix_impl::mul_add(cooperative_matrix_mul_add_inst in, spv_inst *a, spv_inst *b, + spv_inst *c) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + auto bt = get_coopmatrix_type(in.b()); + auto ct = get_coopmatrix_type(in.c()); + auto rt = get_coopmatrix_type(in.result()); + + if (at->rows() % cfg().subgroup_size != 0) { + throw compilation_error(in.loc(), {&in.a()}, status::ir_unsupported_coopmatrix_shape); + } + if (ct->rows() % cfg().subgroup_size != 0) { + throw compilation_error(in.loc(), {&in.c()}, status::ir_unsupported_coopmatrix_shape); + } + if (rt->rows() % cfg().subgroup_size != 0) { + throw compilation_error(in.loc(), {}, status::ir_unsupported_coopmatrix_shape); + } + + auto al = get_layout(cfg(), at); + auto bl = get_layout(cfg(), bt); + auto cl = get_layout(cfg(), ct); + auto rl = get_layout(cfg(), rt); + + const auto a_ty = at->component_ty(); + const auto b_ty = bt->component_ty(); + const auto b_component_ty = component_type(b_ty); + const auto c_ty = ct->component_ty(); + const auto r_ty = rt->component_ty(); + const auto spv_b_ty = get_spv_ty_non_coopmatrix(*unique_, b_ty); + const auto spv_b_component_ty = get_spv_ty_non_coopmatrix(*unique_, b_component_ty); + const auto spv_c_ty = get_spv_ty_non_coopmatrix(*unique_, c_ty); + const bool a_and_b_complex = isa(*a_ty) && isa(*b_ty); + + auto &mod = unique_->mod(); + auto result_ty = spv_ty(rl); + spv_inst *result = mod.add(result_ty); + spv_inst *imaginary_unit = + a_and_b_complex ? make_constant(*unique_, c_ty, std::complex{0.0, 1.0}) : nullptr; + + constexpr std::int64_t nbb = 4; + auto broadcast_scope = unique_->constant(static_cast(Scope::Subgroup)); + for (std::int64_t m_block = 0; m_block < rl.blocks; ++m_block) { + for (std::int64_t nb = 0; nb < rl.cols; nb += nbb) { + auto c_block = std::array{}; + auto c_im_block = std::array{}; + std::fill(std::begin(c_block), std::end(c_block), nullptr); + std::fill(std::begin(c_im_block), std::end(c_im_block), nullptr); + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + c_block[n - nb] = extract(cl, c, cl.component_no(n, m_block)); + if (a_and_b_complex) { + c_im_block[n - nb] = unique_->null_constant(spv_c_ty); + } + } + } + + for (std::int64_t k = 0; k < bl.rows * bl.blocks; ++k) { + auto a_mk = extract(al, a, al.component_no(k, m_block)); + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + /** For matrix B we have L(i,k_1,j,k_2) = i + k_1*I + j*I*K_1 + + * k_2*I*K_1*J. + * + * The n loop variable is equal to j and the k loop variable fuses + * iteration over indices i,k_1,k_2, such that k = i + k_1*I + + * k_2*I*K_1: The k loop variable fuses iteration over indices + * i,k_1,k_2, We recover i+k_1*I = k%(IK_1) k_2 = k/(IK_1) + * + * We have + * p + vS = L + * Therefore, + * p = L%S + * v = L/S + */ + const auto IK_1 = bl.rows * bl.blocks1; + const auto L = k % IK_1 + n * IK_1 + (k / IK_1) * IK_1 * bl.cols; + const auto p = + unique_->constant(static_cast(L % cfg().subgroup_size)); + const auto v = static_cast(L / cfg().subgroup_size); + + spv_inst *b_kn = extract(bl, b, v); + b_kn = mod.add(spv_b_ty, broadcast_scope, b_kn, p); + + auto &c_mn = c_block[n - nb]; + if (a_and_b_complex) { + auto &c_im_mn = c_im_block[n - nb]; + auto b_kn_re = mod.add( + spv_b_component_ty, b_kn, std::vector{0}); + auto b_kn_im = mod.add( + spv_b_component_ty, b_kn, std::vector{1}); + + auto ab_mn = make_binary_op_mixed_precision(*unique_, c_ty, IK::IK_mul, + a_ty, a_mk, b_component_ty, + b_kn_re, in.loc()); + c_mn = + make_binary_op(*unique_, c_ty, IK::IK_add, ab_mn, c_mn, in.loc()); + auto ab_im_mn = make_binary_op_mixed_precision( + *unique_, c_ty, IK::IK_mul, a_ty, a_mk, b_component_ty, b_kn_im, + in.loc()); + c_im_mn = make_binary_op(*unique_, c_ty, IK::IK_add, ab_im_mn, c_im_mn, + in.loc()); + } else { + auto ab_mn = make_binary_op_mixed_precision( + *unique_, c_ty, IK::IK_mul, a_ty, a_mk, b_ty, b_kn, in.loc()); + c_mn = + make_binary_op(*unique_, c_ty, IK::IK_add, ab_mn, c_mn, in.loc()); + } + } + } + } + if (a_and_b_complex) { + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + auto &c_mn = c_block[n - nb]; + auto &c_im_mn = c_im_block[n - nb]; + auto c_im_mn_times_i = make_binary_op(*unique_, c_ty, IK::IK_mul, c_im_mn, + imaginary_unit, in.loc()); + c_mn = make_binary_op(*unique_, c_ty, IK::IK_add, c_mn, c_im_mn_times_i, + in.loc()); + } + } + } + for (std::int64_t n = nb; n < nb + nbb; ++n) { + if (n < rl.cols) { + auto &c_mn = c_block[n - nb]; + if (c_ty != r_ty) { + c_mn = make_cast(*unique_, r_ty, c_ty, c_mn, in.loc()); + } + result = insert(cl, c_mn, result, n + m_block * rl.cols); + } + } + } + } + return result; +} + +void coopmatrix_impl::prefetch(cooperative_matrix_prefetch_inst, dope_vector const &, spv_inst *, + spv_inst *, spv_inst *) {} + +auto coopmatrix_impl::reduce(cooperative_matrix_reduce_inst in, spv_inst *a) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + const auto sgs = cfg().subgroup_size; + + if (at->rows() % sgs != 0) { + throw compilation_error(in.loc(), {&in.a()}, status::ir_unsupported_coopmatrix_shape); + } + + auto rt = get_coopmatrix_type(in.result()); + auto rl = get_layout(cfg(), rt); + auto al = get_layout(cfg(), at); + auto matrix_ty = spv_ty(rl); + auto sty = rt->component_ty(); + auto ty = get_spv_ty_non_coopmatrix(*unique_, sty); + auto bool_ty = unique_->bool_ty(); + auto i32_ty = unique_->int_ty(32); + + auto const binary_arith = [&in](IK op) { + switch (op) { + case IK::IK_cooperative_matrix_reduce_add: + return IK::IK_add; + case IK::IK_cooperative_matrix_reduce_max: + return IK::IK_max; + case IK::IK_cooperative_matrix_reduce_min: + return IK::IK_min; + default: + break; + } + throw compilation_error(in.loc(), status::internal_compiler_error); + }(in.get().type_id()); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(matrix_ty); + if (in.mode() == reduce_mode::column) { + auto p = unique_->load_builtin(BuiltIn::SubgroupLocalInvocationId); + auto scope = unique_->constant(static_cast(Scope::Subgroup)); + auto c0 = unique_->constant(std::int32_t{0}); + auto cnull = unique_->null_constant(ty); + + for (std::int32_t j0 = 0; j0 < al.shape1; j0 += sgs) { + auto x = std::vector(sgs, nullptr); + for (std::int32_t b = 0; b < al.blocks; ++b) { + for (std::int32_t i = 0; i < sgs; ++i) { + if (j0 + i < al.shape1) { + auto a_i = extract(al, a, al.component_no(j0 + i, b)); + x[i] = + x[i] ? make_binary_op(*unique_, sty, binary_arith, x[i], a_i, in.loc()) + : a_i; + } else if (x[i] == nullptr) { + x[i] = cnull; + } + } + } + for (std::int32_t v = 1; v < sgs; v *= 2) { + auto cv = unique_->constant(v); + spv_inst *cond = mod.add(i32_ty, p, cv); + cond = mod.add(bool_ty, cond, c0); + for (int i = 0; i < sgs / v; i += 2) { + auto xip1_up = mod.add(ty, scope, x[i + 1], cv); + auto xi_down = mod.add(ty, scope, x[i], cv); + auto s1 = mod.add(ty, cond, x[i], xip1_up); + auto s2 = mod.add(ty, cond, xi_down, x[i + 1]); + x[i / 2] = make_binary_op(*unique_, sty, binary_arith, s1, s2, in.loc()); + } + } + result = insert(rl, x[0], result, j0 / sgs); + } + } else { + for (std::int32_t b = 0; b < al.blocks; ++b) { + spv_inst *sum = nullptr; + for (std::int32_t j = 0; j < al.length / al.blocks; ++j) { + auto a_i = extract(al, a, al.component_no(j, b)); + sum = sum ? make_binary_op(*unique_, sty, binary_arith, sum, a_i, in.loc()) : a_i; + } + result = insert(rl, sum, result, rl.component_no(0, b)); + } + } + + return result; +} + +auto coopmatrix_impl::scale(cooperative_matrix_scale_inst in, spv_inst *a, spv_inst *b) + -> spv_inst * { + auto rt = get_coopmatrix_type(in.result()); + auto rl = get_layout(cfg(), rt); + auto bl = get_layout(cfg(), get_coopmatrix_type(in.b())); + auto sty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto b_v = extract(bl, b, v); + auto r_v = make_binary_op(*unique_, sty, IK::IK_mul, a, b_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::arith(arith_inst in, spv_inst *a, spv_inst *b) -> spv_inst * { + auto rt = get_coopmatrix_type(in.result()); + auto rl = get_layout(cfg(), rt); + auto al = get_layout(cfg(), get_coopmatrix_type(in.a())); + auto bl = get_layout(cfg(), get_coopmatrix_type(in.b())); + auto sty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto a_v = extract(al, a, v); + auto b_v = extract(bl, b, v); + auto r_v = make_binary_op(*unique_, sty, in.get().type_id(), a_v, b_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::arith_unary(arith_unary_inst in, spv_inst *a) -> spv_inst * { + auto al = get_layout(cfg(), get_coopmatrix_type(in.a())); + auto rt = get_coopmatrix_type(in.result()); + auto rl = get_layout(cfg(), rt); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto a_v = extract(al, a, v); + auto r_v = make_unary_op(*unique_, al.sty, in.get().type_id(), a_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::cast(cast_inst in, spv_inst *a) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + auto al = get_layout(cfg(), at); + auto a_ty = at->component_ty(); + auto rt = get_coopmatrix_type(in.result()); + auto rl = get_layout(cfg(), rt); + auto r_ty = rt->component_ty(); + auto ty = spv_ty(rl); + + auto &mod = unique_->mod(); + spv_inst *result = mod.add(ty); + + const auto P = + rt->use() == matrix_use::b && at->use() == matrix_use::acc + ? std::function([&](LiteralInteger v) -> LiteralInteger { + /** + * Using that M >= S we have for matrix_b + * L_b(i,k_1,j,k_2) = i + k_1*S + j*S*K_1 + k_2*S*K_1*J. + * + * We have + * p_b + v_bS = L_b + * + * Recovering i,k_1,j,k_2 from L_b we have + * i = L_b%S = p_b + * k_1 = L_b/S%K_1 = v_b%K_1 + * j = L_b/(SK_1)%J = v_b/K_1%J + * k_2 = L_b/(SK_1J) = v_b/(K_1J) + * + * Let k=k_1 + k_2K_1, and L_1, L_2 be the block sizes of matrix + * acc. We have L_{acc} = i + (k%L_1)*S + j*S*L_1 + + * (k/L_1)*S*L_1*J. + * + * Recovering p_{acc}, v_{acc} from + * p_{acc} + v_{acc}S = L_{acc} + * we have + * p_{acc} = L_{acc}%S = p_b + * v_{acc} = L_{acc}/S = k%L_1 + j*L_1 + (k/L_1)*L_1*J + * + * If M < S, then we have K_1=K_2=L_1=L_2=1, and there is no layout + * transformation. The code below just returns v - the identity - + * if M < S. + */ + auto const k_1 = v % rl.blocks1; + auto const j = v / rl.blocks1 % rl.cols; + auto const k_2 = v / (rl.blocks1 * rl.cols); + auto const k = k_1 + k_2 * rl.blocks1; + return k % al.blocks1 + j * al.blocks1 + (k / al.blocks1) * al.blocks1 * al.cols; + }) + : std::function([](LiteralInteger v) -> LiteralInteger { return v; }); + for (LiteralInteger v = 0; v < static_cast(rl.length); ++v) { + auto a_v = extract(al, a, P(v)); + auto r_v = make_cast(*unique_, r_ty, a_ty, a_v, in.loc()); + result = insert(rl, r_v, result, v); + } + + return result; +} + +auto coopmatrix_impl::constant(constant_inst in) -> spv_inst * { + auto rt = get_coopmatrix_type(in.result()); + auto rl = get_layout(cfg(), rt); + auto sty = rt->component_ty(); + auto spv_result_ty = spv_ty(rl); + + if (in.is_zero()) { + return unique_->null_constant(spv_result_ty); + } + if (rl.length == 1) { + return make_constant(*unique_, sty, in.value()); + } + + auto const init_vector = [&]() { + if (isa(*sty)) { + const auto c = std::get>(in.value()); + auto comp_ty = component_type(sty); + auto re = make_constant(*unique_, comp_ty, c.real()); + auto im = make_constant(*unique_, comp_ty, c.imag()); + auto vec = std::vector(2 * rl.length); + for (std::int64_t v = 0; v < rl.length; ++v) { + vec[2 * v] = re; + vec[2 * v + 1] = im; + } + return vec; + } else if (rl.ops_per_chan > 1) { + auto cst = std::visit( + overloaded{[&](std::int64_t i) -> spv_inst * { + if (isa(*rl.sty)) { + auto v8 = + std::bit_cast(static_cast(i)); + return unique_->constant( + std::int32_t{v8 | (v8 << 8) | (v8 << 16) | (v8 << 24)}); + } else if (isa(*rl.sty)) { + auto v16 = + std::bit_cast(static_cast(i)); + return unique_->constant(std::int32_t{v16 | (v16 << 16)}); + } + return nullptr; + }, + [&](double d) -> spv_inst * { + const float f = static_cast(d); + if (isa(*rl.sty)) { + std::uint16_t v16 = bfloat16{f}.bits(); + return unique_->constant(std::int32_t{v16 | (v16 << 16)}); + } else if (isa(*rl.sty)) { + std::uint16_t v16 = half{f}.bits(); + return unique_->constant(std::int32_t{v16 | (v16 << 16)}); + } + return nullptr; + }, + [&](auto const &) -> spv_inst * { return nullptr; }}, + in.value()); + if (!cst) { + throw status::internal_compiler_error; + } + return std::vector(rl.length, cst); + } + return std::vector(rl.length, make_constant(*unique_, sty, in.value())); + }; + return unique_->mod().add_to(section::type_const_var, spv_result_ty, + init_vector()); +} + +auto coopmatrix_impl::spv_interface_ty(coopmatrix_layout const &layout) -> spv_inst * { + return get_spv_ty_non_coopmatrix(*unique_, layout.sty); +} + +auto coopmatrix_impl::spv_storage_ty(coopmatrix_layout const &layout) -> spv_inst * { + if (layout.ops_per_chan > 1) { + if (layout.ops_per_chan * size(layout.sty) != 4) { + throw status::internal_compiler_error; + } + return unique_->int_ty(32); + } + return get_spv_ty_non_coopmatrix(*unique_, component_type(layout.sty)); +} + +auto coopmatrix_impl::spv_ty(coopmatrix_layout const &layout) -> spv_inst * { + if (layout.length == 1) { + return spv_interface_ty(layout); + } + const auto length = + static_cast(component_count(layout.sty)) * layout.length / layout.ops_per_chan; + auto storage_ty = spv_storage_ty(layout); + return length > 1 ? unique_->vec_ty(storage_ty, length) : storage_ty; +} + +auto coopmatrix_impl::spv_ty(coopmatrix_type const *ct) -> spv_inst * { + return spv_ty(get_layout(cfg(), ct)); +} + +auto coopmatrix_impl::extract(coopmatrix_layout const &layout, spv_inst *mat, LiteralInteger v) + -> spv_inst * { + if (layout.length == 1) { + return mat; + } + const auto ty = spv_interface_ty(layout); + auto &mod = unique_->mod(); + if (isa(*layout.sty)) { + const auto storage_ty = spv_storage_ty(layout); + auto re = mod.add(storage_ty, mat, std::vector{2 * v}); + auto im = mod.add(storage_ty, mat, std::vector{2 * v + 1}); + return mod.add(ty, std::vector{re, im}); + } else if (layout.ops_per_chan > 1) { + if (layout.blocks1 != 1) { + throw status::internal_compiler_error; + } + const auto storage_ty = spv_storage_ty(layout); + const auto chan_ty = unique_->vec_ty(ty, layout.ops_per_chan); + spv_inst *val = + layout.length > layout.ops_per_chan + ? mod.add(storage_ty, mat, std::vector{v / layout.ops_per_chan}) + : mat; + val = mod.add(chan_ty, val); + return mod.add(ty, val, std::vector{v % layout.ops_per_chan}); + } + return mod.add(ty, mat, std::vector{v}); +} +auto coopmatrix_impl::insert(coopmatrix_layout const &layout, spv_inst *val, spv_inst *mat, + LiteralInteger v) -> spv_inst * { + if (layout.length == 1) { + return val; + } + auto matrix_ty = spv_ty(layout); + auto &mod = unique_->mod(); + if (isa(*layout.sty)) { + const auto storage_ty = spv_storage_ty(layout); + auto re = mod.add(storage_ty, val, std::vector{0}); + auto im = mod.add(storage_ty, val, std::vector{1}); + auto tmp = mod.add(matrix_ty, re, mat, std::vector{2 * v}); + return mod.add(matrix_ty, im, tmp, std::vector{2 * v + 1}); + } else if (layout.ops_per_chan > 1) { + if (layout.blocks1 != 1) { + throw status::internal_compiler_error; + } + const auto storage_ty = spv_storage_ty(layout); + const auto channels_ty = unique_->vec_ty(spv_interface_ty(layout), layout.ops_per_chan); + const auto entry_no = std::vector{v / layout.ops_per_chan}; + spv_inst *channels = layout.length > layout.ops_per_chan + ? mod.add(storage_ty, mat, entry_no) + : mat; + channels = mod.add(channels_ty, channels); + channels = mod.add(channels_ty, val, channels, + std::vector{v % layout.ops_per_chan}); + channels = mod.add(storage_ty, channels); + return layout.length > layout.ops_per_chan + ? mod.add(matrix_ty, channels, mat, entry_no) + : channels; + } + return mod.add(matrix_ty, val, mat, std::vector{v}); +} + +} // namespace tinytc::spv diff --git a/src/spv/coopmatrix_impl.hpp b/src/spv/coopmatrix_impl.hpp new file mode 100644 index 00000000..3614c551 --- /dev/null +++ b/src/spv/coopmatrix_impl.hpp @@ -0,0 +1,76 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_IMPL_20250415_HPP +#define COOPMATRIX_IMPL_20250415_HPP + +#include "analysis/gcd.hpp" +#include "coopmatrix_layout.hpp" // IWYU pragma: keep +#include "device_info.hpp" +#include "spv/defs.hpp" +#include "tinytc/types.hpp" + +#include + +namespace tinytc { +class arith_inst; +class arith_unary_inst; +class cooperative_matrix_reduce_inst; +} // namespace tinytc + +namespace tinytc::spv { + +class dope_vector; +class uniquifier; + +class coopmatrix_impl { + public: + coopmatrix_impl(uniquifier &unique, core_config const &cfg, gcd_analysis_result g); + virtual ~coopmatrix_impl() = default; + + inline auto gcd() const -> gcd_analysis_result const & { return gcd_; } + inline void gcd(gcd_analysis_result g) { gcd_ = std::move(g); } + + auto cfg() const -> core_config const & { return cfg_; } + inline void cfg(core_config const &cfg) { cfg_ = cfg; } + + virtual auto extract(cooperative_matrix_extract_inst in, spv_inst *mat) -> spv_inst *; + virtual auto insert(cooperative_matrix_insert_inst in, spv_inst *val, spv_inst *mat) + -> spv_inst *; + virtual auto load(cooperative_matrix_load_inst in, dope_vector const &odv, spv_inst *operand, + spv_inst *pos0, spv_inst *pos1) -> spv_inst *; + virtual auto mul_add(cooperative_matrix_mul_add_inst in, spv_inst *a, spv_inst *b, spv_inst *c) + -> spv_inst *; + virtual void prefetch(cooperative_matrix_prefetch_inst in, dope_vector const &odv, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1); + virtual auto reduce(cooperative_matrix_reduce_inst in, spv_inst *a) -> spv_inst *; + virtual auto scale(cooperative_matrix_scale_inst in, spv_inst *a, spv_inst *b) -> spv_inst *; + virtual void store(cooperative_matrix_store_inst in, dope_vector const &odv, spv_inst *val, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1); + + virtual auto arith(arith_inst in, spv_inst *a, spv_inst *b) -> spv_inst *; + virtual auto arith_unary(arith_unary_inst in, spv_inst *a) -> spv_inst *; + virtual auto cast(cast_inst in, spv_inst *a) -> spv_inst *; + virtual auto constant(constant_inst in) -> spv_inst *; + + virtual auto spv_ty(coopmatrix_type const *ct) -> spv_inst *; + + protected: + auto spv_interface_ty(coopmatrix_layout const &layout) -> spv_inst *; + auto spv_storage_ty(coopmatrix_layout const &layout) -> spv_inst *; + auto spv_ty(coopmatrix_layout const &layout) -> spv_inst *; + auto extract(coopmatrix_layout const &layout, spv_inst *mat, LiteralInteger v) -> spv_inst *; + auto insert(coopmatrix_layout const &layout, spv_inst *val, spv_inst *mat, LiteralInteger v) + -> spv_inst *; + + inline auto unique() -> uniquifier & { return *unique_; } + + private: + uniquifier *unique_; + core_config cfg_; + gcd_analysis_result gcd_; +}; + +} // namespace tinytc::spv + +#endif // COOPMATRIX_IMPL_20250415_HPP diff --git a/src/spv/coopmatrix_impl_block.cpp b/src/spv/coopmatrix_impl_block.cpp new file mode 100644 index 00000000..c950977a --- /dev/null +++ b/src/spv/coopmatrix_impl_block.cpp @@ -0,0 +1,289 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/coopmatrix_impl_block.hpp" +#include "analysis/gcd.hpp" +#include "codegen_tools.hpp" +#include "converter_aux.hpp" +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "spv/defs.hpp" +#include "spv/dope_vector.hpp" +#include "spv/instructions.hpp" +#include "spv/matrix_walker.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/math.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto max_block_io_vec_size(tinytc_type_t ty) -> std::int64_t { + return visit(overloaded{[](i8_type &) { return 16; }, [](i16_type &) { return 16; }, + [](tinytc_type &) { return 8; }}, + + *ty); +} + +auto coopmatrix_impl_block::load(cooperative_matrix_load_inst in, dope_vector const &odv, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) -> spv_inst * { + const auto ot = get_memref_type(in.operand()); + const auto rt = get_coopmatrix_type(in.result()); + const auto layout = get_layout(cfg(), rt); + const auto sty = layout.sty; + + const std::int32_t required_alignment = ot->addrspace() == address_space::global ? 4 : 16; + + const bool layout_ok = layout.rows >= cfg().subgroup_size; + const bool transpose_ok = in.t() == transpose::N; + const bool alignment_ok = is_aligned(required_alignment, in.operand(), in.pos0()); + const bool checked_ok = + in.checked() == checked_flag::none || in.checked() == checked_flag::cols; + const bool sty_ok = !isa(*sty); // We do not have 16 byte/lane block loads + if (!layout_ok || !transpose_ok || !alignment_ok || !checked_ok || !sty_ok) { + return coopmatrix_impl::load(in, odv, operand, pos0, pos1); + } + + auto walker = matrix_walker(unique(), cfg().subgroup_size, layout, pos0, pos1, odv.shape(0), + odv.shape(1), odv.stride(0), odv.stride(1), in.checked(), 0); + + const auto io_sty = get_io_sty(sty); + const std::int32_t blocks_per_load = + is_positive_power_of_two(layout.blocks) + ? std::min(layout.blocks, max_block_io_vec_size(io_sty)) + : 1; + const std::int32_t cols_per_load = [&]() -> std::int32_t { + const auto ot = get_memref_type(in.operand()); + const bool is_contiguous = ot->dim() == 2 && ot->shape(0) == rt->shape(0) && + ot->stride(0) == 1 && ot->stride(1) == ot->shape(0); + std::int32_t cols_per_load = 1; + if (is_contiguous && !walker.cols_checked()) { + const std::int32_t max_cols_per_load = max_block_io_vec_size(io_sty) / blocks_per_load; + const std::int32_t num_cols = layout.length / layout.blocks; + while (2 * cols_per_load <= max_cols_per_load && num_cols % (2 * cols_per_load) == 0) { + cols_per_load *= 2; + } + } + return cols_per_load; + }(); + + const auto matrix_ty = spv_ty(layout); + const auto interface_ty = spv_interface_ty(layout); + auto io_ty = get_spv_ty_non_coopmatrix(unique(), io_sty); + const auto io_vec_size = blocks_per_load * cols_per_load; + spv_inst *io_vec_ty = io_vec_size > 1 ? unique().vec_ty(io_ty, io_vec_size) : io_ty; + const auto pointer_ty = [&] { + auto ot = get_memref_type(in.operand()); + const auto storage_cls = address_space_to_storage_class(ot->addrspace()); + const auto align = ot->element_alignment(); + return unique().pointer_ty(storage_cls, io_ty, align); + }(); + + auto &mod = unique().mod(); + operand = mod.add(pointer_ty, operand); + spv_inst *result = mod.add(matrix_ty); + + const auto ld = [&](tinytc_spv_mod &mod) -> spv_inst * { + spv_inst *offset = walker.offset(); + auto pointer = mod.add(pointer_ty, operand, offset, + std::vector{}); + return mod.add(io_vec_ty, pointer); + }; + const auto ld_chk = [&](tinytc_spv_mod &) { + return make_conditional_execution(unique(), interface_ty, walker.col_ok(), ld, + unique().null_constant(io_vec_ty), in.loc()); + }; + auto const ld_block = [&](tinytc_spv_mod &mod) { + spv_inst *block_result = result; + for (std::int64_t u = 0; u < layout.length / layout.blocks; u += cols_per_load) { + spv_inst *val = walker.needs_mask() || walker.cols_checked() ? ld_chk(mod) : ld(mod); + if (io_vec_size > 1) { + for (std::int32_t c = 0; c < cols_per_load; ++c) { + for (std::int32_t b = 0; b < blocks_per_load; ++b) { + spv_inst *v = mod.add( + io_ty, val, std::vector{b + c * blocks_per_load}); + v = mod.add(interface_ty, v); + const auto comp_no = + layout.component_no(walker.col_no() + c, walker.block_no() + b); + block_result = insert(layout, v, block_result, comp_no); + } + } + } else { + val = mod.add(interface_ty, val); + block_result = insert(layout, val, block_result, walker.component_no()); + } + + if (u < layout.cols - 1) { + for (std::int32_t c = 0; c < cols_per_load; ++c) { + walker.advance_column(); + } + } + } + return block_result; + }; + + for (std::int64_t w = 0; w < layout.blocks; w += blocks_per_load) { + result = ld_block(mod); + + if (w < layout.blocks - blocks_per_load) { + for (std::int32_t b = 0; b < blocks_per_load; ++b) { + walker.advance_block(); + } + } + } + return result; +} + +void coopmatrix_impl_block::store(cooperative_matrix_store_inst in, dope_vector const &odv, + spv_inst *val, spv_inst *operand, spv_inst *pos0, + spv_inst *pos1) { + constexpr std::int32_t required_alignment = 16; + + auto vt = get_coopmatrix_type(in.val()); + auto layout = get_layout(cfg(), vt); + auto sty = vt->component_ty(); + + const bool layout_ok = layout.rows >= cfg().subgroup_size; + const bool transpose_ok = in.t() == transpose::N; + const bool flag_ok = in.flag() == store_flag::regular; + const bool alignment_ok = is_aligned(required_alignment, in.operand(), in.pos0()); + const bool checked_ok = + in.checked() == checked_flag::none || in.checked() == checked_flag::cols; + const bool sty_ok = !isa(*sty); // We do not have 16 byte/lane block writes + if (!layout_ok || !transpose_ok || !flag_ok || !alignment_ok || !checked_ok || !sty_ok) { + coopmatrix_impl::store(in, odv, val, operand, pos0, pos1); + return; + } + + auto walker = matrix_walker(unique(), cfg().subgroup_size, layout, pos0, pos1, odv.shape(0), + odv.shape(1), odv.stride(0), odv.stride(1), in.checked(), 0); + + const auto io_sty = get_io_sty(sty); + const std::int32_t blocks_per_store = + is_positive_power_of_two(layout.blocks) + ? std::min(layout.blocks, max_block_io_vec_size(io_sty)) + : 1; + const std::int32_t cols_per_store = [&]() -> std::int32_t { + const auto ot = get_memref_type(in.operand()); + const bool is_contiguous = ot->dim() == 2 && ot->shape(0) == vt->shape(0) && + ot->stride(0) == 1 && ot->stride(1) == ot->shape(0); + std::int32_t cols_per_store = 1; + if (is_contiguous && !walker.cols_checked()) { + const std::int32_t max_cols_per_store = + max_block_io_vec_size(io_sty) / blocks_per_store; + const std::int32_t num_cols = layout.length / layout.blocks; + while (2 * cols_per_store <= max_cols_per_store && + num_cols % (2 * cols_per_store) == 0) { + cols_per_store *= 2; + } + } + return cols_per_store; + }(); + + auto io_ty = get_spv_ty_non_coopmatrix(unique(), io_sty); + auto const io_vec_size = blocks_per_store * cols_per_store; + spv_inst *io_vec_ty = io_vec_size > 1 ? unique().vec_ty(io_ty, io_vec_size) : io_ty; + const auto pointer_ty = [&] { + auto ot = get_memref_type(in.operand()); + const auto storage_cls = address_space_to_storage_class(ot->addrspace()); + const auto align = std::max(16, ot->element_alignment()); + return unique().pointer_ty(storage_cls, io_ty, align); + }(); + + auto &mod = unique().mod(); + operand = mod.add(pointer_ty, operand); + + for (std::int64_t w = 0; w < layout.blocks; w += blocks_per_store) { + auto const st_block = [&](tinytc_spv_mod &mod) { + for (std::int64_t u = 0; u < layout.length / layout.blocks; u += cols_per_store) { + const auto st = [&](tinytc_spv_mod &mod) { + spv_inst *offset = walker.offset(); + auto pointer = mod.add(pointer_ty, operand, offset, + std::vector{}); + spv_inst *val_ij = nullptr; + if (io_vec_size > 1) { + val_ij = mod.add(io_vec_ty); + for (std::int32_t c = 0; c < cols_per_store; ++c) { + for (std::int32_t b = 0; b < blocks_per_store; ++b) { + const auto comp_no = + layout.component_no(walker.col_no() + c, walker.block_no() + b); + spv_inst *v = extract(layout, val, comp_no); + v = mod.add(io_ty, v); + val_ij = mod.add( + io_vec_ty, v, val_ij, + std::vector{b + c * blocks_per_store}); + } + } + } else { + val_ij = extract(layout, val, walker.component_no()); + val_ij = mod.add(io_ty, val_ij); + } + mod.add(pointer, val_ij); + }; + if (walker.needs_mask() || walker.cols_checked()) { + make_conditional_execution(unique(), walker.col_ok(), st); + } else { + st(mod); + } + + if (u < layout.cols - 1) { + for (std::int32_t c = 0; c < cols_per_store; ++c) { + walker.advance_column(); + } + } + } + }; + st_block(mod); + + if (w < layout.blocks - blocks_per_store) { + for (std::int32_t b = 0; b < blocks_per_store; ++b) { + walker.advance_block(); + } + } + } +} + +auto coopmatrix_impl_block::get_io_sty(tinytc_type_t ty) -> tinytc_type_t { + return visit( + overloaded{[](bf16_type &ty) -> tinytc_type_t { return i16_type::get(ty.context()); }, + [](f16_type &ty) -> tinytc_type_t { return i16_type::get(ty.context()); }, + [](f32_type &ty) -> tinytc_type_t { return i32_type::get(ty.context()); }, + [](f64_type &ty) -> tinytc_type_t { return i64_type::get(ty.context()); }, + [](c32_type &ty) -> tinytc_type_t { return i64_type::get(ty.context()); }, + [](tinytc_type &ty) -> tinytc_type_t { return &ty; }}, + + *ty); +} + +auto coopmatrix_impl_block::is_aligned(std::int32_t alignment, tinytc_value const &operand, + tinytc_value const &pos0) -> bool { + auto const mt = get_memref_type(operand); + const auto sty_size = size(mt->element_ty()); + if (sty_size >= static_cast(alignment)) { + return true; + } + if (auto mi = gcd().get_memref_if(operand); mi) { + const bool base_ok = (mi->offset_gcd() * sty_size) % alignment == 0; + const bool pos0_ok = (gcd().get(pos0) * sty_size) % alignment == 0; + const bool stride_ok = + mt->stride(0) == 1 && (mi->stride_gcd()[1] * sty_size) % alignment == 0; + + return base_ok && pos0_ok && stride_ok; + } + return false; +} + +} // namespace tinytc::spv diff --git a/src/spv/coopmatrix_impl_block.hpp b/src/spv/coopmatrix_impl_block.hpp new file mode 100644 index 00000000..f06e7312 --- /dev/null +++ b/src/spv/coopmatrix_impl_block.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_IMPL_BLOCK_20250428_HPP +#define COOPMATRIX_IMPL_BLOCK_20250428_HPP + +#include "spv/coopmatrix_impl.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc::spv { + +class coopmatrix_impl_block : public coopmatrix_impl { + public: + using coopmatrix_impl::coopmatrix_impl; + + auto load(cooperative_matrix_load_inst in, dope_vector const &odv, spv_inst *operand, + spv_inst *pos0, spv_inst *pos1) -> spv_inst * override; + void store(cooperative_matrix_store_inst in, dope_vector const &odv, spv_inst *val, + spv_inst *operand, spv_inst *pos0, spv_inst *pos1) override; + + private: + auto get_io_sty(tinytc_type_t ty) -> tinytc_type_t; + auto is_aligned(std::int32_t alignment, tinytc_value const &operand, tinytc_value const &pos0) + -> bool; +}; + +} // namespace tinytc::spv + +#endif // COOPMATRIX_IMPL_BLOCK_20250428_HPP diff --git a/src/spv/coopmatrix_impl_dpas.cpp b/src/spv/coopmatrix_impl_dpas.cpp new file mode 100644 index 00000000..206ac592 --- /dev/null +++ b/src/spv/coopmatrix_impl_dpas.cpp @@ -0,0 +1,610 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/coopmatrix_impl_dpas.hpp" +#include "analysis/gcd.hpp" +#include "codegen_tools.hpp" +#include "coopmatrix_layout.hpp" +#include "device_info.hpp" +#include "matrix_ext_info.hpp" +#include "node/inst.hpp" +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "node/visit.hpp" +#include "number.hpp" +#include "spv/block2d_diy.hpp" +#include "spv/converter_aux.hpp" +#include "spv/coopmatrix_impl.hpp" +#include "spv/defs.hpp" +#include "spv/dope_vector.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/lut.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "spv/xe_constants.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/math.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto precision(tinytc_type_t ty) -> char const * { + return visit( + overloaded{[&](i8_type &) { return "s8"; }, // + [&](bf16_type &) { return "bf"; }, // + [&](f16_type &) { return "hf"; }, // + [](tinytc_type &) -> char const * { throw status::internal_compiler_error; }}, + *ty); +} + +auto coopmatrix_impl_dpas::max_rows_in_block(matrix_use use, std::int32_t element_size) const + -> std::int32_t { + if (use == matrix_use::b) { + std::int32_t ops_per_chan = xe::channel_size / element_size; + return ops_per_chan * xe::sdepth; + } + return xe::exec_size; +} + +auto coopmatrix_impl_dpas::check_2d_block_io(tinytc_value const &operand, tinytc_value const &pos0) + -> bool { + if (auto mi = gcd().get_memref_if(operand); mi) { + auto const mt = get_memref_type(operand); + const auto sty_size = size(mt->element_ty()); + auto const &block_io = cfg().matrix->block_io(); + + const bool sfid_ok = mt->addrspace() == address_space::global; + const bool base_address_alignment_ok = + (mi->offset_gcd() * sty_size) % block_io.base_address_alignment == 0; + const bool pos0_alignment_ok = (gcd().get(pos0) * sty_size) % block_io.pos0_alignment == 0; + const bool stride_ok = + mt->stride(0) == 1 && (mi->stride_gcd()[1] * sty_size) % block_io.stride_alignment == 0; + const bool width_ok = (mi->shape_gcd()[0] * sty_size) % block_io.width_alignment == 0; + + return sfid_ok && base_address_alignment_ok && pos0_alignment_ok && stride_ok && width_ok; + } + return false; +} + +auto coopmatrix_impl_dpas::load_config(tinytc_type_t sty, std::int32_t rows, std::int32_t cols, + matrix_use use, transpose trans, int32_t cache_level) + -> block_config { + auto cfg = block_config{}; + cfg.sty = sty; + cfg.element_size = size(sty); + cfg.array_length = 1; + cfg.rows = rows; + cfg.cols = cols; + cfg.row_blocks = 1; + cfg.col_blocks = 1; + cfg.transpose = trans == transpose::T; + cfg.vnni = use == matrix_use::a; + cfg.pos0_shr = 0; + cfg.cache_level = cache_level; + + auto const adjust_rows = [&cfg](std::int32_t max_rows, std::int32_t max_array_length) { + if (cfg.rows > max_rows) { + const std::int32_t num_blocks = cfg.rows / max_rows; + if (num_blocks > max_array_length) { + cfg.array_length = max_array_length; + cfg.row_blocks = num_blocks / max_array_length; + } else { + cfg.array_length = num_blocks; + } + cfg.rows = max_rows; + } + }; + auto const adjust_cols = [&cfg](std::int32_t max_cols_in_block) { + if (cfg.cols > max_cols_in_block) { + cfg.col_blocks = cfg.cols / max_cols_in_block; + cfg.cols = max_cols_in_block; + } + }; + auto const max_array_length = [&cfg](std::int32_t max_rows) -> std::int32_t { + return 64 / (max_rows * cfg.element_size); + }; + + // transpose + vnni message is the same as transpose message on d32 + if (cfg.transpose && cfg.vnni) { + std::swap(cfg.rows, cfg.cols); + + const auto ops_per_chan = 4 / cfg.element_size; + cfg.rows /= ops_per_chan; + cfg.sty = i32_type::get(sty->context()); + cfg.element_size = 4; + cfg.pos0_shr = ilog2(ops_per_chan); + cfg.vnni = false; + + adjust_cols(xe::exec_size); + + const auto max_rows = xe::exec_size / 2; + adjust_rows(max_rows, 1); + } else if (cfg.transpose) { + std::swap(cfg.rows, cfg.cols); + // Enable VNNI as transpose loads for B matrix are missing, so we use VNNI + mov-based 8x8 + // transpose + cfg.vnni = true; + + const std::int32_t max_cols = max_rows_in_block(use, cfg.element_size); + const std::int32_t max_rows = 8; + adjust_cols(max_cols); + adjust_rows(max_rows, max_array_length(max_rows)); + } else { + const std::int32_t max_cols = 32; + const std::int32_t max_rows = max_rows_in_block(use, cfg.element_size); + + adjust_cols(max_cols); + adjust_rows(max_rows, max_array_length(max_rows)); + } + + return cfg; +} + +auto coopmatrix_impl_dpas::load_fun(coopmatrix_type const *result_ty, spv_inst *spv_operand_ty, + transpose trans) -> spv_inst * { + const auto key = load_key{result_ty, spv_operand_ty, trans}; + return lookup(load_funs_, key, [&](load_key const &key) { + const auto [result_ty, spv_operand_ty, trans] = key; + + auto sty = result_ty->component_ty(); + const auto cfg = + load_config(sty, result_ty->rows(), result_ty->cols(), result_ty->use(), trans); + auto code = load_block2d_native(cfg, tmp_); + + auto spv_i32_ty = unique().int_ty(32); + auto spv_result_ty = spv_ty(result_ty); + auto fun_ty = unique().function_ty( + spv_result_ty, array_view{spv_operand_ty, spv_i32_ty, spv_i32_ty, + spv_i32_ty, spv_i32_ty, spv_i32_ty}); + return unique().mod().add_to(section::type_const_var, spv_result_ty, fun_ty, + unique().asm_target(), code, + "=rw,rw.u,rw.u,rw.u,rw.u,rw.u,rw.u"); + }); +} + +auto coopmatrix_impl_dpas::prefetch_fun(std::int32_t cache_level, tinytc_type_t sty, + spv_inst *spv_operand_ty, std::int32_t rows, + std::int32_t cols) -> spv_inst * { + const auto key = prefetch_key{cache_level, sty, spv_operand_ty, rows, cols}; + return lookup(prefetch_funs_, key, [&](prefetch_key const &key) -> spv_inst * { + const auto [cache_level, sty, spv_operand_ty, rows, cols] = key; + + const auto cfg = load_config(sty, rows, cols, matrix_use::acc, transpose::N, cache_level); + auto code = prefetch_block2d_native(cfg, tmp_); + + auto spv_i32_ty = unique().int_ty(32); + auto spv_void_ty = unique().void_ty(); + auto fun_ty = unique().function_ty( + spv_void_ty, array_view{spv_operand_ty, spv_i32_ty, spv_i32_ty, spv_i32_ty, + spv_i32_ty, spv_i32_ty}); + return unique().mod().add_to(section::type_const_var, spv_void_ty, fun_ty, + unique().asm_target(), code, + "rw.u,rw.u,rw.u,rw.u,rw.u,rw.u"); + }); +} + +auto coopmatrix_impl_dpas::store_config(coopmatrix_type const *ct) -> block_config { + constexpr std::int32_t max_cols_in_block = 8; + + auto cfg = block_config{}; + cfg.sty = ct->component_ty(); + cfg.element_size = size(cfg.sty); + cfg.array_length = 1; + cfg.rows = ct->rows(); + cfg.cols = ct->cols(); + cfg.row_blocks = 1; + cfg.col_blocks = 1; + cfg.transpose = false; + cfg.vnni = false; + cfg.pos0_shr = 0; + cfg.cache_level = -1; + + if (cfg.cols > max_cols_in_block) { + cfg.col_blocks = cfg.cols / max_cols_in_block; + cfg.cols = max_cols_in_block; + } + + const auto max_rows = max_rows_in_block(ct->use(), cfg.element_size); + if (cfg.rows > max_rows) { + cfg.row_blocks = cfg.rows / max_rows; + cfg.rows = max_rows; + } + + return cfg; +} + +auto coopmatrix_impl_dpas::store_fun(coopmatrix_type const *val_ty, spv_inst *spv_operand_ty) + -> spv_inst * { + const auto key = store_key{val_ty, spv_operand_ty}; + return lookup(store_funs_, key, [&](store_key const &key) { + const auto [val_ty, spv_operand_ty] = key; + + const auto cfg = store_config(val_ty); + auto code = store_block2d_native(cfg, tmp_); + + auto spv_void_ty = unique().void_ty(); + auto spv_val_ty = spv_ty(val_ty); + auto spv_i32_ty = unique().int_ty(32); + auto fun_ty = unique().function_ty( + spv_void_ty, array_view{spv_val_ty, spv_operand_ty, spv_i32_ty, spv_i32_ty, + spv_i32_ty, spv_i32_ty, spv_i32_ty}); + auto &mod = unique().mod(); + auto asmop = + mod.add_to(section::type_const_var, spv_void_ty, fun_ty, + unique().asm_target(), code, "rw,rw.u,rw.u,rw.u,rw.u,rw.u,rw.u"); + mod.add_to(section::decoration, asmop, Decoration::SideEffectsINTEL); + return asmop; + }); +} + +auto coopmatrix_impl_dpas::mul_add_fun(coopmatrix_type const *at, coopmatrix_type const *bt, + coopmatrix_type const *ct, coopmatrix_type const *rt, + bool is_c_zero) -> spv_inst * { + const auto key = mul_add_key{{at, bt, ct, rt}, is_c_zero}; + return lookup(mul_add_funs_, key, [&](mul_add_key const &key) { + const auto [at, bt, ct, rt] = key.op_ty; + + auto oasm = std::ostringstream{}; + + auto at_sty = at->component_ty(); + auto bt_sty = bt->component_ty(); + auto ct_sty = ct->component_ty(); + auto rt_sty = rt->component_ty(); + const std::int32_t ops_per_chan = xe::channel_size / size(at_sty); + const std::int32_t K = ops_per_chan * xe::sdepth; + + oasm << "{\n"; + std::string result_placeholder = "$0"; + std::string temp = result_placeholder; + if (rt->component_ty() != ct->component_ty() && at->cols() / K > 1) { + temp = tmp_("temp"); + oasm << ".decl " << temp << " v_type=G type=" << visa_type(ct_sty) + << " num_elts=" << ct->rows() * ct->cols() << " align=wordx32\n"; + } + const auto mat_A = tmp_("matrix_A"); + const auto mat_B = tmp_("matrix_B"); + oasm << ".decl " << mat_A + << " v_type=G type=d num_elts=" << at->rows() * at->cols() / ops_per_chan + << " align=wordx32 alias=<$1,0>\n"; + oasm << ".decl " << mat_B + << " v_type=G type=d num_elts=" << bt->rows() * bt->cols() / ops_per_chan + << " align=wordx32 alias=<$2,0>\n"; + + /** The GRF layout must follow the layout described in the following. + * + * Let CM, CN, CK be the size of the coopmatrices, where + * CM = ct->rows() = at->rows(), + * CN = ct->cols() = bt->cols(), + * CK = at->cols() = bt->rows(), + * and let M, N, K be the size expected by DPAS, where + * M = xe::exec_size, + * N = xe::rcount + * K = ops_per_chan * xe::sdepth + * Let BM:=CM/M, BN:=CN/N, BK:=CK/K be the number of blocks in the respective mode. + * + * The blocks are laid out in the GRF as following + * + * A[m,k,bk,bm] = m + k * M + bk * M * K + bm * M * K * BK + * B[k,n,bn,bk] = k + n * K + bn * K * N + bk * K * N * BN + * C[m,n,bn,bm] = m + n * M + bn * M * N + bm * M * N * BN + * + * where m \in [M], n \in [N], k \in [K], bm \in [BM], bn \in [BN], bk \in [BK]. + * + * The mapping of m,n,k,bm,bn,bk to memory address is given by + * + * MA[m,k,bk,bm] = m' + bm' * M + (k' + bk' * K) * A_stride1 + * MB[k,n,bn,bk] = k'' + bk'' * K + (n'' + bn'' * N) * B_stride1 + * MC[m,n,bn,bm] = m + bm * M + (n + bn * N) * C_stride1 + * + * where + * + * (m',k') = { (m%ops_per_chan + k*ops_per_chan, floor(m/ops_per_chan)) if A + * transposed { (floor(m/ops_per_chan) + k*(M/ops_per_chan), m%ops_per_chan) else (bm',bk') + * = { (bk,bm) if A transposed { (bm,bk) else + * + * and + * + * (k'',n'') = { (n,k) if B transposed + * { (k,n) else + * (bk'',bn'') = { (bn,bk) if B transposed + * { (bk,bn) else + * + */ + const auto precision_src1 = precision(at_sty); + const auto precision_src2 = precision(bt_sty); + for (std::int32_t k = 0; k < at->cols(); k += K) { + char const *src0 = k > 0 ? temp.c_str() : (!key.is_c_zero ? "$3" : "%null"); + char const *dst = k + K >= at->cols() ? result_placeholder.c_str() : temp.c_str(); + const auto rsize = k + K >= at->cols() ? size(rt_sty) : size(ct_sty); + for (std::int32_t m = 0; m < ct->rows(); m += xe::exec_size) { + for (std::int32_t n = 0; n < ct->cols(); n += xe::rcount) { + const auto aoffset = (k * xe::exec_size + m * at->cols()) * size(at_sty); + const auto brow = (k * bt->cols() + n * K) * size(bt_sty) / xe::grf_size; + const auto coffset = !key.is_c_zero || k > 0 + ? (m * ct->cols() + n * xe::exec_size) * size(ct_sty) + : 0; + const auto roffset = (m * rt->cols() + n * xe::exec_size) * rsize; + oasm << "dpas." << precision_src1 << "." << precision_src2 << "." << xe::sdepth + << "." << xe::rcount << " (M1," << xe::exec_size << ") " << dst << "." + << roffset << " " << src0 << "." << coffset << " " << mat_A << "." + << aoffset << " " << mat_B << "(" << brow << ",0)\n"; + } + } + } + oasm << "}\n"; + + auto spv_a_ty = spv_ty(at); + auto spv_b_ty = spv_ty(bt); + auto spv_c_ty = spv_ty(ct); + auto spv_result_ty = spv_ty(rt); + auto fun_ty = unique().function_ty(spv_result_ty, + array_view{spv_a_ty, spv_b_ty, spv_c_ty}); + + return unique().mod().add_to(section::type_const_var, spv_result_ty, fun_ty, + unique().asm_target(), std::move(oasm).str(), + "=rw,rw,rw,rw"); + }); +} + +auto coopmatrix_impl_dpas::reduce_fun(std::int32_t sgs, IK op, coopmatrix_type const *at, + coopmatrix_type const *rt) -> spv_inst * { + const auto key = std::make_tuple(sgs, op, at, rt); + return lookup(reduce_funs_, key, [&](reduce_key const &key) { + auto [sgs, op, at, rt] = key; + auto rl = get_layout(cfg(), rt); + auto al = get_layout(cfg(), at); + auto matrix_ty = spv_ty(rl); + const auto at_sty = at->component_ty(); + const auto sty = rt->component_ty(); + const auto sty_size = size(sty); + + auto oasm = std::ostringstream{}; + + oasm << "{\n"; + auto aview = tmp_("aview"); + oasm << ".decl " << aview << " v_type=G type=" << visa_type(at_sty) + << " num_elts=" << al.length * sgs << " align=wordx32 alias=<$1,0>\n"; + auto rview = tmp_("rview"); + oasm << ".decl " << rview << " v_type=G type=" << visa_type(sty) + << " num_elts=" << rl.length * sgs << " align=wordx32 alias=<$0,0>\n"; + auto predicate = tmp_("predicate"); + oasm << ".decl " << predicate << " v_type=P num_elts=" << sgs << "\n"; + + char const *reduce = [](IK op) { + switch (op) { + case IK::IK_cooperative_matrix_reduce_add: + return "add"; + case IK::IK_cooperative_matrix_reduce_max: + return "max"; + case IK::IK_cooperative_matrix_reduce_min: + return "min"; + default: + break; + } + throw status::internal_compiler_error; + }(op); + + for (std::int32_t offset = 0; offset < al.shape1; offset += sgs) { + const std::int32_t remainder = + std::min(sgs, static_cast(al.shape1 - offset)); + std::string src = aview; + if (al.blocks > 1) { + auto tmp = tmp_("tmp"); + oasm << ".decl " << tmp << " v_type=G type=" << visa_type(at_sty) + << " num_elts=" << sgs * sgs << " align=wordx32\n"; + for (std::int32_t j0 = offset; j0 < offset + remainder; ++j0) { + const auto t1 = region_origin(sty_size, sgs * (j0 - offset) * sty_size); + const auto a1 = + region_origin(sty_size, sgs * al.component_no(j0, 0) * sty_size); + const auto a2 = + region_origin(sty_size, sgs * al.component_no(j0, 1) * sty_size); + oasm << reduce << " (M1," << sgs << ") " << tmp << "(" << t1[0] << "," << t1[1] + << ")<1> " << aview << "(" << a1[0] << "," << a1[1] << ")<1;1,0> " << aview + << "(" << a2[0] << "," << a2[1] << ")<1;1,0>\n"; + for (std::int32_t b = 2; b < al.blocks; ++b) { + const auto a2 = + region_origin(sty_size, sgs * al.component_no(j0, b) * sty_size); + oasm << reduce << " (M1," << sgs << ") " << tmp << "(" << t1[0] << "," + << t1[1] << ")<1> " << tmp << "(" << t1[0] << "," << t1[1] + << ")<1;1,0> " << aview << "(" << a2[0] << "," << a2[1] + << ")<1;1,0>\n"; + } + } + src = tmp; + } + + for (std::int32_t v = 1; v < sgs; v *= 2) { + std::uint32_t pval = 0; + for (std::uint32_t j = 0; j < 32; j += 2 * v) { + pval |= (((1 << v) - 1) << j); + } + oasm << "setp (M1," << sgs << ") " << predicate << " " << pval << ":ud\n"; + + std::string dst = rview; + auto dst_offset = offset; + if (2 * v < sgs) { + auto tmp = tmp_("tmp"); + oasm << ".decl " << tmp << " v_type=G type=" << visa_type(at_sty) + << " num_elts=" << sgs * sgs / (2 * v) << " align=wordx32\n"; + dst = tmp; + dst_offset = 0; + } + + for (int i = 0; i < sgs / v && i < remainder; i += 2) { + auto tmp1 = tmp_("tmp"); + oasm << ".decl " << tmp1 << " v_type=G type=" << visa_type(at_sty) + << " num_elts=" << sgs << " align=wordx32\n"; + auto tmp2 = tmp_("tmp"); + oasm << ".decl " << tmp2 << " v_type=G type=" << visa_type(at_sty) + << " num_elts=" << sgs << " align=wordx32\n"; + + const auto t0 = region_origin(sty_size, (dst_offset + sgs * i / 2) * sty_size); + const auto t1 = region_origin(sty_size, sgs * i * sty_size); + const auto t1down = region_origin(sty_size, (sgs * i + v) * sty_size); + const auto t2 = region_origin(sty_size, sgs * (i + 1) * sty_size); + const auto t2up = region_origin(sty_size, (sgs * (i + 1) - v) * sty_size); + oasm << "(!" << predicate << ") sel (M1," << sgs << ") " << tmp1 << "(0,0)<1> " + << src << "(" << t2up[0] << "," << t2up[1] << ")<1;1,0> " << src << "(" + << t1[0] << "," << t1[1] << ")<1;1,0>\n"; + oasm << "(" << predicate << ") sel (M1," << sgs << ") " << tmp2 << "(0,0)<1> " + << src << "(" << t1down[0] << "," << t1down[1] << ")<1;1,0> " << src << "(" + << t2[0] << "," << t2[1] << ")<1;1,0>\n"; + oasm << reduce << " (M1," << sgs << ") " << dst << "(" << t0[0] << "," << t0[1] + << ")<1> " << tmp1 << "(0,0)<1;1,0> " << tmp2 << "(0,0)<1;1,0>\n"; + } + src = dst; + } + } + oasm << "}\n"; + + auto fun_ty = unique().function_ty(matrix_ty, array_view{spv_ty(al)}); + return unique().mod().add_to(section::type_const_var, matrix_ty, fun_ty, + unique().asm_target(), std::move(oasm).str(), + "=rw,rw"); + }); +} + +auto coopmatrix_impl_dpas::load(cooperative_matrix_load_inst in, dope_vector const &odv, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) -> spv_inst * { + auto rt = get_coopmatrix_type(in.result()); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const auto type_ok = cfg().matrix->have_type(rt); + const auto block_io_ok = check_2d_block_io(in.operand(), in.pos0()); + const bool transpose_ok = in.t() == transpose::N || rt->use() == matrix_use::a; + + if (!sgs_ok || !type_ok || !block_io_ok || !transpose_ok) { + return coopmatrix_impl_block::load(in, odv, pointer, pos0, pos1); + } + + auto ot = get_memref_type(in.operand()); + auto ot_sty = ot->element_ty(); + auto ct = get_coopmatrix_type(in.result()); + auto fun = load_fun(ct, get_spv_ty(unique(), ot), in.t()); + + auto &mod = unique().mod(); + auto spv_i32_ty = unique().int_ty(32); + auto csize = unique().constant(static_cast(size(ot_sty))); + auto shape0_i32 = mod.add(spv_i32_ty, odv.shape(0)); + auto width_in_bytes = mod.add(spv_i32_ty, shape0_i32, csize); + auto height = mod.add(spv_i32_ty, odv.shape(1)); + auto stride1_i32 = mod.add(spv_i32_ty, odv.stride(1)); + auto stride_in_bytes = mod.add(spv_i32_ty, stride1_i32, csize); + auto pos0_i32 = mod.add(spv_i32_ty, pos0); + auto pos1_i32 = mod.add(spv_i32_ty, pos1); + + return mod.add(spv_ty(ct), fun, + array_view{pointer, width_in_bytes, height, + stride_in_bytes, pos0_i32, pos1_i32}); +} + +auto coopmatrix_impl_dpas::mul_add(cooperative_matrix_mul_add_inst in, spv_inst *a, spv_inst *b, + spv_inst *c) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + auto bt = get_coopmatrix_type(in.b()); + auto ct = get_coopmatrix_type(in.c()); + auto rt = get_coopmatrix_type(in.result()); + auto at_sty = at->component_ty()->type_id(); + auto bt_sty = bt->component_ty()->type_id(); + auto ct_sty = ct->component_ty()->type_id(); + auto rt_sty = rt->component_ty()->type_id(); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const bool have_gemm = + cfg().matrix->have_gemm(at_sty, bt_sty, ct_sty, rt_sty, rt->rows(), rt->cols(), at->cols()); + if (!sgs_ok || !have_gemm) { + return coopmatrix_impl_block::mul_add(in, a, b, c); + } + + auto fun = mul_add_fun(at, bt, ct, rt, in.is_c_zero()); + return unique().mod().add(spv_ty(rt), fun, array_view{a, b, c}); +} + +void coopmatrix_impl_dpas::prefetch(cooperative_matrix_prefetch_inst in, dope_vector const &odv, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) { + auto ot = get_memref_type(in.operand()); + auto ot_sty = ot->element_ty(); + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const auto type_ok = size(ot_sty) <= 4; + const auto block_io_ok = check_2d_block_io(in.operand(), in.pos0()); + + if (!sgs_ok || !type_ok || !block_io_ok) { + coopmatrix_impl_block::prefetch(in, odv, pointer, pos0, pos1); + } else { + auto fun = + prefetch_fun(in.cache_level(), ot_sty, get_spv_ty(unique(), ot), in.rows(), in.cols()); + + if (fun) { + auto &mod = unique().mod(); + auto spv_void_ty = unique().void_ty(); + auto spv_i32_ty = unique().int_ty(32); + auto csize = unique().constant(static_cast(size(ot_sty))); + auto shape0_i32 = mod.add(spv_i32_ty, odv.shape(0)); + auto width_in_bytes = mod.add(spv_i32_ty, shape0_i32, csize); + auto height = mod.add(spv_i32_ty, odv.shape(1)); + auto stride1_i32 = mod.add(spv_i32_ty, odv.stride(1)); + auto stride_in_bytes = mod.add(spv_i32_ty, stride1_i32, csize); + auto pos0_i32 = mod.add(spv_i32_ty, pos0); + auto pos1_i32 = mod.add(spv_i32_ty, pos1); + + mod.add(spv_void_ty, fun, + array_view{pointer, width_in_bytes, height, + stride_in_bytes, pos0_i32, pos1_i32}); + } + } +} + +void coopmatrix_impl_dpas::store(cooperative_matrix_store_inst in, dope_vector const &odv, + spv_inst *val, spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) { + auto ct = get_coopmatrix_type(in.val()); + const bool transpose_ok = in.t() == transpose::N; + const bool sgs_ok = cfg().subgroup_size == cfg().matrix->required_subgroup_size(); + const auto type_ok = cfg().matrix->have_type(ct); + const auto block_io_ok = check_2d_block_io(in.operand(), in.pos0()); + + if (!transpose_ok || !sgs_ok || !type_ok || !block_io_ok) { + coopmatrix_impl_block::store(in, odv, val, pointer, pos0, pos1); + } else { + auto ot = get_memref_type(in.operand()); + auto ot_sty = ot->element_ty(); + auto fun = store_fun(ct, get_spv_ty(unique(), ot)); + + auto &mod = unique().mod(); + auto spv_void_ty = unique().void_ty(); + auto spv_i32_ty = unique().int_ty(32); + auto csize = unique().constant(static_cast(size(ot_sty))); + auto shape0_i32 = mod.add(spv_i32_ty, odv.shape(0)); + auto width_in_bytes = mod.add(spv_i32_ty, shape0_i32, csize); + auto height = mod.add(spv_i32_ty, odv.shape(1)); + auto stride1_i32 = mod.add(spv_i32_ty, odv.stride(1)); + auto stride_in_bytes = mod.add(spv_i32_ty, stride1_i32, csize); + auto pos0_i32 = mod.add(spv_i32_ty, pos0); + auto pos1_i32 = mod.add(spv_i32_ty, pos1); + + mod.add(spv_void_ty, fun, + array_view{val, pointer, width_in_bytes, height, + stride_in_bytes, pos0_i32, pos1_i32}); + } +} + +auto coopmatrix_impl_dpas::reduce(cooperative_matrix_reduce_inst in, spv_inst *a) -> spv_inst * { + auto at = get_coopmatrix_type(in.a()); + const auto sgs = cfg().subgroup_size; + + if (in.mode() != reduce_mode::column || at->rows() % sgs != 0 || at->use() == matrix_use::a) { + return coopmatrix_impl::reduce(in, a); + } + + auto rt = get_coopmatrix_type(in.result()); + auto fun = reduce_fun(sgs, in.get().type_id(), at, rt); + return unique().mod().add(spv_ty(rt), fun, array_view{a}); +} + +} // namespace tinytc::spv + diff --git a/src/spv/coopmatrix_impl_dpas.hpp b/src/spv/coopmatrix_impl_dpas.hpp new file mode 100644 index 00000000..0d46bb8d --- /dev/null +++ b/src/spv/coopmatrix_impl_dpas.hpp @@ -0,0 +1,94 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef COOPMATRIX_IMPL_DPAS_20250428_HPP +#define COOPMATRIX_IMPL_DPAS_20250428_HPP + +#include "node/inst_view.hpp" +#include "node/type.hpp" +#include "spv/block2d_diy.hpp" +#include "spv/coopmatrix_impl_block.hpp" +#include "support/temp_counter.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst; + +class coopmatrix_impl_dpas : public coopmatrix_impl_block { + public: + using coopmatrix_impl_block::coopmatrix_impl_block; + + auto load(cooperative_matrix_load_inst in, dope_vector const &odv, spv_inst *pointer, + spv_inst *pos0, spv_inst *pos1) -> spv_inst * override; + auto mul_add(cooperative_matrix_mul_add_inst in, spv_inst *a, spv_inst *b, spv_inst *c) + -> spv_inst * override; + void prefetch(cooperative_matrix_prefetch_inst in, dope_vector const &odv, spv_inst *pointer, + spv_inst *pos0, spv_inst *pos1) override; + void store(cooperative_matrix_store_inst in, dope_vector const &odv, spv_inst *val, + spv_inst *pointer, spv_inst *pos0, spv_inst *pos1) override; + auto reduce(cooperative_matrix_reduce_inst in, spv_inst *a) -> spv_inst * override; + + private: + struct mul_add_key { + std::array op_ty; + bool is_c_zero; + + auto operator==(mul_add_key const &other) const { + return op_ty == other.op_ty && is_c_zero == other.is_c_zero; + } + }; + struct mul_add_hash { + inline auto operator()(mul_add_key const &key) const -> std::size_t { + return fnv1a_combine(key.op_ty[0], key.op_ty[1], key.op_ty[2], key.op_ty[3], + key.is_c_zero); + } + }; + template struct tuple_hash { + inline auto operator()(Tuple const &key) const -> std::size_t { + return std::apply([](auto const &...args) { return fnv1a_combine(args...); }, key); + } + }; + + auto max_rows_in_block(matrix_use use, std::int32_t element_size) const -> std::int32_t; + auto check_2d_block_io(tinytc_value const &operand, tinytc_value const &pos0) -> bool; + auto load_config(tinytc_type_t sty, std::int32_t rows, std::int32_t cols, matrix_use use, + transpose trans, std::int32_t cache_level = -1) -> block_config; + auto load_fun(coopmatrix_type const *result_ty, spv_inst *spv_operand_ty, transpose trans) + -> spv_inst *; + auto prefetch_fun(std::int32_t cache_level, tinytc_type_t sty, spv_inst *spv_operand_ty, + std::int32_t rows, std::int32_t cols) -> spv_inst *; + auto store_config(coopmatrix_type const *ct) -> block_config; + auto store_fun(coopmatrix_type const *val_ty, spv_inst *spv_operand_ty) -> spv_inst *; + auto mul_add_fun(coopmatrix_type const *at, coopmatrix_type const *bt, + coopmatrix_type const *ct, coopmatrix_type const *rt, bool is_c_zero) + -> spv_inst *; + auto reduce_fun(std::int32_t sgs, IK op, coopmatrix_type const *at, coopmatrix_type const *rt) + -> spv_inst *; + + using load_key = std::tuple; + using prefetch_key = + std::tuple; + using store_key = std::tuple; + using reduce_key = + std::tuple; + std::unordered_map> load_funs_; + std::unordered_map> prefetch_funs_; + std::unordered_map> store_funs_; + std::unordered_map mul_add_funs_; + std::unordered_map> reduce_funs_; + temp_counter tmp_; +}; + +} // namespace tinytc::spv + +#endif // COOPMATRIX_IMPL_DPAS_20250428_HPP diff --git a/src/spv/defs.hpp b/src/spv/defs.hpp new file mode 100644 index 00000000..9f341da3 --- /dev/null +++ b/src/spv/defs.hpp @@ -0,0 +1,426 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_DEFS_20250630_HPP +#define GENERATED_DEFS_20250630_HPP + +#include "enums.hpp" +#include "tinytc/core.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +class spv_inst : public ilist_node { + public: + inline spv_inst(Op opcode, bool has_result_id) + : opcode_{opcode}, id_{has_result_id ? 0 : std::numeric_limits::max()} {} + virtual ~spv_inst() = default; + + spv_inst(spv_inst const &other) = delete; + spv_inst(spv_inst &&other) = delete; + spv_inst &operator=(spv_inst const &other) = delete; + spv_inst &operator=(spv_inst &&other) = delete; + + inline auto opcode() const -> Op { return opcode_; } + // SPIR-V requires 0 < id < Bound, therefore, we can reserve 0 for encoding "produces result; id + // not yet assigned" and uint32_max for encoding "does not produce result" + inline auto has_result_id() const -> bool { + return id_ != std::numeric_limits::max(); + } + inline auto id() const -> std::uint32_t { return id_; } + inline void id(std::uint32_t id) { id_ = id; } + + private: + Op opcode_; + std::uint32_t id_; +}; + +using DecorationAttr = std::variant>; +using ExecutionModeAttr = std::variant>; +using LiteralContextDependentNumber = + std::variant; +using LiteralString = std::string; +using LiteralInteger = std::int32_t; +using LiteralExtInstInteger = std::int32_t; +using IdResultType = spv_inst *; +using IdRef = spv_inst *; +using IdScope = spv_inst *; +using IdMemorySemantics = spv_inst *; +using LoopControlAttr = std::int32_t; +using MemoryAccessAttr = std::int32_t; +using PairIdRefIdRef = std::pair; +using PairLiteralIntegerIdRef = + std::pair, spv_inst *>; +using PairIdRefLiteralInteger = std::pair; + +class OpNop; // IWYU pragma: export +class OpUndef; // IWYU pragma: export +class OpSourceContinued; // IWYU pragma: export +class OpSource; // IWYU pragma: export +class OpSourceExtension; // IWYU pragma: export +class OpName; // IWYU pragma: export +class OpMemberName; // IWYU pragma: export +class OpString; // IWYU pragma: export +class OpLine; // IWYU pragma: export +class OpExtension; // IWYU pragma: export +class OpExtInstImport; // IWYU pragma: export +class OpExtInst; // IWYU pragma: export +class OpMemoryModel; // IWYU pragma: export +class OpEntryPoint; // IWYU pragma: export +class OpExecutionMode; // IWYU pragma: export +class OpCapability; // IWYU pragma: export +class OpTypeVoid; // IWYU pragma: export +class OpTypeBool; // IWYU pragma: export +class OpTypeInt; // IWYU pragma: export +class OpTypeFloat; // IWYU pragma: export +class OpTypeVector; // IWYU pragma: export +class OpTypeMatrix; // IWYU pragma: export +class OpTypeImage; // IWYU pragma: export +class OpTypeSampler; // IWYU pragma: export +class OpTypeSampledImage; // IWYU pragma: export +class OpTypeArray; // IWYU pragma: export +class OpTypeRuntimeArray; // IWYU pragma: export +class OpTypeStruct; // IWYU pragma: export +class OpTypeOpaque; // IWYU pragma: export +class OpTypePointer; // IWYU pragma: export +class OpTypeFunction; // IWYU pragma: export +class OpTypeEvent; // IWYU pragma: export +class OpTypeDeviceEvent; // IWYU pragma: export +class OpTypeReserveId; // IWYU pragma: export +class OpTypeQueue; // IWYU pragma: export +class OpTypePipe; // IWYU pragma: export +class OpTypeForwardPointer; // IWYU pragma: export +class OpConstantTrue; // IWYU pragma: export +class OpConstantFalse; // IWYU pragma: export +class OpConstant; // IWYU pragma: export +class OpConstantComposite; // IWYU pragma: export +class OpConstantSampler; // IWYU pragma: export +class OpConstantNull; // IWYU pragma: export +class OpFunction; // IWYU pragma: export +class OpFunctionParameter; // IWYU pragma: export +class OpFunctionEnd; // IWYU pragma: export +class OpFunctionCall; // IWYU pragma: export +class OpVariable; // IWYU pragma: export +class OpImageTexelPointer; // IWYU pragma: export +class OpLoad; // IWYU pragma: export +class OpStore; // IWYU pragma: export +class OpCopyMemory; // IWYU pragma: export +class OpCopyMemorySized; // IWYU pragma: export +class OpAccessChain; // IWYU pragma: export +class OpInBoundsAccessChain; // IWYU pragma: export +class OpPtrAccessChain; // IWYU pragma: export +class OpArrayLength; // IWYU pragma: export +class OpGenericPtrMemSemantics; // IWYU pragma: export +class OpInBoundsPtrAccessChain; // IWYU pragma: export +class OpDecorate; // IWYU pragma: export +class OpMemberDecorate; // IWYU pragma: export +class OpDecorationGroup; // IWYU pragma: export +class OpGroupDecorate; // IWYU pragma: export +class OpGroupMemberDecorate; // IWYU pragma: export +class OpVectorExtractDynamic; // IWYU pragma: export +class OpVectorInsertDynamic; // IWYU pragma: export +class OpVectorShuffle; // IWYU pragma: export +class OpCompositeConstruct; // IWYU pragma: export +class OpCompositeExtract; // IWYU pragma: export +class OpCompositeInsert; // IWYU pragma: export +class OpCopyObject; // IWYU pragma: export +class OpTranspose; // IWYU pragma: export +class OpSampledImage; // IWYU pragma: export +class OpImageSampleImplicitLod; // IWYU pragma: export +class OpImageSampleExplicitLod; // IWYU pragma: export +class OpImageSampleDrefImplicitLod; // IWYU pragma: export +class OpImageSampleDrefExplicitLod; // IWYU pragma: export +class OpImageSampleProjImplicitLod; // IWYU pragma: export +class OpImageSampleProjExplicitLod; // IWYU pragma: export +class OpImageSampleProjDrefImplicitLod; // IWYU pragma: export +class OpImageSampleProjDrefExplicitLod; // IWYU pragma: export +class OpImageFetch; // IWYU pragma: export +class OpImageGather; // IWYU pragma: export +class OpImageDrefGather; // IWYU pragma: export +class OpImageRead; // IWYU pragma: export +class OpImageWrite; // IWYU pragma: export +class OpImage; // IWYU pragma: export +class OpImageQueryFormat; // IWYU pragma: export +class OpImageQueryOrder; // IWYU pragma: export +class OpImageQuerySizeLod; // IWYU pragma: export +class OpImageQuerySize; // IWYU pragma: export +class OpImageQueryLod; // IWYU pragma: export +class OpImageQueryLevels; // IWYU pragma: export +class OpImageQuerySamples; // IWYU pragma: export +class OpConvertFToU; // IWYU pragma: export +class OpConvertFToS; // IWYU pragma: export +class OpConvertSToF; // IWYU pragma: export +class OpConvertUToF; // IWYU pragma: export +class OpUConvert; // IWYU pragma: export +class OpSConvert; // IWYU pragma: export +class OpFConvert; // IWYU pragma: export +class OpQuantizeToF16; // IWYU pragma: export +class OpConvertPtrToU; // IWYU pragma: export +class OpSatConvertSToU; // IWYU pragma: export +class OpSatConvertUToS; // IWYU pragma: export +class OpConvertUToPtr; // IWYU pragma: export +class OpPtrCastToGeneric; // IWYU pragma: export +class OpGenericCastToPtr; // IWYU pragma: export +class OpGenericCastToPtrExplicit; // IWYU pragma: export +class OpBitcast; // IWYU pragma: export +class OpSNegate; // IWYU pragma: export +class OpFNegate; // IWYU pragma: export +class OpIAdd; // IWYU pragma: export +class OpFAdd; // IWYU pragma: export +class OpISub; // IWYU pragma: export +class OpFSub; // IWYU pragma: export +class OpIMul; // IWYU pragma: export +class OpFMul; // IWYU pragma: export +class OpUDiv; // IWYU pragma: export +class OpSDiv; // IWYU pragma: export +class OpFDiv; // IWYU pragma: export +class OpUMod; // IWYU pragma: export +class OpSRem; // IWYU pragma: export +class OpSMod; // IWYU pragma: export +class OpFRem; // IWYU pragma: export +class OpFMod; // IWYU pragma: export +class OpVectorTimesScalar; // IWYU pragma: export +class OpMatrixTimesScalar; // IWYU pragma: export +class OpVectorTimesMatrix; // IWYU pragma: export +class OpMatrixTimesVector; // IWYU pragma: export +class OpMatrixTimesMatrix; // IWYU pragma: export +class OpOuterProduct; // IWYU pragma: export +class OpDot; // IWYU pragma: export +class OpIAddCarry; // IWYU pragma: export +class OpISubBorrow; // IWYU pragma: export +class OpUMulExtended; // IWYU pragma: export +class OpSMulExtended; // IWYU pragma: export +class OpAny; // IWYU pragma: export +class OpAll; // IWYU pragma: export +class OpIsNan; // IWYU pragma: export +class OpIsInf; // IWYU pragma: export +class OpIsFinite; // IWYU pragma: export +class OpIsNormal; // IWYU pragma: export +class OpSignBitSet; // IWYU pragma: export +class OpLessOrGreater; // IWYU pragma: export +class OpOrdered; // IWYU pragma: export +class OpUnordered; // IWYU pragma: export +class OpLogicalEqual; // IWYU pragma: export +class OpLogicalNotEqual; // IWYU pragma: export +class OpLogicalOr; // IWYU pragma: export +class OpLogicalAnd; // IWYU pragma: export +class OpLogicalNot; // IWYU pragma: export +class OpSelect; // IWYU pragma: export +class OpIEqual; // IWYU pragma: export +class OpINotEqual; // IWYU pragma: export +class OpUGreaterThan; // IWYU pragma: export +class OpSGreaterThan; // IWYU pragma: export +class OpUGreaterThanEqual; // IWYU pragma: export +class OpSGreaterThanEqual; // IWYU pragma: export +class OpULessThan; // IWYU pragma: export +class OpSLessThan; // IWYU pragma: export +class OpULessThanEqual; // IWYU pragma: export +class OpSLessThanEqual; // IWYU pragma: export +class OpFOrdEqual; // IWYU pragma: export +class OpFUnordEqual; // IWYU pragma: export +class OpFOrdNotEqual; // IWYU pragma: export +class OpFUnordNotEqual; // IWYU pragma: export +class OpFOrdLessThan; // IWYU pragma: export +class OpFUnordLessThan; // IWYU pragma: export +class OpFOrdGreaterThan; // IWYU pragma: export +class OpFUnordGreaterThan; // IWYU pragma: export +class OpFOrdLessThanEqual; // IWYU pragma: export +class OpFUnordLessThanEqual; // IWYU pragma: export +class OpFOrdGreaterThanEqual; // IWYU pragma: export +class OpFUnordGreaterThanEqual; // IWYU pragma: export +class OpShiftRightLogical; // IWYU pragma: export +class OpShiftRightArithmetic; // IWYU pragma: export +class OpShiftLeftLogical; // IWYU pragma: export +class OpBitwiseOr; // IWYU pragma: export +class OpBitwiseXor; // IWYU pragma: export +class OpBitwiseAnd; // IWYU pragma: export +class OpNot; // IWYU pragma: export +class OpBitFieldInsert; // IWYU pragma: export +class OpBitFieldSExtract; // IWYU pragma: export +class OpBitFieldUExtract; // IWYU pragma: export +class OpBitReverse; // IWYU pragma: export +class OpBitCount; // IWYU pragma: export +class OpDPdx; // IWYU pragma: export +class OpDPdy; // IWYU pragma: export +class OpFwidth; // IWYU pragma: export +class OpDPdxFine; // IWYU pragma: export +class OpDPdyFine; // IWYU pragma: export +class OpFwidthFine; // IWYU pragma: export +class OpDPdxCoarse; // IWYU pragma: export +class OpDPdyCoarse; // IWYU pragma: export +class OpFwidthCoarse; // IWYU pragma: export +class OpEmitVertex; // IWYU pragma: export +class OpEndPrimitive; // IWYU pragma: export +class OpEmitStreamVertex; // IWYU pragma: export +class OpEndStreamPrimitive; // IWYU pragma: export +class OpControlBarrier; // IWYU pragma: export +class OpMemoryBarrier; // IWYU pragma: export +class OpAtomicLoad; // IWYU pragma: export +class OpAtomicStore; // IWYU pragma: export +class OpAtomicExchange; // IWYU pragma: export +class OpAtomicCompareExchange; // IWYU pragma: export +class OpAtomicCompareExchangeWeak; // IWYU pragma: export +class OpAtomicIIncrement; // IWYU pragma: export +class OpAtomicIDecrement; // IWYU pragma: export +class OpAtomicIAdd; // IWYU pragma: export +class OpAtomicISub; // IWYU pragma: export +class OpAtomicSMin; // IWYU pragma: export +class OpAtomicUMin; // IWYU pragma: export +class OpAtomicSMax; // IWYU pragma: export +class OpAtomicUMax; // IWYU pragma: export +class OpAtomicAnd; // IWYU pragma: export +class OpAtomicOr; // IWYU pragma: export +class OpAtomicXor; // IWYU pragma: export +class OpPhi; // IWYU pragma: export +class OpLoopMerge; // IWYU pragma: export +class OpSelectionMerge; // IWYU pragma: export +class OpLabel; // IWYU pragma: export +class OpBranch; // IWYU pragma: export +class OpBranchConditional; // IWYU pragma: export +class OpSwitch; // IWYU pragma: export +class OpKill; // IWYU pragma: export +class OpReturn; // IWYU pragma: export +class OpReturnValue; // IWYU pragma: export +class OpUnreachable; // IWYU pragma: export +class OpLifetimeStart; // IWYU pragma: export +class OpLifetimeStop; // IWYU pragma: export +class OpGroupAsyncCopy; // IWYU pragma: export +class OpGroupWaitEvents; // IWYU pragma: export +class OpGroupAll; // IWYU pragma: export +class OpGroupAny; // IWYU pragma: export +class OpGroupBroadcast; // IWYU pragma: export +class OpGroupIAdd; // IWYU pragma: export +class OpGroupFAdd; // IWYU pragma: export +class OpGroupFMin; // IWYU pragma: export +class OpGroupUMin; // IWYU pragma: export +class OpGroupSMin; // IWYU pragma: export +class OpGroupFMax; // IWYU pragma: export +class OpGroupUMax; // IWYU pragma: export +class OpGroupSMax; // IWYU pragma: export +class OpReadPipe; // IWYU pragma: export +class OpWritePipe; // IWYU pragma: export +class OpReservedReadPipe; // IWYU pragma: export +class OpReservedWritePipe; // IWYU pragma: export +class OpReserveReadPipePackets; // IWYU pragma: export +class OpReserveWritePipePackets; // IWYU pragma: export +class OpCommitReadPipe; // IWYU pragma: export +class OpCommitWritePipe; // IWYU pragma: export +class OpIsValidReserveId; // IWYU pragma: export +class OpGetNumPipePackets; // IWYU pragma: export +class OpGetMaxPipePackets; // IWYU pragma: export +class OpGroupReserveReadPipePackets; // IWYU pragma: export +class OpGroupReserveWritePipePackets; // IWYU pragma: export +class OpGroupCommitReadPipe; // IWYU pragma: export +class OpGroupCommitWritePipe; // IWYU pragma: export +class OpEnqueueMarker; // IWYU pragma: export +class OpEnqueueKernel; // IWYU pragma: export +class OpGetKernelNDrangeSubGroupCount; // IWYU pragma: export +class OpGetKernelNDrangeMaxSubGroupSize; // IWYU pragma: export +class OpGetKernelWorkGroupSize; // IWYU pragma: export +class OpGetKernelPreferredWorkGroupSizeMultiple; // IWYU pragma: export +class OpRetainEvent; // IWYU pragma: export +class OpReleaseEvent; // IWYU pragma: export +class OpCreateUserEvent; // IWYU pragma: export +class OpIsValidEvent; // IWYU pragma: export +class OpSetUserEventStatus; // IWYU pragma: export +class OpCaptureEventProfilingInfo; // IWYU pragma: export +class OpGetDefaultQueue; // IWYU pragma: export +class OpBuildNDRange; // IWYU pragma: export +class OpImageSparseSampleImplicitLod; // IWYU pragma: export +class OpImageSparseSampleExplicitLod; // IWYU pragma: export +class OpImageSparseSampleDrefImplicitLod; // IWYU pragma: export +class OpImageSparseSampleDrefExplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjImplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjExplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjDrefImplicitLod; // IWYU pragma: export +class OpImageSparseSampleProjDrefExplicitLod; // IWYU pragma: export +class OpImageSparseFetch; // IWYU pragma: export +class OpImageSparseGather; // IWYU pragma: export +class OpImageSparseDrefGather; // IWYU pragma: export +class OpImageSparseTexelsResident; // IWYU pragma: export +class OpNoLine; // IWYU pragma: export +class OpAtomicFlagTestAndSet; // IWYU pragma: export +class OpAtomicFlagClear; // IWYU pragma: export +class OpImageSparseRead; // IWYU pragma: export +class OpSizeOf; // IWYU pragma: export +class OpTypePipeStorage; // IWYU pragma: export +class OpConstantPipeStorage; // IWYU pragma: export +class OpCreatePipeFromPipeStorage; // IWYU pragma: export +class OpGetKernelLocalSizeForSubgroupCount; // IWYU pragma: export +class OpGetKernelMaxNumSubgroups; // IWYU pragma: export +class OpTypeNamedBarrier; // IWYU pragma: export +class OpNamedBarrierInitialize; // IWYU pragma: export +class OpMemoryNamedBarrier; // IWYU pragma: export +class OpModuleProcessed; // IWYU pragma: export +class OpExecutionModeId; // IWYU pragma: export +class OpDecorateId; // IWYU pragma: export +class OpGroupNonUniformElect; // IWYU pragma: export +class OpGroupNonUniformAll; // IWYU pragma: export +class OpGroupNonUniformAny; // IWYU pragma: export +class OpGroupNonUniformAllEqual; // IWYU pragma: export +class OpGroupNonUniformBroadcast; // IWYU pragma: export +class OpGroupNonUniformBroadcastFirst; // IWYU pragma: export +class OpGroupNonUniformBallot; // IWYU pragma: export +class OpGroupNonUniformInverseBallot; // IWYU pragma: export +class OpGroupNonUniformBallotBitExtract; // IWYU pragma: export +class OpGroupNonUniformBallotBitCount; // IWYU pragma: export +class OpGroupNonUniformBallotFindLSB; // IWYU pragma: export +class OpGroupNonUniformBallotFindMSB; // IWYU pragma: export +class OpGroupNonUniformShuffle; // IWYU pragma: export +class OpGroupNonUniformShuffleXor; // IWYU pragma: export +class OpGroupNonUniformShuffleUp; // IWYU pragma: export +class OpGroupNonUniformShuffleDown; // IWYU pragma: export +class OpGroupNonUniformIAdd; // IWYU pragma: export +class OpGroupNonUniformFAdd; // IWYU pragma: export +class OpGroupNonUniformIMul; // IWYU pragma: export +class OpGroupNonUniformFMul; // IWYU pragma: export +class OpGroupNonUniformSMin; // IWYU pragma: export +class OpGroupNonUniformUMin; // IWYU pragma: export +class OpGroupNonUniformFMin; // IWYU pragma: export +class OpGroupNonUniformSMax; // IWYU pragma: export +class OpGroupNonUniformUMax; // IWYU pragma: export +class OpGroupNonUniformFMax; // IWYU pragma: export +class OpGroupNonUniformBitwiseAnd; // IWYU pragma: export +class OpGroupNonUniformBitwiseOr; // IWYU pragma: export +class OpGroupNonUniformBitwiseXor; // IWYU pragma: export +class OpGroupNonUniformLogicalAnd; // IWYU pragma: export +class OpGroupNonUniformLogicalOr; // IWYU pragma: export +class OpGroupNonUniformLogicalXor; // IWYU pragma: export +class OpGroupNonUniformQuadBroadcast; // IWYU pragma: export +class OpGroupNonUniformQuadSwap; // IWYU pragma: export +class OpCopyLogical; // IWYU pragma: export +class OpPtrEqual; // IWYU pragma: export +class OpPtrNotEqual; // IWYU pragma: export +class OpPtrDiff; // IWYU pragma: export +class OpTypeCooperativeMatrixKHR; // IWYU pragma: export +class OpCooperativeMatrixLoadKHR; // IWYU pragma: export +class OpCooperativeMatrixStoreKHR; // IWYU pragma: export +class OpCooperativeMatrixMulAddKHR; // IWYU pragma: export +class OpCooperativeMatrixLengthKHR; // IWYU pragma: export +class OpSubgroupBlockReadINTEL; // IWYU pragma: export +class OpSubgroupBlockWriteINTEL; // IWYU pragma: export +class OpAsmTargetINTEL; // IWYU pragma: export +class OpAsmINTEL; // IWYU pragma: export +class OpAsmCallINTEL; // IWYU pragma: export +class OpAtomicFMinEXT; // IWYU pragma: export +class OpAtomicFMaxEXT; // IWYU pragma: export +class OpAtomicFAddEXT; // IWYU pragma: export +class OpConvertFToBF16INTEL; // IWYU pragma: export +class OpConvertBF16ToFINTEL; // IWYU pragma: export +class OpControlBarrierArriveINTEL; // IWYU pragma: export +class OpControlBarrierWaitINTEL; // IWYU pragma: export +class OpCooperativeMatrixLoadCheckedINTEL; // IWYU pragma: export +class OpCooperativeMatrixStoreCheckedINTEL; // IWYU pragma: export + +} // namespace tinytc::spv + +#endif // GENERATED_DEFS_20250630_HPP diff --git a/src/spv/dope_vector.cpp b/src/spv/dope_vector.cpp new file mode 100644 index 00000000..eeaf7217 --- /dev/null +++ b/src/spv/dope_vector.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/dope_vector.hpp" +#include "spv/defs.hpp" +#include "tinytc/types.hpp" + +#include +#include + +namespace tinytc::spv { + +dope_vector::dope_vector(spv_inst *ty, std::vector static_shape, + std::vector static_stride, spv_inst *size_ty, + std::int64_t static_size, spv_inst *offset_ty, std::int64_t static_offset) + : ty_(ty), static_shape_(std::move(static_shape)), static_stride_(std::move(static_stride)), + shape_(dim(), nullptr), stride_(dim(), nullptr), size_ty_(size_ty), offset_ty_(offset_ty), + static_size_(static_size), static_offset_(static_offset) { + if (static_shape_.size() != static_stride_.size()) { + throw status::internal_compiler_error; + } +} + +auto dope_vector::num_dynamic() const -> std::int64_t { + auto const sum_dynamic = [](std::vector const &vec) { + std::int64_t num_dynamic = 0; + for (auto &v : vec) { + if (is_dynamic_value(v)) { + ++num_dynamic; + } + } + return num_dynamic; + }; + return sum_dynamic(static_shape_) + sum_dynamic(static_stride_); +} + +} // namespace tinytc::spv + diff --git a/src/spv/dope_vector.hpp b/src/spv/dope_vector.hpp new file mode 100644 index 00000000..64bbec3f --- /dev/null +++ b/src/spv/dope_vector.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DOPE_VECTOR_20241213_HPP +#define DOPE_VECTOR_20241213_HPP + +#include +#include + +namespace tinytc::spv { + +class spv_inst; + +class dope_vector { + public: + dope_vector() = default; + dope_vector(spv_inst *ty, std::vector static_shape, + std::vector static_stride, spv_inst *size_ty = nullptr, + std::int64_t static_size = 0, spv_inst *offset_ty = nullptr, + std::int64_t static_offset = 0); + + inline auto dim() const -> std::int64_t { return static_shape_.size(); } + inline auto ty() const -> spv_inst * { return ty_; } + inline auto static_shape(std::int64_t i) const -> std::int64_t { return static_shape_[i]; } + inline auto static_stride(std::int64_t i) const -> std::int64_t { return static_stride_[i]; } + inline auto shape(std::int64_t i) const -> spv_inst * { return shape_[i]; } + inline auto stride(std::int64_t i) const -> spv_inst * { return stride_[i]; } + inline void shape(std::int64_t i, spv_inst *s) { shape_[i] = s; } + inline void stride(std::int64_t i, spv_inst *s) { stride_[i] = s; } + + inline auto size_ty() const -> spv_inst * { return size_ty_; } + inline auto static_size() const -> std::int64_t { return static_size_; } + inline auto size() -> spv_inst * { return size_; } + inline void size(spv_inst *size) { size_ = size; } + + inline auto offset_ty() const -> spv_inst * { return offset_ty_; } + inline auto static_offset() const -> std::int64_t { return static_offset_; } + inline auto offset() -> spv_inst * { return offset_; } + inline void offset(spv_inst *offset) { offset_ = offset; } + + auto num_dynamic() const -> std::int64_t; + + private: + spv_inst *ty_ = nullptr; + std::vector static_shape_, static_stride_; + std::vector shape_, stride_; + spv_inst *size_ty_ = nullptr, *offset_ty_ = nullptr; + std::int64_t static_size_, static_offset_; + spv_inst *size_ = nullptr; + spv_inst *offset_ = nullptr; +}; + +} // namespace tinytc::spv + +#endif // DOPE_VECTOR_20241213_HPP diff --git a/src/spv/enums.hpp b/src/spv/enums.hpp new file mode 100644 index 00000000..00e0775f --- /dev/null +++ b/src/spv/enums.hpp @@ -0,0 +1,1546 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_ENUMS_20250630_HPP +#define GENERATED_ENUMS_20250630_HPP + +#include + +namespace tinytc::spv { + +constexpr std::int32_t magic_number = 0x07230203; + +enum class Op { + Nop = 0, + Undef = 1, + SourceContinued = 2, + Source = 3, + SourceExtension = 4, + Name = 5, + MemberName = 6, + String = 7, + Line = 8, + Extension = 10, + ExtInstImport = 11, + ExtInst = 12, + MemoryModel = 14, + EntryPoint = 15, + ExecutionMode = 16, + Capability = 17, + TypeVoid = 19, + TypeBool = 20, + TypeInt = 21, + TypeFloat = 22, + TypeVector = 23, + TypeMatrix = 24, + TypeImage = 25, + TypeSampler = 26, + TypeSampledImage = 27, + TypeArray = 28, + TypeRuntimeArray = 29, + TypeStruct = 30, + TypeOpaque = 31, + TypePointer = 32, + TypeFunction = 33, + TypeEvent = 34, + TypeDeviceEvent = 35, + TypeReserveId = 36, + TypeQueue = 37, + TypePipe = 38, + TypeForwardPointer = 39, + ConstantTrue = 41, + ConstantFalse = 42, + Constant = 43, + ConstantComposite = 44, + ConstantSampler = 45, + ConstantNull = 46, + Function = 54, + FunctionParameter = 55, + FunctionEnd = 56, + FunctionCall = 57, + Variable = 59, + ImageTexelPointer = 60, + Load = 61, + Store = 62, + CopyMemory = 63, + CopyMemorySized = 64, + AccessChain = 65, + InBoundsAccessChain = 66, + PtrAccessChain = 67, + ArrayLength = 68, + GenericPtrMemSemantics = 69, + InBoundsPtrAccessChain = 70, + Decorate = 71, + MemberDecorate = 72, + DecorationGroup = 73, + GroupDecorate = 74, + GroupMemberDecorate = 75, + VectorExtractDynamic = 77, + VectorInsertDynamic = 78, + VectorShuffle = 79, + CompositeConstruct = 80, + CompositeExtract = 81, + CompositeInsert = 82, + CopyObject = 83, + Transpose = 84, + SampledImage = 86, + ImageSampleImplicitLod = 87, + ImageSampleExplicitLod = 88, + ImageSampleDrefImplicitLod = 89, + ImageSampleDrefExplicitLod = 90, + ImageSampleProjImplicitLod = 91, + ImageSampleProjExplicitLod = 92, + ImageSampleProjDrefImplicitLod = 93, + ImageSampleProjDrefExplicitLod = 94, + ImageFetch = 95, + ImageGather = 96, + ImageDrefGather = 97, + ImageRead = 98, + ImageWrite = 99, + Image = 100, + ImageQueryFormat = 101, + ImageQueryOrder = 102, + ImageQuerySizeLod = 103, + ImageQuerySize = 104, + ImageQueryLod = 105, + ImageQueryLevels = 106, + ImageQuerySamples = 107, + ConvertFToU = 109, + ConvertFToS = 110, + ConvertSToF = 111, + ConvertUToF = 112, + UConvert = 113, + SConvert = 114, + FConvert = 115, + QuantizeToF16 = 116, + ConvertPtrToU = 117, + SatConvertSToU = 118, + SatConvertUToS = 119, + ConvertUToPtr = 120, + PtrCastToGeneric = 121, + GenericCastToPtr = 122, + GenericCastToPtrExplicit = 123, + Bitcast = 124, + SNegate = 126, + FNegate = 127, + IAdd = 128, + FAdd = 129, + ISub = 130, + FSub = 131, + IMul = 132, + FMul = 133, + UDiv = 134, + SDiv = 135, + FDiv = 136, + UMod = 137, + SRem = 138, + SMod = 139, + FRem = 140, + FMod = 141, + VectorTimesScalar = 142, + MatrixTimesScalar = 143, + VectorTimesMatrix = 144, + MatrixTimesVector = 145, + MatrixTimesMatrix = 146, + OuterProduct = 147, + Dot = 148, + IAddCarry = 149, + ISubBorrow = 150, + UMulExtended = 151, + SMulExtended = 152, + Any = 154, + All = 155, + IsNan = 156, + IsInf = 157, + IsFinite = 158, + IsNormal = 159, + SignBitSet = 160, + LessOrGreater = 161, + Ordered = 162, + Unordered = 163, + LogicalEqual = 164, + LogicalNotEqual = 165, + LogicalOr = 166, + LogicalAnd = 167, + LogicalNot = 168, + Select = 169, + IEqual = 170, + INotEqual = 171, + UGreaterThan = 172, + SGreaterThan = 173, + UGreaterThanEqual = 174, + SGreaterThanEqual = 175, + ULessThan = 176, + SLessThan = 177, + ULessThanEqual = 178, + SLessThanEqual = 179, + FOrdEqual = 180, + FUnordEqual = 181, + FOrdNotEqual = 182, + FUnordNotEqual = 183, + FOrdLessThan = 184, + FUnordLessThan = 185, + FOrdGreaterThan = 186, + FUnordGreaterThan = 187, + FOrdLessThanEqual = 188, + FUnordLessThanEqual = 189, + FOrdGreaterThanEqual = 190, + FUnordGreaterThanEqual = 191, + ShiftRightLogical = 194, + ShiftRightArithmetic = 195, + ShiftLeftLogical = 196, + BitwiseOr = 197, + BitwiseXor = 198, + BitwiseAnd = 199, + Not = 200, + BitFieldInsert = 201, + BitFieldSExtract = 202, + BitFieldUExtract = 203, + BitReverse = 204, + BitCount = 205, + DPdx = 207, + DPdy = 208, + Fwidth = 209, + DPdxFine = 210, + DPdyFine = 211, + FwidthFine = 212, + DPdxCoarse = 213, + DPdyCoarse = 214, + FwidthCoarse = 215, + EmitVertex = 218, + EndPrimitive = 219, + EmitStreamVertex = 220, + EndStreamPrimitive = 221, + ControlBarrier = 224, + MemoryBarrier = 225, + AtomicLoad = 227, + AtomicStore = 228, + AtomicExchange = 229, + AtomicCompareExchange = 230, + AtomicCompareExchangeWeak = 231, + AtomicIIncrement = 232, + AtomicIDecrement = 233, + AtomicIAdd = 234, + AtomicISub = 235, + AtomicSMin = 236, + AtomicUMin = 237, + AtomicSMax = 238, + AtomicUMax = 239, + AtomicAnd = 240, + AtomicOr = 241, + AtomicXor = 242, + Phi = 245, + LoopMerge = 246, + SelectionMerge = 247, + Label = 248, + Branch = 249, + BranchConditional = 250, + Switch = 251, + Kill = 252, + Return = 253, + ReturnValue = 254, + Unreachable = 255, + LifetimeStart = 256, + LifetimeStop = 257, + GroupAsyncCopy = 259, + GroupWaitEvents = 260, + GroupAll = 261, + GroupAny = 262, + GroupBroadcast = 263, + GroupIAdd = 264, + GroupFAdd = 265, + GroupFMin = 266, + GroupUMin = 267, + GroupSMin = 268, + GroupFMax = 269, + GroupUMax = 270, + GroupSMax = 271, + ReadPipe = 274, + WritePipe = 275, + ReservedReadPipe = 276, + ReservedWritePipe = 277, + ReserveReadPipePackets = 278, + ReserveWritePipePackets = 279, + CommitReadPipe = 280, + CommitWritePipe = 281, + IsValidReserveId = 282, + GetNumPipePackets = 283, + GetMaxPipePackets = 284, + GroupReserveReadPipePackets = 285, + GroupReserveWritePipePackets = 286, + GroupCommitReadPipe = 287, + GroupCommitWritePipe = 288, + EnqueueMarker = 291, + EnqueueKernel = 292, + GetKernelNDrangeSubGroupCount = 293, + GetKernelNDrangeMaxSubGroupSize = 294, + GetKernelWorkGroupSize = 295, + GetKernelPreferredWorkGroupSizeMultiple = 296, + RetainEvent = 297, + ReleaseEvent = 298, + CreateUserEvent = 299, + IsValidEvent = 300, + SetUserEventStatus = 301, + CaptureEventProfilingInfo = 302, + GetDefaultQueue = 303, + BuildNDRange = 304, + ImageSparseSampleImplicitLod = 305, + ImageSparseSampleExplicitLod = 306, + ImageSparseSampleDrefImplicitLod = 307, + ImageSparseSampleDrefExplicitLod = 308, + ImageSparseSampleProjImplicitLod = 309, + ImageSparseSampleProjExplicitLod = 310, + ImageSparseSampleProjDrefImplicitLod = 311, + ImageSparseSampleProjDrefExplicitLod = 312, + ImageSparseFetch = 313, + ImageSparseGather = 314, + ImageSparseDrefGather = 315, + ImageSparseTexelsResident = 316, + NoLine = 317, + AtomicFlagTestAndSet = 318, + AtomicFlagClear = 319, + ImageSparseRead = 320, + SizeOf = 321, + TypePipeStorage = 322, + ConstantPipeStorage = 323, + CreatePipeFromPipeStorage = 324, + GetKernelLocalSizeForSubgroupCount = 325, + GetKernelMaxNumSubgroups = 326, + TypeNamedBarrier = 327, + NamedBarrierInitialize = 328, + MemoryNamedBarrier = 329, + ModuleProcessed = 330, + ExecutionModeId = 331, + DecorateId = 332, + GroupNonUniformElect = 333, + GroupNonUniformAll = 334, + GroupNonUniformAny = 335, + GroupNonUniformAllEqual = 336, + GroupNonUniformBroadcast = 337, + GroupNonUniformBroadcastFirst = 338, + GroupNonUniformBallot = 339, + GroupNonUniformInverseBallot = 340, + GroupNonUniformBallotBitExtract = 341, + GroupNonUniformBallotBitCount = 342, + GroupNonUniformBallotFindLSB = 343, + GroupNonUniformBallotFindMSB = 344, + GroupNonUniformShuffle = 345, + GroupNonUniformShuffleXor = 346, + GroupNonUniformShuffleUp = 347, + GroupNonUniformShuffleDown = 348, + GroupNonUniformIAdd = 349, + GroupNonUniformFAdd = 350, + GroupNonUniformIMul = 351, + GroupNonUniformFMul = 352, + GroupNonUniformSMin = 353, + GroupNonUniformUMin = 354, + GroupNonUniformFMin = 355, + GroupNonUniformSMax = 356, + GroupNonUniformUMax = 357, + GroupNonUniformFMax = 358, + GroupNonUniformBitwiseAnd = 359, + GroupNonUniformBitwiseOr = 360, + GroupNonUniformBitwiseXor = 361, + GroupNonUniformLogicalAnd = 362, + GroupNonUniformLogicalOr = 363, + GroupNonUniformLogicalXor = 364, + GroupNonUniformQuadBroadcast = 365, + GroupNonUniformQuadSwap = 366, + CopyLogical = 400, + PtrEqual = 401, + PtrNotEqual = 402, + PtrDiff = 403, + TypeCooperativeMatrixKHR = 4456, + CooperativeMatrixLoadKHR = 4457, + CooperativeMatrixStoreKHR = 4458, + CooperativeMatrixMulAddKHR = 4459, + CooperativeMatrixLengthKHR = 4460, + SubgroupBlockReadINTEL = 5575, + SubgroupBlockWriteINTEL = 5576, + AsmTargetINTEL = 5609, + AsmINTEL = 5610, + AsmCallINTEL = 5611, + AtomicFMinEXT = 5614, + AtomicFMaxEXT = 5615, + AtomicFAddEXT = 6035, + ConvertFToBF16INTEL = 6116, + ConvertBF16ToFINTEL = 6117, + ControlBarrierArriveINTEL = 6142, + ControlBarrierWaitINTEL = 6143, + CooperativeMatrixLoadCheckedINTEL = 6193, + CooperativeMatrixStoreCheckedINTEL = 6194, +}; +enum class ImageOperands { + None = 0x0000, + Bias = 0x0001, + Lod = 0x0002, + Grad = 0x0004, + ConstOffset = 0x0008, + Offset = 0x0010, + ConstOffsets = 0x0020, + Sample = 0x0040, + MinLod = 0x0080, + MakeTexelAvailable = 0x0100, + MakeTexelVisible = 0x0200, + NonPrivateTexel = 0x0400, + VolatileTexel = 0x0800, + SignExtend = 0x1000, + ZeroExtend = 0x2000, + Nontemporal = 0x4000, + Offsets = 0x10000, +}; +enum class FPFastMathMode { + None = 0x0000, + NotNaN = 0x0001, + NotInf = 0x0002, + NSZ = 0x0004, + AllowRecip = 0x0008, + Fast = 0x0010, + AllowContract = 0x10000, + AllowReassoc = 0x20000, + AllowTransform = 0x40000, +}; +enum class SelectionControl { + None = 0x0000, + Flatten = 0x0001, + DontFlatten = 0x0002, +}; +enum class LoopControl { + None = 0x0000, + Unroll = 0x0001, + DontUnroll = 0x0002, + DependencyInfinite = 0x0004, + DependencyLength = 0x0008, + MinIterations = 0x0010, + MaxIterations = 0x0020, + IterationMultiple = 0x0040, + PeelCount = 0x0080, + PartialCount = 0x0100, + InitiationIntervalINTEL = 0x10000, + MaxConcurrencyINTEL = 0x20000, + DependencyArrayINTEL = 0x40000, + PipelineEnableINTEL = 0x80000, + LoopCoalesceINTEL = 0x100000, + MaxInterleavingINTEL = 0x200000, + SpeculatedIterationsINTEL = 0x400000, + NoFusionINTEL = 0x800000, + LoopCountINTEL = 0x1000000, + MaxReinvocationDelayINTEL = 0x2000000, +}; +enum class FunctionControl { + None = 0x0000, + Inline = 0x0001, + DontInline = 0x0002, + Pure = 0x0004, + Const = 0x0008, + OptNoneEXT = 0x10000, +}; +enum class MemorySemantics { + Relaxed = 0x0000, + Acquire = 0x0002, + Release = 0x0004, + AcquireRelease = 0x0008, + SequentiallyConsistent = 0x0010, + UniformMemory = 0x0040, + SubgroupMemory = 0x0080, + WorkgroupMemory = 0x0100, + CrossWorkgroupMemory = 0x0200, + AtomicCounterMemory = 0x0400, + ImageMemory = 0x0800, + OutputMemory = 0x1000, + MakeAvailable = 0x2000, + MakeVisible = 0x4000, + Volatile = 0x8000, +}; +enum class MemoryAccess { + None = 0x0000, + Volatile = 0x0001, + Aligned = 0x0002, + Nontemporal = 0x0004, + MakePointerAvailable = 0x0008, + MakePointerVisible = 0x0010, + NonPrivatePointer = 0x0020, + AliasScopeINTELMask = 0x10000, + NoAliasINTELMask = 0x20000, +}; +enum class KernelProfilingInfo { + None = 0x0000, + CmdExecTime = 0x0001, +}; +enum class RayFlags { + NoneKHR = 0x0000, + OpaqueKHR = 0x0001, + NoOpaqueKHR = 0x0002, + TerminateOnFirstHitKHR = 0x0004, + SkipClosestHitShaderKHR = 0x0008, + CullBackFacingTrianglesKHR = 0x0010, + CullFrontFacingTrianglesKHR = 0x0020, + CullOpaqueKHR = 0x0040, + CullNoOpaqueKHR = 0x0080, + SkipTrianglesKHR = 0x0100, + SkipAABBsKHR = 0x0200, + ForceOpacityMicromap2StateEXT = 0x0400, +}; +enum class FragmentShadingRate { + Vertical2Pixels = 0x0001, + Vertical4Pixels = 0x0002, + Horizontal2Pixels = 0x0004, + Horizontal4Pixels = 0x0008, +}; +enum class RawAccessChainOperands { + None = 0x0000, + RobustnessPerComponentNV = 0x0001, + RobustnessPerElementNV = 0x0002, +}; +enum class SourceLanguage { + Unknown = 0, + ESSL = 1, + GLSL = 2, + OpenCL_C = 3, + OpenCL_CPP = 4, + HLSL = 5, + CPP_for_OpenCL = 6, + SYCL = 7, + HERO_C = 8, + NZSL = 9, + WGSL = 10, + Slang = 11, + Zig = 12, + Rust = 13, +}; +enum class ExecutionModel { + Vertex = 0, + TessellationControl = 1, + TessellationEvaluation = 2, + Geometry = 3, + Fragment = 4, + GLCompute = 5, + Kernel = 6, + TaskNV = 5267, + MeshNV = 5268, + RayGenerationKHR = 5313, + IntersectionKHR = 5314, + AnyHitKHR = 5315, + ClosestHitKHR = 5316, + MissKHR = 5317, + CallableKHR = 5318, + TaskEXT = 5364, + MeshEXT = 5365, +}; +enum class AddressingModel { + Logical = 0, + Physical32 = 1, + Physical64 = 2, + PhysicalStorageBuffer64 = 5348, +}; +enum class MemoryModel { + Simple = 0, + GLSL450 = 1, + OpenCL = 2, + Vulkan = 3, +}; +enum class ExecutionMode { + Invocations = 0, + SpacingEqual = 1, + SpacingFractionalEven = 2, + SpacingFractionalOdd = 3, + VertexOrderCw = 4, + VertexOrderCcw = 5, + PixelCenterInteger = 6, + OriginUpperLeft = 7, + OriginLowerLeft = 8, + EarlyFragmentTests = 9, + PointMode = 10, + Xfb = 11, + DepthReplacing = 12, + DepthGreater = 14, + DepthLess = 15, + DepthUnchanged = 16, + LocalSize = 17, + LocalSizeHint = 18, + InputPoints = 19, + InputLines = 20, + InputLinesAdjacency = 21, + Triangles = 22, + InputTrianglesAdjacency = 23, + Quads = 24, + Isolines = 25, + OutputVertices = 26, + OutputPoints = 27, + OutputLineStrip = 28, + OutputTriangleStrip = 29, + VecTypeHint = 30, + ContractionOff = 31, + Initializer = 33, + Finalizer = 34, + SubgroupSize = 35, + SubgroupsPerWorkgroup = 36, + SubgroupsPerWorkgroupId = 37, + LocalSizeId = 38, + LocalSizeHintId = 39, + NonCoherentColorAttachmentReadEXT = 4169, + NonCoherentDepthAttachmentReadEXT = 4170, + NonCoherentStencilAttachmentReadEXT = 4171, + SubgroupUniformControlFlowKHR = 4421, + PostDepthCoverage = 4446, + DenormPreserve = 4459, + DenormFlushToZero = 4460, + SignedZeroInfNanPreserve = 4461, + RoundingModeRTE = 4462, + RoundingModeRTZ = 4463, + NonCoherentTileAttachmentReadQCOM = 4489, + TileShadingRateQCOM = 4490, + EarlyAndLateFragmentTestsAMD = 5017, + StencilRefReplacingEXT = 5027, + CoalescingAMDX = 5069, + IsApiEntryAMDX = 5070, + MaxNodeRecursionAMDX = 5071, + StaticNumWorkgroupsAMDX = 5072, + ShaderIndexAMDX = 5073, + MaxNumWorkgroupsAMDX = 5077, + StencilRefUnchangedFrontAMD = 5079, + StencilRefGreaterFrontAMD = 5080, + StencilRefLessFrontAMD = 5081, + StencilRefUnchangedBackAMD = 5082, + StencilRefGreaterBackAMD = 5083, + StencilRefLessBackAMD = 5084, + QuadDerivativesKHR = 5088, + RequireFullQuadsKHR = 5089, + SharesInputWithAMDX = 5102, + OutputLinesEXT = 5269, + OutputPrimitivesEXT = 5270, + DerivativeGroupQuadsKHR = 5289, + DerivativeGroupLinearKHR = 5290, + OutputTrianglesEXT = 5298, + PixelInterlockOrderedEXT = 5366, + PixelInterlockUnorderedEXT = 5367, + SampleInterlockOrderedEXT = 5368, + SampleInterlockUnorderedEXT = 5369, + ShadingRateInterlockOrderedEXT = 5370, + ShadingRateInterlockUnorderedEXT = 5371, + SharedLocalMemorySizeINTEL = 5618, + RoundingModeRTPINTEL = 5620, + RoundingModeRTNINTEL = 5621, + FloatingPointModeALTINTEL = 5622, + FloatingPointModeIEEEINTEL = 5623, + MaxWorkgroupSizeINTEL = 5893, + MaxWorkDimINTEL = 5894, + NoGlobalOffsetINTEL = 5895, + NumSIMDWorkitemsINTEL = 5896, + SchedulerTargetFmaxMhzINTEL = 5903, + MaximallyReconvergesKHR = 6023, + FPFastMathDefault = 6028, + StreamingInterfaceINTEL = 6154, + RegisterMapInterfaceINTEL = 6160, + NamedBarrierCountINTEL = 6417, + MaximumRegistersINTEL = 6461, + MaximumRegistersIdINTEL = 6462, + NamedMaximumRegistersINTEL = 6463, +}; +enum class StorageClass { + UniformConstant = 0, + Input = 1, + Uniform = 2, + Output = 3, + Workgroup = 4, + CrossWorkgroup = 5, + Private = 6, + Function = 7, + Generic = 8, + PushConstant = 9, + AtomicCounter = 10, + Image = 11, + StorageBuffer = 12, + TileImageEXT = 4172, + TileAttachmentQCOM = 4491, + NodePayloadAMDX = 5068, + CallableDataKHR = 5328, + IncomingCallableDataKHR = 5329, + RayPayloadKHR = 5338, + HitAttributeKHR = 5339, + IncomingRayPayloadKHR = 5342, + ShaderRecordBufferKHR = 5343, + PhysicalStorageBuffer = 5349, + HitObjectAttributeNV = 5385, + TaskPayloadWorkgroupEXT = 5402, + CodeSectionINTEL = 5605, + DeviceOnlyINTEL = 5936, + HostOnlyINTEL = 5937, +}; +enum class Dim { + Dim1D = 0, + Dim2D = 1, + Dim3D = 2, + Cube = 3, + Rect = 4, + Buffer = 5, + SubpassData = 6, + TileImageDataEXT = 4173, +}; +enum class SamplerAddressingMode { + None = 0, + ClampToEdge = 1, + Clamp = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +enum class SamplerFilterMode { + Nearest = 0, + Linear = 1, +}; +enum class ImageFormat { + Unknown = 0, + Rgba32f = 1, + Rgba16f = 2, + R32f = 3, + Rgba8 = 4, + Rgba8Snorm = 5, + Rg32f = 6, + Rg16f = 7, + R11fG11fB10f = 8, + R16f = 9, + Rgba16 = 10, + Rgb10A2 = 11, + Rg16 = 12, + Rg8 = 13, + R16 = 14, + R8 = 15, + Rgba16Snorm = 16, + Rg16Snorm = 17, + Rg8Snorm = 18, + R16Snorm = 19, + R8Snorm = 20, + Rgba32i = 21, + Rgba16i = 22, + Rgba8i = 23, + R32i = 24, + Rg32i = 25, + Rg16i = 26, + Rg8i = 27, + R16i = 28, + R8i = 29, + Rgba32ui = 30, + Rgba16ui = 31, + Rgba8ui = 32, + R32ui = 33, + Rgb10a2ui = 34, + Rg32ui = 35, + Rg16ui = 36, + Rg8ui = 37, + R16ui = 38, + R8ui = 39, + R64ui = 40, + R64i = 41, +}; +enum class ImageChannelOrder { + R = 0, + A = 1, + RG = 2, + RA = 3, + RGB = 4, + RGBA = 5, + BGRA = 6, + ARGB = 7, + Intensity = 8, + Luminance = 9, + Rx = 10, + RGx = 11, + RGBx = 12, + Depth = 13, + DepthStencil = 14, + sRGB = 15, + sRGBx = 16, + sRGBA = 17, + sBGRA = 18, + ABGR = 19, +}; +enum class ImageChannelDataType { + SnormInt8 = 0, + SnormInt16 = 1, + UnormInt8 = 2, + UnormInt16 = 3, + UnormShort565 = 4, + UnormShort555 = 5, + UnormInt101010 = 6, + SignedInt8 = 7, + SignedInt16 = 8, + SignedInt32 = 9, + UnsignedInt8 = 10, + UnsignedInt16 = 11, + UnsignedInt32 = 12, + HalfFloat = 13, + Float = 14, + UnormInt24 = 15, + UnormInt101010_2 = 16, + UnormInt10X6EXT = 17, + UnsignedIntRaw10EXT = 19, + UnsignedIntRaw12EXT = 20, + UnormInt2_101010EXT = 21, + UnsignedInt10X6EXT = 22, + UnsignedInt12X4EXT = 23, + UnsignedInt14X2EXT = 24, + UnormInt12X4EXT = 25, + UnormInt14X2EXT = 26, +}; +enum class FPRoundingMode { + RTE = 0, + RTZ = 1, + RTP = 2, + RTN = 3, +}; +enum class FPDenormMode { + Preserve = 0, + FlushToZero = 1, +}; +enum class QuantizationModes { + TRN = 0, + TRN_ZERO = 1, + RND = 2, + RND_ZERO = 3, + RND_INF = 4, + RND_MIN_INF = 5, + RND_CONV = 6, + RND_CONV_ODD = 7, +}; +enum class FPOperationMode { + IEEE = 0, + ALT = 1, +}; +enum class OverflowModes { + WRAP = 0, + SAT = 1, + SAT_ZERO = 2, + SAT_SYM = 3, +}; +enum class LinkageType { + Export = 0, + Import = 1, + LinkOnceODR = 2, +}; +enum class AccessQualifier { + ReadOnly = 0, + WriteOnly = 1, + ReadWrite = 2, +}; +enum class HostAccessQualifier { + NoneINTEL = 0, + ReadINTEL = 1, + WriteINTEL = 2, + ReadWriteINTEL = 3, +}; +enum class FunctionParameterAttribute { + Zext = 0, + Sext = 1, + ByVal = 2, + Sret = 3, + NoAlias = 4, + NoCapture = 5, + NoWrite = 6, + NoReadWrite = 7, + RuntimeAlignedINTEL = 5940, +}; +enum class Decoration { + RelaxedPrecision = 0, + SpecId = 1, + Block = 2, + BufferBlock = 3, + RowMajor = 4, + ColMajor = 5, + ArrayStride = 6, + MatrixStride = 7, + GLSLShared = 8, + GLSLPacked = 9, + CPacked = 10, + BuiltIn = 11, + NoPerspective = 13, + Flat = 14, + Patch = 15, + Centroid = 16, + Sample = 17, + Invariant = 18, + Restrict = 19, + Aliased = 20, + Volatile = 21, + Constant = 22, + Coherent = 23, + NonWritable = 24, + NonReadable = 25, + Uniform = 26, + UniformId = 27, + SaturatedConversion = 28, + Stream = 29, + Location = 30, + Component = 31, + Index = 32, + Binding = 33, + DescriptorSet = 34, + Offset = 35, + XfbBuffer = 36, + XfbStride = 37, + FuncParamAttr = 38, + FPRoundingMode = 39, + FPFastMathMode = 40, + LinkageAttributes = 41, + NoContraction = 42, + InputAttachmentIndex = 43, + Alignment = 44, + MaxByteOffset = 45, + AlignmentId = 46, + MaxByteOffsetId = 47, + SaturatedToLargestFloat8NormalConversionEXT = 4216, + NoSignedWrap = 4469, + NoUnsignedWrap = 4470, + WeightTextureQCOM = 4487, + BlockMatchTextureQCOM = 4488, + BlockMatchSamplerQCOM = 4499, + ExplicitInterpAMD = 4999, + NodeSharesPayloadLimitsWithAMDX = 5019, + NodeMaxPayloadsAMDX = 5020, + TrackFinishWritingAMDX = 5078, + PayloadNodeNameAMDX = 5091, + PayloadNodeBaseIndexAMDX = 5098, + PayloadNodeSparseArrayAMDX = 5099, + PayloadNodeArraySizeAMDX = 5100, + PayloadDispatchIndirectAMDX = 5105, + OverrideCoverageNV = 5248, + PassthroughNV = 5250, + ViewportRelativeNV = 5252, + SecondaryViewportRelativeNV = 5256, + PerPrimitiveEXT = 5271, + PerViewNV = 5272, + PerTaskNV = 5273, + PerVertexKHR = 5285, + NonUniform = 5300, + RestrictPointer = 5355, + AliasedPointer = 5356, + HitObjectShaderRecordBufferNV = 5386, + BindlessSamplerNV = 5398, + BindlessImageNV = 5399, + BoundSamplerNV = 5400, + BoundImageNV = 5401, + SIMTCallINTEL = 5599, + ReferencedIndirectlyINTEL = 5602, + ClobberINTEL = 5607, + SideEffectsINTEL = 5608, + VectorComputeVariableINTEL = 5624, + FuncParamIOKindINTEL = 5625, + VectorComputeFunctionINTEL = 5626, + StackCallINTEL = 5627, + GlobalVariableOffsetINTEL = 5628, + CounterBuffer = 5634, + UserSemantic = 5635, + UserTypeGOOGLE = 5636, + FunctionRoundingModeINTEL = 5822, + FunctionDenormModeINTEL = 5823, + RegisterINTEL = 5825, + MemoryINTEL = 5826, + NumbanksINTEL = 5827, + BankwidthINTEL = 5828, + MaxPrivateCopiesINTEL = 5829, + SinglepumpINTEL = 5830, + DoublepumpINTEL = 5831, + MaxReplicatesINTEL = 5832, + SimpleDualPortINTEL = 5833, + MergeINTEL = 5834, + BankBitsINTEL = 5835, + ForcePow2DepthINTEL = 5836, + StridesizeINTEL = 5883, + WordsizeINTEL = 5884, + TrueDualPortINTEL = 5885, + BurstCoalesceINTEL = 5899, + CacheSizeINTEL = 5900, + DontStaticallyCoalesceINTEL = 5901, + PrefetchINTEL = 5902, + StallEnableINTEL = 5905, + FuseLoopsInFunctionINTEL = 5907, + MathOpDSPModeINTEL = 5909, + AliasScopeINTEL = 5914, + NoAliasINTEL = 5915, + InitiationIntervalINTEL = 5917, + MaxConcurrencyINTEL = 5918, + PipelineEnableINTEL = 5919, + BufferLocationINTEL = 5921, + IOPipeStorageINTEL = 5944, + FunctionFloatingPointModeINTEL = 6080, + SingleElementVectorINTEL = 6085, + VectorComputeCallableFunctionINTEL = 6087, + MediaBlockIOINTEL = 6140, + StallFreeINTEL = 6151, + FPMaxErrorDecorationINTEL = 6170, + LatencyControlLabelINTEL = 6172, + LatencyControlConstraintINTEL = 6173, + ConduitKernelArgumentINTEL = 6175, + RegisterMapKernelArgumentINTEL = 6176, + MMHostInterfaceAddressWidthINTEL = 6177, + MMHostInterfaceDataWidthINTEL = 6178, + MMHostInterfaceLatencyINTEL = 6179, + MMHostInterfaceReadWriteModeINTEL = 6180, + MMHostInterfaceMaxBurstINTEL = 6181, + MMHostInterfaceWaitRequestINTEL = 6182, + StableKernelArgumentINTEL = 6183, + HostAccessINTEL = 6188, + InitModeINTEL = 6190, + ImplementInRegisterMapINTEL = 6191, + CacheControlLoadINTEL = 6442, + CacheControlStoreINTEL = 6443, +}; +enum class BuiltIn { + Position = 0, + PointSize = 1, + ClipDistance = 3, + CullDistance = 4, + VertexId = 5, + InstanceId = 6, + PrimitiveId = 7, + InvocationId = 8, + Layer = 9, + ViewportIndex = 10, + TessLevelOuter = 11, + TessLevelInner = 12, + TessCoord = 13, + PatchVertices = 14, + FragCoord = 15, + PointCoord = 16, + FrontFacing = 17, + SampleId = 18, + SamplePosition = 19, + SampleMask = 20, + FragDepth = 22, + HelperInvocation = 23, + NumWorkgroups = 24, + WorkgroupSize = 25, + WorkgroupId = 26, + LocalInvocationId = 27, + GlobalInvocationId = 28, + LocalInvocationIndex = 29, + WorkDim = 30, + GlobalSize = 31, + EnqueuedWorkgroupSize = 32, + GlobalOffset = 33, + GlobalLinearId = 34, + SubgroupSize = 36, + SubgroupMaxSize = 37, + NumSubgroups = 38, + NumEnqueuedSubgroups = 39, + SubgroupId = 40, + SubgroupLocalInvocationId = 41, + VertexIndex = 42, + InstanceIndex = 43, + CoreIDARM = 4160, + CoreCountARM = 4161, + CoreMaxIDARM = 4162, + WarpIDARM = 4163, + WarpMaxIDARM = 4164, + SubgroupEqMask = 4416, + SubgroupGeMask = 4417, + SubgroupGtMask = 4418, + SubgroupLeMask = 4419, + SubgroupLtMask = 4420, + BaseVertex = 4424, + BaseInstance = 4425, + DrawIndex = 4426, + PrimitiveShadingRateKHR = 4432, + DeviceIndex = 4438, + ViewIndex = 4440, + ShadingRateKHR = 4444, + TileOffsetQCOM = 4492, + TileDimensionQCOM = 4493, + TileApronSizeQCOM = 4494, + BaryCoordNoPerspAMD = 4992, + BaryCoordNoPerspCentroidAMD = 4993, + BaryCoordNoPerspSampleAMD = 4994, + BaryCoordSmoothAMD = 4995, + BaryCoordSmoothCentroidAMD = 4996, + BaryCoordSmoothSampleAMD = 4997, + BaryCoordPullModelAMD = 4998, + FragStencilRefEXT = 5014, + RemainingRecursionLevelsAMDX = 5021, + ShaderIndexAMDX = 5073, + ViewportMaskNV = 5253, + SecondaryPositionNV = 5257, + SecondaryViewportMaskNV = 5258, + PositionPerViewNV = 5261, + ViewportMaskPerViewNV = 5262, + FullyCoveredEXT = 5264, + TaskCountNV = 5274, + PrimitiveCountNV = 5275, + PrimitiveIndicesNV = 5276, + ClipDistancePerViewNV = 5277, + CullDistancePerViewNV = 5278, + LayerPerViewNV = 5279, + MeshViewCountNV = 5280, + MeshViewIndicesNV = 5281, + BaryCoordKHR = 5286, + BaryCoordNoPerspKHR = 5287, + FragSizeEXT = 5292, + FragInvocationCountEXT = 5293, + PrimitivePointIndicesEXT = 5294, + PrimitiveLineIndicesEXT = 5295, + PrimitiveTriangleIndicesEXT = 5296, + CullPrimitiveEXT = 5299, + LaunchIdKHR = 5319, + LaunchSizeKHR = 5320, + WorldRayOriginKHR = 5321, + WorldRayDirectionKHR = 5322, + ObjectRayOriginKHR = 5323, + ObjectRayDirectionKHR = 5324, + RayTminKHR = 5325, + RayTmaxKHR = 5326, + InstanceCustomIndexKHR = 5327, + ObjectToWorldKHR = 5330, + WorldToObjectKHR = 5331, + HitTNV = 5332, + HitKindKHR = 5333, + CurrentRayTimeNV = 5334, + HitTriangleVertexPositionsKHR = 5335, + HitMicroTriangleVertexPositionsNV = 5337, + HitMicroTriangleVertexBarycentricsNV = 5344, + IncomingRayFlagsKHR = 5351, + RayGeometryIndexKHR = 5352, + HitIsSphereNV = 5359, + HitIsLSSNV = 5360, + HitSpherePositionNV = 5361, + WarpsPerSMNV = 5374, + SMCountNV = 5375, + WarpIDNV = 5376, + SMIDNV = 5377, + HitLSSPositionsNV = 5396, + HitKindFrontFacingMicroTriangleNV = 5405, + HitKindBackFacingMicroTriangleNV = 5406, + HitSphereRadiusNV = 5420, + HitLSSRadiiNV = 5421, + ClusterIDNV = 5436, + CullMaskKHR = 6021, +}; +enum class Scope { + CrossDevice = 0, + Device = 1, + Workgroup = 2, + Subgroup = 3, + Invocation = 4, + QueueFamily = 5, + ShaderCallKHR = 6, +}; +enum class GroupOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, + ClusteredReduce = 3, + PartitionedReduceNV = 6, + PartitionedInclusiveScanNV = 7, + PartitionedExclusiveScanNV = 8, +}; +enum class KernelEnqueueFlags { + NoWait = 0, + WaitKernel = 1, + WaitWorkGroup = 2, +}; +enum class Capability { + Matrix = 0, + Shader = 1, + Geometry = 2, + Tessellation = 3, + Addresses = 4, + Linkage = 5, + Kernel = 6, + Vector16 = 7, + Float16Buffer = 8, + Float16 = 9, + Float64 = 10, + Int64 = 11, + Int64Atomics = 12, + ImageBasic = 13, + ImageReadWrite = 14, + ImageMipmap = 15, + Pipes = 17, + Groups = 18, + DeviceEnqueue = 19, + LiteralSampler = 20, + AtomicStorage = 21, + Int16 = 22, + TessellationPointSize = 23, + GeometryPointSize = 24, + ImageGatherExtended = 25, + StorageImageMultisample = 27, + UniformBufferArrayDynamicIndexing = 28, + SampledImageArrayDynamicIndexing = 29, + StorageBufferArrayDynamicIndexing = 30, + StorageImageArrayDynamicIndexing = 31, + ClipDistance = 32, + CullDistance = 33, + ImageCubeArray = 34, + SampleRateShading = 35, + ImageRect = 36, + SampledRect = 37, + GenericPointer = 38, + Int8 = 39, + InputAttachment = 40, + SparseResidency = 41, + MinLod = 42, + Sampled1D = 43, + Image1D = 44, + SampledCubeArray = 45, + SampledBuffer = 46, + ImageBuffer = 47, + ImageMSArray = 48, + StorageImageExtendedFormats = 49, + ImageQuery = 50, + DerivativeControl = 51, + InterpolationFunction = 52, + TransformFeedback = 53, + GeometryStreams = 54, + StorageImageReadWithoutFormat = 55, + StorageImageWriteWithoutFormat = 56, + MultiViewport = 57, + SubgroupDispatch = 58, + NamedBarrier = 59, + PipeStorage = 60, + GroupNonUniform = 61, + GroupNonUniformVote = 62, + GroupNonUniformArithmetic = 63, + GroupNonUniformBallot = 64, + GroupNonUniformShuffle = 65, + GroupNonUniformShuffleRelative = 66, + GroupNonUniformClustered = 67, + GroupNonUniformQuad = 68, + ShaderLayer = 69, + ShaderViewportIndex = 70, + UniformDecoration = 71, + CoreBuiltinsARM = 4165, + TileImageColorReadAccessEXT = 4166, + TileImageDepthReadAccessEXT = 4167, + TileImageStencilReadAccessEXT = 4168, + TensorsARM = 4174, + StorageTensorArrayDynamicIndexingARM = 4175, + StorageTensorArrayNonUniformIndexingARM = 4176, + CooperativeMatrixLayoutsARM = 4201, + Float8EXT = 4212, + Float8CooperativeMatrixEXT = 4213, + FragmentShadingRateKHR = 4422, + SubgroupBallotKHR = 4423, + DrawParameters = 4427, + WorkgroupMemoryExplicitLayoutKHR = 4428, + WorkgroupMemoryExplicitLayout8BitAccessKHR = 4429, + WorkgroupMemoryExplicitLayout16BitAccessKHR = 4430, + SubgroupVoteKHR = 4431, + StorageBuffer16BitAccess = 4433, + UniformAndStorageBuffer16BitAccess = 4434, + StoragePushConstant16 = 4435, + StorageInputOutput16 = 4436, + DeviceGroup = 4437, + MultiView = 4439, + VariablePointersStorageBuffer = 4441, + VariablePointers = 4442, + AtomicStorageOps = 4445, + SampleMaskPostDepthCoverage = 4447, + StorageBuffer8BitAccess = 4448, + UniformAndStorageBuffer8BitAccess = 4449, + StoragePushConstant8 = 4450, + DenormPreserve = 4464, + DenormFlushToZero = 4465, + SignedZeroInfNanPreserve = 4466, + RoundingModeRTE = 4467, + RoundingModeRTZ = 4468, + RayQueryProvisionalKHR = 4471, + RayQueryKHR = 4472, + UntypedPointersKHR = 4473, + RayTraversalPrimitiveCullingKHR = 4478, + RayTracingKHR = 4479, + TextureSampleWeightedQCOM = 4484, + TextureBoxFilterQCOM = 4485, + TextureBlockMatchQCOM = 4486, + TileShadingQCOM = 4495, + TextureBlockMatch2QCOM = 4498, + Float16ImageAMD = 5008, + ImageGatherBiasLodAMD = 5009, + FragmentMaskAMD = 5010, + StencilExportEXT = 5013, + ImageReadWriteLodAMD = 5015, + Int64ImageEXT = 5016, + ShaderClockKHR = 5055, + ShaderEnqueueAMDX = 5067, + QuadControlKHR = 5087, + Int4TypeINTEL = 5112, + Int4CooperativeMatrixINTEL = 5114, + BFloat16TypeKHR = 5116, + BFloat16DotProductKHR = 5117, + BFloat16CooperativeMatrixKHR = 5118, + SampleMaskOverrideCoverageNV = 5249, + GeometryShaderPassthroughNV = 5251, + ShaderViewportIndexLayerEXT = 5254, + ShaderViewportMaskNV = 5255, + ShaderStereoViewNV = 5259, + PerViewAttributesNV = 5260, + FragmentFullyCoveredEXT = 5265, + MeshShadingNV = 5266, + ImageFootprintNV = 5282, + MeshShadingEXT = 5283, + FragmentBarycentricKHR = 5284, + ComputeDerivativeGroupQuadsKHR = 5288, + FragmentDensityEXT = 5291, + GroupNonUniformPartitionedNV = 5297, + ShaderNonUniform = 5301, + RuntimeDescriptorArray = 5302, + InputAttachmentArrayDynamicIndexing = 5303, + UniformTexelBufferArrayDynamicIndexing = 5304, + StorageTexelBufferArrayDynamicIndexing = 5305, + UniformBufferArrayNonUniformIndexing = 5306, + SampledImageArrayNonUniformIndexing = 5307, + StorageBufferArrayNonUniformIndexing = 5308, + StorageImageArrayNonUniformIndexing = 5309, + InputAttachmentArrayNonUniformIndexing = 5310, + UniformTexelBufferArrayNonUniformIndexing = 5311, + StorageTexelBufferArrayNonUniformIndexing = 5312, + RayTracingPositionFetchKHR = 5336, + RayTracingNV = 5340, + RayTracingMotionBlurNV = 5341, + VulkanMemoryModel = 5345, + VulkanMemoryModelDeviceScope = 5346, + PhysicalStorageBufferAddresses = 5347, + ComputeDerivativeGroupLinearKHR = 5350, + RayTracingProvisionalKHR = 5353, + CooperativeMatrixNV = 5357, + FragmentShaderSampleInterlockEXT = 5363, + FragmentShaderShadingRateInterlockEXT = 5372, + ShaderSMBuiltinsNV = 5373, + FragmentShaderPixelInterlockEXT = 5378, + DemoteToHelperInvocation = 5379, + DisplacementMicromapNV = 5380, + RayTracingOpacityMicromapEXT = 5381, + ShaderInvocationReorderNV = 5383, + BindlessTextureNV = 5390, + RayQueryPositionFetchKHR = 5391, + CooperativeVectorNV = 5394, + AtomicFloat16VectorNV = 5404, + RayTracingDisplacementMicromapNV = 5409, + RawAccessChainsNV = 5414, + RayTracingSpheresGeometryNV = 5418, + RayTracingLinearSweptSpheresGeometryNV = 5419, + CooperativeMatrixReductionsNV = 5430, + CooperativeMatrixConversionsNV = 5431, + CooperativeMatrixPerElementOperationsNV = 5432, + CooperativeMatrixTensorAddressingNV = 5433, + CooperativeMatrixBlockLoadsNV = 5434, + CooperativeVectorTrainingNV = 5435, + RayTracingClusterAccelerationStructureNV = 5437, + TensorAddressingNV = 5439, + SubgroupShuffleINTEL = 5568, + SubgroupBufferBlockIOINTEL = 5569, + SubgroupImageBlockIOINTEL = 5570, + SubgroupImageMediaBlockIOINTEL = 5579, + RoundToInfinityINTEL = 5582, + FloatingPointModeINTEL = 5583, + IntegerFunctions2INTEL = 5584, + FunctionPointersINTEL = 5603, + IndirectReferencesINTEL = 5604, + AsmINTEL = 5606, + AtomicFloat32MinMaxEXT = 5612, + AtomicFloat64MinMaxEXT = 5613, + AtomicFloat16MinMaxEXT = 5616, + VectorComputeINTEL = 5617, + VectorAnyINTEL = 5619, + ExpectAssumeKHR = 5629, + SubgroupAvcMotionEstimationINTEL = 5696, + SubgroupAvcMotionEstimationIntraINTEL = 5697, + SubgroupAvcMotionEstimationChromaINTEL = 5698, + VariableLengthArrayINTEL = 5817, + FunctionFloatControlINTEL = 5821, + FPGAMemoryAttributesINTEL = 5824, + FPFastMathModeINTEL = 5837, + ArbitraryPrecisionIntegersINTEL = 5844, + ArbitraryPrecisionFloatingPointINTEL = 5845, + UnstructuredLoopControlsINTEL = 5886, + FPGALoopControlsINTEL = 5888, + KernelAttributesINTEL = 5892, + FPGAKernelAttributesINTEL = 5897, + FPGAMemoryAccessesINTEL = 5898, + FPGAClusterAttributesINTEL = 5904, + LoopFuseINTEL = 5906, + FPGADSPControlINTEL = 5908, + MemoryAccessAliasingINTEL = 5910, + FPGAInvocationPipeliningAttributesINTEL = 5916, + FPGABufferLocationINTEL = 5920, + ArbitraryPrecisionFixedPointINTEL = 5922, + USMStorageClassesINTEL = 5935, + RuntimeAlignedAttributeINTEL = 5939, + IOPipesINTEL = 5943, + BlockingPipesINTEL = 5945, + FPGARegINTEL = 5948, + DotProductInputAll = 6016, + DotProductInput4x8Bit = 6017, + DotProductInput4x8BitPacked = 6018, + DotProduct = 6019, + RayCullMaskKHR = 6020, + CooperativeMatrixKHR = 6022, + ReplicatedCompositesEXT = 6024, + BitInstructions = 6025, + GroupNonUniformRotateKHR = 6026, + FloatControls2 = 6029, + AtomicFloat32AddEXT = 6033, + AtomicFloat64AddEXT = 6034, + LongCompositesINTEL = 6089, + OptNoneEXT = 6094, + AtomicFloat16AddEXT = 6095, + DebugInfoModuleINTEL = 6114, + BFloat16ConversionINTEL = 6115, + SplitBarrierINTEL = 6141, + ArithmeticFenceEXT = 6144, + FPGAClusterAttributesV2INTEL = 6150, + FPGAKernelAttributesv2INTEL = 6161, + TaskSequenceINTEL = 6162, + FPMaxErrorINTEL = 6169, + FPGALatencyControlINTEL = 6171, + FPGAArgumentInterfacesINTEL = 6174, + GlobalVariableHostAccessINTEL = 6187, + GlobalVariableFPGADecorationsINTEL = 6189, + SubgroupBufferPrefetchINTEL = 6220, + Subgroup2DBlockIOINTEL = 6228, + Subgroup2DBlockTransformINTEL = 6229, + Subgroup2DBlockTransposeINTEL = 6230, + SubgroupMatrixMultiplyAccumulateINTEL = 6236, + TernaryBitwiseFunctionINTEL = 6241, + GroupUniformArithmeticKHR = 6400, + TensorFloat32RoundingINTEL = 6425, + MaskedGatherScatterINTEL = 6427, + CacheControlsINTEL = 6441, + RegisterLimitsINTEL = 6460, + BindlessImagesINTEL = 6528, + PackedCooperativeMatrixINTEL = 6434, + CooperativeMatrixInvocationInstructionsINTEL = 6435, + CooperativeMatrixTF32ComponentTypeINTEL = 6436, + CooperativeMatrixBFloat16ComponentTypeINTEL = 6437, + CooperativeMatrixCheckedInstructionsINTEL = 6192, + CooperativeMatrixPrefetchINTEL = 6449, +}; +enum class RayQueryIntersection { + RayQueryCandidateIntersectionKHR = 0, + RayQueryCommittedIntersectionKHR = 1, +}; +enum class RayQueryCommittedIntersectionType { + RayQueryCommittedIntersectionNoneKHR = 0, + RayQueryCommittedIntersectionTriangleKHR = 1, + RayQueryCommittedIntersectionGeneratedKHR = 2, +}; +enum class RayQueryCandidateIntersectionType { + RayQueryCandidateIntersectionTriangleKHR = 0, + RayQueryCandidateIntersectionAABBKHR = 1, +}; +enum class PackedVectorFormat { + PackedVectorFormat4x8Bit = 0, +}; +enum class CooperativeMatrixOperands { + NoneKHR = 0x0000, + MatrixASignedComponentsKHR = 0x0001, + MatrixBSignedComponentsKHR = 0x0002, + MatrixCSignedComponentsKHR = 0x0004, + MatrixResultSignedComponentsKHR = 0x0008, + SaturatingAccumulationKHR = 0x0010, +}; +enum class CooperativeMatrixLayout { + RowMajorKHR = 0, + ColumnMajorKHR = 1, + RowBlockedInterleavedARM = 4202, + ColumnBlockedInterleavedARM = 4203, +}; +enum class CooperativeMatrixUse { + MatrixAKHR = 0, + MatrixBKHR = 1, + MatrixAccumulatorKHR = 2, +}; +enum class CooperativeMatrixReduce { + Row = 0x0001, + Column = 0x0002, + CooperativeMatrixReduce2x2 = 0x0004, +}; +enum class TensorClampMode { + Undefined = 0, + Constant = 1, + ClampToEdge = 2, + Repeat = 3, + RepeatMirrored = 4, +}; +enum class TensorAddressingOperands { + None = 0x0000, + TensorView = 0x0001, + DecodeFunc = 0x0002, +}; +enum class InitializationModeQualifier { + InitOnDeviceReprogramINTEL = 0, + InitOnDeviceResetINTEL = 1, +}; +enum class LoadCacheControl { + UncachedINTEL = 0, + CachedINTEL = 1, + StreamingINTEL = 2, + InvalidateAfterReadINTEL = 3, + ConstCachedINTEL = 4, +}; +enum class StoreCacheControl { + UncachedINTEL = 0, + WriteThroughINTEL = 1, + WriteBackINTEL = 2, + StreamingINTEL = 3, +}; +enum class NamedMaximumNumberOfRegisters { + AutoINTEL = 0, +}; +enum class MatrixMultiplyAccumulateOperands { + None = 0x0, + MatrixASignedComponentsINTEL = 0x1, + MatrixBSignedComponentsINTEL = 0x2, + MatrixCBFloat16INTEL = 0x4, + MatrixResultBFloat16INTEL = 0x8, + MatrixAPackedInt8INTEL = 0x10, + MatrixBPackedInt8INTEL = 0x20, + MatrixAPackedInt4INTEL = 0x40, + MatrixBPackedInt4INTEL = 0x80, + MatrixATF32INTEL = 0x100, + MatrixBTF32INTEL = 0x200, + MatrixAPackedFloat16INTEL = 0x400, + MatrixBPackedFloat16INTEL = 0x800, + MatrixAPackedBFloat16INTEL = 0x1000, + MatrixBPackedBFloat16INTEL = 0x2000, +}; +enum class FPEncoding { + BFloat16KHR = 0, + Float8E4M3EXT = 4214, + Float8E5M2EXT = 4215, +}; +enum class CooperativeVectorMatrixLayout { + RowMajorNV = 0, + ColumnMajorNV = 1, + InferencingOptimalNV = 2, + TrainingOptimalNV = 3, +}; +enum class ComponentType { + Float16NV = 0, + Float32NV = 1, + Float64NV = 2, + SignedInt8NV = 3, + SignedInt16NV = 4, + SignedInt32NV = 5, + SignedInt64NV = 6, + UnsignedInt8NV = 7, + UnsignedInt16NV = 8, + UnsignedInt32NV = 9, + UnsignedInt64NV = 10, + SignedInt8PackedNV = 1000491000, + UnsignedInt8PackedNV = 1000491001, + FloatE4M3NV = 1000491002, + FloatE5M2NV = 1000491003, +}; +enum class TensorOperands { + NoneARM = 0x0000, + NontemporalARM = 0x0001, + OutOfBoundsValueARM = 0x0002, + MakeElementAvailableARM = 0x0004, + MakeElementVisibleARM = 0x0008, + NonPrivateElementARM = 0x0010, +}; + +} // namespace tinytc::spv + +#endif // GENERATED_ENUMS_20250630_HPP diff --git a/src/spv/inst_assembler.cpp b/src/spv/inst_assembler.cpp new file mode 100644 index 00000000..d65e8306 --- /dev/null +++ b/src/spv/inst_assembler.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/inst_assembler.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include + +namespace tinytc::spv { + +enum class LinkageType; + +inst_assembler::inst_assembler(word_stream &stream) : stream_{&stream} {} + +void inst_assembler::operator()(DecorationAttr const &da) { + std::visit(overloaded{[&](auto const &a) { this->operator()(a); }, + [&](std::pair const &a) { + *stream_ << a.first; + this->operator()(a.second); + }}, + da); +} +void inst_assembler::operator()(ExecutionModeAttr const &ea) { + std::visit(overloaded{[&](std::int32_t const &a) { *stream_ << a; }, + [&](std::array const &a) { + for (auto const &s : a) { + *stream_ << s; + } + }}, + ea); +} +void inst_assembler::operator()(LiteralContextDependentNumber const &l) { + std::visit(overloaded{[&](auto const &l) { *stream_ << l; }}, l); +} +void inst_assembler::operator()(LiteralInteger const &l) { *stream_ << l; } +void inst_assembler::operator()(LiteralString const &l) { *stream_ << l; } + +void inst_assembler::operator()(PairIdRefIdRef const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void inst_assembler::operator()(PairIdRefLiteralInteger const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void inst_assembler::operator()(PairLiteralIntegerIdRef const &p) { + std::visit(overloaded{[&](auto const &l) { *stream_ << l; }}, p.first); + this->operator()(p.second); +} + +void inst_assembler::pre_visit(spv_inst const &) { + *stream_ << std::int32_t{0}; + last_opcode_pos_ = stream_->tell(); +} + +void inst_assembler::visit_result(spv_inst const &in) { *stream_ << in.id(); } + +void inst_assembler::post_visit(spv_inst const &in) { + const std::int32_t word_count = stream_->tell() - last_opcode_pos_ + 1; + const auto ophead = (word_count << 16) | static_cast(in.opcode()); + stream_->update(last_opcode_pos_, ophead); +} + +void inst_assembler::operator()(spv_inst *const &in) { *stream_ << in->id(); } + +} // namespace tinytc::spv + diff --git a/src/spv/inst_assembler.hpp b/src/spv/inst_assembler.hpp new file mode 100644 index 00000000..f86149a5 --- /dev/null +++ b/src/spv/inst_assembler.hpp @@ -0,0 +1,96 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef INST_ASSEMBLER_20241111_HPP +#define INST_ASSEMBLER_20241111_HPP + +#include "spv/defs.hpp" +#include "spv/visit.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +enum class BuiltIn; +enum class LinkageType; + +template class word_stream { + public: + word_stream(std::vector &vec) : vec_{&vec} {} + + template auto operator<<(T const &t) -> word_stream & { + const std::size_t insert_pos = vec_->size() / sizeof(WordT); + vec_->resize(vec_->size() + word_count(t) * sizeof(WordT)); + update(insert_pos, t); + return *this; + } + + template static auto word_count(T const &) -> std::size_t { + return 1 + (sizeof(T) - 1) / sizeof(WordT); // ceil(sizeof(T)/sizeof(WordT)) + } + + static auto word_count(std::string const &s) -> std::size_t { + return 1 + s.size() / sizeof(WordT); // ceil((s.size()+1)/sizeof(WordT)) + } + + template auto update(std::size_t word, T const &t) -> word_stream & { + const std::size_t addr = word * sizeof(WordT); + std::memcpy(vec_->data() + addr, &t, sizeof(T)); + return *this; + } + + auto update(std::size_t word, std::string const &s) -> word_stream & { + const std::size_t addr = word * sizeof(WordT); + std::memcpy(vec_->data() + addr, s.c_str(), s.size() + 1); + return *this; + } + + //! Returns last word position + auto tell() const -> std::size_t { + const auto size = vec_->size(); + return size > 0 ? size / sizeof(WordT) - 1 : 0; + } + + private: + std::vector *vec_; +}; + +class inst_assembler : public default_visitor { + public: + using default_visitor::operator(); + + inst_assembler(word_stream &stream); + + void pre_visit(spv_inst const &in); + void visit_result(spv_inst const &in); + void post_visit(spv_inst const &in); + + template + requires std::is_enum_v + void operator()(T const &t) { + *stream_ << static_cast(t); + } + void operator()(DecorationAttr const &da); + void operator()(ExecutionModeAttr const &ea); + void operator()(LiteralContextDependentNumber const &l); + void operator()(LiteralInteger const &l); + void operator()(LiteralString const &l); + + void operator()(PairIdRefIdRef const &p); + void operator()(PairIdRefLiteralInteger const &p); + void operator()(PairLiteralIntegerIdRef const &p); + + void operator()(spv_inst *const &in); + + private: + word_stream *stream_; + std::size_t last_opcode_pos_ = 0; +}; + +} // namespace tinytc::spv + +#endif // INST_ASSEMBLER_20241111_HPP diff --git a/src/spv/instructions.hpp b/src/spv/instructions.hpp new file mode 100644 index 00000000..dea618d9 --- /dev/null +++ b/src/spv/instructions.hpp @@ -0,0 +1,7008 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_INSTRUCTIONS_20250630_HPP +#define GENERATED_INSTRUCTIONS_20250630_HPP + +#include "defs.hpp" +#include "enums.hpp" +#include "error.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +class OpNop : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Nop; } + OpNop() : spv_inst{Op::Nop, false} {} + + private: +}; +class OpUndef : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Undef; } + OpUndef(IdResultType type) : spv_inst{Op::Undef, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpSourceContinued : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceContinued; } + OpSourceContinued(LiteralString op0) + : spv_inst{Op::SourceContinued, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpSource : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Source; } + OpSource(SourceLanguage op0, LiteralInteger op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt) + : spv_inst{Op::Source, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> SourceLanguage & { return op0_; } + inline auto op0() const -> SourceLanguage const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + SourceLanguage op0_; + LiteralInteger op1_; + std::optional op2_; + std::optional op3_; +}; +class OpSourceExtension : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SourceExtension; } + OpSourceExtension(LiteralString op0) + : spv_inst{Op::SourceExtension, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpName : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Name; } + OpName(IdRef op0, LiteralString op1) + : spv_inst{Op::Name, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralString & { return op1_; } + inline auto op1() const -> LiteralString const & { return op1_; } + + private: + IdRef op0_; + LiteralString op1_; +}; +class OpMemberName : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemberName; } + OpMemberName(IdRef op0, LiteralInteger op1, LiteralString op2) + : spv_inst{Op::MemberName, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } + inline auto op2() const -> LiteralString const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + LiteralString op2_; +}; +class OpString : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::String; } + OpString(LiteralString op0) : spv_inst{Op::String, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpLine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Line; } + OpLine(IdRef op0, LiteralInteger op1, LiteralInteger op2) + : spv_inst{Op::Line, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + LiteralInteger op2_; +}; +class OpExtension : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Extension; } + OpExtension(LiteralString op0) : spv_inst{Op::Extension, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExtInstImport : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInstImport; } + OpExtInstImport(LiteralString op0) : spv_inst{Op::ExtInstImport, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExtInst : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExtInst; } + OpExtInst(IdResultType type, IdRef op0, LiteralExtInstInteger op1, std::vector op2) + : spv_inst{Op::ExtInst, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralExtInstInteger & { return op1_; } + inline auto op1() const -> LiteralExtInstInteger const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + LiteralExtInstInteger op1_; + std::vector op2_; +}; +class OpMemoryModel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryModel; } + OpMemoryModel(AddressingModel op0, MemoryModel op1) + : spv_inst{Op::MemoryModel, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> AddressingModel & { return op0_; } + inline auto op0() const -> AddressingModel const & { return op0_; } + inline auto op1() -> MemoryModel & { return op1_; } + inline auto op1() const -> MemoryModel const & { return op1_; } + + private: + AddressingModel op0_; + MemoryModel op1_; +}; +class OpEntryPoint : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EntryPoint; } + OpEntryPoint(ExecutionModel op0, IdRef op1, LiteralString op2, std::vector op3) + : spv_inst{Op::EntryPoint, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> ExecutionModel & { return op0_; } + inline auto op0() const -> ExecutionModel const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } + inline auto op2() const -> LiteralString const & { return op2_; } + inline auto op3() -> std::vector & { return op3_; } + inline auto op3() const -> std::vector const & { return op3_; } + + private: + ExecutionModel op0_; + IdRef op1_; + LiteralString op2_; + std::vector op3_; +}; +class OpExecutionMode : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionMode; } + OpExecutionMode(IdRef op0, ExecutionMode op1, ExecutionModeAttr op2) + : spv_inst{Op::ExecutionMode, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> ExecutionMode & { return op1_; } + inline auto op1() const -> ExecutionMode const & { return op1_; } + inline auto op2() -> ExecutionModeAttr & { return op2_; } + inline auto op2() const -> ExecutionModeAttr const & { return op2_; } + + private: + IdRef op0_; + ExecutionMode op1_; + ExecutionModeAttr op2_; +}; +class OpCapability : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Capability; } + OpCapability(Capability op0) : spv_inst{Op::Capability, false}, op0_(std::move(op0)) {} + inline auto op0() -> Capability & { return op0_; } + inline auto op0() const -> Capability const & { return op0_; } + + private: + Capability op0_; +}; +class OpTypeVoid : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVoid; } + OpTypeVoid() : spv_inst{Op::TypeVoid, true} {} + + private: +}; +class OpTypeBool : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeBool; } + OpTypeBool() : spv_inst{Op::TypeBool, true} {} + + private: +}; +class OpTypeInt : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeInt; } + OpTypeInt(LiteralInteger op0, LiteralInteger op1) + : spv_inst{Op::TypeInt, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> LiteralInteger & { return op0_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + LiteralInteger op0_; + LiteralInteger op1_; +}; +class OpTypeFloat : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFloat; } + OpTypeFloat(LiteralInteger op0, std::optional op1 = std::nullopt) + : spv_inst{Op::TypeFloat, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> LiteralInteger & { return op0_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + LiteralInteger op0_; + std::optional op1_; +}; +class OpTypeVector : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeVector; } + OpTypeVector(IdRef op0, LiteralInteger op1) + : spv_inst{Op::TypeVector, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpTypeMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpTypeMatrix(IdRef op0, LiteralInteger op1) + : spv_inst{Op::TypeMatrix, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpTypeImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeImage; } + OpTypeImage(IdRef op0, Dim op1, LiteralInteger op2, LiteralInteger op3, LiteralInteger op4, + LiteralInteger op5, ImageFormat op6, + std::optional op7 = std::nullopt) + : spv_inst{Op::TypeImage, true}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)), + op6_(std::move(op6)), op7_(std::move(op7)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Dim & { return op1_; } + inline auto op1() const -> Dim const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + inline auto op3() -> LiteralInteger & { return op3_; } + inline auto op3() const -> LiteralInteger const & { return op3_; } + inline auto op4() -> LiteralInteger & { return op4_; } + inline auto op4() const -> LiteralInteger const & { return op4_; } + inline auto op5() -> LiteralInteger & { return op5_; } + inline auto op5() const -> LiteralInteger const & { return op5_; } + inline auto op6() -> ImageFormat & { return op6_; } + inline auto op6() const -> ImageFormat const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } + inline auto op7() const -> std::optional const & { return op7_; } + + private: + IdRef op0_; + Dim op1_; + LiteralInteger op2_; + LiteralInteger op3_; + LiteralInteger op4_; + LiteralInteger op5_; + ImageFormat op6_; + std::optional op7_; +}; +class OpTypeSampler : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampler; } + OpTypeSampler() : spv_inst{Op::TypeSampler, true} {} + + private: +}; +class OpTypeSampledImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeSampledImage; } + OpTypeSampledImage(IdRef op0) : spv_inst{Op::TypeSampledImage, true}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpTypeArray : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeArray; } + OpTypeArray(IdRef op0, IdRef op1) + : spv_inst{Op::TypeArray, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpTypeRuntimeArray : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeRuntimeArray; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpTypeRuntimeArray(IdRef op0) : spv_inst{Op::TypeRuntimeArray, true}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpTypeStruct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeStruct; } + OpTypeStruct(std::vector op0) : spv_inst{Op::TypeStruct, true}, op0_(std::move(op0)) {} + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + std::vector op0_; +}; +class OpTypeOpaque : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeOpaque; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpTypeOpaque(LiteralString op0) : spv_inst{Op::TypeOpaque, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpTypePointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePointer; } + OpTypePointer(StorageClass op0, IdRef op1) + : spv_inst{Op::TypePointer, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> StorageClass & { return op0_; } + inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + StorageClass op0_; + IdRef op1_; +}; +class OpTypeFunction : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeFunction; } + OpTypeFunction(IdRef op0, std::vector op1) + : spv_inst{Op::TypeFunction, true}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpTypeEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeEvent; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpTypeEvent() : spv_inst{Op::TypeEvent, true} {} + + private: +}; +class OpTypeDeviceEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeDeviceEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpTypeDeviceEvent() : spv_inst{Op::TypeDeviceEvent, true} {} + + private: +}; +class OpTypeReserveId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeReserveId; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpTypeReserveId() : spv_inst{Op::TypeReserveId, true} {} + + private: +}; +class OpTypeQueue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeQueue; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpTypeQueue() : spv_inst{Op::TypeQueue, true} {} + + private: +}; +class OpTypePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpTypePipe(AccessQualifier op0) : spv_inst{Op::TypePipe, true}, op0_(std::move(op0)) {} + inline auto op0() -> AccessQualifier & { return op0_; } + inline auto op0() const -> AccessQualifier const & { return op0_; } + + private: + AccessQualifier op0_; +}; +class OpTypeForwardPointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeForwardPointer; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpTypeForwardPointer(IdRef op0, StorageClass op1) + : spv_inst{Op::TypeForwardPointer, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> StorageClass & { return op1_; } + inline auto op1() const -> StorageClass const & { return op1_; } + + private: + IdRef op0_; + StorageClass op1_; +}; +class OpConstantTrue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantTrue; } + OpConstantTrue(IdResultType type) : spv_inst{Op::ConstantTrue, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpConstantFalse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantFalse; } + OpConstantFalse(IdResultType type) + : spv_inst{Op::ConstantFalse, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpConstant : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Constant; } + OpConstant(IdResultType type, LiteralContextDependentNumber op0) + : spv_inst{Op::Constant, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> LiteralContextDependentNumber & { return op0_; } + inline auto op0() const -> LiteralContextDependentNumber const & { return op0_; } + + private: + IdResultType type_; + LiteralContextDependentNumber op0_; +}; +class OpConstantComposite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantComposite; } + OpConstantComposite(IdResultType type, std::vector op0) + : spv_inst{Op::ConstantComposite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpConstantSampler : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantSampler; } + constexpr static std::array required_capabilities = {Capability::LiteralSampler}; + OpConstantSampler(IdResultType type, SamplerAddressingMode op0, LiteralInteger op1, + SamplerFilterMode op2) + : spv_inst{Op::ConstantSampler, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> SamplerAddressingMode & { return op0_; } + inline auto op0() const -> SamplerAddressingMode const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> SamplerFilterMode & { return op2_; } + inline auto op2() const -> SamplerFilterMode const & { return op2_; } + + private: + IdResultType type_; + SamplerAddressingMode op0_; + LiteralInteger op1_; + SamplerFilterMode op2_; +}; +class OpConstantNull : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantNull; } + OpConstantNull(IdResultType type) : spv_inst{Op::ConstantNull, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpFunction : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Function; } + OpFunction(IdResultType type, FunctionControl op0, IdRef op1) + : spv_inst{Op::Function, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> FunctionControl & { return op0_; } + inline auto op0() const -> FunctionControl const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + FunctionControl op0_; + IdRef op1_; +}; +class OpFunctionParameter : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionParameter; } + OpFunctionParameter(IdResultType type) + : spv_inst{Op::FunctionParameter, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpFunctionEnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionEnd; } + OpFunctionEnd() : spv_inst{Op::FunctionEnd, false} {} + + private: +}; +class OpFunctionCall : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FunctionCall; } + OpFunctionCall(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::FunctionCall, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpVariable : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Variable; } + OpVariable(IdResultType type, StorageClass op0, std::optional op1 = std::nullopt) + : spv_inst{Op::Variable, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> StorageClass & { return op0_; } + inline auto op0() const -> StorageClass const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } + inline auto op1() const -> std::optional const & { return op1_; } + + private: + IdResultType type_; + StorageClass op0_; + std::optional op1_; +}; +class OpImageTexelPointer : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageTexelPointer; } + OpImageTexelPointer(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::ImageTexelPointer, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpLoad : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Load; } + OpLoad(IdResultType type, IdRef op0, std::optional op1 = std::nullopt, + std::optional op2 = std::nullopt) + : spv_inst{Op::Load, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::optional & { return op1_; } + inline auto op1() const -> std::optional const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + std::optional op1_; + std::optional op2_; +}; +class OpStore : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Store; } + OpStore(IdRef op0, IdRef op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt) + : spv_inst{Op::Store, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; +}; +class OpCopyMemory : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemory; } + OpCopyMemory(IdRef op0, IdRef op1, std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt) + : spv_inst{Op::CopyMemory, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + + private: + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; + std::optional op4_; +}; +class OpCopyMemorySized : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyMemorySized; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::UntypedPointersKHR}; + OpCopyMemorySized(IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt, + std::optional op5 = std::nullopt) + : spv_inst{Op::CopyMemorySized, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() -> std::optional & { return op5_; } + inline auto op5() const -> std::optional const & { return op5_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; + std::optional op4_; + std::optional op5_; +}; +class OpAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AccessChain; } + OpAccessChain(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::AccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpInBoundsAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::InBoundsAccessChain; } + OpInBoundsAccessChain(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::InBoundsAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpPtrAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrAccessChain; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::VariablePointers, + Capability::VariablePointersStorageBuffer, Capability::PhysicalStorageBufferAddresses}; + OpPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::PtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpArrayLength : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ArrayLength; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpArrayLength(IdResultType type, IdRef op0, LiteralInteger op1) + : spv_inst{Op::ArrayLength, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + LiteralInteger op1_; +}; +class OpGenericPtrMemSemantics : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GenericPtrMemSemantics; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericPtrMemSemantics(IdResultType type, IdRef op0) + : spv_inst{Op::GenericPtrMemSemantics, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpInBoundsPtrAccessChain : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::InBoundsPtrAccessChain; + } + constexpr static std::array required_capabilities = {Capability::Addresses}; + OpInBoundsPtrAccessChain(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::InBoundsPtrAccessChain, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Decorate; } + OpDecorate(IdRef op0, Decoration op1, std::optional op2 = std::nullopt) + : spv_inst{Op::Decorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Decoration & { return op1_; } + inline auto op1() const -> Decoration const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdRef op0_; + Decoration op1_; + std::optional op2_; +}; +class OpMemberDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemberDecorate; } + OpMemberDecorate(IdRef op0, LiteralInteger op1, Decoration op2) + : spv_inst{Op::MemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> Decoration & { return op2_; } + inline auto op2() const -> Decoration const & { return op2_; } + + private: + IdRef op0_; + LiteralInteger op1_; + Decoration op2_; +}; +class OpDecorationGroup : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorationGroup; } + OpDecorationGroup() : spv_inst{Op::DecorationGroup, true} {} + + private: +}; +class OpGroupDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupDecorate; } + OpGroupDecorate(IdRef op0, std::vector op1) + : spv_inst{Op::GroupDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpGroupMemberDecorate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupMemberDecorate; } + OpGroupMemberDecorate(IdRef op0, std::vector op1) + : spv_inst{Op::GroupMemberDecorate, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdRef op0_; + std::vector op1_; +}; +class OpVectorExtractDynamic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorExtractDynamic; } + OpVectorExtractDynamic(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorExtractDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorInsertDynamic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorInsertDynamic; } + OpVectorInsertDynamic(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::VectorInsertDynamic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpVectorShuffle : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorShuffle; } + OpVectorShuffle(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::VectorShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpCompositeConstruct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeConstruct; } + OpCompositeConstruct(IdResultType type, std::vector op0) + : spv_inst{Op::CompositeConstruct, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpCompositeExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeExtract; } + OpCompositeExtract(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::CompositeExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpCompositeInsert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CompositeInsert; } + OpCompositeInsert(IdResultType type, IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::CompositeInsert, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpCopyObject : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyObject; } + OpCopyObject(IdResultType type, IdRef op0) + : spv_inst{Op::CopyObject, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpTranspose : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Transpose; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpTranspose(IdResultType type, IdRef op0) + : spv_inst{Op::Transpose, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSampledImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SampledImage; } + OpSampledImage(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SampledImage, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageSampleImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSampleImplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSampleExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleExplicitLod; + } + OpImageSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSampleExplicitLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSampleDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleDrefImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSampleDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSampleDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleDrefExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSampleDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSampleProjImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSampleProjImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSampleProjExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSampleProjExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSampleProjDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjDrefImplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSampleProjDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSampleProjDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSampleProjDrefExplicitLod; + } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageSampleProjDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSampleProjDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageFetch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageFetch; } + OpImageFetch(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageFetch, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageGather; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageDrefGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageDrefGather; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpImageDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageRead : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageRead; } + OpImageRead(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageRead, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageWrite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageWrite; } + OpImageWrite(IdRef op0, IdRef op1, IdRef op2, std::optional op3 = std::nullopt) + : spv_inst{Op::ImageWrite, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Image; } + OpImage(IdResultType type, IdRef op0) + : spv_inst{Op::Image, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryFormat : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryFormat; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpImageQueryFormat(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryFormat, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryOrder : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryOrder; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpImageQueryOrder(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryOrder, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQuerySizeLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySizeLod; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySizeLod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ImageQuerySizeLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageQuerySize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySize; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySize(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQuerySize, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQueryLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryLod; } + constexpr static std::array required_capabilities = {Capability::ImageQuery}; + OpImageQueryLod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ImageQueryLod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpImageQueryLevels : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQueryLevels; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQueryLevels(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQueryLevels, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpImageQuerySamples : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageQuerySamples; } + constexpr static std::array required_capabilities = {Capability::Kernel, + Capability::ImageQuery}; + OpImageQuerySamples(IdResultType type, IdRef op0) + : spv_inst{Op::ImageQuerySamples, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertFToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToU; } + OpConvertFToU(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertFToS : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToS; } + OpConvertFToS(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertSToF : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertSToF; } + OpConvertSToF(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertSToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertUToF : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToF; } + OpConvertUToF(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertUToF, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpUConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UConvert; } + OpUConvert(IdResultType type, IdRef op0) + : spv_inst{Op::UConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SConvert; } + OpSConvert(IdResultType type, IdRef op0) + : spv_inst{Op::SConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFConvert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FConvert; } + OpFConvert(IdResultType type, IdRef op0) + : spv_inst{Op::FConvert, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpQuantizeToF16 : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::QuantizeToF16; } + OpQuantizeToF16(IdResultType type, IdRef op0) + : spv_inst{Op::QuantizeToF16, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertPtrToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertPtrToU; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpConvertPtrToU(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertPtrToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSatConvertSToU : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SatConvertSToU; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSatConvertSToU(IdResultType type, IdRef op0) + : spv_inst{Op::SatConvertSToU, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSatConvertUToS : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SatConvertUToS; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSatConvertUToS(IdResultType type, IdRef op0) + : spv_inst{Op::SatConvertUToS, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertUToPtr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertUToPtr; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::PhysicalStorageBufferAddresses}; + OpConvertUToPtr(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertUToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpPtrCastToGeneric : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrCastToGeneric; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpPtrCastToGeneric(IdResultType type, IdRef op0) + : spv_inst{Op::PtrCastToGeneric, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGenericCastToPtr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GenericCastToPtr; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericCastToPtr(IdResultType type, IdRef op0) + : spv_inst{Op::GenericCastToPtr, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGenericCastToPtrExplicit : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GenericCastToPtrExplicit; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGenericCastToPtrExplicit(IdResultType type, IdRef op0, StorageClass op1) + : spv_inst{Op::GenericCastToPtrExplicit, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> StorageClass & { return op1_; } + inline auto op1() const -> StorageClass const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + StorageClass op1_; +}; +class OpBitcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Bitcast; } + OpBitcast(IdResultType type, IdRef op0) + : spv_inst{Op::Bitcast, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSNegate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SNegate; } + OpSNegate(IdResultType type, IdRef op0) + : spv_inst{Op::SNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFNegate : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FNegate; } + OpFNegate(IdResultType type, IdRef op0) + : spv_inst{Op::FNegate, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IAdd; } + OpIAdd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FAdd; } + OpFAdd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpISub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ISub; } + OpISub(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ISub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFSub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FSub; } + OpFSub(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FSub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpIMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IMul; } + OpIMul(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FMul; } + OpFMul(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UDiv; } + OpUDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SDiv; } + OpSDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFDiv : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FDiv; } + OpFDiv(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FDiv, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UMod; } + OpUMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSRem : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SRem; } + OpSRem(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SRem, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SMod; } + OpSMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFRem : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FRem; } + OpFRem(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FRem, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFMod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FMod; } + OpFMod(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FMod, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorTimesScalar : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorTimesScalar; } + OpVectorTimesScalar(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesScalar : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesScalar; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesScalar(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesScalar, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpVectorTimesMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::VectorTimesMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpVectorTimesMatrix(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::VectorTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesVector : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesVector; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesVector(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesVector, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpMatrixTimesMatrix : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MatrixTimesMatrix; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpMatrixTimesMatrix(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::MatrixTimesMatrix, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpOuterProduct : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::OuterProduct; } + constexpr static std::array required_capabilities = {Capability::Matrix}; + OpOuterProduct(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::OuterProduct, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpDot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Dot; } + OpDot(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Dot, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpIAddCarry : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IAddCarry; } + OpIAddCarry(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IAddCarry, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpISubBorrow : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ISubBorrow; } + OpISubBorrow(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ISubBorrow, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUMulExtended : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UMulExtended; } + OpUMulExtended(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSMulExtended : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SMulExtended; } + OpSMulExtended(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SMulExtended, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Any; } + OpAny(IdResultType type, IdRef op0) + : spv_inst{Op::Any, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::All; } + OpAll(IdResultType type, IdRef op0) + : spv_inst{Op::All, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsNan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNan; } + OpIsNan(IdResultType type, IdRef op0) + : spv_inst{Op::IsNan, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsInf : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsInf; } + OpIsInf(IdResultType type, IdRef op0) + : spv_inst{Op::IsInf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsFinite : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsFinite; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpIsFinite(IdResultType type, IdRef op0) + : spv_inst{Op::IsFinite, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpIsNormal : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsNormal; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpIsNormal(IdResultType type, IdRef op0) + : spv_inst{Op::IsNormal, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSignBitSet : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SignBitSet; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpSignBitSet(IdResultType type, IdRef op0) + : spv_inst{Op::SignBitSet, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpLessOrGreater : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LessOrGreater; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLessOrGreater(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LessOrGreater, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpOrdered : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Ordered; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpOrdered(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Ordered, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUnordered : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Unordered; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpUnordered(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::Unordered, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalEqual; } + OpLogicalEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNotEqual; } + OpLogicalNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalOr; } + OpLogicalOr(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalAnd; } + OpLogicalAnd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::LogicalAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpLogicalNot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LogicalNot; } + OpLogicalNot(IdResultType type, IdRef op0) + : spv_inst{Op::LogicalNot, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSelect : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Select; } + OpSelect(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::Select, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpIEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IEqual; } + OpIEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::IEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpINotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::INotEqual; } + OpINotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::INotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UGreaterThan; } + OpUGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SGreaterThan; } + OpSGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpUGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::UGreaterThanEqual; } + OpUGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::UGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SGreaterThanEqual; } + OpSGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpULessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ULessThan; } + OpULessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ULessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SLessThan; } + OpSLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpULessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ULessThanEqual; } + OpULessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ULessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpSLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SLessThanEqual; } + OpSLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::SLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdEqual; } + OpFOrdEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordEqual; } + OpFUnordEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdNotEqual; } + OpFOrdNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordNotEqual; } + OpFUnordNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdLessThan; } + OpFOrdLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordLessThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordLessThan; } + OpFUnordLessThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordLessThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdGreaterThan; } + OpFOrdGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordGreaterThan : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordGreaterThan; } + OpFUnordGreaterThan(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordGreaterThan, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdLessThanEqual; } + OpFOrdLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordLessThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FUnordLessThanEqual; } + OpFUnordLessThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordLessThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFOrdGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FOrdGreaterThanEqual; } + OpFOrdGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FOrdGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpFUnordGreaterThanEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::FUnordGreaterThanEqual; + } + OpFUnordGreaterThanEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::FUnordGreaterThanEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftRightLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftRightLogical; } + OpShiftRightLogical(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftRightLogical, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftRightArithmetic : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftRightArithmetic; } + OpShiftRightArithmetic(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftRightArithmetic, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpShiftLeftLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ShiftLeftLogical; } + OpShiftLeftLogical(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::ShiftLeftLogical, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseOr; } + OpBitwiseOr(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseXor; } + OpBitwiseXor(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseXor, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpBitwiseAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitwiseAnd; } + OpBitwiseAnd(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::BitwiseAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpNot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Not; } + OpNot(IdResultType type, IdRef op0) + : spv_inst{Op::Not, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpBitFieldInsert : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldInsert; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldInsert(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::BitFieldInsert, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpBitFieldSExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldSExtract; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldSExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BitFieldSExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpBitFieldUExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitFieldUExtract; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitFieldUExtract(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BitFieldUExtract, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpBitReverse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitReverse; } + constexpr static std::array required_capabilities = { + Capability::Shader, Capability::BitInstructions}; + OpBitReverse(IdResultType type, IdRef op0) + : spv_inst{Op::BitReverse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpBitCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BitCount; } + OpBitCount(IdResultType type, IdRef op0) + : spv_inst{Op::BitCount, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdx : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdx; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpDPdx(IdResultType type, IdRef op0) + : spv_inst{Op::DPdx, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdy : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdy; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpDPdy(IdResultType type, IdRef op0) + : spv_inst{Op::DPdy, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidth : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Fwidth; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpFwidth(IdResultType type, IdRef op0) + : spv_inst{Op::Fwidth, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdxFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdxFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdxFine(IdResultType type, IdRef op0) + : spv_inst{Op::DPdxFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdyFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdyFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdyFine(IdResultType type, IdRef op0) + : spv_inst{Op::DPdyFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidthFine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FwidthFine; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpFwidthFine(IdResultType type, IdRef op0) + : spv_inst{Op::FwidthFine, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdxCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdxCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdxCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::DPdxCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpDPdyCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DPdyCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpDPdyCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::DPdyCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpFwidthCoarse : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::FwidthCoarse; } + constexpr static std::array required_capabilities = { + Capability::DerivativeControl}; + OpFwidthCoarse(IdResultType type, IdRef op0) + : spv_inst{Op::FwidthCoarse, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpEmitVertex : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EmitVertex; } + constexpr static std::array required_capabilities = {Capability::Geometry}; + OpEmitVertex() : spv_inst{Op::EmitVertex, false} {} + + private: +}; +class OpEndPrimitive : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EndPrimitive; } + constexpr static std::array required_capabilities = {Capability::Geometry}; + OpEndPrimitive() : spv_inst{Op::EndPrimitive, false} {} + + private: +}; +class OpEmitStreamVertex : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EmitStreamVertex; } + constexpr static std::array required_capabilities = { + Capability::GeometryStreams}; + OpEmitStreamVertex(IdRef op0) : spv_inst{Op::EmitStreamVertex, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpEndStreamPrimitive : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EndStreamPrimitive; } + constexpr static std::array required_capabilities = { + Capability::GeometryStreams}; + OpEndStreamPrimitive(IdRef op0) + : spv_inst{Op::EndStreamPrimitive, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpControlBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ControlBarrier; } + OpControlBarrier(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpMemoryBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryBarrier; } + OpMemoryBarrier(IdScope op0, IdMemorySemantics op1) + : spv_inst{Op::MemoryBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdMemorySemantics & { return op1_; } + inline auto op1() const -> IdMemorySemantics const & { return op1_; } + + private: + IdScope op0_; + IdMemorySemantics op1_; +}; +class OpAtomicLoad : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicLoad; } + OpAtomicLoad(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicLoad, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicStore : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicStore; } + OpAtomicStore(IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicStore, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicExchange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicExchange; } + OpAtomicExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicExchange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicCompareExchange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::AtomicCompareExchange; + } + OpAtomicCompareExchange(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, + IdMemorySemantics op3, IdRef op4, IdRef op5) + : spv_inst{Op::AtomicCompareExchange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdMemorySemantics & { return op3_; } + inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdMemorySemantics op3_; + IdRef op4_; + IdRef op5_; +}; +class OpAtomicCompareExchangeWeak : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::AtomicCompareExchangeWeak; + } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicCompareExchangeWeak(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, + IdMemorySemantics op3, IdRef op4, IdRef op5) + : spv_inst{Op::AtomicCompareExchangeWeak, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)), op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdMemorySemantics & { return op3_; } + inline auto op3() const -> IdMemorySemantics const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdMemorySemantics op3_; + IdRef op4_; + IdRef op5_; +}; +class OpAtomicIIncrement : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIIncrement; } + OpAtomicIIncrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicIIncrement, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicIDecrement : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIDecrement; } + OpAtomicIDecrement(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicIDecrement, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicIAdd; } + OpAtomicIAdd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicISub : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicISub; } + OpAtomicISub(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicISub, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicSMin; } + OpAtomicSMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicUMin; } + OpAtomicUMin(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicSMax; } + OpAtomicSMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicUMax; } + OpAtomicUMax(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicAnd; } + OpAtomicAnd(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicAnd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicOr; } + OpAtomicOr(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicOr, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicXor; } + OpAtomicXor(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicXor, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpPhi : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Phi; } + OpPhi(IdResultType type, std::vector op0) + : spv_inst{Op::Phi, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> std::vector & { return op0_; } + inline auto op0() const -> std::vector const & { return op0_; } + + private: + IdResultType type_; + std::vector op0_; +}; +class OpLoopMerge : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LoopMerge; } + OpLoopMerge(IdRef op0, IdRef op1, LoopControl op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::LoopMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LoopControl & { return op2_; } + inline auto op2() const -> LoopControl const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + LoopControl op2_; + std::optional op3_; +}; +class OpSelectionMerge : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SelectionMerge; } + OpSelectionMerge(IdRef op0, SelectionControl op1) + : spv_inst{Op::SelectionMerge, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> SelectionControl & { return op1_; } + inline auto op1() const -> SelectionControl const & { return op1_; } + + private: + IdRef op0_; + SelectionControl op1_; +}; +class OpLabel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Label; } + OpLabel() : spv_inst{Op::Label, true} {} + + private: +}; +class OpBranch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Branch; } + OpBranch(IdRef op0) : spv_inst{Op::Branch, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpBranchConditional : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BranchConditional; } + OpBranchConditional(IdRef op0, IdRef op1, IdRef op2, std::vector op3) + : spv_inst{Op::BranchConditional, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::vector & { return op3_; } + inline auto op3() const -> std::vector const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::vector op3_; +}; +class OpSwitch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Switch; } + OpSwitch(IdRef op0, IdRef op1, std::vector op2) + : spv_inst{Op::Switch, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::vector & { return op2_; } + inline auto op2() const -> std::vector const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + std::vector op2_; +}; +class OpKill : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Kill; } + constexpr static std::array required_capabilities = {Capability::Shader}; + OpKill() : spv_inst{Op::Kill, false} {} + + private: +}; +class OpReturn : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Return; } + OpReturn() : spv_inst{Op::Return, false} {} + + private: +}; +class OpReturnValue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReturnValue; } + OpReturnValue(IdRef op0) : spv_inst{Op::ReturnValue, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpUnreachable : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::Unreachable; } + OpUnreachable() : spv_inst{Op::Unreachable, false} {} + + private: +}; +class OpLifetimeStart : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LifetimeStart; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLifetimeStart(IdRef op0, LiteralInteger op1) + : spv_inst{Op::LifetimeStart, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpLifetimeStop : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::LifetimeStop; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpLifetimeStop(IdRef op0, LiteralInteger op1) + : spv_inst{Op::LifetimeStop, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + + private: + IdRef op0_; + LiteralInteger op1_; +}; +class OpGroupAsyncCopy : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAsyncCopy; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGroupAsyncCopy(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::GroupAsyncCopy, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpGroupWaitEvents : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupWaitEvents; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpGroupWaitEvents(IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupWaitEvents, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAll; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupAll(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupAll, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupAny; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupAny(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupAny, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupBroadcast; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupBroadcast, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupIAdd; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFAdd; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupUMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupSMin; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupFMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupFMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupFMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupUMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupSMax; } + constexpr static std::array required_capabilities = {Capability::Groups}; + OpGroupSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::WritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::WritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpReservedReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReservedReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReservedReadPipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::ReservedReadPipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpReservedWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReservedWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReservedWritePipe(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5) + : spv_inst{Op::ReservedWritePipe, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; +}; +class OpReserveReadPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ReserveReadPipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReserveReadPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReserveReadPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpReserveWritePipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ReserveWritePipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpReserveWritePipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::ReserveWritePipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpCommitReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CommitReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpCommitReadPipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::CommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpCommitWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CommitWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpCommitWritePipe(IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::CommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpIsValidReserveId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsValidReserveId; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpIsValidReserveId(IdResultType type, IdRef op0) + : spv_inst{Op::IsValidReserveId, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGetNumPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetNumPipePackets; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGetNumPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::GetNumPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGetMaxPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetMaxPipePackets; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGetMaxPipePackets(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::GetMaxPipePackets, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupReserveReadPipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupReserveReadPipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupReserveReadPipePackets(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GroupReserveReadPipePackets, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupReserveWritePipePackets : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupReserveWritePipePackets; + } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupReserveWritePipePackets(IdResultType type, IdScope op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GroupReserveWritePipePackets, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupCommitReadPipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupCommitReadPipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupCommitReadPipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::GroupCommitReadPipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGroupCommitWritePipe : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupCommitWritePipe; } + constexpr static std::array required_capabilities = {Capability::Pipes}; + OpGroupCommitWritePipe(IdScope op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::GroupCommitWritePipe, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdScope op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpEnqueueMarker : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EnqueueMarker; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpEnqueueMarker(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::EnqueueMarker, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpEnqueueKernel : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::EnqueueKernel; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpEnqueueKernel(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5, IdRef op6, IdRef op7, IdRef op8, IdRef op9, std::vector op10) + : spv_inst{Op::EnqueueKernel, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), op8_(std::move(op8)), + op9_(std::move(op9)), op10_(std::move(op10)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> IdRef & { return op6_; } + inline auto op6() const -> IdRef const & { return op6_; } + inline auto op7() -> IdRef & { return op7_; } + inline auto op7() const -> IdRef const & { return op7_; } + inline auto op8() -> IdRef & { return op8_; } + inline auto op8() const -> IdRef const & { return op8_; } + inline auto op9() -> IdRef & { return op9_; } + inline auto op9() const -> IdRef const & { return op9_; } + inline auto op10() -> std::vector & { return op10_; } + inline auto op10() const -> std::vector const & { return op10_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + IdRef op6_; + IdRef op7_; + IdRef op8_; + IdRef op9_; + std::vector op10_; +}; +class OpGetKernelNDrangeSubGroupCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelNDrangeSubGroupCount; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelNDrangeSubGroupCount(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GetKernelNDrangeSubGroupCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelNDrangeMaxSubGroupSize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelNDrangeMaxSubGroupSize; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelNDrangeMaxSubGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3, + IdRef op4) + : spv_inst{Op::GetKernelNDrangeMaxSubGroupSize, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelWorkGroupSize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelWorkGroupSize; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelWorkGroupSize(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::GetKernelWorkGroupSize, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpGetKernelPreferredWorkGroupSizeMultiple : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelPreferredWorkGroupSizeMultiple; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetKernelPreferredWorkGroupSizeMultiple(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3) + : spv_inst{Op::GetKernelPreferredWorkGroupSizeMultiple, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpRetainEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::RetainEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpRetainEvent(IdRef op0) : spv_inst{Op::RetainEvent, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpReleaseEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ReleaseEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpReleaseEvent(IdRef op0) : spv_inst{Op::ReleaseEvent, false}, op0_(std::move(op0)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdRef op0_; +}; +class OpCreateUserEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CreateUserEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpCreateUserEvent(IdResultType type) + : spv_inst{Op::CreateUserEvent, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpIsValidEvent : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::IsValidEvent; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpIsValidEvent(IdResultType type, IdRef op0) + : spv_inst{Op::IsValidEvent, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSetUserEventStatus : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SetUserEventStatus; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpSetUserEventStatus(IdRef op0, IdRef op1) + : spv_inst{Op::SetUserEventStatus, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpCaptureEventProfilingInfo : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CaptureEventProfilingInfo; + } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpCaptureEventProfilingInfo(IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::CaptureEventProfilingInfo, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGetDefaultQueue : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GetDefaultQueue; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpGetDefaultQueue(IdResultType type) + : spv_inst{Op::GetDefaultQueue, true}, type_(std::move(type)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + + private: + IdResultType type_; +}; +class OpBuildNDRange : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::BuildNDRange; } + constexpr static std::array required_capabilities = {Capability::DeviceEnqueue}; + OpBuildNDRange(IdResultType type, IdRef op0, IdRef op1, IdRef op2) + : spv_inst{Op::BuildNDRange, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; +}; +class OpImageSparseSampleImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseSampleImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseSampleExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSparseSampleExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSparseSampleDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleDrefImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseSampleDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseSampleDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleDrefExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSparseSampleDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSparseSampleProjImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjImplicitLod(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseSampleProjImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseSampleProjExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjExplicitLod(IdResultType type, IdRef op0, IdRef op1, ImageOperands op2) + : spv_inst{Op::ImageSparseSampleProjExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> ImageOperands & { return op2_; } + inline auto op2() const -> ImageOperands const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + ImageOperands op2_; +}; +class OpImageSparseSampleProjDrefImplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjDrefImplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjDrefImplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseSampleProjDrefImplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseSampleProjDrefExplicitLod : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseSampleProjDrefExplicitLod; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseSampleProjDrefExplicitLod(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + ImageOperands op3) + : spv_inst{Op::ImageSparseSampleProjDrefExplicitLod, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> ImageOperands & { return op3_; } + inline auto op3() const -> ImageOperands const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + ImageOperands op3_; +}; +class OpImageSparseFetch : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseFetch; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseFetch(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseFetch, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpImageSparseGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseGather; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseDrefGather : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseDrefGather; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseDrefGather(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::ImageSparseDrefGather, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpImageSparseTexelsResident : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ImageSparseTexelsResident; + } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseTexelsResident(IdResultType type, IdRef op0) + : spv_inst{Op::ImageSparseTexelsResident, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpNoLine : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::NoLine; } + OpNoLine() : spv_inst{Op::NoLine, false} {} + + private: +}; +class OpAtomicFlagTestAndSet : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFlagTestAndSet; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicFlagTestAndSet(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicFlagTestAndSet, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpAtomicFlagClear : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFlagClear; } + constexpr static std::array required_capabilities = {Capability::Kernel}; + OpAtomicFlagClear(IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::AtomicFlagClear, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpImageSparseRead : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ImageSparseRead; } + constexpr static std::array required_capabilities = { + Capability::SparseResidency}; + OpImageSparseRead(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt) + : spv_inst{Op::ImageSparseRead, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; +}; +class OpSizeOf : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::SizeOf; } + constexpr static std::array required_capabilities = {Capability::Addresses}; + OpSizeOf(IdResultType type, IdRef op0) + : spv_inst{Op::SizeOf, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpTypePipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypePipeStorage; } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpTypePipeStorage() : spv_inst{Op::TypePipeStorage, true} {} + + private: +}; +class OpConstantPipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConstantPipeStorage; } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpConstantPipeStorage(IdResultType type, LiteralInteger op0, LiteralInteger op1, + LiteralInteger op2) + : spv_inst{Op::ConstantPipeStorage, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> LiteralInteger & { return op0_; } + inline auto op0() const -> LiteralInteger const & { return op0_; } + inline auto op1() -> LiteralInteger & { return op1_; } + inline auto op1() const -> LiteralInteger const & { return op1_; } + inline auto op2() -> LiteralInteger & { return op2_; } + inline auto op2() const -> LiteralInteger const & { return op2_; } + + private: + IdResultType type_; + LiteralInteger op0_; + LiteralInteger op1_; + LiteralInteger op2_; +}; +class OpCreatePipeFromPipeStorage : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CreatePipeFromPipeStorage; + } + constexpr static std::array required_capabilities = {Capability::PipeStorage}; + OpCreatePipeFromPipeStorage(IdResultType type, IdRef op0) + : spv_inst{Op::CreatePipeFromPipeStorage, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpGetKernelLocalSizeForSubgroupCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelLocalSizeForSubgroupCount; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupDispatch}; + OpGetKernelLocalSizeForSubgroupCount(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3, IdRef op4) + : spv_inst{Op::GetKernelLocalSizeForSubgroupCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpGetKernelMaxNumSubgroups : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GetKernelMaxNumSubgroups; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupDispatch}; + OpGetKernelMaxNumSubgroups(IdResultType type, IdRef op0, IdRef op1, IdRef op2, IdRef op3) + : spv_inst{Op::GetKernelMaxNumSubgroups, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; +}; +class OpTypeNamedBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::TypeNamedBarrier; } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpTypeNamedBarrier() : spv_inst{Op::TypeNamedBarrier, true} {} + + private: +}; +class OpNamedBarrierInitialize : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::NamedBarrierInitialize; + } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpNamedBarrierInitialize(IdResultType type, IdRef op0) + : spv_inst{Op::NamedBarrierInitialize, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpMemoryNamedBarrier : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::MemoryNamedBarrier; } + constexpr static std::array required_capabilities = {Capability::NamedBarrier}; + OpMemoryNamedBarrier(IdRef op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::MemoryNamedBarrier, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpModuleProcessed : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ModuleProcessed; } + OpModuleProcessed(LiteralString op0) + : spv_inst{Op::ModuleProcessed, false}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpExecutionModeId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ExecutionModeId; } + OpExecutionModeId(IdRef op0, ExecutionMode op1) + : spv_inst{Op::ExecutionModeId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> ExecutionMode & { return op1_; } + inline auto op1() const -> ExecutionMode const & { return op1_; } + + private: + IdRef op0_; + ExecutionMode op1_; +}; +class OpDecorateId : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::DecorateId; } + constexpr static std::array required_extensions = { + "SPV_GOOGLE_hlsl_functionality1"}; + OpDecorateId(IdRef op0, Decoration op1) + : spv_inst{Op::DecorateId, false}, op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> Decoration & { return op1_; } + inline auto op1() const -> Decoration const & { return op1_; } + + private: + IdRef op0_; + Decoration op1_; +}; +class OpGroupNonUniformElect : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformElect; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniform}; + OpGroupNonUniformElect(IdResultType type, IdScope op0) + : spv_inst{Op::GroupNonUniformElect, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + + private: + IdResultType type_; + IdScope op0_; +}; +class OpGroupNonUniformAll : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformAll; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAll(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAll, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformAny : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformAny; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAny(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAny, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformAllEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformAllEqual; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformVote}; + OpGroupNonUniformAllEqual(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformAllEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBroadcast; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBroadcast, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformBroadcastFirst : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBroadcastFirst; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBroadcastFirst(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBroadcastFirst, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallot; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallot(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallot, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformInverseBallot : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformInverseBallot; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformInverseBallot(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformInverseBallot, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallotBitExtract : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotBitExtract; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotBitExtract(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBallotBitExtract, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformBallotBitCount : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotBitCount; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotBitCount(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2) + : spv_inst{Op::GroupNonUniformBallotBitCount, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; +}; +class OpGroupNonUniformBallotFindLSB : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotFindLSB; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotFindLSB(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallotFindLSB, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformBallotFindMSB : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBallotFindMSB; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformBallot}; + OpGroupNonUniformBallotFindMSB(IdResultType type, IdScope op0, IdRef op1) + : spv_inst{Op::GroupNonUniformBallotFindMSB, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; +}; +class OpGroupNonUniformShuffle : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffle; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffle}; + OpGroupNonUniformShuffle(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffle, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffle}; + OpGroupNonUniformShuffleXor(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleUp : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleUp; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffleRelative}; + OpGroupNonUniformShuffleUp(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleUp, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformShuffleDown : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformShuffleDown; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformShuffleRelative}; + OpGroupNonUniformShuffleDown(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformShuffleDown, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformIAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformIAdd; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformIAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformIAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFAdd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFAdd; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFAdd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFAdd, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformIMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformIMul; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformIMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformIMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMul : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMul; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMul(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFMul, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformSMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformSMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformSMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformSMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformUMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformUMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformUMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformUMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMin : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMin; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMin(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFMin, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformSMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformSMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformSMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformSMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformUMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformUMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformUMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformUMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformFMax : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::GroupNonUniformFMax; } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformFMax(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformFMax, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseAnd; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformBitwiseAnd, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseOr; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformBitwiseOr, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformBitwiseXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformBitwiseXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformBitwiseXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformBitwiseXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalAnd : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalAnd; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalAnd(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformLogicalAnd, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalOr : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalOr; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalOr(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformLogicalOr, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformLogicalXor : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformLogicalXor; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformArithmetic, Capability::GroupNonUniformClustered, + Capability::GroupNonUniformPartitionedNV}; + OpGroupNonUniformLogicalXor(IdResultType type, IdScope op0, GroupOperation op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::GroupNonUniformLogicalXor, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> GroupOperation & { return op1_; } + inline auto op1() const -> GroupOperation const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdScope op0_; + GroupOperation op1_; + IdRef op2_; + std::optional op3_; +}; +class OpGroupNonUniformQuadBroadcast : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformQuadBroadcast; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformQuad}; + OpGroupNonUniformQuadBroadcast(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformQuadBroadcast, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpGroupNonUniformQuadSwap : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::GroupNonUniformQuadSwap; + } + constexpr static std::array required_capabilities = { + Capability::GroupNonUniformQuad}; + OpGroupNonUniformQuadSwap(IdResultType type, IdScope op0, IdRef op1, IdRef op2) + : spv_inst{Op::GroupNonUniformQuadSwap, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + + private: + IdResultType type_; + IdScope op0_; + IdRef op1_; + IdRef op2_; +}; +class OpCopyLogical : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::CopyLogical; } + OpCopyLogical(IdResultType type, IdRef op0) + : spv_inst{Op::CopyLogical, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpPtrEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrEqual; } + OpPtrEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpPtrNotEqual : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrNotEqual; } + OpPtrNotEqual(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrNotEqual, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpPtrDiff : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::PtrDiff; } + constexpr static std::array required_capabilities = { + Capability::Addresses, Capability::VariablePointers, + Capability::VariablePointersStorageBuffer}; + OpPtrDiff(IdResultType type, IdRef op0, IdRef op1) + : spv_inst{Op::PtrDiff, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; +}; +class OpTypeCooperativeMatrixKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::TypeCooperativeMatrixKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpTypeCooperativeMatrixKHR(IdRef op0, IdScope op1, IdRef op2, IdRef op3, IdRef op4) + : spv_inst{Op::TypeCooperativeMatrixKHR, true}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + + private: + IdRef op0_; + IdScope op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; +}; +class OpCooperativeMatrixLoadKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLoadKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixLoadKHR(IdResultType type, IdRef op0, IdRef op1, + std::optional op2 = std::nullopt, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt) + : spv_inst{Op::CooperativeMatrixLoadKHR, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> std::optional & { return op2_; } + inline auto op2() const -> std::optional const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + std::optional op2_; + std::optional op3_; + std::optional op4_; +}; +class OpCooperativeMatrixStoreKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixStoreKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixStoreKHR(IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt, + std::optional op4 = std::nullopt, + std::optional op5 = std::nullopt) + : spv_inst{Op::CooperativeMatrixStoreKHR, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + inline auto op4() -> std::optional & { return op4_; } + inline auto op4() const -> std::optional const & { return op4_; } + inline auto op5() -> std::optional & { return op5_; } + inline auto op5() const -> std::optional const & { return op5_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; + std::optional op4_; + std::optional op5_; +}; +class OpCooperativeMatrixMulAddKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixMulAddKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixMulAddKHR(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + std::optional op3 = std::nullopt) + : spv_inst{Op::CooperativeMatrixMulAddKHR, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> std::optional & { return op3_; } + inline auto op3() const -> std::optional const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + std::optional op3_; +}; +class OpCooperativeMatrixLengthKHR : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLengthKHR; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixKHR}; + OpCooperativeMatrixLengthKHR(IdResultType type, IdRef op0) + : spv_inst{Op::CooperativeMatrixLengthKHR, true}, type_(std::move(type)), + op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSubgroupBlockReadINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::SubgroupBlockReadINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupBufferBlockIOINTEL}; + OpSubgroupBlockReadINTEL(IdResultType type, IdRef op0) + : spv_inst{Op::SubgroupBlockReadINTEL, true}, type_(std::move(type)), op0_(std::move(op0)) { + } + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpSubgroupBlockWriteINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::SubgroupBlockWriteINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SubgroupBufferBlockIOINTEL}; + OpSubgroupBlockWriteINTEL(IdRef op0, IdRef op1) + : spv_inst{Op::SubgroupBlockWriteINTEL, false}, op0_(std::move(op0)), op1_(std::move(op1)) { + } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + + private: + IdRef op0_; + IdRef op1_; +}; +class OpAsmTargetINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AsmTargetINTEL; } + constexpr static std::array required_capabilities = {Capability::AsmINTEL}; + OpAsmTargetINTEL(LiteralString op0) + : spv_inst{Op::AsmTargetINTEL, true}, op0_(std::move(op0)) {} + inline auto op0() -> LiteralString & { return op0_; } + inline auto op0() const -> LiteralString const & { return op0_; } + + private: + LiteralString op0_; +}; +class OpAsmINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AsmINTEL; } + constexpr static std::array required_capabilities = {Capability::AsmINTEL}; + OpAsmINTEL(IdResultType type, IdRef op0, IdRef op1, LiteralString op2, LiteralString op3) + : spv_inst{Op::AsmINTEL, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> LiteralString & { return op2_; } + inline auto op2() const -> LiteralString const & { return op2_; } + inline auto op3() -> LiteralString & { return op3_; } + inline auto op3() const -> LiteralString const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + LiteralString op2_; + LiteralString op3_; +}; +class OpAsmCallINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AsmCallINTEL; } + constexpr static std::array required_capabilities = {Capability::AsmINTEL}; + OpAsmCallINTEL(IdResultType type, IdRef op0, std::vector op1) + : spv_inst{Op::AsmCallINTEL, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> std::vector & { return op1_; } + inline auto op1() const -> std::vector const & { return op1_; } + + private: + IdResultType type_; + IdRef op0_; + std::vector op1_; +}; +class OpAtomicFMinEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMinEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16MinMaxEXT, Capability::AtomicFloat32MinMaxEXT, + Capability::AtomicFloat64MinMaxEXT, Capability::AtomicFloat16VectorNV}; + OpAtomicFMinEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFMinEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicFMaxEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFMaxEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16MinMaxEXT, Capability::AtomicFloat32MinMaxEXT, + Capability::AtomicFloat64MinMaxEXT, Capability::AtomicFloat16VectorNV}; + OpAtomicFMaxEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFMaxEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpAtomicFAddEXT : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::AtomicFAddEXT; } + constexpr static std::array required_capabilities = { + Capability::AtomicFloat16AddEXT, Capability::AtomicFloat32AddEXT, + Capability::AtomicFloat64AddEXT, Capability::AtomicFloat16VectorNV}; + constexpr static std::array required_extensions = { + "SPV_EXT_shader_atomic_float_add"}; + OpAtomicFAddEXT(IdResultType type, IdRef op0, IdScope op1, IdMemorySemantics op2, IdRef op3) + : spv_inst{Op::AtomicFAddEXT, true}, type_(std::move(type)), op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + + private: + IdResultType type_; + IdRef op0_; + IdScope op1_; + IdMemorySemantics op2_; + IdRef op3_; +}; +class OpConvertFToBF16INTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertFToBF16INTEL; } + constexpr static std::array required_capabilities = { + Capability::BFloat16ConversionINTEL}; + OpConvertFToBF16INTEL(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertFToBF16INTEL, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpConvertBF16ToFINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { return s.opcode() == Op::ConvertBF16ToFINTEL; } + constexpr static std::array required_capabilities = { + Capability::BFloat16ConversionINTEL}; + OpConvertBF16ToFINTEL(IdResultType type, IdRef op0) + : spv_inst{Op::ConvertBF16ToFINTEL, true}, type_(std::move(type)), op0_(std::move(op0)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + + private: + IdResultType type_; + IdRef op0_; +}; +class OpControlBarrierArriveINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ControlBarrierArriveINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SplitBarrierINTEL}; + OpControlBarrierArriveINTEL(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrierArriveINTEL, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpControlBarrierWaitINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::ControlBarrierWaitINTEL; + } + constexpr static std::array required_capabilities = { + Capability::SplitBarrierINTEL}; + OpControlBarrierWaitINTEL(IdScope op0, IdScope op1, IdMemorySemantics op2) + : spv_inst{Op::ControlBarrierWaitINTEL, false}, op0_(std::move(op0)), op1_(std::move(op1)), + op2_(std::move(op2)) {} + inline auto op0() -> IdScope & { return op0_; } + inline auto op0() const -> IdScope const & { return op0_; } + inline auto op1() -> IdScope & { return op1_; } + inline auto op1() const -> IdScope const & { return op1_; } + inline auto op2() -> IdMemorySemantics & { return op2_; } + inline auto op2() const -> IdMemorySemantics const & { return op2_; } + + private: + IdScope op0_; + IdScope op1_; + IdMemorySemantics op2_; +}; +class OpCooperativeMatrixLoadCheckedINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixLoadCheckedINTEL; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixCheckedInstructionsINTEL}; + OpCooperativeMatrixLoadCheckedINTEL(IdResultType type, IdRef op0, IdRef op1, IdRef op2, + IdRef op3, IdRef op4, IdRef op5, + std::optional op6 = std::nullopt, + std::optional op7 = std::nullopt, + std::optional op8 = std::nullopt) + : spv_inst{Op::CooperativeMatrixLoadCheckedINTEL, true}, type_(std::move(type)), + op0_(std::move(op0)), op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), + op4_(std::move(op4)), op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), + op8_(std::move(op8)) {} + inline auto type() -> IdResultType & { return type_; } + inline auto type() const -> IdResultType const & { return type_; } + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> std::optional & { return op6_; } + inline auto op6() const -> std::optional const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } + inline auto op7() const -> std::optional const & { return op7_; } + inline auto op8() -> std::optional & { return op8_; } + inline auto op8() const -> std::optional const & { return op8_; } + + private: + IdResultType type_; + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + std::optional op6_; + std::optional op7_; + std::optional op8_; +}; +class OpCooperativeMatrixStoreCheckedINTEL : public spv_inst { + public: + inline static bool classof(spv_inst const &s) { + return s.opcode() == Op::CooperativeMatrixStoreCheckedINTEL; + } + constexpr static std::array required_capabilities = { + Capability::CooperativeMatrixCheckedInstructionsINTEL}; + OpCooperativeMatrixStoreCheckedINTEL(IdRef op0, IdRef op1, IdRef op2, IdRef op3, IdRef op4, + IdRef op5, IdRef op6, + std::optional op7 = std::nullopt, + std::optional op8 = std::nullopt, + std::optional op9 = std::nullopt) + : spv_inst{Op::CooperativeMatrixStoreCheckedINTEL, false}, op0_(std::move(op0)), + op1_(std::move(op1)), op2_(std::move(op2)), op3_(std::move(op3)), op4_(std::move(op4)), + op5_(std::move(op5)), op6_(std::move(op6)), op7_(std::move(op7)), op8_(std::move(op8)), + op9_(std::move(op9)) {} + inline auto op0() -> IdRef & { return op0_; } + inline auto op0() const -> IdRef const & { return op0_; } + inline auto op1() -> IdRef & { return op1_; } + inline auto op1() const -> IdRef const & { return op1_; } + inline auto op2() -> IdRef & { return op2_; } + inline auto op2() const -> IdRef const & { return op2_; } + inline auto op3() -> IdRef & { return op3_; } + inline auto op3() const -> IdRef const & { return op3_; } + inline auto op4() -> IdRef & { return op4_; } + inline auto op4() const -> IdRef const & { return op4_; } + inline auto op5() -> IdRef & { return op5_; } + inline auto op5() const -> IdRef const & { return op5_; } + inline auto op6() -> IdRef & { return op6_; } + inline auto op6() const -> IdRef const & { return op6_; } + inline auto op7() -> std::optional & { return op7_; } + inline auto op7() const -> std::optional const & { return op7_; } + inline auto op8() -> std::optional & { return op8_; } + inline auto op8() const -> std::optional const & { return op8_; } + inline auto op9() -> std::optional & { return op9_; } + inline auto op9() const -> std::optional const & { return op9_; } + + private: + IdRef op0_; + IdRef op1_; + IdRef op2_; + IdRef op3_; + IdRef op4_; + IdRef op5_; + IdRef op6_; + std::optional op7_; + std::optional op8_; + std::optional op9_; +}; + +} // namespace tinytc::spv + +#endif // GENERATED_INSTRUCTIONS_20250630_HPP diff --git a/src/spv/lut.hpp b/src/spv/lut.hpp new file mode 100644 index 00000000..4af4b86b --- /dev/null +++ b/src/spv/lut.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LUT_20241216_HPP +#define LUT_20241216_HPP + +namespace tinytc::spv { + +class spv_inst; + +template +auto lookup(Map &map, Key &&key, Maker &&maker) { + auto it = map.find(key); + if (it == map.end()) { + map[key] = maker(key); + return map[key]; + } + return it->second; +} +template auto lookup(spv_inst *&var, Maker &&maker) -> spv_inst * { + if (!var) { + var = maker(); + } + return var; +} + +} // namespace tinytc::spv + +#endif // LUT_20241216_HPP diff --git a/src/spv/matrix_walker.cpp b/src/spv/matrix_walker.cpp new file mode 100644 index 00000000..24a78b93 --- /dev/null +++ b/src/spv/matrix_walker.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/matrix_walker.hpp" +#include "coopmatrix_layout.hpp" +#include "node/type.hpp" +#include "spv/converter_aux.hpp" +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "tinytc/types.hpp" + +namespace tinytc::spv { + +matrix_walker::matrix_walker(uniquifier &unique, std::int32_t sgs, coopmatrix_layout const &layout, + spv_inst *pos0, spv_inst *pos1, spv_inst *shape0, spv_inst *shape1, + spv_inst *stride0, spv_inst *stride1, checked_flag chk, + std::int32_t constant_p) + : unique_{unique}, layout_{layout}, chk_{chk} { + index_ty_ = get_spv_index_ty(unique, layout.sty->context()); + + auto &mod = unique_.mod(); + auto crows = unique.constant(layout_.rows); + row_inc_ = mod.add(index_ty_, crows, stride0); + col_inc_factor_ = sgs / layout.rows; + col_inc_ = mod.add(index_ty_, unique.constant(col_inc_factor_), stride1); + + spv_inst *p = constant_p >= 0 ? unique.constant(constant_p) + : unique.load_builtin(BuiltIn::SubgroupLocalInvocationId); + p = mod.add(index_ty_, p); + + row_ = layout.rows < sgs ? mod.add(index_ty_, p, crows) : p; + row_ = mod.add(index_ty_, row_, pos0); + row_ = mod.add(index_ty_, row_, stride0); + + auto c0 = unique.null_constant(index_ty_); + col0_ = layout.rows < sgs ? mod.add(index_ty_, p, crows) : c0; + col0_ = mod.add(index_ty_, col0_, pos1); + col0_ = mod.add(index_ty_, col0_, stride1); + col_ = col0_; + + if (rows_checked()) { + row_max_ = mod.add(index_ty_, shape0, stride0); + } + if (may_need_mask() || cols_checked()) { + col_max_ = mod.add(index_ty_, shape1, stride1); + } +} + +void matrix_walker::advance_block() { + col_ = col0_; + col_no_ = 0; + row_ = unique_.mod().add(index_ty_, row_, row_inc_); + ++block_no_; +} +void matrix_walker::advance_column() { + col_ = unique_.mod().add(index_ty_, col_, col_inc_); + ++col_no_; +} + +auto matrix_walker::component_no(std::int32_t col_no) const -> std::int32_t { + return layout_.component_no(col_no, block_no_); +} +auto matrix_walker::component_no() const -> std::int32_t { return component_no(col_no_); } +auto matrix_walker::offset() const -> spv_inst * { + return unique_.mod().add(index_ty_, row_, col_); +}; +auto matrix_walker::rows_checked() const -> bool { + return chk_ == checked_flag::both || chk_ == checked_flag::rows; +} +auto matrix_walker::cols_checked() const -> bool { + return chk_ == checked_flag::both || chk_ == checked_flag::cols; +} +auto matrix_walker::needs_mask() const -> bool { + return (col_no_ + 1) * col_inc_factor_ > layout_.shape1; +} +auto matrix_walker::may_need_mask() const -> bool { + return layout_.cols * col_inc_factor_ > layout_.shape1; +} + +auto matrix_walker::col_ok() const -> spv_inst * { + auto c0 = unique_.null_constant(index_ty_); + auto bool_ty = unique_.bool_ty(); + auto &mod = unique_.mod(); + auto check1 = mod.add(bool_ty, c0, col_); + auto check2 = mod.add(bool_ty, col_, col_max_); + return mod.add(bool_ty, check1, check2); +} +auto matrix_walker::row_ok() const -> spv_inst * { + auto c0 = unique_.null_constant(index_ty_); + auto bool_ty = unique_.bool_ty(); + auto &mod = unique_.mod(); + auto check1 = mod.add(bool_ty, c0, row_); + auto check2 = mod.add(bool_ty, row_, row_max_); + return mod.add(bool_ty, check1, check2); +} + +} // namespace tinytc::spv diff --git a/src/spv/matrix_walker.hpp b/src/spv/matrix_walker.hpp new file mode 100644 index 00000000..d4b36579 --- /dev/null +++ b/src/spv/matrix_walker.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MATRIX_WALKER_20250428_HPP +#define MATRIX_WALKER_20250428_HPP + +#include + +namespace tinytc { +enum class checked_flag; +struct coopmatrix_layout; +} // namespace tinytc + +namespace tinytc::spv { +class spv_inst; +class uniquifier; + +class matrix_walker { + public: + matrix_walker(uniquifier &unique, std::int32_t sgs, coopmatrix_layout const &layout, + spv_inst *pos0, spv_inst *pos1, spv_inst *shape0, spv_inst *shape1, + spv_inst *stride0, spv_inst *stride1, checked_flag chk, + std::int32_t constant_p = -1); + + void advance_block(); + void advance_column(); + + auto component_no(std::int32_t col_no) const -> std::int32_t; + auto component_no() const -> std::int32_t; + auto offset() const -> spv_inst *; + auto rows_checked() const -> bool; + auto cols_checked() const -> bool; + auto needs_mask() const -> bool; + auto may_need_mask() const -> bool; + auto col_ok() const -> spv_inst *; + auto row_ok() const -> spv_inst *; + inline auto block_no() const { return block_no_; } + inline auto col_no() const { return col_no_; } + + private: + uniquifier &unique_; + coopmatrix_layout const &layout_; + checked_flag chk_; + spv_inst *index_ty_; + spv_inst *row_inc_; + std::int64_t col_inc_factor_; + spv_inst *col_inc_; + spv_inst *row_; + spv_inst *col0_; + spv_inst *col_; + spv_inst *row_max_ = nullptr; + spv_inst *col_max_ = nullptr; + std::int32_t block_no_ = 0; + std::int32_t col_no_ = 0; +}; + +} // namespace tinytc::spv + +#endif // MATRIX_WALKER_20250428_HPP diff --git a/src/spv/module.cpp b/src/spv/module.cpp new file mode 100644 index 00000000..eed0df09 --- /dev/null +++ b/src/spv/module.cpp @@ -0,0 +1,106 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "tinytc/types.h" + +#include "error.hpp" +#include "spv/defs.hpp" +#include "spv/module.hpp" +#include "spv/pass/dump_asm.hpp" +#include "tinytc/core.h" +#include "tinytc/types.hpp" +#include "util/ilist_base.hpp" + +#include +#include +#include +#include +#include +#include + +namespace tinytc { +void ilist_callbacks::node_added(spv::spv_inst *) {} +void ilist_callbacks::node_moved(spv::spv_inst *) {} +void ilist_callbacks::node_removed(spv::spv_inst *node) { delete node; } +} // namespace tinytc + +using namespace tinytc; + +tinytc_spv_mod::tinytc_spv_mod(shared_handle ctx, + tinytc_core_feature_flags_t core_features, + std::int32_t major_version, std::int32_t minor_version) + : ctx_{std::move(ctx)}, core_features_(core_features), major_version_{major_version}, + minor_version_{minor_version} {} +tinytc_spv_mod::~tinytc_spv_mod() {} + +auto tinytc_spv_mod::bound() const -> std::uint32_t { + std::uint32_t bnd = 0; + for (auto const &sec : insts_) { + for (auto const &i : sec) { + if (i.has_result_id()) { + bnd = std::max(bnd, i.id()); + } + } + } + return bnd + 1; +} + +extern "C" { + +tinytc_status_t tinytc_spv_mod_dump(const_tinytc_spv_mod_t mod) { + if (mod == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { spv::dump_asm_pass{std::cerr}.run_on_module(*mod); }); +} + +tinytc_status_t tinytc_spv_mod_print_to_file(const_tinytc_spv_mod_t mod, char const *filename) { + if (mod == nullptr || filename == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto stream = std::ofstream(filename); + if (!stream.good()) { + throw status::file_io_error; + } + spv::dump_asm_pass{stream}.run_on_module(*mod); + }); +} + +tinytc_status_t tinytc_spv_mod_print_to_string(const_tinytc_spv_mod_t mod, char **str) { + if (mod == nullptr || str == nullptr) { + return tinytc_status_invalid_arguments; + } + return exception_to_status_code([&] { + auto const text = [&] { + auto oss = std::ostringstream{}; + spv::dump_asm_pass{oss}.run_on_module(*mod); + return std::move(oss).str(); + }(); + auto const length = text.size() + 1; // Need to include terminating null character + *str = (char *)malloc(length * sizeof(char)); + if (!str) { + throw status::bad_alloc; + } + std::strncpy(*str, text.c_str(), length); + }); +} +tinytc_status_t tinytc_spv_mod_release(tinytc_spv_mod_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + auto ref_count = obj->dec_ref(); + if (ref_count == 0) { + delete obj; + } + return tinytc_status_success; +} + +tinytc_status_t tinytc_spv_mod_retain(tinytc_spv_mod_t obj) { + if (obj == nullptr) { + return tinytc_status_invalid_arguments; + } + obj->inc_ref(); + return tinytc_status_success; +} +} diff --git a/src/spv/module.hpp b/src/spv/module.hpp new file mode 100644 index 00000000..64b0c9fa --- /dev/null +++ b/src/spv/module.hpp @@ -0,0 +1,97 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef MODULE_20241029_HPP +#define MODULE_20241029_HPP + +#include "reference_counted.hpp" +#include "spv/defs.hpp" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/ilist.hpp" + +#include +#include +#include +#include + +namespace tinytc { + +template <> struct ilist_callbacks { + void node_added(spv::spv_inst *node); + void node_moved(spv::spv_inst *node); + void node_removed(spv::spv_inst *node); +}; + +namespace spv { + +enum class section { + capability = 0, + extension = 1, + ext_inst = 2, + memory_model = 3, + entry_point = 4, + execution_mode = 5, + decoration = 6, + type_const_var = 7, + function = 8 +}; +inline constexpr std::int32_t num_module_sections = 9; + +} // namespace spv +} // namespace tinytc + +struct tinytc_spv_mod final : tinytc::reference_counted { + public: + using iterator = tinytc::ilist::iterator; + using const_iterator = tinytc::ilist::const_iterator; + + tinytc_spv_mod(tinytc::shared_handle ctx, + tinytc_core_feature_flags_t core_features, std::int32_t major_version = 1, + std::int32_t minor_version = 6); + ~tinytc_spv_mod(); + + inline auto context() const -> tinytc_compiler_context_t { return ctx_.get(); } + inline auto share_context() const -> tinytc::shared_handle { + return ctx_; + } + inline auto core_features() const -> tinytc_core_feature_flags_t { return core_features_; } + + auto bound() const -> std::uint32_t; + + inline auto insts(tinytc::spv::section s) -> tinytc::ilist & { + return insts_[static_cast(s)]; + } + inline auto insts(tinytc::spv::section s) const + -> tinytc::ilist const & { + return insts_[static_cast(s)]; + } + inline auto empty(tinytc::spv::section s) const -> bool { + return insts_[static_cast(s)].empty(); + } + + inline auto major_version() const -> std::int32_t { return major_version_; } + inline auto minor_version() const -> std::int32_t { return minor_version_; } + + template + auto add_to(tinytc::spv::section s, Args &&...args) -> T * { + auto ptr = std::make_unique(std::forward(args)...).release(); + insts(s).push_back(ptr); + return ptr; + } + template auto add(Args &&...args) -> T * { + return add_to(tinytc::spv::section::function, std::forward(args)...); + } + + private: + tinytc::shared_handle ctx_; + tinytc_core_feature_flags_t core_features_; + std::array, tinytc::spv::num_module_sections> insts_; + std::int32_t major_version_, minor_version_; +}; + +namespace tinytc::spv { +// using mod = ::tinytc_spv_mod; +} // namespace tinytc::spv + +#endif // MODULE_20241029_HPP diff --git a/src/spv/names.cpp b/src/spv/names.cpp new file mode 100644 index 00000000..4f82ceeb --- /dev/null +++ b/src/spv/names.cpp @@ -0,0 +1,3125 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "names.hpp" +#include "enums.hpp" + +namespace tinytc::spv { + +auto to_string(Op op) -> char const * { + switch (op) { + case Op::Nop: + return "Nop"; + case Op::Undef: + return "Undef"; + case Op::SourceContinued: + return "SourceContinued"; + case Op::Source: + return "Source"; + case Op::SourceExtension: + return "SourceExtension"; + case Op::Name: + return "Name"; + case Op::MemberName: + return "MemberName"; + case Op::String: + return "String"; + case Op::Line: + return "Line"; + case Op::Extension: + return "Extension"; + case Op::ExtInstImport: + return "ExtInstImport"; + case Op::ExtInst: + return "ExtInst"; + case Op::MemoryModel: + return "MemoryModel"; + case Op::EntryPoint: + return "EntryPoint"; + case Op::ExecutionMode: + return "ExecutionMode"; + case Op::Capability: + return "Capability"; + case Op::TypeVoid: + return "TypeVoid"; + case Op::TypeBool: + return "TypeBool"; + case Op::TypeInt: + return "TypeInt"; + case Op::TypeFloat: + return "TypeFloat"; + case Op::TypeVector: + return "TypeVector"; + case Op::TypeMatrix: + return "TypeMatrix"; + case Op::TypeImage: + return "TypeImage"; + case Op::TypeSampler: + return "TypeSampler"; + case Op::TypeSampledImage: + return "TypeSampledImage"; + case Op::TypeArray: + return "TypeArray"; + case Op::TypeRuntimeArray: + return "TypeRuntimeArray"; + case Op::TypeStruct: + return "TypeStruct"; + case Op::TypeOpaque: + return "TypeOpaque"; + case Op::TypePointer: + return "TypePointer"; + case Op::TypeFunction: + return "TypeFunction"; + case Op::TypeEvent: + return "TypeEvent"; + case Op::TypeDeviceEvent: + return "TypeDeviceEvent"; + case Op::TypeReserveId: + return "TypeReserveId"; + case Op::TypeQueue: + return "TypeQueue"; + case Op::TypePipe: + return "TypePipe"; + case Op::TypeForwardPointer: + return "TypeForwardPointer"; + case Op::ConstantTrue: + return "ConstantTrue"; + case Op::ConstantFalse: + return "ConstantFalse"; + case Op::Constant: + return "Constant"; + case Op::ConstantComposite: + return "ConstantComposite"; + case Op::ConstantSampler: + return "ConstantSampler"; + case Op::ConstantNull: + return "ConstantNull"; + case Op::Function: + return "Function"; + case Op::FunctionParameter: + return "FunctionParameter"; + case Op::FunctionEnd: + return "FunctionEnd"; + case Op::FunctionCall: + return "FunctionCall"; + case Op::Variable: + return "Variable"; + case Op::ImageTexelPointer: + return "ImageTexelPointer"; + case Op::Load: + return "Load"; + case Op::Store: + return "Store"; + case Op::CopyMemory: + return "CopyMemory"; + case Op::CopyMemorySized: + return "CopyMemorySized"; + case Op::AccessChain: + return "AccessChain"; + case Op::InBoundsAccessChain: + return "InBoundsAccessChain"; + case Op::PtrAccessChain: + return "PtrAccessChain"; + case Op::ArrayLength: + return "ArrayLength"; + case Op::GenericPtrMemSemantics: + return "GenericPtrMemSemantics"; + case Op::InBoundsPtrAccessChain: + return "InBoundsPtrAccessChain"; + case Op::Decorate: + return "Decorate"; + case Op::MemberDecorate: + return "MemberDecorate"; + case Op::DecorationGroup: + return "DecorationGroup"; + case Op::GroupDecorate: + return "GroupDecorate"; + case Op::GroupMemberDecorate: + return "GroupMemberDecorate"; + case Op::VectorExtractDynamic: + return "VectorExtractDynamic"; + case Op::VectorInsertDynamic: + return "VectorInsertDynamic"; + case Op::VectorShuffle: + return "VectorShuffle"; + case Op::CompositeConstruct: + return "CompositeConstruct"; + case Op::CompositeExtract: + return "CompositeExtract"; + case Op::CompositeInsert: + return "CompositeInsert"; + case Op::CopyObject: + return "CopyObject"; + case Op::Transpose: + return "Transpose"; + case Op::SampledImage: + return "SampledImage"; + case Op::ImageSampleImplicitLod: + return "ImageSampleImplicitLod"; + case Op::ImageSampleExplicitLod: + return "ImageSampleExplicitLod"; + case Op::ImageSampleDrefImplicitLod: + return "ImageSampleDrefImplicitLod"; + case Op::ImageSampleDrefExplicitLod: + return "ImageSampleDrefExplicitLod"; + case Op::ImageSampleProjImplicitLod: + return "ImageSampleProjImplicitLod"; + case Op::ImageSampleProjExplicitLod: + return "ImageSampleProjExplicitLod"; + case Op::ImageSampleProjDrefImplicitLod: + return "ImageSampleProjDrefImplicitLod"; + case Op::ImageSampleProjDrefExplicitLod: + return "ImageSampleProjDrefExplicitLod"; + case Op::ImageFetch: + return "ImageFetch"; + case Op::ImageGather: + return "ImageGather"; + case Op::ImageDrefGather: + return "ImageDrefGather"; + case Op::ImageRead: + return "ImageRead"; + case Op::ImageWrite: + return "ImageWrite"; + case Op::Image: + return "Image"; + case Op::ImageQueryFormat: + return "ImageQueryFormat"; + case Op::ImageQueryOrder: + return "ImageQueryOrder"; + case Op::ImageQuerySizeLod: + return "ImageQuerySizeLod"; + case Op::ImageQuerySize: + return "ImageQuerySize"; + case Op::ImageQueryLod: + return "ImageQueryLod"; + case Op::ImageQueryLevels: + return "ImageQueryLevels"; + case Op::ImageQuerySamples: + return "ImageQuerySamples"; + case Op::ConvertFToU: + return "ConvertFToU"; + case Op::ConvertFToS: + return "ConvertFToS"; + case Op::ConvertSToF: + return "ConvertSToF"; + case Op::ConvertUToF: + return "ConvertUToF"; + case Op::UConvert: + return "UConvert"; + case Op::SConvert: + return "SConvert"; + case Op::FConvert: + return "FConvert"; + case Op::QuantizeToF16: + return "QuantizeToF16"; + case Op::ConvertPtrToU: + return "ConvertPtrToU"; + case Op::SatConvertSToU: + return "SatConvertSToU"; + case Op::SatConvertUToS: + return "SatConvertUToS"; + case Op::ConvertUToPtr: + return "ConvertUToPtr"; + case Op::PtrCastToGeneric: + return "PtrCastToGeneric"; + case Op::GenericCastToPtr: + return "GenericCastToPtr"; + case Op::GenericCastToPtrExplicit: + return "GenericCastToPtrExplicit"; + case Op::Bitcast: + return "Bitcast"; + case Op::SNegate: + return "SNegate"; + case Op::FNegate: + return "FNegate"; + case Op::IAdd: + return "IAdd"; + case Op::FAdd: + return "FAdd"; + case Op::ISub: + return "ISub"; + case Op::FSub: + return "FSub"; + case Op::IMul: + return "IMul"; + case Op::FMul: + return "FMul"; + case Op::UDiv: + return "UDiv"; + case Op::SDiv: + return "SDiv"; + case Op::FDiv: + return "FDiv"; + case Op::UMod: + return "UMod"; + case Op::SRem: + return "SRem"; + case Op::SMod: + return "SMod"; + case Op::FRem: + return "FRem"; + case Op::FMod: + return "FMod"; + case Op::VectorTimesScalar: + return "VectorTimesScalar"; + case Op::MatrixTimesScalar: + return "MatrixTimesScalar"; + case Op::VectorTimesMatrix: + return "VectorTimesMatrix"; + case Op::MatrixTimesVector: + return "MatrixTimesVector"; + case Op::MatrixTimesMatrix: + return "MatrixTimesMatrix"; + case Op::OuterProduct: + return "OuterProduct"; + case Op::Dot: + return "Dot"; + case Op::IAddCarry: + return "IAddCarry"; + case Op::ISubBorrow: + return "ISubBorrow"; + case Op::UMulExtended: + return "UMulExtended"; + case Op::SMulExtended: + return "SMulExtended"; + case Op::Any: + return "Any"; + case Op::All: + return "All"; + case Op::IsNan: + return "IsNan"; + case Op::IsInf: + return "IsInf"; + case Op::IsFinite: + return "IsFinite"; + case Op::IsNormal: + return "IsNormal"; + case Op::SignBitSet: + return "SignBitSet"; + case Op::LessOrGreater: + return "LessOrGreater"; + case Op::Ordered: + return "Ordered"; + case Op::Unordered: + return "Unordered"; + case Op::LogicalEqual: + return "LogicalEqual"; + case Op::LogicalNotEqual: + return "LogicalNotEqual"; + case Op::LogicalOr: + return "LogicalOr"; + case Op::LogicalAnd: + return "LogicalAnd"; + case Op::LogicalNot: + return "LogicalNot"; + case Op::Select: + return "Select"; + case Op::IEqual: + return "IEqual"; + case Op::INotEqual: + return "INotEqual"; + case Op::UGreaterThan: + return "UGreaterThan"; + case Op::SGreaterThan: + return "SGreaterThan"; + case Op::UGreaterThanEqual: + return "UGreaterThanEqual"; + case Op::SGreaterThanEqual: + return "SGreaterThanEqual"; + case Op::ULessThan: + return "ULessThan"; + case Op::SLessThan: + return "SLessThan"; + case Op::ULessThanEqual: + return "ULessThanEqual"; + case Op::SLessThanEqual: + return "SLessThanEqual"; + case Op::FOrdEqual: + return "FOrdEqual"; + case Op::FUnordEqual: + return "FUnordEqual"; + case Op::FOrdNotEqual: + return "FOrdNotEqual"; + case Op::FUnordNotEqual: + return "FUnordNotEqual"; + case Op::FOrdLessThan: + return "FOrdLessThan"; + case Op::FUnordLessThan: + return "FUnordLessThan"; + case Op::FOrdGreaterThan: + return "FOrdGreaterThan"; + case Op::FUnordGreaterThan: + return "FUnordGreaterThan"; + case Op::FOrdLessThanEqual: + return "FOrdLessThanEqual"; + case Op::FUnordLessThanEqual: + return "FUnordLessThanEqual"; + case Op::FOrdGreaterThanEqual: + return "FOrdGreaterThanEqual"; + case Op::FUnordGreaterThanEqual: + return "FUnordGreaterThanEqual"; + case Op::ShiftRightLogical: + return "ShiftRightLogical"; + case Op::ShiftRightArithmetic: + return "ShiftRightArithmetic"; + case Op::ShiftLeftLogical: + return "ShiftLeftLogical"; + case Op::BitwiseOr: + return "BitwiseOr"; + case Op::BitwiseXor: + return "BitwiseXor"; + case Op::BitwiseAnd: + return "BitwiseAnd"; + case Op::Not: + return "Not"; + case Op::BitFieldInsert: + return "BitFieldInsert"; + case Op::BitFieldSExtract: + return "BitFieldSExtract"; + case Op::BitFieldUExtract: + return "BitFieldUExtract"; + case Op::BitReverse: + return "BitReverse"; + case Op::BitCount: + return "BitCount"; + case Op::DPdx: + return "DPdx"; + case Op::DPdy: + return "DPdy"; + case Op::Fwidth: + return "Fwidth"; + case Op::DPdxFine: + return "DPdxFine"; + case Op::DPdyFine: + return "DPdyFine"; + case Op::FwidthFine: + return "FwidthFine"; + case Op::DPdxCoarse: + return "DPdxCoarse"; + case Op::DPdyCoarse: + return "DPdyCoarse"; + case Op::FwidthCoarse: + return "FwidthCoarse"; + case Op::EmitVertex: + return "EmitVertex"; + case Op::EndPrimitive: + return "EndPrimitive"; + case Op::EmitStreamVertex: + return "EmitStreamVertex"; + case Op::EndStreamPrimitive: + return "EndStreamPrimitive"; + case Op::ControlBarrier: + return "ControlBarrier"; + case Op::MemoryBarrier: + return "MemoryBarrier"; + case Op::AtomicLoad: + return "AtomicLoad"; + case Op::AtomicStore: + return "AtomicStore"; + case Op::AtomicExchange: + return "AtomicExchange"; + case Op::AtomicCompareExchange: + return "AtomicCompareExchange"; + case Op::AtomicCompareExchangeWeak: + return "AtomicCompareExchangeWeak"; + case Op::AtomicIIncrement: + return "AtomicIIncrement"; + case Op::AtomicIDecrement: + return "AtomicIDecrement"; + case Op::AtomicIAdd: + return "AtomicIAdd"; + case Op::AtomicISub: + return "AtomicISub"; + case Op::AtomicSMin: + return "AtomicSMin"; + case Op::AtomicUMin: + return "AtomicUMin"; + case Op::AtomicSMax: + return "AtomicSMax"; + case Op::AtomicUMax: + return "AtomicUMax"; + case Op::AtomicAnd: + return "AtomicAnd"; + case Op::AtomicOr: + return "AtomicOr"; + case Op::AtomicXor: + return "AtomicXor"; + case Op::Phi: + return "Phi"; + case Op::LoopMerge: + return "LoopMerge"; + case Op::SelectionMerge: + return "SelectionMerge"; + case Op::Label: + return "Label"; + case Op::Branch: + return "Branch"; + case Op::BranchConditional: + return "BranchConditional"; + case Op::Switch: + return "Switch"; + case Op::Kill: + return "Kill"; + case Op::Return: + return "Return"; + case Op::ReturnValue: + return "ReturnValue"; + case Op::Unreachable: + return "Unreachable"; + case Op::LifetimeStart: + return "LifetimeStart"; + case Op::LifetimeStop: + return "LifetimeStop"; + case Op::GroupAsyncCopy: + return "GroupAsyncCopy"; + case Op::GroupWaitEvents: + return "GroupWaitEvents"; + case Op::GroupAll: + return "GroupAll"; + case Op::GroupAny: + return "GroupAny"; + case Op::GroupBroadcast: + return "GroupBroadcast"; + case Op::GroupIAdd: + return "GroupIAdd"; + case Op::GroupFAdd: + return "GroupFAdd"; + case Op::GroupFMin: + return "GroupFMin"; + case Op::GroupUMin: + return "GroupUMin"; + case Op::GroupSMin: + return "GroupSMin"; + case Op::GroupFMax: + return "GroupFMax"; + case Op::GroupUMax: + return "GroupUMax"; + case Op::GroupSMax: + return "GroupSMax"; + case Op::ReadPipe: + return "ReadPipe"; + case Op::WritePipe: + return "WritePipe"; + case Op::ReservedReadPipe: + return "ReservedReadPipe"; + case Op::ReservedWritePipe: + return "ReservedWritePipe"; + case Op::ReserveReadPipePackets: + return "ReserveReadPipePackets"; + case Op::ReserveWritePipePackets: + return "ReserveWritePipePackets"; + case Op::CommitReadPipe: + return "CommitReadPipe"; + case Op::CommitWritePipe: + return "CommitWritePipe"; + case Op::IsValidReserveId: + return "IsValidReserveId"; + case Op::GetNumPipePackets: + return "GetNumPipePackets"; + case Op::GetMaxPipePackets: + return "GetMaxPipePackets"; + case Op::GroupReserveReadPipePackets: + return "GroupReserveReadPipePackets"; + case Op::GroupReserveWritePipePackets: + return "GroupReserveWritePipePackets"; + case Op::GroupCommitReadPipe: + return "GroupCommitReadPipe"; + case Op::GroupCommitWritePipe: + return "GroupCommitWritePipe"; + case Op::EnqueueMarker: + return "EnqueueMarker"; + case Op::EnqueueKernel: + return "EnqueueKernel"; + case Op::GetKernelNDrangeSubGroupCount: + return "GetKernelNDrangeSubGroupCount"; + case Op::GetKernelNDrangeMaxSubGroupSize: + return "GetKernelNDrangeMaxSubGroupSize"; + case Op::GetKernelWorkGroupSize: + return "GetKernelWorkGroupSize"; + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return "GetKernelPreferredWorkGroupSizeMultiple"; + case Op::RetainEvent: + return "RetainEvent"; + case Op::ReleaseEvent: + return "ReleaseEvent"; + case Op::CreateUserEvent: + return "CreateUserEvent"; + case Op::IsValidEvent: + return "IsValidEvent"; + case Op::SetUserEventStatus: + return "SetUserEventStatus"; + case Op::CaptureEventProfilingInfo: + return "CaptureEventProfilingInfo"; + case Op::GetDefaultQueue: + return "GetDefaultQueue"; + case Op::BuildNDRange: + return "BuildNDRange"; + case Op::ImageSparseSampleImplicitLod: + return "ImageSparseSampleImplicitLod"; + case Op::ImageSparseSampleExplicitLod: + return "ImageSparseSampleExplicitLod"; + case Op::ImageSparseSampleDrefImplicitLod: + return "ImageSparseSampleDrefImplicitLod"; + case Op::ImageSparseSampleDrefExplicitLod: + return "ImageSparseSampleDrefExplicitLod"; + case Op::ImageSparseSampleProjImplicitLod: + return "ImageSparseSampleProjImplicitLod"; + case Op::ImageSparseSampleProjExplicitLod: + return "ImageSparseSampleProjExplicitLod"; + case Op::ImageSparseSampleProjDrefImplicitLod: + return "ImageSparseSampleProjDrefImplicitLod"; + case Op::ImageSparseSampleProjDrefExplicitLod: + return "ImageSparseSampleProjDrefExplicitLod"; + case Op::ImageSparseFetch: + return "ImageSparseFetch"; + case Op::ImageSparseGather: + return "ImageSparseGather"; + case Op::ImageSparseDrefGather: + return "ImageSparseDrefGather"; + case Op::ImageSparseTexelsResident: + return "ImageSparseTexelsResident"; + case Op::NoLine: + return "NoLine"; + case Op::AtomicFlagTestAndSet: + return "AtomicFlagTestAndSet"; + case Op::AtomicFlagClear: + return "AtomicFlagClear"; + case Op::ImageSparseRead: + return "ImageSparseRead"; + case Op::SizeOf: + return "SizeOf"; + case Op::TypePipeStorage: + return "TypePipeStorage"; + case Op::ConstantPipeStorage: + return "ConstantPipeStorage"; + case Op::CreatePipeFromPipeStorage: + return "CreatePipeFromPipeStorage"; + case Op::GetKernelLocalSizeForSubgroupCount: + return "GetKernelLocalSizeForSubgroupCount"; + case Op::GetKernelMaxNumSubgroups: + return "GetKernelMaxNumSubgroups"; + case Op::TypeNamedBarrier: + return "TypeNamedBarrier"; + case Op::NamedBarrierInitialize: + return "NamedBarrierInitialize"; + case Op::MemoryNamedBarrier: + return "MemoryNamedBarrier"; + case Op::ModuleProcessed: + return "ModuleProcessed"; + case Op::ExecutionModeId: + return "ExecutionModeId"; + case Op::DecorateId: + return "DecorateId"; + case Op::GroupNonUniformElect: + return "GroupNonUniformElect"; + case Op::GroupNonUniformAll: + return "GroupNonUniformAll"; + case Op::GroupNonUniformAny: + return "GroupNonUniformAny"; + case Op::GroupNonUniformAllEqual: + return "GroupNonUniformAllEqual"; + case Op::GroupNonUniformBroadcast: + return "GroupNonUniformBroadcast"; + case Op::GroupNonUniformBroadcastFirst: + return "GroupNonUniformBroadcastFirst"; + case Op::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Op::GroupNonUniformInverseBallot: + return "GroupNonUniformInverseBallot"; + case Op::GroupNonUniformBallotBitExtract: + return "GroupNonUniformBallotBitExtract"; + case Op::GroupNonUniformBallotBitCount: + return "GroupNonUniformBallotBitCount"; + case Op::GroupNonUniformBallotFindLSB: + return "GroupNonUniformBallotFindLSB"; + case Op::GroupNonUniformBallotFindMSB: + return "GroupNonUniformBallotFindMSB"; + case Op::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Op::GroupNonUniformShuffleXor: + return "GroupNonUniformShuffleXor"; + case Op::GroupNonUniformShuffleUp: + return "GroupNonUniformShuffleUp"; + case Op::GroupNonUniformShuffleDown: + return "GroupNonUniformShuffleDown"; + case Op::GroupNonUniformIAdd: + return "GroupNonUniformIAdd"; + case Op::GroupNonUniformFAdd: + return "GroupNonUniformFAdd"; + case Op::GroupNonUniformIMul: + return "GroupNonUniformIMul"; + case Op::GroupNonUniformFMul: + return "GroupNonUniformFMul"; + case Op::GroupNonUniformSMin: + return "GroupNonUniformSMin"; + case Op::GroupNonUniformUMin: + return "GroupNonUniformUMin"; + case Op::GroupNonUniformFMin: + return "GroupNonUniformFMin"; + case Op::GroupNonUniformSMax: + return "GroupNonUniformSMax"; + case Op::GroupNonUniformUMax: + return "GroupNonUniformUMax"; + case Op::GroupNonUniformFMax: + return "GroupNonUniformFMax"; + case Op::GroupNonUniformBitwiseAnd: + return "GroupNonUniformBitwiseAnd"; + case Op::GroupNonUniformBitwiseOr: + return "GroupNonUniformBitwiseOr"; + case Op::GroupNonUniformBitwiseXor: + return "GroupNonUniformBitwiseXor"; + case Op::GroupNonUniformLogicalAnd: + return "GroupNonUniformLogicalAnd"; + case Op::GroupNonUniformLogicalOr: + return "GroupNonUniformLogicalOr"; + case Op::GroupNonUniformLogicalXor: + return "GroupNonUniformLogicalXor"; + case Op::GroupNonUniformQuadBroadcast: + return "GroupNonUniformQuadBroadcast"; + case Op::GroupNonUniformQuadSwap: + return "GroupNonUniformQuadSwap"; + case Op::CopyLogical: + return "CopyLogical"; + case Op::PtrEqual: + return "PtrEqual"; + case Op::PtrNotEqual: + return "PtrNotEqual"; + case Op::PtrDiff: + return "PtrDiff"; + case Op::TypeCooperativeMatrixKHR: + return "TypeCooperativeMatrixKHR"; + case Op::CooperativeMatrixLoadKHR: + return "CooperativeMatrixLoadKHR"; + case Op::CooperativeMatrixStoreKHR: + return "CooperativeMatrixStoreKHR"; + case Op::CooperativeMatrixMulAddKHR: + return "CooperativeMatrixMulAddKHR"; + case Op::CooperativeMatrixLengthKHR: + return "CooperativeMatrixLengthKHR"; + case Op::SubgroupBlockReadINTEL: + return "SubgroupBlockReadINTEL"; + case Op::SubgroupBlockWriteINTEL: + return "SubgroupBlockWriteINTEL"; + case Op::AsmTargetINTEL: + return "AsmTargetINTEL"; + case Op::AsmINTEL: + return "AsmINTEL"; + case Op::AsmCallINTEL: + return "AsmCallINTEL"; + case Op::AtomicFMinEXT: + return "AtomicFMinEXT"; + case Op::AtomicFMaxEXT: + return "AtomicFMaxEXT"; + case Op::AtomicFAddEXT: + return "AtomicFAddEXT"; + case Op::ConvertFToBF16INTEL: + return "ConvertFToBF16INTEL"; + case Op::ConvertBF16ToFINTEL: + return "ConvertBF16ToFINTEL"; + case Op::ControlBarrierArriveINTEL: + return "ControlBarrierArriveINTEL"; + case Op::ControlBarrierWaitINTEL: + return "ControlBarrierWaitINTEL"; + case Op::CooperativeMatrixLoadCheckedINTEL: + return "CooperativeMatrixLoadCheckedINTEL"; + case Op::CooperativeMatrixStoreCheckedINTEL: + return "CooperativeMatrixStoreCheckedINTEL"; + } + return "unknown"; +} +auto to_string(ImageOperands e) -> char const * { + switch (e) { + case ImageOperands::None: + return "None"; + case ImageOperands::Bias: + return "Bias"; + case ImageOperands::Lod: + return "Lod"; + case ImageOperands::Grad: + return "Grad"; + case ImageOperands::ConstOffset: + return "ConstOffset"; + case ImageOperands::Offset: + return "Offset"; + case ImageOperands::ConstOffsets: + return "ConstOffsets"; + case ImageOperands::Sample: + return "Sample"; + case ImageOperands::MinLod: + return "MinLod"; + case ImageOperands::MakeTexelAvailable: + return "MakeTexelAvailable"; + case ImageOperands::MakeTexelVisible: + return "MakeTexelVisible"; + case ImageOperands::NonPrivateTexel: + return "NonPrivateTexel"; + case ImageOperands::VolatileTexel: + return "VolatileTexel"; + case ImageOperands::SignExtend: + return "SignExtend"; + case ImageOperands::ZeroExtend: + return "ZeroExtend"; + case ImageOperands::Nontemporal: + return "Nontemporal"; + case ImageOperands::Offsets: + return "Offsets"; + } + return "unknown"; +} +auto to_string(FPFastMathMode e) -> char const * { + switch (e) { + case FPFastMathMode::None: + return "None"; + case FPFastMathMode::NotNaN: + return "NotNaN"; + case FPFastMathMode::NotInf: + return "NotInf"; + case FPFastMathMode::NSZ: + return "NSZ"; + case FPFastMathMode::AllowRecip: + return "AllowRecip"; + case FPFastMathMode::Fast: + return "Fast"; + case FPFastMathMode::AllowContract: + return "AllowContract"; + case FPFastMathMode::AllowReassoc: + return "AllowReassoc"; + case FPFastMathMode::AllowTransform: + return "AllowTransform"; + } + return "unknown"; +} +auto to_string(SelectionControl e) -> char const * { + switch (e) { + case SelectionControl::None: + return "None"; + case SelectionControl::Flatten: + return "Flatten"; + case SelectionControl::DontFlatten: + return "DontFlatten"; + } + return "unknown"; +} +auto to_string(LoopControl e) -> char const * { + switch (e) { + case LoopControl::None: + return "None"; + case LoopControl::Unroll: + return "Unroll"; + case LoopControl::DontUnroll: + return "DontUnroll"; + case LoopControl::DependencyInfinite: + return "DependencyInfinite"; + case LoopControl::DependencyLength: + return "DependencyLength"; + case LoopControl::MinIterations: + return "MinIterations"; + case LoopControl::MaxIterations: + return "MaxIterations"; + case LoopControl::IterationMultiple: + return "IterationMultiple"; + case LoopControl::PeelCount: + return "PeelCount"; + case LoopControl::PartialCount: + return "PartialCount"; + case LoopControl::InitiationIntervalINTEL: + return "InitiationIntervalINTEL"; + case LoopControl::MaxConcurrencyINTEL: + return "MaxConcurrencyINTEL"; + case LoopControl::DependencyArrayINTEL: + return "DependencyArrayINTEL"; + case LoopControl::PipelineEnableINTEL: + return "PipelineEnableINTEL"; + case LoopControl::LoopCoalesceINTEL: + return "LoopCoalesceINTEL"; + case LoopControl::MaxInterleavingINTEL: + return "MaxInterleavingINTEL"; + case LoopControl::SpeculatedIterationsINTEL: + return "SpeculatedIterationsINTEL"; + case LoopControl::NoFusionINTEL: + return "NoFusionINTEL"; + case LoopControl::LoopCountINTEL: + return "LoopCountINTEL"; + case LoopControl::MaxReinvocationDelayINTEL: + return "MaxReinvocationDelayINTEL"; + } + return "unknown"; +} +auto to_string(FunctionControl e) -> char const * { + switch (e) { + case FunctionControl::None: + return "None"; + case FunctionControl::Inline: + return "Inline"; + case FunctionControl::DontInline: + return "DontInline"; + case FunctionControl::Pure: + return "Pure"; + case FunctionControl::Const: + return "Const"; + case FunctionControl::OptNoneEXT: + return "OptNoneEXT"; + } + return "unknown"; +} +auto to_string(MemorySemantics e) -> char const * { + switch (e) { + case MemorySemantics::Relaxed: + return "Relaxed"; + case MemorySemantics::Acquire: + return "Acquire"; + case MemorySemantics::Release: + return "Release"; + case MemorySemantics::AcquireRelease: + return "AcquireRelease"; + case MemorySemantics::SequentiallyConsistent: + return "SequentiallyConsistent"; + case MemorySemantics::UniformMemory: + return "UniformMemory"; + case MemorySemantics::SubgroupMemory: + return "SubgroupMemory"; + case MemorySemantics::WorkgroupMemory: + return "WorkgroupMemory"; + case MemorySemantics::CrossWorkgroupMemory: + return "CrossWorkgroupMemory"; + case MemorySemantics::AtomicCounterMemory: + return "AtomicCounterMemory"; + case MemorySemantics::ImageMemory: + return "ImageMemory"; + case MemorySemantics::OutputMemory: + return "OutputMemory"; + case MemorySemantics::MakeAvailable: + return "MakeAvailable"; + case MemorySemantics::MakeVisible: + return "MakeVisible"; + case MemorySemantics::Volatile: + return "Volatile"; + } + return "unknown"; +} +auto to_string(MemoryAccess e) -> char const * { + switch (e) { + case MemoryAccess::None: + return "None"; + case MemoryAccess::Volatile: + return "Volatile"; + case MemoryAccess::Aligned: + return "Aligned"; + case MemoryAccess::Nontemporal: + return "Nontemporal"; + case MemoryAccess::MakePointerAvailable: + return "MakePointerAvailable"; + case MemoryAccess::MakePointerVisible: + return "MakePointerVisible"; + case MemoryAccess::NonPrivatePointer: + return "NonPrivatePointer"; + case MemoryAccess::AliasScopeINTELMask: + return "AliasScopeINTELMask"; + case MemoryAccess::NoAliasINTELMask: + return "NoAliasINTELMask"; + } + return "unknown"; +} +auto to_string(KernelProfilingInfo e) -> char const * { + switch (e) { + case KernelProfilingInfo::None: + return "None"; + case KernelProfilingInfo::CmdExecTime: + return "CmdExecTime"; + } + return "unknown"; +} +auto to_string(RayFlags e) -> char const * { + switch (e) { + case RayFlags::NoneKHR: + return "NoneKHR"; + case RayFlags::OpaqueKHR: + return "OpaqueKHR"; + case RayFlags::NoOpaqueKHR: + return "NoOpaqueKHR"; + case RayFlags::TerminateOnFirstHitKHR: + return "TerminateOnFirstHitKHR"; + case RayFlags::SkipClosestHitShaderKHR: + return "SkipClosestHitShaderKHR"; + case RayFlags::CullBackFacingTrianglesKHR: + return "CullBackFacingTrianglesKHR"; + case RayFlags::CullFrontFacingTrianglesKHR: + return "CullFrontFacingTrianglesKHR"; + case RayFlags::CullOpaqueKHR: + return "CullOpaqueKHR"; + case RayFlags::CullNoOpaqueKHR: + return "CullNoOpaqueKHR"; + case RayFlags::SkipTrianglesKHR: + return "SkipTrianglesKHR"; + case RayFlags::SkipAABBsKHR: + return "SkipAABBsKHR"; + case RayFlags::ForceOpacityMicromap2StateEXT: + return "ForceOpacityMicromap2StateEXT"; + } + return "unknown"; +} +auto to_string(FragmentShadingRate e) -> char const * { + switch (e) { + case FragmentShadingRate::Vertical2Pixels: + return "Vertical2Pixels"; + case FragmentShadingRate::Vertical4Pixels: + return "Vertical4Pixels"; + case FragmentShadingRate::Horizontal2Pixels: + return "Horizontal2Pixels"; + case FragmentShadingRate::Horizontal4Pixels: + return "Horizontal4Pixels"; + } + return "unknown"; +} +auto to_string(RawAccessChainOperands e) -> char const * { + switch (e) { + case RawAccessChainOperands::None: + return "None"; + case RawAccessChainOperands::RobustnessPerComponentNV: + return "RobustnessPerComponentNV"; + case RawAccessChainOperands::RobustnessPerElementNV: + return "RobustnessPerElementNV"; + } + return "unknown"; +} +auto to_string(SourceLanguage e) -> char const * { + switch (e) { + case SourceLanguage::Unknown: + return "Unknown"; + case SourceLanguage::ESSL: + return "ESSL"; + case SourceLanguage::GLSL: + return "GLSL"; + case SourceLanguage::OpenCL_C: + return "OpenCL_C"; + case SourceLanguage::OpenCL_CPP: + return "OpenCL_CPP"; + case SourceLanguage::HLSL: + return "HLSL"; + case SourceLanguage::CPP_for_OpenCL: + return "CPP_for_OpenCL"; + case SourceLanguage::SYCL: + return "SYCL"; + case SourceLanguage::HERO_C: + return "HERO_C"; + case SourceLanguage::NZSL: + return "NZSL"; + case SourceLanguage::WGSL: + return "WGSL"; + case SourceLanguage::Slang: + return "Slang"; + case SourceLanguage::Zig: + return "Zig"; + case SourceLanguage::Rust: + return "Rust"; + } + return "unknown"; +} +auto to_string(ExecutionModel e) -> char const * { + switch (e) { + case ExecutionModel::Vertex: + return "Vertex"; + case ExecutionModel::TessellationControl: + return "TessellationControl"; + case ExecutionModel::TessellationEvaluation: + return "TessellationEvaluation"; + case ExecutionModel::Geometry: + return "Geometry"; + case ExecutionModel::Fragment: + return "Fragment"; + case ExecutionModel::GLCompute: + return "GLCompute"; + case ExecutionModel::Kernel: + return "Kernel"; + case ExecutionModel::TaskNV: + return "TaskNV"; + case ExecutionModel::MeshNV: + return "MeshNV"; + case ExecutionModel::RayGenerationKHR: + return "RayGenerationKHR"; + case ExecutionModel::IntersectionKHR: + return "IntersectionKHR"; + case ExecutionModel::AnyHitKHR: + return "AnyHitKHR"; + case ExecutionModel::ClosestHitKHR: + return "ClosestHitKHR"; + case ExecutionModel::MissKHR: + return "MissKHR"; + case ExecutionModel::CallableKHR: + return "CallableKHR"; + case ExecutionModel::TaskEXT: + return "TaskEXT"; + case ExecutionModel::MeshEXT: + return "MeshEXT"; + } + return "unknown"; +} +auto to_string(AddressingModel e) -> char const * { + switch (e) { + case AddressingModel::Logical: + return "Logical"; + case AddressingModel::Physical32: + return "Physical32"; + case AddressingModel::Physical64: + return "Physical64"; + case AddressingModel::PhysicalStorageBuffer64: + return "PhysicalStorageBuffer64"; + } + return "unknown"; +} +auto to_string(MemoryModel e) -> char const * { + switch (e) { + case MemoryModel::Simple: + return "Simple"; + case MemoryModel::GLSL450: + return "GLSL450"; + case MemoryModel::OpenCL: + return "OpenCL"; + case MemoryModel::Vulkan: + return "Vulkan"; + } + return "unknown"; +} +auto to_string(ExecutionMode e) -> char const * { + switch (e) { + case ExecutionMode::Invocations: + return "Invocations"; + case ExecutionMode::SpacingEqual: + return "SpacingEqual"; + case ExecutionMode::SpacingFractionalEven: + return "SpacingFractionalEven"; + case ExecutionMode::SpacingFractionalOdd: + return "SpacingFractionalOdd"; + case ExecutionMode::VertexOrderCw: + return "VertexOrderCw"; + case ExecutionMode::VertexOrderCcw: + return "VertexOrderCcw"; + case ExecutionMode::PixelCenterInteger: + return "PixelCenterInteger"; + case ExecutionMode::OriginUpperLeft: + return "OriginUpperLeft"; + case ExecutionMode::OriginLowerLeft: + return "OriginLowerLeft"; + case ExecutionMode::EarlyFragmentTests: + return "EarlyFragmentTests"; + case ExecutionMode::PointMode: + return "PointMode"; + case ExecutionMode::Xfb: + return "Xfb"; + case ExecutionMode::DepthReplacing: + return "DepthReplacing"; + case ExecutionMode::DepthGreater: + return "DepthGreater"; + case ExecutionMode::DepthLess: + return "DepthLess"; + case ExecutionMode::DepthUnchanged: + return "DepthUnchanged"; + case ExecutionMode::LocalSize: + return "LocalSize"; + case ExecutionMode::LocalSizeHint: + return "LocalSizeHint"; + case ExecutionMode::InputPoints: + return "InputPoints"; + case ExecutionMode::InputLines: + return "InputLines"; + case ExecutionMode::InputLinesAdjacency: + return "InputLinesAdjacency"; + case ExecutionMode::Triangles: + return "Triangles"; + case ExecutionMode::InputTrianglesAdjacency: + return "InputTrianglesAdjacency"; + case ExecutionMode::Quads: + return "Quads"; + case ExecutionMode::Isolines: + return "Isolines"; + case ExecutionMode::OutputVertices: + return "OutputVertices"; + case ExecutionMode::OutputPoints: + return "OutputPoints"; + case ExecutionMode::OutputLineStrip: + return "OutputLineStrip"; + case ExecutionMode::OutputTriangleStrip: + return "OutputTriangleStrip"; + case ExecutionMode::VecTypeHint: + return "VecTypeHint"; + case ExecutionMode::ContractionOff: + return "ContractionOff"; + case ExecutionMode::Initializer: + return "Initializer"; + case ExecutionMode::Finalizer: + return "Finalizer"; + case ExecutionMode::SubgroupSize: + return "SubgroupSize"; + case ExecutionMode::SubgroupsPerWorkgroup: + return "SubgroupsPerWorkgroup"; + case ExecutionMode::SubgroupsPerWorkgroupId: + return "SubgroupsPerWorkgroupId"; + case ExecutionMode::LocalSizeId: + return "LocalSizeId"; + case ExecutionMode::LocalSizeHintId: + return "LocalSizeHintId"; + case ExecutionMode::NonCoherentColorAttachmentReadEXT: + return "NonCoherentColorAttachmentReadEXT"; + case ExecutionMode::NonCoherentDepthAttachmentReadEXT: + return "NonCoherentDepthAttachmentReadEXT"; + case ExecutionMode::NonCoherentStencilAttachmentReadEXT: + return "NonCoherentStencilAttachmentReadEXT"; + case ExecutionMode::SubgroupUniformControlFlowKHR: + return "SubgroupUniformControlFlowKHR"; + case ExecutionMode::PostDepthCoverage: + return "PostDepthCoverage"; + case ExecutionMode::DenormPreserve: + return "DenormPreserve"; + case ExecutionMode::DenormFlushToZero: + return "DenormFlushToZero"; + case ExecutionMode::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case ExecutionMode::RoundingModeRTE: + return "RoundingModeRTE"; + case ExecutionMode::RoundingModeRTZ: + return "RoundingModeRTZ"; + case ExecutionMode::NonCoherentTileAttachmentReadQCOM: + return "NonCoherentTileAttachmentReadQCOM"; + case ExecutionMode::TileShadingRateQCOM: + return "TileShadingRateQCOM"; + case ExecutionMode::EarlyAndLateFragmentTestsAMD: + return "EarlyAndLateFragmentTestsAMD"; + case ExecutionMode::StencilRefReplacingEXT: + return "StencilRefReplacingEXT"; + case ExecutionMode::CoalescingAMDX: + return "CoalescingAMDX"; + case ExecutionMode::IsApiEntryAMDX: + return "IsApiEntryAMDX"; + case ExecutionMode::MaxNodeRecursionAMDX: + return "MaxNodeRecursionAMDX"; + case ExecutionMode::StaticNumWorkgroupsAMDX: + return "StaticNumWorkgroupsAMDX"; + case ExecutionMode::ShaderIndexAMDX: + return "ShaderIndexAMDX"; + case ExecutionMode::MaxNumWorkgroupsAMDX: + return "MaxNumWorkgroupsAMDX"; + case ExecutionMode::StencilRefUnchangedFrontAMD: + return "StencilRefUnchangedFrontAMD"; + case ExecutionMode::StencilRefGreaterFrontAMD: + return "StencilRefGreaterFrontAMD"; + case ExecutionMode::StencilRefLessFrontAMD: + return "StencilRefLessFrontAMD"; + case ExecutionMode::StencilRefUnchangedBackAMD: + return "StencilRefUnchangedBackAMD"; + case ExecutionMode::StencilRefGreaterBackAMD: + return "StencilRefGreaterBackAMD"; + case ExecutionMode::StencilRefLessBackAMD: + return "StencilRefLessBackAMD"; + case ExecutionMode::QuadDerivativesKHR: + return "QuadDerivativesKHR"; + case ExecutionMode::RequireFullQuadsKHR: + return "RequireFullQuadsKHR"; + case ExecutionMode::SharesInputWithAMDX: + return "SharesInputWithAMDX"; + case ExecutionMode::OutputLinesEXT: + return "OutputLinesEXT"; + case ExecutionMode::OutputPrimitivesEXT: + return "OutputPrimitivesEXT"; + case ExecutionMode::DerivativeGroupQuadsKHR: + return "DerivativeGroupQuadsKHR"; + case ExecutionMode::DerivativeGroupLinearKHR: + return "DerivativeGroupLinearKHR"; + case ExecutionMode::OutputTrianglesEXT: + return "OutputTrianglesEXT"; + case ExecutionMode::PixelInterlockOrderedEXT: + return "PixelInterlockOrderedEXT"; + case ExecutionMode::PixelInterlockUnorderedEXT: + return "PixelInterlockUnorderedEXT"; + case ExecutionMode::SampleInterlockOrderedEXT: + return "SampleInterlockOrderedEXT"; + case ExecutionMode::SampleInterlockUnorderedEXT: + return "SampleInterlockUnorderedEXT"; + case ExecutionMode::ShadingRateInterlockOrderedEXT: + return "ShadingRateInterlockOrderedEXT"; + case ExecutionMode::ShadingRateInterlockUnorderedEXT: + return "ShadingRateInterlockUnorderedEXT"; + case ExecutionMode::SharedLocalMemorySizeINTEL: + return "SharedLocalMemorySizeINTEL"; + case ExecutionMode::RoundingModeRTPINTEL: + return "RoundingModeRTPINTEL"; + case ExecutionMode::RoundingModeRTNINTEL: + return "RoundingModeRTNINTEL"; + case ExecutionMode::FloatingPointModeALTINTEL: + return "FloatingPointModeALTINTEL"; + case ExecutionMode::FloatingPointModeIEEEINTEL: + return "FloatingPointModeIEEEINTEL"; + case ExecutionMode::MaxWorkgroupSizeINTEL: + return "MaxWorkgroupSizeINTEL"; + case ExecutionMode::MaxWorkDimINTEL: + return "MaxWorkDimINTEL"; + case ExecutionMode::NoGlobalOffsetINTEL: + return "NoGlobalOffsetINTEL"; + case ExecutionMode::NumSIMDWorkitemsINTEL: + return "NumSIMDWorkitemsINTEL"; + case ExecutionMode::SchedulerTargetFmaxMhzINTEL: + return "SchedulerTargetFmaxMhzINTEL"; + case ExecutionMode::MaximallyReconvergesKHR: + return "MaximallyReconvergesKHR"; + case ExecutionMode::FPFastMathDefault: + return "FPFastMathDefault"; + case ExecutionMode::StreamingInterfaceINTEL: + return "StreamingInterfaceINTEL"; + case ExecutionMode::RegisterMapInterfaceINTEL: + return "RegisterMapInterfaceINTEL"; + case ExecutionMode::NamedBarrierCountINTEL: + return "NamedBarrierCountINTEL"; + case ExecutionMode::MaximumRegistersINTEL: + return "MaximumRegistersINTEL"; + case ExecutionMode::MaximumRegistersIdINTEL: + return "MaximumRegistersIdINTEL"; + case ExecutionMode::NamedMaximumRegistersINTEL: + return "NamedMaximumRegistersINTEL"; + } + return "unknown"; +} +auto to_string(StorageClass e) -> char const * { + switch (e) { + case StorageClass::UniformConstant: + return "UniformConstant"; + case StorageClass::Input: + return "Input"; + case StorageClass::Uniform: + return "Uniform"; + case StorageClass::Output: + return "Output"; + case StorageClass::Workgroup: + return "Workgroup"; + case StorageClass::CrossWorkgroup: + return "CrossWorkgroup"; + case StorageClass::Private: + return "Private"; + case StorageClass::Function: + return "Function"; + case StorageClass::Generic: + return "Generic"; + case StorageClass::PushConstant: + return "PushConstant"; + case StorageClass::AtomicCounter: + return "AtomicCounter"; + case StorageClass::Image: + return "Image"; + case StorageClass::StorageBuffer: + return "StorageBuffer"; + case StorageClass::TileImageEXT: + return "TileImageEXT"; + case StorageClass::TileAttachmentQCOM: + return "TileAttachmentQCOM"; + case StorageClass::NodePayloadAMDX: + return "NodePayloadAMDX"; + case StorageClass::CallableDataKHR: + return "CallableDataKHR"; + case StorageClass::IncomingCallableDataKHR: + return "IncomingCallableDataKHR"; + case StorageClass::RayPayloadKHR: + return "RayPayloadKHR"; + case StorageClass::HitAttributeKHR: + return "HitAttributeKHR"; + case StorageClass::IncomingRayPayloadKHR: + return "IncomingRayPayloadKHR"; + case StorageClass::ShaderRecordBufferKHR: + return "ShaderRecordBufferKHR"; + case StorageClass::PhysicalStorageBuffer: + return "PhysicalStorageBuffer"; + case StorageClass::HitObjectAttributeNV: + return "HitObjectAttributeNV"; + case StorageClass::TaskPayloadWorkgroupEXT: + return "TaskPayloadWorkgroupEXT"; + case StorageClass::CodeSectionINTEL: + return "CodeSectionINTEL"; + case StorageClass::DeviceOnlyINTEL: + return "DeviceOnlyINTEL"; + case StorageClass::HostOnlyINTEL: + return "HostOnlyINTEL"; + } + return "unknown"; +} +auto to_string(Dim e) -> char const * { + switch (e) { + case Dim::Dim1D: + return "Dim1D"; + case Dim::Dim2D: + return "Dim2D"; + case Dim::Dim3D: + return "Dim3D"; + case Dim::Cube: + return "Cube"; + case Dim::Rect: + return "Rect"; + case Dim::Buffer: + return "Buffer"; + case Dim::SubpassData: + return "SubpassData"; + case Dim::TileImageDataEXT: + return "TileImageDataEXT"; + } + return "unknown"; +} +auto to_string(SamplerAddressingMode e) -> char const * { + switch (e) { + case SamplerAddressingMode::None: + return "None"; + case SamplerAddressingMode::ClampToEdge: + return "ClampToEdge"; + case SamplerAddressingMode::Clamp: + return "Clamp"; + case SamplerAddressingMode::Repeat: + return "Repeat"; + case SamplerAddressingMode::RepeatMirrored: + return "RepeatMirrored"; + } + return "unknown"; +} +auto to_string(SamplerFilterMode e) -> char const * { + switch (e) { + case SamplerFilterMode::Nearest: + return "Nearest"; + case SamplerFilterMode::Linear: + return "Linear"; + } + return "unknown"; +} +auto to_string(ImageFormat e) -> char const * { + switch (e) { + case ImageFormat::Unknown: + return "Unknown"; + case ImageFormat::Rgba32f: + return "Rgba32f"; + case ImageFormat::Rgba16f: + return "Rgba16f"; + case ImageFormat::R32f: + return "R32f"; + case ImageFormat::Rgba8: + return "Rgba8"; + case ImageFormat::Rgba8Snorm: + return "Rgba8Snorm"; + case ImageFormat::Rg32f: + return "Rg32f"; + case ImageFormat::Rg16f: + return "Rg16f"; + case ImageFormat::R11fG11fB10f: + return "R11fG11fB10f"; + case ImageFormat::R16f: + return "R16f"; + case ImageFormat::Rgba16: + return "Rgba16"; + case ImageFormat::Rgb10A2: + return "Rgb10A2"; + case ImageFormat::Rg16: + return "Rg16"; + case ImageFormat::Rg8: + return "Rg8"; + case ImageFormat::R16: + return "R16"; + case ImageFormat::R8: + return "R8"; + case ImageFormat::Rgba16Snorm: + return "Rgba16Snorm"; + case ImageFormat::Rg16Snorm: + return "Rg16Snorm"; + case ImageFormat::Rg8Snorm: + return "Rg8Snorm"; + case ImageFormat::R16Snorm: + return "R16Snorm"; + case ImageFormat::R8Snorm: + return "R8Snorm"; + case ImageFormat::Rgba32i: + return "Rgba32i"; + case ImageFormat::Rgba16i: + return "Rgba16i"; + case ImageFormat::Rgba8i: + return "Rgba8i"; + case ImageFormat::R32i: + return "R32i"; + case ImageFormat::Rg32i: + return "Rg32i"; + case ImageFormat::Rg16i: + return "Rg16i"; + case ImageFormat::Rg8i: + return "Rg8i"; + case ImageFormat::R16i: + return "R16i"; + case ImageFormat::R8i: + return "R8i"; + case ImageFormat::Rgba32ui: + return "Rgba32ui"; + case ImageFormat::Rgba16ui: + return "Rgba16ui"; + case ImageFormat::Rgba8ui: + return "Rgba8ui"; + case ImageFormat::R32ui: + return "R32ui"; + case ImageFormat::Rgb10a2ui: + return "Rgb10a2ui"; + case ImageFormat::Rg32ui: + return "Rg32ui"; + case ImageFormat::Rg16ui: + return "Rg16ui"; + case ImageFormat::Rg8ui: + return "Rg8ui"; + case ImageFormat::R16ui: + return "R16ui"; + case ImageFormat::R8ui: + return "R8ui"; + case ImageFormat::R64ui: + return "R64ui"; + case ImageFormat::R64i: + return "R64i"; + } + return "unknown"; +} +auto to_string(ImageChannelOrder e) -> char const * { + switch (e) { + case ImageChannelOrder::R: + return "R"; + case ImageChannelOrder::A: + return "A"; + case ImageChannelOrder::RG: + return "RG"; + case ImageChannelOrder::RA: + return "RA"; + case ImageChannelOrder::RGB: + return "RGB"; + case ImageChannelOrder::RGBA: + return "RGBA"; + case ImageChannelOrder::BGRA: + return "BGRA"; + case ImageChannelOrder::ARGB: + return "ARGB"; + case ImageChannelOrder::Intensity: + return "Intensity"; + case ImageChannelOrder::Luminance: + return "Luminance"; + case ImageChannelOrder::Rx: + return "Rx"; + case ImageChannelOrder::RGx: + return "RGx"; + case ImageChannelOrder::RGBx: + return "RGBx"; + case ImageChannelOrder::Depth: + return "Depth"; + case ImageChannelOrder::DepthStencil: + return "DepthStencil"; + case ImageChannelOrder::sRGB: + return "sRGB"; + case ImageChannelOrder::sRGBx: + return "sRGBx"; + case ImageChannelOrder::sRGBA: + return "sRGBA"; + case ImageChannelOrder::sBGRA: + return "sBGRA"; + case ImageChannelOrder::ABGR: + return "ABGR"; + } + return "unknown"; +} +auto to_string(ImageChannelDataType e) -> char const * { + switch (e) { + case ImageChannelDataType::SnormInt8: + return "SnormInt8"; + case ImageChannelDataType::SnormInt16: + return "SnormInt16"; + case ImageChannelDataType::UnormInt8: + return "UnormInt8"; + case ImageChannelDataType::UnormInt16: + return "UnormInt16"; + case ImageChannelDataType::UnormShort565: + return "UnormShort565"; + case ImageChannelDataType::UnormShort555: + return "UnormShort555"; + case ImageChannelDataType::UnormInt101010: + return "UnormInt101010"; + case ImageChannelDataType::SignedInt8: + return "SignedInt8"; + case ImageChannelDataType::SignedInt16: + return "SignedInt16"; + case ImageChannelDataType::SignedInt32: + return "SignedInt32"; + case ImageChannelDataType::UnsignedInt8: + return "UnsignedInt8"; + case ImageChannelDataType::UnsignedInt16: + return "UnsignedInt16"; + case ImageChannelDataType::UnsignedInt32: + return "UnsignedInt32"; + case ImageChannelDataType::HalfFloat: + return "HalfFloat"; + case ImageChannelDataType::Float: + return "Float"; + case ImageChannelDataType::UnormInt24: + return "UnormInt24"; + case ImageChannelDataType::UnormInt101010_2: + return "UnormInt101010_2"; + case ImageChannelDataType::UnormInt10X6EXT: + return "UnormInt10X6EXT"; + case ImageChannelDataType::UnsignedIntRaw10EXT: + return "UnsignedIntRaw10EXT"; + case ImageChannelDataType::UnsignedIntRaw12EXT: + return "UnsignedIntRaw12EXT"; + case ImageChannelDataType::UnormInt2_101010EXT: + return "UnormInt2_101010EXT"; + case ImageChannelDataType::UnsignedInt10X6EXT: + return "UnsignedInt10X6EXT"; + case ImageChannelDataType::UnsignedInt12X4EXT: + return "UnsignedInt12X4EXT"; + case ImageChannelDataType::UnsignedInt14X2EXT: + return "UnsignedInt14X2EXT"; + case ImageChannelDataType::UnormInt12X4EXT: + return "UnormInt12X4EXT"; + case ImageChannelDataType::UnormInt14X2EXT: + return "UnormInt14X2EXT"; + } + return "unknown"; +} +auto to_string(FPRoundingMode e) -> char const * { + switch (e) { + case FPRoundingMode::RTE: + return "RTE"; + case FPRoundingMode::RTZ: + return "RTZ"; + case FPRoundingMode::RTP: + return "RTP"; + case FPRoundingMode::RTN: + return "RTN"; + } + return "unknown"; +} +auto to_string(FPDenormMode e) -> char const * { + switch (e) { + case FPDenormMode::Preserve: + return "Preserve"; + case FPDenormMode::FlushToZero: + return "FlushToZero"; + } + return "unknown"; +} +auto to_string(QuantizationModes e) -> char const * { + switch (e) { + case QuantizationModes::TRN: + return "TRN"; + case QuantizationModes::TRN_ZERO: + return "TRN_ZERO"; + case QuantizationModes::RND: + return "RND"; + case QuantizationModes::RND_ZERO: + return "RND_ZERO"; + case QuantizationModes::RND_INF: + return "RND_INF"; + case QuantizationModes::RND_MIN_INF: + return "RND_MIN_INF"; + case QuantizationModes::RND_CONV: + return "RND_CONV"; + case QuantizationModes::RND_CONV_ODD: + return "RND_CONV_ODD"; + } + return "unknown"; +} +auto to_string(FPOperationMode e) -> char const * { + switch (e) { + case FPOperationMode::IEEE: + return "IEEE"; + case FPOperationMode::ALT: + return "ALT"; + } + return "unknown"; +} +auto to_string(OverflowModes e) -> char const * { + switch (e) { + case OverflowModes::WRAP: + return "WRAP"; + case OverflowModes::SAT: + return "SAT"; + case OverflowModes::SAT_ZERO: + return "SAT_ZERO"; + case OverflowModes::SAT_SYM: + return "SAT_SYM"; + } + return "unknown"; +} +auto to_string(LinkageType e) -> char const * { + switch (e) { + case LinkageType::Export: + return "Export"; + case LinkageType::Import: + return "Import"; + case LinkageType::LinkOnceODR: + return "LinkOnceODR"; + } + return "unknown"; +} +auto to_string(AccessQualifier e) -> char const * { + switch (e) { + case AccessQualifier::ReadOnly: + return "ReadOnly"; + case AccessQualifier::WriteOnly: + return "WriteOnly"; + case AccessQualifier::ReadWrite: + return "ReadWrite"; + } + return "unknown"; +} +auto to_string(HostAccessQualifier e) -> char const * { + switch (e) { + case HostAccessQualifier::NoneINTEL: + return "NoneINTEL"; + case HostAccessQualifier::ReadINTEL: + return "ReadINTEL"; + case HostAccessQualifier::WriteINTEL: + return "WriteINTEL"; + case HostAccessQualifier::ReadWriteINTEL: + return "ReadWriteINTEL"; + } + return "unknown"; +} +auto to_string(FunctionParameterAttribute e) -> char const * { + switch (e) { + case FunctionParameterAttribute::Zext: + return "Zext"; + case FunctionParameterAttribute::Sext: + return "Sext"; + case FunctionParameterAttribute::ByVal: + return "ByVal"; + case FunctionParameterAttribute::Sret: + return "Sret"; + case FunctionParameterAttribute::NoAlias: + return "NoAlias"; + case FunctionParameterAttribute::NoCapture: + return "NoCapture"; + case FunctionParameterAttribute::NoWrite: + return "NoWrite"; + case FunctionParameterAttribute::NoReadWrite: + return "NoReadWrite"; + case FunctionParameterAttribute::RuntimeAlignedINTEL: + return "RuntimeAlignedINTEL"; + } + return "unknown"; +} +auto to_string(Decoration e) -> char const * { + switch (e) { + case Decoration::RelaxedPrecision: + return "RelaxedPrecision"; + case Decoration::SpecId: + return "SpecId"; + case Decoration::Block: + return "Block"; + case Decoration::BufferBlock: + return "BufferBlock"; + case Decoration::RowMajor: + return "RowMajor"; + case Decoration::ColMajor: + return "ColMajor"; + case Decoration::ArrayStride: + return "ArrayStride"; + case Decoration::MatrixStride: + return "MatrixStride"; + case Decoration::GLSLShared: + return "GLSLShared"; + case Decoration::GLSLPacked: + return "GLSLPacked"; + case Decoration::CPacked: + return "CPacked"; + case Decoration::BuiltIn: + return "BuiltIn"; + case Decoration::NoPerspective: + return "NoPerspective"; + case Decoration::Flat: + return "Flat"; + case Decoration::Patch: + return "Patch"; + case Decoration::Centroid: + return "Centroid"; + case Decoration::Sample: + return "Sample"; + case Decoration::Invariant: + return "Invariant"; + case Decoration::Restrict: + return "Restrict"; + case Decoration::Aliased: + return "Aliased"; + case Decoration::Volatile: + return "Volatile"; + case Decoration::Constant: + return "Constant"; + case Decoration::Coherent: + return "Coherent"; + case Decoration::NonWritable: + return "NonWritable"; + case Decoration::NonReadable: + return "NonReadable"; + case Decoration::Uniform: + return "Uniform"; + case Decoration::UniformId: + return "UniformId"; + case Decoration::SaturatedConversion: + return "SaturatedConversion"; + case Decoration::Stream: + return "Stream"; + case Decoration::Location: + return "Location"; + case Decoration::Component: + return "Component"; + case Decoration::Index: + return "Index"; + case Decoration::Binding: + return "Binding"; + case Decoration::DescriptorSet: + return "DescriptorSet"; + case Decoration::Offset: + return "Offset"; + case Decoration::XfbBuffer: + return "XfbBuffer"; + case Decoration::XfbStride: + return "XfbStride"; + case Decoration::FuncParamAttr: + return "FuncParamAttr"; + case Decoration::FPRoundingMode: + return "FPRoundingMode"; + case Decoration::FPFastMathMode: + return "FPFastMathMode"; + case Decoration::LinkageAttributes: + return "LinkageAttributes"; + case Decoration::NoContraction: + return "NoContraction"; + case Decoration::InputAttachmentIndex: + return "InputAttachmentIndex"; + case Decoration::Alignment: + return "Alignment"; + case Decoration::MaxByteOffset: + return "MaxByteOffset"; + case Decoration::AlignmentId: + return "AlignmentId"; + case Decoration::MaxByteOffsetId: + return "MaxByteOffsetId"; + case Decoration::SaturatedToLargestFloat8NormalConversionEXT: + return "SaturatedToLargestFloat8NormalConversionEXT"; + case Decoration::NoSignedWrap: + return "NoSignedWrap"; + case Decoration::NoUnsignedWrap: + return "NoUnsignedWrap"; + case Decoration::WeightTextureQCOM: + return "WeightTextureQCOM"; + case Decoration::BlockMatchTextureQCOM: + return "BlockMatchTextureQCOM"; + case Decoration::BlockMatchSamplerQCOM: + return "BlockMatchSamplerQCOM"; + case Decoration::ExplicitInterpAMD: + return "ExplicitInterpAMD"; + case Decoration::NodeSharesPayloadLimitsWithAMDX: + return "NodeSharesPayloadLimitsWithAMDX"; + case Decoration::NodeMaxPayloadsAMDX: + return "NodeMaxPayloadsAMDX"; + case Decoration::TrackFinishWritingAMDX: + return "TrackFinishWritingAMDX"; + case Decoration::PayloadNodeNameAMDX: + return "PayloadNodeNameAMDX"; + case Decoration::PayloadNodeBaseIndexAMDX: + return "PayloadNodeBaseIndexAMDX"; + case Decoration::PayloadNodeSparseArrayAMDX: + return "PayloadNodeSparseArrayAMDX"; + case Decoration::PayloadNodeArraySizeAMDX: + return "PayloadNodeArraySizeAMDX"; + case Decoration::PayloadDispatchIndirectAMDX: + return "PayloadDispatchIndirectAMDX"; + case Decoration::OverrideCoverageNV: + return "OverrideCoverageNV"; + case Decoration::PassthroughNV: + return "PassthroughNV"; + case Decoration::ViewportRelativeNV: + return "ViewportRelativeNV"; + case Decoration::SecondaryViewportRelativeNV: + return "SecondaryViewportRelativeNV"; + case Decoration::PerPrimitiveEXT: + return "PerPrimitiveEXT"; + case Decoration::PerViewNV: + return "PerViewNV"; + case Decoration::PerTaskNV: + return "PerTaskNV"; + case Decoration::PerVertexKHR: + return "PerVertexKHR"; + case Decoration::NonUniform: + return "NonUniform"; + case Decoration::RestrictPointer: + return "RestrictPointer"; + case Decoration::AliasedPointer: + return "AliasedPointer"; + case Decoration::HitObjectShaderRecordBufferNV: + return "HitObjectShaderRecordBufferNV"; + case Decoration::BindlessSamplerNV: + return "BindlessSamplerNV"; + case Decoration::BindlessImageNV: + return "BindlessImageNV"; + case Decoration::BoundSamplerNV: + return "BoundSamplerNV"; + case Decoration::BoundImageNV: + return "BoundImageNV"; + case Decoration::SIMTCallINTEL: + return "SIMTCallINTEL"; + case Decoration::ReferencedIndirectlyINTEL: + return "ReferencedIndirectlyINTEL"; + case Decoration::ClobberINTEL: + return "ClobberINTEL"; + case Decoration::SideEffectsINTEL: + return "SideEffectsINTEL"; + case Decoration::VectorComputeVariableINTEL: + return "VectorComputeVariableINTEL"; + case Decoration::FuncParamIOKindINTEL: + return "FuncParamIOKindINTEL"; + case Decoration::VectorComputeFunctionINTEL: + return "VectorComputeFunctionINTEL"; + case Decoration::StackCallINTEL: + return "StackCallINTEL"; + case Decoration::GlobalVariableOffsetINTEL: + return "GlobalVariableOffsetINTEL"; + case Decoration::CounterBuffer: + return "CounterBuffer"; + case Decoration::UserSemantic: + return "UserSemantic"; + case Decoration::UserTypeGOOGLE: + return "UserTypeGOOGLE"; + case Decoration::FunctionRoundingModeINTEL: + return "FunctionRoundingModeINTEL"; + case Decoration::FunctionDenormModeINTEL: + return "FunctionDenormModeINTEL"; + case Decoration::RegisterINTEL: + return "RegisterINTEL"; + case Decoration::MemoryINTEL: + return "MemoryINTEL"; + case Decoration::NumbanksINTEL: + return "NumbanksINTEL"; + case Decoration::BankwidthINTEL: + return "BankwidthINTEL"; + case Decoration::MaxPrivateCopiesINTEL: + return "MaxPrivateCopiesINTEL"; + case Decoration::SinglepumpINTEL: + return "SinglepumpINTEL"; + case Decoration::DoublepumpINTEL: + return "DoublepumpINTEL"; + case Decoration::MaxReplicatesINTEL: + return "MaxReplicatesINTEL"; + case Decoration::SimpleDualPortINTEL: + return "SimpleDualPortINTEL"; + case Decoration::MergeINTEL: + return "MergeINTEL"; + case Decoration::BankBitsINTEL: + return "BankBitsINTEL"; + case Decoration::ForcePow2DepthINTEL: + return "ForcePow2DepthINTEL"; + case Decoration::StridesizeINTEL: + return "StridesizeINTEL"; + case Decoration::WordsizeINTEL: + return "WordsizeINTEL"; + case Decoration::TrueDualPortINTEL: + return "TrueDualPortINTEL"; + case Decoration::BurstCoalesceINTEL: + return "BurstCoalesceINTEL"; + case Decoration::CacheSizeINTEL: + return "CacheSizeINTEL"; + case Decoration::DontStaticallyCoalesceINTEL: + return "DontStaticallyCoalesceINTEL"; + case Decoration::PrefetchINTEL: + return "PrefetchINTEL"; + case Decoration::StallEnableINTEL: + return "StallEnableINTEL"; + case Decoration::FuseLoopsInFunctionINTEL: + return "FuseLoopsInFunctionINTEL"; + case Decoration::MathOpDSPModeINTEL: + return "MathOpDSPModeINTEL"; + case Decoration::AliasScopeINTEL: + return "AliasScopeINTEL"; + case Decoration::NoAliasINTEL: + return "NoAliasINTEL"; + case Decoration::InitiationIntervalINTEL: + return "InitiationIntervalINTEL"; + case Decoration::MaxConcurrencyINTEL: + return "MaxConcurrencyINTEL"; + case Decoration::PipelineEnableINTEL: + return "PipelineEnableINTEL"; + case Decoration::BufferLocationINTEL: + return "BufferLocationINTEL"; + case Decoration::IOPipeStorageINTEL: + return "IOPipeStorageINTEL"; + case Decoration::FunctionFloatingPointModeINTEL: + return "FunctionFloatingPointModeINTEL"; + case Decoration::SingleElementVectorINTEL: + return "SingleElementVectorINTEL"; + case Decoration::VectorComputeCallableFunctionINTEL: + return "VectorComputeCallableFunctionINTEL"; + case Decoration::MediaBlockIOINTEL: + return "MediaBlockIOINTEL"; + case Decoration::StallFreeINTEL: + return "StallFreeINTEL"; + case Decoration::FPMaxErrorDecorationINTEL: + return "FPMaxErrorDecorationINTEL"; + case Decoration::LatencyControlLabelINTEL: + return "LatencyControlLabelINTEL"; + case Decoration::LatencyControlConstraintINTEL: + return "LatencyControlConstraintINTEL"; + case Decoration::ConduitKernelArgumentINTEL: + return "ConduitKernelArgumentINTEL"; + case Decoration::RegisterMapKernelArgumentINTEL: + return "RegisterMapKernelArgumentINTEL"; + case Decoration::MMHostInterfaceAddressWidthINTEL: + return "MMHostInterfaceAddressWidthINTEL"; + case Decoration::MMHostInterfaceDataWidthINTEL: + return "MMHostInterfaceDataWidthINTEL"; + case Decoration::MMHostInterfaceLatencyINTEL: + return "MMHostInterfaceLatencyINTEL"; + case Decoration::MMHostInterfaceReadWriteModeINTEL: + return "MMHostInterfaceReadWriteModeINTEL"; + case Decoration::MMHostInterfaceMaxBurstINTEL: + return "MMHostInterfaceMaxBurstINTEL"; + case Decoration::MMHostInterfaceWaitRequestINTEL: + return "MMHostInterfaceWaitRequestINTEL"; + case Decoration::StableKernelArgumentINTEL: + return "StableKernelArgumentINTEL"; + case Decoration::HostAccessINTEL: + return "HostAccessINTEL"; + case Decoration::InitModeINTEL: + return "InitModeINTEL"; + case Decoration::ImplementInRegisterMapINTEL: + return "ImplementInRegisterMapINTEL"; + case Decoration::CacheControlLoadINTEL: + return "CacheControlLoadINTEL"; + case Decoration::CacheControlStoreINTEL: + return "CacheControlStoreINTEL"; + } + return "unknown"; +} +auto to_string(BuiltIn e) -> char const * { + switch (e) { + case BuiltIn::Position: + return "Position"; + case BuiltIn::PointSize: + return "PointSize"; + case BuiltIn::ClipDistance: + return "ClipDistance"; + case BuiltIn::CullDistance: + return "CullDistance"; + case BuiltIn::VertexId: + return "VertexId"; + case BuiltIn::InstanceId: + return "InstanceId"; + case BuiltIn::PrimitiveId: + return "PrimitiveId"; + case BuiltIn::InvocationId: + return "InvocationId"; + case BuiltIn::Layer: + return "Layer"; + case BuiltIn::ViewportIndex: + return "ViewportIndex"; + case BuiltIn::TessLevelOuter: + return "TessLevelOuter"; + case BuiltIn::TessLevelInner: + return "TessLevelInner"; + case BuiltIn::TessCoord: + return "TessCoord"; + case BuiltIn::PatchVertices: + return "PatchVertices"; + case BuiltIn::FragCoord: + return "FragCoord"; + case BuiltIn::PointCoord: + return "PointCoord"; + case BuiltIn::FrontFacing: + return "FrontFacing"; + case BuiltIn::SampleId: + return "SampleId"; + case BuiltIn::SamplePosition: + return "SamplePosition"; + case BuiltIn::SampleMask: + return "SampleMask"; + case BuiltIn::FragDepth: + return "FragDepth"; + case BuiltIn::HelperInvocation: + return "HelperInvocation"; + case BuiltIn::NumWorkgroups: + return "NumWorkgroups"; + case BuiltIn::WorkgroupSize: + return "WorkgroupSize"; + case BuiltIn::WorkgroupId: + return "WorkgroupId"; + case BuiltIn::LocalInvocationId: + return "LocalInvocationId"; + case BuiltIn::GlobalInvocationId: + return "GlobalInvocationId"; + case BuiltIn::LocalInvocationIndex: + return "LocalInvocationIndex"; + case BuiltIn::WorkDim: + return "WorkDim"; + case BuiltIn::GlobalSize: + return "GlobalSize"; + case BuiltIn::EnqueuedWorkgroupSize: + return "EnqueuedWorkgroupSize"; + case BuiltIn::GlobalOffset: + return "GlobalOffset"; + case BuiltIn::GlobalLinearId: + return "GlobalLinearId"; + case BuiltIn::SubgroupSize: + return "SubgroupSize"; + case BuiltIn::SubgroupMaxSize: + return "SubgroupMaxSize"; + case BuiltIn::NumSubgroups: + return "NumSubgroups"; + case BuiltIn::NumEnqueuedSubgroups: + return "NumEnqueuedSubgroups"; + case BuiltIn::SubgroupId: + return "SubgroupId"; + case BuiltIn::SubgroupLocalInvocationId: + return "SubgroupLocalInvocationId"; + case BuiltIn::VertexIndex: + return "VertexIndex"; + case BuiltIn::InstanceIndex: + return "InstanceIndex"; + case BuiltIn::CoreIDARM: + return "CoreIDARM"; + case BuiltIn::CoreCountARM: + return "CoreCountARM"; + case BuiltIn::CoreMaxIDARM: + return "CoreMaxIDARM"; + case BuiltIn::WarpIDARM: + return "WarpIDARM"; + case BuiltIn::WarpMaxIDARM: + return "WarpMaxIDARM"; + case BuiltIn::SubgroupEqMask: + return "SubgroupEqMask"; + case BuiltIn::SubgroupGeMask: + return "SubgroupGeMask"; + case BuiltIn::SubgroupGtMask: + return "SubgroupGtMask"; + case BuiltIn::SubgroupLeMask: + return "SubgroupLeMask"; + case BuiltIn::SubgroupLtMask: + return "SubgroupLtMask"; + case BuiltIn::BaseVertex: + return "BaseVertex"; + case BuiltIn::BaseInstance: + return "BaseInstance"; + case BuiltIn::DrawIndex: + return "DrawIndex"; + case BuiltIn::PrimitiveShadingRateKHR: + return "PrimitiveShadingRateKHR"; + case BuiltIn::DeviceIndex: + return "DeviceIndex"; + case BuiltIn::ViewIndex: + return "ViewIndex"; + case BuiltIn::ShadingRateKHR: + return "ShadingRateKHR"; + case BuiltIn::TileOffsetQCOM: + return "TileOffsetQCOM"; + case BuiltIn::TileDimensionQCOM: + return "TileDimensionQCOM"; + case BuiltIn::TileApronSizeQCOM: + return "TileApronSizeQCOM"; + case BuiltIn::BaryCoordNoPerspAMD: + return "BaryCoordNoPerspAMD"; + case BuiltIn::BaryCoordNoPerspCentroidAMD: + return "BaryCoordNoPerspCentroidAMD"; + case BuiltIn::BaryCoordNoPerspSampleAMD: + return "BaryCoordNoPerspSampleAMD"; + case BuiltIn::BaryCoordSmoothAMD: + return "BaryCoordSmoothAMD"; + case BuiltIn::BaryCoordSmoothCentroidAMD: + return "BaryCoordSmoothCentroidAMD"; + case BuiltIn::BaryCoordSmoothSampleAMD: + return "BaryCoordSmoothSampleAMD"; + case BuiltIn::BaryCoordPullModelAMD: + return "BaryCoordPullModelAMD"; + case BuiltIn::FragStencilRefEXT: + return "FragStencilRefEXT"; + case BuiltIn::RemainingRecursionLevelsAMDX: + return "RemainingRecursionLevelsAMDX"; + case BuiltIn::ShaderIndexAMDX: + return "ShaderIndexAMDX"; + case BuiltIn::ViewportMaskNV: + return "ViewportMaskNV"; + case BuiltIn::SecondaryPositionNV: + return "SecondaryPositionNV"; + case BuiltIn::SecondaryViewportMaskNV: + return "SecondaryViewportMaskNV"; + case BuiltIn::PositionPerViewNV: + return "PositionPerViewNV"; + case BuiltIn::ViewportMaskPerViewNV: + return "ViewportMaskPerViewNV"; + case BuiltIn::FullyCoveredEXT: + return "FullyCoveredEXT"; + case BuiltIn::TaskCountNV: + return "TaskCountNV"; + case BuiltIn::PrimitiveCountNV: + return "PrimitiveCountNV"; + case BuiltIn::PrimitiveIndicesNV: + return "PrimitiveIndicesNV"; + case BuiltIn::ClipDistancePerViewNV: + return "ClipDistancePerViewNV"; + case BuiltIn::CullDistancePerViewNV: + return "CullDistancePerViewNV"; + case BuiltIn::LayerPerViewNV: + return "LayerPerViewNV"; + case BuiltIn::MeshViewCountNV: + return "MeshViewCountNV"; + case BuiltIn::MeshViewIndicesNV: + return "MeshViewIndicesNV"; + case BuiltIn::BaryCoordKHR: + return "BaryCoordKHR"; + case BuiltIn::BaryCoordNoPerspKHR: + return "BaryCoordNoPerspKHR"; + case BuiltIn::FragSizeEXT: + return "FragSizeEXT"; + case BuiltIn::FragInvocationCountEXT: + return "FragInvocationCountEXT"; + case BuiltIn::PrimitivePointIndicesEXT: + return "PrimitivePointIndicesEXT"; + case BuiltIn::PrimitiveLineIndicesEXT: + return "PrimitiveLineIndicesEXT"; + case BuiltIn::PrimitiveTriangleIndicesEXT: + return "PrimitiveTriangleIndicesEXT"; + case BuiltIn::CullPrimitiveEXT: + return "CullPrimitiveEXT"; + case BuiltIn::LaunchIdKHR: + return "LaunchIdKHR"; + case BuiltIn::LaunchSizeKHR: + return "LaunchSizeKHR"; + case BuiltIn::WorldRayOriginKHR: + return "WorldRayOriginKHR"; + case BuiltIn::WorldRayDirectionKHR: + return "WorldRayDirectionKHR"; + case BuiltIn::ObjectRayOriginKHR: + return "ObjectRayOriginKHR"; + case BuiltIn::ObjectRayDirectionKHR: + return "ObjectRayDirectionKHR"; + case BuiltIn::RayTminKHR: + return "RayTminKHR"; + case BuiltIn::RayTmaxKHR: + return "RayTmaxKHR"; + case BuiltIn::InstanceCustomIndexKHR: + return "InstanceCustomIndexKHR"; + case BuiltIn::ObjectToWorldKHR: + return "ObjectToWorldKHR"; + case BuiltIn::WorldToObjectKHR: + return "WorldToObjectKHR"; + case BuiltIn::HitTNV: + return "HitTNV"; + case BuiltIn::HitKindKHR: + return "HitKindKHR"; + case BuiltIn::CurrentRayTimeNV: + return "CurrentRayTimeNV"; + case BuiltIn::HitTriangleVertexPositionsKHR: + return "HitTriangleVertexPositionsKHR"; + case BuiltIn::HitMicroTriangleVertexPositionsNV: + return "HitMicroTriangleVertexPositionsNV"; + case BuiltIn::HitMicroTriangleVertexBarycentricsNV: + return "HitMicroTriangleVertexBarycentricsNV"; + case BuiltIn::IncomingRayFlagsKHR: + return "IncomingRayFlagsKHR"; + case BuiltIn::RayGeometryIndexKHR: + return "RayGeometryIndexKHR"; + case BuiltIn::HitIsSphereNV: + return "HitIsSphereNV"; + case BuiltIn::HitIsLSSNV: + return "HitIsLSSNV"; + case BuiltIn::HitSpherePositionNV: + return "HitSpherePositionNV"; + case BuiltIn::WarpsPerSMNV: + return "WarpsPerSMNV"; + case BuiltIn::SMCountNV: + return "SMCountNV"; + case BuiltIn::WarpIDNV: + return "WarpIDNV"; + case BuiltIn::SMIDNV: + return "SMIDNV"; + case BuiltIn::HitLSSPositionsNV: + return "HitLSSPositionsNV"; + case BuiltIn::HitKindFrontFacingMicroTriangleNV: + return "HitKindFrontFacingMicroTriangleNV"; + case BuiltIn::HitKindBackFacingMicroTriangleNV: + return "HitKindBackFacingMicroTriangleNV"; + case BuiltIn::HitSphereRadiusNV: + return "HitSphereRadiusNV"; + case BuiltIn::HitLSSRadiiNV: + return "HitLSSRadiiNV"; + case BuiltIn::ClusterIDNV: + return "ClusterIDNV"; + case BuiltIn::CullMaskKHR: + return "CullMaskKHR"; + } + return "unknown"; +} +auto to_string(Scope e) -> char const * { + switch (e) { + case Scope::CrossDevice: + return "CrossDevice"; + case Scope::Device: + return "Device"; + case Scope::Workgroup: + return "Workgroup"; + case Scope::Subgroup: + return "Subgroup"; + case Scope::Invocation: + return "Invocation"; + case Scope::QueueFamily: + return "QueueFamily"; + case Scope::ShaderCallKHR: + return "ShaderCallKHR"; + } + return "unknown"; +} +auto to_string(GroupOperation e) -> char const * { + switch (e) { + case GroupOperation::Reduce: + return "Reduce"; + case GroupOperation::InclusiveScan: + return "InclusiveScan"; + case GroupOperation::ExclusiveScan: + return "ExclusiveScan"; + case GroupOperation::ClusteredReduce: + return "ClusteredReduce"; + case GroupOperation::PartitionedReduceNV: + return "PartitionedReduceNV"; + case GroupOperation::PartitionedInclusiveScanNV: + return "PartitionedInclusiveScanNV"; + case GroupOperation::PartitionedExclusiveScanNV: + return "PartitionedExclusiveScanNV"; + } + return "unknown"; +} +auto to_string(KernelEnqueueFlags e) -> char const * { + switch (e) { + case KernelEnqueueFlags::NoWait: + return "NoWait"; + case KernelEnqueueFlags::WaitKernel: + return "WaitKernel"; + case KernelEnqueueFlags::WaitWorkGroup: + return "WaitWorkGroup"; + } + return "unknown"; +} +auto to_string(Capability e) -> char const * { + switch (e) { + case Capability::Matrix: + return "Matrix"; + case Capability::Shader: + return "Shader"; + case Capability::Geometry: + return "Geometry"; + case Capability::Tessellation: + return "Tessellation"; + case Capability::Addresses: + return "Addresses"; + case Capability::Linkage: + return "Linkage"; + case Capability::Kernel: + return "Kernel"; + case Capability::Vector16: + return "Vector16"; + case Capability::Float16Buffer: + return "Float16Buffer"; + case Capability::Float16: + return "Float16"; + case Capability::Float64: + return "Float64"; + case Capability::Int64: + return "Int64"; + case Capability::Int64Atomics: + return "Int64Atomics"; + case Capability::ImageBasic: + return "ImageBasic"; + case Capability::ImageReadWrite: + return "ImageReadWrite"; + case Capability::ImageMipmap: + return "ImageMipmap"; + case Capability::Pipes: + return "Pipes"; + case Capability::Groups: + return "Groups"; + case Capability::DeviceEnqueue: + return "DeviceEnqueue"; + case Capability::LiteralSampler: + return "LiteralSampler"; + case Capability::AtomicStorage: + return "AtomicStorage"; + case Capability::Int16: + return "Int16"; + case Capability::TessellationPointSize: + return "TessellationPointSize"; + case Capability::GeometryPointSize: + return "GeometryPointSize"; + case Capability::ImageGatherExtended: + return "ImageGatherExtended"; + case Capability::StorageImageMultisample: + return "StorageImageMultisample"; + case Capability::UniformBufferArrayDynamicIndexing: + return "UniformBufferArrayDynamicIndexing"; + case Capability::SampledImageArrayDynamicIndexing: + return "SampledImageArrayDynamicIndexing"; + case Capability::StorageBufferArrayDynamicIndexing: + return "StorageBufferArrayDynamicIndexing"; + case Capability::StorageImageArrayDynamicIndexing: + return "StorageImageArrayDynamicIndexing"; + case Capability::ClipDistance: + return "ClipDistance"; + case Capability::CullDistance: + return "CullDistance"; + case Capability::ImageCubeArray: + return "ImageCubeArray"; + case Capability::SampleRateShading: + return "SampleRateShading"; + case Capability::ImageRect: + return "ImageRect"; + case Capability::SampledRect: + return "SampledRect"; + case Capability::GenericPointer: + return "GenericPointer"; + case Capability::Int8: + return "Int8"; + case Capability::InputAttachment: + return "InputAttachment"; + case Capability::SparseResidency: + return "SparseResidency"; + case Capability::MinLod: + return "MinLod"; + case Capability::Sampled1D: + return "Sampled1D"; + case Capability::Image1D: + return "Image1D"; + case Capability::SampledCubeArray: + return "SampledCubeArray"; + case Capability::SampledBuffer: + return "SampledBuffer"; + case Capability::ImageBuffer: + return "ImageBuffer"; + case Capability::ImageMSArray: + return "ImageMSArray"; + case Capability::StorageImageExtendedFormats: + return "StorageImageExtendedFormats"; + case Capability::ImageQuery: + return "ImageQuery"; + case Capability::DerivativeControl: + return "DerivativeControl"; + case Capability::InterpolationFunction: + return "InterpolationFunction"; + case Capability::TransformFeedback: + return "TransformFeedback"; + case Capability::GeometryStreams: + return "GeometryStreams"; + case Capability::StorageImageReadWithoutFormat: + return "StorageImageReadWithoutFormat"; + case Capability::StorageImageWriteWithoutFormat: + return "StorageImageWriteWithoutFormat"; + case Capability::MultiViewport: + return "MultiViewport"; + case Capability::SubgroupDispatch: + return "SubgroupDispatch"; + case Capability::NamedBarrier: + return "NamedBarrier"; + case Capability::PipeStorage: + return "PipeStorage"; + case Capability::GroupNonUniform: + return "GroupNonUniform"; + case Capability::GroupNonUniformVote: + return "GroupNonUniformVote"; + case Capability::GroupNonUniformArithmetic: + return "GroupNonUniformArithmetic"; + case Capability::GroupNonUniformBallot: + return "GroupNonUniformBallot"; + case Capability::GroupNonUniformShuffle: + return "GroupNonUniformShuffle"; + case Capability::GroupNonUniformShuffleRelative: + return "GroupNonUniformShuffleRelative"; + case Capability::GroupNonUniformClustered: + return "GroupNonUniformClustered"; + case Capability::GroupNonUniformQuad: + return "GroupNonUniformQuad"; + case Capability::ShaderLayer: + return "ShaderLayer"; + case Capability::ShaderViewportIndex: + return "ShaderViewportIndex"; + case Capability::UniformDecoration: + return "UniformDecoration"; + case Capability::CoreBuiltinsARM: + return "CoreBuiltinsARM"; + case Capability::TileImageColorReadAccessEXT: + return "TileImageColorReadAccessEXT"; + case Capability::TileImageDepthReadAccessEXT: + return "TileImageDepthReadAccessEXT"; + case Capability::TileImageStencilReadAccessEXT: + return "TileImageStencilReadAccessEXT"; + case Capability::TensorsARM: + return "TensorsARM"; + case Capability::StorageTensorArrayDynamicIndexingARM: + return "StorageTensorArrayDynamicIndexingARM"; + case Capability::StorageTensorArrayNonUniformIndexingARM: + return "StorageTensorArrayNonUniformIndexingARM"; + case Capability::CooperativeMatrixLayoutsARM: + return "CooperativeMatrixLayoutsARM"; + case Capability::Float8EXT: + return "Float8EXT"; + case Capability::Float8CooperativeMatrixEXT: + return "Float8CooperativeMatrixEXT"; + case Capability::FragmentShadingRateKHR: + return "FragmentShadingRateKHR"; + case Capability::SubgroupBallotKHR: + return "SubgroupBallotKHR"; + case Capability::DrawParameters: + return "DrawParameters"; + case Capability::WorkgroupMemoryExplicitLayoutKHR: + return "WorkgroupMemoryExplicitLayoutKHR"; + case Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR: + return "WorkgroupMemoryExplicitLayout8BitAccessKHR"; + case Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR: + return "WorkgroupMemoryExplicitLayout16BitAccessKHR"; + case Capability::SubgroupVoteKHR: + return "SubgroupVoteKHR"; + case Capability::StorageBuffer16BitAccess: + return "StorageBuffer16BitAccess"; + case Capability::UniformAndStorageBuffer16BitAccess: + return "UniformAndStorageBuffer16BitAccess"; + case Capability::StoragePushConstant16: + return "StoragePushConstant16"; + case Capability::StorageInputOutput16: + return "StorageInputOutput16"; + case Capability::DeviceGroup: + return "DeviceGroup"; + case Capability::MultiView: + return "MultiView"; + case Capability::VariablePointersStorageBuffer: + return "VariablePointersStorageBuffer"; + case Capability::VariablePointers: + return "VariablePointers"; + case Capability::AtomicStorageOps: + return "AtomicStorageOps"; + case Capability::SampleMaskPostDepthCoverage: + return "SampleMaskPostDepthCoverage"; + case Capability::StorageBuffer8BitAccess: + return "StorageBuffer8BitAccess"; + case Capability::UniformAndStorageBuffer8BitAccess: + return "UniformAndStorageBuffer8BitAccess"; + case Capability::StoragePushConstant8: + return "StoragePushConstant8"; + case Capability::DenormPreserve: + return "DenormPreserve"; + case Capability::DenormFlushToZero: + return "DenormFlushToZero"; + case Capability::SignedZeroInfNanPreserve: + return "SignedZeroInfNanPreserve"; + case Capability::RoundingModeRTE: + return "RoundingModeRTE"; + case Capability::RoundingModeRTZ: + return "RoundingModeRTZ"; + case Capability::RayQueryProvisionalKHR: + return "RayQueryProvisionalKHR"; + case Capability::RayQueryKHR: + return "RayQueryKHR"; + case Capability::UntypedPointersKHR: + return "UntypedPointersKHR"; + case Capability::RayTraversalPrimitiveCullingKHR: + return "RayTraversalPrimitiveCullingKHR"; + case Capability::RayTracingKHR: + return "RayTracingKHR"; + case Capability::TextureSampleWeightedQCOM: + return "TextureSampleWeightedQCOM"; + case Capability::TextureBoxFilterQCOM: + return "TextureBoxFilterQCOM"; + case Capability::TextureBlockMatchQCOM: + return "TextureBlockMatchQCOM"; + case Capability::TileShadingQCOM: + return "TileShadingQCOM"; + case Capability::TextureBlockMatch2QCOM: + return "TextureBlockMatch2QCOM"; + case Capability::Float16ImageAMD: + return "Float16ImageAMD"; + case Capability::ImageGatherBiasLodAMD: + return "ImageGatherBiasLodAMD"; + case Capability::FragmentMaskAMD: + return "FragmentMaskAMD"; + case Capability::StencilExportEXT: + return "StencilExportEXT"; + case Capability::ImageReadWriteLodAMD: + return "ImageReadWriteLodAMD"; + case Capability::Int64ImageEXT: + return "Int64ImageEXT"; + case Capability::ShaderClockKHR: + return "ShaderClockKHR"; + case Capability::ShaderEnqueueAMDX: + return "ShaderEnqueueAMDX"; + case Capability::QuadControlKHR: + return "QuadControlKHR"; + case Capability::Int4TypeINTEL: + return "Int4TypeINTEL"; + case Capability::Int4CooperativeMatrixINTEL: + return "Int4CooperativeMatrixINTEL"; + case Capability::BFloat16TypeKHR: + return "BFloat16TypeKHR"; + case Capability::BFloat16DotProductKHR: + return "BFloat16DotProductKHR"; + case Capability::BFloat16CooperativeMatrixKHR: + return "BFloat16CooperativeMatrixKHR"; + case Capability::SampleMaskOverrideCoverageNV: + return "SampleMaskOverrideCoverageNV"; + case Capability::GeometryShaderPassthroughNV: + return "GeometryShaderPassthroughNV"; + case Capability::ShaderViewportIndexLayerEXT: + return "ShaderViewportIndexLayerEXT"; + case Capability::ShaderViewportMaskNV: + return "ShaderViewportMaskNV"; + case Capability::ShaderStereoViewNV: + return "ShaderStereoViewNV"; + case Capability::PerViewAttributesNV: + return "PerViewAttributesNV"; + case Capability::FragmentFullyCoveredEXT: + return "FragmentFullyCoveredEXT"; + case Capability::MeshShadingNV: + return "MeshShadingNV"; + case Capability::ImageFootprintNV: + return "ImageFootprintNV"; + case Capability::MeshShadingEXT: + return "MeshShadingEXT"; + case Capability::FragmentBarycentricKHR: + return "FragmentBarycentricKHR"; + case Capability::ComputeDerivativeGroupQuadsKHR: + return "ComputeDerivativeGroupQuadsKHR"; + case Capability::FragmentDensityEXT: + return "FragmentDensityEXT"; + case Capability::GroupNonUniformPartitionedNV: + return "GroupNonUniformPartitionedNV"; + case Capability::ShaderNonUniform: + return "ShaderNonUniform"; + case Capability::RuntimeDescriptorArray: + return "RuntimeDescriptorArray"; + case Capability::InputAttachmentArrayDynamicIndexing: + return "InputAttachmentArrayDynamicIndexing"; + case Capability::UniformTexelBufferArrayDynamicIndexing: + return "UniformTexelBufferArrayDynamicIndexing"; + case Capability::StorageTexelBufferArrayDynamicIndexing: + return "StorageTexelBufferArrayDynamicIndexing"; + case Capability::UniformBufferArrayNonUniformIndexing: + return "UniformBufferArrayNonUniformIndexing"; + case Capability::SampledImageArrayNonUniformIndexing: + return "SampledImageArrayNonUniformIndexing"; + case Capability::StorageBufferArrayNonUniformIndexing: + return "StorageBufferArrayNonUniformIndexing"; + case Capability::StorageImageArrayNonUniformIndexing: + return "StorageImageArrayNonUniformIndexing"; + case Capability::InputAttachmentArrayNonUniformIndexing: + return "InputAttachmentArrayNonUniformIndexing"; + case Capability::UniformTexelBufferArrayNonUniformIndexing: + return "UniformTexelBufferArrayNonUniformIndexing"; + case Capability::StorageTexelBufferArrayNonUniformIndexing: + return "StorageTexelBufferArrayNonUniformIndexing"; + case Capability::RayTracingPositionFetchKHR: + return "RayTracingPositionFetchKHR"; + case Capability::RayTracingNV: + return "RayTracingNV"; + case Capability::RayTracingMotionBlurNV: + return "RayTracingMotionBlurNV"; + case Capability::VulkanMemoryModel: + return "VulkanMemoryModel"; + case Capability::VulkanMemoryModelDeviceScope: + return "VulkanMemoryModelDeviceScope"; + case Capability::PhysicalStorageBufferAddresses: + return "PhysicalStorageBufferAddresses"; + case Capability::ComputeDerivativeGroupLinearKHR: + return "ComputeDerivativeGroupLinearKHR"; + case Capability::RayTracingProvisionalKHR: + return "RayTracingProvisionalKHR"; + case Capability::CooperativeMatrixNV: + return "CooperativeMatrixNV"; + case Capability::FragmentShaderSampleInterlockEXT: + return "FragmentShaderSampleInterlockEXT"; + case Capability::FragmentShaderShadingRateInterlockEXT: + return "FragmentShaderShadingRateInterlockEXT"; + case Capability::ShaderSMBuiltinsNV: + return "ShaderSMBuiltinsNV"; + case Capability::FragmentShaderPixelInterlockEXT: + return "FragmentShaderPixelInterlockEXT"; + case Capability::DemoteToHelperInvocation: + return "DemoteToHelperInvocation"; + case Capability::DisplacementMicromapNV: + return "DisplacementMicromapNV"; + case Capability::RayTracingOpacityMicromapEXT: + return "RayTracingOpacityMicromapEXT"; + case Capability::ShaderInvocationReorderNV: + return "ShaderInvocationReorderNV"; + case Capability::BindlessTextureNV: + return "BindlessTextureNV"; + case Capability::RayQueryPositionFetchKHR: + return "RayQueryPositionFetchKHR"; + case Capability::CooperativeVectorNV: + return "CooperativeVectorNV"; + case Capability::AtomicFloat16VectorNV: + return "AtomicFloat16VectorNV"; + case Capability::RayTracingDisplacementMicromapNV: + return "RayTracingDisplacementMicromapNV"; + case Capability::RawAccessChainsNV: + return "RawAccessChainsNV"; + case Capability::RayTracingSpheresGeometryNV: + return "RayTracingSpheresGeometryNV"; + case Capability::RayTracingLinearSweptSpheresGeometryNV: + return "RayTracingLinearSweptSpheresGeometryNV"; + case Capability::CooperativeMatrixReductionsNV: + return "CooperativeMatrixReductionsNV"; + case Capability::CooperativeMatrixConversionsNV: + return "CooperativeMatrixConversionsNV"; + case Capability::CooperativeMatrixPerElementOperationsNV: + return "CooperativeMatrixPerElementOperationsNV"; + case Capability::CooperativeMatrixTensorAddressingNV: + return "CooperativeMatrixTensorAddressingNV"; + case Capability::CooperativeMatrixBlockLoadsNV: + return "CooperativeMatrixBlockLoadsNV"; + case Capability::CooperativeVectorTrainingNV: + return "CooperativeVectorTrainingNV"; + case Capability::RayTracingClusterAccelerationStructureNV: + return "RayTracingClusterAccelerationStructureNV"; + case Capability::TensorAddressingNV: + return "TensorAddressingNV"; + case Capability::SubgroupShuffleINTEL: + return "SubgroupShuffleINTEL"; + case Capability::SubgroupBufferBlockIOINTEL: + return "SubgroupBufferBlockIOINTEL"; + case Capability::SubgroupImageBlockIOINTEL: + return "SubgroupImageBlockIOINTEL"; + case Capability::SubgroupImageMediaBlockIOINTEL: + return "SubgroupImageMediaBlockIOINTEL"; + case Capability::RoundToInfinityINTEL: + return "RoundToInfinityINTEL"; + case Capability::FloatingPointModeINTEL: + return "FloatingPointModeINTEL"; + case Capability::IntegerFunctions2INTEL: + return "IntegerFunctions2INTEL"; + case Capability::FunctionPointersINTEL: + return "FunctionPointersINTEL"; + case Capability::IndirectReferencesINTEL: + return "IndirectReferencesINTEL"; + case Capability::AsmINTEL: + return "AsmINTEL"; + case Capability::AtomicFloat32MinMaxEXT: + return "AtomicFloat32MinMaxEXT"; + case Capability::AtomicFloat64MinMaxEXT: + return "AtomicFloat64MinMaxEXT"; + case Capability::AtomicFloat16MinMaxEXT: + return "AtomicFloat16MinMaxEXT"; + case Capability::VectorComputeINTEL: + return "VectorComputeINTEL"; + case Capability::VectorAnyINTEL: + return "VectorAnyINTEL"; + case Capability::ExpectAssumeKHR: + return "ExpectAssumeKHR"; + case Capability::SubgroupAvcMotionEstimationINTEL: + return "SubgroupAvcMotionEstimationINTEL"; + case Capability::SubgroupAvcMotionEstimationIntraINTEL: + return "SubgroupAvcMotionEstimationIntraINTEL"; + case Capability::SubgroupAvcMotionEstimationChromaINTEL: + return "SubgroupAvcMotionEstimationChromaINTEL"; + case Capability::VariableLengthArrayINTEL: + return "VariableLengthArrayINTEL"; + case Capability::FunctionFloatControlINTEL: + return "FunctionFloatControlINTEL"; + case Capability::FPGAMemoryAttributesINTEL: + return "FPGAMemoryAttributesINTEL"; + case Capability::FPFastMathModeINTEL: + return "FPFastMathModeINTEL"; + case Capability::ArbitraryPrecisionIntegersINTEL: + return "ArbitraryPrecisionIntegersINTEL"; + case Capability::ArbitraryPrecisionFloatingPointINTEL: + return "ArbitraryPrecisionFloatingPointINTEL"; + case Capability::UnstructuredLoopControlsINTEL: + return "UnstructuredLoopControlsINTEL"; + case Capability::FPGALoopControlsINTEL: + return "FPGALoopControlsINTEL"; + case Capability::KernelAttributesINTEL: + return "KernelAttributesINTEL"; + case Capability::FPGAKernelAttributesINTEL: + return "FPGAKernelAttributesINTEL"; + case Capability::FPGAMemoryAccessesINTEL: + return "FPGAMemoryAccessesINTEL"; + case Capability::FPGAClusterAttributesINTEL: + return "FPGAClusterAttributesINTEL"; + case Capability::LoopFuseINTEL: + return "LoopFuseINTEL"; + case Capability::FPGADSPControlINTEL: + return "FPGADSPControlINTEL"; + case Capability::MemoryAccessAliasingINTEL: + return "MemoryAccessAliasingINTEL"; + case Capability::FPGAInvocationPipeliningAttributesINTEL: + return "FPGAInvocationPipeliningAttributesINTEL"; + case Capability::FPGABufferLocationINTEL: + return "FPGABufferLocationINTEL"; + case Capability::ArbitraryPrecisionFixedPointINTEL: + return "ArbitraryPrecisionFixedPointINTEL"; + case Capability::USMStorageClassesINTEL: + return "USMStorageClassesINTEL"; + case Capability::RuntimeAlignedAttributeINTEL: + return "RuntimeAlignedAttributeINTEL"; + case Capability::IOPipesINTEL: + return "IOPipesINTEL"; + case Capability::BlockingPipesINTEL: + return "BlockingPipesINTEL"; + case Capability::FPGARegINTEL: + return "FPGARegINTEL"; + case Capability::DotProductInputAll: + return "DotProductInputAll"; + case Capability::DotProductInput4x8Bit: + return "DotProductInput4x8Bit"; + case Capability::DotProductInput4x8BitPacked: + return "DotProductInput4x8BitPacked"; + case Capability::DotProduct: + return "DotProduct"; + case Capability::RayCullMaskKHR: + return "RayCullMaskKHR"; + case Capability::CooperativeMatrixKHR: + return "CooperativeMatrixKHR"; + case Capability::ReplicatedCompositesEXT: + return "ReplicatedCompositesEXT"; + case Capability::BitInstructions: + return "BitInstructions"; + case Capability::GroupNonUniformRotateKHR: + return "GroupNonUniformRotateKHR"; + case Capability::FloatControls2: + return "FloatControls2"; + case Capability::AtomicFloat32AddEXT: + return "AtomicFloat32AddEXT"; + case Capability::AtomicFloat64AddEXT: + return "AtomicFloat64AddEXT"; + case Capability::LongCompositesINTEL: + return "LongCompositesINTEL"; + case Capability::OptNoneEXT: + return "OptNoneEXT"; + case Capability::AtomicFloat16AddEXT: + return "AtomicFloat16AddEXT"; + case Capability::DebugInfoModuleINTEL: + return "DebugInfoModuleINTEL"; + case Capability::BFloat16ConversionINTEL: + return "BFloat16ConversionINTEL"; + case Capability::SplitBarrierINTEL: + return "SplitBarrierINTEL"; + case Capability::ArithmeticFenceEXT: + return "ArithmeticFenceEXT"; + case Capability::FPGAClusterAttributesV2INTEL: + return "FPGAClusterAttributesV2INTEL"; + case Capability::FPGAKernelAttributesv2INTEL: + return "FPGAKernelAttributesv2INTEL"; + case Capability::TaskSequenceINTEL: + return "TaskSequenceINTEL"; + case Capability::FPMaxErrorINTEL: + return "FPMaxErrorINTEL"; + case Capability::FPGALatencyControlINTEL: + return "FPGALatencyControlINTEL"; + case Capability::FPGAArgumentInterfacesINTEL: + return "FPGAArgumentInterfacesINTEL"; + case Capability::GlobalVariableHostAccessINTEL: + return "GlobalVariableHostAccessINTEL"; + case Capability::GlobalVariableFPGADecorationsINTEL: + return "GlobalVariableFPGADecorationsINTEL"; + case Capability::SubgroupBufferPrefetchINTEL: + return "SubgroupBufferPrefetchINTEL"; + case Capability::Subgroup2DBlockIOINTEL: + return "Subgroup2DBlockIOINTEL"; + case Capability::Subgroup2DBlockTransformINTEL: + return "Subgroup2DBlockTransformINTEL"; + case Capability::Subgroup2DBlockTransposeINTEL: + return "Subgroup2DBlockTransposeINTEL"; + case Capability::SubgroupMatrixMultiplyAccumulateINTEL: + return "SubgroupMatrixMultiplyAccumulateINTEL"; + case Capability::TernaryBitwiseFunctionINTEL: + return "TernaryBitwiseFunctionINTEL"; + case Capability::GroupUniformArithmeticKHR: + return "GroupUniformArithmeticKHR"; + case Capability::TensorFloat32RoundingINTEL: + return "TensorFloat32RoundingINTEL"; + case Capability::MaskedGatherScatterINTEL: + return "MaskedGatherScatterINTEL"; + case Capability::CacheControlsINTEL: + return "CacheControlsINTEL"; + case Capability::RegisterLimitsINTEL: + return "RegisterLimitsINTEL"; + case Capability::BindlessImagesINTEL: + return "BindlessImagesINTEL"; + case Capability::PackedCooperativeMatrixINTEL: + return "PackedCooperativeMatrixINTEL"; + case Capability::CooperativeMatrixInvocationInstructionsINTEL: + return "CooperativeMatrixInvocationInstructionsINTEL"; + case Capability::CooperativeMatrixTF32ComponentTypeINTEL: + return "CooperativeMatrixTF32ComponentTypeINTEL"; + case Capability::CooperativeMatrixBFloat16ComponentTypeINTEL: + return "CooperativeMatrixBFloat16ComponentTypeINTEL"; + case Capability::CooperativeMatrixCheckedInstructionsINTEL: + return "CooperativeMatrixCheckedInstructionsINTEL"; + case Capability::CooperativeMatrixPrefetchINTEL: + return "CooperativeMatrixPrefetchINTEL"; + } + return "unknown"; +} +auto to_string(RayQueryIntersection e) -> char const * { + switch (e) { + case RayQueryIntersection::RayQueryCandidateIntersectionKHR: + return "RayQueryCandidateIntersectionKHR"; + case RayQueryIntersection::RayQueryCommittedIntersectionKHR: + return "RayQueryCommittedIntersectionKHR"; + } + return "unknown"; +} +auto to_string(RayQueryCommittedIntersectionType e) -> char const * { + switch (e) { + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionNoneKHR: + return "RayQueryCommittedIntersectionNoneKHR"; + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionTriangleKHR: + return "RayQueryCommittedIntersectionTriangleKHR"; + case RayQueryCommittedIntersectionType::RayQueryCommittedIntersectionGeneratedKHR: + return "RayQueryCommittedIntersectionGeneratedKHR"; + } + return "unknown"; +} +auto to_string(RayQueryCandidateIntersectionType e) -> char const * { + switch (e) { + case RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR: + return "RayQueryCandidateIntersectionTriangleKHR"; + case RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionAABBKHR: + return "RayQueryCandidateIntersectionAABBKHR"; + } + return "unknown"; +} +auto to_string(PackedVectorFormat e) -> char const * { + switch (e) { + case PackedVectorFormat::PackedVectorFormat4x8Bit: + return "PackedVectorFormat4x8Bit"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixOperands e) -> char const * { + switch (e) { + case CooperativeMatrixOperands::NoneKHR: + return "NoneKHR"; + case CooperativeMatrixOperands::MatrixASignedComponentsKHR: + return "MatrixASignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixBSignedComponentsKHR: + return "MatrixBSignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixCSignedComponentsKHR: + return "MatrixCSignedComponentsKHR"; + case CooperativeMatrixOperands::MatrixResultSignedComponentsKHR: + return "MatrixResultSignedComponentsKHR"; + case CooperativeMatrixOperands::SaturatingAccumulationKHR: + return "SaturatingAccumulationKHR"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixLayout e) -> char const * { + switch (e) { + case CooperativeMatrixLayout::RowMajorKHR: + return "RowMajorKHR"; + case CooperativeMatrixLayout::ColumnMajorKHR: + return "ColumnMajorKHR"; + case CooperativeMatrixLayout::RowBlockedInterleavedARM: + return "RowBlockedInterleavedARM"; + case CooperativeMatrixLayout::ColumnBlockedInterleavedARM: + return "ColumnBlockedInterleavedARM"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixUse e) -> char const * { + switch (e) { + case CooperativeMatrixUse::MatrixAKHR: + return "MatrixAKHR"; + case CooperativeMatrixUse::MatrixBKHR: + return "MatrixBKHR"; + case CooperativeMatrixUse::MatrixAccumulatorKHR: + return "MatrixAccumulatorKHR"; + } + return "unknown"; +} +auto to_string(CooperativeMatrixReduce e) -> char const * { + switch (e) { + case CooperativeMatrixReduce::Row: + return "Row"; + case CooperativeMatrixReduce::Column: + return "Column"; + case CooperativeMatrixReduce::CooperativeMatrixReduce2x2: + return "CooperativeMatrixReduce2x2"; + } + return "unknown"; +} +auto to_string(TensorClampMode e) -> char const * { + switch (e) { + case TensorClampMode::Undefined: + return "Undefined"; + case TensorClampMode::Constant: + return "Constant"; + case TensorClampMode::ClampToEdge: + return "ClampToEdge"; + case TensorClampMode::Repeat: + return "Repeat"; + case TensorClampMode::RepeatMirrored: + return "RepeatMirrored"; + } + return "unknown"; +} +auto to_string(TensorAddressingOperands e) -> char const * { + switch (e) { + case TensorAddressingOperands::None: + return "None"; + case TensorAddressingOperands::TensorView: + return "TensorView"; + case TensorAddressingOperands::DecodeFunc: + return "DecodeFunc"; + } + return "unknown"; +} +auto to_string(InitializationModeQualifier e) -> char const * { + switch (e) { + case InitializationModeQualifier::InitOnDeviceReprogramINTEL: + return "InitOnDeviceReprogramINTEL"; + case InitializationModeQualifier::InitOnDeviceResetINTEL: + return "InitOnDeviceResetINTEL"; + } + return "unknown"; +} +auto to_string(LoadCacheControl e) -> char const * { + switch (e) { + case LoadCacheControl::UncachedINTEL: + return "UncachedINTEL"; + case LoadCacheControl::CachedINTEL: + return "CachedINTEL"; + case LoadCacheControl::StreamingINTEL: + return "StreamingINTEL"; + case LoadCacheControl::InvalidateAfterReadINTEL: + return "InvalidateAfterReadINTEL"; + case LoadCacheControl::ConstCachedINTEL: + return "ConstCachedINTEL"; + } + return "unknown"; +} +auto to_string(StoreCacheControl e) -> char const * { + switch (e) { + case StoreCacheControl::UncachedINTEL: + return "UncachedINTEL"; + case StoreCacheControl::WriteThroughINTEL: + return "WriteThroughINTEL"; + case StoreCacheControl::WriteBackINTEL: + return "WriteBackINTEL"; + case StoreCacheControl::StreamingINTEL: + return "StreamingINTEL"; + } + return "unknown"; +} +auto to_string(NamedMaximumNumberOfRegisters e) -> char const * { + switch (e) { + case NamedMaximumNumberOfRegisters::AutoINTEL: + return "AutoINTEL"; + } + return "unknown"; +} +auto to_string(MatrixMultiplyAccumulateOperands e) -> char const * { + switch (e) { + case MatrixMultiplyAccumulateOperands::None: + return "None"; + case MatrixMultiplyAccumulateOperands::MatrixASignedComponentsINTEL: + return "MatrixASignedComponentsINTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBSignedComponentsINTEL: + return "MatrixBSignedComponentsINTEL"; + case MatrixMultiplyAccumulateOperands::MatrixCBFloat16INTEL: + return "MatrixCBFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixResultBFloat16INTEL: + return "MatrixResultBFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedInt8INTEL: + return "MatrixAPackedInt8INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedInt8INTEL: + return "MatrixBPackedInt8INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedInt4INTEL: + return "MatrixAPackedInt4INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedInt4INTEL: + return "MatrixBPackedInt4INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixATF32INTEL: + return "MatrixATF32INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBTF32INTEL: + return "MatrixBTF32INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedFloat16INTEL: + return "MatrixAPackedFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedFloat16INTEL: + return "MatrixBPackedFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixAPackedBFloat16INTEL: + return "MatrixAPackedBFloat16INTEL"; + case MatrixMultiplyAccumulateOperands::MatrixBPackedBFloat16INTEL: + return "MatrixBPackedBFloat16INTEL"; + } + return "unknown"; +} +auto to_string(FPEncoding e) -> char const * { + switch (e) { + case FPEncoding::BFloat16KHR: + return "BFloat16KHR"; + case FPEncoding::Float8E4M3EXT: + return "Float8E4M3EXT"; + case FPEncoding::Float8E5M2EXT: + return "Float8E5M2EXT"; + } + return "unknown"; +} +auto to_string(CooperativeVectorMatrixLayout e) -> char const * { + switch (e) { + case CooperativeVectorMatrixLayout::RowMajorNV: + return "RowMajorNV"; + case CooperativeVectorMatrixLayout::ColumnMajorNV: + return "ColumnMajorNV"; + case CooperativeVectorMatrixLayout::InferencingOptimalNV: + return "InferencingOptimalNV"; + case CooperativeVectorMatrixLayout::TrainingOptimalNV: + return "TrainingOptimalNV"; + } + return "unknown"; +} +auto to_string(ComponentType e) -> char const * { + switch (e) { + case ComponentType::Float16NV: + return "Float16NV"; + case ComponentType::Float32NV: + return "Float32NV"; + case ComponentType::Float64NV: + return "Float64NV"; + case ComponentType::SignedInt8NV: + return "SignedInt8NV"; + case ComponentType::SignedInt16NV: + return "SignedInt16NV"; + case ComponentType::SignedInt32NV: + return "SignedInt32NV"; + case ComponentType::SignedInt64NV: + return "SignedInt64NV"; + case ComponentType::UnsignedInt8NV: + return "UnsignedInt8NV"; + case ComponentType::UnsignedInt16NV: + return "UnsignedInt16NV"; + case ComponentType::UnsignedInt32NV: + return "UnsignedInt32NV"; + case ComponentType::UnsignedInt64NV: + return "UnsignedInt64NV"; + case ComponentType::SignedInt8PackedNV: + return "SignedInt8PackedNV"; + case ComponentType::UnsignedInt8PackedNV: + return "UnsignedInt8PackedNV"; + case ComponentType::FloatE4M3NV: + return "FloatE4M3NV"; + case ComponentType::FloatE5M2NV: + return "FloatE5M2NV"; + } + return "unknown"; +} +auto to_string(TensorOperands e) -> char const * { + switch (e) { + case TensorOperands::NoneARM: + return "NoneARM"; + case TensorOperands::NontemporalARM: + return "NontemporalARM"; + case TensorOperands::OutOfBoundsValueARM: + return "OutOfBoundsValueARM"; + case TensorOperands::MakeElementAvailableARM: + return "MakeElementAvailableARM"; + case TensorOperands::MakeElementVisibleARM: + return "MakeElementVisibleARM"; + case TensorOperands::NonPrivateElementARM: + return "NonPrivateElementARM"; + } + return "unknown"; +} + +} // namespace tinytc::spv diff --git a/src/spv/names.hpp b/src/spv/names.hpp new file mode 100644 index 00000000..fb5b463d --- /dev/null +++ b/src/spv/names.hpp @@ -0,0 +1,131 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_NAMES_20250630_HPP +#define GENERATED_NAMES_20250630_HPP + +namespace tinytc::spv { + +enum class Op; +auto to_string(Op op) -> char const *; +enum class ImageOperands; +auto to_string(ImageOperands e) -> char const *; +enum class FPFastMathMode; +auto to_string(FPFastMathMode e) -> char const *; +enum class SelectionControl; +auto to_string(SelectionControl e) -> char const *; +enum class LoopControl; +auto to_string(LoopControl e) -> char const *; +enum class FunctionControl; +auto to_string(FunctionControl e) -> char const *; +enum class MemorySemantics; +auto to_string(MemorySemantics e) -> char const *; +enum class MemoryAccess; +auto to_string(MemoryAccess e) -> char const *; +enum class KernelProfilingInfo; +auto to_string(KernelProfilingInfo e) -> char const *; +enum class RayFlags; +auto to_string(RayFlags e) -> char const *; +enum class FragmentShadingRate; +auto to_string(FragmentShadingRate e) -> char const *; +enum class RawAccessChainOperands; +auto to_string(RawAccessChainOperands e) -> char const *; +enum class SourceLanguage; +auto to_string(SourceLanguage e) -> char const *; +enum class ExecutionModel; +auto to_string(ExecutionModel e) -> char const *; +enum class AddressingModel; +auto to_string(AddressingModel e) -> char const *; +enum class MemoryModel; +auto to_string(MemoryModel e) -> char const *; +enum class ExecutionMode; +auto to_string(ExecutionMode e) -> char const *; +enum class StorageClass; +auto to_string(StorageClass e) -> char const *; +enum class Dim; +auto to_string(Dim e) -> char const *; +enum class SamplerAddressingMode; +auto to_string(SamplerAddressingMode e) -> char const *; +enum class SamplerFilterMode; +auto to_string(SamplerFilterMode e) -> char const *; +enum class ImageFormat; +auto to_string(ImageFormat e) -> char const *; +enum class ImageChannelOrder; +auto to_string(ImageChannelOrder e) -> char const *; +enum class ImageChannelDataType; +auto to_string(ImageChannelDataType e) -> char const *; +enum class FPRoundingMode; +auto to_string(FPRoundingMode e) -> char const *; +enum class FPDenormMode; +auto to_string(FPDenormMode e) -> char const *; +enum class QuantizationModes; +auto to_string(QuantizationModes e) -> char const *; +enum class FPOperationMode; +auto to_string(FPOperationMode e) -> char const *; +enum class OverflowModes; +auto to_string(OverflowModes e) -> char const *; +enum class LinkageType; +auto to_string(LinkageType e) -> char const *; +enum class AccessQualifier; +auto to_string(AccessQualifier e) -> char const *; +enum class HostAccessQualifier; +auto to_string(HostAccessQualifier e) -> char const *; +enum class FunctionParameterAttribute; +auto to_string(FunctionParameterAttribute e) -> char const *; +enum class Decoration; +auto to_string(Decoration e) -> char const *; +enum class BuiltIn; +auto to_string(BuiltIn e) -> char const *; +enum class Scope; +auto to_string(Scope e) -> char const *; +enum class GroupOperation; +auto to_string(GroupOperation e) -> char const *; +enum class KernelEnqueueFlags; +auto to_string(KernelEnqueueFlags e) -> char const *; +enum class Capability; +auto to_string(Capability e) -> char const *; +enum class RayQueryIntersection; +auto to_string(RayQueryIntersection e) -> char const *; +enum class RayQueryCommittedIntersectionType; +auto to_string(RayQueryCommittedIntersectionType e) -> char const *; +enum class RayQueryCandidateIntersectionType; +auto to_string(RayQueryCandidateIntersectionType e) -> char const *; +enum class PackedVectorFormat; +auto to_string(PackedVectorFormat e) -> char const *; +enum class CooperativeMatrixOperands; +auto to_string(CooperativeMatrixOperands e) -> char const *; +enum class CooperativeMatrixLayout; +auto to_string(CooperativeMatrixLayout e) -> char const *; +enum class CooperativeMatrixUse; +auto to_string(CooperativeMatrixUse e) -> char const *; +enum class CooperativeMatrixReduce; +auto to_string(CooperativeMatrixReduce e) -> char const *; +enum class TensorClampMode; +auto to_string(TensorClampMode e) -> char const *; +enum class TensorAddressingOperands; +auto to_string(TensorAddressingOperands e) -> char const *; +enum class InitializationModeQualifier; +auto to_string(InitializationModeQualifier e) -> char const *; +enum class LoadCacheControl; +auto to_string(LoadCacheControl e) -> char const *; +enum class StoreCacheControl; +auto to_string(StoreCacheControl e) -> char const *; +enum class NamedMaximumNumberOfRegisters; +auto to_string(NamedMaximumNumberOfRegisters e) -> char const *; +enum class MatrixMultiplyAccumulateOperands; +auto to_string(MatrixMultiplyAccumulateOperands e) -> char const *; +enum class FPEncoding; +auto to_string(FPEncoding e) -> char const *; +enum class CooperativeVectorMatrixLayout; +auto to_string(CooperativeVectorMatrixLayout e) -> char const *; +enum class ComponentType; +auto to_string(ComponentType e) -> char const *; +enum class TensorOperands; +auto to_string(TensorOperands e) -> char const *; + +} // namespace tinytc::spv + +#endif // GENERATED_NAMES_20250630_HPP diff --git a/src/spv/opencl.std.cpp b/src/spv/opencl.std.cpp new file mode 100644 index 00000000..1284448b --- /dev/null +++ b/src/spv/opencl.std.cpp @@ -0,0 +1,341 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#include "opencl.std.hpp" + +namespace tinytc::spv { + +auto to_string(OpenCLEntrypoint ep) -> char const * { + switch (ep) { + case OpenCLEntrypoint::acos: + return "acos"; + case OpenCLEntrypoint::acosh: + return "acosh"; + case OpenCLEntrypoint::acospi: + return "acospi"; + case OpenCLEntrypoint::asin: + return "asin"; + case OpenCLEntrypoint::asinh: + return "asinh"; + case OpenCLEntrypoint::asinpi: + return "asinpi"; + case OpenCLEntrypoint::atan: + return "atan"; + case OpenCLEntrypoint::atan2: + return "atan2"; + case OpenCLEntrypoint::atanh: + return "atanh"; + case OpenCLEntrypoint::atanpi: + return "atanpi"; + case OpenCLEntrypoint::atan2pi: + return "atan2pi"; + case OpenCLEntrypoint::cbrt: + return "cbrt"; + case OpenCLEntrypoint::ceil: + return "ceil"; + case OpenCLEntrypoint::copysign: + return "copysign"; + case OpenCLEntrypoint::cos: + return "cos"; + case OpenCLEntrypoint::cosh: + return "cosh"; + case OpenCLEntrypoint::cospi: + return "cospi"; + case OpenCLEntrypoint::erfc: + return "erfc"; + case OpenCLEntrypoint::erf: + return "erf"; + case OpenCLEntrypoint::exp: + return "exp"; + case OpenCLEntrypoint::exp2: + return "exp2"; + case OpenCLEntrypoint::exp10: + return "exp10"; + case OpenCLEntrypoint::expm1: + return "expm1"; + case OpenCLEntrypoint::fabs: + return "fabs"; + case OpenCLEntrypoint::fdim: + return "fdim"; + case OpenCLEntrypoint::floor: + return "floor"; + case OpenCLEntrypoint::fma: + return "fma"; + case OpenCLEntrypoint::fmax: + return "fmax"; + case OpenCLEntrypoint::fmin: + return "fmin"; + case OpenCLEntrypoint::fmod: + return "fmod"; + case OpenCLEntrypoint::fract: + return "fract"; + case OpenCLEntrypoint::frexp: + return "frexp"; + case OpenCLEntrypoint::hypot: + return "hypot"; + case OpenCLEntrypoint::ilogb: + return "ilogb"; + case OpenCLEntrypoint::ldexp: + return "ldexp"; + case OpenCLEntrypoint::lgamma: + return "lgamma"; + case OpenCLEntrypoint::lgamma_r: + return "lgamma_r"; + case OpenCLEntrypoint::log: + return "log"; + case OpenCLEntrypoint::log2: + return "log2"; + case OpenCLEntrypoint::log10: + return "log10"; + case OpenCLEntrypoint::log1p: + return "log1p"; + case OpenCLEntrypoint::logb: + return "logb"; + case OpenCLEntrypoint::mad: + return "mad"; + case OpenCLEntrypoint::maxmag: + return "maxmag"; + case OpenCLEntrypoint::minmag: + return "minmag"; + case OpenCLEntrypoint::modf: + return "modf"; + case OpenCLEntrypoint::nan: + return "nan"; + case OpenCLEntrypoint::nextafter: + return "nextafter"; + case OpenCLEntrypoint::pow: + return "pow"; + case OpenCLEntrypoint::pown: + return "pown"; + case OpenCLEntrypoint::powr: + return "powr"; + case OpenCLEntrypoint::remainder: + return "remainder"; + case OpenCLEntrypoint::remquo: + return "remquo"; + case OpenCLEntrypoint::rint: + return "rint"; + case OpenCLEntrypoint::rootn: + return "rootn"; + case OpenCLEntrypoint::round: + return "round"; + case OpenCLEntrypoint::rsqrt: + return "rsqrt"; + case OpenCLEntrypoint::sin: + return "sin"; + case OpenCLEntrypoint::sincos: + return "sincos"; + case OpenCLEntrypoint::sinh: + return "sinh"; + case OpenCLEntrypoint::sinpi: + return "sinpi"; + case OpenCLEntrypoint::sqrt: + return "sqrt"; + case OpenCLEntrypoint::tan: + return "tan"; + case OpenCLEntrypoint::tanh: + return "tanh"; + case OpenCLEntrypoint::tanpi: + return "tanpi"; + case OpenCLEntrypoint::tgamma: + return "tgamma"; + case OpenCLEntrypoint::trunc: + return "trunc"; + case OpenCLEntrypoint::half_cos: + return "half_cos"; + case OpenCLEntrypoint::half_divide: + return "half_divide"; + case OpenCLEntrypoint::half_exp: + return "half_exp"; + case OpenCLEntrypoint::half_exp2: + return "half_exp2"; + case OpenCLEntrypoint::half_exp10: + return "half_exp10"; + case OpenCLEntrypoint::half_log: + return "half_log"; + case OpenCLEntrypoint::half_log2: + return "half_log2"; + case OpenCLEntrypoint::half_log10: + return "half_log10"; + case OpenCLEntrypoint::half_powr: + return "half_powr"; + case OpenCLEntrypoint::half_recip: + return "half_recip"; + case OpenCLEntrypoint::half_rsqrt: + return "half_rsqrt"; + case OpenCLEntrypoint::half_sin: + return "half_sin"; + case OpenCLEntrypoint::half_sqrt: + return "half_sqrt"; + case OpenCLEntrypoint::half_tan: + return "half_tan"; + case OpenCLEntrypoint::native_cos: + return "native_cos"; + case OpenCLEntrypoint::native_divide: + return "native_divide"; + case OpenCLEntrypoint::native_exp: + return "native_exp"; + case OpenCLEntrypoint::native_exp2: + return "native_exp2"; + case OpenCLEntrypoint::native_exp10: + return "native_exp10"; + case OpenCLEntrypoint::native_log: + return "native_log"; + case OpenCLEntrypoint::native_log2: + return "native_log2"; + case OpenCLEntrypoint::native_log10: + return "native_log10"; + case OpenCLEntrypoint::native_powr: + return "native_powr"; + case OpenCLEntrypoint::native_recip: + return "native_recip"; + case OpenCLEntrypoint::native_rsqrt: + return "native_rsqrt"; + case OpenCLEntrypoint::native_sin: + return "native_sin"; + case OpenCLEntrypoint::native_sqrt: + return "native_sqrt"; + case OpenCLEntrypoint::native_tan: + return "native_tan"; + case OpenCLEntrypoint::s_abs: + return "s_abs"; + case OpenCLEntrypoint::s_abs_diff: + return "s_abs_diff"; + case OpenCLEntrypoint::s_add_sat: + return "s_add_sat"; + case OpenCLEntrypoint::u_add_sat: + return "u_add_sat"; + case OpenCLEntrypoint::s_hadd: + return "s_hadd"; + case OpenCLEntrypoint::u_hadd: + return "u_hadd"; + case OpenCLEntrypoint::s_rhadd: + return "s_rhadd"; + case OpenCLEntrypoint::u_rhadd: + return "u_rhadd"; + case OpenCLEntrypoint::s_clamp: + return "s_clamp"; + case OpenCLEntrypoint::u_clamp: + return "u_clamp"; + case OpenCLEntrypoint::clz: + return "clz"; + case OpenCLEntrypoint::ctz: + return "ctz"; + case OpenCLEntrypoint::s_mad_hi: + return "s_mad_hi"; + case OpenCLEntrypoint::u_mad_sat: + return "u_mad_sat"; + case OpenCLEntrypoint::s_mad_sat: + return "s_mad_sat"; + case OpenCLEntrypoint::s_max: + return "s_max"; + case OpenCLEntrypoint::u_max: + return "u_max"; + case OpenCLEntrypoint::s_min: + return "s_min"; + case OpenCLEntrypoint::u_min: + return "u_min"; + case OpenCLEntrypoint::s_mul_hi: + return "s_mul_hi"; + case OpenCLEntrypoint::rotate: + return "rotate"; + case OpenCLEntrypoint::s_sub_sat: + return "s_sub_sat"; + case OpenCLEntrypoint::u_sub_sat: + return "u_sub_sat"; + case OpenCLEntrypoint::u_upsample: + return "u_upsample"; + case OpenCLEntrypoint::s_upsample: + return "s_upsample"; + case OpenCLEntrypoint::popcount: + return "popcount"; + case OpenCLEntrypoint::s_mad24: + return "s_mad24"; + case OpenCLEntrypoint::u_mad24: + return "u_mad24"; + case OpenCLEntrypoint::s_mul24: + return "s_mul24"; + case OpenCLEntrypoint::u_mul24: + return "u_mul24"; + case OpenCLEntrypoint::u_abs: + return "u_abs"; + case OpenCLEntrypoint::u_abs_diff: + return "u_abs_diff"; + case OpenCLEntrypoint::u_mul_hi: + return "u_mul_hi"; + case OpenCLEntrypoint::u_mad_hi: + return "u_mad_hi"; + case OpenCLEntrypoint::fclamp: + return "fclamp"; + case OpenCLEntrypoint::degrees: + return "degrees"; + case OpenCLEntrypoint::fmax_common: + return "fmax_common"; + case OpenCLEntrypoint::fmin_common: + return "fmin_common"; + case OpenCLEntrypoint::mix: + return "mix"; + case OpenCLEntrypoint::radians: + return "radians"; + case OpenCLEntrypoint::step: + return "step"; + case OpenCLEntrypoint::smoothstep: + return "smoothstep"; + case OpenCLEntrypoint::sign: + return "sign"; + case OpenCLEntrypoint::cross: + return "cross"; + case OpenCLEntrypoint::distance: + return "distance"; + case OpenCLEntrypoint::length: + return "length"; + case OpenCLEntrypoint::normalize: + return "normalize"; + case OpenCLEntrypoint::fast_distance: + return "fast_distance"; + case OpenCLEntrypoint::fast_length: + return "fast_length"; + case OpenCLEntrypoint::fast_normalize: + return "fast_normalize"; + case OpenCLEntrypoint::bitselect: + return "bitselect"; + case OpenCLEntrypoint::select: + return "select"; + case OpenCLEntrypoint::vloadn: + return "vloadn"; + case OpenCLEntrypoint::vstoren: + return "vstoren"; + case OpenCLEntrypoint::vload_half: + return "vload_half"; + case OpenCLEntrypoint::vload_halfn: + return "vload_halfn"; + case OpenCLEntrypoint::vstore_half: + return "vstore_half"; + case OpenCLEntrypoint::vstore_half_r: + return "vstore_half_r"; + case OpenCLEntrypoint::vstore_halfn: + return "vstore_halfn"; + case OpenCLEntrypoint::vstore_halfn_r: + return "vstore_halfn_r"; + case OpenCLEntrypoint::vloada_halfn: + return "vloada_halfn"; + case OpenCLEntrypoint::vstorea_halfn: + return "vstorea_halfn"; + case OpenCLEntrypoint::vstorea_halfn_r: + return "vstorea_halfn_r"; + case OpenCLEntrypoint::shuffle: + return "shuffle"; + case OpenCLEntrypoint::shuffle2: + return "shuffle2"; + case OpenCLEntrypoint::printf: + return "printf"; + case OpenCLEntrypoint::prefetch: + return "prefetch"; + } + return "unknown"; +} + +} // namespace tinytc::spv diff --git a/src/spv/opencl.std.hpp b/src/spv/opencl.std.hpp new file mode 100644 index 00000000..b662101f --- /dev/null +++ b/src/spv/opencl.std.hpp @@ -0,0 +1,183 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_OPENCL_STD_2024115_HPP +#define GENERATED_OPENCL_STD_2024115_HPP + +namespace tinytc::spv { + +constexpr char const *OpenCLExt = "OpenCL.std"; + +enum class OpenCLEntrypoint { + acos = 0, + acosh = 1, + acospi = 2, + asin = 3, + asinh = 4, + asinpi = 5, + atan = 6, + atan2 = 7, + atanh = 8, + atanpi = 9, + atan2pi = 10, + cbrt = 11, + ceil = 12, + copysign = 13, + cos = 14, + cosh = 15, + cospi = 16, + erfc = 17, + erf = 18, + exp = 19, + exp2 = 20, + exp10 = 21, + expm1 = 22, + fabs = 23, + fdim = 24, + floor = 25, + fma = 26, + fmax = 27, + fmin = 28, + fmod = 29, + fract = 30, + frexp = 31, + hypot = 32, + ilogb = 33, + ldexp = 34, + lgamma = 35, + lgamma_r = 36, + log = 37, + log2 = 38, + log10 = 39, + log1p = 40, + logb = 41, + mad = 42, + maxmag = 43, + minmag = 44, + modf = 45, + nan = 46, + nextafter = 47, + pow = 48, + pown = 49, + powr = 50, + remainder = 51, + remquo = 52, + rint = 53, + rootn = 54, + round = 55, + rsqrt = 56, + sin = 57, + sincos = 58, + sinh = 59, + sinpi = 60, + sqrt = 61, + tan = 62, + tanh = 63, + tanpi = 64, + tgamma = 65, + trunc = 66, + half_cos = 67, + half_divide = 68, + half_exp = 69, + half_exp2 = 70, + half_exp10 = 71, + half_log = 72, + half_log2 = 73, + half_log10 = 74, + half_powr = 75, + half_recip = 76, + half_rsqrt = 77, + half_sin = 78, + half_sqrt = 79, + half_tan = 80, + native_cos = 81, + native_divide = 82, + native_exp = 83, + native_exp2 = 84, + native_exp10 = 85, + native_log = 86, + native_log2 = 87, + native_log10 = 88, + native_powr = 89, + native_recip = 90, + native_rsqrt = 91, + native_sin = 92, + native_sqrt = 93, + native_tan = 94, + s_abs = 141, + s_abs_diff = 142, + s_add_sat = 143, + u_add_sat = 144, + s_hadd = 145, + u_hadd = 146, + s_rhadd = 147, + u_rhadd = 148, + s_clamp = 149, + u_clamp = 150, + clz = 151, + ctz = 152, + s_mad_hi = 153, + u_mad_sat = 154, + s_mad_sat = 155, + s_max = 156, + u_max = 157, + s_min = 158, + u_min = 159, + s_mul_hi = 160, + rotate = 161, + s_sub_sat = 162, + u_sub_sat = 163, + u_upsample = 164, + s_upsample = 165, + popcount = 166, + s_mad24 = 167, + u_mad24 = 168, + s_mul24 = 169, + u_mul24 = 170, + u_abs = 201, + u_abs_diff = 202, + u_mul_hi = 203, + u_mad_hi = 204, + fclamp = 95, + degrees = 96, + fmax_common = 97, + fmin_common = 98, + mix = 99, + radians = 100, + step = 101, + smoothstep = 102, + sign = 103, + cross = 104, + distance = 105, + length = 106, + normalize = 107, + fast_distance = 108, + fast_length = 109, + fast_normalize = 110, + bitselect = 186, + select = 187, + vloadn = 171, + vstoren = 172, + vload_half = 173, + vload_halfn = 174, + vstore_half = 175, + vstore_half_r = 176, + vstore_halfn = 177, + vstore_halfn_r = 178, + vloada_halfn = 179, + vstorea_halfn = 180, + vstorea_halfn_r = 181, + shuffle = 182, + shuffle2 = 183, + printf = 184, + prefetch = 185, +}; + +auto to_string(OpenCLEntrypoint op) -> char const *; + +} // namespace tinytc::spv + +#endif // GENERATED_OPENCL_STD_2024115_HPP diff --git a/src/spv/pass/assemble.cpp b/src/spv/pass/assemble.cpp new file mode 100644 index 00000000..f202a502 --- /dev/null +++ b/src/spv/pass/assemble.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/assemble.hpp" +#include "spv/enums.hpp" +#include "spv/inst_assembler.hpp" +#include "spv/module.hpp" +#include "spv/visit.hpp" +#include "tinytc/core.h" +#include "tinytc/types.h" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc::spv { + +auto assembler::run_on_module(tinytc_spv_mod const &mod) -> shared_handle { + auto data = std::vector{}; + auto stream = word_stream{data}; + + const std::int32_t bound = mod.bound(); + // Guess instruction stream by using 5 words per instruction that produces a result + // Not really important, but could be improved + data.reserve(5 * sizeof(std::int32_t) * bound); + + // Make header + const std::int32_t version = (mod.major_version() << 16) | (mod.minor_version() << 8); + const std::int32_t generator_number = 0; + stream << magic_number << version << generator_number << bound << std::int32_t{0}; + + // Assemble instructions + auto ia = inst_assembler{stream}; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : mod.insts(enum_cast
(s))) { + visit(ia, i); + } + } + + // Create binary + tinytc_binary_t bin; + CHECK_STATUS(tinytc_binary_create(&bin, mod.context(), tinytc_bundle_format_spirv, data.size(), + data.data(), mod.core_features())); + return shared_handle{bin}; +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/assemble.hpp b/src/spv/pass/assemble.hpp new file mode 100644 index 00000000..fc6b17f0 --- /dev/null +++ b/src/spv/pass/assemble.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ASSEMBLE_20241111_HPP +#define ASSEMBLE_20241111_HPP + +#include "tinytc/types.h" +#include "tinytc/types.hpp" + +namespace tinytc::spv { + +class assembler { + public: + auto run_on_module(tinytc_spv_mod const &mod) -> shared_handle; +}; + +} // namespace tinytc::spv + +#endif // ASSEMBLE_20241111_HPP diff --git a/src/spv/pass/assign_ids.cpp b/src/spv/pass/assign_ids.cpp new file mode 100644 index 00000000..6295faae --- /dev/null +++ b/src/spv/pass/assign_ids.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/assign_ids.hpp" +#include "spv/defs.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" + +#include +#include + +namespace tinytc::spv { + +void id_assigner::declare(spv_inst *in) { + if (!slot_map_.contains(in)) { + const auto slot = slot_++; + slot_map_[in] = slot; + in->id(slot); + } +} + +void id_assigner::visit_result(spv_inst &in) { declare(&in); } + +void id_assigner::operator()(spv_inst *&in) { + if (!slot_map_.contains(in)) { + if (isa(*in) || isa(*in) || isa(*in) || + isa(*in) || isa(*in)) { + declare(in); + } else { + throw status::spirv_forbidden_forward_declaration; + } + } +} + +void id_assigner::operator()(OpPhi &in) { + pre_visit(in); + this->operator()(in.type()); + this->visit_result(in); + for (auto &op : in.op0()) { + // Forward references are allowed in phi instructions + declare(op.first); + this->operator()(op); + } + post_visit(in); +} + +void id_assigner::run_on_module(tinytc_spv_mod &m) { + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto &i : m.insts(enum_cast
(s))) { + visit(*this, i); + } + } +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/assign_ids.hpp b/src/spv/pass/assign_ids.hpp new file mode 100644 index 00000000..a6e16c8c --- /dev/null +++ b/src/spv/pass/assign_ids.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef ASSIGN_IDS_20241111_HPP +#define ASSIGN_IDS_20241111_HPP + +#include "spv/defs.hpp" +#include "spv/visit.hpp" +#include "tinytc/types.h" + +#include +#include + +namespace tinytc::spv { + +class id_assigner : public default_visitor { + public: + using default_visitor::operator(); + + void visit_result(spv_inst &in); + + // Do nothing by default + template void operator()(T &) {} + + void operator()(spv_inst *&in); + void operator()(OpPhi &in); + + void run_on_module(tinytc_spv_mod &m); + + private: + void declare(spv_inst *in); + + std::uint32_t slot_ = 1; + std::unordered_map slot_map_; +}; + +} // namespace tinytc::spv + +#endif // ASSIGN_IDS_20241111_HPP diff --git a/src/spv/pass/capex.cpp b/src/spv/pass/capex.cpp new file mode 100644 index 00000000..39a82376 --- /dev/null +++ b/src/spv/pass/capex.cpp @@ -0,0 +1,260 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/capex.hpp" +#include "spv/capex_util.hpp" +#include "spv/enums.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/uniquifier.hpp" +#include "spv/visit.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/overloaded.hpp" + +#include +#include + +namespace tinytc::spv { + +template +concept inst_with_return_type = requires(T &t) { + { t.type() } -> std::same_as; +}; + +capex::capex(uniquifier &unique) : unique_{&unique} {} + +void capex::operator()(spv_inst const &) {} +void capex::operator()(OpAtomicStore const &in) { + auto ty = visit(overloaded{[](inst_with_return_type auto &a) -> spv_inst * { return a.type(); }, + [](spv_inst &) -> spv_inst * { return nullptr; }}, + *in.op3()); + if (!ty) { + throw status::internal_compiler_error; + } + auto ity = dyn_cast(ty); + if (ity && ity->op0() == 64) { + unique_->capability(Capability::Int64Atomics); + required_features_[tinytc_spirv_feature_int64_atomics] = true; + } +} +auto capex::float_atomic_class(spv_inst *raw_ty, spv_inst *op0) + -> std::pair { + auto ty = dyn_cast(raw_ty); + if (!ty) { + throw status::internal_compiler_error; + } + + auto pointer_ty = visit(overloaded{[](inst_with_return_type auto &a) -> OpTypePointer * { + return dyn_cast(a.type()); + }, + [](spv_inst &) -> OpTypePointer * { return nullptr; }}, + *op0); + if (!pointer_ty) { + throw status::internal_compiler_error; + } + return {ty->op0(), pointer_ty->op0()}; +} +void capex::operator()(OpAtomicFAddEXT const &in) { + const auto [bits, storage_cls] = float_atomic_class(in.type(), in.op0()); + switch (bits) { + case 16: + unique_->capability(Capability::AtomicFloat16AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float16_add"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float16_add_local + : tinytc_spirv_feature_atomic_float16_add_global] = true; + break; + case 32: + unique_->capability(Capability::AtomicFloat32AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float_add"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float32_add_local + : tinytc_spirv_feature_atomic_float32_add_global] = true; + break; + case 64: + unique_->capability(Capability::AtomicFloat64AddEXT); + unique_->extension("SPV_EXT_shader_atomic_float_add"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float64_add_local + : tinytc_spirv_feature_atomic_float64_add_global] = true; + break; + default: + break; + } +} +void capex::check_float_min_max_atomic(spv_inst *ty, spv_inst *op0) { + const auto [bits, storage_cls] = float_atomic_class(ty, op0); + switch (bits) { + case 16: + unique_->capability(Capability::AtomicFloat16MinMaxEXT); + unique_->extension("SPV_EXT_shader_atomic_float16_min_max"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float16_min_max_local + : tinytc_spirv_feature_atomic_float16_min_max_global] = true; + break; + case 32: + unique_->capability(Capability::AtomicFloat32MinMaxEXT); + unique_->extension("SPV_EXT_shader_atomic_float_min_max"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float32_min_max_local + : tinytc_spirv_feature_atomic_float32_min_max_global] = true; + break; + case 64: + unique_->capability(Capability::AtomicFloat64MinMaxEXT); + unique_->extension("SPV_EXT_shader_atomic_float_min_max"); + required_features_[storage_cls == StorageClass::Workgroup + ? tinytc_spirv_feature_atomic_float64_min_max_local + : tinytc_spirv_feature_atomic_float64_min_max_global] = true; + break; + default: + break; + } +} +void capex::operator()(OpAtomicFMaxEXT const &in) { + check_float_min_max_atomic(in.type(), in.op0()); +} +void capex::operator()(OpAtomicFMinEXT const &in) { + check_float_min_max_atomic(in.type(), in.op0()); +} +void capex::check_int_atomic(spv_inst *raw_ty) { + auto ty = dyn_cast(raw_ty); + if (!ty) { + throw status::internal_compiler_error; + } + if (ty && ty->op0() == 64) { + unique_->capability(Capability::Int64Atomics); + required_features_[tinytc_spirv_feature_int64_atomics] = true; + } +} +void capex::operator()(OpAtomicIAdd const &in) { check_int_atomic(in.type()); } +void capex::operator()(OpAtomicSMax const &in) { check_int_atomic(in.type()); } +void capex::operator()(OpAtomicSMin const &in) { check_int_atomic(in.type()); } +void capex::operator()(OpAsmTargetINTEL const &) { + unique_->capability(Capability::AsmINTEL); + unique_->extension("SPV_INTEL_inline_assembly"); +} +void capex::operator()(OpAsmINTEL const &) { + unique_->capability(Capability::AsmINTEL); + unique_->extension("SPV_INTEL_inline_assembly"); +} +void capex::operator()(OpAsmCallINTEL const &) { + unique_->capability(Capability::AsmINTEL); + unique_->extension("SPV_INTEL_inline_assembly"); +} +void capex::operator()(OpConvertBF16ToFINTEL const &) { + unique_->capability(Capability::BFloat16ConversionINTEL); + unique_->extension("SPV_INTEL_bfloat16_conversion"); + required_features_[tinytc_spirv_feature_bfloat16_conversion] = true; +} +void capex::operator()(OpConvertFToBF16INTEL const &) { + unique_->capability(Capability::BFloat16ConversionINTEL); + unique_->extension("SPV_INTEL_bfloat16_conversion"); + required_features_[tinytc_spirv_feature_bfloat16_conversion] = true; +} +void capex::operator()(OpCooperativeMatrixLoadKHR const &) { + unique_->capability(Capability::CooperativeMatrixKHR); + unique_->extension("SPV_KHR_cooperative_matrix"); +} +void capex::operator()(OpCooperativeMatrixMulAddKHR const &) { + unique_->capability(Capability::CooperativeMatrixKHR); + unique_->extension("SPV_KHR_cooperative_matrix"); +} +void capex::operator()(OpCooperativeMatrixStoreKHR const &) { + unique_->capability(Capability::CooperativeMatrixKHR); + unique_->extension("SPV_KHR_cooperative_matrix"); +} +void capex::operator()(OpEntryPoint const &in) { + for (auto const &cap : capabilities(in.op0())) { + unique_->capability(cap); + } +} +void capex::operator()(OpExecutionMode const &in) { + for (auto const &cap : capabilities(in.op1())) { + unique_->capability(cap); + if (cap == Capability::SubgroupDispatch) { + required_features_[tinytc_spirv_feature_subgroup_dispatch] = true; + } + } +} +void capex::operator()(OpGroupBroadcast const &) { + unique_->capability(Capability::Groups); + required_features_[tinytc_spirv_feature_groups] = true; +} +void capex::operator()(OpGroupFAdd const &) { + unique_->capability(Capability::Groups); + required_features_[tinytc_spirv_feature_groups] = true; +} +void capex::operator()(OpGroupIAdd const &) { + unique_->capability(Capability::Groups); + required_features_[tinytc_spirv_feature_groups] = true; +} +void capex::operator()(OpInBoundsPtrAccessChain const &) { + unique_->capability(Capability::Addresses); +} +void capex::operator()(OpMemoryModel const &in) { + for (auto const &cap : capabilities(in.op0())) { + unique_->capability(cap); + } + for (auto const &cap : capabilities(in.op1())) { + unique_->capability(cap); + } +} +void capex::operator()(OpSubgroupBlockReadINTEL const &) { + unique_->capability(Capability::SubgroupBufferBlockIOINTEL); + unique_->extension("SPV_INTEL_subgroups"); + required_features_[tinytc_spirv_feature_subgroup_buffer_block_io] = true; +} +void capex::operator()(OpSubgroupBlockWriteINTEL const &) { + unique_->capability(Capability::SubgroupBufferBlockIOINTEL); + unique_->extension("SPV_INTEL_subgroups"); + required_features_[tinytc_spirv_feature_subgroup_buffer_block_io] = true; +} +void capex::operator()(OpTypeFloat const &in) { + switch (in.op0()) { + case 16: + unique_->capability(Capability::Float16); + required_features_[tinytc_spirv_feature_float16] = true; + break; + case 64: + unique_->capability(Capability::Float64); + required_features_[tinytc_spirv_feature_float64] = true; + break; + default: + break; + } +} +void capex::operator()(OpTypeInt const &in) { + switch (in.op0()) { + case 8: + unique_->capability(Capability::Int8); + break; + case 16: + unique_->capability(Capability::Int16); + break; + case 64: + unique_->capability(Capability::Int64); + break; + default: + break; + } +} +void capex::operator()(OpTypeVector const &in) { + if (in.op1() > 4) { + unique_->capability(Capability::Vector16); + } +} + +void capex::run_on_module(tinytc_spv_mod const &mod) { + for (std::int32_t s = 0; s < num_module_sections; ++s) { + for (auto const &i : mod.insts(enum_cast
(s))) { + visit(*this, i); + } + } +} + +} // namespace tinytc::spv + diff --git a/src/spv/pass/capex.hpp b/src/spv/pass/capex.hpp new file mode 100644 index 00000000..025dc965 --- /dev/null +++ b/src/spv/pass/capex.hpp @@ -0,0 +1,72 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef CAPEX_20241113_HPP +#define CAPEX_20241113_HPP + +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "tinytc/types.h" + +#include +#include + +namespace tinytc { +enum class spirv_feature; +} + +namespace tinytc::spv { + +class uniquifier; + +class capex { + public: + capex(uniquifier &unique); + + void operator()(spv_inst const &in); + void operator()(OpAtomicStore const &in); + void operator()(OpAtomicFAddEXT const &in); + void operator()(OpAtomicFMaxEXT const &in); + void operator()(OpAtomicFMinEXT const &in); + void operator()(OpAtomicIAdd const &in); + void operator()(OpAtomicSMax const &in); + void operator()(OpAtomicSMin const &in); + void operator()(OpAsmTargetINTEL const &in); + void operator()(OpAsmINTEL const &in); + void operator()(OpAsmCallINTEL const &in); + void operator()(OpConvertBF16ToFINTEL const &in); + void operator()(OpConvertFToBF16INTEL const &in); + void operator()(OpCooperativeMatrixLoadKHR const &in); + void operator()(OpCooperativeMatrixMulAddKHR const &in); + void operator()(OpCooperativeMatrixStoreKHR const &in); + void operator()(OpEntryPoint const &in); + void operator()(OpExecutionMode const &in); + void operator()(OpGroupBroadcast const &in); + void operator()(OpGroupFAdd const &in); + void operator()(OpGroupIAdd const &in); + void operator()(OpInBoundsPtrAccessChain const &in); + void operator()(OpMemoryModel const &in); + void operator()(OpSubgroupBlockReadINTEL const &in); + void operator()(OpSubgroupBlockWriteINTEL const &in); + void operator()(OpTypeFloat const &in); + void operator()(OpTypeInt const &in); + void operator()(OpTypeVector const &in); + + void run_on_module(tinytc_spv_mod const &mod); + + inline auto requires_feature(spirv_feature f) const -> bool { + return required_features_[static_cast(f)]; + } + + private: + auto float_atomic_class(spv_inst *ty, spv_inst *op0) -> std::pair; + void check_float_min_max_atomic(spv_inst *ty, spv_inst *op0); + void check_int_atomic(spv_inst *ty); + + uniquifier *unique_; + std::array required_features_ = {}; +}; + +} // namespace tinytc::spv + +#endif // CAPEX_20241113_HPP diff --git a/src/spv/pass/dump_asm.cpp b/src/spv/pass/dump_asm.cpp new file mode 100644 index 00000000..d3898ac7 --- /dev/null +++ b/src/spv/pass/dump_asm.cpp @@ -0,0 +1,151 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/pass/dump_asm.hpp" +#include "spv/instructions.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "tinytc/core.hpp" +#include "util/casting.hpp" +#include "util/ilist.hpp" +#include "util/ilist_base.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +enum class LinkageType; + +dump_asm_pass::dump_asm_pass(std::ostream &os) : os_(&os) {} + +void dump_asm_pass::pre_visit(spv_inst const &in) { + auto const num_digits = [](std::int64_t number) { + std::int64_t d = 1; + while (number /= 10) { + ++d; + } + return d; + }; + *os_ << std::endl; + if (in.has_result_id()) { + const auto id = in.id(); + + for (int i = 0; i < rhs_indent - 4 - num_digits(id); ++i) { + *os_ << ' '; + } + *os_ << "%" << id << " = "; + } else { + for (int i = 0; i < rhs_indent; ++i) { + *os_ << ' '; + } + } + *os_ << "Op" << to_string(in.opcode()); +} + +void dump_asm_pass::operator()(DecorationAttr const &da) { + std::visit(overloaded{[&](auto const &a) { this->operator()(a); }, + [&](std::pair const &a) { + *os_ << " \"" << a.first << '"'; + this->operator()(a.second); + }}, + da); +} +void dump_asm_pass::operator()(ExecutionModeAttr const &ea) { + std::visit( + overloaded{[&](std::int32_t const &a) { *os_ << " " << static_cast(a); }, + [&](std::array const &a) { + for (auto const &s : a) { + *os_ << " " << static_cast(s); + } + }}, + ea); +} +void dump_asm_pass::operator()(LiteralContextDependentNumber const &l) { + std::visit(overloaded{[&](std::int8_t const &l) { + *os_ << " " + << static_cast(static_cast(l)); + }, + [&](std::signed_integral auto const &l) { + using unsigned_t = std::make_unsigned_t>; + *os_ << " " << static_cast(l); + }, + [&](std::floating_point auto const &l) { + auto flags = os_->flags(); + *os_ << " " << std::hexfloat << l; + os_->flags(flags); + }, + [&](half const &l) { + auto flags = os_->flags(); + *os_ << " " << std::hexfloat << l; + os_->flags(flags); + }}, + l); +} +void dump_asm_pass::operator()(LiteralInteger const &l) { + *os_ << " " << static_cast>(l); +} +void dump_asm_pass::operator()(LiteralString const &l) { *os_ << " \"" << l << '"'; } + +void dump_asm_pass::operator()(PairIdRefIdRef const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void dump_asm_pass::operator()(PairIdRefLiteralInteger const &p) { + this->operator()(p.first); + this->operator()(p.second); +} +void dump_asm_pass::operator()(PairLiteralIntegerIdRef const &p) { + std::visit(overloaded{[&](auto const &l) { + using unsigned_t = std::make_unsigned_t>; + *os_ << " " << static_cast(l); + }}, + p.first); + this->operator()(p.second); +} + +void dump_asm_pass::operator()(spv_inst *const &in) { *os_ << " %" << in->id(); } +void dump_asm_pass::operator()(OpExtInst const &in) { + pre_visit(in); + this->operator()(in.type()); + visit_result(in); + this->operator()(in.op0()); + + if (auto extimport = dyn_cast(in.op0()); + extimport && extimport->op0() == OpenCLExt) { + this->operator()(static_cast(in.op1())); + } else { + this->operator()(in.op1()); + } + + for (auto const &op : in.op2()) { + this->operator()(op); + } + post_visit(in); +} + +void dump_asm_pass::run_on_module(tinytc_spv_mod const &m) { + auto const visit_section = [&](section s) { + for (auto const &i : m.insts(s)) { + visit(*this, i); + } + }; + *os_ << "; SPIR-V" << std::endl + << "; Version " << m.major_version() << '.' << m.minor_version() << std::endl + << "; Generator: Tiny Tensor Compiler" << std::endl + << "; Bound: " << m.bound() << std::endl + << "; Schema: 0"; + for (std::int32_t s = 0; s < num_module_sections; ++s) { + visit_section(enum_cast
(s)); + } + *os_ << std::endl; +} + +} // namespace tinytc::spv diff --git a/src/spv/pass/dump_asm.hpp b/src/spv/pass/dump_asm.hpp new file mode 100644 index 00000000..a1bbad29 --- /dev/null +++ b/src/spv/pass/dump_asm.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef DUMP_ASM_20241029_HPP +#define DUMP_ASM_20241029_HPP + +#include "spv/defs.hpp" +#include "spv/names.hpp" +#include "spv/visit.hpp" +#include "tinytc/types.h" + +#include + +namespace tinytc::spv { + +class dump_asm_pass : public default_visitor { + public: + using default_visitor::operator(); + constexpr static int rhs_indent = 15; + + dump_asm_pass(std::ostream &os); + + void pre_visit(spv_inst const &in); + + template + requires requires(T const &e) { to_string(e); } + void operator()(T const &e) { + *os_ << " " << to_string(e); + } + void operator()(DecorationAttr const &da); + void operator()(ExecutionModeAttr const &ea); + void operator()(LiteralContextDependentNumber const &l); + void operator()(LiteralInteger const &l); + void operator()(LiteralString const &l); + + void operator()(PairIdRefIdRef const &p); + void operator()(PairIdRefLiteralInteger const &p); + void operator()(PairLiteralIntegerIdRef const &p); + + void operator()(spv_inst *const &in); + void operator()(OpExtInst const &in); + + void run_on_module(tinytc_spv_mod const &m); + + private: + std::ostream *os_; +}; + +} // namespace tinytc::spv + +#endif // DUMP_ASM_20241029_HPP diff --git a/src/spv/uniquifier.cpp b/src/spv/uniquifier.cpp new file mode 100644 index 00000000..65b9070d --- /dev/null +++ b/src/spv/uniquifier.cpp @@ -0,0 +1,250 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#include "spv/uniquifier.hpp" +#include "compiler_context.hpp" +#include "number.hpp" +#include "spv/defs.hpp" +#include "spv/instructions.hpp" +#include "spv/lut.hpp" +#include "spv/module.hpp" +#include "spv/opencl.std.hpp" +#include "support/fnv1a_array_view.hpp" +#include "tinytc/types.hpp" +#include "util/overloaded.hpp" + +#include +#include +#include +#include +#include + +namespace tinytc::spv { + +auto address_space_to_storage_class(address_space as) -> StorageClass { + return as == address_space::local ? StorageClass::Workgroup : StorageClass::CrossWorkgroup; +} + +uniquifier::uniquifier(tinytc_spv_mod &m) : mod_(&m) {} + +auto uniquifier::asm_target() -> spv_inst * { + if (!asm_target_) { + asm_target_ = + mod_->add_to(section::type_const_var, "spirv64-unknown-unknown"); + } + return asm_target_; +} + +auto uniquifier::bool_constant(bool b) -> spv_inst * { + if (b) { + return lookup(bool_true_, [&] { + return mod_->add_to(section::type_const_var, bool_ty()); + }); + } + return lookup(bool_false_, [&] { + return mod_->add_to(section::type_const_var, bool_ty()); + }); +} + +auto uniquifier::builtin_alignment(BuiltIn b) -> std::int32_t { + switch (b) { + case BuiltIn::WorkDim: + case BuiltIn::SubgroupSize: + case BuiltIn::SubgroupMaxSize: + case BuiltIn::NumSubgroups: + case BuiltIn::NumEnqueuedSubgroups: + case BuiltIn::SubgroupId: + case BuiltIn::SubgroupLocalInvocationId: + return 4; // i32 + case BuiltIn::GlobalLinearId: + case BuiltIn::LocalInvocationIndex: + return mod().context()->index_bit_width() / 8; // index + case BuiltIn::GlobalSize: + case BuiltIn::GlobalInvocationId: + case BuiltIn::WorkgroupSize: + case BuiltIn::EnqueuedWorkgroupSize: + case BuiltIn::LocalInvocationId: + case BuiltIn::NumWorkgroups: + case BuiltIn::WorkgroupId: + case BuiltIn::GlobalOffset: + return 4 * (mod().context()->index_bit_width() / 8); // index x 3 + default: + throw status::internal_compiler_error; + } +} + +auto uniquifier::builtin_pointee_ty(BuiltIn b) -> spv_inst * { + switch (b) { + case BuiltIn::WorkDim: + case BuiltIn::SubgroupSize: + case BuiltIn::SubgroupMaxSize: + case BuiltIn::NumSubgroups: + case BuiltIn::NumEnqueuedSubgroups: + case BuiltIn::SubgroupId: + case BuiltIn::SubgroupLocalInvocationId: + return int_ty(32); // i32 + case BuiltIn::GlobalLinearId: + case BuiltIn::LocalInvocationIndex: + return int_ty(mod().context()->index_bit_width()); // index + case BuiltIn::GlobalSize: + case BuiltIn::GlobalInvocationId: + case BuiltIn::WorkgroupSize: + case BuiltIn::EnqueuedWorkgroupSize: + case BuiltIn::LocalInvocationId: + case BuiltIn::NumWorkgroups: + case BuiltIn::WorkgroupId: + case BuiltIn::GlobalOffset: { + auto index_ty = int_ty(mod().context()->index_bit_width()); + return vec_ty(index_ty, vector_size::v3); // index x 3 + } + default: + throw status::internal_compiler_error; + } +} + +auto uniquifier::builtin_var(BuiltIn b) -> spv_inst * { + return lookup(builtin_, b, [&](BuiltIn b) { + auto ty = pointer_ty(StorageClass::Input, builtin_pointee_ty(b), builtin_alignment(b)); + auto var = mod_->add_to(section::type_const_var, ty, StorageClass::Input, + std::nullopt); + mod_->add_to(section::decoration, var, Decoration::Constant); + mod_->add_to(section::decoration, var, Decoration::BuiltIn, b); + return var; + }); +} + +void uniquifier::capability(Capability cap) { + if (!capabilities_.contains(cap)) { + mod_->add_to(section::capability, cap); + capabilities_.insert(cap); + } +} + +auto uniquifier::constant(LiteralContextDependentNumber cst) -> spv_inst * { + return lookup(cst_map_, cst, [&](LiteralContextDependentNumber cst) { + const auto visitor = overloaded{ + [&](std::int8_t &) -> spv_inst * { + return mod_->add_to(section::type_const_var, int_ty(8), cst); + }, + [&](std::int16_t &) -> spv_inst * { + return mod_->add_to(section::type_const_var, int_ty(16), cst); + }, + [&](std::int32_t &) -> spv_inst * { + return mod_->add_to(section::type_const_var, int_ty(32), cst); + }, + [&](std::int64_t &) -> spv_inst * { + return mod_->add_to(section::type_const_var, int_ty(64), cst); + }, + [&](half &) -> spv_inst * { + return mod_->add_to(section::type_const_var, float_ty(16), cst); + }, + [&](float &) -> spv_inst * { + return mod_->add_to(section::type_const_var, float_ty(32), cst); + }, + [&](double &) -> spv_inst * { + return mod_->add_to(section::type_const_var, float_ty(64), cst); + }}; + return std::visit(visitor, cst); + }); +} + +void uniquifier::extension(char const *ext_name) { + if (!extensions_.contains(ext_name)) { + mod_->add_to(section::extension, ext_name); + extensions_.insert(ext_name); + } +} + +auto uniquifier::null_constant(spv_inst *spv_ty) -> spv_inst * { + return lookup(null_cst_, spv_ty, [&](spv_inst *spv_ty) { + return mod_->add_to(section::type_const_var, spv_ty); + }); +} + +auto uniquifier::opencl_ext() -> spv_inst * { + return lookup(opencl_ext_, + [&] { return mod_->add_to(section::ext_inst, OpenCLExt); }); +} + +auto uniquifier::array_ty(spv_inst *element_ty, std::int32_t length) -> spv_inst * { + auto key = std::make_pair(element_ty, length); + return lookup(array_tys_, key, [&](std::pair const &key) { + return mod_->add_to(section::type_const_var, key.first, constant(key.second)); + }); +} + +auto uniquifier::bool_ty() -> spv_inst * { + if (!bool_ty_) { + bool_ty_ = mod_->add_to(section::type_const_var); + } + return bool_ty_; +} + +auto uniquifier::float_ty(std::int32_t width) -> spv_inst * { + return lookup(float_tys_, width, [&](std::int32_t width) { + return mod_->add_to(section::type_const_var, width); + }); +} + +auto uniquifier::function_ty(spv_inst *return_ty, array_view params) -> spv_inst * { + const auto map_key = fnv1a_step(fnv1a_step(fnv1a0(), return_ty), params); + auto range = function_tys_.equal_range(map_key); + for (auto it = range.first; it != range.second; ++it) { + if (return_ty == it->second->op0() && + std::equal(params.begin(), params.end(), it->second->op1().begin(), + it->second->op1().end())) { + return it->second; + } + } + return function_tys_ + .emplace(map_key, mod_->add_to(section::type_const_var, return_ty, + std::move(params))) + ->second; +} + +auto uniquifier::int_ty(std::int32_t width) -> spv_inst * { + return lookup(int_tys_, width, [&](std::int32_t width) { + return mod_->add_to(section::type_const_var, width, 0); + }); +} + +auto uniquifier::pointer_ty(StorageClass cls, spv_inst *pointee_ty, std::int32_t alignment) + -> spv_inst * { + auto key = std::make_tuple(cls, pointee_ty, alignment); + return lookup( + pointer_tys_, key, [&](std::tuple const &key) { + auto pointer_ty = mod_->add_to(section::type_const_var, std::get<0>(key), + std::get<1>(key)); + if (std::get<2>(key) > 0) { + mod_->add_to(section::decoration, pointer_ty, Decoration::Alignment, + DecorationAttr{std::get<2>(key)}); + } + return pointer_ty; + }); +} + +auto uniquifier::vec_ty(spv_inst *component_ty, std::int32_t length) -> spv_inst * { + auto key = std::make_pair(component_ty, length); + return lookup(vec_tys_, key, [&](std::pair const &key) { + return mod_->add_to(section::type_const_var, key.first, key.second); + }); +} +auto uniquifier::vec_ty(spv_inst *component_ty, vector_size length) -> spv_inst * { + return vec_ty(component_ty, static_cast(length)); +} + +auto uniquifier::void_ty() -> spv_inst * { + if (!void_ty_) { + void_ty_ = mod_->add_to(section::type_const_var); + } + return void_ty_; +} + +auto uniquifier::load_builtin(BuiltIn b) -> spv_inst * { + auto builtin = builtin_var(b); + return mod_->add(builtin_pointee_ty(b), builtin, MemoryAccess::Aligned, + builtin_alignment(b)); +} + +} // namespace tinytc::spv + diff --git a/src/spv/uniquifier.hpp b/src/spv/uniquifier.hpp new file mode 100644 index 00000000..c67262bb --- /dev/null +++ b/src/spv/uniquifier.hpp @@ -0,0 +1,98 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef UNIQUIFIER_20241107_HPP +#define UNIQUIFIER_20241107_HPP + +#include "spv/defs.hpp" +#include "spv/enums.hpp" +#include "tinytc/core.hpp" +#include "tinytc/types.h" +#include "util/fnv1a.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace tinytc { +enum class address_space; +enum class vector_size; +} // namespace tinytc + +namespace tinytc::spv { + +auto address_space_to_storage_class(address_space as) -> StorageClass; + +class uniquifier { + public: + uniquifier(tinytc_spv_mod &m); + + inline auto mod() -> tinytc_spv_mod & { return *mod_; } + + auto asm_target() -> spv_inst *; + auto bool_constant(bool b) -> spv_inst *; + auto builtin_alignment(BuiltIn b) -> std::int32_t; + auto builtin_pointee_ty(BuiltIn b) -> spv_inst *; + auto builtin_var(BuiltIn b) -> spv_inst *; + void capability(Capability cap); + auto constant(LiteralContextDependentNumber cst) -> spv_inst *; + void extension(char const *ext_name); + auto null_constant(spv_inst *spv_ty) -> spv_inst *; + auto opencl_ext() -> spv_inst *; + + // types + auto array_ty(spv_inst *element_ty, std::int32_t length) -> spv_inst *; + auto bool_ty() -> spv_inst *; + auto float_ty(std::int32_t width) -> spv_inst *; + auto function_ty(spv_inst *return_ty, array_view params) -> spv_inst *; + auto int_ty(std::int32_t width) -> spv_inst *; + auto pointer_ty(StorageClass cls, spv_inst *pointee_ty, std::int32_t alignment) -> spv_inst *; + auto vec_ty(spv_inst *component_ty, std::int32_t length) -> spv_inst *; + auto vec_ty(spv_inst *component_ty, vector_size length) -> spv_inst *; + auto void_ty() -> spv_inst *; + + // util + auto load_builtin(BuiltIn b) -> spv_inst *; + + private: + struct array_key_hash { + inline auto operator()(std::pair const &key) const + -> std::size_t { + return fnv1a_combine(key.first, key.second); + } + }; + struct pointer_key_hash { + inline auto operator()(std::tuple const &key) const + -> std::size_t { + return fnv1a_combine(std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + }; + + tinytc_spv_mod_t mod_; + spv_inst *asm_target_ = nullptr; + spv_inst *bool_true_ = nullptr, *bool_false_ = nullptr; + spv_inst *opencl_ext_ = nullptr; + std::unordered_map builtin_; + std::unordered_set capabilities_; + std::unordered_map cst_map_; + std::unordered_set extensions_; + std::unordered_map null_cst_; + + std::unordered_map, spv_inst *, array_key_hash> array_tys_; + spv_inst *bool_ty_ = nullptr; + std::unordered_multimap function_tys_; + std::unordered_map, spv_inst *, + pointer_key_hash> + pointer_tys_; + std::unordered_map int_tys_, float_tys_; + std::unordered_map, spv_inst *, array_key_hash> vec_tys_; + spv_inst *void_ty_ = nullptr; +}; + +} // namespace tinytc::spv + +#endif // UNIQUIFIER_20241107_HPP diff --git a/src/spv/visit.hpp b/src/spv/visit.hpp new file mode 100644 index 00000000..c18f728a --- /dev/null +++ b/src/spv/visit.hpp @@ -0,0 +1,4579 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +// This file is generated +// Do not edit manually + +#ifndef GENERATED_VISIT_20250630_HPP +#define GENERATED_VISIT_20250630_HPP + +#include "defs.hpp" +#include "enums.hpp" +#include "instructions.hpp" +#include "util/overloaded.hpp" + +namespace tinytc::spv { + +template auto visit(Visitor &&visitor, spv_inst &inst) { + switch (inst.opcode()) { + case Op::Nop: + return visitor(static_cast(inst)); + case Op::Undef: + return visitor(static_cast(inst)); + case Op::SourceContinued: + return visitor(static_cast(inst)); + case Op::Source: + return visitor(static_cast(inst)); + case Op::SourceExtension: + return visitor(static_cast(inst)); + case Op::Name: + return visitor(static_cast(inst)); + case Op::MemberName: + return visitor(static_cast(inst)); + case Op::String: + return visitor(static_cast(inst)); + case Op::Line: + return visitor(static_cast(inst)); + case Op::Extension: + return visitor(static_cast(inst)); + case Op::ExtInstImport: + return visitor(static_cast(inst)); + case Op::ExtInst: + return visitor(static_cast(inst)); + case Op::MemoryModel: + return visitor(static_cast(inst)); + case Op::EntryPoint: + return visitor(static_cast(inst)); + case Op::ExecutionMode: + return visitor(static_cast(inst)); + case Op::Capability: + return visitor(static_cast(inst)); + case Op::TypeVoid: + return visitor(static_cast(inst)); + case Op::TypeBool: + return visitor(static_cast(inst)); + case Op::TypeInt: + return visitor(static_cast(inst)); + case Op::TypeFloat: + return visitor(static_cast(inst)); + case Op::TypeVector: + return visitor(static_cast(inst)); + case Op::TypeMatrix: + return visitor(static_cast(inst)); + case Op::TypeImage: + return visitor(static_cast(inst)); + case Op::TypeSampler: + return visitor(static_cast(inst)); + case Op::TypeSampledImage: + return visitor(static_cast(inst)); + case Op::TypeArray: + return visitor(static_cast(inst)); + case Op::TypeRuntimeArray: + return visitor(static_cast(inst)); + case Op::TypeStruct: + return visitor(static_cast(inst)); + case Op::TypeOpaque: + return visitor(static_cast(inst)); + case Op::TypePointer: + return visitor(static_cast(inst)); + case Op::TypeFunction: + return visitor(static_cast(inst)); + case Op::TypeEvent: + return visitor(static_cast(inst)); + case Op::TypeDeviceEvent: + return visitor(static_cast(inst)); + case Op::TypeReserveId: + return visitor(static_cast(inst)); + case Op::TypeQueue: + return visitor(static_cast(inst)); + case Op::TypePipe: + return visitor(static_cast(inst)); + case Op::TypeForwardPointer: + return visitor(static_cast(inst)); + case Op::ConstantTrue: + return visitor(static_cast(inst)); + case Op::ConstantFalse: + return visitor(static_cast(inst)); + case Op::Constant: + return visitor(static_cast(inst)); + case Op::ConstantComposite: + return visitor(static_cast(inst)); + case Op::ConstantSampler: + return visitor(static_cast(inst)); + case Op::ConstantNull: + return visitor(static_cast(inst)); + case Op::Function: + return visitor(static_cast(inst)); + case Op::FunctionParameter: + return visitor(static_cast(inst)); + case Op::FunctionEnd: + return visitor(static_cast(inst)); + case Op::FunctionCall: + return visitor(static_cast(inst)); + case Op::Variable: + return visitor(static_cast(inst)); + case Op::ImageTexelPointer: + return visitor(static_cast(inst)); + case Op::Load: + return visitor(static_cast(inst)); + case Op::Store: + return visitor(static_cast(inst)); + case Op::CopyMemory: + return visitor(static_cast(inst)); + case Op::CopyMemorySized: + return visitor(static_cast(inst)); + case Op::AccessChain: + return visitor(static_cast(inst)); + case Op::InBoundsAccessChain: + return visitor(static_cast(inst)); + case Op::PtrAccessChain: + return visitor(static_cast(inst)); + case Op::ArrayLength: + return visitor(static_cast(inst)); + case Op::GenericPtrMemSemantics: + return visitor(static_cast(inst)); + case Op::InBoundsPtrAccessChain: + return visitor(static_cast(inst)); + case Op::Decorate: + return visitor(static_cast(inst)); + case Op::MemberDecorate: + return visitor(static_cast(inst)); + case Op::DecorationGroup: + return visitor(static_cast(inst)); + case Op::GroupDecorate: + return visitor(static_cast(inst)); + case Op::GroupMemberDecorate: + return visitor(static_cast(inst)); + case Op::VectorExtractDynamic: + return visitor(static_cast(inst)); + case Op::VectorInsertDynamic: + return visitor(static_cast(inst)); + case Op::VectorShuffle: + return visitor(static_cast(inst)); + case Op::CompositeConstruct: + return visitor(static_cast(inst)); + case Op::CompositeExtract: + return visitor(static_cast(inst)); + case Op::CompositeInsert: + return visitor(static_cast(inst)); + case Op::CopyObject: + return visitor(static_cast(inst)); + case Op::Transpose: + return visitor(static_cast(inst)); + case Op::SampledImage: + return visitor(static_cast(inst)); + case Op::ImageSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageFetch: + return visitor(static_cast(inst)); + case Op::ImageGather: + return visitor(static_cast(inst)); + case Op::ImageDrefGather: + return visitor(static_cast(inst)); + case Op::ImageRead: + return visitor(static_cast(inst)); + case Op::ImageWrite: + return visitor(static_cast(inst)); + case Op::Image: + return visitor(static_cast(inst)); + case Op::ImageQueryFormat: + return visitor(static_cast(inst)); + case Op::ImageQueryOrder: + return visitor(static_cast(inst)); + case Op::ImageQuerySizeLod: + return visitor(static_cast(inst)); + case Op::ImageQuerySize: + return visitor(static_cast(inst)); + case Op::ImageQueryLod: + return visitor(static_cast(inst)); + case Op::ImageQueryLevels: + return visitor(static_cast(inst)); + case Op::ImageQuerySamples: + return visitor(static_cast(inst)); + case Op::ConvertFToU: + return visitor(static_cast(inst)); + case Op::ConvertFToS: + return visitor(static_cast(inst)); + case Op::ConvertSToF: + return visitor(static_cast(inst)); + case Op::ConvertUToF: + return visitor(static_cast(inst)); + case Op::UConvert: + return visitor(static_cast(inst)); + case Op::SConvert: + return visitor(static_cast(inst)); + case Op::FConvert: + return visitor(static_cast(inst)); + case Op::QuantizeToF16: + return visitor(static_cast(inst)); + case Op::ConvertPtrToU: + return visitor(static_cast(inst)); + case Op::SatConvertSToU: + return visitor(static_cast(inst)); + case Op::SatConvertUToS: + return visitor(static_cast(inst)); + case Op::ConvertUToPtr: + return visitor(static_cast(inst)); + case Op::PtrCastToGeneric: + return visitor(static_cast(inst)); + case Op::GenericCastToPtr: + return visitor(static_cast(inst)); + case Op::GenericCastToPtrExplicit: + return visitor(static_cast(inst)); + case Op::Bitcast: + return visitor(static_cast(inst)); + case Op::SNegate: + return visitor(static_cast(inst)); + case Op::FNegate: + return visitor(static_cast(inst)); + case Op::IAdd: + return visitor(static_cast(inst)); + case Op::FAdd: + return visitor(static_cast(inst)); + case Op::ISub: + return visitor(static_cast(inst)); + case Op::FSub: + return visitor(static_cast(inst)); + case Op::IMul: + return visitor(static_cast(inst)); + case Op::FMul: + return visitor(static_cast(inst)); + case Op::UDiv: + return visitor(static_cast(inst)); + case Op::SDiv: + return visitor(static_cast(inst)); + case Op::FDiv: + return visitor(static_cast(inst)); + case Op::UMod: + return visitor(static_cast(inst)); + case Op::SRem: + return visitor(static_cast(inst)); + case Op::SMod: + return visitor(static_cast(inst)); + case Op::FRem: + return visitor(static_cast(inst)); + case Op::FMod: + return visitor(static_cast(inst)); + case Op::VectorTimesScalar: + return visitor(static_cast(inst)); + case Op::MatrixTimesScalar: + return visitor(static_cast(inst)); + case Op::VectorTimesMatrix: + return visitor(static_cast(inst)); + case Op::MatrixTimesVector: + return visitor(static_cast(inst)); + case Op::MatrixTimesMatrix: + return visitor(static_cast(inst)); + case Op::OuterProduct: + return visitor(static_cast(inst)); + case Op::Dot: + return visitor(static_cast(inst)); + case Op::IAddCarry: + return visitor(static_cast(inst)); + case Op::ISubBorrow: + return visitor(static_cast(inst)); + case Op::UMulExtended: + return visitor(static_cast(inst)); + case Op::SMulExtended: + return visitor(static_cast(inst)); + case Op::Any: + return visitor(static_cast(inst)); + case Op::All: + return visitor(static_cast(inst)); + case Op::IsNan: + return visitor(static_cast(inst)); + case Op::IsInf: + return visitor(static_cast(inst)); + case Op::IsFinite: + return visitor(static_cast(inst)); + case Op::IsNormal: + return visitor(static_cast(inst)); + case Op::SignBitSet: + return visitor(static_cast(inst)); + case Op::LessOrGreater: + return visitor(static_cast(inst)); + case Op::Ordered: + return visitor(static_cast(inst)); + case Op::Unordered: + return visitor(static_cast(inst)); + case Op::LogicalEqual: + return visitor(static_cast(inst)); + case Op::LogicalNotEqual: + return visitor(static_cast(inst)); + case Op::LogicalOr: + return visitor(static_cast(inst)); + case Op::LogicalAnd: + return visitor(static_cast(inst)); + case Op::LogicalNot: + return visitor(static_cast(inst)); + case Op::Select: + return visitor(static_cast(inst)); + case Op::IEqual: + return visitor(static_cast(inst)); + case Op::INotEqual: + return visitor(static_cast(inst)); + case Op::UGreaterThan: + return visitor(static_cast(inst)); + case Op::SGreaterThan: + return visitor(static_cast(inst)); + case Op::UGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::SGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ULessThan: + return visitor(static_cast(inst)); + case Op::SLessThan: + return visitor(static_cast(inst)); + case Op::ULessThanEqual: + return visitor(static_cast(inst)); + case Op::SLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdEqual: + return visitor(static_cast(inst)); + case Op::FUnordEqual: + return visitor(static_cast(inst)); + case Op::FOrdNotEqual: + return visitor(static_cast(inst)); + case Op::FUnordNotEqual: + return visitor(static_cast(inst)); + case Op::FOrdLessThan: + return visitor(static_cast(inst)); + case Op::FUnordLessThan: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThan: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThan: + return visitor(static_cast(inst)); + case Op::FOrdLessThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ShiftRightLogical: + return visitor(static_cast(inst)); + case Op::ShiftRightArithmetic: + return visitor(static_cast(inst)); + case Op::ShiftLeftLogical: + return visitor(static_cast(inst)); + case Op::BitwiseOr: + return visitor(static_cast(inst)); + case Op::BitwiseXor: + return visitor(static_cast(inst)); + case Op::BitwiseAnd: + return visitor(static_cast(inst)); + case Op::Not: + return visitor(static_cast(inst)); + case Op::BitFieldInsert: + return visitor(static_cast(inst)); + case Op::BitFieldSExtract: + return visitor(static_cast(inst)); + case Op::BitFieldUExtract: + return visitor(static_cast(inst)); + case Op::BitReverse: + return visitor(static_cast(inst)); + case Op::BitCount: + return visitor(static_cast(inst)); + case Op::DPdx: + return visitor(static_cast(inst)); + case Op::DPdy: + return visitor(static_cast(inst)); + case Op::Fwidth: + return visitor(static_cast(inst)); + case Op::DPdxFine: + return visitor(static_cast(inst)); + case Op::DPdyFine: + return visitor(static_cast(inst)); + case Op::FwidthFine: + return visitor(static_cast(inst)); + case Op::DPdxCoarse: + return visitor(static_cast(inst)); + case Op::DPdyCoarse: + return visitor(static_cast(inst)); + case Op::FwidthCoarse: + return visitor(static_cast(inst)); + case Op::EmitVertex: + return visitor(static_cast(inst)); + case Op::EndPrimitive: + return visitor(static_cast(inst)); + case Op::EmitStreamVertex: + return visitor(static_cast(inst)); + case Op::EndStreamPrimitive: + return visitor(static_cast(inst)); + case Op::ControlBarrier: + return visitor(static_cast(inst)); + case Op::MemoryBarrier: + return visitor(static_cast(inst)); + case Op::AtomicLoad: + return visitor(static_cast(inst)); + case Op::AtomicStore: + return visitor(static_cast(inst)); + case Op::AtomicExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchangeWeak: + return visitor(static_cast(inst)); + case Op::AtomicIIncrement: + return visitor(static_cast(inst)); + case Op::AtomicIDecrement: + return visitor(static_cast(inst)); + case Op::AtomicIAdd: + return visitor(static_cast(inst)); + case Op::AtomicISub: + return visitor(static_cast(inst)); + case Op::AtomicSMin: + return visitor(static_cast(inst)); + case Op::AtomicUMin: + return visitor(static_cast(inst)); + case Op::AtomicSMax: + return visitor(static_cast(inst)); + case Op::AtomicUMax: + return visitor(static_cast(inst)); + case Op::AtomicAnd: + return visitor(static_cast(inst)); + case Op::AtomicOr: + return visitor(static_cast(inst)); + case Op::AtomicXor: + return visitor(static_cast(inst)); + case Op::Phi: + return visitor(static_cast(inst)); + case Op::LoopMerge: + return visitor(static_cast(inst)); + case Op::SelectionMerge: + return visitor(static_cast(inst)); + case Op::Label: + return visitor(static_cast(inst)); + case Op::Branch: + return visitor(static_cast(inst)); + case Op::BranchConditional: + return visitor(static_cast(inst)); + case Op::Switch: + return visitor(static_cast(inst)); + case Op::Kill: + return visitor(static_cast(inst)); + case Op::Return: + return visitor(static_cast(inst)); + case Op::ReturnValue: + return visitor(static_cast(inst)); + case Op::Unreachable: + return visitor(static_cast(inst)); + case Op::LifetimeStart: + return visitor(static_cast(inst)); + case Op::LifetimeStop: + return visitor(static_cast(inst)); + case Op::GroupAsyncCopy: + return visitor(static_cast(inst)); + case Op::GroupWaitEvents: + return visitor(static_cast(inst)); + case Op::GroupAll: + return visitor(static_cast(inst)); + case Op::GroupAny: + return visitor(static_cast(inst)); + case Op::GroupBroadcast: + return visitor(static_cast(inst)); + case Op::GroupIAdd: + return visitor(static_cast(inst)); + case Op::GroupFAdd: + return visitor(static_cast(inst)); + case Op::GroupFMin: + return visitor(static_cast(inst)); + case Op::GroupUMin: + return visitor(static_cast(inst)); + case Op::GroupSMin: + return visitor(static_cast(inst)); + case Op::GroupFMax: + return visitor(static_cast(inst)); + case Op::GroupUMax: + return visitor(static_cast(inst)); + case Op::GroupSMax: + return visitor(static_cast(inst)); + case Op::ReadPipe: + return visitor(static_cast(inst)); + case Op::WritePipe: + return visitor(static_cast(inst)); + case Op::ReservedReadPipe: + return visitor(static_cast(inst)); + case Op::ReservedWritePipe: + return visitor(static_cast(inst)); + case Op::ReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::ReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::CommitReadPipe: + return visitor(static_cast(inst)); + case Op::CommitWritePipe: + return visitor(static_cast(inst)); + case Op::IsValidReserveId: + return visitor(static_cast(inst)); + case Op::GetNumPipePackets: + return visitor(static_cast(inst)); + case Op::GetMaxPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::GroupCommitReadPipe: + return visitor(static_cast(inst)); + case Op::GroupCommitWritePipe: + return visitor(static_cast(inst)); + case Op::EnqueueMarker: + return visitor(static_cast(inst)); + case Op::EnqueueKernel: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeSubGroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeMaxSubGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelWorkGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return visitor(static_cast(inst)); + case Op::RetainEvent: + return visitor(static_cast(inst)); + case Op::ReleaseEvent: + return visitor(static_cast(inst)); + case Op::CreateUserEvent: + return visitor(static_cast(inst)); + case Op::IsValidEvent: + return visitor(static_cast(inst)); + case Op::SetUserEventStatus: + return visitor(static_cast(inst)); + case Op::CaptureEventProfilingInfo: + return visitor(static_cast(inst)); + case Op::GetDefaultQueue: + return visitor(static_cast(inst)); + case Op::BuildNDRange: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseFetch: + return visitor(static_cast(inst)); + case Op::ImageSparseGather: + return visitor(static_cast(inst)); + case Op::ImageSparseDrefGather: + return visitor(static_cast(inst)); + case Op::ImageSparseTexelsResident: + return visitor(static_cast(inst)); + case Op::NoLine: + return visitor(static_cast(inst)); + case Op::AtomicFlagTestAndSet: + return visitor(static_cast(inst)); + case Op::AtomicFlagClear: + return visitor(static_cast(inst)); + case Op::ImageSparseRead: + return visitor(static_cast(inst)); + case Op::SizeOf: + return visitor(static_cast(inst)); + case Op::TypePipeStorage: + return visitor(static_cast(inst)); + case Op::ConstantPipeStorage: + return visitor(static_cast(inst)); + case Op::CreatePipeFromPipeStorage: + return visitor(static_cast(inst)); + case Op::GetKernelLocalSizeForSubgroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelMaxNumSubgroups: + return visitor(static_cast(inst)); + case Op::TypeNamedBarrier: + return visitor(static_cast(inst)); + case Op::NamedBarrierInitialize: + return visitor(static_cast(inst)); + case Op::MemoryNamedBarrier: + return visitor(static_cast(inst)); + case Op::ModuleProcessed: + return visitor(static_cast(inst)); + case Op::ExecutionModeId: + return visitor(static_cast(inst)); + case Op::DecorateId: + return visitor(static_cast(inst)); + case Op::GroupNonUniformElect: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAll: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAny: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAllEqual: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcastFirst: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformInverseBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitExtract: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitCount: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindLSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindMSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffle: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleUp: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleDown: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadSwap: + return visitor(static_cast(inst)); + case Op::CopyLogical: + return visitor(static_cast(inst)); + case Op::PtrEqual: + return visitor(static_cast(inst)); + case Op::PtrNotEqual: + return visitor(static_cast(inst)); + case Op::PtrDiff: + return visitor(static_cast(inst)); + case Op::TypeCooperativeMatrixKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixMulAddKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLengthKHR: + return visitor(static_cast(inst)); + case Op::SubgroupBlockReadINTEL: + return visitor(static_cast(inst)); + case Op::SubgroupBlockWriteINTEL: + return visitor(static_cast(inst)); + case Op::AsmTargetINTEL: + return visitor(static_cast(inst)); + case Op::AsmINTEL: + return visitor(static_cast(inst)); + case Op::AsmCallINTEL: + return visitor(static_cast(inst)); + case Op::AtomicFMinEXT: + return visitor(static_cast(inst)); + case Op::AtomicFMaxEXT: + return visitor(static_cast(inst)); + case Op::AtomicFAddEXT: + return visitor(static_cast(inst)); + case Op::ConvertFToBF16INTEL: + return visitor(static_cast(inst)); + case Op::ConvertBF16ToFINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierArriveINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierWaitINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadCheckedINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreCheckedINTEL: + return visitor(static_cast(inst)); + } + throw internal_compiler_error(); +} + +template auto visit(Visitor &&visitor, spv_inst const &inst) { + switch (inst.opcode()) { + case Op::Nop: + return visitor(static_cast(inst)); + case Op::Undef: + return visitor(static_cast(inst)); + case Op::SourceContinued: + return visitor(static_cast(inst)); + case Op::Source: + return visitor(static_cast(inst)); + case Op::SourceExtension: + return visitor(static_cast(inst)); + case Op::Name: + return visitor(static_cast(inst)); + case Op::MemberName: + return visitor(static_cast(inst)); + case Op::String: + return visitor(static_cast(inst)); + case Op::Line: + return visitor(static_cast(inst)); + case Op::Extension: + return visitor(static_cast(inst)); + case Op::ExtInstImport: + return visitor(static_cast(inst)); + case Op::ExtInst: + return visitor(static_cast(inst)); + case Op::MemoryModel: + return visitor(static_cast(inst)); + case Op::EntryPoint: + return visitor(static_cast(inst)); + case Op::ExecutionMode: + return visitor(static_cast(inst)); + case Op::Capability: + return visitor(static_cast(inst)); + case Op::TypeVoid: + return visitor(static_cast(inst)); + case Op::TypeBool: + return visitor(static_cast(inst)); + case Op::TypeInt: + return visitor(static_cast(inst)); + case Op::TypeFloat: + return visitor(static_cast(inst)); + case Op::TypeVector: + return visitor(static_cast(inst)); + case Op::TypeMatrix: + return visitor(static_cast(inst)); + case Op::TypeImage: + return visitor(static_cast(inst)); + case Op::TypeSampler: + return visitor(static_cast(inst)); + case Op::TypeSampledImage: + return visitor(static_cast(inst)); + case Op::TypeArray: + return visitor(static_cast(inst)); + case Op::TypeRuntimeArray: + return visitor(static_cast(inst)); + case Op::TypeStruct: + return visitor(static_cast(inst)); + case Op::TypeOpaque: + return visitor(static_cast(inst)); + case Op::TypePointer: + return visitor(static_cast(inst)); + case Op::TypeFunction: + return visitor(static_cast(inst)); + case Op::TypeEvent: + return visitor(static_cast(inst)); + case Op::TypeDeviceEvent: + return visitor(static_cast(inst)); + case Op::TypeReserveId: + return visitor(static_cast(inst)); + case Op::TypeQueue: + return visitor(static_cast(inst)); + case Op::TypePipe: + return visitor(static_cast(inst)); + case Op::TypeForwardPointer: + return visitor(static_cast(inst)); + case Op::ConstantTrue: + return visitor(static_cast(inst)); + case Op::ConstantFalse: + return visitor(static_cast(inst)); + case Op::Constant: + return visitor(static_cast(inst)); + case Op::ConstantComposite: + return visitor(static_cast(inst)); + case Op::ConstantSampler: + return visitor(static_cast(inst)); + case Op::ConstantNull: + return visitor(static_cast(inst)); + case Op::Function: + return visitor(static_cast(inst)); + case Op::FunctionParameter: + return visitor(static_cast(inst)); + case Op::FunctionEnd: + return visitor(static_cast(inst)); + case Op::FunctionCall: + return visitor(static_cast(inst)); + case Op::Variable: + return visitor(static_cast(inst)); + case Op::ImageTexelPointer: + return visitor(static_cast(inst)); + case Op::Load: + return visitor(static_cast(inst)); + case Op::Store: + return visitor(static_cast(inst)); + case Op::CopyMemory: + return visitor(static_cast(inst)); + case Op::CopyMemorySized: + return visitor(static_cast(inst)); + case Op::AccessChain: + return visitor(static_cast(inst)); + case Op::InBoundsAccessChain: + return visitor(static_cast(inst)); + case Op::PtrAccessChain: + return visitor(static_cast(inst)); + case Op::ArrayLength: + return visitor(static_cast(inst)); + case Op::GenericPtrMemSemantics: + return visitor(static_cast(inst)); + case Op::InBoundsPtrAccessChain: + return visitor(static_cast(inst)); + case Op::Decorate: + return visitor(static_cast(inst)); + case Op::MemberDecorate: + return visitor(static_cast(inst)); + case Op::DecorationGroup: + return visitor(static_cast(inst)); + case Op::GroupDecorate: + return visitor(static_cast(inst)); + case Op::GroupMemberDecorate: + return visitor(static_cast(inst)); + case Op::VectorExtractDynamic: + return visitor(static_cast(inst)); + case Op::VectorInsertDynamic: + return visitor(static_cast(inst)); + case Op::VectorShuffle: + return visitor(static_cast(inst)); + case Op::CompositeConstruct: + return visitor(static_cast(inst)); + case Op::CompositeExtract: + return visitor(static_cast(inst)); + case Op::CompositeInsert: + return visitor(static_cast(inst)); + case Op::CopyObject: + return visitor(static_cast(inst)); + case Op::Transpose: + return visitor(static_cast(inst)); + case Op::SampledImage: + return visitor(static_cast(inst)); + case Op::ImageSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageFetch: + return visitor(static_cast(inst)); + case Op::ImageGather: + return visitor(static_cast(inst)); + case Op::ImageDrefGather: + return visitor(static_cast(inst)); + case Op::ImageRead: + return visitor(static_cast(inst)); + case Op::ImageWrite: + return visitor(static_cast(inst)); + case Op::Image: + return visitor(static_cast(inst)); + case Op::ImageQueryFormat: + return visitor(static_cast(inst)); + case Op::ImageQueryOrder: + return visitor(static_cast(inst)); + case Op::ImageQuerySizeLod: + return visitor(static_cast(inst)); + case Op::ImageQuerySize: + return visitor(static_cast(inst)); + case Op::ImageQueryLod: + return visitor(static_cast(inst)); + case Op::ImageQueryLevels: + return visitor(static_cast(inst)); + case Op::ImageQuerySamples: + return visitor(static_cast(inst)); + case Op::ConvertFToU: + return visitor(static_cast(inst)); + case Op::ConvertFToS: + return visitor(static_cast(inst)); + case Op::ConvertSToF: + return visitor(static_cast(inst)); + case Op::ConvertUToF: + return visitor(static_cast(inst)); + case Op::UConvert: + return visitor(static_cast(inst)); + case Op::SConvert: + return visitor(static_cast(inst)); + case Op::FConvert: + return visitor(static_cast(inst)); + case Op::QuantizeToF16: + return visitor(static_cast(inst)); + case Op::ConvertPtrToU: + return visitor(static_cast(inst)); + case Op::SatConvertSToU: + return visitor(static_cast(inst)); + case Op::SatConvertUToS: + return visitor(static_cast(inst)); + case Op::ConvertUToPtr: + return visitor(static_cast(inst)); + case Op::PtrCastToGeneric: + return visitor(static_cast(inst)); + case Op::GenericCastToPtr: + return visitor(static_cast(inst)); + case Op::GenericCastToPtrExplicit: + return visitor(static_cast(inst)); + case Op::Bitcast: + return visitor(static_cast(inst)); + case Op::SNegate: + return visitor(static_cast(inst)); + case Op::FNegate: + return visitor(static_cast(inst)); + case Op::IAdd: + return visitor(static_cast(inst)); + case Op::FAdd: + return visitor(static_cast(inst)); + case Op::ISub: + return visitor(static_cast(inst)); + case Op::FSub: + return visitor(static_cast(inst)); + case Op::IMul: + return visitor(static_cast(inst)); + case Op::FMul: + return visitor(static_cast(inst)); + case Op::UDiv: + return visitor(static_cast(inst)); + case Op::SDiv: + return visitor(static_cast(inst)); + case Op::FDiv: + return visitor(static_cast(inst)); + case Op::UMod: + return visitor(static_cast(inst)); + case Op::SRem: + return visitor(static_cast(inst)); + case Op::SMod: + return visitor(static_cast(inst)); + case Op::FRem: + return visitor(static_cast(inst)); + case Op::FMod: + return visitor(static_cast(inst)); + case Op::VectorTimesScalar: + return visitor(static_cast(inst)); + case Op::MatrixTimesScalar: + return visitor(static_cast(inst)); + case Op::VectorTimesMatrix: + return visitor(static_cast(inst)); + case Op::MatrixTimesVector: + return visitor(static_cast(inst)); + case Op::MatrixTimesMatrix: + return visitor(static_cast(inst)); + case Op::OuterProduct: + return visitor(static_cast(inst)); + case Op::Dot: + return visitor(static_cast(inst)); + case Op::IAddCarry: + return visitor(static_cast(inst)); + case Op::ISubBorrow: + return visitor(static_cast(inst)); + case Op::UMulExtended: + return visitor(static_cast(inst)); + case Op::SMulExtended: + return visitor(static_cast(inst)); + case Op::Any: + return visitor(static_cast(inst)); + case Op::All: + return visitor(static_cast(inst)); + case Op::IsNan: + return visitor(static_cast(inst)); + case Op::IsInf: + return visitor(static_cast(inst)); + case Op::IsFinite: + return visitor(static_cast(inst)); + case Op::IsNormal: + return visitor(static_cast(inst)); + case Op::SignBitSet: + return visitor(static_cast(inst)); + case Op::LessOrGreater: + return visitor(static_cast(inst)); + case Op::Ordered: + return visitor(static_cast(inst)); + case Op::Unordered: + return visitor(static_cast(inst)); + case Op::LogicalEqual: + return visitor(static_cast(inst)); + case Op::LogicalNotEqual: + return visitor(static_cast(inst)); + case Op::LogicalOr: + return visitor(static_cast(inst)); + case Op::LogicalAnd: + return visitor(static_cast(inst)); + case Op::LogicalNot: + return visitor(static_cast(inst)); + case Op::Select: + return visitor(static_cast(inst)); + case Op::IEqual: + return visitor(static_cast(inst)); + case Op::INotEqual: + return visitor(static_cast(inst)); + case Op::UGreaterThan: + return visitor(static_cast(inst)); + case Op::SGreaterThan: + return visitor(static_cast(inst)); + case Op::UGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::SGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ULessThan: + return visitor(static_cast(inst)); + case Op::SLessThan: + return visitor(static_cast(inst)); + case Op::ULessThanEqual: + return visitor(static_cast(inst)); + case Op::SLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdEqual: + return visitor(static_cast(inst)); + case Op::FUnordEqual: + return visitor(static_cast(inst)); + case Op::FOrdNotEqual: + return visitor(static_cast(inst)); + case Op::FUnordNotEqual: + return visitor(static_cast(inst)); + case Op::FOrdLessThan: + return visitor(static_cast(inst)); + case Op::FUnordLessThan: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThan: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThan: + return visitor(static_cast(inst)); + case Op::FOrdLessThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordLessThanEqual: + return visitor(static_cast(inst)); + case Op::FOrdGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::FUnordGreaterThanEqual: + return visitor(static_cast(inst)); + case Op::ShiftRightLogical: + return visitor(static_cast(inst)); + case Op::ShiftRightArithmetic: + return visitor(static_cast(inst)); + case Op::ShiftLeftLogical: + return visitor(static_cast(inst)); + case Op::BitwiseOr: + return visitor(static_cast(inst)); + case Op::BitwiseXor: + return visitor(static_cast(inst)); + case Op::BitwiseAnd: + return visitor(static_cast(inst)); + case Op::Not: + return visitor(static_cast(inst)); + case Op::BitFieldInsert: + return visitor(static_cast(inst)); + case Op::BitFieldSExtract: + return visitor(static_cast(inst)); + case Op::BitFieldUExtract: + return visitor(static_cast(inst)); + case Op::BitReverse: + return visitor(static_cast(inst)); + case Op::BitCount: + return visitor(static_cast(inst)); + case Op::DPdx: + return visitor(static_cast(inst)); + case Op::DPdy: + return visitor(static_cast(inst)); + case Op::Fwidth: + return visitor(static_cast(inst)); + case Op::DPdxFine: + return visitor(static_cast(inst)); + case Op::DPdyFine: + return visitor(static_cast(inst)); + case Op::FwidthFine: + return visitor(static_cast(inst)); + case Op::DPdxCoarse: + return visitor(static_cast(inst)); + case Op::DPdyCoarse: + return visitor(static_cast(inst)); + case Op::FwidthCoarse: + return visitor(static_cast(inst)); + case Op::EmitVertex: + return visitor(static_cast(inst)); + case Op::EndPrimitive: + return visitor(static_cast(inst)); + case Op::EmitStreamVertex: + return visitor(static_cast(inst)); + case Op::EndStreamPrimitive: + return visitor(static_cast(inst)); + case Op::ControlBarrier: + return visitor(static_cast(inst)); + case Op::MemoryBarrier: + return visitor(static_cast(inst)); + case Op::AtomicLoad: + return visitor(static_cast(inst)); + case Op::AtomicStore: + return visitor(static_cast(inst)); + case Op::AtomicExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchange: + return visitor(static_cast(inst)); + case Op::AtomicCompareExchangeWeak: + return visitor(static_cast(inst)); + case Op::AtomicIIncrement: + return visitor(static_cast(inst)); + case Op::AtomicIDecrement: + return visitor(static_cast(inst)); + case Op::AtomicIAdd: + return visitor(static_cast(inst)); + case Op::AtomicISub: + return visitor(static_cast(inst)); + case Op::AtomicSMin: + return visitor(static_cast(inst)); + case Op::AtomicUMin: + return visitor(static_cast(inst)); + case Op::AtomicSMax: + return visitor(static_cast(inst)); + case Op::AtomicUMax: + return visitor(static_cast(inst)); + case Op::AtomicAnd: + return visitor(static_cast(inst)); + case Op::AtomicOr: + return visitor(static_cast(inst)); + case Op::AtomicXor: + return visitor(static_cast(inst)); + case Op::Phi: + return visitor(static_cast(inst)); + case Op::LoopMerge: + return visitor(static_cast(inst)); + case Op::SelectionMerge: + return visitor(static_cast(inst)); + case Op::Label: + return visitor(static_cast(inst)); + case Op::Branch: + return visitor(static_cast(inst)); + case Op::BranchConditional: + return visitor(static_cast(inst)); + case Op::Switch: + return visitor(static_cast(inst)); + case Op::Kill: + return visitor(static_cast(inst)); + case Op::Return: + return visitor(static_cast(inst)); + case Op::ReturnValue: + return visitor(static_cast(inst)); + case Op::Unreachable: + return visitor(static_cast(inst)); + case Op::LifetimeStart: + return visitor(static_cast(inst)); + case Op::LifetimeStop: + return visitor(static_cast(inst)); + case Op::GroupAsyncCopy: + return visitor(static_cast(inst)); + case Op::GroupWaitEvents: + return visitor(static_cast(inst)); + case Op::GroupAll: + return visitor(static_cast(inst)); + case Op::GroupAny: + return visitor(static_cast(inst)); + case Op::GroupBroadcast: + return visitor(static_cast(inst)); + case Op::GroupIAdd: + return visitor(static_cast(inst)); + case Op::GroupFAdd: + return visitor(static_cast(inst)); + case Op::GroupFMin: + return visitor(static_cast(inst)); + case Op::GroupUMin: + return visitor(static_cast(inst)); + case Op::GroupSMin: + return visitor(static_cast(inst)); + case Op::GroupFMax: + return visitor(static_cast(inst)); + case Op::GroupUMax: + return visitor(static_cast(inst)); + case Op::GroupSMax: + return visitor(static_cast(inst)); + case Op::ReadPipe: + return visitor(static_cast(inst)); + case Op::WritePipe: + return visitor(static_cast(inst)); + case Op::ReservedReadPipe: + return visitor(static_cast(inst)); + case Op::ReservedWritePipe: + return visitor(static_cast(inst)); + case Op::ReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::ReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::CommitReadPipe: + return visitor(static_cast(inst)); + case Op::CommitWritePipe: + return visitor(static_cast(inst)); + case Op::IsValidReserveId: + return visitor(static_cast(inst)); + case Op::GetNumPipePackets: + return visitor(static_cast(inst)); + case Op::GetMaxPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveReadPipePackets: + return visitor(static_cast(inst)); + case Op::GroupReserveWritePipePackets: + return visitor(static_cast(inst)); + case Op::GroupCommitReadPipe: + return visitor(static_cast(inst)); + case Op::GroupCommitWritePipe: + return visitor(static_cast(inst)); + case Op::EnqueueMarker: + return visitor(static_cast(inst)); + case Op::EnqueueKernel: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeSubGroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelNDrangeMaxSubGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelWorkGroupSize: + return visitor(static_cast(inst)); + case Op::GetKernelPreferredWorkGroupSizeMultiple: + return visitor(static_cast(inst)); + case Op::RetainEvent: + return visitor(static_cast(inst)); + case Op::ReleaseEvent: + return visitor(static_cast(inst)); + case Op::CreateUserEvent: + return visitor(static_cast(inst)); + case Op::IsValidEvent: + return visitor(static_cast(inst)); + case Op::SetUserEventStatus: + return visitor(static_cast(inst)); + case Op::CaptureEventProfilingInfo: + return visitor(static_cast(inst)); + case Op::GetDefaultQueue: + return visitor(static_cast(inst)); + case Op::BuildNDRange: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefImplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseSampleProjDrefExplicitLod: + return visitor(static_cast(inst)); + case Op::ImageSparseFetch: + return visitor(static_cast(inst)); + case Op::ImageSparseGather: + return visitor(static_cast(inst)); + case Op::ImageSparseDrefGather: + return visitor(static_cast(inst)); + case Op::ImageSparseTexelsResident: + return visitor(static_cast(inst)); + case Op::NoLine: + return visitor(static_cast(inst)); + case Op::AtomicFlagTestAndSet: + return visitor(static_cast(inst)); + case Op::AtomicFlagClear: + return visitor(static_cast(inst)); + case Op::ImageSparseRead: + return visitor(static_cast(inst)); + case Op::SizeOf: + return visitor(static_cast(inst)); + case Op::TypePipeStorage: + return visitor(static_cast(inst)); + case Op::ConstantPipeStorage: + return visitor(static_cast(inst)); + case Op::CreatePipeFromPipeStorage: + return visitor(static_cast(inst)); + case Op::GetKernelLocalSizeForSubgroupCount: + return visitor(static_cast(inst)); + case Op::GetKernelMaxNumSubgroups: + return visitor(static_cast(inst)); + case Op::TypeNamedBarrier: + return visitor(static_cast(inst)); + case Op::NamedBarrierInitialize: + return visitor(static_cast(inst)); + case Op::MemoryNamedBarrier: + return visitor(static_cast(inst)); + case Op::ModuleProcessed: + return visitor(static_cast(inst)); + case Op::ExecutionModeId: + return visitor(static_cast(inst)); + case Op::DecorateId: + return visitor(static_cast(inst)); + case Op::GroupNonUniformElect: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAll: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAny: + return visitor(static_cast(inst)); + case Op::GroupNonUniformAllEqual: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBroadcastFirst: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformInverseBallot: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitExtract: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotBitCount: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindLSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBallotFindMSB: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffle: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleUp: + return visitor(static_cast(inst)); + case Op::GroupNonUniformShuffleDown: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFAdd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformIMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMul: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMin: + return visitor(static_cast(inst)); + case Op::GroupNonUniformSMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformUMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformFMax: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformBitwiseXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalAnd: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalOr: + return visitor(static_cast(inst)); + case Op::GroupNonUniformLogicalXor: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadBroadcast: + return visitor(static_cast(inst)); + case Op::GroupNonUniformQuadSwap: + return visitor(static_cast(inst)); + case Op::CopyLogical: + return visitor(static_cast(inst)); + case Op::PtrEqual: + return visitor(static_cast(inst)); + case Op::PtrNotEqual: + return visitor(static_cast(inst)); + case Op::PtrDiff: + return visitor(static_cast(inst)); + case Op::TypeCooperativeMatrixKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixMulAddKHR: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLengthKHR: + return visitor(static_cast(inst)); + case Op::SubgroupBlockReadINTEL: + return visitor(static_cast(inst)); + case Op::SubgroupBlockWriteINTEL: + return visitor(static_cast(inst)); + case Op::AsmTargetINTEL: + return visitor(static_cast(inst)); + case Op::AsmINTEL: + return visitor(static_cast(inst)); + case Op::AsmCallINTEL: + return visitor(static_cast(inst)); + case Op::AtomicFMinEXT: + return visitor(static_cast(inst)); + case Op::AtomicFMaxEXT: + return visitor(static_cast(inst)); + case Op::AtomicFAddEXT: + return visitor(static_cast(inst)); + case Op::ConvertFToBF16INTEL: + return visitor(static_cast(inst)); + case Op::ConvertBF16ToFINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierArriveINTEL: + return visitor(static_cast(inst)); + case Op::ControlBarrierWaitINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixLoadCheckedINTEL: + return visitor(static_cast(inst)); + case Op::CooperativeMatrixStoreCheckedINTEL: + return visitor(static_cast(inst)); + } + throw internal_compiler_error(); +} + +template class default_visitor { + public: + template using const_t = std::conditional_t, T>; + auto pre_visit(const_t &) {} + auto visit_result(const_t &) {} + auto post_visit(const_t &) {} + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + for (auto &op : in.op3()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + if (in.op1()) { + static_cast(this)->operator()(*in.op1()); + } + + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + if (in.op5()) { + static_cast(this)->operator()(*in.op5()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + for (auto &op : in.op0()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + for (auto &op : in.op3()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + for (auto &op : in.op2()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + static_cast(this)->operator()(in.op7()); + static_cast(this)->operator()(in.op8()); + static_cast(this)->operator()(in.op9()); + for (auto &op : in.op10()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + if (in.op2()) { + static_cast(this)->operator()(*in.op2()); + } + + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + if (in.op4()) { + static_cast(this)->operator()(*in.op4()); + } + + if (in.op5()) { + static_cast(this)->operator()(*in.op5()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + if (in.op3()) { + static_cast(this)->operator()(*in.op3()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + for (auto &op : in.op1()) { + static_cast(this)->operator()(op); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.type()); + static_cast(this)->visit_result(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + if (in.op6()) { + static_cast(this)->operator()(*in.op6()); + } + + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + + if (in.op8()) { + static_cast(this)->operator()(*in.op8()); + } + + static_cast(this)->post_visit(in); + } + auto operator()(const_t &in) { + static_cast(this)->pre_visit(in); + static_cast(this)->operator()(in.op0()); + static_cast(this)->operator()(in.op1()); + static_cast(this)->operator()(in.op2()); + static_cast(this)->operator()(in.op3()); + static_cast(this)->operator()(in.op4()); + static_cast(this)->operator()(in.op5()); + static_cast(this)->operator()(in.op6()); + if (in.op7()) { + static_cast(this)->operator()(*in.op7()); + } + + if (in.op8()) { + static_cast(this)->operator()(*in.op8()); + } + + if (in.op9()) { + static_cast(this)->operator()(*in.op9()); + } + + static_cast(this)->post_visit(in); + } +}; + +} // namespace tinytc::spv + +#endif // GENERATED_VISIT_20250630_HPP diff --git a/src/spv/xe_constants.hpp b/src/spv/xe_constants.hpp new file mode 100644 index 00000000..6da62431 --- /dev/null +++ b/src/spv/xe_constants.hpp @@ -0,0 +1,16 @@ +#ifndef XE_CONSTANTS_20250219_HPP +#define XE_CONSTANTS_20250219_HPP + +#include + +namespace tinytc::spv::xe { +constexpr static std::int32_t grf_size = 64; +constexpr static std::int32_t exec_size = 16; +constexpr static std::int32_t channel_size = 4; +constexpr static std::int32_t sdepth = 8; +constexpr static std::int32_t rcount = 8; +constexpr static std::int32_t load_batch_size = 4; +constexpr static std::int32_t store_batch_size = 1; +} // namespace tinytc::spv::xe + +#endif // XE_CONSTANTS_20250219_HPP diff --git a/src/support/fnv1a_array_view.hpp b/src/support/fnv1a_array_view.hpp new file mode 100644 index 00000000..527879c2 --- /dev/null +++ b/src/support/fnv1a_array_view.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FNV1A_ARRAY_VIEW_20241010_HPP +#define FNV1A_ARRAY_VIEW_20241010_HPP + +#include "tinytc/core.hpp" +#include "util/fnv1a.hpp" + +namespace tinytc { + +template +constexpr auto fnv1a_step(std::uint64_t hash, array_view const &data) -> std::uint64_t { + for (auto const &i : data) { + hash = fnv1a_step(hash, i); + } + return hash; +} + +} // namespace tinytc + +#endif // FNV1A_ARRAY_VIEW_20241010_HPP diff --git a/src/support/fp_util.hpp b/src/support/fp_util.hpp new file mode 100644 index 00000000..8bbc6ade --- /dev/null +++ b/src/support/fp_util.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef FP_UTIL_20241126_HPP +#define FP_UTIL_20241126_HPP + +#include "tinytc/core.hpp" + +#include +#include + +namespace tinytc { + +template class U> +struct is_instance_of : public std::false_type {}; +template