Skip to content

Commit

Permalink
Build NNCF extensions in separate folders per torch version (#1120)
Browse files Browse the repository at this point in the history
* Build NNCF extensions in separate folders per torch version

* Fix tests

* Extend .pylintrc with ignores
  • Loading branch information
vshampor authored Mar 14, 2022
1 parent 37a830a commit 3402adb
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 36 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ disable = arguments-differ,
wrong-import-order,
attribute-defined-outside-init,
import-outside-toplevel,
duplicate-code
duplicate-code,
consider-using-f-string

max-line-length = 120
ignore-docstrings = yes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils.cpp_extension import load

from nncf.torch.extensions import CudaNotAvailableStub
from nncf.torch.extensions import get_build_directory_for_extension

ext_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
if torch.cuda.is_available():
Expand All @@ -14,7 +15,8 @@
os.path.join(ext_dir, 'nms/nms.cpp'),
os.path.join(ext_dir, 'nms/nms_kernel.cu'),
],
verbose=False
verbose=False,
build_directory=get_build_directory_for_extension('extensions')
)
else:
EXTENSIONS = CudaNotAvailableStub
34 changes: 24 additions & 10 deletions nncf/torch/binarization/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,39 @@

@EXTENSIONS.register()
class BinarizedFunctionsCPULoader(ExtensionLoader):
@staticmethod
def extension_type():
@classmethod
def name(cls) -> str:
return 'binarized_functions_cpu'

@classmethod
def extension_type(cls):
return ExtensionsType.CPU

@staticmethod
def load():
return load('binarized_functions_cpu', CPU_EXT_SRC_LIST, extra_include_paths=EXT_INCLUDE_DIRS,
@classmethod
def load(cls):
return load(cls.name(),
CPU_EXT_SRC_LIST,
extra_include_paths=EXT_INCLUDE_DIRS,
build_directory=cls.get_build_dir(),
verbose=False)


@EXTENSIONS.register()
class BinarizedFunctionsCUDALoader(ExtensionLoader):
@staticmethod
def extension_type():
@classmethod
def name(cls) -> str:
return 'binarized_functions_cuda'

@classmethod
def extension_type(cls):
return ExtensionsType.CUDA

@staticmethod
def load():
return load('binarized_functions_cuda', CUDA_EXT_SRC_LIST, extra_include_paths=EXT_INCLUDE_DIRS,
@classmethod
def load(cls):
return load(cls.name(),
CUDA_EXT_SRC_LIST,
extra_include_paths=EXT_INCLUDE_DIRS,
build_directory=cls.get_build_dir(),
verbose=False)


Expand Down
34 changes: 28 additions & 6 deletions nncf/torch/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import enum
from pathlib import Path

import torch

from abc import ABC, abstractmethod

from torch.utils.cpp_extension import _get_build_directory

from nncf.common.utils.registry import Registry
from nncf.common.utils.logger import logger as nncf_logger

EXTENSIONS = Registry('extensions')

Expand All @@ -11,18 +17,34 @@ class ExtensionsType(enum.Enum):
CPU = 0
CUDA = 1

def get_build_directory_for_extension(name: str) -> Path:
build_dir = Path(_get_build_directory('nncf/' + name, verbose=False)) / torch.__version__
if not build_dir.exists():
nncf_logger.debug("Creating build directory: {}".format(str(build_dir)))
build_dir.mkdir(parents=True, exist_ok=True)
return build_dir


class ExtensionLoader(ABC):
@staticmethod
@classmethod
@abstractmethod
def extension_type():
def extension_type(cls):
pass

@staticmethod
@classmethod
@abstractmethod
def load():
def load(cls):
pass

@classmethod
@abstractmethod
def name(cls) -> str:
pass

@classmethod
def get_build_dir(cls) -> str:
return str(get_build_directory_for_extension(cls.name()))


def _force_build_extensions(ext_type: ExtensionsType):
for class_type in EXTENSIONS.registry_dict.values():
Expand All @@ -41,5 +63,5 @@ def force_build_cuda_extensions():

class CudaNotAvailableStub:
def __getattr__(self, item):
raise RuntimeError("CUDA is not available on this machine. Check that the machine has a GPU and a proper"
"driver supporting CUDA {} is installed".format(torch.version.cuda))
raise RuntimeError(f"CUDA is not available on this machine. Check that the machine has a GPU and a proper"
f"driver supporting CUDA {torch.version.cuda} is installed")
34 changes: 24 additions & 10 deletions nncf/torch/quantization/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,41 @@

@EXTENSIONS.register()
class QuantizedFunctionsCPULoader(ExtensionLoader):
@staticmethod
def extension_type():
@classmethod
def extension_type(cls):
return ExtensionsType.CPU

@staticmethod
def load():
return load('quantized_functions_cpu', CPU_EXT_SRC_LIST, extra_include_paths=EXT_INCLUDE_DIRS,
@classmethod
def name(cls) -> str:
return 'quantized_functions_cpu'

@classmethod
def load(cls):
return load(cls.name(),
CPU_EXT_SRC_LIST,
extra_include_paths=EXT_INCLUDE_DIRS,
build_directory=cls.get_build_dir(),
verbose=False)


@EXTENSIONS.register()
class QuantizedFunctionsCUDALoader(ExtensionLoader):
@staticmethod
def extension_type():
@classmethod
def extension_type(cls):
return ExtensionsType.CUDA

@staticmethod
def load():
return load('quantized_functions_cuda', CUDA_EXT_SRC_LIST, extra_include_paths=EXT_INCLUDE_DIRS,
@classmethod
def load(cls):
return load(cls.name(),
CUDA_EXT_SRC_LIST,
extra_include_paths=EXT_INCLUDE_DIRS,
build_directory=cls.get_build_dir(),
verbose=False)

@classmethod
def name(cls) -> str:
return 'quantized_functions_cuda'


QuantizedFunctionsCPU = QuantizedFunctionsCPULoader.load()

Expand Down
21 changes: 13 additions & 8 deletions tests/torch/test_extensions_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,29 @@ def test_force_cuda_build(tmp_venv_with_nncf, install_type, tmp_path, package_ty
path=run_path)
command.run()

cpu_ext_dir = (torch_ext_dir / 'quantized_functions_cpu')
version_command = Command('{} -c "import torch; print(torch.__version__)"'.format(python_executable_with_venv),
path=run_path)
version_command.run()
torch_version = version_command.output[0].replace('\n', '')

cpu_ext_dir = (torch_ext_dir / 'nncf' / 'quantized_functions_cpu' / torch_version)
assert cpu_ext_dir.exists()
cpu_ext_so = (cpu_ext_dir / 'quantized_functions_cpu.so')
cpu_ext_so = (cpu_ext_dir / 'quantized_functions_cpu.so' )
assert cpu_ext_so.exists()

cuda_ext_dir = (torch_ext_dir / 'quantized_functions_cuda')
cuda_ext_dir = (torch_ext_dir / 'nncf'/ 'quantized_functions_cuda' / torch_version)
assert not cuda_ext_dir.exists()
cuda_ext_so = (cuda_ext_dir / 'quantized_functions_cuda.so')
assert not cuda_ext_so.exists()

cpu_ext_dir = (torch_ext_dir / 'binarized_functions_cpu')
cpu_ext_dir = (torch_ext_dir / 'nncf' / 'binarized_functions_cpu' / torch_version)
assert cpu_ext_dir.exists()
cpu_ext_so = (cpu_ext_dir / 'binarized_functions_cpu.so')
assert cpu_ext_so.exists()

cuda_ext_dir = (torch_ext_dir / 'binarized_functions_cuda')
cuda_ext_dir = (torch_ext_dir / 'nncf' / 'binarized_functions_cuda' / torch_version)
assert not cuda_ext_dir.exists()
cuda_ext_so = (cuda_ext_dir / 'binarized_functions_cuda.so')
cuda_ext_so = (cuda_ext_dir / 'nncf' / torch_version / 'binarized_functions_cuda.so')
assert not cuda_ext_so.exists()

mode = 'cuda'
Expand All @@ -78,12 +83,12 @@ def test_force_cuda_build(tmp_venv_with_nncf, install_type, tmp_path, package_ty
path=run_path)
command.run()

cuda_ext_dir = (torch_ext_dir / 'quantized_functions_cuda')
cuda_ext_dir = (torch_ext_dir / 'nncf' / 'quantized_functions_cuda' / torch_version)
assert cuda_ext_dir.exists()
cuda_ext_so = (cuda_ext_dir / 'quantized_functions_cuda.so')
assert cuda_ext_so.exists()

cuda_ext_dir = (torch_ext_dir / 'binarized_functions_cuda')
cuda_ext_dir = (torch_ext_dir / 'nncf' / 'binarized_functions_cuda' / torch_version)
assert cuda_ext_dir.exists()
cuda_ext_so = (cuda_ext_dir / 'binarized_functions_cuda.so')
assert cuda_ext_so.exists()

0 comments on commit 3402adb

Please sign in to comment.