Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
# Conflicts:
#	configs/some.yaml
  • Loading branch information
autumn-2-net committed Sep 13, 2023
2 parents 3e00d16 + 898c8a5 commit 02e245e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion configs/some.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ audio_sample_rate: 16000
test_prefixes:
- item1
- item2
units_encoder: contentvec
units_encoder: contentvec768l12
units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt
pe: rmvpe
pe_ckpt: pretrained/rmvpe/model.pt
Expand Down
6 changes: 3 additions & 3 deletions modules/contentvec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from fairseq import checkpoint_utils


class ContentVec(torch.nn.Module):
class ContentVec768L12(torch.nn.Module):
def __init__(self, path, h_sample_rate=16000, h_hop_size=320, device='cpu'):
super().__init__()
self.device = device
models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="", )
models, self.saved_cfg, self.task = checkpoint_utils.load_model_ensemble_and_task([path], suffix="")
self.hubert = models[0].to(self.device).eval()

def forward(self, waveform): # B, T
Expand All @@ -19,6 +19,6 @@ def forward(self, waveform): # B, T
}
with torch.no_grad():
logits = self.hubert.extract_features(**inputs)
feats = self.hubert.final_proj(logits[0])
feats = logits[0]
units = feats # .transpose(2, 1)
return units
2 changes: 1 addition & 1 deletion preprocessing/me_binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def process_item(self, item_name, meta_data, binarization_args):
wav_tensor = torch.from_numpy(waveform).to(self.device)
global contentvec
if contentvec is None:
contentvec = modules.contentvec.ContentVec(self.config['units_encoder_ckpt'], device=self.device)
contentvec = modules.contentvec.ContentVec768L12(self.config['units_encoder_ckpt'], device=self.device)
units = contentvec(wav_tensor).squeeze(0).cpu().numpy()
assert len(units.shape) == 2 and units.shape[1] == self.config['units_dim'], \
f'Shape of units must be [T, units_dim], but is {units.shape}.'
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def train(config, exp_name, work_dir):
work_dir = work_dir / exp_name
assert not work_dir.exists() or work_dir.is_dir(), f'Path \'{work_dir}\' is not a directory.'
work_dir.mkdir(parents=True, exist_ok=True)
with open(work_dir / 'config.yaml', 'w', encoding='utf8') as f:
yaml.safe_dump(config, f)

if config['ddp_backend'] == 'nccl_no_p2p':
print("Disabling NCCL P2P")
Expand Down

0 comments on commit 02e245e

Please sign in to comment.