Skip to content

Commit

Permalink
fix empty inputs (microsoft#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Mar 29, 2021
1 parent 9138379 commit c889ee3
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion backends/c-hlsl_win64/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

local_dll_path = os.environ["ANTARES_DRIVER_PATH"]

if not os.path.exists('{local_dll_path}/dxcompiler.dll'):
if not os.path.exists(f'{local_dll_path}/dxcompiler.dll'):
os.system(f'curl -Ls https://github.com/microsoft/antares/releases/download/v0.1.0/antares_hlsl_v0.2dev3_x64.dll -o {local_dll_path}/antares_hlsl_v0.2_x64.dll')
os.system(f'curl -Ls https://github.com/microsoft/antares/releases/download/v0.1.0/dxil.dll -o {local_dll_path}/dxil.dll')
os.system(f'curl -Ls https://github.com/microsoft/antares/releases/download/v0.1.0/dxcompiler.dll -o {local_dll_path}/dxcompiler.dll')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,15 @@ namespace {
return iter;
}

static std::vector<std::string> ssplit(const std::string& source, const std::string& delim) {
if (!source.size())
return {};
static std::vector<std::string> ssplit(const std::string& source, const std::string& delim, bool allow_empty = false) {
std::vector<std::string> ret;
int it = 0, next;
while (next = (int)source.find(delim, it), next >= 0) {
if (next > it)
if (next > it || allow_empty)
ret.push_back(source.substr(it, next - it));
it = next + (int)delim.size();
}
if (it < source.size())
if (it < source.size() || allow_empty)
ret.push_back(source.substr(it));
return std::move(ret);
}
Expand Down Expand Up @@ -270,18 +268,20 @@ void* dxShaderLoad_v2(const char* shader_src)

if (legacy_format) {
str_params = get_between(source, "///", "\n");
arr_params = ssplit(str_params, ":");
arr_params = ssplit(str_params, ":", true);
assert(arr_params.size() == 2);
in_params = ssplit(arr_params[0], ",");
out_params = ssplit(arr_params[1], ",");
}
else {
str_params = get_between(source, " -- ", "\n");
arr_params = ssplit(str_params, " -> ");
arr_params = ssplit(str_params, " -> ", true);
assert(arr_params.size() == 2);
in_params = ssplit(arr_params[0] + ", ", "], ");
out_params = ssplit(arr_params[1] + ", ", "], ");
}
if (!arr_params[0].size())
in_params.clear();

auto parse_tensor = [&](const std::string & param) -> dx_tensor_t {
dx_tensor_t ret;
Expand Down
7 changes: 4 additions & 3 deletions backends/c-rocm/schedule/standard/algo_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ def plan_threads(attrs, axes):
while num_threads > 1:
unchanged = True
for i, x in enumerate(shape):
if x % th == 0:
if x % th == 0 and num_threads % th == 0:
num_threads //= th
shape[i] //= th
init_threads[i] *= th
unchanged = False
if unchanged:
break
init_vthreads = [1] * len(axes)
num_vthreads, init_vthreads = 256, [1] * len(axes)
for i, x in enumerate(shape):
if x % 2 == 0:
if x % 2 == 0 and num_vthreads % 2 == 0:
num_vthreads //= 2
shape[i] //= 2
init_vthreads[i] *= 2
return init_threads, init_vthreads
Expand Down
10 changes: 5 additions & 5 deletions graph_evaluator/execute_module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ std::string get_between(const std::string &str, const std::string &begin, const
return str.substr(at, next - at);
}

std::vector<std::string> ssplit(const std::string &str, const std::string &sub) {
std::vector<std::string> ssplit(const std::string &str, const std::string &sub, bool allow_empty = false) {
std::vector<std::string> ret;
int it = 0, next;
while (next = str.find(sub, it), next >= 0) {
if (next > it)
if (next > it || allow_empty)
ret.push_back(str.substr(it, next - it));
it = next + sub.size();
}
if (it < str.size())
if (it < str.size() || allow_empty)
ret.push_back(str.substr(it));
return std::move(ret);
}
Expand Down Expand Up @@ -186,7 +186,7 @@ struct ExecutionModule {
}

auto encoded_params = get_between(source, "// GLOBALS: ", "\n");
auto params = ssplit(encoded_params, " -> ");
auto params = ssplit(encoded_params, " -> ", true);
global_inputs = parse_properties(params[0]), global_outputs = parse_properties(params[1]);

backend = get_between(source, "// BACKEND: ", " (");
Expand All @@ -200,7 +200,7 @@ struct ExecutionModule {
local_kernels.push_back(kernel_property{});
auto &kp = local_kernels[local_kernels.size() - 1];
kp.fname = name;
auto inputs_outputs = ssplit(get_between(kernel_slices[i], " -- ", "\n"), " -> ");
auto inputs_outputs = ssplit(get_between(kernel_slices[i], " -- ", "\n"), " -> ", true);
auto local_inputs = parse_properties(inputs_outputs[0]);
auto local_outputs = parse_properties(inputs_outputs[1]);

Expand Down

0 comments on commit c889ee3

Please sign in to comment.