Skip to content

Commit

Permalink
add a simple inference example
Browse files Browse the repository at this point in the history
  • Loading branch information
songyouwei committed Jan 8, 2019
1 parent f84a215 commit ee10e21
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 4 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ python train.py --model_name ian --dataset twitter --logdir ian_logs
tensorboard --logdir=./ian_logs
```

### Inference

Please refer to [infer_example.py](./infer_example.py).

## Implemented models

### AOA ([aoa.py](./models/aoa.py))
Expand Down
1 change: 1 addition & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(self, dataset='twitter', embed_dim=100, max_seq_len=40):
text = ABSADatesetReader.__read_text__([fname[dataset]['train'], fname[dataset]['test']])
tokenizer = Tokenizer(max_seq_len=max_seq_len)
tokenizer.fit_on_text(text.lower())
self.tokenizer = tokenizer
self.embedding_matrix = build_embedding_matrix(tokenizer.word2idx, embed_dim, dataset)
self.train_data = ABSADataset(ABSADatesetReader.__read_data__(fname[dataset]['train'], tokenizer))
self.test_data = ABSADataset(ABSADatesetReader.__read_data__(fname[dataset]['test'], tokenizer))
66 changes: 66 additions & 0 deletions infer_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# file: infer.py
# author: songyouwei <[email protected]>
# Copyright (C) 2019. All Rights Reserved.

from data_utils import ABSADatesetReader
import torch
import torch.nn.functional as F
import argparse

from models import IAN, MemNet, TD_LSTM, ATAE_LSTM, AOA


class Inferer:
"""A simple inference example"""
def __init__(self, opt):
self.opt = opt
absa_dataset = ABSADatesetReader(dataset=opt.dataset, embed_dim=opt.embed_dim, max_seq_len=opt.max_seq_len)
self.tokenizer = absa_dataset.tokenizer

self.model = opt.model_class(absa_dataset.embedding_matrix, opt)
self.model.load_state_dict(torch.load(opt.state_dict_path))
self.model = self.model.to(opt.device)

def evaluate(self, raw_texts):
context_seqs = [self.tokenizer.text_to_sequence(raw_text.lower().strip()) for raw_text in raw_texts]
aspect_seqs = [self.tokenizer.text_to_sequence('null')] * len(raw_texts)
context_indices = torch.tensor(context_seqs, dtype=torch.int64).to(self.opt.device)
aspect_indices = torch.tensor(aspect_seqs, dtype=torch.int64).to(self.opt.device)
# switch model to evaluation mode
self.model.eval()
with torch.no_grad():
t_inputs = [context_indices, aspect_indices]
t_outputs = self.model(t_inputs)

t_probs = F.softmax(t_outputs, dim=-1).cpu().numpy()
return t_probs


if __name__ == '__main__':
# Hyper Parameters
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='ian', type=str)
parser.add_argument('--state_dict_path', default='state_dict/ian_restaurant_acc0.7911', type=str)
parser.add_argument('--dataset', default='restaurant', type=str, help='twitter, restaurant, laptop')
parser.add_argument('--embed_dim', default=300, type=int)
parser.add_argument('--hidden_dim', default=300, type=int)
parser.add_argument('--max_seq_len', default=80, type=int)
parser.add_argument('--polarities_dim', default=3, type=int)
parser.add_argument('--device', default=None, type=str)
opt = parser.parse_args()

model_classes = {
'td_lstm': TD_LSTM,
'atae_lstm': ATAE_LSTM,
'ian': IAN,
'memnet': MemNet,
'aoa': AOA,
}
opt.model_class = model_classes[opt.model_name]
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \
if opt.device is None else torch.device(opt.device)

inf = Inferer(opt)
t_probs = inf.evaluate(['happy memory', 'the service is terrible', 'just normal food'])
print(t_probs.argmax(axis=-1) - 1)
15 changes: 11 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tensorboardX import SummaryWriter
import argparse
import math
import os

from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA

Expand Down Expand Up @@ -50,7 +51,7 @@ def _reset_params(self):
stdv = 1. / math.sqrt(p.shape[0])
torch.nn.init.uniform_(p, a=-stdv, b=stdv)

def _train(self, criterion, optimizer):
def _train(self, criterion, optimizer, max_test_acc_overall=0):
writer = SummaryWriter(log_dir=self.opt.logdir)
max_test_acc = 0
max_f1 = 0
Expand Down Expand Up @@ -82,13 +83,19 @@ def _train(self, criterion, optimizer):
test_acc, f1 = self._evaluate_acc_f1()
if test_acc > max_test_acc:
max_test_acc = test_acc
if test_acc > max_test_acc_overall:
if not os.path.exists('state_dict'):
os.mkdir('state_dict')
path = 'state_dict/{0}_{1}_acc{2}'.format(self.opt.model_name, self.opt.dataset, round(test_acc, 4))
torch.save(self.model.state_dict(), path)
print('>> saved: ' + path)
if f1 > max_f1:
max_f1 = f1

writer.add_scalar('loss', loss, global_step)
writer.add_scalar('acc', train_acc, global_step)
writer.add_scalar('test_acc', test_acc, global_step)
print('loss: {:.4f}, acc: {:.4f}, test_acc: {:.4f}'.format(loss.item(), train_acc, test_acc))
print('loss: {:.4f}, acc: {:.4f}, test_acc: {:.4f}, f1: {:.4f}'.format(loss.item(), train_acc, test_acc, f1))

writer.close()
return max_test_acc, max_f1
Expand Down Expand Up @@ -129,7 +136,7 @@ def run(self, repeats=1):
for i in range(repeats):
print('repeat: ', i)
self._reset_params()
max_test_acc, max_f1 = self._train(criterion, optimizer)
max_test_acc, max_f1 = self._train(criterion, optimizer, max_test_acc_overall=max_test_acc_overall)
print('max_test_acc: {0} max_f1: {1}'.format(max_test_acc, max_f1))
max_test_acc_overall = max(max_test_acc, max_test_acc_overall)
max_f1_overall = max(max_f1, max_f1_overall)
Expand Down Expand Up @@ -204,4 +211,4 @@ def run(self, repeats=1):
if opt.device is None else torch.device(opt.device)

ins = Instructor(opt)
ins.run()
ins.run(5)

0 comments on commit ee10e21

Please sign in to comment.