Skip to content

Commit

Permalink
Make system_api.load_vm_flatbuffer_file mmap the file. (iree-org#14333)
Browse files Browse the repository at this point in the history
Also adds a unit test to load_vm_flatbuffer.

Fixes iree-org#14321.
  • Loading branch information
stellaraccident authored Jul 11, 2023
1 parent 0dce100 commit b7c942a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 22 deletions.
55 changes: 39 additions & 16 deletions runtime/bindings/python/iree/runtime/system_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,9 @@ def load_vm_module(vm_module, config: Optional[Config] = None):
return load_vm_modules(vm_module, config=config)[0]


def load_vm_flatbuffer(
vm_flatbuffer: bytes, *, driver: Optional[str] = None, backend: Optional[str] = None
) -> BoundModule:
"""Loads a VM Flatbuffer into a callable module.
Either 'driver' or 'backend' must be specified.
"""
def _create_config(
*, driver: Optional[str] = None, backend: Optional[str] = None
) -> Config:
if driver is None and backend is None:
raise ValueError(
"Either 'driver' or 'backend' must be specified, but got "
Expand All @@ -302,20 +298,47 @@ def load_vm_flatbuffer(
if backend is not None:
driver = TARGET_BACKEND_TO_DRIVER[backend]
config = Config(driver)
vm_module = _binding.VmModule.from_flatbuffer(config.vm_instance, vm_flatbuffer)
bound_module = load_vm_module(vm_module, config)
return bound_module
return config


def load_vm_flatbuffer(
vm_flatbuffer: bytes, *, driver: Optional[str] = None, backend: Optional[str] = None
) -> BoundModule:
"""Loads a VM Flatbuffer into a callable module.
Either 'driver' or 'backend' must be specified.
Note that this API makes a defensive copy to ensure proper alignment and is
therefore not suitable for large flatbuffers. See load_vm_flatbuffer_file()
or mmap APIs on VmModule.
"""
config = _create_config(driver=driver, backend=backend)
vm_module = _binding.VmModule.copy_buffer(config.vm_instance, vm_flatbuffer)
return load_vm_module(vm_module, config)


# TODO: There should be an API for mmap'ing the file which should be used
# instead of reading into memory.
def load_vm_flatbuffer_file(
path: str, *, driver: Optional[str] = None, backend: Optional[str] = None
path: str,
*,
driver: Optional[str] = None,
backend: Optional[str] = None,
destroy_callback=None,
) -> BoundModule:
"""Loads a file containing a VM Flatbuffer into a callable module.
Either 'driver' or 'backend' must be specified.
Note that this delegates to the lower level VmModule.mmap() API, which,
as the name implies, memory maps the file. This can be fiddly across
platforms and for maximum compatibility, ensure that the file is not
otherwise open for write or deleted while in use.
If provided, 'destroy_callback' will be passed to VmModule.mmap and will
be invoked when no further references to the mapping exist. This can be
used to clean up test state, etc (in a Windows compatible way).
"""
with open(path, "rb") as f:
vm_flatbuffer = f.read()
return load_vm_flatbuffer(vm_flatbuffer, driver=driver, backend=backend)
config = _create_config(driver=driver, backend=backend)
vm_module = _binding.VmModule.mmap(
config.vm_instance, str(path), destroy_callback=destroy_callback
)
return load_vm_module(vm_module, config)
45 changes: 39 additions & 6 deletions runtime/bindings/python/tests/system_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pylint: disable=unused-variable

import gc
import logging
import os
import re
Expand All @@ -17,19 +18,28 @@
import numpy as np


def create_simple_mul_module(instance):
binary = iree.compiler.compile_str(
"""
_SIMPLE_MUL_BINARY = None


def compile_simple_mul_binary():
global _SIMPLE_MUL_BINARY
if not _SIMPLE_MUL_BINARY:
_SIMPLE_MUL_BINARY = iree.compiler.compile_str(
"""
module @arithmetic {
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
}
""",
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
target_backends=iree.compiler.core.DEFAULT_TESTING_BACKENDS,
)
return _SIMPLE_MUL_BINARY


def create_simple_mul_module(instance):
m = iree.runtime.VmModule.from_flatbuffer(instance, compile_simple_mul_binary())
return m


Expand Down Expand Up @@ -147,6 +157,29 @@ def test_load_multiple_modules(self):
m1 = iree.runtime.load_vm_module(m)
m2 = iree.runtime.load_vm_module(m)

def test_load_vm_flatbuffer(self):
# This API is old and not highly recommended but testing as-is.
m = iree.runtime.load_vm_flatbuffer(
compile_simple_mul_binary(), driver="local-sync"
)
m = iree.runtime.load_vm_flatbuffer(
compile_simple_mul_binary(), backend="llvm-cpu"
)

def test_load_vm_flatbuffer_file(self):
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(compile_simple_mul_binary())

def _cleanup():
os.unlink(f.name)

m = iree.runtime.load_vm_flatbuffer_file(
f.name, driver="local-sync", destroy_callback=_cleanup
)
del m
gc.collect()
self.assertFalse(os.path.exists(f.name))


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit b7c942a

Please sign in to comment.