Skip to content

Commit

Permalink
New YOLOv8 Results() class for prediction outputs (ultralytics#314)
Browse files Browse the repository at this point in the history
Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: Laughing-q <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Laughing <[email protected]>
Co-authored-by: Viet Nhat Thai <[email protected]>
Co-authored-by: Paula Derrenger <[email protected]>
  • Loading branch information
8 people authored Jan 17, 2023
1 parent 0cb87f7 commit c6985da
Show file tree
Hide file tree
Showing 32 changed files with 816 additions and 262 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ jobs:
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
- name: Pytest tests
shell: bash # for Windows compatibility
run: pytest tests
run: pytest tests
2 changes: 1 addition & 1 deletion .github/workflows/cla.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
steps:
- name: "CLA Assistant"
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I sign the CLA') || github.event_name == 'pull_request_target'
uses: contributor-assistant/[email protected].0
uses: contributor-assistant/[email protected].1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# must be repository secret token
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ include requirements.txt
include LICENSE
include setup.py
recursive-include ultralytics *.yaml
recursive-exclude __pycache__ *
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ To request an Enterprise License please complete the form at [Ultralytics Licens

<div align="center">

[Ultralytics Live Session 3](https://youtu.be/IPcpYO5ITa8) ✨ is here! Join us on January 18th at 18 CET as we dive into the latest advancements in YOLOv8, and demonstrate how to use this cutting-edge, SOTA model to improve your object detection, instance segmentation, and image classification projects. See firsthand how YOLOv8's speed, accuracy, and ease of use make it a top choice for professionals and researchers alike.
[Ultralytics Live Session 3](https://youtu.be/IPcpYO5ITa8) ✨ is here! Join us on January 24th at 18 CET as we dive into the latest advancements in YOLOv8, and demonstrate how to use this cutting-edge, SOTA model to improve your object detection, instance segmentation, and image classification projects. See firsthand how YOLOv8's speed, accuracy, and ease of use make it a top choice for professionals and researchers alike.

In addition to learning about the exciting new features and improvements of Ultralytics YOLOv8, you will also have the opportunity to ask questions and interact with our team during the live Q&A session. We encourage all of you to come prepared with any questions you may have.
In addition to learning about the exciting new features and improvements of Ultralytics YOLOv8, you will also have the opportunity to ask questions and interact with our team during the live Q&A session. We encourage you to come prepared with any questions you may have.

Don't miss out on this opportunity! To join the webinar, visit our YouTube [Channel](https://www.youtube.com/@Ultralytics/streams) and turn on your notifications!
To join the webinar, visit our YouTube [Channel](https://www.youtube.com/@Ultralytics/streams) and turn on your notifications!

<a align="center" href="https://youtu.be/IPcpYO5ITa8" target="_blank">
<img width="80%" src="https://user-images.githubusercontent.com/26833433/212472119-7de539c1-5022-41cf-ae28-37b69158fbbe.png"></a>
<img width="80%" src="https://user-images.githubusercontent.com/107626595/212887899-e94b006c-5192-40fa-8b24-7b5428e065e8.png"></a>
</div>

## <div align="center">Documentation</div>
Expand All @@ -76,7 +76,7 @@ documentation on training, validation, prediction and deployment.

Pip install the ultralytics package including
all [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) in a
[**Python>=3.7.0**](https://www.python.org/) environment, including
[**3.10>=Python>=3.7**](https://www.python.org/) environment, including
[**PyTorch>=1.7**](https://pytorch.org/get-started/locally/).

```bash
Expand Down
2 changes: 1 addition & 1 deletion README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
<details open>
<summary>安装</summary>

Pip 安装包含所有 [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) 的 ultralytics 包,环境要求 [**Python>=3.7.0**](https://www.python.org/),且 [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/)
Pip 安装包含所有 [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) 的 ultralytics 包,环境要求 [**3.10>=Python>=3.7**](https://www.python.org/),且 [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/)

```bash
pip install ultralytics
Expand Down
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics Docs

Deployed to https://docs.ultralytics.com
Ultralytics Docs are deployed to [https://docs.ultralytics.com](https://docs.ultralytics.com).

### Install Ultralytics package

Expand Down
2 changes: 1 addition & 1 deletion docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ yolo cfg=default.yaml
yolo task=init
yolo cfg=default.yaml
```
=== "Result"
=== "Results"
TODO: add terminal output


File renamed without changes.
72 changes: 72 additions & 0 deletions docs/predict.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
Inference or prediction of a task returns a list of `Results` objects. Alternatively, in the streaming mode, it returns a generator of `Results` objects which is memory efficient. Streaming mode can be enabled by passing `stream=True` in predictor's call method.

!!! example "Predict"
=== "Getting a List"
```python
inputs = [img, img] # list of np arrays
results = model(inputs) # List of Results objects
for result in results:
boxes = results.boxes # Boxes object for bbox outputs
masks = results.masks # Masks object for segmenation masks outputs
probs = results.probs # Class probabilities for classification outputs
...
```
=== "Getting a Generator"
```python
inputs = [img, img] # list of np arrays
results = model(inputs, stream="True") # Generator of Results objects
for result in results:
boxes = results.boxes # Boxes object for bbox outputs
masks = results.masks # Masks object for segmenation masks outputs
probs = results.probs # Class probabilities for classification outputs
...
```

## Working with Results

Results object consists of these component objects:

- `results.boxes` : It is an object of class `Boxes`. It has properties and methods for manipulating bboxes
- `results.masks` : It is an object of class `Masks`. It can be used to index masks or to get segment coordinates.
- `results.prob` : It is a `Tensor` object. It contains the class probabilities/logits.

Each result is composed of torch.Tensor by default, in which you can easily use following functionality:
```python
results = results.cuda()
results = results.cpu()
results = results.to("cpu")
results = results.numpy()
```
### Boxes
`Boxes` object can be used index, manipulate and convert bboxes to different formats. The box format conversion operations are cached, which means they're only calculated once per object and those values are reused for future calls.

- Indexing a `Boxes` objects returns a `Boxes` object
```python
boxes = results.boxes
box = boxes[0] # returns one box
box.xyxy
```
- Properties and conversions
```
results.boxes.xyxy # box with xyxy format, (N, 4)
results.boxes.xywh # box with xywh format, (N, 4)
results.boxes.xyxyn # box with xyxy format but normalized, (N, 4)
results.boxes.xywhn # box with xywh format but normalized, (N, 4)
results.boxes.conf # confidence score, (N, 1)
results.boxes.cls # cls, (N, 1)
```
### Masks
`Masks` object can be used index, manipulate and convert masks to segments. The segment conversion operation is cached.

```python
results.masks.masks # masks, (N, H, W)
results.masks.segments # bounding coordinates of masks, List[segment] * N
```

### probs
`probs` attribute of `Results` class is a `Tensor` containing class probabilities of a classification operation.
```python
results.probs # cls prob, (num_class, )
```

Class reference documentation for `Results` module and its components can be found [here](reference/results.md)
72 changes: 48 additions & 24 deletions docs/python.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
This is the simplest way of simply using YOLOv8 models in a Python environment. It can be imported from
the `ultralytics` module.
The simplest way of simply using YOLOv8 directly in a Python environment.

!!! example "Train"

Expand Down Expand Up @@ -51,35 +50,60 @@ the `ultralytics` module.
=== "From source"
```python
from ultralytics import YOLO
from PIL import Image
import cv2

model = YOLO("model.pt")
model.predict(source="0") # accepts all formats - img/folder/vid.*(mp4/format). 0 for webcam
model.predict(source="folder", show=True) # Display preds. Accepts all yolo predict arguments
# accepts all formats - image/dir/Path/URL/video/PIL/ndarray. 0 for webcam
results = model.predict(source="0")
results = model.predict(source="folder", show=True) # Display preds. Accepts all YOLO predict arguments

```
# from PIL
im1 = Image.open("bus.jpg")
results = model.predict(source=im1, save=True) # save plotted images

=== "From image/ndarray/tensor"
```python
# TODO, still working on it.
```
# from ndarray
im2 = cv2.imread("bus.jpg")
results = model.predict(source=im2, save=True, save_txt=True) # save predictions as labels

# from list of PIL/ndarray
results = model.predict(source=[im1, im2])
```

=== "Return outputs"
=== "Results usage"
```python
from ultralytics import YOLO

model = YOLO("model.pt")
outputs = model.predict(source="0", return_outputs=True) # treat predict as a Python generator
for output in outputs:
# each output here is a dict.
# for detection
print(output["det"]) # np.ndarray, (N, 6), xyxy, score, cls
# for segmentation
print(output["det"]) # np.ndarray, (N, 6), xyxy, score, cls
print(output["segment"]) # List[np.ndarray] * N, bounding coordinates of masks
# for classify
print(output["prob"]) # np.ndarray, (num_class, ), cls prob

# results would be a list of Results object including all the predictions by default
# but be careful as it could occupy a lot memory when there're many images,
# especially the task is segmentation.
# 1. return as a list
results = model.predict(source="folder")

# results would be a generator which is more friendly to memory by setting stream=True
# 2. return as a generator
results = model.predict(source=0, stream=True)

for result in results:
# detection
result.boxes.xyxy # box with xyxy format, (N, 4)
result.boxes.xywh # box with xywh format, (N, 4)
result.boxes.xyxyn # box with xyxy format but normalized, (N, 4)
result.boxes.xywhn # box with xywh format but normalized, (N, 4)
result.boxes.conf # confidence score, (N, 1)
result.boxes.cls # cls, (N, 1)

# segmentation
result.masks.masks # masks, (N, H, W)
result.masks.segments # bounding coordinates of masks, List[segment] * N

# classification
result.probs # cls prob, (num_class, )

# Each result is composed of torch.Tensor by default,
# in which you can easily use following functionality:
result = result.cuda()
result = result.cpu()
result = result.to("cpu")
result = result.numpy()
```

!!! note "Export and Deployment"
Expand Down
11 changes: 11 additions & 0 deletions docs/reference/results.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
### Results API Reference

:::ultralytics.yolo.engine.results.Results

### Boxes API Reference

:::ultralytics.yolo.engine.results.Boxes

### Masks API Reference

:::ultralytics.yolo.engine.results.Masks
10 changes: 5 additions & 5 deletions examples/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"import ultralytics\n",
"ultralytics.checks()"
],
"execution_count": 1,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -145,7 +145,7 @@
},
"source": [
"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;\n",
"<img align=\"left\" src=\"https://user-images.githubusercontent.com/26833433/127574988-6a558aa1-d268-44b9-bf6b-62d4c605cc72.jpg\" width=\"600\">"
"<img align=\"left\" src=\"https://user-images.githubusercontent.com/26833433/212889447-69e5bdf1-5800-4e29-835e-2ed2336dede2.jpg\" width=\"600\">"
]
},
{
Expand Down Expand Up @@ -185,7 +185,7 @@
"# Validate YOLOv8n on COCO128 val\n",
"!yolo task=detect mode=val model=yolov8n.pt data=coco128.yaml"
],
"execution_count": 2,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -310,7 +310,7 @@
"# Train YOLOv8n on COCO128 for 3 epochs\n",
"!yolo task=detect mode=train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640"
],
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -501,7 +501,7 @@
"id": "CYIjW4igCjqD",
"outputId": "3bb45917-f90e-4951-959d-7bcd26680f2e"
},
"execution_count": 4,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down
9 changes: 6 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ plugins:

# Primary navigation
nav:
- Home: index.md
- Home: home.md
- Quickstart: quickstart.md
- Tasks:
- Detection: tasks/detection.md
Expand All @@ -84,6 +84,7 @@ nav:
- Usage:
- CLI: cli.md
- Python: python.md
- Predict: predict.md
- Configuration: config.md
- Customization Guide: engine.md
- Ultralytics HUB: hub.md
Expand All @@ -95,5 +96,7 @@ nav:
- Validator: reference/base_val.md
- Predictor: reference/base_pred.md
- Exporter: reference/exporter.md
- nn Module: reference/nn.md
- operations: reference/ops.md
- Results: reference/results.md
- ultralytics.nn: reference/nn.md
- Operations: reference/ops.md
- Docs: README.md
13 changes: 6 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,30 @@

def get_version():
file = PARENT / 'ultralytics/__init__.py'
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(), re.M)[1]
return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', file.read_text(encoding="utf-8"), re.M)[1]


setup(
name="ultralytics", # name of pypi package
version=get_version(), # version of pypi package
python_requires=">=3.7.0",
python_requires=">=3.7,<=3.11",
license='GPL-3.0',
description='Ultralytics YOLOv8 and HUB',
description='Ultralytics YOLOv8',
long_description=README,
long_description_content_type="text/markdown",
url="https://github.com/ultralytics/ultralytics",
project_urls={
'Bug Reports': 'https://github.com/ultralytics/ultralytics/issues',
'Funding': 'https://ultralytics.com',
'Source': 'https://github.com/ultralytics/ultralytics',},
'Source': 'https://github.com/ultralytics/ultralytics'},
author="Ultralytics",
author_email='[email protected]',
packages=find_packages(), # required
include_package_data=True,
install_requires=REQUIREMENTS,
extras_require={
'dev':
['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material'],},
['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs', 'mkdocstrings[python]', 'mkdocs-material']},
classifiers=[
"Intended Audience :: Developers", "Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Programming Language :: Python :: 3",
Expand All @@ -49,5 +49,4 @@ def get_version():
"Topic :: Scientific/Engineering :: Image Recognition", "Operating System :: POSIX :: Linux",
"Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
entry_points={
'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli'],})
entry_points={'console_scripts': ['yolo = ultralytics.yolo.cli:cli', 'ultralytics = ultralytics.yolo.cli:cli']})
Loading

0 comments on commit c6985da

Please sign in to comment.