High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
- SHARK Discord server: Real time discussions with the SHARK team and other users
- GitHub issues: Feature requests, bugs etc
Installation (Linux and macOS)
This step sets up a new VirtualEnv for Python
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
python -m venv shark_venv
source shark_venv/bin/activate
# If you are using conda create and activate a new conda env
# Some older pip installs may not be able to handle the recent PyTorch deps
python -m pip install --upgrade pip
macOS Metal users please install https://sdk.lunarg.com/sdk/download/latest/mac/vulkan-sdk.dmg and enable "System wide install"
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
pip install nodai-shark -f https://github.com/nod-ai/SHARK/releases -f https://github.com/llvm/torch-mlir/releases -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
If you are on an Intel macOS machine you need this workaround for an upstream issue.
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/resnet50_script.py
#Install deps for test script
pip install --pre torch torchvision torchaudio tqdm pillow --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python ./resnet50_script.py --device="cpu" #use cuda or vulkan or metal
curl -O https://raw.githubusercontent.com/nod-ai/SHARK/main/shark/examples/shark_inference/minilm_jit.py
#Install deps for test script
pip install transformers torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
Source Installation
git clone https://github.com/nod-ai/SHARK.git
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
./setup_venv.sh
source shark.venv/bin/activate
For example if you want to use Python3.10 and upstream IREE with TF Import tools you can use the environment variables like:
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 USE_IREE=1 ./setup_venv.sh
If you are a Torch-mlir developer or an IREE developer and want to test local changes you can uninstall
the provided packages with pip uninstall torch-mlir
and / or pip uninstall iree-compiler iree-runtime
and build locally
with Python bindings and set your PYTHONPATH as mentioned here
for IREE and here
for Torch-MLIR.
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
# Or a pytest
pytest tank/tf/hf_masked_lm/albert-base-v2_test.py::AlbertBaseModuleTest::test_module_static_cpu
Testing
pytest tank
# If on Linux for quicker results:
pytest tank -n auto
# Run tests for a specific model:
pytest tank/<MODEL_NAME> #i.e., pytest tank/bert-base-uncased
# Run tests for a specific case:
pytest tank/<MODEL_NAME>/<MODEL_TEST>.py::<MODEL>ModuleTest::<CASE>
# i.e., pytest tank/bert-base-uncased/bert-base-uncased_test.py::BertModuleTest::test_module_static_cpu
# For frontends other than pytorch, if available for a model, add frontend to filename: tank/bert-base-uncased/bert-base-uncased_tf_test.py
# Run all tests, including tests for benchmarking and SHARK modules:
# From base SHARK directory,
pytest
pytest benchmarks
API Reference
from shark.shark_importer import SharkImporter
# SharkImporter imports mlir file from the torch, tensorflow or tf-lite module.
mlir_importer = SharkImporter(
torch_module,
(input),
frontend="torch", #tf, #tf-lite
)
torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
from shark.shark_inference import SharkInference
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
shark_module.compile()
result = shark_module.forward((input))
from shark.shark_inference import SharkInference
import numpy as np
mhlo_ir = r"""builtin.module {
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
}"""
arg0 = np.ones((1, 4)).astype(np.float32)
arg1 = np.ones((4, 1)).astype(np.float32)
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
shark_module.compile()
result = shark_module.forward((arg0, arg1))
PyTorch Models
Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
---|---|---|---|---|
BERT | π (JIT) | π | π | π |
Albert | π (JIT) | π | π | π |
BigBird | π (AOT) | |||
DistilBERT | π (JIT) | π | π | π |
GPT2 | π (AOT) | |||
MobileBert | π (JIT) | π | π | π |
TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
---|---|---|---|---|
AlexNet | π (Script) | π | π | π |
DenseNet121 | π (Script) | |||
MNasNet1_0 | π (Script) | π | π | π |
MobileNetV2 | π (Script) | π | π | π |
MobileNetV3 | π (Script) | π | π | π |
Unet | π (Script) | |||
Resnet18 | π (Script) | π | π | π |
Resnet50 | π (Script) | π | π | π |
Resnet101 | π (Script) | π | π | π |
Resnext50_32x4d | π (Script) | π | π | π |
ShuffleNet_v2 | π (Script) | |||
SqueezeNet | π (Script) | π | π | π |
EfficientNet | π (Script) | |||
Regnet | π (Script) | π | π | π |
Resnest | π (Script) | |||
Vision Transformer | π (Script) | |||
VGG 16 | π (Script) | π | π | |
Wide Resnet | π (Script) | π | π | π |
RAFT | π (JIT) |
For more information refer to MODEL TRACKING SHEET
Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
---|---|---|---|---|
BERT | π | π | ||
FullyConnected | π | π |
JAX Models
Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
---|---|---|---|---|
DALL-E | π | π | ||
FullyConnected | π | π |
TFLite Models
Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
---|---|---|---|---|
BERT | π | π | ||
FullyConnected | π | π | ||
albert | π | π | ||
asr_conformer | π | π | ||
bird_classifier | π | π | ||
cartoon_gan | π | π | ||
craft_text | π | π | ||
deeplab_v3 | π | π | ||
densenet | π | π | ||
east_text_detector | π | π | ||
efficientnet_lite0_int8 | π | π | ||
efficientnet | π | π | ||
gpt2 | π | π | ||
image_stylization | π | π | ||
inception_v4 | π | π | ||
inception_v4_uint8 | π | π | ||
lightning_fp16 | π | π | ||
lightning_i8 | π | π | ||
lightning | π | π | ||
magenta | π | π | ||
midas | π | π | ||
mirnet | π | π | ||
mnasnet | π | π | ||
mobilebert_edgetpu_s_float | π | π | ||
mobilebert_edgetpu_s_quant | π | π | ||
mobilebert | π | π | ||
mobilebert_tf2_float | π | π | ||
mobilebert_tf2_quant | π | π | ||
mobilenet_ssd_quant | π | π | ||
mobilenet_v1 | π | π | ||
mobilenet_v1_uint8 | π | π | ||
mobilenet_v2_int8 | π | π | ||
mobilenet_v2 | π | π | ||
mobilenet_v2_uint8 | π | π | ||
mobilenet_v3-large | π | π | ||
mobilenet_v3-large_uint8 | π | π | ||
mobilenet_v35-int8 | π | π | ||
nasnet | π | π | ||
person_detect | π | π | ||
posenet | π | π | ||
resnet_50_int8 | π | π | ||
rosetta | π | π | ||
spice | π | π | ||
squeezenet | π | π | ||
ssd_mobilenet_v1 | π | π | ||
ssd_mobilenet_v1_uint8 | π | π | ||
ssd_mobilenet_v2_fpnlite | π | π | ||
ssd_mobilenet_v2_fpnlite_uint8 | π | π | ||
ssd_mobilenet_v2_int8 | π | π | ||
ssd_mobilenet_v2 | π | π | ||
ssd_spaghettinet_large | π | π | ||
ssd_spaghettinet_large_uint8 | π | π | ||
visual_wake_words_i8 | π | π |
TF Models
Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
---|---|---|---|---|
BERT | π | π | π | π |
albert-base-v2 | π | π | π | π |
DistilBERT | π | π | π | π |
CamemBert | π | π | π | π |
ConvBert | π | π | π | π |
Deberta | ||||
electra | π | π | π | π |
funnel | ||||
layoutlm | π | π | π | π |
longformer | ||||
mobile-bert | π | π | π | π |
remembert | ||||
tapas | ||||
flaubert | π | π | π | π |
roberta | π | π | π | π |
xlm-roberta | π | π | π | π |
mpnet | π | π | π | π |
IREE Project Channels
- Upstream IREE issues: Feature requests, bugs, and other work tracking
- Upstream IREE Discord server: Daily development discussions with the core team and collaborators
- iree-discuss email list: Announcements, general and low-priority discussion
MLIR and Torch-MLIR Project Channels
#torch-mlir
channel on the LLVM Discord - this is the most active communication channel- Torch-MLIR Github issues here
torch-mlir
section of LLVM Discourse- Weekly meetings on Mondays 9AM PST. See here for more information.
- MLIR topic within LLVM Discourse SHARK and IREE is enabled by and heavily relies on MLIR.
nod.ai SHARK is licensed under the terms of the Apache 2.0 License with LLVM Exceptions. See LICENSE for more information.