forked from songyouwei/ABSA-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add k-fold validation version of train file
- Loading branch information
1 parent
9ef6cd5
commit f59cfce
Showing
1 changed file
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
# -*- coding: utf-8 -*- | ||
# file: train_k_fold_cross_val.py | ||
# author: songyouwei <[email protected]> | ||
# Copyright (C) 2019. All Rights Reserved. | ||
|
||
import logging | ||
import argparse | ||
import math | ||
import os | ||
import sys | ||
from time import strftime, localtime | ||
import random | ||
import numpy | ||
|
||
from pytorch_pretrained_bert import BertModel | ||
from sklearn import metrics | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import DataLoader, random_split, ConcatDataset | ||
|
||
from data_utils import build_tokenizer, build_embedding_matrix, Tokenizer4Bert, ABSADataset | ||
|
||
from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA, MGAN | ||
from models.aen import CrossEntropyLoss_LSR, AEN, AEN_BERT | ||
from models.bert_spc import BERT_SPC | ||
|
||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
logger.addHandler(logging.StreamHandler(sys.stdout)) | ||
|
||
|
||
class Instructor: | ||
def __init__(self, opt): | ||
self.opt = opt | ||
|
||
if 'bert' in opt.model_name: | ||
tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.pretrained_bert_name) | ||
bert = BertModel.from_pretrained(opt.pretrained_bert_name) | ||
self.pretrained_bert_state_dict = bert.state_dict() | ||
self.model = opt.model_class(bert, opt).to(opt.device) | ||
else: | ||
tokenizer = build_tokenizer( | ||
fnames=[opt.dataset_file['train'], opt.dataset_file['test']], | ||
max_seq_len=opt.max_seq_len, | ||
dat_fname='{0}_tokenizer.dat'.format(opt.dataset)) | ||
embedding_matrix = build_embedding_matrix( | ||
word2idx=tokenizer.word2idx, | ||
embed_dim=opt.embed_dim, | ||
dat_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), opt.dataset)) | ||
self.model = opt.model_class(embedding_matrix, opt).to(opt.device) | ||
|
||
self.trainset = ABSADataset(opt.dataset_file['train'], tokenizer) | ||
self.testset = ABSADataset(opt.dataset_file['test'], tokenizer) | ||
|
||
if opt.device.type == 'cuda': | ||
logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index))) | ||
self._print_args() | ||
|
||
def _print_args(self): | ||
n_trainable_params, n_nontrainable_params = 0, 0 | ||
for p in self.model.parameters(): | ||
n_params = torch.prod(torch.tensor(p.shape)) | ||
if p.requires_grad: | ||
n_trainable_params += n_params | ||
else: | ||
n_nontrainable_params += n_params | ||
logger.info('n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params)) | ||
logger.info('> training arguments:') | ||
for arg in vars(self.opt): | ||
logger.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg))) | ||
|
||
def _reset_params(self): | ||
for child in self.model.children(): | ||
if type(child) != BertModel: # skip bert params | ||
for p in child.parameters(): | ||
if p.requires_grad: | ||
if len(p.shape) > 1: | ||
self.opt.initializer(p) | ||
else: | ||
stdv = 1. / math.sqrt(p.shape[0]) | ||
torch.nn.init.uniform_(p, a=-stdv, b=stdv) | ||
else: | ||
self.model.bert.load_state_dict(self.pretrained_bert_state_dict) | ||
|
||
def _train(self, criterion, optimizer, train_data_loader, val_data_loader): | ||
max_val_acc = 0 | ||
max_val_f1 = 0 | ||
global_step = 0 | ||
path = None | ||
for epoch in range(self.opt.num_epoch): | ||
logger.info('epoch: {}'.format(epoch)) | ||
n_correct, n_total = 0, 0 | ||
# switch model to training mode | ||
self.model.train() | ||
for i_batch, sample_batched in enumerate(train_data_loader): | ||
global_step += 1 | ||
# clear gradient accumulators | ||
optimizer.zero_grad() | ||
|
||
inputs = [sample_batched[col].to(self.opt.device) for col in self.opt.inputs_cols] | ||
outputs = self.model(inputs) | ||
targets = sample_batched['polarity'].to(self.opt.device) | ||
|
||
loss = criterion(outputs, targets) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if global_step % self.opt.log_step == 0: | ||
n_correct += (torch.argmax(outputs, -1) == targets).sum().item() | ||
n_total += len(outputs) | ||
train_acc = n_correct / n_total | ||
|
||
logger.info('loss: {:.4f}, acc: {:.4f}'.format(loss.item(), train_acc)) | ||
|
||
val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader) | ||
logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1)) | ||
if val_acc > max_val_acc: | ||
max_val_acc = val_acc | ||
if not os.path.exists('state_dict'): | ||
os.mkdir('state_dict') | ||
path = 'state_dict/{0}_{1}_val_temp'.format(self.opt.model_name, self.opt.dataset) | ||
torch.save(self.model.state_dict(), path) | ||
logger.info('>> saved: {}'.format(path)) | ||
if val_f1 > max_val_f1: | ||
max_val_f1 = val_f1 | ||
|
||
return path | ||
|
||
def _evaluate_acc_f1(self, data_loader): | ||
n_correct, n_total = 0, 0 | ||
t_targets_all, t_outputs_all = None, None | ||
# switch model to evaluation mode | ||
self.model.eval() | ||
with torch.no_grad(): | ||
for t_batch, t_sample_batched in enumerate(data_loader): | ||
t_inputs = [t_sample_batched[col].to(self.opt.device) for col in self.opt.inputs_cols] | ||
t_targets = t_sample_batched['polarity'].to(self.opt.device) | ||
t_outputs = self.model(t_inputs) | ||
|
||
n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item() | ||
n_total += len(t_outputs) | ||
|
||
if t_targets_all is None: | ||
t_targets_all = t_targets | ||
t_outputs_all = t_outputs | ||
else: | ||
t_targets_all = torch.cat((t_targets_all, t_targets), dim=0) | ||
t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0) | ||
|
||
acc = n_correct / n_total | ||
f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], average='macro') | ||
return acc, f1 | ||
|
||
def run(self): | ||
# Loss and Optimizer | ||
criterion = nn.CrossEntropyLoss() | ||
_params = filter(lambda p: p.requires_grad, self.model.parameters()) | ||
optimizer = self.opt.optimizer(_params, lr=self.opt.learning_rate, weight_decay=self.opt.l2reg) | ||
|
||
test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False) | ||
valset_len = len(self.trainset) // self.opt.cross_val_fold | ||
splitedsets = random_split(self.trainset, tuple([valset_len] * (self.opt.cross_val_fold - 1) + [len(self.trainset) - valset_len * (self.opt.cross_val_fold - 1)])) | ||
|
||
all_test_acc, all_test_f1 = [], [] | ||
for fid in range(self.opt.cross_val_fold): | ||
logger.info('fold : {}'.format(fid)) | ||
logger.info('>' * 100) | ||
trainset = ConcatDataset([x for i, x in enumerate(splitedsets) if i != fid]) | ||
valset = splitedsets[fid] | ||
train_data_loader = DataLoader(dataset=trainset, batch_size=self.opt.batch_size, shuffle=True) | ||
val_data_loader = DataLoader(dataset=valset, batch_size=self.opt.batch_size, shuffle=False) | ||
|
||
self._reset_params() | ||
best_model_path = self._train(criterion, optimizer, train_data_loader, val_data_loader) | ||
|
||
self.model.load_state_dict(torch.load(best_model_path)) | ||
test_acc, test_f1 = self._evaluate_acc_f1(test_data_loader) | ||
all_test_acc.append(test_acc) | ||
all_test_f1.append(test_f1) | ||
logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1)) | ||
|
||
mean_test_acc, mean_test_f1 = numpy.mean(all_test_acc), numpy.mean(all_test_f1) | ||
logger.info('>' * 100) | ||
logger.info('>>> mean_test_acc: {:.4f}, mean_test_f1: {:.4f}'.format(mean_test_acc, mean_test_f1)) | ||
|
||
|
||
def main(): | ||
# Hyper Parameters | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_name', default='bert_spc', type=str) | ||
parser.add_argument('--dataset', default='twitter', type=str, help='twitter, restaurant, laptop') | ||
parser.add_argument('--optimizer', default='adam', type=str) | ||
parser.add_argument('--initializer', default='xavier_uniform_', type=str) | ||
parser.add_argument('--learning_rate', default=2e-5, type=float, help='try 5e-5, 2e-5 for BERT, 1e-3 for others') | ||
parser.add_argument('--dropout', default=0.1, type=float) | ||
parser.add_argument('--l2reg', default=0.01, type=float) | ||
parser.add_argument('--num_epoch', default=10, type=int, help='try larger number for non-BERT models') | ||
parser.add_argument('--batch_size', default=64, type=int, help='try 16, 32, 64 for BERT models') | ||
parser.add_argument('--log_step', default=10, type=int) | ||
parser.add_argument('--embed_dim', default=300, type=int) | ||
parser.add_argument('--hidden_dim', default=300, type=int) | ||
parser.add_argument('--bert_dim', default=768, type=int) | ||
parser.add_argument('--pretrained_bert_name', default='bert-base-uncased', type=str) | ||
parser.add_argument('--max_seq_len', default=80, type=int) | ||
parser.add_argument('--polarities_dim', default=3, type=int) | ||
parser.add_argument('--hops', default=3, type=int) | ||
parser.add_argument('--device', default=None, type=str, help='e.g. cuda:0') | ||
parser.add_argument('--seed', default=None, type=int, help='set seed for reproducibility') | ||
parser.add_argument('--cross_val_fold', default=10, type=int, help='k-fold cross validation') | ||
opt = parser.parse_args() | ||
|
||
if opt.seed is not None: | ||
random.seed(opt.seed) | ||
numpy.random.seed(opt.seed) | ||
torch.manual_seed(opt.seed) | ||
torch.cuda.manual_seed(opt.seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
model_classes = { | ||
'lstm': LSTM, | ||
'td_lstm': TD_LSTM, | ||
'atae_lstm': ATAE_LSTM, | ||
'ian': IAN, | ||
'memnet': MemNet, | ||
'ram': RAM, | ||
'cabasc': Cabasc, | ||
'tnet_lf': TNet_LF, | ||
'aoa': AOA, | ||
'mgan': MGAN, | ||
'bert_spc': BERT_SPC, | ||
'aen': AEN, | ||
'aen_bert': AEN_BERT, | ||
} | ||
dataset_files = { | ||
'twitter': { | ||
'train': './datasets/acl-14-short-data/train.raw', | ||
'test': './datasets/acl-14-short-data/test.raw' | ||
}, | ||
'restaurant': { | ||
'train': './datasets/semeval14/Restaurants_Train.xml.seg', | ||
'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg' | ||
}, | ||
'laptop': { | ||
'train': './datasets/semeval14/Laptops_Train.xml.seg', | ||
'test': './datasets/semeval14/Laptops_Test_Gold.xml.seg' | ||
} | ||
} | ||
input_colses = { | ||
'lstm': ['text_raw_indices'], | ||
'td_lstm': ['text_left_with_aspect_indices', 'text_right_with_aspect_indices'], | ||
'atae_lstm': ['text_raw_indices', 'aspect_indices'], | ||
'ian': ['text_raw_indices', 'aspect_indices'], | ||
'memnet': ['text_raw_without_aspect_indices', 'aspect_indices'], | ||
'ram': ['text_raw_indices', 'aspect_indices', 'text_left_indices'], | ||
'cabasc': ['text_raw_indices', 'aspect_indices', 'text_left_with_aspect_indices', 'text_right_with_aspect_indices'], | ||
'tnet_lf': ['text_raw_indices', 'aspect_indices', 'aspect_in_text'], | ||
'aoa': ['text_raw_indices', 'aspect_indices'], | ||
'mgan': ['text_raw_indices', 'aspect_indices', 'text_left_indices'], | ||
'bert_spc': ['text_bert_indices', 'bert_segments_ids'], | ||
'aen': ['text_raw_indices', 'aspect_indices'], | ||
'aen_bert': ['text_raw_bert_indices', 'aspect_bert_indices'], | ||
} | ||
initializers = { | ||
'xavier_uniform_': torch.nn.init.xavier_uniform_, | ||
'xavier_normal_': torch.nn.init.xavier_normal, | ||
'orthogonal_': torch.nn.init.orthogonal_, | ||
} | ||
optimizers = { | ||
'adadelta': torch.optim.Adadelta, # default lr=1.0 | ||
'adagrad': torch.optim.Adagrad, # default lr=0.01 | ||
'adam': torch.optim.Adam, # default lr=0.001 | ||
'adamax': torch.optim.Adamax, # default lr=0.002 | ||
'asgd': torch.optim.ASGD, # default lr=0.01 | ||
'rmsprop': torch.optim.RMSprop, # default lr=0.01 | ||
'sgd': torch.optim.SGD, | ||
} | ||
opt.model_class = model_classes[opt.model_name] | ||
opt.dataset_file = dataset_files[opt.dataset] | ||
opt.inputs_cols = input_colses[opt.model_name] | ||
opt.initializer = initializers[opt.initializer] | ||
opt.optimizer = optimizers[opt.optimizer] | ||
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \ | ||
if opt.device is None else torch.device(opt.device) | ||
|
||
log_file = '{}-{}-{}.log'.format(opt.model_name, opt.dataset, strftime("%y%m%d-%H%M", localtime())) | ||
logger.addHandler(logging.FileHandler(log_file)) | ||
|
||
ins = Instructor(opt) | ||
ins.run() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |