diff --git a/README.md b/README.md index 8541aad..539b0ac 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ python train.py --model_name ian --dataset twitter --logdir ian_logs tensorboard --logdir=./ian_logs ``` +### Inference + +Please refer to [infer_example.py](./infer_example.py). + ## Implemented models ### AOA ([aoa.py](./models/aoa.py)) diff --git a/data_utils.py b/data_utils.py index e546fd6..2c749bb 100644 --- a/data_utils.py +++ b/data_utils.py @@ -172,6 +172,7 @@ def __init__(self, dataset='twitter', embed_dim=100, max_seq_len=40): text = ABSADatesetReader.__read_text__([fname[dataset]['train'], fname[dataset]['test']]) tokenizer = Tokenizer(max_seq_len=max_seq_len) tokenizer.fit_on_text(text.lower()) + self.tokenizer = tokenizer self.embedding_matrix = build_embedding_matrix(tokenizer.word2idx, embed_dim, dataset) self.train_data = ABSADataset(ABSADatesetReader.__read_data__(fname[dataset]['train'], tokenizer)) self.test_data = ABSADataset(ABSADatesetReader.__read_data__(fname[dataset]['test'], tokenizer)) diff --git a/infer_example.py b/infer_example.py new file mode 100644 index 0000000..a9351f5 --- /dev/null +++ b/infer_example.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# file: infer.py +# author: songyouwei +# Copyright (C) 2019. All Rights Reserved. + +from data_utils import ABSADatesetReader +import torch +import torch.nn.functional as F +import argparse + +from models import IAN, MemNet, TD_LSTM, ATAE_LSTM, AOA + + +class Inferer: + """A simple inference example""" + def __init__(self, opt): + self.opt = opt + absa_dataset = ABSADatesetReader(dataset=opt.dataset, embed_dim=opt.embed_dim, max_seq_len=opt.max_seq_len) + self.tokenizer = absa_dataset.tokenizer + + self.model = opt.model_class(absa_dataset.embedding_matrix, opt) + self.model.load_state_dict(torch.load(opt.state_dict_path)) + self.model = self.model.to(opt.device) + + def evaluate(self, raw_texts): + context_seqs = [self.tokenizer.text_to_sequence(raw_text.lower().strip()) for raw_text in raw_texts] + aspect_seqs = [self.tokenizer.text_to_sequence('null')] * len(raw_texts) + context_indices = torch.tensor(context_seqs, dtype=torch.int64).to(self.opt.device) + aspect_indices = torch.tensor(aspect_seqs, dtype=torch.int64).to(self.opt.device) + # switch model to evaluation mode + self.model.eval() + with torch.no_grad(): + t_inputs = [context_indices, aspect_indices] + t_outputs = self.model(t_inputs) + + t_probs = F.softmax(t_outputs, dim=-1).cpu().numpy() + return t_probs + + +if __name__ == '__main__': + # Hyper Parameters + parser = argparse.ArgumentParser() + parser.add_argument('--model_name', default='ian', type=str) + parser.add_argument('--state_dict_path', default='state_dict/ian_restaurant_acc0.7911', type=str) + parser.add_argument('--dataset', default='restaurant', type=str, help='twitter, restaurant, laptop') + parser.add_argument('--embed_dim', default=300, type=int) + parser.add_argument('--hidden_dim', default=300, type=int) + parser.add_argument('--max_seq_len', default=80, type=int) + parser.add_argument('--polarities_dim', default=3, type=int) + parser.add_argument('--device', default=None, type=str) + opt = parser.parse_args() + + model_classes = { + 'td_lstm': TD_LSTM, + 'atae_lstm': ATAE_LSTM, + 'ian': IAN, + 'memnet': MemNet, + 'aoa': AOA, + } + opt.model_class = model_classes[opt.model_name] + opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \ + if opt.device is None else torch.device(opt.device) + + inf = Inferer(opt) + t_probs = inf.evaluate(['happy memory', 'the service is terrible', 'just normal food']) + print(t_probs.argmax(axis=-1) - 1) diff --git a/train.py b/train.py index 8416cbd..816b4bf 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ from tensorboardX import SummaryWriter import argparse import math +import os from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA @@ -50,7 +51,7 @@ def _reset_params(self): stdv = 1. / math.sqrt(p.shape[0]) torch.nn.init.uniform_(p, a=-stdv, b=stdv) - def _train(self, criterion, optimizer): + def _train(self, criterion, optimizer, max_test_acc_overall=0): writer = SummaryWriter(log_dir=self.opt.logdir) max_test_acc = 0 max_f1 = 0 @@ -82,13 +83,19 @@ def _train(self, criterion, optimizer): test_acc, f1 = self._evaluate_acc_f1() if test_acc > max_test_acc: max_test_acc = test_acc + if test_acc > max_test_acc_overall: + if not os.path.exists('state_dict'): + os.mkdir('state_dict') + path = 'state_dict/{0}_{1}_acc{2}'.format(self.opt.model_name, self.opt.dataset, round(test_acc, 4)) + torch.save(self.model.state_dict(), path) + print('>> saved: ' + path) if f1 > max_f1: max_f1 = f1 writer.add_scalar('loss', loss, global_step) writer.add_scalar('acc', train_acc, global_step) writer.add_scalar('test_acc', test_acc, global_step) - print('loss: {:.4f}, acc: {:.4f}, test_acc: {:.4f}'.format(loss.item(), train_acc, test_acc)) + print('loss: {:.4f}, acc: {:.4f}, test_acc: {:.4f}, f1: {:.4f}'.format(loss.item(), train_acc, test_acc, f1)) writer.close() return max_test_acc, max_f1 @@ -129,7 +136,7 @@ def run(self, repeats=1): for i in range(repeats): print('repeat: ', i) self._reset_params() - max_test_acc, max_f1 = self._train(criterion, optimizer) + max_test_acc, max_f1 = self._train(criterion, optimizer, max_test_acc_overall=max_test_acc_overall) print('max_test_acc: {0} max_f1: {1}'.format(max_test_acc, max_f1)) max_test_acc_overall = max(max_test_acc, max_test_acc_overall) max_f1_overall = max(max_f1, max_f1_overall) @@ -204,4 +211,4 @@ def run(self, repeats=1): if opt.device is None else torch.device(opt.device) ins = Instructor(opt) - ins.run() + ins.run(5)