forked from thu-coai/seq2seq-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
77 lines (67 loc) · 2.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import json
import numpy as np
import tensorflow as tf
from cotk.dataloader import SingleTurnDialog
from cotk.wordvector import WordVector, Glove
from utils import debug, try_cache
from model import Seq2SeqModel
def create_model(sess, data, args, embed):
with tf.variable_scope(args.name):
model = Seq2SeqModel(data, args, embed)
model.print_parameters()
latest_dir = '%s/checkpoint_latest' % args.model_dir
best_dir = '%s/checkpoint_best' % args.model_dir
if not os.path.isdir(args.model_dir):
os.mkdir(args.model_dir)
if not os.path.isdir(latest_dir):
os.mkdir(latest_dir)
if not os.path.isdir(best_dir):
os.mkdir(best_dir)
if tf.train.get_checkpoint_state(latest_dir, args.name) and args.restore == "last":
print("Reading model parameters from %s" % tf.train.latest_checkpoint(latest_dir, args.name))
model.latest_saver.restore(sess, tf.train.latest_checkpoint(latest_dir, args.name))
else:
if tf.train.get_checkpoint_state(best_dir, args.name) and args.restore == "best":
print('Reading model parameters from %s' % tf.train.latest_checkpoint(best_dir, args.name))
model.best_saver.restore(sess, tf.train.latest_checkpoint(best_dir, args.name))
else:
print("Created model with fresh parameters.")
global_variable = [gv for gv in tf.global_variables() if args.name in gv.name]
sess.run(tf.variables_initializer(global_variable))
return model
def main(args):
if args.debug:
debug()
if args.cuda:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
else:
config = tf.ConfigProto(device_count={'GPU': 0})
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
data_class = SingleTurnDialog.load_class(args.dataset)
wordvec_class = WordVector.load_class(args.wvclass)
if wordvec_class == None:
wordvec_class = Glove
if args.cache:
data = try_cache(data_class, (args.datapath,), args.cache_dir)
vocab = data.vocab_list
embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
(args.wvpath, args.embedding_size, vocab),
args.cache_dir, wordvec_class.__name__)
else:
data = data_class(args.datapath)
wv = wordvec_class(args.wvpath)
vocab = data.vocab_list
embed = wv.load_matrix(args.embedding_size, vocab)
embed = np.array(embed, dtype = np.float32)
with tf.Session(config=config) as sess:
model = create_model(sess, data, args, embed)
if args.mode == "train":
model.train_process(sess, data, args)
else:
test_res = model.test_process(sess, data, args)
for key, val in test_res.items():
if isinstance(val, bytes):
test_res[key] = str(val)
json.dump(test_res, open("./result.json", "w"))