forked from yeyupiaoling/PPASR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_server.py
108 lines (95 loc) · 5.24 KB
/
infer_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import functools
import os
import sys
import time
from datetime import datetime
from flask import request, Flask, render_template
from flask_cors import CORS
from ppasr.predict import Predictor
from ppasr.utils.audio_vad import crop_audio_vad
from ppasr.utils.utils import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("host", str, "0.0.0.0", "监听主机的IP地址")
add_arg("port", int, 5000, "服务所使用的端口号")
add_arg("save_path", str, 'dataset/upload/', "上传音频文件的保存目录")
add_arg('use_gpu', bool, True, "是否使用GPU预测")
add_arg('use_pun', bool, False, "是否给识别结果加标点符号")
add_arg('to_an', bool, False, "是否转为阿拉伯数字")
add_arg('beam_size', int, 300, "集束搜索解码相关参数,搜索大小,范围:[5, 500]")
add_arg('alpha', float, 2.2, "集束搜索解码相关参数,LM系数")
add_arg('beta', float, 4.3, "集束搜索解码相关参数,WC系数")
add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数,剪枝的概率")
add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数,剪枝的最大值")
add_arg('use_model', str, 'deepspeech2', "所使用的模型")
add_arg('vocab_path', str, 'dataset/vocabulary.txt', "数据集的词汇表文件路径")
add_arg('model_dir', str, 'models/deepspeech2/infer/', "导出的预测模型文件夹路径")
add_arg('pun_model_dir', str, 'models/pun_models/', "加标点符号的模型文件夹路径")
add_arg('lang_model_path', str, 'lm/zh_giga.no_cna_cmn.prune01244.klm', "集束搜索解码相关参数,语言模型文件路径")
add_arg('feature_method', str, 'linear', "音频预处理方法", choices=['linear', 'mfcc', 'fbank'])
add_arg('decoder', str, 'ctc_beam_search', "结果解码方法", choices=['ctc_beam_search', 'ctc_greedy'])
args = parser.parse_args()
app = Flask(__name__, template_folder="templates", static_folder="static", static_url_path="/")
# 允许跨越访问
CORS(app)
predictor = Predictor(model_dir=args.model_dir, vocab_path=args.vocab_path, use_model=args.use_model,
decoder=args.decoder, alpha=args.alpha, beta=args.beta, lang_model_path=args.lang_model_path,
beam_size=args.beam_size, cutoff_prob=args.cutoff_prob, cutoff_top_n=args.cutoff_top_n,
use_gpu=args.use_gpu, use_pun=args.use_pun, pun_model_dir=args.pun_model_dir,
feature_method=args.feature_method)
# 语音识别接口
@app.route("/recognition", methods=['POST'])
def recognition():
f = request.files['audio']
if f:
# 临时保存路径
file_path = os.path.join(args.save_path, f.filename)
f.save(file_path)
try:
start = time.time()
# 执行识别
score, text = predictor.predict(audio_path=file_path, to_an=args.to_an)
end = time.time()
print("识别时间:%dms,识别结果:%s, 得分: %f" % (round((end - start) * 1000), text, score))
result = str({"code": 0, "msg": "success", "result": text, "score": round(score, 3)}).replace("'", '"')
return result
except Exception as e:
print(f'[{datetime.now()}] 短语音识别失败,错误信息:{e}', file=sys.stderr)
return str({"error": 1, "msg": "audio read fail!"})
return str({"error": 3, "msg": "audio is None!"})
# 长语音识别接口
@app.route("/recognition_long_audio", methods=['POST'])
def recognition_long_audio():
f = request.files['audio']
if f:
# 临时保存路径
file_path = os.path.join(args.save_path, f.filename)
f.save(file_path)
try:
start = time.time()
# 分割长音频
audios_bytes = crop_audio_vad(file_path)
texts = ''
scores = []
# 执行识别
for i, audio_bytes in enumerate(audios_bytes):
score, text = predictor.predict(audio_bytes=audio_bytes, to_an=args.to_an)
texts = texts + text if args.use_pun else texts + ',' + text
scores.append(score)
end = time.time()
print("识别时间:%dms,识别结果:%s, 得分: %f" % (round((end - start) * 1000), texts, sum(scores) / len(scores)))
result = str({"code": 0, "msg": "success", "result": texts, "score": round(float(sum(scores) / len(scores)), 3)}).replace("'", '"')
return result
except Exception as e:
print(f'[{datetime.now()}] 短语音识别失败,错误信息:{e}', file=sys.stderr)
return str({"error": 1, "msg": "audio read fail!"})
return str({"error": 3, "msg": "audio is None!"})
@app.route('/')
def home():
return render_template("index.html")
if __name__ == '__main__':
print_arguments(args)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
app.run(host=args.host, port=args.port)