Skip to content

Commit

Permalink
add k-fold validation version of train file
Browse files Browse the repository at this point in the history
  • Loading branch information
songyouwei committed May 18, 2019
1 parent 9ef6cd5 commit f59cfce
Showing 1 changed file with 294 additions and 0 deletions.
294 changes: 294 additions & 0 deletions train_k_fold_cross_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# -*- coding: utf-8 -*-
# file: train_k_fold_cross_val.py
# author: songyouwei <[email protected]>
# Copyright (C) 2019. All Rights Reserved.

import logging
import argparse
import math
import os
import sys
from time import strftime, localtime
import random
import numpy

from pytorch_pretrained_bert import BertModel
from sklearn import metrics
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, ConcatDataset

from data_utils import build_tokenizer, build_embedding_matrix, Tokenizer4Bert, ABSADataset

from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA, MGAN
from models.aen import CrossEntropyLoss_LSR, AEN, AEN_BERT
from models.bert_spc import BERT_SPC

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))


class Instructor:
def __init__(self, opt):
self.opt = opt

if 'bert' in opt.model_name:
tokenizer = Tokenizer4Bert(opt.max_seq_len, opt.pretrained_bert_name)
bert = BertModel.from_pretrained(opt.pretrained_bert_name)
self.pretrained_bert_state_dict = bert.state_dict()
self.model = opt.model_class(bert, opt).to(opt.device)
else:
tokenizer = build_tokenizer(
fnames=[opt.dataset_file['train'], opt.dataset_file['test']],
max_seq_len=opt.max_seq_len,
dat_fname='{0}_tokenizer.dat'.format(opt.dataset))
embedding_matrix = build_embedding_matrix(
word2idx=tokenizer.word2idx,
embed_dim=opt.embed_dim,
dat_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), opt.dataset))
self.model = opt.model_class(embedding_matrix, opt).to(opt.device)

self.trainset = ABSADataset(opt.dataset_file['train'], tokenizer)
self.testset = ABSADataset(opt.dataset_file['test'], tokenizer)

if opt.device.type == 'cuda':
logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index)))
self._print_args()

def _print_args(self):
n_trainable_params, n_nontrainable_params = 0, 0
for p in self.model.parameters():
n_params = torch.prod(torch.tensor(p.shape))
if p.requires_grad:
n_trainable_params += n_params
else:
n_nontrainable_params += n_params
logger.info('n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
logger.info('> training arguments:')
for arg in vars(self.opt):
logger.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))

def _reset_params(self):
for child in self.model.children():
if type(child) != BertModel: # skip bert params
for p in child.parameters():
if p.requires_grad:
if len(p.shape) > 1:
self.opt.initializer(p)
else:
stdv = 1. / math.sqrt(p.shape[0])
torch.nn.init.uniform_(p, a=-stdv, b=stdv)
else:
self.model.bert.load_state_dict(self.pretrained_bert_state_dict)

def _train(self, criterion, optimizer, train_data_loader, val_data_loader):
max_val_acc = 0
max_val_f1 = 0
global_step = 0
path = None
for epoch in range(self.opt.num_epoch):
logger.info('epoch: {}'.format(epoch))
n_correct, n_total = 0, 0
# switch model to training mode
self.model.train()
for i_batch, sample_batched in enumerate(train_data_loader):
global_step += 1
# clear gradient accumulators
optimizer.zero_grad()

inputs = [sample_batched[col].to(self.opt.device) for col in self.opt.inputs_cols]
outputs = self.model(inputs)
targets = sample_batched['polarity'].to(self.opt.device)

loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

if global_step % self.opt.log_step == 0:
n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
n_total += len(outputs)
train_acc = n_correct / n_total

logger.info('loss: {:.4f}, acc: {:.4f}'.format(loss.item(), train_acc))

val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader)
logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1))
if val_acc > max_val_acc:
max_val_acc = val_acc
if not os.path.exists('state_dict'):
os.mkdir('state_dict')
path = 'state_dict/{0}_{1}_val_temp'.format(self.opt.model_name, self.opt.dataset)
torch.save(self.model.state_dict(), path)
logger.info('>> saved: {}'.format(path))
if val_f1 > max_val_f1:
max_val_f1 = val_f1

return path

def _evaluate_acc_f1(self, data_loader):
n_correct, n_total = 0, 0
t_targets_all, t_outputs_all = None, None
# switch model to evaluation mode
self.model.eval()
with torch.no_grad():
for t_batch, t_sample_batched in enumerate(data_loader):
t_inputs = [t_sample_batched[col].to(self.opt.device) for col in self.opt.inputs_cols]
t_targets = t_sample_batched['polarity'].to(self.opt.device)
t_outputs = self.model(t_inputs)

n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
n_total += len(t_outputs)

if t_targets_all is None:
t_targets_all = t_targets
t_outputs_all = t_outputs
else:
t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)

acc = n_correct / n_total
f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], average='macro')
return acc, f1

def run(self):
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
_params = filter(lambda p: p.requires_grad, self.model.parameters())
optimizer = self.opt.optimizer(_params, lr=self.opt.learning_rate, weight_decay=self.opt.l2reg)

test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False)
valset_len = len(self.trainset) // self.opt.cross_val_fold
splitedsets = random_split(self.trainset, tuple([valset_len] * (self.opt.cross_val_fold - 1) + [len(self.trainset) - valset_len * (self.opt.cross_val_fold - 1)]))

all_test_acc, all_test_f1 = [], []
for fid in range(self.opt.cross_val_fold):
logger.info('fold : {}'.format(fid))
logger.info('>' * 100)
trainset = ConcatDataset([x for i, x in enumerate(splitedsets) if i != fid])
valset = splitedsets[fid]
train_data_loader = DataLoader(dataset=trainset, batch_size=self.opt.batch_size, shuffle=True)
val_data_loader = DataLoader(dataset=valset, batch_size=self.opt.batch_size, shuffle=False)

self._reset_params()
best_model_path = self._train(criterion, optimizer, train_data_loader, val_data_loader)

self.model.load_state_dict(torch.load(best_model_path))
test_acc, test_f1 = self._evaluate_acc_f1(test_data_loader)
all_test_acc.append(test_acc)
all_test_f1.append(test_f1)
logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1))

mean_test_acc, mean_test_f1 = numpy.mean(all_test_acc), numpy.mean(all_test_f1)
logger.info('>' * 100)
logger.info('>>> mean_test_acc: {:.4f}, mean_test_f1: {:.4f}'.format(mean_test_acc, mean_test_f1))


def main():
# Hyper Parameters
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='bert_spc', type=str)
parser.add_argument('--dataset', default='twitter', type=str, help='twitter, restaurant, laptop')
parser.add_argument('--optimizer', default='adam', type=str)
parser.add_argument('--initializer', default='xavier_uniform_', type=str)
parser.add_argument('--learning_rate', default=2e-5, type=float, help='try 5e-5, 2e-5 for BERT, 1e-3 for others')
parser.add_argument('--dropout', default=0.1, type=float)
parser.add_argument('--l2reg', default=0.01, type=float)
parser.add_argument('--num_epoch', default=10, type=int, help='try larger number for non-BERT models')
parser.add_argument('--batch_size', default=64, type=int, help='try 16, 32, 64 for BERT models')
parser.add_argument('--log_step', default=10, type=int)
parser.add_argument('--embed_dim', default=300, type=int)
parser.add_argument('--hidden_dim', default=300, type=int)
parser.add_argument('--bert_dim', default=768, type=int)
parser.add_argument('--pretrained_bert_name', default='bert-base-uncased', type=str)
parser.add_argument('--max_seq_len', default=80, type=int)
parser.add_argument('--polarities_dim', default=3, type=int)
parser.add_argument('--hops', default=3, type=int)
parser.add_argument('--device', default=None, type=str, help='e.g. cuda:0')
parser.add_argument('--seed', default=None, type=int, help='set seed for reproducibility')
parser.add_argument('--cross_val_fold', default=10, type=int, help='k-fold cross validation')
opt = parser.parse_args()

if opt.seed is not None:
random.seed(opt.seed)
numpy.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model_classes = {
'lstm': LSTM,
'td_lstm': TD_LSTM,
'atae_lstm': ATAE_LSTM,
'ian': IAN,
'memnet': MemNet,
'ram': RAM,
'cabasc': Cabasc,
'tnet_lf': TNet_LF,
'aoa': AOA,
'mgan': MGAN,
'bert_spc': BERT_SPC,
'aen': AEN,
'aen_bert': AEN_BERT,
}
dataset_files = {
'twitter': {
'train': './datasets/acl-14-short-data/train.raw',
'test': './datasets/acl-14-short-data/test.raw'
},
'restaurant': {
'train': './datasets/semeval14/Restaurants_Train.xml.seg',
'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'
},
'laptop': {
'train': './datasets/semeval14/Laptops_Train.xml.seg',
'test': './datasets/semeval14/Laptops_Test_Gold.xml.seg'
}
}
input_colses = {
'lstm': ['text_raw_indices'],
'td_lstm': ['text_left_with_aspect_indices', 'text_right_with_aspect_indices'],
'atae_lstm': ['text_raw_indices', 'aspect_indices'],
'ian': ['text_raw_indices', 'aspect_indices'],
'memnet': ['text_raw_without_aspect_indices', 'aspect_indices'],
'ram': ['text_raw_indices', 'aspect_indices', 'text_left_indices'],
'cabasc': ['text_raw_indices', 'aspect_indices', 'text_left_with_aspect_indices', 'text_right_with_aspect_indices'],
'tnet_lf': ['text_raw_indices', 'aspect_indices', 'aspect_in_text'],
'aoa': ['text_raw_indices', 'aspect_indices'],
'mgan': ['text_raw_indices', 'aspect_indices', 'text_left_indices'],
'bert_spc': ['text_bert_indices', 'bert_segments_ids'],
'aen': ['text_raw_indices', 'aspect_indices'],
'aen_bert': ['text_raw_bert_indices', 'aspect_bert_indices'],
}
initializers = {
'xavier_uniform_': torch.nn.init.xavier_uniform_,
'xavier_normal_': torch.nn.init.xavier_normal,
'orthogonal_': torch.nn.init.orthogonal_,
}
optimizers = {
'adadelta': torch.optim.Adadelta, # default lr=1.0
'adagrad': torch.optim.Adagrad, # default lr=0.01
'adam': torch.optim.Adam, # default lr=0.001
'adamax': torch.optim.Adamax, # default lr=0.002
'asgd': torch.optim.ASGD, # default lr=0.01
'rmsprop': torch.optim.RMSprop, # default lr=0.01
'sgd': torch.optim.SGD,
}
opt.model_class = model_classes[opt.model_name]
opt.dataset_file = dataset_files[opt.dataset]
opt.inputs_cols = input_colses[opt.model_name]
opt.initializer = initializers[opt.initializer]
opt.optimizer = optimizers[opt.optimizer]
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \
if opt.device is None else torch.device(opt.device)

log_file = '{}-{}-{}.log'.format(opt.model_name, opt.dataset, strftime("%y%m%d-%H%M", localtime()))
logger.addHandler(logging.FileHandler(log_file))

ins = Instructor(opt)
ins.run()


if __name__ == '__main__':
main()

0 comments on commit f59cfce

Please sign in to comment.