Skip to content

Commit

Permalink
compatibility with older torch
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Jul 18, 2023
1 parent 10371a7 commit 3371c72
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
5 changes: 4 additions & 1 deletion inference/interact/fbrs/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
from torch import mps
try:
from torch import mps
except:
pass

from ..fbrs.inference import clicker
from ..fbrs.inference.predictors import get_predictor
Expand Down
5 changes: 4 additions & 1 deletion inference/interact/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@

import numpy as np
import torch
from torch import mps
try:
from torch import mps
except:
print('torch.MPS not available.')

from PyQt6.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox,
QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog,
Expand Down
13 changes: 8 additions & 5 deletions inference/interact/interactive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def index_numpy_to_one_hot_torch(mask, num_classes):
"""
Some constants fro visualization
"""
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
try:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
except:
device = torch.device("cpu")

color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
Expand Down

0 comments on commit 3371c72

Please sign in to comment.