-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
492 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#coding: UTF-8 | ||
import time | ||
import mindspore | ||
import numpy as np | ||
from train_eval import train, init_network | ||
from importlib import import_module | ||
import mindspore.ops as ops | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='Chinese Text Classification') | ||
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer') | ||
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') | ||
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') | ||
args = parser.parse_args() | ||
|
||
if __name__ == '__main__': | ||
dataset = 'THUCNews' # 数据集 | ||
|
||
# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random | ||
embedding = 'embedding_SougouNews.npz' | ||
if args.embedding == 'random': | ||
embedding = 'random' | ||
model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer | ||
if model_name == 'FastText': | ||
from utils_fasttext import build_dataset, build_iterator, get_time_dif | ||
embedding = 'random' | ||
else: | ||
from utils import build_dataset, build_iterator, get_time_dif | ||
|
||
x = import_module('models.' + model_name) | ||
config = x.Config(dataset, embedding) | ||
np.random.seed(1) | ||
mindspore.set_seed(1) | ||
|
||
start_time = time.time() | ||
print("Loading data...") | ||
vocab, train_data, dev_data, test_data = build_dataset(config, args.word) | ||
train_iter = build_iterator(train_data, config) | ||
dev_iter = build_iterator(dev_data, config) | ||
test_iter = build_iterator(test_data, config) | ||
time_dif = get_time_dif(start_time) | ||
print("Time usage:", time_dif) | ||
|
||
# train | ||
config.n_vocab = len(vocab) | ||
model = x.Model(config).to(config.device) | ||
if model_name != 'Transformer': | ||
init_network(model) | ||
print(model.parameters) | ||
train(config, model, train_iter, dev_iter, test_iter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#coding: UTF-8 | ||
import numpy as np | ||
import mindspore | ||
import mindspore.nn as nn | ||
import mindspore.ops as ops | ||
from sklearn import metrics | ||
import time | ||
from utils import get_time_dif | ||
from mindspore.common.initializer import initializer, XavierNormal, Normal, HeNormal, Constant | ||
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor | ||
|
||
|
||
|
||
# 权重初始化,默认xavier | ||
def init_network(model, method='xavier', exclude='embedding', seed=123): | ||
for name, w in model.named_parameters(): | ||
if exclude not in name: | ||
if 'weight' in name: | ||
if method == 'xavier': | ||
XavierNormal(w) | ||
elif method == 'kaiming': | ||
HeNormal(w) | ||
else: | ||
Normal(w) | ||
elif 'bias' in name: | ||
Constant(w, 0) | ||
else: | ||
pass | ||
|
||
|
||
def train(config, model, train_iter, dev_iter, test_iter): | ||
start_time = time.time() | ||
model.train() | ||
optimizer = nn.Adam(model.parameters(), lr=config.learning_rate) | ||
|
||
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率 | ||
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) | ||
total_batch = 0 # 记录进行到多少batch | ||
dev_best_loss = float('inf') | ||
last_improve = 0 # 记录上次验证集loss下降的batch数 | ||
flag = False # 记录是否很久没有效果提升 | ||
for epoch in range(config.num_epochs): | ||
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs)) | ||
# scheduler.step() # 学习率衰减 | ||
for i, (trains, labels) in enumerate(train_iter): | ||
outputs = model(trains) | ||
model.zero_grad() | ||
loss = ops.cross_entropy(outputs, labels) | ||
grads = mindspore.grad(loss) | ||
optimizer(grads) | ||
if total_batch % 100 == 0: | ||
# 每多少轮输出在训练集和验证集上的效果 | ||
true = labels.data.cpu() | ||
predic = ops.max(outputs.data, 1)[1].cpu() | ||
train_acc = metrics.accuracy_score(true, predic) | ||
dev_acc, dev_loss = evaluate(config, model, dev_iter) | ||
if dev_loss < dev_best_loss: | ||
dev_best_loss = dev_loss | ||
mindspore.save_checkpoint(model.state_dict(), config.save_path) | ||
improve = '*' | ||
last_improve = total_batch | ||
else: | ||
improve = '' | ||
time_dif = get_time_dif(start_time) | ||
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}' | ||
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve)) | ||
model.train() | ||
total_batch += 1 | ||
if total_batch - last_improve > config.require_improvement: | ||
# 验证集loss超过1000batch没下降,结束训练 | ||
print("No optimization for a long time, auto-stopping...") | ||
flag = True | ||
break | ||
if flag: | ||
break | ||
test(config, model, test_iter) | ||
|
||
|
||
def test(config, model, test_iter): | ||
# test | ||
param_dict = mindspore.load_checkpoint(config.save_path) | ||
mindspore.load_param_into_net(model, param_dict) | ||
model.eval() | ||
start_time = time.time() | ||
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True) | ||
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}' | ||
print(msg.format(test_loss, test_acc)) | ||
print("Precision, Recall and F1-Score...") | ||
print(test_report) | ||
print("Confusion Matrix...") | ||
print(test_confusion) | ||
time_dif = get_time_dif(start_time) | ||
print("Time usage:", time_dif) | ||
|
||
|
||
def evaluate(config, model, data_iter, test=False): | ||
model.eval() | ||
loss_total = 0 | ||
predict_all = np.array([], dtype=int) | ||
labels_all = np.array([], dtype=int) | ||
with ops.stop_gradient(): | ||
for texts, labels in data_iter: | ||
outputs = model(texts) | ||
loss = ops.cross_entropy(outputs, labels) | ||
loss_total += loss | ||
labels = labels.data.cpu().numpy() | ||
predic = ops.max(outputs.data, 1)[1].cpu().numpy() | ||
labels_all = np.append(labels_all, labels) | ||
predict_all = np.append(predict_all, predic) | ||
|
||
acc = metrics.accuracy_score(labels_all, predict_all) | ||
if test: | ||
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4) | ||
confusion = metrics.confusion_matrix(labels_all, predict_all) | ||
return acc, loss_total / len(data_iter), report, confusion | ||
return acc, loss_total / len(data_iter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
#coding: UTF-8 | ||
import os | ||
import mindspore | ||
import numpy as np | ||
import pickle as pkl | ||
from tqdm import tqdm | ||
import time | ||
from datetime import timedelta | ||
from mindspore import Tensor | ||
import mindspore.nn as nn | ||
|
||
|
||
|
||
MAX_VOCAB_SIZE = 10000 # 词表长度限制 | ||
UNK, PAD = '<UNK>', '<PAD>' # 未知字,padding符号 | ||
|
||
|
||
def build_vocab(file_path, tokenizer, max_size, min_freq): | ||
vocab_dic = {} | ||
with open(file_path, 'r', encoding='UTF-8') as f: | ||
for line in tqdm(f): | ||
lin = line.strip() | ||
if not lin: | ||
continue | ||
content = lin.split('\t')[0] | ||
for word in tokenizer(content): | ||
vocab_dic[word] = vocab_dic.get(word, 0) + 1 | ||
vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size] | ||
vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)} | ||
vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) | ||
return vocab_dic | ||
|
||
|
||
def build_dataset(config, ues_word): | ||
if ues_word: | ||
tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level | ||
else: | ||
tokenizer = lambda x: [y for y in x] # char-level | ||
if os.path.exists(config.vocab_path): | ||
vocab = pkl.load(open(config.vocab_path, 'rb')) | ||
else: | ||
vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) | ||
pkl.dump(vocab, open(config.vocab_path, 'wb')) | ||
print(f"Vocab size: {len(vocab)}") | ||
|
||
def load_dataset(path, pad_size=32): | ||
contents = [] | ||
with open(path, 'r', encoding='UTF-8') as f: | ||
for line in tqdm(f): | ||
lin = line.strip() | ||
if not lin: | ||
continue | ||
content, label = lin.split('\t') | ||
words_line = [] | ||
token = tokenizer(content) | ||
seq_len = len(token) | ||
if pad_size: | ||
if len(token) < pad_size: | ||
token.extend([PAD] * (pad_size - len(token))) | ||
else: | ||
token = token[:pad_size] | ||
seq_len = pad_size | ||
# word to id | ||
for word in token: | ||
words_line.append(vocab.get(word, vocab.get(UNK))) | ||
contents.append((words_line, int(label), seq_len)) | ||
return contents # [([...], 0), ([...], 1), ...] | ||
train = load_dataset(config.train_path, config.pad_size) | ||
dev = load_dataset(config.dev_path, config.pad_size) | ||
test = load_dataset(config.test_path, config.pad_size) | ||
return vocab, train, dev, test | ||
|
||
class DatasetIterater(object): | ||
def __init__(self, batches, batch_size, device): | ||
self.batch_size = batch_size | ||
self.batches = batches | ||
self.n_batches = len(batches) // batch_size | ||
self.residue = False # 记录batch数量是否为整数 | ||
if len(batches) % self.n_batches != 0: | ||
self.residue = True | ||
self.index = 0 | ||
self.device = device | ||
|
||
def _to_tensor(self, datas): | ||
x = Tensor.long([_[0] for _ in datas]).to(self.device) | ||
y = Tensor.long([_[1] for _ in datas]).to(self.device) | ||
|
||
# pad前的长度(超过pad_size的设为pad_size) | ||
seq_len = Tensor.long([_[2] for _ in datas]).to(self.device) | ||
return (x, seq_len), y | ||
|
||
def __next__(self): | ||
if self.residue and self.index == self.n_batches: | ||
batches = self.batches[self.index * self.batch_size: len(self.batches)] | ||
self.index += 1 | ||
batches = self._to_tensor(batches) | ||
return batches | ||
|
||
elif self.index >= self.n_batches: | ||
self.index = 0 | ||
raise StopIteration | ||
else: | ||
batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size] | ||
self.index += 1 | ||
batches = self._to_tensor(batches) | ||
return batches | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __len__(self): | ||
if self.residue: | ||
return self.n_batches + 1 | ||
else: | ||
return self.n_batches | ||
|
||
def build_iterator(dataset, config): | ||
iter = DatasetIterater(dataset, config.batch_size, config.device) | ||
return iter | ||
|
||
|
||
def get_time_dif(start_time): | ||
"""获取已使用时间""" | ||
end_time = time.time() | ||
time_dif = end_time - start_time | ||
return timedelta(seconds=int(round(time_dif))) | ||
|
||
|
||
if __name__ == "__main__": | ||
'''提取预训练词向量''' | ||
# 下面的目录、文件名按需更改。 | ||
train_dir = "./THUCNews/data/train.txt" | ||
vocab_dir = "./THUCNews/data/vocab.pkl" | ||
pretrain_dir = "./THUCNews/data/sgns.sogou.char" | ||
emb_dim = 300 | ||
filename_trimmed_dir = "./THUCNews/data/embedding_SougouNews" | ||
if os.path.exists(vocab_dir): | ||
word_to_id = pkl.load(open(vocab_dir, 'rb')) | ||
else: | ||
# tokenizer = lambda x: x.split(' ') # 以词为单位构建词表(数据集中词之间以空格隔开) | ||
tokenizer = lambda x: [y for y in x] # 以字为单位构建词表 | ||
word_to_id = build_vocab(train_dir, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) | ||
pkl.dump(word_to_id, open(vocab_dir, 'wb')) | ||
|
||
embeddings = np.random.rand(len(word_to_id), emb_dim) | ||
f = open(pretrain_dir, "r", encoding='UTF-8') | ||
for i, line in enumerate(f.readlines()): | ||
# if i == 0: # 若第一行是标题,则跳过 | ||
# continue | ||
lin = line.strip().split(" ") | ||
if lin[0] in word_to_id: | ||
idx = word_to_id[lin[0]] | ||
emb = [float(x) for x in lin[1:301]] | ||
embeddings[idx] = np.asarray(emb, dtype='float32') | ||
f.close() | ||
np.savez_compressed(filename_trimmed_dir, embeddings=embeddings) |
Oops, something went wrong.