Skip to content

Commit

Permalink
feat: reserve interface for other torch devices (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tohrusky authored Nov 19, 2024
1 parent 4fe1a73 commit 0da0246
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 20 deletions.
23 changes: 5 additions & 18 deletions Final2x_core/SRclass.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import math
from typing import Any

import cv2
import numpy as np
from ccrestoration import AutoModel, SRBaseModel
from loguru import logger

from Final2x_core.config import SRConfig
from Final2x_core.util import PrintProgressLog
from Final2x_core.util import PrintProgressLog, get_device


class CCRestoration:
Expand All @@ -22,27 +21,15 @@ def __init__(self, config: SRConfig) -> None:

PrintProgressLog().set(len(self.config.input_path), 1)

self._SR_class: SRBaseModel = self._init_SR_model()

logger.info("SR Class init, device: " + str(self._SR_class.device))

def _init_SR_model(self) -> Any:
"""
init sr model from ccrestoration
:return:
"""
if self.config.device == "auto":
_device = None
else:
_device = self.config.device

return AutoModel.from_pretrained(
self._SR_class: SRBaseModel = AutoModel.from_pretrained(
pretrained_model_name=self.config.pretrained_model_name,
fp16=False,
device=_device,
device=get_device(self.config.device),
gh_proxy=self.config.gh_proxy,
)

logger.info("SR Class init, device: " + str(self._SR_class.device))

@logger.catch # type: ignore
def process(self, img: np.ndarray) -> np.ndarray:
"""
Expand Down
2 changes: 1 addition & 1 deletion Final2x_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def from_base64(cls, base64_str: str) -> Any:

@field_validator("device")
def device_match(cls, v: str) -> str:
device_list = ["auto", "cpu", "cuda", "mps", "xpu", "xla", "meta"]
device_list = ["auto", "cpu", "cuda", "mps", "directml", "xpu"]
for d in device_list:
if v.startswith(d):
return v
Expand Down
1 change: 1 addition & 0 deletions Final2x_core/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from Final2x_core.util.progressLog import PrintProgressLog # noqa
from Final2x_core.util.singleton import singleton # noqa
from Final2x_core.util.device import get_device # noqa
29 changes: 29 additions & 0 deletions Final2x_core/util/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Union

import torch
from ccrestoration.util.device import default_device


def get_device(device: str) -> Union[torch.device, str]:
"""
Get device from string
:param device: device string
"""
if device.startswith("auto"):
return default_device()
elif device.startswith("cpu"):
return torch.device("cpu")
elif device.startswith("cuda"):
return torch.device("cuda")
elif device.startswith("mps"):
return torch.device("mps")
elif device.startswith("directml"):
import torch_directml

return torch_directml.device()
elif device.startswith("xpu"):
return torch.device("xpu")
else:
print(f"Unknown device: {device}, use auto instead.")
return default_device()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ license = "BSD-3-Clause"
name = "Final2x_core"
readme = "README.md"
repository = "https://github.com/Tohrusky/Final2x-core"
version = "3.0.1"
version = "3.0.2"

# Requirements
[tool.poetry.dependencies]
Expand Down

0 comments on commit 0da0246

Please sign in to comment.