Skip to content

Commit

Permalink
[Target] enable -arch=sm_xx for assigning cuda target arch and deprec…
Browse files Browse the repository at this point in the history
…ate autotvm.measure.set_cuda_target_arch api (apache#9544)

* [Target] enable -arch=sm_xx for assigning cuda target arch and deprecate autotvm.measure.set_cuda_target_arch api

Signed-off-by: ZQPei <[email protected]>

* [Format] fix format error in CI

Signed-off-by: ZQPei <[email protected]>

* [Target] add warnings to target.cuda and fix errors in ci

Signed-off-by: ZQPei <[email protected]>

* [Target] fix docstring

Signed-off-by: ZQPei <[email protected]>

* [Target] amend warning condition

Signed-off-by: ZQPei <[email protected]>
  • Loading branch information
ZQPei authored Nov 24, 2021
1 parent 289bd90 commit 0195afc
Show file tree
Hide file tree
Showing 16 changed files with 118 additions and 87 deletions.
4 changes: 2 additions & 2 deletions apps/topi_recipe/broadcast/test_broadcast_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
USE_MANUAL_CODE = False


@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


Expand Down
4 changes: 2 additions & 2 deletions apps/topi_recipe/conv/depthwise_conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
USE_MANUAL_CODE = False


@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


Expand Down
4 changes: 2 additions & 2 deletions apps/topi_recipe/conv/test_conv2d_hwcn_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
USE_MANUAL_CODE = False


@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


Expand Down
4 changes: 2 additions & 2 deletions apps/topi_recipe/gemm/cuda_gemm_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
USE_MANUAL_CODE = False


@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


Expand Down
4 changes: 2 additions & 2 deletions apps/topi_recipe/rnn/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
UNROLL_WLOAD = True


@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx")
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


Expand Down
4 changes: 2 additions & 2 deletions apps/topi_recipe/rnn/matexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
SKIP_CHECK = False


@tvm.register_func
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx")
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


Expand Down
8 changes: 7 additions & 1 deletion jvm/core/src/test/scripts/test_add_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

import tvm
from tvm import te
from tvm.contrib import cc, utils
from tvm.contrib import cc, utils, nvcc


@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target_format="ptx")
return ptx


def test_add(target_dir):
Expand Down
5 changes: 0 additions & 5 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from tvm.runtime import Object, module, ndarray
from tvm.driver import build_module
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope
from tvm.contrib import tar, ndk
from tvm.contrib.popen_pool import PopenWorker, PopenPoolExecutor, StatusKind
Expand Down Expand Up @@ -550,10 +549,6 @@ def __init__(
from tvm.rpc.tracker import Tracker
from tvm.rpc.server import Server

dev = tvm.device("cuda", 0)
if dev.exist:
cuda_arch = "sm_" + "".join(dev.compute_version.split("."))
set_cuda_target_arch(cuda_arch)
self.tracker = Tracker(port=9000, port_end=10000, silent=True)
device_key = "$local$device$%d" % self.tracker.port
self.server = Server(
Expand Down
1 change: 0 additions & 1 deletion python/tvm/autotvm/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self):
self._old = AutotvmGlobalScope.current
AutotvmGlobalScope.current = self

self.cuda_target_arch = None
self.in_tuning = False
self.silent = False

Expand Down
34 changes: 10 additions & 24 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from tvm import nd
from tvm import rpc as _rpc
from tvm.autotvm.env import AutotvmGlobalScope, reset_global_scope
from tvm.contrib import ndk, nvcc, stackvm, tar
from tvm.contrib import ndk, stackvm, tar
from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.driver import build
from tvm.error import TVMError
Expand Down Expand Up @@ -322,9 +322,6 @@ def get_build_kwargs(self):
"max_thread_z": max_dims[2],
}

if "cuda" in self.task.target.keys:
kwargs["cuda_arch"] = "sm_" + "".join(dev.compute_version.split("."))

return kwargs

def run(self, measure_inputs, build_results):
Expand Down Expand Up @@ -463,9 +460,7 @@ def set_task(self, task):
return server, tracker


def _build_func_common(
measure_input, runtime=None, check_gpu=None, cuda_arch=None, build_option=None
):
def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
target, task.target_host = Target.check_and_update_host_consist(target, task.target_host)
Expand All @@ -480,8 +475,6 @@ def _build_func_common(
opts = build_option or {}
if check_gpu: # Add verify pass to filter out invalid configs in advance.
opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
if cuda_arch:
set_cuda_target_arch(cuda_arch)

# if target is vta, we need to use vta build
if (
Expand Down Expand Up @@ -789,21 +782,10 @@ def _check():
return not t.is_alive()


@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate ptx code for better optimization"""
curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch
# e.g., target arch could be [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx"
ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch)
return ptx


def set_cuda_target_arch(arch):
"""set target architecture of nvcc compiler
"""THIS API IS DEPRECATED.
set target architecture of nvcc compiler
Parameters
----------
Expand All @@ -812,7 +794,11 @@ def set_cuda_target_arch(arch):
it can also be a count of gencode arguments pass to nvcc command line,
e.g., ["-gencode", "arch=compute_52,code=sm_52", "-gencode", "arch=compute_70,code=sm_70"]
"""
AutotvmGlobalScope.current.cuda_target_arch = arch
raise ValueError(
"The API 'autotvm.measure.set_cuda_target_arch' is deprecated."
"Try specifying it by adding '-arch=sm_xx' to your target, such as 'cuda -arch=sm_86'."
"See https://github.com/apache/tvm/pull/9544 for the upgrade guide."
)


def gpu_verify_pass(**kwargs):
Expand Down
76 changes: 40 additions & 36 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,28 @@
import warnings

import tvm._ffi
from tvm.runtime import ndarray as nd
from tvm.target import Target

from . import utils
from .._ffi.base import py_str


def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None):
def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target : str
The target format
target_format : str
The target format of nvcc compiler.
arch : str
The architecture
The cuda architecture.
options : str or list of str
The additional options
The additional options.
path_target : str, optional
Output file.
Expand All @@ -54,28 +54,33 @@ def compile_cuda(code, target="ptx", arch=None, options=None, path_target=None):
cubin : bytearray
The bytearray of the cubin
"""
if arch is None:
# If None, then it will use `tvm.target.Target.current().arch`.
# Target arch could be a str like "sm_xx", or a list, such as
# [
# "-gencode", "arch=compute_52,code=sm_52",
# "-gencode", "arch=compute_70,code=sm_70"
# ]
compute_version = "".join(
get_target_compute_version(Target.current(allow_none=True)).split(".")
)
arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]

temp = utils.tempdir()
if target not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target must be in cubin, ptx, fatbin")
if target_format not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target_format must be in cubin, ptx, fatbin")
temp_code = temp.relpath("my_kernel.cu")
temp_target = temp.relpath("my_kernel.%s" % target)
temp_target = temp.relpath("my_kernel.%s" % target_format)

with open(temp_code, "w") as out_file:
out_file.write(code)

if arch is None:
if nd.cuda(0).exist:
# auto detect the compute arch argument
arch = "sm_" + "".join(nd.cuda(0).compute_version.split("."))
else:
raise ValueError("arch(sm_xy) is not passed, and we cannot detect it from env")

file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["--%s" % target_format, "-O3"]
if isinstance(arch, list):
cmd += arch
else:
elif isinstance(arch, str):
cmd += ["-arch", arch]

if options:
Expand Down Expand Up @@ -172,6 +177,13 @@ def get_cuda_version(cuda_path):
raise RuntimeError("Cannot read cuda version file")


@tvm._ffi.register_func
def tvm_callback_cuda_compile(code):
"""use nvcc to generate fatbin code for better optimization"""
ptx = compile_cuda(code, target_format="fatbin")
return ptx


@tvm._ffi.register_func("tvm_callback_libdevice_path")
def find_libdevice_path(arch):
"""Utility function to find libdevice
Expand Down Expand Up @@ -221,8 +233,8 @@ def callback_libdevice_path(arch):
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
Looks for the arch in three different places, first in the target attributes, then the global
scope, and finally the GPU device (if it exists).
Looks for the target arch in three different places, first in the target input, then the
Target.current() scope, and finally the GPU device (if it exists).
Parameters
----------
Expand All @@ -232,31 +244,23 @@ def get_target_compute_version(target=None):
Returns
-------
compute_version : str
compute capability of a GPU (e.g. "8.0")
compute capability of a GPU (e.g. "8.6")
"""
# 1. Target
if target:
if "arch" in target.attrs:
compute_version = target.attrs["arch"]
major, minor = compute_version.split("_")[1]
return major + "." + minor

# 2. Global scope
from tvm.autotvm.env import AutotvmGlobalScope # pylint: disable=import-outside-toplevel

if AutotvmGlobalScope.current.cuda_target_arch:
major, minor = AutotvmGlobalScope.current.cuda_target_arch.split("_")[1]
# 1. input target object
# 2. Target.current()
target = target or Target.current()
if target and target.arch:
major, minor = target.arch.split("_")[1]
return major + "." + minor

# 3. GPU
# 3. GPU compute version
if tvm.cuda(0).exist:
return tvm.cuda(0).compute_version

warnings.warn(
raise ValueError(
"No CUDA architecture was specified or GPU detected."
"Try specifying it by adding '-arch=sm_xx' to your target."
)
return None


def parse_compute_version(compute_version):
Expand Down
4 changes: 0 additions & 4 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,10 @@ def default_build(mod: IRModule, target: Target) -> Module:
The built Module.
"""
# pylint: disable=import-outside-toplevel
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.driver import build as tvm_build

# pylint: enable=import-outside-toplevel

if target.kind.name == "cuda":
set_cuda_target_arch(target.attrs["arch"])

return tvm_build(mod, target=target)


Expand Down
15 changes: 14 additions & 1 deletion python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,19 @@ def current(allow_none=True):
"""
return _ffi_api.TargetCurrent(allow_none)

@property
def arch(self):
"""Returns the cuda arch from the target if it exists."""
return str(self.attrs.get("arch", ""))

@property
def max_num_threads(self):
"""Returns the max_num_threads from the target if it exists."""
return int(self.attrs["max_num_threads"])

@property
def thread_warp_size(self):
"""Returns the thread_warp_size from the target if it exists."""
return int(self.attrs["thread_warp_size"])

@property
Expand Down Expand Up @@ -228,17 +235,23 @@ def _merge_opts(opts, new_opts):
return opts


def cuda(model="unknown", options=None):
def cuda(model="unknown", arch=None, options=None):
"""Returns a cuda target.
Parameters
----------
model: str
The model of cuda device (e.g. 1080ti)
arch: str
The cuda architecture (e.g. sm_61)
options : str or list of str
Additional options
"""
opts = _merge_opts(["-model=%s" % model], options)
if arch:
opts = _merge_opts(["-arch=%s" % arch], opts)
if not any(["-arch" in opt for opt in opts]):
warnings.warn("Try specifying cuda arch by adding 'arch=sm_xx' to your target.")
return Target(" ".join(["cuda"] + opts))


Expand Down
Loading

0 comments on commit 0195afc

Please sign in to comment.