Skip to content

Commit

Permalink
CLI updates (ultralytics#58)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AyushExel and pre-commit-ci[bot] authored Nov 29, 2022
1 parent c5f5b80 commit d0b0fe2
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def get_version():
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
entry_points={
'console_scripts': [
'yolo = ultralytics.yolo.__init__:cli',],})
'yolo = ultralytics.yolo.cli:cli',],})
36 changes: 1 addition & 35 deletions ultralytics/yolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,5 @@
import hydra

import ultralytics
import ultralytics.yolo.v8 as yolo

from .engine.model import YOLO
from .engine.trainer import DEFAULT_CONFIG, BaseTrainer
from .engine.trainer import BaseTrainer
from .engine.validator import BaseValidator
from .utils import LOGGER

__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import


@hydra.main(version_base=None, config_path="utils/configs", config_name="default")
def cli(cfg):
LOGGER.info(f"using Ultralytics YOLO v{ultralytics.__version__}")
module_file = None
if cfg.task.lower() == "detect":
module_file = yolo.detect
elif cfg.task.lower() == "segment":
module_file = yolo.segment
elif cfg.task.lower() == "classify":
module_file = yolo.classify

if not module_file:
raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`")

module_function = None

if cfg.mode.lower() == "train":
module_function = module_file.train
elif cfg.mode.lower() == "val":
module_function = module_file.val
elif cfg.mode.lower() == "infer":
module_function = module_file.infer

if not module_function:
raise Exception("mode not recognized. Choices are `'train', 'val', 'infer'`")
module_function(cfg)
47 changes: 47 additions & 0 deletions ultralytics/yolo/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import shutil

import hydra

import ultralytics
import ultralytics.yolo.v8 as yolo
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG

from .utils import LOGGER, colorstr


@hydra.main(version_base=None, config_path="utils/configs", config_name="default")
def cli(cfg):
LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")

module_file = None
if cfg.task.lower() == "init": # special case
shutil.copy2(DEFAULT_CONFIG, os.getcwd())
LOGGER.info(f"""
{colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}.
To run experiments using custom configuration:
yolo task='task' mode='mode' --config-name config_file.yaml
""")
return
elif cfg.task.lower() == "detect":
module_file = yolo.detect
elif cfg.task.lower() == "segment":
module_file = yolo.segment
elif cfg.task.lower() == "classify":
module_file = yolo.classify

if not module_file:
raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`")

module_function = None

if cfg.mode.lower() == "train":
module_function = module_file.train
elif cfg.mode.lower() == "val":
module_function = module_file.val
elif cfg.mode.lower() == "infer":
module_function = module_file.infer

if not module_function:
raise Exception("mode not recognized. Choices are `'train', 'val', 'infer'`")
module_function(cfg)
1 change: 0 additions & 1 deletion ultralytics/yolo/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import yaml

import ultralytics.yolo as yolo
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.modeling import get_model
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/yolo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib
import logging
import logging.config
import os
import platform
import sys
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/yolo/utils/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Default training settings and hyperparameters for medium-augmentation COCO training

# Task and Mode
task: "classify" # choices=['detect', 'segment', 'classify']
task: "classify" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case
mode: "train" # choice=['train', 'val', 'infer']

# Train settings -------------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit d0b0fe2

Please sign in to comment.