Skip to content

Commit

Permalink
ultralytics 8.0.37 add TFLite metadata in AutoBackend (ultralytics#953
Browse files Browse the repository at this point in the history
)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <[email protected]>
Co-authored-by: Yonghye Kwon <[email protected]>
Co-authored-by: Aarni Koskela <[email protected]>
  • Loading branch information
5 people authored Feb 14, 2023
1 parent 20fe708 commit bdc6cd4
Show file tree
Hide file tree
Showing 18 changed files with 86 additions and 46 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ target/
profile_default/
ipython_config.py

# Profiling
*.pclprof

# pyenv
.python-version

Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,7 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classification/) fo

## <div align="center">Ultralytics HUB</div>

[Ultralytics HUB](https://bit.ly/ultralytics_hub) is our ⭐ **NEW** no-code solution to visualize datasets, train YOLOv8
🚀 models, and deploy to the real world in a seamless experience. Get started for **Free** now! Also run YOLOv8 models on
your iOS or Android device by downloading the [Ultralytics App](https://ultralytics.com/app_install)!
Experience seamless AI with [Ultralytics HUB](https://bit.ly/ultralytics_hub) ⭐, the all-in-one solution for data visualization, YOLOv5 and YOLOv8 (coming soon) 🚀 model training and deployment, without any coding. Transform images into actionable insights and bring your AI visions to life with ease using our cutting-edge platform and user-friendly [Ultralytics App](https://ultralytics.com/app_install). Start your journey for **Free** now!

<a href="https://bit.ly/ultralytics_hub" target="_blank">
<img width="100%" src="https://github.com/ultralytics/assets/raw/main/im/ultralytics-hub.png"></a>
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ seaborn>=0.11.0
# openvino-dev>=2022.3 # OpenVINO export

# Extras --------------------------------------
ipython # interactive notebook
psutil # system utilization
thop>=0.1.1 # FLOPs computation
wheel>=0.38.0 # Snyk vulnerability fix
# ipython # interactive notebook
# albumentations>=1.0.3
# pycocotools>=2.0.6 # COCO mAP
# roboflow
4 changes: 2 additions & 2 deletions ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

__version__ = "8.0.36"
__version__ = "8.0.37"

from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks

__all__ = ["__version__", "YOLO", "hub", "checks"] # allow simpler import
__all__ = ["__version__", "YOLO", "checks"] # allow simpler import
4 changes: 3 additions & 1 deletion ultralytics/hub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import requests

from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, RANK, SETTINGS, TryExcept, __version__,
colorstr, emojis, get_git_origin_url, is_git_dir, is_github_actions_ci,
colorstr, emojis, get_git_origin_url, is_colab, is_git_dir, is_github_actions_ci,
is_pip_package, is_pytest_running)
from ultralytics.yolo.utils.checks import check_online

Expand All @@ -36,6 +36,8 @@ def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', s

def request_with_credentials(url: str) -> any:
""" Make an ajax request with cookies attached """
if not is_colab():
raise OSError('request_with_credentials() must run in a Colab environment')
from google.colab import output # noqa
from IPython import display # noqa
display.display(
Expand Down
15 changes: 13 additions & 2 deletions ultralytics/nn/autobackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

import ast
import contextlib
import json
import platform
import zipfile
from collections import OrderedDict, namedtuple
from pathlib import Path
from urllib.parse import urlparse
Expand Down Expand Up @@ -207,14 +209,20 @@ def gd_outputs(gd):
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js
raise NotImplementedError('ERROR: YOLOv8 TF.js inference is not supported')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
import paddle.inference as pdi
if not Path(w).is_file(): # if not *.pdmodel
w = next(Path(w).rglob('*.pdmodel')) # get *.xml file from *_openvino_model dir
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
weights = Path(w).with_suffix('.pdiparams')
config = pdi.Config(str(w), str(weights))
if cuda:
Expand Down Expand Up @@ -328,6 +336,9 @@ def forward(self, im, augment=False, visualize=False):
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(self.output_details) == 2: # segment
y = [y[1], np.transpose(y[0], (0, 3, 1, 2))]
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels

Expand Down
50 changes: 28 additions & 22 deletions ultralytics/nn/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

import ast
import contextlib
from copy import deepcopy
from pathlib import Path
Expand Down Expand Up @@ -427,6 +428,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
# TODO: re-implement with eval() removal if possible
# args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a
with contextlib.suppress(NameError):
args[j] = eval(a) if isinstance(a, str) else a # eval strings

Expand Down Expand Up @@ -480,28 +483,9 @@ def guess_model_task(model):
Raises:
SyntaxError: If the task of the model could not be determined.
"""
cfg = None
if isinstance(model, dict):
cfg = model
elif isinstance(model, nn.Module): # PyTorch model
for x in 'model.args', 'model.model.args', 'model.model.model.args':
with contextlib.suppress(Exception):
return eval(x)['task']
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
with contextlib.suppress(Exception):
cfg = eval(x)
break
elif isinstance(model, (str, Path)):
model = str(model)
if '-seg' in model:
return "segment"
elif '-cls' in model:
return "classify"
else:
return "detect"

# Guess from YAML dictionary
if cfg:
def cfg2task(cfg):
# Guess from YAML dictionary
m = cfg["head"][-1][-2].lower() # output module name
if m in ["classify", "classifier", "cls", "fc"]:
return "classify"
Expand All @@ -510,8 +494,20 @@ def guess_model_task(model):
if m in ["segment"]:
return "segment"

# Guess from model cfg
if isinstance(model, dict):
with contextlib.suppress(Exception):
return cfg2task(model)

# Guess from PyTorch model
if isinstance(model, nn.Module):
if isinstance(model, nn.Module): # PyTorch model
for x in 'model.args', 'model.model.args', 'model.model.model.args':
with contextlib.suppress(Exception):
return eval(x)['task']
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
with contextlib.suppress(Exception):
return cfg2task(eval(x))

for m in model.modules():
if isinstance(m, Detect):
return "detect"
Expand All @@ -520,6 +516,16 @@ def guess_model_task(model):
elif isinstance(m, Classify):
return "classify"

# Guess from model filename
if isinstance(model, (str, Path)):
model = Path(model).stem
if '-seg' in model:
return "segment"
elif '-cls' in model:
return "classify"
else:
return "detect"

# Unable to determine task from model
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
2 changes: 2 additions & 0 deletions ultralytics/yolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license

from . import v8

__all__ = ["v8"]
2 changes: 1 addition & 1 deletion ultralytics/yolo/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
string = ''
for x in mismatched:
matches = get_close_matches(x, base) # key list
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT[k] is not None else k for k in matches] # k=v
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
match_str = f"Similar arguments are i.e. {matches}." if matches else ''
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e
Expand Down
10 changes: 10 additions & 0 deletions ultralytics/yolo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@
from .build import build_classification_dataloader, build_dataloader, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset

__all__ = [
"BaseDataset",
"ClassificationDataset",
"MixAndRectDataset",
"SemanticDataset",
"YOLODataset",
"build_classification_dataloader",
"build_dataloader",
"load_inference_source",]
4 changes: 2 additions & 2 deletions ultralytics/yolo/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode, get_latest_opset
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode

MACOS = platform.system() == 'Darwin' # macOS environment

Expand Down Expand Up @@ -508,7 +508,7 @@ def _export_saved_model(self,
onnx = self.file.with_suffix('.onnx')

# Export to TF SavedModel
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
subprocess.run(f'onnx2tf -i {onnx} -o {f} --non_verbose', shell=True)

# Add TFLite metadata
for tflite_file in Path(f).rglob('*.tflite'):
Expand Down
6 changes: 3 additions & 3 deletions ultralytics/yolo/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def _check_is_pytorch_model(self):
Raises TypeError is model is not a PyTorch model
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"model='{self.model}' must be a PyTorch model, but is a different type. PyTorch models "
f"can be used to train, val, predict and export, i.e. "
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f"PyTorch models can be used to train, val, predict and export, i.e. "
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")

Expand Down Expand Up @@ -240,7 +240,7 @@ def train(self, **kwargs):
if RANK in {0, -1}:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics_data = self.trainer.validator.metrics
self.metrics_data = self.trainer.validator.metrics

def to(self, device):
"""
Expand Down
11 changes: 4 additions & 7 deletions ultralytics/yolo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,10 @@ def is_jupyter():
Returns:
bool: True if running inside a Jupyter Notebook, False otherwise.
"""
try:
with contextlib.suppress(Exception):
from IPython import get_ipython
return get_ipython() is not None
except ImportError:
return False
return False


def is_docker() -> bool:
Expand Down Expand Up @@ -287,11 +286,9 @@ def is_pytest_running():
Returns:
(bool): True if pytest is running, False otherwise.
"""
try:
import sys
with contextlib.suppress(Exception):
return "pytest" in sys.modules
except ImportError:
return False
return False


def is_github_actions_ci() -> bool:
Expand Down
4 changes: 4 additions & 0 deletions ultralytics/yolo/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .base import add_integration_callbacks, default_callbacks

__all__ = [
'add_integration_callbacks',
'default_callbacks',]
5 changes: 3 additions & 2 deletions ultralytics/yolo/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pkg_resources as pkg
import psutil
import torch
from IPython import display
from matplotlib import font_manager

from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
Expand Down Expand Up @@ -292,8 +291,10 @@ def check_yolo(verbose=True):
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
display.clear_output()
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display
display.clear_output()
else:
s = ''

Expand Down
2 changes: 2 additions & 0 deletions ultralytics/yolo/v8/classify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val

__all__ = ["ClassificationPredictor", "predict", "ClassificationTrainer", "train", "ClassificationValidator", "val"]
2 changes: 2 additions & 0 deletions ultralytics/yolo/v8/detect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from .predict import DetectionPredictor, predict
from .train import DetectionTrainer, train
from .val import DetectionValidator, val

__all__ = ["DetectionPredictor", "predict", "DetectionTrainer", "train", "DetectionValidator", "val"]
2 changes: 2 additions & 0 deletions ultralytics/yolo/v8/segment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from .predict import SegmentationPredictor, predict
from .train import SegmentationTrainer, train
from .val import SegmentationValidator, val

__all__ = ["SegmentationPredictor", "predict", "SegmentationTrainer", "train", "SegmentationValidator", "val"]

0 comments on commit bdc6cd4

Please sign in to comment.