Skip to content

Commit

Permalink
Merge pull request #6940 from stuartarchibald/fix/hardware_extension_…
Browse files Browse the repository at this point in the history
…api_intrin

Fix function resolution for intrinsics across hardware.
  • Loading branch information
sklam authored Apr 22, 2021
2 parents 47ff1be + 4e8786c commit a694947
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 4 deletions.
5 changes: 5 additions & 0 deletions numba/core/extending_hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __call__(self):
class Target(ABC):
""" Implements a hardware/pseudo-hardware target """

@classmethod
def inherits_from(cls, other):
"""Returns True if this target inherits from 'other' False otherwise"""
return issubclass(cls, other)


class Generic(Target):
"""Mark the hardware target as generic, i.e. suitable for compilation on
Expand Down
2 changes: 1 addition & 1 deletion numba/core/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def get_call_type(self, context, args, kws):
hw = temp_cls.metadata.get('hardware', DEFAULT_HARDWARE)
if hw is not None:
hw_clazz = hardware_registry[hw]
if hw_clazz in target_hw.__mro__:
if target_hw.inherits_from(hw_clazz):
usable.append((temp_cls, hw_clazz, ix))

# sort templates based on hardware specificity
Expand Down
18 changes: 15 additions & 3 deletions numba/core/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,12 +885,24 @@ def generic(self, args, kws):
"""
Type the intrinsic by the arguments.
"""
from numba.core.extending_hardware import resolve_dispatcher_from_str
from numba.core.extending_hardware import (get_local_target,
resolve_target_str,
dispatcher_registry)
from numba.core.imputils import builtin_registry

cache_key = self.context, args, tuple(kws.items())
hwstr = self.metadata.get('hardware', 'cpu')
disp = resolve_dispatcher_from_str(hwstr)
hwstr = self.metadata.get('hardware', 'generic')
# Get the class for the target declared by the function
hw_clazz = resolve_target_str(hwstr)
# get the local target
target_hw = get_local_target(self.context)
# make sure the target_hw is in the MRO for hw_clazz else bail
if not target_hw.inherits_from(hw_clazz):
msg = (f"Intrinsic being resolved on a target from which it does "
f"not inherit. Local target is {target_hw}, declared "
f"target class is {hw_clazz}.")
raise InternalError(msg)
disp = dispatcher_registry[target_hw]
tgtctx = disp.targetdescr.target_context
# This is all workarounds...
# The issue is that whilst targets shouldn't care about which registry
Expand Down
90 changes: 90 additions & 0 deletions numba/tests/test_hardware_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,96 @@ def foo():
msg = "No target is registered against 'invalid_silicon'"
self.assertIn(msg, str(raises.exception))

def test_intrinsic_selection(self):
"""
Test to make sure that targets can share generic implementations and
cannot reach implementations that are not in their hardware hierarchy.
"""

# NOTE: The actual operation performed by these functions is irrelevant
@intrinsic(hardware="generic")
def intrin_math_generic(tyctx, x, y):
sig = x(x, y)

def codegen(cgctx, builder, tyargs, llargs):
return builder.mul(*llargs)

return sig, codegen

@intrinsic(hardware="dpu")
def intrin_math_dpu(tyctx, x, y):
sig = x(x, y)

def codegen(cgctx, builder, tyargs, llargs):
return builder.sub(*llargs)

return sig, codegen

@intrinsic(hardware="cpu")
def intrin_math_cpu(tyctx, x, y):
sig = x(x, y)

def codegen(cgctx, builder, tyargs, llargs):
return builder.add(*llargs)

return sig, codegen

# CPU can use the CPU version
@njit
def cpu_foo_specific():
return intrin_math_cpu(3, 4)

self.assertEqual(cpu_foo_specific(), 7)

# CPU can use the 'generic' version
@njit
def cpu_foo_generic():
return intrin_math_generic(3, 4)

self.assertEqual(cpu_foo_generic(), 12)

# CPU cannot use the 'dpu' version
@njit
def cpu_foo_dpu():
return intrin_math_dpu(3, 4)

with self.assertRaises(errors.TypingError) as raises:
cpu_foo_dpu()

msgs = ["Function resolution cannot find any matches for function",
"intrinsic intrin_math_dpu",
"for the current hardware",]
for msg in msgs:
self.assertIn(msg, str(raises.exception))

# DPU can use the DPU version
@djit(nopython=True)
def dpu_foo_specific():
return intrin_math_dpu(3, 4)

self.assertEqual(dpu_foo_specific(), -1)

# DPU can use the 'generic' version
@djit(nopython=True)
def dpu_foo_generic():
return intrin_math_generic(3, 4)

self.assertEqual(dpu_foo_generic(), 12)

# DPU cannot use the 'cpu' version
@djit(nopython=True)
def dpu_foo_cpu():
return intrin_math_cpu(3, 4)

with self.assertRaises(errors.TypingError) as raises:
dpu_foo_cpu()

msgs = ["Function resolution cannot find any matches for function",
"intrinsic intrin_math_cpu",
"for the current hardware",]
for msg in msgs:
self.assertIn(msg, str(raises.exception))


class TestHardwareOffload(TestCase):
"""In this use case the CPU compilation pipeline is extended with a new
Expand Down

0 comments on commit a694947

Please sign in to comment.