forked from nakaizura/Source-Code-Notebook
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
63 lines (52 loc) · 2.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
'''
Created on May 29, 2020
@author: nakaizura
'''
import tensorflow as tf
import numpy as np
from model import RippleNet
def train(args, data_info, show_loss):
train_data = data_info[0]
eval_data = data_info[1]
test_data = data_info[2]
n_entity = data_info[3]
n_relation = data_info[4]
ripple_set = data_info[5]
model = RippleNet(args, n_entity, n_relation)#实例化模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(args.n_epoch):
# training,开始训练
np.random.shuffle(train_data)#打乱训练数据
start = 0
while start < train_data.shape[0]:#计算loss
_, loss = model.train(
sess, get_feed_dict(args, model, train_data, ripple_set, start, start + args.batch_size))
start += args.batch_size
if show_loss:
print('%.1f%% %.4f' % (start / train_data.shape[0] * 100, loss))
# evaluation,开始在三个数据集上进行评估
train_auc, train_acc = evaluation(sess, args, model, train_data, ripple_set, args.batch_size)
eval_auc, eval_acc = evaluation(sess, args, model, eval_data, ripple_set, args.batch_size)
test_auc, test_acc = evaluation(sess, args, model, test_data, ripple_set, args.batch_size)
print('epoch %d train auc: %.4f acc: %.4f eval auc: %.4f acc: %.4f test auc: %.4f acc: %.4f'
% (step, train_auc, train_acc, eval_auc, eval_acc, test_auc, test_acc))
def get_feed_dict(args, model, data, ripple_set, start, end):
feed_dict = dict()
feed_dict[model.items] = data[start:end, 1]
feed_dict[model.labels] = data[start:end, 2]
for i in range(args.n_hop):#喂入ripple_set每一跳的结果
feed_dict[model.memories_h[i]] = [ripple_set[user][i][0] for user in data[start:end, 0]]
feed_dict[model.memories_r[i]] = [ripple_set[user][i][1] for user in data[start:end, 0]]
feed_dict[model.memories_t[i]] = [ripple_set[user][i][2] for user in data[start:end, 0]]
return feed_dict
def evaluation(sess, args, model, data, ripple_set, batch_size):
start = 0
auc_list = []
acc_list = []
while start < data.shape[0]:#只对测试集进行评估
auc, acc = model.eval(sess, get_feed_dict(args, model, data, ripple_set, start, start + batch_size))
auc_list.append(auc)
acc_list.append(acc)
start += batch_size
return float(np.mean(auc_list)), float(np.mean(acc_list))