forked from ultralytics/ultralytics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
c5f5b80
commit d0b0fe2
Showing
6 changed files
with
51 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,5 @@ | ||
import hydra | ||
|
||
import ultralytics | ||
import ultralytics.yolo.v8 as yolo | ||
|
||
from .engine.model import YOLO | ||
from .engine.trainer import DEFAULT_CONFIG, BaseTrainer | ||
from .engine.trainer import BaseTrainer | ||
from .engine.validator import BaseValidator | ||
from .utils import LOGGER | ||
|
||
__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import | ||
|
||
|
||
@hydra.main(version_base=None, config_path="utils/configs", config_name="default") | ||
def cli(cfg): | ||
LOGGER.info(f"using Ultralytics YOLO v{ultralytics.__version__}") | ||
module_file = None | ||
if cfg.task.lower() == "detect": | ||
module_file = yolo.detect | ||
elif cfg.task.lower() == "segment": | ||
module_file = yolo.segment | ||
elif cfg.task.lower() == "classify": | ||
module_file = yolo.classify | ||
|
||
if not module_file: | ||
raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`") | ||
|
||
module_function = None | ||
|
||
if cfg.mode.lower() == "train": | ||
module_function = module_file.train | ||
elif cfg.mode.lower() == "val": | ||
module_function = module_file.val | ||
elif cfg.mode.lower() == "infer": | ||
module_function = module_file.infer | ||
|
||
if not module_function: | ||
raise Exception("mode not recognized. Choices are `'train', 'val', 'infer'`") | ||
module_function(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os | ||
import shutil | ||
|
||
import hydra | ||
|
||
import ultralytics | ||
import ultralytics.yolo.v8 as yolo | ||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG | ||
|
||
from .utils import LOGGER, colorstr | ||
|
||
|
||
@hydra.main(version_base=None, config_path="utils/configs", config_name="default") | ||
def cli(cfg): | ||
LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}") | ||
|
||
module_file = None | ||
if cfg.task.lower() == "init": # special case | ||
shutil.copy2(DEFAULT_CONFIG, os.getcwd()) | ||
LOGGER.info(f""" | ||
{colorstr("YOLO :")} configuration saved to {os.getcwd()}/{DEFAULT_CONFIG.name}. | ||
To run experiments using custom configuration: | ||
yolo task='task' mode='mode' --config-name config_file.yaml | ||
""") | ||
return | ||
elif cfg.task.lower() == "detect": | ||
module_file = yolo.detect | ||
elif cfg.task.lower() == "segment": | ||
module_file = yolo.segment | ||
elif cfg.task.lower() == "classify": | ||
module_file = yolo.classify | ||
|
||
if not module_file: | ||
raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`") | ||
|
||
module_function = None | ||
|
||
if cfg.mode.lower() == "train": | ||
module_function = module_file.train | ||
elif cfg.mode.lower() == "val": | ||
module_function = module_file.val | ||
elif cfg.mode.lower() == "infer": | ||
module_function = module_file.infer | ||
|
||
if not module_function: | ||
raise Exception("mode not recognized. Choices are `'train', 'val', 'infer'`") | ||
module_function(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import contextlib | ||
import logging | ||
import logging.config | ||
import os | ||
import platform | ||
import sys | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters