Skip to content

Commit

Permalink
新增api.py中:可在启动后动态修改模型,以此满足同一个api不同的朗读者请求
Browse files Browse the repository at this point in the history
可在启动后动态修改模型,以此满足同一个api不同的朗读者请求
  • Loading branch information
JavaAndPython55 authored Feb 21, 2024
1 parent a16de2e commit 4b0fae8
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")

parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
Expand Down Expand Up @@ -227,6 +227,44 @@ def is_full(*items): # 任意一项为空返回False
return False
return True

def change_sovits_weights(sovits_path):
global vq_model, hps
dict_s2 = torch.load(sovits_path, map_location="cpu")
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
if ("pretrained" not in sovits_path):
del vq_model.enc_q
if is_half == True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
with open("./sweight.txt", "w", encoding="utf-8") as f:
f.write(sovits_path)
def change_gpt_weights(gpt_path):
global hz, max_sec, t2s_model, config
hz = 50
dict_s1 = torch.load(gpt_path, map_location="cpu")
config = dict_s1["config"]
max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half == True:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path)


def get_bert_feature(text, word2ph):
with torch.no_grad():
Expand Down Expand Up @@ -452,6 +490,20 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):

app = FastAPI()

#clark新增-----2024-02-21
#可在启动后动态修改模型,以此满足同一个api不同的朗读者请求
@app.post("/set_model")
async def set_model(request: Request):
json_post_raw = await request.json()
global gpt_path
gpt_path=json_post_raw.get("gpt_model_path")
global sovits_path
sovits_path=json_post_raw.get("sovits_model_path")
print("gptpath"+gpt_path+";vitspath"+sovits_path)
change_sovits_weights(sovits_path)
change_gpt_weights(gpt_path)
return "ok"
# 新增-----end------

@app.post("/control")
async def control(request: Request):
Expand Down

0 comments on commit 4b0fae8

Please sign in to comment.