Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
- change  positon of dataloader
- fix computation of training loss & acc
- simple README for cross validation
  • Loading branch information
songyouwei committed May 30, 2019
1 parent f59cfce commit 099e62b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 32 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ python train.py --model_name bert_spc --dataset restaurant

See [train.py](./train.py) for more training arguments.

Refer to [train_k_fold_cross_val.py](./train_k_fold_cross_val.py) for k-fold cross validation support.

### Inference

Please refer to [infer_example.py](./infer_example.py).
Expand Down
2 changes: 1 addition & 1 deletion models/aen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, pre, label, size_average=True):
return torch.sum(loss)


class AEN(nn.Module):
class AEN_GloVe(nn.Module):
def __init__(self, embedding_matrix, opt):
super(AEN, self).__init__()
self.opt = opt
Expand Down
46 changes: 23 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from data_utils import build_tokenizer, build_embedding_matrix, Tokenizer4Bert, ABSADataset

from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA, MGAN
from models.aen import CrossEntropyLoss_LSR, AEN, AEN_BERT
from models.aen import CrossEntropyLoss_LSR, AEN_BERT
from models.bert_spc import BERT_SPC

logger = logging.getLogger()
Expand All @@ -47,17 +47,14 @@ def __init__(self, opt):
dat_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), opt.dataset))
self.model = opt.model_class(embedding_matrix, opt).to(opt.device)

trainset = ABSADataset(opt.dataset_file['train'], tokenizer)
testset = ABSADataset(opt.dataset_file['test'], tokenizer)
self.trainset = ABSADataset(opt.dataset_file['train'], tokenizer)
self.testset = ABSADataset(opt.dataset_file['test'], tokenizer)
assert 0 <= opt.valset_ratio < 1
if opt.valset_ratio > 0:
valset_len = int(len(trainset) * opt.valset_ratio)
trainset, valset = random_split(trainset, (len(trainset)-valset_len, valset_len))
valset_len = int(len(self.trainset) * opt.valset_ratio)
self.trainset, self.valset = random_split(self.trainset, (len(self.trainset)-valset_len, valset_len))
else:
valset = testset
self.train_data_loader = DataLoader(dataset=trainset, batch_size=opt.batch_size, shuffle=True)
self.test_data_loader = DataLoader(dataset=testset, batch_size=opt.batch_size, shuffle=False)
self.val_data_loader = DataLoader(dataset=valset, batch_size=opt.batch_size, shuffle=False)
self.valset = self.testset

if opt.device.type == 'cuda':
logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index)))
Expand Down Expand Up @@ -87,18 +84,18 @@ def _reset_params(self):
stdv = 1. / math.sqrt(p.shape[0])
torch.nn.init.uniform_(p, a=-stdv, b=stdv)

def _train(self, criterion, optimizer):
def _train(self, criterion, optimizer, train_data_loader, val_data_loader):
max_val_acc = 0
max_val_f1 = 0
global_step = 0
path = None
for epoch in range(self.opt.num_epoch):
logger.info('>' * 100)
logger.info('epoch: {}'.format(epoch))
n_correct, n_total = 0, 0
n_correct, n_total, loss_total = 0, 0, 0
# switch model to training mode
self.model.train()
for i_batch, sample_batched in enumerate(self.train_data_loader):
for i_batch, sample_batched in enumerate(train_data_loader):
global_step += 1
# clear gradient accumulators
optimizer.zero_grad()
Expand All @@ -111,16 +108,15 @@ def _train(self, criterion, optimizer):
loss.backward()
optimizer.step()

n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
n_total += len(outputs)
loss_total += loss.item() * len(outputs)
if global_step % self.opt.log_step == 0:
n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
n_total += len(outputs)
train_acc = n_correct / n_total
train_loss = loss_total / n_total
logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

logger.info('loss: {:.4f}, acc: {:.4f}'.format(loss.item(), train_acc))

# switch model to evaluation mode
self.model.eval()
val_acc, val_f1 = self._evaluate_acc_f1(self.val_data_loader)
val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader)
logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1))
if val_acc > max_val_acc:
max_val_acc = val_acc
Expand All @@ -137,6 +133,8 @@ def _train(self, criterion, optimizer):
def _evaluate_acc_f1(self, data_loader):
n_correct, n_total = 0, 0
t_targets_all, t_outputs_all = None, None
# switch model to evaluation mode
self.model.eval()
with torch.no_grad():
for t_batch, t_sample_batched in enumerate(data_loader):
t_inputs = [t_sample_batched[col].to(self.opt.device) for col in self.opt.inputs_cols]
Expand All @@ -163,11 +161,15 @@ def run(self):
_params = filter(lambda p: p.requires_grad, self.model.parameters())
optimizer = self.opt.optimizer(_params, lr=self.opt.learning_rate, weight_decay=self.opt.l2reg)

train_data_loader = DataLoader(dataset=self.trainset, batch_size=self.opt.batch_size, shuffle=True)
test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False)
val_data_loader = DataLoader(dataset=self.valset, batch_size=self.opt.batch_size, shuffle=False)

self._reset_params()
best_model_path = self._train(criterion, optimizer)
best_model_path = self._train(criterion, optimizer, train_data_loader, val_data_loader)
self.model.load_state_dict(torch.load(best_model_path))
self.model.eval()
test_acc, test_f1 = self._evaluate_acc_f1(self.test_data_loader)
test_acc, test_f1 = self._evaluate_acc_f1(test_data_loader)
logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1))


Expand Down Expand Up @@ -216,7 +218,6 @@ def main():
'aoa': AOA,
'mgan': MGAN,
'bert_spc': BERT_SPC,
'aen': AEN,
'aen_bert': AEN_BERT,
}
dataset_files = {
Expand Down Expand Up @@ -245,7 +246,6 @@ def main():
'aoa': ['text_raw_indices', 'aspect_indices'],
'mgan': ['text_raw_indices', 'aspect_indices', 'text_left_indices'],
'bert_spc': ['text_bert_indices', 'bert_segments_ids'],
'aen': ['text_raw_indices', 'aspect_indices'],
'aen_bert': ['text_raw_bert_indices', 'aspect_bert_indices'],
}
initializers = {
Expand Down
15 changes: 7 additions & 8 deletions train_k_fold_cross_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from data_utils import build_tokenizer, build_embedding_matrix, Tokenizer4Bert, ABSADataset

from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA, MGAN
from models.aen import CrossEntropyLoss_LSR, AEN, AEN_BERT
from models.aen import CrossEntropyLoss_LSR, AEN_BERT
from models.bert_spc import BERT_SPC

logger = logging.getLogger()
Expand Down Expand Up @@ -89,7 +89,7 @@ def _train(self, criterion, optimizer, train_data_loader, val_data_loader):
path = None
for epoch in range(self.opt.num_epoch):
logger.info('epoch: {}'.format(epoch))
n_correct, n_total = 0, 0
n_correct, n_total, loss_total = 0, 0, 0
# switch model to training mode
self.model.train()
for i_batch, sample_batched in enumerate(train_data_loader):
Expand All @@ -105,12 +105,13 @@ def _train(self, criterion, optimizer, train_data_loader, val_data_loader):
loss.backward()
optimizer.step()

n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
n_total += len(outputs)
loss_total += loss.item() * len(outputs)
if global_step % self.opt.log_step == 0:
n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
n_total += len(outputs)
train_acc = n_correct / n_total

logger.info('loss: {:.4f}, acc: {:.4f}'.format(loss.item(), train_acc))
train_loss = loss_total / n_total
logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader)
logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1))
Expand Down Expand Up @@ -229,7 +230,6 @@ def main():
'aoa': AOA,
'mgan': MGAN,
'bert_spc': BERT_SPC,
'aen': AEN,
'aen_bert': AEN_BERT,
}
dataset_files = {
Expand Down Expand Up @@ -258,7 +258,6 @@ def main():
'aoa': ['text_raw_indices', 'aspect_indices'],
'mgan': ['text_raw_indices', 'aspect_indices', 'text_left_indices'],
'bert_spc': ['text_bert_indices', 'bert_segments_ids'],
'aen': ['text_raw_indices', 'aspect_indices'],
'aen_bert': ['text_raw_bert_indices', 'aspect_bert_indices'],
}
initializers = {
Expand Down

0 comments on commit 099e62b

Please sign in to comment.