forked from w-okada/voice-changer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
wataru
committed
Oct 29, 2022
1 parent
9d5c714
commit c01b733
Showing
9 changed files
with
760 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
import sys, os, struct, argparse, logging, shutil, base64, traceback | ||
sys.path.append("/MMVC_Trainer") | ||
sys.path.append("/MMVC_Trainer/text") | ||
|
||
import uvicorn | ||
from fastapi import FastAPI, UploadFile, File, Form | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from fastapi.responses import JSONResponse | ||
from fastapi.encoders import jsonable_encoder | ||
from pydantic import BaseModel | ||
|
||
from scipy.io.wavfile import write, read | ||
|
||
import socketio | ||
from distutils.util import strtobool | ||
from datetime import datetime | ||
|
||
import torch | ||
import numpy as np | ||
|
||
from mods.ssl import create_self_signed_cert | ||
from mods.VoiceChanger import VoiceChanger | ||
|
||
class UvicornSuppressFilter(logging.Filter): | ||
def filter(self, record): | ||
return False | ||
|
||
logger = logging.getLogger("uvicorn.error") | ||
logger.addFilter(UvicornSuppressFilter()) | ||
# logger.propagate = False | ||
logger = logging.getLogger("multipart.multipart") | ||
logger.propagate = False | ||
|
||
|
||
|
||
class VoiceModel(BaseModel): | ||
gpu: int | ||
srcId: int | ||
dstId: int | ||
timestamp: int | ||
buffer: str | ||
|
||
|
||
class MyCustomNamespace(socketio.AsyncNamespace): | ||
def __init__(self, namespace): | ||
super().__init__(namespace) | ||
|
||
def loadModel(self, config, model): | ||
if hasattr(self, 'voiceChanger') == True: | ||
self.voiceChanger.destroy() | ||
self.voiceChanger = VoiceChanger(config, model) | ||
|
||
def changeVoice(self, gpu, srcId, dstId, timestamp, unpackedData): | ||
return self.voiceChanger.on_request(gpu, srcId, dstId, timestamp, unpackedData) | ||
|
||
def on_connect(self, sid, environ): | ||
# print('[{}] connet sid : {}'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S') , sid)) | ||
pass | ||
|
||
async def on_request_message(self, sid, msg): | ||
# print("on_request_message", torch.cuda.memory_allocated()) | ||
gpu = int(msg[0]) | ||
srcId = int(msg[1]) | ||
dstId = int(msg[2]) | ||
timestamp = int(msg[3]) | ||
data = msg[4] | ||
# print(srcId, dstId, timestamp) | ||
unpackedData = np.array(struct.unpack('<%sh'%(len(data) // struct.calcsize('<h') ), data)) | ||
audio1 = self.changeVoice(gpu, srcId, dstId, timestamp, unpackedData) | ||
|
||
bin = struct.pack('<%sh'%len(audio1), *audio1) | ||
|
||
await self.emit('response',[timestamp, bin]) | ||
|
||
def on_disconnect(self, sid): | ||
# print('[{}] disconnect'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) | ||
pass; | ||
|
||
|
||
def setupArgParser(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-p", type=int, default=8080, help="port") | ||
parser.add_argument("-c", type=str, help="path for the config.json") | ||
parser.add_argument("-m", type=str, help="path for the model file") | ||
parser.add_argument("--https", type=strtobool, default=False, help="use https") | ||
parser.add_argument("--httpsKey", type=str, default="ssl.key", help="path for the key of https") | ||
parser.add_argument("--httpsCert", type=str, default="ssl.cert", help="path for the cert of https") | ||
parser.add_argument("--httpsSelfSigned", type=strtobool, default=True, help="generate self-signed certificate") | ||
return parser | ||
|
||
def printMessage(message, level=0): | ||
if level == 0: | ||
print(f"\033[17m{message}\033[0m") | ||
elif level == 1: | ||
print(f"\033[34m {message}\033[0m") | ||
elif level == 2: | ||
print(f"\033[32m {message}\033[0m") | ||
else: | ||
print(f"\033[47m {message}\033[0m") | ||
|
||
global app_socketio | ||
|
||
|
||
printMessage(f"Phase name:{__name__}", level=2) | ||
thisFilename = os.path.basename(__file__)[:-3] | ||
|
||
|
||
if __name__ == thisFilename: | ||
printMessage(f"PHASE3:{__name__}", level=2) | ||
parser = setupArgParser() | ||
args = parser.parse_args() | ||
PORT = args.p | ||
CONFIG = args.c | ||
MODEL = args.m | ||
|
||
app_fastapi = FastAPI() | ||
sio = socketio.AsyncServer( | ||
async_mode='asgi', | ||
cors_allowed_origins='*' | ||
) | ||
namespace = MyCustomNamespace('/test') | ||
sio.register_namespace(namespace) | ||
if CONFIG and MODEL: | ||
namespace.loadModel(CONFIG, MODEL) | ||
app_socketio = socketio.ASGIApp( | ||
sio, | ||
other_asgi_app=app_fastapi, | ||
static_files={ | ||
'': '../frontend/dist', | ||
'/': '../frontend/dist/index.html', | ||
} | ||
) | ||
|
||
@app_fastapi.get("/api/hello") | ||
async def index(): | ||
return {"result": "Index"} | ||
|
||
|
||
@app_fastapi.post("/api/uploadfile/model") | ||
async def upload_file(configFile:UploadFile = File(...), modelFile: UploadFile = File(...)): | ||
if configFile and modelFile: | ||
for file in [modelFile, configFile]: | ||
filename = file.filename | ||
fileobj = file.file | ||
upload_dir = open(os.path.join(".", filename),'wb+') | ||
shutil.copyfileobj(fileobj, upload_dir) | ||
upload_dir.close() | ||
namespace.loadModel(configFile.filename, modelFile.filename) | ||
return {"uploaded files": f"{configFile.filename}, {modelFile.filename} "} | ||
return {"Error": "uploaded file is not found."} | ||
|
||
|
||
|
||
@app_fastapi.post("/test") | ||
async def post_test(voice:VoiceModel): | ||
try: | ||
# print("POST REQUEST PROCESSING....") | ||
gpu = voice.gpu | ||
srcId = voice.srcId | ||
dstId = voice.dstId | ||
timestamp = voice.timestamp | ||
buffer = voice.buffer | ||
wav = base64.b64decode(buffer) | ||
|
||
if wav==0: | ||
samplerate, data=read("dummy.wav") | ||
unpackedData = data | ||
else: | ||
unpackedData = np.array(struct.unpack('<%sh'%(len(wav) // struct.calcsize('<h') ), wav)) | ||
write("logs/received_data.wav", 24000, unpackedData.astype(np.int16)) | ||
|
||
changedVoice = namespace.changeVoice(gpu, srcId, dstId, timestamp, unpackedData) | ||
changedVoiceBase64 = base64.b64encode(changedVoice).decode('utf-8') | ||
|
||
data = { | ||
"gpu":gpu, | ||
"srcId":srcId, | ||
"dstId":dstId, | ||
"timestamp":timestamp, | ||
"changedVoiceBase64":changedVoiceBase64 | ||
} | ||
|
||
json_compatible_item_data = jsonable_encoder(data) | ||
|
||
return JSONResponse(content=json_compatible_item_data) | ||
except Exception as e: | ||
print("REQUEST PROCESSING!!!! EXCEPTION!!!", e) | ||
print(traceback.format_exc()) | ||
return str(e) | ||
|
||
|
||
if __name__ == '__mp_main__': | ||
printMessage(f"PHASE2:{__name__}", level=2) | ||
|
||
|
||
if __name__ == '__main__': | ||
printMessage(f"PHASE1:{__name__}", level=2) | ||
|
||
parser = setupArgParser() | ||
args = parser.parse_args() | ||
PORT = args.p | ||
CONFIG = args.c | ||
MODEL = args.m | ||
|
||
printMessage(f"Start MMVC SocketIO Server", level=0) | ||
printMessage(f"CONFIG:{CONFIG}, MODEL:{MODEL}", level=1) | ||
|
||
if os.environ["EX_PORT"]: | ||
EX_PORT = os.environ["EX_PORT"] | ||
printMessage(f"External_Port:{EX_PORT} Internal_Port:{PORT}", level=1) | ||
else: | ||
printMessage(f"Internal_Port:{PORT}", level=1) | ||
|
||
if os.environ["EX_IP"]: | ||
EX_IP = os.environ["EX_IP"] | ||
printMessage(f"External_IP:{EX_IP}", level=1) | ||
|
||
# HTTPS key/cert作成 | ||
if args.https and args.httpsSelfSigned == 1: | ||
# HTTPS(おれおれ証明書生成) | ||
os.makedirs("./key", exist_ok=True) | ||
key_base_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}" | ||
keyname = f"{key_base_name}.key" | ||
certname = f"{key_base_name}.cert" | ||
create_self_signed_cert(certname, keyname, certargs= | ||
{"Country": "JP", | ||
"State": "Tokyo", | ||
"City": "Chuo-ku", | ||
"Organization": "F", | ||
"Org. Unit": "F"}, cert_dir="./key") | ||
key_path = os.path.join("./key", keyname) | ||
cert_path = os.path.join("./key", certname) | ||
printMessage(f"protocol: HTTPS(self-signed), key:{key_path}, cert:{cert_path}", level=1) | ||
elif args.https and args.httpsSelfSigned == 0: | ||
# HTTPS | ||
key_path = args.httpsKey | ||
cert_path = args.httpsCert | ||
printMessage(f"protocol: HTTPS, key:{key_path}, cert:{cert_path}", level=1) | ||
else: | ||
# HTTP | ||
printMessage(f"protocol: HTTP", level=1) | ||
|
||
# アドレス表示 | ||
if args.https == 1: | ||
printMessage(f"open https://<IP>:<PORT>/ with your browser.", level=0) | ||
else: | ||
printMessage(f"open http://<IP>:<PORT>/ with your browser.", level=0) | ||
|
||
if EX_PORT and EX_IP and args.https == 1: | ||
printMessage(f"In many cases it is one of the following", level=1) | ||
printMessage(f"https://localhost:{EX_PORT}/", level=1) | ||
for ip in EX_IP.strip().split(" "): | ||
printMessage(f"https://{ip}:{EX_PORT}/", level=1) | ||
elif EX_PORT and EX_IP and args.https == 0: | ||
printMessage(f"In many cases it is one of the following", level=1) | ||
printMessage(f"http://localhost:{EX_PORT}/", level=1) | ||
|
||
|
||
# サーバ起動 | ||
if args.https: | ||
# HTTPS サーバ起動 | ||
uvicorn.run( | ||
f"{os.path.basename(__file__)[:-3]}:app_socketio", | ||
host="0.0.0.0", | ||
port=int(PORT), | ||
reload=True, | ||
ssl_keyfile = key_path, | ||
ssl_certfile = cert_path, | ||
log_level="critical" | ||
) | ||
else: | ||
# HTTP サーバ起動 | ||
uvicorn.run( | ||
f"{os.path.basename(__file__)[:-3]}:app_socketio", | ||
host="0.0.0.0", | ||
port=int(PORT), | ||
reload=True, | ||
log_level="critical" | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import torch | ||
from scipy.io.wavfile import write, read | ||
import numpy as np | ||
import struct, traceback | ||
|
||
import utils | ||
import commons | ||
from models import SynthesizerTrn | ||
from text.symbols import symbols | ||
from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate | ||
from mel_processing import spectrogram_torch | ||
from text import text_to_sequence, cleaned_text_to_sequence | ||
|
||
|
||
class VoiceChanger(): | ||
def __init__(self, config, model): | ||
self.hps = utils.get_hparams_from_file(config) | ||
self.net_g = SynthesizerTrn( | ||
len(symbols), | ||
self.hps.data.filter_length // 2 + 1, | ||
self.hps.train.segment_size // self.hps.data.hop_length, | ||
n_speakers=self.hps.data.n_speakers, | ||
**self.hps.model) | ||
self.net_g.eval() | ||
self.gpu_num = torch.cuda.device_count() | ||
utils.load_checkpoint( model, self.net_g, None) | ||
print(f"VoiceChanger Initialized (GPU_NUM:{self.gpu_num})") | ||
|
||
def destroy(self): | ||
del self.net_g | ||
|
||
def on_request(self, gpu, srcId, dstId, timestamp, wav): | ||
# if wav==0: | ||
# samplerate, data=read("dummy.wav") | ||
# unpackedData = data | ||
# else: | ||
# unpackedData = np.array(struct.unpack('<%sh'%(len(wav) // struct.calcsize('<h') ), wav)) | ||
# write("logs/received_data.wav", 24000, unpackedData.astype(np.int16)) | ||
|
||
unpackedData = wav | ||
|
||
try: | ||
|
||
text_norm = text_to_sequence("a", self.hps.data.text_cleaners) | ||
text_norm = commons.intersperse(text_norm, 0) | ||
text_norm = torch.LongTensor(text_norm) | ||
|
||
audio = torch.FloatTensor(unpackedData.astype(np.float32)) | ||
audio_norm = audio /self.hps.data.max_wav_value | ||
audio_norm = audio_norm.unsqueeze(0) | ||
|
||
spec = spectrogram_torch(audio_norm, self.hps.data.filter_length, | ||
self.hps.data.sampling_rate, self.hps.data.hop_length, self.hps.data.win_length, | ||
center=False) | ||
spec = torch.squeeze(spec, 0) | ||
sid = torch.LongTensor([int(srcId)]) | ||
|
||
data = (text_norm, spec, audio_norm, sid) | ||
data = TextAudioSpeakerCollate()([data]) | ||
|
||
if gpu<0 or self.gpu_num==0 : | ||
with torch.no_grad(): | ||
x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cpu() for x in data] | ||
sid_tgt1 = torch.LongTensor([dstId]).cpu() | ||
audio1 = (self.net_g.cpu().voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data * self.hps.data.max_wav_value).cpu().float().numpy() | ||
else: | ||
with torch.no_grad(): | ||
x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cuda(gpu) for x in data] | ||
sid_tgt1 = torch.LongTensor([dstId]).cuda(gpu) | ||
audio1 = (self.net_g.cuda(gpu).voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data * self.hps.data.max_wav_value).cpu().float().numpy() | ||
except Exception as e: | ||
print("VC PROCESSING!!!! EXCEPTION!!!", e) | ||
print(traceback.format_exc()) | ||
|
||
audio1 = audio1.astype(np.int16) | ||
return audio1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
from OpenSSL import crypto | ||
|
||
def create_self_signed_cert(certfile, keyfile, certargs, cert_dir="."): | ||
C_F = os.path.join(cert_dir, certfile) | ||
K_F = os.path.join(cert_dir, keyfile) | ||
if not os.path.exists(C_F) or not os.path.exists(K_F): | ||
k = crypto.PKey() | ||
k.generate_key(crypto.TYPE_RSA, 2048) | ||
cert = crypto.X509() | ||
cert.get_subject().C = certargs["Country"] | ||
cert.get_subject().ST = certargs["State"] | ||
cert.get_subject().L = certargs["City"] | ||
cert.get_subject().O = certargs["Organization"] | ||
cert.get_subject().OU = certargs["Org. Unit"] | ||
cert.get_subject().CN = 'Example' | ||
cert.set_serial_number(1000) | ||
cert.gmtime_adj_notBefore(0) | ||
cert.gmtime_adj_notAfter(315360000) | ||
cert.set_issuer(cert.get_subject()) | ||
cert.set_pubkey(k) | ||
cert.sign(k, 'sha1') | ||
open(C_F, "wb").write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) | ||
open(K_F, "wb").write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k)) |
Oops, something went wrong.