Skip to content

Commit

Permalink
update cuda mps cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Kedreamix committed Jun 12, 2024
1 parent 37d7ac5 commit e7cc6ca
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions TFG/MuseTalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ class MuseTalk_RealTime:
def __init__(self):
# load model weights
self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.timesteps = torch.tensor([0], device=device)
self.pe = self.pe.half()
self.vae.vae = self.vae.vae.half()
Expand Down Expand Up @@ -517,7 +522,13 @@ class MuseTalk:
def __init__(self):
# load model weights
self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import platform
if torch.cuda.is_available():
device = "cuda"
elif platform.system() == 'Darwin': # macos
device = "mps"
else:
device = "cpu"
self.timesteps = torch.tensor([0], device=device)


Expand Down Expand Up @@ -753,7 +764,7 @@ def check_video(self, video):
# musetalk = MuseTalk()
musetalk = MuseTalk_RealTime()
audio_path = "Musetalk/data/audio/sun.wav"
video_path = "Musetalk/data/video/yongen.mp4"
video_path = "Musetalk/data/video/yongen_musev.mp4"
bbox_shift = 5
video_path, bbox_shift_text = musetalk.prepare_material(video_path, bbox_shift)
# print(video_path, bbox_shift_text)
Expand Down
2 changes: 1 addition & 1 deletion TFG/SadTalker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class SadTalker():

def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):

import platform
if torch.cuda.is_available():
device = "cuda"
elif platform.system() == 'Darwin': # macos
Expand Down

0 comments on commit e7cc6ca

Please sign in to comment.