diff --git a/backends/c-hlsl_win64/config.py b/backends/c-hlsl_win64/config.py index 47e6c878..d03b15d1 100644 --- a/backends/c-hlsl_win64/config.py +++ b/backends/c-hlsl_win64/config.py @@ -6,7 +6,7 @@ local_dll_path = os.environ["ANTARES_DRIVER_PATH"] -if not os.path.exists('{local_dll_path}/dxcompiler.dll'): +if not os.path.exists(f'{local_dll_path}/dxcompiler.dll'): os.system(f'curl -Ls https://github.com/microsoft/antares/releases/download/v0.1.0/antares_hlsl_v0.2dev3_x64.dll -o {local_dll_path}/antares_hlsl_v0.2_x64.dll') os.system(f'curl -Ls https://github.com/microsoft/antares/releases/download/v0.1.0/dxil.dll -o {local_dll_path}/dxil.dll') os.system(f'curl -Ls https://github.com/microsoft/antares/releases/download/v0.1.0/dxcompiler.dll -o {local_dll_path}/dxcompiler.dll') diff --git a/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp b/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp index 44d1a6a4..bc13792f 100644 --- a/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp +++ b/backends/c-hlsl_xbox/evaluator/AntaresHlslLib/D3D12APIWrapper.cpp @@ -145,17 +145,15 @@ namespace { return iter; } - static std::vector ssplit(const std::string& source, const std::string& delim) { - if (!source.size()) - return {}; + static std::vector ssplit(const std::string& source, const std::string& delim, bool allow_empty = false) { std::vector ret; int it = 0, next; while (next = (int)source.find(delim, it), next >= 0) { - if (next > it) + if (next > it || allow_empty) ret.push_back(source.substr(it, next - it)); it = next + (int)delim.size(); } - if (it < source.size()) + if (it < source.size() || allow_empty) ret.push_back(source.substr(it)); return std::move(ret); } @@ -270,18 +268,20 @@ void* dxShaderLoad_v2(const char* shader_src) if (legacy_format) { str_params = get_between(source, "///", "\n"); - arr_params = ssplit(str_params, ":"); + arr_params = ssplit(str_params, ":", true); assert(arr_params.size() == 2); in_params = ssplit(arr_params[0], ","); out_params = ssplit(arr_params[1], ","); } else { str_params = get_between(source, " -- ", "\n"); - arr_params = ssplit(str_params, " -> "); + arr_params = ssplit(str_params, " -> ", true); assert(arr_params.size() == 2); in_params = ssplit(arr_params[0] + ", ", "], "); out_params = ssplit(arr_params[1] + ", ", "], "); } + if (!arr_params[0].size()) + in_params.clear(); auto parse_tensor = [&](const std::string & param) -> dx_tensor_t { dx_tensor_t ret; diff --git a/backends/c-rocm/schedule/standard/algo_tiling.py b/backends/c-rocm/schedule/standard/algo_tiling.py index fde957eb..0f9f6095 100644 --- a/backends/c-rocm/schedule/standard/algo_tiling.py +++ b/backends/c-rocm/schedule/standard/algo_tiling.py @@ -10,16 +10,17 @@ def plan_threads(attrs, axes): while num_threads > 1: unchanged = True for i, x in enumerate(shape): - if x % th == 0: + if x % th == 0 and num_threads % th == 0: num_threads //= th shape[i] //= th init_threads[i] *= th unchanged = False if unchanged: break - init_vthreads = [1] * len(axes) + num_vthreads, init_vthreads = 256, [1] * len(axes) for i, x in enumerate(shape): - if x % 2 == 0: + if x % 2 == 0 and num_vthreads % 2 == 0: + num_vthreads //= 2 shape[i] //= 2 init_vthreads[i] *= 2 return init_threads, init_vthreads diff --git a/graph_evaluator/execute_module.hpp b/graph_evaluator/execute_module.hpp index d275e70d..57b0708b 100644 --- a/graph_evaluator/execute_module.hpp +++ b/graph_evaluator/execute_module.hpp @@ -92,15 +92,15 @@ std::string get_between(const std::string &str, const std::string &begin, const return str.substr(at, next - at); } -std::vector ssplit(const std::string &str, const std::string &sub) { +std::vector ssplit(const std::string &str, const std::string &sub, bool allow_empty = false) { std::vector ret; int it = 0, next; while (next = str.find(sub, it), next >= 0) { - if (next > it) + if (next > it || allow_empty) ret.push_back(str.substr(it, next - it)); it = next + sub.size(); } - if (it < str.size()) + if (it < str.size() || allow_empty) ret.push_back(str.substr(it)); return std::move(ret); } @@ -186,7 +186,7 @@ struct ExecutionModule { } auto encoded_params = get_between(source, "// GLOBALS: ", "\n"); - auto params = ssplit(encoded_params, " -> "); + auto params = ssplit(encoded_params, " -> ", true); global_inputs = parse_properties(params[0]), global_outputs = parse_properties(params[1]); backend = get_between(source, "// BACKEND: ", " ("); @@ -200,7 +200,7 @@ struct ExecutionModule { local_kernels.push_back(kernel_property{}); auto &kp = local_kernels[local_kernels.size() - 1]; kp.fname = name; - auto inputs_outputs = ssplit(get_between(kernel_slices[i], " -- ", "\n"), " -> "); + auto inputs_outputs = ssplit(get_between(kernel_slices[i], " -- ", "\n"), " -> ", true); auto local_inputs = parse_properties(inputs_outputs[0]); auto local_outputs = parse_properties(inputs_outputs[1]);