diff --git a/Final2x_core/SRclass.py b/Final2x_core/SRclass.py index eb4b077..e8713b7 100644 --- a/Final2x_core/SRclass.py +++ b/Final2x_core/SRclass.py @@ -1,5 +1,4 @@ import math -from typing import Any import cv2 import numpy as np @@ -7,7 +6,7 @@ from loguru import logger from Final2x_core.config import SRConfig -from Final2x_core.util import PrintProgressLog +from Final2x_core.util import PrintProgressLog, get_device class CCRestoration: @@ -22,27 +21,15 @@ def __init__(self, config: SRConfig) -> None: PrintProgressLog().set(len(self.config.input_path), 1) - self._SR_class: SRBaseModel = self._init_SR_model() - - logger.info("SR Class init, device: " + str(self._SR_class.device)) - - def _init_SR_model(self) -> Any: - """ - init sr model from ccrestoration - :return: - """ - if self.config.device == "auto": - _device = None - else: - _device = self.config.device - - return AutoModel.from_pretrained( + self._SR_class: SRBaseModel = AutoModel.from_pretrained( pretrained_model_name=self.config.pretrained_model_name, fp16=False, - device=_device, + device=get_device(self.config.device), gh_proxy=self.config.gh_proxy, ) + logger.info("SR Class init, device: " + str(self._SR_class.device)) + @logger.catch # type: ignore def process(self, img: np.ndarray) -> np.ndarray: """ diff --git a/Final2x_core/config.py b/Final2x_core/config.py index bb19974..3e55510 100644 --- a/Final2x_core/config.py +++ b/Final2x_core/config.py @@ -60,7 +60,7 @@ def from_base64(cls, base64_str: str) -> Any: @field_validator("device") def device_match(cls, v: str) -> str: - device_list = ["auto", "cpu", "cuda", "mps", "xpu", "xla", "meta"] + device_list = ["auto", "cpu", "cuda", "mps", "directml", "xpu"] for d in device_list: if v.startswith(d): return v diff --git a/Final2x_core/util/__init__.py b/Final2x_core/util/__init__.py index 1903ad5..a340aa8 100644 --- a/Final2x_core/util/__init__.py +++ b/Final2x_core/util/__init__.py @@ -1,2 +1,3 @@ from Final2x_core.util.progressLog import PrintProgressLog # noqa from Final2x_core.util.singleton import singleton # noqa +from Final2x_core.util.device import get_device # noqa diff --git a/Final2x_core/util/device.py b/Final2x_core/util/device.py new file mode 100644 index 0000000..3b61a71 --- /dev/null +++ b/Final2x_core/util/device.py @@ -0,0 +1,29 @@ +from typing import Union + +import torch +from ccrestoration.util.device import default_device + + +def get_device(device: str) -> Union[torch.device, str]: + """ + Get device from string + + :param device: device string + """ + if device.startswith("auto"): + return default_device() + elif device.startswith("cpu"): + return torch.device("cpu") + elif device.startswith("cuda"): + return torch.device("cuda") + elif device.startswith("mps"): + return torch.device("mps") + elif device.startswith("directml"): + import torch_directml + + return torch_directml.device() + elif device.startswith("xpu"): + return torch.device("xpu") + else: + print(f"Unknown device: {device}, use auto instead.") + return default_device() diff --git a/pyproject.toml b/pyproject.toml index 7cfc099..f428266 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ license = "BSD-3-Clause" name = "Final2x_core" readme = "README.md" repository = "https://github.com/Tohrusky/Final2x-core" -version = "3.0.1" +version = "3.0.2" # Requirements [tool.poetry.dependencies]