Skip to content

Commit

Permalink
修复由于分词词表带来的切分不一致问题 #466
Browse files Browse the repository at this point in the history
  • Loading branch information
AlongWY committed Dec 28, 2020
1 parent aae913b commit 0f3c956
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ltp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*_
# Author: Yunlong Feng <[email protected]>

__version__ = '4.1.3'
__version__ = '4.1.3.post1'

from . import const
from . import nn, utils
Expand Down
7 changes: 4 additions & 3 deletions ltp/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(self, path: str = 'small', device=None, **kwargs):
self.model.eval()

self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
self.seg_vocab_dict = {tag: idx for idx, tag in enumerate(self.seg_vocab)}
self.pos_vocab = ckpt.get('pos', [])
self.ner_vocab = ckpt.get('ner', [])
self.dep_vocab = ckpt.get('dep', [])
Expand Down Expand Up @@ -255,10 +256,10 @@ def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True
matches = self.seg_with_dict(inputs, tokenized, batch_prefix)
for sent_match, sent_seg in zip(matches, seg):
for start, end in sent_match:
sent_seg[start] = 0
sent_seg[start + 1:end] = 1
sent_seg[start] = self.seg_vocab_dict[WORD_START]
sent_seg[start + 1:end] = self.seg_vocab_dict[WORD_MIDDLE]
if end < len(sent_seg):
sent_seg[end] = 0
sent_seg[end] = self.seg_vocab_dict[WORD_START]

if is_preseged:
sentences = inputs
Expand Down
56 changes: 55 additions & 1 deletion ltp/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from argparse import ArgumentParser
from collections import OrderedDict

import numpy
import torch
import torch.utils.data
from pytorch_lightning import Trainer
from tqdm import tqdm

import ltp
from ltp import (
Expand All @@ -20,7 +22,7 @@
from ltp.data import dataset as datasets
from ltp.data.utils import collate, MultiTaskDataloader
from ltp.transformer_multitask import TransformerMultiTask as Model
from ltp.utils import TaskInfo, common_train, tune_train
from ltp.utils import TaskInfo, common_train, tune_train, map2device, convert2npy
from ltp.utils import deploy_model

os.environ['TOKENIZERS_PARALLELISM'] = 'true'
Expand Down Expand Up @@ -190,6 +192,55 @@ def configure_optimizers(self: Model):
)


def build_ner_distill_dataset(args):
model = Model.load_from_checkpoint(
args.resume_from_checkpoint, hparams=args
)

model.eval()
model.freeze()

dataset, metric = task_named_entity_recognition.build_dataset(model, args.ner_data_dir, task_info.task_name)
train_dataloader = torch.utils.data.DataLoader(
dataset[datasets.Split.TRAIN],
batch_size=args.batch_size,
collate_fn=collate,
num_workers=args.num_workers
)

output = os.path.join(args.ner_data_dir, task_info.task_name, 'output.npz')

if torch.cuda.is_available():
model.cuda()
map2cpu = lambda x: map2device(x)
map2cuda = lambda x: map2device(x, model.device)
else:
map2cpu = lambda x: x
map2cuda = lambda x: x

with torch.no_grad():
batchs = []
for batch in tqdm(train_dataloader):
batch = map2cuda(batch)
logits = model.forward(task='ner', **batch).logits
batch.update(logits=logits)
batchs.append(map2cpu(batch))
try:
numpy.savez(
output,
data=convert2npy(batchs),
extra=convert2npy({
'transitions': model.ner_classifier.crf.transitions,
'start_transitions': model.ner_classifier.crf.start_transitions,
'end_transitions': model.ner_classifier.crf.end_transitions
})
)
except Exception as e:
numpy.savez(output, data=convert2npy(batchs))

print("Done")


def add_task_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--seed', type=int, default=19980524)
Expand All @@ -210,6 +261,7 @@ def add_task_specific_args(parent_parser):
parser.add_argument('--dep_data_dir', type=str, default=None)
parser.add_argument('--sdp_data_dir', type=str, default=None)
parser.add_argument('--srl_data_dir', type=str, default=None)
parser.add_argument('--build_ner_dataset', action='store_true')
return parser


Expand All @@ -226,6 +278,8 @@ def main():

if args.ltp_model is not None and args.resume_from_checkpoint is not None:
deploy_model(args, args.ltp_version)
elif args.build_ner_dataset:
build_ner_distill_dataset(args)
elif args.tune:
tune_train(args, model_class=Model, task_info=task_info, build_method=build_method)
else:
Expand Down

0 comments on commit 0f3c956

Please sign in to comment.