Skip to content

Commit

Permalink
Auto download missing model for bert_gen.py (fishaudio#146)
Browse files Browse the repository at this point in the history
* auto download missing model

* support openi

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix wrong delete

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pass pre-commit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix repeat login

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Isotr0py and pre-commit-ci[bot] authored Nov 3, 2023
1 parent 4d6de24 commit 8609449
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 22 deletions.
14 changes: 14 additions & 0 deletions bert/bert_models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"deberta-v2-large-japanese": {
"repo_id": "ku-nlp/deberta-v2-large-japanese",
"files": ["spm.model", "pytorch_model.bin"]
},
"chinese-roberta-wwm-ext-large": {
"repo_id": "hfl/chinese-roberta-wwm-ext-large",
"files": ["pytorch_model.bin"]
},
"deberta-v3-large": {
"repo_id": "microsoft/deberta-v3-large",
"files": ["spm.model", "pytorch_model.bin"]
}
}
3 changes: 2 additions & 1 deletion bert_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import commons
import utils
from tqdm import tqdm
from text import cleaned_text_to_sequence, get_bert
from text import check_bert_models, cleaned_text_to_sequence, get_bert
import argparse
import torch.multiprocessing as mp
from config import config
Expand Down Expand Up @@ -57,6 +57,7 @@ def process_line(line):
args, _ = parser.parse_known_args()
config_path = args.config
hps = utils.get_hparams_from_file(config_path)
check_bert_models()
lines = []
with open(hps.data.training_files, encoding="utf-8") as f:
lines.extend(f.readlines())
Expand Down
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ def __init__(self, config_path: str):
with open(file=config_path, mode="r", encoding="utf-8") as file:
yaml_config: Dict[str, any] = yaml.safe_load(file.read())
dataset_path: str = yaml_config["dataset_path"]
openi_token: str = yaml_config["openi_token"]
self.dataset_path: str = dataset_path
self.mirror: str = yaml_config["mirror"]
self.openi_token: str = openi_token
self.resample_config: Resample_config = Resample_config.from_dict(
dataset_path, yaml_config["resample"]
)
Expand Down
3 changes: 2 additions & 1 deletion default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
# 不填或者填空则路径为相对于项目根目录的路径
dataset_path: "Data/你的数据集"

mirror: "openi" # 模型镜像源
openi_token: "1145141919810" # openi token

# resample 音频重采样配置
# 注意, “:” 后需要加空格
Expand Down
20 changes: 20 additions & 0 deletions text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,23 @@ def get_bert(norm_text, word2ph, language, device):
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
bert = lang_bert_func_map[language](norm_text, word2ph, device)
return bert


def check_bert_models():
import json
from pathlib import Path

from config import config
from .bert_utils import _check_bert

if config.mirror.lower() == "openi":
import openi

kwargs = {"token": config.openi_token} if config.openi_token else {}
openi.login(**kwargs)

with open("./bert/bert_models.json", "r") as fp:
models = json.load(fp)
for k, v in models.items():
local_path = Path("./bert").joinpath(k)
_check_bert(v["repo_id"], v["files"], local_path)
23 changes: 23 additions & 0 deletions text/bert_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from huggingface_hub import hf_hub_download

from config import config


MIRROR: str = config.mirror


def _check_bert(repo_id, files, local_path):
for file in files:
if not Path(local_path).joinpath(file).exists():
if MIRROR.lower() == "openi":
import openi

openi.model.download_model(
"Stardust_minus/Bert-VITS2", repo_id.split("/")[-1], "./bert"
)
else:
hf_hub_download(
repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
)
16 changes: 8 additions & 8 deletions text/chinese_bert.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import torch
import sys
from transformers import AutoTokenizer, AutoModelForMaskedLM

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

from config import config

tokenizer = AutoTokenizer.from_pretrained("./bert/chinese-roberta-wwm-ext-large")
LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large"

tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)

models = dict()

Expand All @@ -18,9 +22,7 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
if not device:
device = "cuda"
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(
"./bert/chinese-roberta-wwm-ext-large"
).to(device)
models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
Expand All @@ -41,8 +43,6 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):


if __name__ == "__main__":
import torch

word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
word2phone = [
1,
Expand Down
13 changes: 8 additions & 5 deletions text/english_bert_mock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import sys

import torch
from transformers import DebertaV2Model, DebertaV2Tokenizer

from config import config
import sys

tokenizer = DebertaV2Tokenizer.from_pretrained("./bert/deberta-v3-large")

LOCAL_PATH = "./bert/deberta-v3-large"

tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)

models = dict()

Expand All @@ -18,9 +23,7 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
if not device:
device = "cuda"
if device not in models.keys():
models[device] = DebertaV2Model.from_pretrained("./bert/deberta-v3-large").to(
device
)
models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device)
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
Expand Down
16 changes: 9 additions & 7 deletions text/japanese_bert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys
from text.japanese import text2sep_kata

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

from config import config
from text.japanese import text2sep_kata

LOCAL_PATH = "./bert/deberta-v2-large-japanese"

tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese")
tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)

models = dict()

Expand All @@ -27,9 +31,7 @@ def get_bert_feature_with_token(tokens, word2ph, device=config.bert_gen_config.d
if not device:
device = "cuda"
if device not in models.keys():
models[device] = AutoModelForMaskedLM.from_pretrained(
"./bert/deberta-v2-large-japanese"
).to(device)
models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
with torch.no_grad():
inputs = torch.tensor(tokens).to(device).unsqueeze(0)
token_type_ids = torch.zeros_like(inputs).to(device)
Expand Down

0 comments on commit 8609449

Please sign in to comment.