Skip to content

Commit

Permalink
Merge pull request iree-org#7040 from KoolJBlack:main-to-google
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 396661234
  • Loading branch information
iree-copybara-bot committed Sep 14, 2021
2 parents afbc06e + 288ff98 commit 426612f
Show file tree
Hide file tree
Showing 108 changed files with 2,404 additions and 726 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ if(${IREE_ENABLE_CCACHE})
endif()
endif()

option(IREE_DEV_MODE "Configure settings to optimize for IREE development (as opposed to CI or release)" OFF)

#-------------------------------------------------------------------------------
# IREE assertions
Expand Down
145 changes: 145 additions & 0 deletions benchmarks/dashboard.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# IREE Performance Dashboard

This documentation explains IREE's performance dashboard (https://perf.iree.dev).
A [Buildkite pipeline](https://buildkite.com/iree/iree-benchmark) runs on each
commit to the `main` branch and posts those results to the dashboard.

## Benchmarking philosophy

Benchmarking and interpreting results properly is a delicate task. We can record
metrics from various parts of a system, but depending on what we are trying to
evaluate, those numbers may or may not be relevant. For example, for somebody
working solely on better kernel code generation, the end-to-end model reference
latency is unlikely meaningful given it also includes runtime overhead. The
environment could also vary per benchmark run in uncontrollable ways, causing
instability in the results. This is especially true for mobile and embedded
systems, where a tight compromise between performance and thermal/battery limits
is made. Too many aspects can affect the benchmarking results. So before going
into details, it's worth nothing the general guideline to IREE benchmarking as
context.

The overarching goal for benchmarking here is to track IREE's performance
progress and guard against regression. So the benchmarks are meant to understand
the performance of IREE _itself_, not the absolute capability of the exercised
hardware. In order to fulfill the above goal, we have the following guidelines
for benchmarking:

* We choose representative real-world models with varying characteristics.
* We cover different IREE backends and different modes for each backend so that
folks working on different components can find the metrics they need.

## Model benchmark specification

Each benchmark in IREE has a unique identifier with the following format:

```
<model-name> `[` <model-tag>.. `]` `(` <model-source> `)` <benchmark-mode>..
`with` <iree-driver>
`@` <device-name> `(` <target-architecture> `)`
```

The following subsections explain possible choices in each field.

### Model source

This field specifies the original model source:

```
├── TensorFlow
│ * models authored in TensorFlow and imported with `iree-import-tf`
└── TFLite
* models converted to TensorFlow Lite and imported with `iree-import-tflite`
```

### Model name

This field specifies the input model:

* `DeepLabV3` [[source](https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/default/1)]:
Vision model for semantic image segmentation.
Characteristics: convolution, feedforward NN.
* `MobileBERT` [[source](https://github.com/google-research/google-research/tree/master/mobilebert)]:
NLP for Q&A.
Characteristics: matmul, attention, feedforward NN.
* `MobileNetV2` [[source](https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV2)]:
Vision model for image classification.
Characteristics: convolution, feedforward NN
* `MobileNetV3Small` [[source](https://www.tensorflow.org/api_docs/python/tf/keras/applications/MobileNetV3Small)]:
Vision model for image classification.
Characteristics: convolution, feedforward NN.
* `MobileSSD` [[source](https://www.tensorflow.org/lite/performance/gpu#demo_app_tutorials)]:
Vision model for object detection.
Characteristics: convolution, feedforward NN.
* `PoseNet` [[source](https://tfhub.dev/tensorflow/lite-model/posenet/mobilenet/float/075/1/default/1)]:
Vision model for pose estimation.
Characteristics: convolution, feedforward NN.

### Model tag

This field specifies the model variant. It depends on the model, but here are
some examples:

* `f32`: the model is working on float types.
* `imagenet`: the model takes ImageNet-sized inputs (224x224x3).

### IREE driver

This field specifies the IREE HAL driver:

* [`Dylib`](https://google.github.io/iree/deployment-configurations/cpu-dylib/):
For CPU via dynamic library. Kernels contain CPU native instructions AOT
compiled using LLVM. This driver issues workload to the CPU in async
manner and supports multithreading.
* [`Dylib-Sync`](https://google.github.io/iree/deployment-configurations/cpu-dylib/):
For CPU via dynamic library. Kernels contain contain CPU native instructions
AOT compiled using LLVM. This driver issues workload to the CPU in sync
manner.
* [`VMVX`](https://github.com/google/iree/issues/5123):
For CPU via dynamic library. Kernels contain vector-level intrinsics that
are backed by fast implementations ([WIP](https://github.com/google/iree/issues/5819)).
This driver issues workload to the CPU in async manner and supports
multithreading.
* [`Vulkan`](https://google.github.io/iree/deployment-configurations/gpu-vulkan/):
For GPU via Vulkan. Kernels contain SPIR-V. This driver issues workload to
the GPU via the Vulkan API.

### Device name and target architecture

These two fields are tightly coupled. They specify the device and hardware
target for executing the benchmark.

Right now there are two Android devices:

* `Pixel-4`: Google Pixel 4 running Android 11. The SoC is
[Snapdragon 855](https://www.qualcomm.com/products/snapdragon-855-plus-and-860-mobile-platform),
with 1+3+4 ARMv8.2 CPU cores and Adreno 640 GPU.
* `SM-G980F`: Samsung Galaxy S20 running Android 11. The SoC is
[Exynos 990](https://www.samsung.com/semiconductor/minisite/exynos/products/mobileprocessor/exynos-990/),
with 2+2+4 ARMv8.2 CPU cores and Mali G77 MP11 GPU.

Therefore the target architectures are:

* `CPU-CPU-ARMv8.2-A`: can benchmark all CPU-based IREE backends and drivers.
* `GPU-Adreno-640`: can benchmark IREE Vulkan with Adreno target triples.
* `GPU-Mali-G77`: can benchmark IREE Vulkan with Mali target triples.

### Benchmark mode

This field is to further specify the benchmark variant, given the same input
model and target architecture. It controls important aspects like:

* `*-core`: specifies the core flavor for CPU.
* `*-thread`: specifies the number of threads for CPU.
* `full-inference`: measures the latency for one full inference. Note that this
does not include the IREE system initialization time.
* `kernel-execution`: measures only kernel execution latency for GPU. Note that
this is only possible for feedforward NN models that can be put into one
command buffer.

`*-core` and `*-thread` together determines the `taskset` mask used for
benchmarking IREE backends and drivers on CPU. For example,

* `1-thread,big-core` would mean `taskset 80`.
* `1-thread,little-core` would mean `taskset 08`.
* `3-thread,big-core` would mean `taskset f0`.
* `3-thread,little-core` would mean `taskset 0f`.
11 changes: 11 additions & 0 deletions bindings/python/iree/runtime/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,22 @@ def convert(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
return convert


def _vm_to_pylist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
# The descriptor for a pylist is like:
# ['pylist', element_type]
sub_vm_list = vm_list.get_as_list(vm_index)
element_type_desc = desc[1:]
py_items = _extract_vm_sequence_to_python(
inv, sub_vm_list, element_type_desc * len(sub_vm_list))
return py_items


VM_TO_PYTHON_CONVERTERS = {
"ndarray": _vm_to_ndarray,
"sdict": _vm_to_sdict,
"slist": _vm_to_slist,
"stuple": _vm_to_stuple,
"py_homogeneous_list": _vm_to_pylist,

# Scalars.
"i8": _vm_to_scalar(int),
Expand Down
20 changes: 20 additions & 0 deletions bindings/python/iree/runtime/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,26 @@ def invoke(arg_list, ret_list):
# assertEqual on bool arrays is fraught for... reasons.
self.assertEqual("array([ True, False])", repr(result))

def testReturnTypeList(self):
vm_list = VmVariantList(2)
vm_list.push_int(1)
vm_list.push_int(2)

def invoke(arg_list, ret_list):
ret_list.push_list(vm_list)

vm_context = MockVmContext(invoke)
vm_function = MockVmFunction(reflection={
"iree.abi":
json.dumps({
"a": [],
"r": [["py_homogeneous_list", "i64"]],
})
})
invoker = FunctionInvoker(vm_context, self.device, vm_function, tracer=None)
result = invoker()
self.assertEqual("[1, 2]", repr(result))


if __name__ == "__main__":
absltest.main()
30 changes: 29 additions & 1 deletion bindings/python/iree/runtime/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,35 @@ void AppendListContents(std::string& out, iree_vm_list_t* list,
if (i > 0) out.append(", ");

if (iree_vm_variant_is_value(variant)) {
out += std::to_string(variant.i32);
// Convert a value type to a string.
switch (variant.type.value_type) {
case IREE_VM_VALUE_TYPE_I8: {
out += std::to_string(variant.i8);
break;
}
case IREE_VM_VALUE_TYPE_I16: {
out += std::to_string(variant.i16);
break;
}
case IREE_VM_VALUE_TYPE_I32: {
out += std::to_string(variant.i32);
break;
}
case IREE_VM_VALUE_TYPE_I64: {
out += std::to_string(variant.i64);
break;
}
case IREE_VM_VALUE_TYPE_F32: {
out += std::to_string(variant.f32);
break;
}
case IREE_VM_VALUE_TYPE_F64: {
out += std::to_string(variant.f64);
break;
}
default:
throw RaiseValueError("Unsupported VM value type to string");
}
} else if (iree_vm_variant_is_ref(variant)) {
// Pretty print a subset of ABI impacting known types.
if (iree_hal_buffer_isa(variant.ref)) {
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/iree/runtime/vm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def test_variant_list(self):
logging.info("variant_list: %s", l)
self.assertEqual(l.size, 0)

def test_variant_list_i64(self):
l = iree.runtime.VmVariantList(5)
# Push a value that exceeds 32-bit range.
l.push_int(10 * 1000 * 1000 * 1000)
self.assertEqual(str(l), "<VmVariantList(1): [10000000000]>")

def test_variant_list_buffers(self):
ET = iree.runtime.HalElementType
for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
Expand All @@ -19,7 +18,8 @@
from typing import Any, Dict, Sequence

__all__ = [
"AndroidDeviceInfo", "BenchmarkInfo", "BenchmarkResults", "get_output"
"AndroidDeviceInfo", "BenchmarkInfo", "BenchmarkResults",
"execute_cmd_and_get_output"
]

# A map for IREE driver names. This allows us to normalize driver names like
Expand All @@ -33,7 +33,9 @@
}


def get_output(args: Sequence[str], verbose: bool = False, **kwargs) -> str:
def execute_cmd_and_get_output(args: Sequence[str],
verbose: bool = False,
**kwargs) -> str:
"""Executes a command and returns its stdout."""
if verbose:
cmd = " ".join(args)
Expand All @@ -47,22 +49,22 @@ def get_output(args: Sequence[str], verbose: bool = False, **kwargs) -> str:

def get_android_device_model(verbose: bool = False) -> str:
"""Returns the Android device model."""
model = get_output(["adb", "shell", "getprop", "ro.product.model"],
verbose=verbose)
model = execute_cmd_and_get_output(
["adb", "shell", "getprop", "ro.product.model"], verbose=verbose)
model = re.sub(r"\W+", "-", model)
return model


def get_android_cpu_abi(verbose: bool = False) -> str:
"""Returns the CPU ABI for the Android device."""
return get_output(["adb", "shell", "getprop", "ro.product.cpu.abi"],
verbose=verbose)
return execute_cmd_and_get_output(
["adb", "shell", "getprop", "ro.product.cpu.abi"], verbose=verbose)


def get_android_cpu_features(verbose: bool = False) -> Sequence[str]:
"""Returns the CPU features for the Android device."""
cpuinfo = get_output(["adb", "shell", "cat", "/proc/cpuinfo"],
verbose=verbose)
cpuinfo = execute_cmd_and_get_output(["adb", "shell", "cat", "/proc/cpuinfo"],
verbose=verbose)
features = []
for line in cpuinfo.splitlines():
if line.startswith("Features"):
Expand All @@ -73,7 +75,8 @@ def get_android_cpu_features(verbose: bool = False) -> Sequence[str]:

def get_android_gpu_name(verbose: bool = False) -> str:
"""Returns the GPU name for the Android device."""
vkjson = get_output(["adb", "shell", "cmd", "gpu", "vkjson"], verbose=verbose)
vkjson = execute_cmd_and_get_output(["adb", "shell", "cmd", "gpu", "vkjson"],
verbose=verbose)
vkjson = json.loads(vkjson)
name = vkjson["devices"][0]["properties"]["deviceName"]

Expand Down
22 changes: 22 additions & 0 deletions build_tools/benchmarks/common/noisy_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""A list of noisy benchmarks and their average thresholds."""

import re

# A list of noisy benchmarks. Each one is a tuple that contains the following
# fields:
# - A regular expression to match against the benchmark identifier.
# - A threshold for computing the benchmark value average. Benchmark sample
# values from consecutive runs and within the given range will be considered
# as similar (with some noise). They will be used to compute the moving
# average. The number will be interpreted as a percentage. What value to set
# depends on the noise range of the particular benchmark.
NOISY_BENCHMARKS = [
(re.compile(r"^DeepLabV3.*GPU-Mali-G77"), 100),
(re.compile(r"^MobileSSD.*GPU-Mali-G77"), 100),
(re.compile(r"^PoseNet.*GPU-Mali-G77"), 100),
]
Loading

0 comments on commit 426612f

Please sign in to comment.