1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # @Date : 2017-11-13 22:09:37
4
+ # @Author : jimmy ([email protected] )
5
+ # @Link : http://sdcs.sysu.edu.cn
6
+ # @Version : $Id$
7
+
8
+ import os
9
+ from copy import deepcopy
10
+ import pickle
11
+ import random
12
+ import numpy as np
13
+ import time
14
+ import datetime
15
+
16
+ import loss
17
+
18
+ class Triple (object ):
19
+ def __init__ (self , head , tail , relation ):
20
+ self .h = head
21
+ self .t = tail
22
+ self .r = relation
23
+
24
+ # Compare two Triples in the order of head, relation and tail
25
+ def cmp_head (a , b ):
26
+ return (a .h < b .h or (a .h == b .h and a .r < b .r ) or (a .h == b .h and a .r == b .r and a .t < b .t ))
27
+
28
+ # Compare two Triples in the order of tail, relation and head
29
+ def cmp_tail (a , b ):
30
+ return (a .t < b .t or (a .t == b .t and a .r < b .r ) or (a .t == b .t and a .r == b .r and a .h < b .h ))
31
+
32
+ # Compare two Triples in the order of relation, head and tail
33
+ def cmp_rel (a , b ):
34
+ return (a .r < b .r or (a .r == b .r and a .h < b .h ) or (a .r == b .r and a .h == b .h and a .t < b .t ))
35
+
36
+ def minimal (a , b ):
37
+ if a > b :
38
+ return b
39
+ return a
40
+
41
+ def cmp_list (a , b ):
42
+ return (minimal (a .h , a .t ) > minimal (b .h , b .t ))
43
+
44
+ emptyTriple = Triple (0 , 0 , 0 )
45
+
46
+ # Calculate the statistics of datasets
47
+ def calculate (datasetPath ):
48
+ with open (os .path .join (datasetPath , 'relation2id.txt' ), 'r' ) as fr :
49
+ for line in fr :
50
+ relationTotal = int (line )
51
+ break
52
+
53
+ freqRel = [0 ] * relationTotal # The frequency of each relation
54
+
55
+ with open (os .path .join (datasetPath , 'entity2id.txt' ), 'r' ) as fr :
56
+ for line in fr :
57
+ entityTotal = int (line )
58
+ break
59
+
60
+ freqEnt = [0 ] * entityTotal # The frequency of each entity
61
+
62
+ tripleHead = []
63
+ tripleTail = []
64
+ tripleList = []
65
+
66
+ tripleTotal = 0
67
+ with open (os .path .join (datasetPath , 'train2id.txt' ), 'r' ) as fr :
68
+ i = 0
69
+ for line in fr :
70
+ # Ignore the first line, which is the number of triples
71
+ if i == 0 :
72
+ i += 1
73
+ continue
74
+ else :
75
+ line_split = line .split ()
76
+ head = int (line_split [0 ])
77
+ tail = int (line_split [1 ])
78
+ rel = int (line_split [2 ])
79
+ tripleHead .append (Triple (head , tail , rel ))
80
+ tripleTail .append (Triple (head , tail , rel ))
81
+ tripleList .append (Triple (head , tail , rel ))
82
+ freqEnt [head ] += 1
83
+ freqEnt [tail ] += 1
84
+ freqRel [rel ] += 1
85
+ tripleTotal += 1
86
+
87
+ with open (os .path .join (datasetPath , 'valid2id.txt' ), 'r' ) as fr :
88
+ i = 0
89
+ for line in fr :
90
+ if i == 0 :
91
+ i += 1
92
+ continue
93
+ else :
94
+ line_split = line .split ()
95
+ head = int (line_split [0 ])
96
+ tail = int (line_split [1 ])
97
+ rel = int (line_split [2 ])
98
+ tripleHead .append (Triple (head , tail , rel ))
99
+ tripleTail .append (Triple (head , tail , rel ))
100
+ tripleList .append (Triple (head , tail , rel ))
101
+ freqEnt [head ] += 1
102
+ freqEnt [tail ] += 1
103
+ freqRel [rel ] += 1
104
+ tripleTotal += 1
105
+
106
+ with open (os .path .join (datasetPath , 'test2id.txt' ), 'r' ) as fr :
107
+ i = 0
108
+ for line in fr :
109
+ if i == 0 :
110
+ i += 1
111
+ continue
112
+ else :
113
+ line_split = line .split ()
114
+ head = int (line_split [0 ])
115
+ tail = int (line_split [1 ])
116
+ rel = int (line_split [2 ])
117
+ tripleHead .append (Triple (head , tail , rel ))
118
+ tripleTail .append (Triple (head , tail , rel ))
119
+ tripleList .append (Triple (head , tail , rel ))
120
+ freqEnt [head ] += 1
121
+ freqEnt [tail ] += 1
122
+ freqRel [rel ] += 1
123
+ tripleTotal += 1
124
+
125
+ tripleHead .sort (key = lambda x : (x .h , x .r , x .t ))
126
+ tripleTail .sort (key = lambda x : (x .t , x .r , x .h ))
127
+
128
+ headDict = {}
129
+ tailDict = {}
130
+ for triple in tripleList :
131
+ if triple .r not in headDict :
132
+ headDict [triple .r ] = {}
133
+ tailDict [triple .r ] = {}
134
+ headDict [triple .r ][triple .h ] = set ([triple .t ])
135
+ tailDict [triple .r ][triple .t ] = set ([triple .h ])
136
+ else :
137
+ if triple .h not in headDict [triple .r ]:
138
+ headDict [triple .r ][triple .h ] = set ([triple .t ])
139
+ else :
140
+ headDict [triple .r ][triple .h ].add (triple .t )
141
+
142
+ if triple .t not in tailDict [triple .r ]:
143
+ tailDict [triple .r ][triple .t ] = set ([triple .h ])
144
+ else :
145
+ tailDict [triple .r ][triple .t ].add (triple .h )
146
+
147
+ tail_per_head = [0 ] * relationTotal
148
+ head_per_tail = [0 ] * relationTotal
149
+
150
+ for rel in headDict :
151
+ heads = headDict [rel ].keys ()
152
+ tails = headDict [rel ].values ()
153
+ totalHeads = len (heads )
154
+ totalTails = sum ([len (elem ) for elem in tails ])
155
+ tail_per_head [rel ] = totalTails / totalHeads
156
+
157
+ for rel in tailDict :
158
+ tails = tailDict [rel ].keys ()
159
+ heads = tailDict [rel ].values ()
160
+ totalTails = len (tails )
161
+ totalHeads = sum ([len (elem ) for elem in heads ])
162
+ head_per_tail [rel ] = totalHeads / totalTails
163
+
164
+ connectedTailDict = {}
165
+ for rel in headDict :
166
+ if rel not in connectedTailDict :
167
+ connectedTailDict [rel ] = set ()
168
+ for head in headDict [rel ]:
169
+ connectedTailDict [rel ] = connectedTailDict [rel ].union (headDict [rel ][head ])
170
+
171
+ connectedHeadDict = {}
172
+ for rel in tailDict :
173
+ if rel not in connectedHeadDict :
174
+ connectedHeadDict [rel ] = set ()
175
+ for tail in tailDict [rel ]:
176
+ connectedHeadDict [rel ] = connectedHeadDict [rel ].union (tailDict [rel ][tail ])
177
+
178
+ print (tail_per_head )
179
+ print (head_per_tail )
180
+
181
+ listTripleHead = [(triple .h , triple .t , triple .r ) for triple in tripleHead ]
182
+ listTripleTail = [(triple .h , triple .t , triple .r ) for triple in tripleTail ]
183
+ listTripleList = [(triple .h , triple .t , triple .r ) for triple in tripleList ]
184
+ with open (os .path .join (datasetPath , 'head_tail_proportion.pkl' ), 'wb' ) as fw :
185
+ pickle .dump (tail_per_head , fw )
186
+ pickle .dump (head_per_tail , fw )
187
+
188
+ with open (os .path .join (datasetPath , 'head_tail_connection.pkl' ), 'wb' ) as fw :
189
+ pickle .dump (connectedTailDict , fw )
190
+ pickle .dump (connectedHeadDict , fw )
191
+
192
+ def getRel (triple ):
193
+ return triple .r
194
+
195
+ def getAnythingTotal (inPath , fileName ):
196
+ with open (os .path .join (inPath , fileName ), 'r' ) as fr :
197
+ for line in fr :
198
+ return int (line )
199
+
200
+ def loadTriple (inPath , fileName ):
201
+ with open (os .path .join (inPath , fileName ), 'r' ) as fr :
202
+ i = 0
203
+ tripleList = []
204
+ for line in fr :
205
+ if i == 0 :
206
+ tripleTotal = int (line )
207
+ i += 1
208
+ else :
209
+ line_split = line .split ()
210
+ head = int (line_split [0 ])
211
+ tail = int (line_split [1 ])
212
+ rel = int (line_split [2 ])
213
+ tripleList .append (Triple (head , tail , rel ))
214
+
215
+ tripleDict = {}
216
+ for triple in tripleList :
217
+ tripleDict [(triple .h , triple .t , triple .r )] = True
218
+
219
+ return tripleTotal , tripleList , tripleDict
220
+
221
+ def which_loss_type (num ):
222
+ if num == 0 :
223
+ return loss .marginLoss
224
+ elif num == 1 :
225
+ return loss .EMLoss
226
+ elif num == 2 :
227
+ return loss .WGANLoss
228
+ elif num == 3 :
229
+ return nn .MSELoss
0 commit comments