forked from songyouwei/ABSA-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_example.py
80 lines (71 loc) · 2.96 KB
/
infer_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# -*- coding: utf-8 -*-
# file: infer.py
# author: songyouwei <[email protected]>
# Copyright (C) 2019. All Rights Reserved.
import torch
import torch.nn.functional as F
import argparse
from data_utils import build_tokenizer, build_embedding_matrix
from models import IAN, MemNet, ATAE_LSTM, AOA
class Inferer:
"""A simple inference example"""
def __init__(self, opt):
self.opt = opt
self.tokenizer = build_tokenizer(
fnames=[opt.dataset_file['train'], opt.dataset_file['test']],
max_seq_len=opt.max_seq_len,
dat_fname='{0}_tokenizer.dat'.format(opt.dataset))
embedding_matrix = build_embedding_matrix(
word2idx=self.tokenizer.word2idx,
embed_dim=opt.embed_dim,
dat_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), opt.dataset))
self.model = opt.model_class(embedding_matrix, opt)
print('loading model {0} ...'.format(opt.model_name))
self.model.load_state_dict(torch.load(opt.state_dict_path))
self.model = self.model.to(opt.device)
# switch model to evaluation mode
self.model.eval()
torch.autograd.set_grad_enabled(False)
def evaluate(self, raw_texts):
context_seqs = [self.tokenizer.text_to_sequence(raw_text.lower().strip()) for raw_text in raw_texts]
aspect_seqs = [self.tokenizer.text_to_sequence('null')] * len(raw_texts)
context_indices = torch.tensor(context_seqs, dtype=torch.int64).to(self.opt.device)
aspect_indices = torch.tensor(aspect_seqs, dtype=torch.int64).to(self.opt.device)
t_inputs = [context_indices, aspect_indices]
t_outputs = self.model(t_inputs)
print(t_outputs)
t_probs = F.softmax(t_outputs, dim=-1).cpu().numpy()
return t_probs
if __name__ == '__main__':
model_classes = {
'atae_lstm': ATAE_LSTM,
'ian': IAN,
'memnet': MemNet,
'aoa': AOA,
}
# set your trained models here
model_state_dict_paths = {
'atae_lstm': 'state_dict/atae_lstm_restaurant_acc0.7786',
'ian': 'state_dict/ian_restaurant_val_acc0.65',
'memnet': 'state_dict/memnet_restaurant_val_acc0.7045',
'aoa': 'state_dict/aoa_restaurant_acc0.8063',
}
class Option(object): pass
opt = Option()
opt.model_name = 'memnet'
opt.model_class = model_classes[opt.model_name]
opt.dataset = 'restaurant'
opt.dataset_file = {
'train': './datasets/semeval14/Restaurants_Train.xml.seg',
'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'
}
opt.state_dict_path = model_state_dict_paths[opt.model_name]
opt.embed_dim = 300
opt.hidden_dim = 300
opt.max_seq_len = 80
opt.polarities_dim = 3
opt.hops = 3
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inf = Inferer(opt)
t_probs = inf.evaluate(['happy memory', 'bad service', 'just normal food'])
print(t_probs.argmax(axis=-1) - 1)