Skip to content

Commit

Permalink
update bert eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Mar 18, 2020
1 parent 5ab5de7 commit bb06acb
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 3 deletions.
23 changes: 23 additions & 0 deletions pycorrector/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@ python3 bert_corrector.py
纠错结果还算可圈可点,速度有点慢,可以用albert-tiny之类的参数小些的模型加速预测。


### Evaluate

提供评估脚本[pycorrector/utils/eval.py](./pycorrector/utils/eval.py),该脚本有两个功能:
- 构建评估样本集:自动生成评估集[pycorrector/data/eval_corpus.json](pycorrector/data/eval_corpus.json), 包括字粒度错误100条、词粒度错误100条、语法错误100条,正确句子200条。用户可以修改条数生成其他评估样本分布。
- 计算纠错准召率:采用保守计算方式,简单把纠错之后与正确句子完成匹配的视为正确,否则为错。

执行该评估脚本后,

Bert模型纠错效果评估如下:
- 准确率:284/500=56.8%
- 召回率:105/300=35%

规则方法的纠错效果评估如下:
- 准确率:320/500=64%
- 召回率:152/300=50.67%


可以看出Bert模型对文本有强大的表达能力,仅仅依赖预训练的MLM模型,在纠错能力上就比优化良久的专家规则方法稍差而已,而且看结果细节一些纠正还挺靠谱。

看来选择一个好的模型,选择一个正确的方向真的很重要。我在这里只能希望规则的方法尽量做到扩展性好些,深度模型尽量做到调研各种模型全一些,深入一些。



## Fine-tuned BERT model with chinese corpus

### chinese corpus
Expand Down
58 changes: 58 additions & 0 deletions pycorrector/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def build_cged_no_error_corpus(data_path, output_path, limit_size=500):


def build_eval_corpus(output_eval_path=eval_data_path):
"""
生成评估样本集,抽样分布可修改
:param output_eval_path:
:return: json file
"""
bcmi_path = os.path.join(pwd_path, '../data/cn/bcmi.txt')
clp_path = os.path.join(pwd_path, '../data/cn/clp14_C1.pkl')
sighan_path = os.path.join(pwd_path, '../data/cn/sighan15_A2.pkl')
Expand Down Expand Up @@ -256,6 +261,59 @@ def eval_corpus(input_eval_path=eval_data_path, output_eval_path=output_eval_err
save_json(res, output_eval_path)


def eval_corpus_by_bert(input_eval_path=eval_data_path, output_eval_path=output_eval_error_path, verbose=True):
from pycorrector.bert.bert_corrector import BertCorrector
model = BertCorrector()
res = []
corpus = load_json(input_eval_path)
total_count = 0
right_count = 0
right_rate = 0.0
recall_rate = 0.0
recall_right_count = 0
recall_total_count = 0
for data_dict in corpus:
text = data_dict.get('text', '')
correction = data_dict.get('correction', '')
errors = data_dict.get('errors', [])

# pred_detail: list(wrong, right, begin_idx, end_idx)
pred_sentence, pred_detail = model.bert_correct(text)
# compute recall
if errors:
recall_total_count += 1
if errors and pred_detail and correction == pred_sentence:
recall_right_count += 1

# compute precision
if correction == pred_sentence:
right_count += 1
else:
err_data_dict = copy.deepcopy(data_dict)
err_data_dict['pred_sentence'] = pred_sentence
err_data_dict['pred_errors'] = str(pred_detail)
res.append(err_data_dict)
if verbose:
print('truth:', text, errors)
print('predict:', pred_sentence, pred_detail)
total_count += 1

if total_count > 0:
right_rate = right_count / total_count
if recall_total_count > 0:
recall_rate = recall_right_count / recall_total_count
print('right_rate:{}, right_count:{}, total_count:{};\n'
'recall_rate:{},recall_right_count:{},recall_total_count:{}'.format(right_rate, right_count, total_count,
recall_rate, recall_right_count,
recall_total_count))
save_json(res, output_eval_path)
if __name__ == "__main__":
# 生成评估数据集样本,当前已经生成评估集,可以打开注释生成自己的样本分布
# build_eval_corpus()

# 评估规则方法的纠错准召率
eval_corpus(eval_data_path, output_eval_path=output_eval_error_path)

# 评估bert模型的纠错准召率
bert_path = os.path.join(pwd_path, './eval_corpus_error_by_bert.json')
eval_corpus_by_bert(eval_data_path, output_eval_path=bert_path)
2 changes: 1 addition & 1 deletion pycorrector/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
@description:
"""

__version__ = '0.2.4'
__version__ = '0.2.5'
6 changes: 4 additions & 2 deletions tests/eval_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# -*- coding: utf-8 -*-
# Author: XuMing <[email protected]>
# Brief:
import sys

sys.path.append("../")
import sys
import os
sys.path.append("../")

from pycorrector.utils.eval import eval_bcmi_data, get_bcmi_corpus, eval_sighan_corpus

pwd_path = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -35,3 +36,4 @@ def test_sighan_data():
rate = eval_sighan_corpus(sighan_path, True)
print('sighan right rate:{}'.format(rate))
# sighan right rate:0.5724725943970768

0 comments on commit bb06acb

Please sign in to comment.