Skip to content

Commit

Permalink
Fix capability detection on fake-hpu (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana authored Nov 4, 2024
1 parent 42127d5 commit 565204e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
21 changes: 17 additions & 4 deletions vllm_hpu_extension/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,32 @@
from vllm_hpu_extension.environment import get_environment


class VersionRange:
class Check:
def __init__(self, *required_params):
self.required_params = required_params

def __call__(self, **kwargs):
if any(kwargs[rp] is None for rp in self.required_params):
return False
return self.check(**kwargs)


class VersionRange(Check):
def __init__(self, *specifiers):
super().__init__('build')
self.specifiers = [SpecifierSet(s) for s in specifiers]

def __call__(self, build, **_):
def check(self, build, **_):
version = Version(build)
return any(version in s for s in self.specifiers)


class Hardware:
class Hardware(Check):
def __init__(self, target_hw):
super().__init__('hw')
self.target_hw = target_hw

def __call__(self, hw, **_):
def check(self, hw, **_):
return hw == self.target_hw


Expand All @@ -52,6 +64,7 @@ def capabilities():
"gaudi": Hardware("gaudi"),
"gaudi2": Hardware("gaudi2"),
"gaudi3": Hardware("gaudi3"),
"cpu": Hardware("cpu"),
}
environment = get_environment()
capabilities = Capabilities(supported_features, environment)
Expand Down
14 changes: 12 additions & 2 deletions vllm_hpu_extension/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

from vllm.logger import init_logger
from vllm_hpu_extension.utils import is_fake_hpu

logger = init_logger(__name__)


def get_hw():
import habana_frameworks.torch.utils.experimental as htexp
device_type = htexp._get_device_type()
Expand All @@ -15,7 +21,10 @@ def get_hw():
return "gaudi2"
case htexp.synDeviceType.synDeviceGaudi3:
return "gaudi3"
raise RuntimeError(f'Unknown device type: {device_type}')
if is_fake_hpu():
return "cpu"
logger.warning(f'Unknown device type: {device_type}')
return None


def get_build():
Expand All @@ -30,7 +39,8 @@ def get_build():
match = version_re.search(output.stdout)
if output.returncode == 0 and match:
return match.group('version')
raise RuntimeError("Unable to detect habana-torch-plugin version!")
logger.warning("Unable to detect habana-torch-plugin version!")
return None


def get_environment(**overrides):
Expand Down

0 comments on commit 565204e

Please sign in to comment.