Skip to content

Commit

Permalink
[MERGE] Merge main into unity 2023-07-03
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 3, 2023
2 parents 918fc4e + a60b815 commit 40300a3
Show file tree
Hide file tree
Showing 43 changed files with 1,194 additions and 150 deletions.
21 changes: 0 additions & 21 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -179,24 +179,3 @@ jobs:
with:
name: android_deploy-debug.apk
path: ./apps/android_deploy/app/build/outputs/apk/debug/app-debug.apk
- name: Build android_camera
working-directory: apps/android_camera
run: |
export TVM_HOME=~/work/tvm/tvm
export PYTHONPATH=$TVM_HOME/python:${PYTHONPATH}
set -eux
mkdir -p app/src/main/assets/models/
export TVM_NDK_CC=${ANDROID_NDK_LATEST_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android30-clang++
python3 ${TVM_HOME}/python/gen_requirements.py
pip3 install -r ${TVM_HOME}/python/requirements/core.txt
cd models
pip3 install -r requirements.txt
python3 prepare_model.py
cd ..
export PATH="${ANDROID_NDK_LATEST_HOME}:$PATH"
gradle clean build
- name: Upload android_camera APK
uses: actions/upload-artifact@v2
with:
name: android_camera-debug.apk
path: ./apps/android_camera/app/build/outputs/apk/debug/app-debug.apk
1 change: 1 addition & 0 deletions apps/android_rpc/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#endif

#ifdef TVM_VULKAN_RUNTIME
#include "../src/runtime/vulkan/vulkan_amdrgp.cc"
#include "../src/runtime/vulkan/vulkan_buffer.cc"
#include "../src/runtime/vulkan/vulkan_common.cc"
#include "../src/runtime/vulkan/vulkan_device.cc"
Expand Down
32 changes: 26 additions & 6 deletions cmake/utils/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ macro(find_llvm use_llvm)
if(NOT "${__llvm_exit_code}" STREQUAL "0")
message(FATAL_ERROR "Fatal error executing: ${LLVM_CONFIG} --libdir")
endif()
message(STATUS "LLVM libdir: ${__llvm_libdir}")
# map prefix => $
# to handle the case when the prefix contains space.
string(REPLACE ${__llvm_prefix} "$" __llvm_cxxflags ${__llvm_cxxflags_space})
Expand Down Expand Up @@ -144,14 +145,33 @@ macro(find_llvm use_llvm)
endforeach()
separate_arguments(__llvm_system_libs)
foreach(__flag IN ITEMS ${__llvm_system_libs})
# If the library file ends in .lib try to
# also search the llvm_libdir
if(__flag MATCHES ".lib$")
if(EXISTS "${__llvm_libdir}/${__flag}")
set(__flag "${__llvm_libdir}/${__flag}")
if("${__flag}" STREQUAL "-lm")
message(STATUS "LLVM links against math")
list(APPEND LLVM_LIBS "m")
elseif(("${__flag}" STREQUAL "-lz") OR ("${__flag}" STREQUAL "z.lib"))
message(STATUS "LLVM links against zlib")
find_package(ZLIB REQUIRED)
list(APPEND LLVM_LIBS "ZLIB::ZLIB")
elseif("${__flag}" STREQUAL "-lzstd" OR ("${__flag}" STREQUAL "zstd.dll.lib"))
find_package(zstd REQUIRED)
if (TARGET "zstd::libzstd_static")
message(STATUS "LLVM links against static zstd")
list(APPEND LLVM_LIBS "zstd::libzstd_static")
else()
message(STATUS "LLVM links against shared zstd")
list(APPEND LLVM_LIBS "zstd::libzstd_shared")
endif()
elseif("${__flag}" STREQUAL "-lxml2")
message(STATUS "LLVM links against xml2")
list(APPEND LLVM_LIBS "-lxml2")
elseif((__flag MATCHES ".lib$") AND (EXISTS "${__llvm_libdir}/${__flag}"))
# If the library file ends in .lib try to also search the llvm_libdir
message(STATUS "LLVM linker flag under LLVM libdir: ${__llvm_libdir}/${__flag}")
list(APPEND LLVM_LIBS "${__llvm_libdir}/${__flag}")
else()
message(STATUS "LLVM linker flag: ${__flag}")
list(APPEND LLVM_LIBS "${__flag}")
endif()
list(APPEND LLVM_LIBS "${__flag}")
endforeach()
endif()
message(STATUS "Found LLVM_INCLUDE_DIRS=" "${LLVM_INCLUDE_DIRS}")
Expand Down
2 changes: 1 addition & 1 deletion conda/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

{% set version = '0.13.dev0' %}
{% set version = '0.14.dev0' %}
{% set pkg_name = 'tvm' %}
{% set cuda_tag = cuda_version | replace('.', '') %} # [cuda]
{% set pkg_name = pkg_name + '-cu' + cuda_tag %} # [cuda]
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,22 @@ class TVM_DLL Analyzer {
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief constructor */
Analyzer();
/*!
* \brief Mark the value as non-negative value globally in analyzer.
*
* Only call this function if the non-neg condition is global and
* not context-dependent.
*
* This function does best-effort propagations to the sub-analyzers
*
* \note We expose this function because non-negative global values,
* such as symbolic buffer shapes in function arguments are really
* important to ensure the best simplification, and usually they
* can be handled in a simpler way than the generic constraints.
*
* This function may call into the Update function of the sub-analyzers.
*/
void MarkGlobalNonNegValue(const PrimExpr& value);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
#endif

// TVM version
#define TVM_VERSION "0.13.dev0"
#define TVM_VERSION "0.14.dev0"

// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
*/
TVM_DLL Pass ConvertBlocksToOpaque();

/*!
* \brief Lift the same thread bindings to their LCA loops
* \return The pass.
*/
TVM_DLL Pass LiftThreadBinding();

/*!
* \brief Compact the buffer access region by removing the buffer regions that are not accessed,
* i.e. narrowing the buffer shape and adjust the access region if necessary.
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class Var : public PrimExpr {
* \param span The location of this object in the source code.
*/
TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span());
/*!
* \brief Make a new copy of var with same type, but a different nam
* \param name The new name to be used.
* \return the new Var copy
*/
TVM_DLL Var copy_with_name(const String& name) const;
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
Expand Down
6 changes: 2 additions & 4 deletions python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@
import collections
import os
import re
import textwrap
import sys
import textwrap
import typing


RequirementsByPieceType = typing.List[typing.Tuple[str, typing.Tuple[str, typing.List[str]]]]


Expand Down Expand Up @@ -253,8 +252,7 @@
("h5py", "==2.10.0"),
("image", None),
("matplotlib", None),
# Workaround, see https://github.com/apache/tvm/issues/13647
("numpy", "<=1.23"),
("numpy", None),
("onnx", None),
("onnxoptimizer", None),
("onnxruntime", None),
Expand Down
39 changes: 30 additions & 9 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
# pylint: disable=invalid-name, exec-used
"""Setup TVM package."""
import os
import pathlib
import shutil
import sys
import sysconfig
import pathlib
import platform

from setuptools import find_packages
from setuptools.dist import Distribution
Expand Down Expand Up @@ -77,6 +76,26 @@ def get_lib_path():
libs.append(candidate_path)
break

for dir in [
"3rdparty",
"jvm",
"web",
"rust",
"golang",
"include",
"src",
"cmake",
"CMakeLists.txt",
]:
for name in lib_path:
candidate_path = os.path.abspath(os.path.join(os.path.dirname(name), "..", dir))
if os.path.exists(candidate_path):
libs.append(candidate_path)
if dir == "3rdparty":
# remove large files
_remove_path(os.path.join(candidate_path, "cutlass", "docs"))
_remove_path(os.path.join(candidate_path, "cutlass", "media"))
break
else:
libs = None

Expand All @@ -94,6 +113,14 @@ def git_describe_version(original_version):
return gd_version


def _remove_path(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)


LIB_LIST, __version__ = get_lib_path()
__version__ = git_describe_version(__version__)

Expand Down Expand Up @@ -245,10 +272,4 @@ def long_description_contents():
os.remove("MANIFEST.in")
for path in LIB_LIST:
_, libname = os.path.split(path)
path_to_be_removed = f"tvm/{libname}"

if os.path.isfile(path_to_be_removed):
os.remove(path_to_be_removed)

if os.path.isdir(path_to_be_removed):
shutil.rmtree(path_to_be_removed)
_remove_path(f"tvm/{libname}")
2 changes: 1 addition & 1 deletion python/tvm/_ffi/libinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,4 @@ def find_include_path(name=None, search_path=None, optional=False):
# We use the version of the incoming release for code
# that is under development.
# The following line is set by tvm/python/update_version.py
__version__ = "0.13.dev0"
__version__ = "0.14.dev0"
5 changes: 5 additions & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class RPCError(TVMError):
"""Error thrown by the remote server handling the RPC call."""


@register_error
class RPCSessionTimeoutError(RPCError, TimeoutError):
"""Error thrown by the remote server when the RPC session has expired."""


@register_error
class OpError(TVMError):
"""Base class of all operator errors in frontends."""
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,10 +955,13 @@ def _convert_concat(
if input_shape is None:
input_shape = keras_layer.input_shape

if data_layout == "NHWC" or len(input_shape[0]) < 4:
axis = -1
else:
axis = 1
axis = keras_layer.axis
dims = len(input_shape[0])
if data_layout == "NCHW": # need_transpose
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims else 1
return _op.concatenate(_as_list(inexpr), axis=axis)


Expand Down
88 changes: 45 additions & 43 deletions python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
import os
import ctypes
import socket
import select
Expand Down Expand Up @@ -119,16 +120,6 @@ def download_linked_module(file_name):
return temp


def _serve_loop(sock, addr, load_library, work_path=None):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env(load_library, work_path)
_ffi_api.ServerLoop(sockfd)
if not work_path:
temp.remove()
logger.info("Finish serving %s", addr)


def _parse_server_opt(opts):
# parse client options
ret = {}
Expand All @@ -138,6 +129,47 @@ def _parse_server_opt(opts):
return ret


def _serving(sock, addr, opts, load_library):
logger.info(f"connected from {addr}")
work_path = utils.tempdir()
old_cwd = os.getcwd()
os.chdir(work_path.path) # Avoiding file name conflict between sessions.
logger.info(f"start serving at {work_path.path}")

def _serve_loop():
_server_env(load_library, work_path)
_ffi_api.ServerLoop(sock.fileno())

server_proc = multiprocessing.Process(target=_serve_loop)
server_proc.start()
server_proc.join(opts.get("timeout", None)) # Wait until finish or timeout.

if server_proc.is_alive():
logger.info("timeout in RPC session, kill..")
_ffi_api.ReturnException(
sock.fileno(),
f'RPCSessionTimeoutError: Your {opts["timeout"]}s session has expired, '
f'try to increase the "session_timeout" value.',
)

try:
import psutil # pylint: disable=import-outside-toplevel

# Terminate worker children firstly.
for child in psutil.Process(server_proc.pid).children(recursive=True):
child.terminate()
except ImportError:
# Don't dependent `psutil` hardly, because it isn't a pure Python
# package and maybe hard to be installed on some platforms.
pass
server_proc.terminate()

logger.info(f"finish serving {addr}")
os.chdir(old_cwd)
work_path.remove()
sock.close()


def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
"""Listening loop of the server."""

Expand Down Expand Up @@ -238,30 +270,7 @@ def _accept_conn(listen_sock, tracker_conn, ping_period=2):
raise exc

# step 3: serving
work_path = utils.tempdir()
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(
target=_serve_loop, args=(conn, addr, load_library, work_path)
)

server_proc.start()
# close from our side.
conn.close()
# wait until server process finish or timeout
server_proc.join(opts.get("timeout", None))

if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..")
# pylint: disable=import-outside-toplevel
import psutil

parent = psutil.Process(server_proc.pid)
# terminate worker children
for child in parent.children(recursive=True):
child.terminate()
# terminate the worker
server_proc.terminate()
work_path.remove()
_serving(conn, addr, opts, load_library)


def _connect_proxy_loop(addr, key, load_library):
Expand All @@ -286,15 +295,8 @@ def _connect_proxy_loop(addr, key, load_library):
raise RuntimeError(f"{str(addr)} is not RPC Proxy")
keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
remote_key = py_str(base.recvall(sock, keylen))
opts = _parse_server_opt(remote_key.split()[1:])
logger.info("connected to %s", str(addr))
process = multiprocessing.Process(target=_serve_loop, args=(sock, addr, load_library))
process.start()
sock.close()
process.join(opts.get("timeout", None))
if process.is_alive():
logger.info("Timeout in RPC session, kill..")
process.terminate()

_serving(sock, addr, _parse_server_opt(remote_key.split()[1:]), load_library)
retry_count = 0
except (socket.error, IOError) as err:
retry_count += 1
Expand Down
Loading

0 comments on commit 40300a3

Please sign in to comment.