Skip to content

Commit

Permalink
Update requirements for TANGO
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanwayx committed Apr 26, 2023
1 parent 3f13e32 commit 493be0d
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ More generated samples are shown [here](https://github.com/declare-lab/tango/blo

## Prerequisites

Our code is built on pytorch version 1.13.1+cu117. We mention `torch==1.13.1` in the requirements file but you might need to install a specific cuda version of torch depending on your GPU device type.

Install `requirements.txt`. You will also need to install the `diffusers` package from the directory provided in this repo:

```bash
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ huggingface_hub==0.13.3
importlib_metadata==6.3.0
librosa==0.9.2
matplotlib==3.5.2
numpy==1.22.0
numpy==1.23.0
omegaconf==2.3.0
packaging==23.1
pandas==1.4.1
Expand Down
6 changes: 3 additions & 3 deletions tango.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __init__(self, name="declare-lab/tango", device="cuda:0"):
self.stft = TacotronSTFT(**stft_config).to(device)
self.model = AudioDiffusion(**main_config).to(device)

vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path))
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path))
main_weights = torch.load("{}/pytorch_model_main.bin".format(path))
vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)

self.vae.load_state_dict(vae_weights)
self.stft.load_state_dict(stft_weights)
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def main():

if args.hf_model:
hf_model_path = snapshot_download(repo_id=args.hf_model)
model.load_state_dict(torch.load("{}/pytorch_model_main.bin".format(hf_model_path)))
model.load_state_dict(torch.load("{}/pytorch_model_main.bin".format(hf_model_path), map_location="cpu"))
accelerator.print("Successfully loaded checkpoint from:", args.hf_model)

if args.prefix:
prefix = args.prefix
Expand Down

0 comments on commit 493be0d

Please sign in to comment.