Skip to content

Commit 3d11b33

Browse files
committed
add py file
1 parent 46b1861 commit 3d11b33

File tree

2 files changed

+225
-0
lines changed

2 files changed

+225
-0
lines changed

TransE/rebuild_relation2vec.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# -*- coding:utf-8 -*-
2+
import numpy as np
3+
import codecs
4+
'''
5+
用h+r求t的效果很差,因为学到的r表示能力太弱了,考虑用下面的方法对r向量进行重构。
6+
方法:对同一个r,在训练集中用t-h求和的均值作为r向量。
7+
8+
@chenbingjin 2016-05-07
9+
'''
10+
entity2id = {}
11+
id2entity = {}
12+
relation2id = {}
13+
id2relation = {}
14+
entity2vec = {}
15+
relation2vec = {}
16+
relationsets = {} # 关系集,每个r对应一个实体二元组列表[(h,t),(h`,t`)...]
17+
n_dim = 50 # 向量维数
18+
d_type = np.float32 # 向量数据类型 float32/float64
19+
20+
def init_entity_id():
21+
print "Loading entity2id ..."
22+
with codecs.open('./data/entity2id.txt','r',encoding='utf-8') as file:
23+
for line in file:
24+
e2d = line.strip().split('\t')
25+
key = e2d[0]
26+
eid = int(e2d[1])
27+
entity2id[key] = eid
28+
id2entity[eid] = key
29+
30+
def init_relation_id():
31+
print "Loading relation2id ..."
32+
with codecs.open('./data/relation2id.txt','r',encoding='utf-8') as f:
33+
for line in f:
34+
r2d = line.strip().split('\t')
35+
rel = r2d[0]
36+
rid = int(r2d[1])
37+
relation2id[rel] = rid
38+
id2relation[rid] = rel
39+
40+
def init_entity_vector():
41+
print "Loading entity2vec ..."
42+
eid = 0
43+
with open('./vec/entity2vec.bern') as f:
44+
for line in f:
45+
vv = line.strip().split('\t')
46+
entity2vec[eid] = np.array(vv,dtype=d_type)
47+
eid += 1
48+
49+
# 数据准备
50+
def prepare():
51+
print "Data Preparing ..."
52+
init_entity_id()
53+
init_relation_id()
54+
init_entity_vector()
55+
56+
def save(filename, final):
57+
arr = np.zeros((len(relation2id),50))
58+
for rid in final:
59+
arr[rid[0]] = rid[1]
60+
np.savetxt(filename, arr, fmt='%.6f', delimiter='\t')
61+
62+
# 重构r向量
63+
def run():
64+
print "relation2id len:",len(relation2id)
65+
print "Build Relation sets ..."
66+
with codecs.open('./data/train.txt','r',encoding='utf-8') as f:
67+
for line in f:
68+
triplet = line.strip().split('\t')
69+
if len(triplet) != 3:
70+
continue
71+
# 获得实体和关系的id
72+
h = entity2id[triplet[0]]
73+
r = relation2id[triplet[1]]
74+
t = entity2id[triplet[2]]
75+
# 关系集
76+
if r not in relationsets:
77+
relationsets[r] = []
78+
relationsets[r].append((h,t))
79+
80+
print "relationsets len:", len(relationsets)
81+
# t-h 加和平均得到r向量
82+
for rel in relationsets:
83+
rel_vec = np.zeros((1,n_dim),dtype=d_type) # 初始化0向量
84+
for tup in relationsets[rel]:
85+
t = tup[1]
86+
h = tup[0]
87+
rel_vec += (entity2vec[t]-entity2vec[h])
88+
rel_vec = rel_vec/len(relationsets[rel])
89+
relation2vec[rel] = rel_vec
90+
91+
final = sorted(relation2vec.iteritems(), key=lambda x:x[0], reverse=False)
92+
print final[:2]
93+
94+
save('./vec/relation2vec.new', final)
95+
96+
97+
if __name__ == '__main__':
98+
prepare()
99+
run()

TransE/test_similarity.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# -*- coding: utf-8 -*-
2+
import codecs
3+
import numpy as np
4+
from scipy.spatial.distance import cosine
5+
import sys
6+
reload(sys)
7+
sys.setdefaultencoding('utf-8')
8+
'''
9+
find topK similar entities.
10+
@chenbingjin 2016-05-06
11+
'''
12+
entity2id = {}
13+
id2entity = {}
14+
relation2id = {}
15+
id2relation = {}
16+
entity2vec = {}
17+
relation2vec = {}
18+
19+
def init_entity_id():
20+
print "load entity2id ..."
21+
with codecs.open('./data/entity2id.txt','r',encoding='utf-8') as file:
22+
for line in file:
23+
e2d = line.strip().split('\t')
24+
key = e2d[0]
25+
eid = int(e2d[1])
26+
entity2id[key] = eid
27+
id2entity[eid] = key
28+
29+
def init_relation_id():
30+
print "load relation2id ..."
31+
with codecs.open('./data/relation2id.txt','r',encoding='utf-8') as f:
32+
for line in f:
33+
r2d = line.strip().split('\t')
34+
rel = r2d[0]
35+
rid = int(r2d[1])
36+
relation2id[rel] = rid
37+
id2relation[rid] = rel
38+
39+
def init_entity_vector():
40+
eid = 0
41+
print "load entity2vec ..."
42+
with open('./vec/entity2vec.bern') as f:
43+
for line in f:
44+
vv = line.strip().split('\t')
45+
entity2vec[eid] = np.array(vv,dtype=np.float32)
46+
eid += 1
47+
48+
def init_relation_vector():
49+
rid = 0
50+
print "load relation2vec ..."
51+
with open('./vec/relation2vec.bern') as f:
52+
for line in f:
53+
vv = line.strip().split('\t')
54+
relation2vec[rid] = np.array(vv,dtype=np.float32)
55+
rid += 1
56+
57+
'''
58+
获取与head+rel相似的实体
59+
'''
60+
def sim_cosine(head, rel):
61+
print 'finding entity (h+r=t) ...'
62+
ee = entity2vec[entity2id[head]]+relation2vec[relation2id[rel]]
63+
ans = {}
64+
for en in entity2id:
65+
eid = entity2id[en]
66+
evec = entity2vec[eid]
67+
ans[eid] = cosine(ee,evec)
68+
ans = sorted(ans.iteritems(),key=lambda x:x[1],reverse=False)
69+
return ans
70+
71+
def sim_relation(head, tail):
72+
print 'finding relation (r=t-h) ...'
73+
rr = entity2vec[entity2id[tail]]-entity2vec[entity2id[head]]
74+
ans = {}
75+
for rel in relation2id:
76+
rid = relation2id[rel]
77+
rvec = relation2vec[rid]
78+
ans[rid] = cosine(rr,rvec)
79+
ans = sorted(ans.iteritems(),key=lambda x:x[1],reverse=False)
80+
return ans
81+
'''
82+
获取相似的实体
83+
'''
84+
def sim_entity(head):
85+
ee = entity2vec[entity2id[head]]
86+
ans = {}
87+
print 'similar entity finding ...'
88+
for en in entity2id:
89+
eid = entity2id[en]
90+
evec = entity2vec[eid]
91+
ans[eid] = cosine(ee,evec)
92+
ans = sorted(ans.iteritems(),key=lambda x:x[1],reverse=False)
93+
return ans
94+
95+
if __name__ == '__main__':
96+
init_entity_id()
97+
init_relation_id()
98+
init_entity_vector()
99+
init_relation_vector()
100+
print "\nThree choices: "
101+
print "\t0.entity similarity;\n\t1.find t (t = h+r);\n\t2.find r (r = t-h)\n"
102+
while True:
103+
x = input("input choice (0/1/2):")
104+
if int(x) == 0:
105+
entity = raw_input("input entity: ")
106+
sim = sim_entity(unicode(entity))
107+
print "-----------top30---------------"
108+
for e in sim[:30]:
109+
eid = e[0]
110+
print id2entity[eid]
111+
elif int(x) == 1:
112+
head = raw_input("input head entity: ")
113+
rel = raw_input("input relation: ")
114+
sim = sim_cosine(unicode(head), unicode(rel))
115+
print "-----------top30---------------"
116+
for e in sim[:30]:
117+
eid = e[0]
118+
print id2entity[eid]
119+
else:
120+
head = raw_input("input head entity: ")
121+
tail = raw_input("input tail entity: ")
122+
sim = sim_relation(unicode(head), unicode(tail))
123+
print "-----------top30---------------"
124+
for e in sim[:30]:
125+
rid = e[0]
126+
print id2relation[rid]

0 commit comments

Comments
 (0)