Skip to content

Commit b81dcda

Browse files
committed
Add head-tail proportion
1 parent 968f91a commit b81dcda

File tree

7 files changed

+229
-0
lines changed

7 files changed

+229
-0
lines changed

.DS_Store

-2 KB
Binary file not shown.

datasets/.DS_Store

14 KB
Binary file not shown.
250 Bytes
Binary file not shown.
23.7 KB
Binary file not shown.
214 Bytes
Binary file not shown.
340 Bytes
Binary file not shown.

utils.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# @Date : 2017-11-13 22:09:37
4+
# @Author : jimmy ([email protected])
5+
# @Link : http://sdcs.sysu.edu.cn
6+
# @Version : $Id$
7+
8+
import os
9+
from copy import deepcopy
10+
import pickle
11+
import random
12+
import numpy as np
13+
import time
14+
import datetime
15+
16+
import loss
17+
18+
class Triple(object):
19+
def __init__(self, head, tail, relation):
20+
self.h = head
21+
self.t = tail
22+
self.r = relation
23+
24+
# Compare two Triples in the order of head, relation and tail
25+
def cmp_head(a, b):
26+
return (a.h < b.h or (a.h == b.h and a.r < b.r) or (a.h == b.h and a.r == b.r and a.t < b.t))
27+
28+
# Compare two Triples in the order of tail, relation and head
29+
def cmp_tail(a, b):
30+
return (a.t < b.t or (a.t == b.t and a.r < b.r) or (a.t == b.t and a.r == b.r and a.h < b.h))
31+
32+
# Compare two Triples in the order of relation, head and tail
33+
def cmp_rel(a, b):
34+
return (a.r < b.r or (a.r == b.r and a.h < b.h) or (a.r == b.r and a.h == b.h and a.t < b.t))
35+
36+
def minimal(a, b):
37+
if a > b:
38+
return b
39+
return a
40+
41+
def cmp_list(a, b):
42+
return (minimal(a.h, a.t) > minimal(b.h, b.t))
43+
44+
emptyTriple = Triple(0, 0, 0)
45+
46+
# Calculate the statistics of datasets
47+
def calculate(datasetPath):
48+
with open(os.path.join(datasetPath, 'relation2id.txt'), 'r') as fr:
49+
for line in fr:
50+
relationTotal = int(line)
51+
break
52+
53+
freqRel = [0] * relationTotal # The frequency of each relation
54+
55+
with open(os.path.join(datasetPath, 'entity2id.txt'), 'r') as fr:
56+
for line in fr:
57+
entityTotal = int(line)
58+
break
59+
60+
freqEnt = [0] * entityTotal # The frequency of each entity
61+
62+
tripleHead = []
63+
tripleTail = []
64+
tripleList = []
65+
66+
tripleTotal = 0
67+
with open(os.path.join(datasetPath, 'train2id.txt'), 'r') as fr:
68+
i = 0
69+
for line in fr:
70+
# Ignore the first line, which is the number of triples
71+
if i == 0:
72+
i += 1
73+
continue
74+
else:
75+
line_split = line.split()
76+
head = int(line_split[0])
77+
tail = int(line_split[1])
78+
rel = int(line_split[2])
79+
tripleHead.append(Triple(head, tail, rel))
80+
tripleTail.append(Triple(head, tail, rel))
81+
tripleList.append(Triple(head, tail, rel))
82+
freqEnt[head] += 1
83+
freqEnt[tail] += 1
84+
freqRel[rel] += 1
85+
tripleTotal += 1
86+
87+
with open(os.path.join(datasetPath, 'valid2id.txt'), 'r') as fr:
88+
i = 0
89+
for line in fr:
90+
if i == 0:
91+
i += 1
92+
continue
93+
else:
94+
line_split = line.split()
95+
head = int(line_split[0])
96+
tail = int(line_split[1])
97+
rel = int(line_split[2])
98+
tripleHead.append(Triple(head, tail, rel))
99+
tripleTail.append(Triple(head, tail, rel))
100+
tripleList.append(Triple(head, tail, rel))
101+
freqEnt[head] += 1
102+
freqEnt[tail] += 1
103+
freqRel[rel] += 1
104+
tripleTotal += 1
105+
106+
with open(os.path.join(datasetPath, 'test2id.txt'), 'r') as fr:
107+
i = 0
108+
for line in fr:
109+
if i == 0:
110+
i += 1
111+
continue
112+
else:
113+
line_split = line.split()
114+
head = int(line_split[0])
115+
tail = int(line_split[1])
116+
rel = int(line_split[2])
117+
tripleHead.append(Triple(head, tail, rel))
118+
tripleTail.append(Triple(head, tail, rel))
119+
tripleList.append(Triple(head, tail, rel))
120+
freqEnt[head] += 1
121+
freqEnt[tail] += 1
122+
freqRel[rel] += 1
123+
tripleTotal += 1
124+
125+
tripleHead.sort(key=lambda x: (x.h, x.r, x.t))
126+
tripleTail.sort(key=lambda x: (x.t, x.r, x.h))
127+
128+
headDict = {}
129+
tailDict = {}
130+
for triple in tripleList:
131+
if triple.r not in headDict:
132+
headDict[triple.r] = {}
133+
tailDict[triple.r] = {}
134+
headDict[triple.r][triple.h] = set([triple.t])
135+
tailDict[triple.r][triple.t] = set([triple.h])
136+
else:
137+
if triple.h not in headDict[triple.r]:
138+
headDict[triple.r][triple.h] = set([triple.t])
139+
else:
140+
headDict[triple.r][triple.h].add(triple.t)
141+
142+
if triple.t not in tailDict[triple.r]:
143+
tailDict[triple.r][triple.t] = set([triple.h])
144+
else:
145+
tailDict[triple.r][triple.t].add(triple.h)
146+
147+
tail_per_head = [0] * relationTotal
148+
head_per_tail = [0] * relationTotal
149+
150+
for rel in headDict:
151+
heads = headDict[rel].keys()
152+
tails = headDict[rel].values()
153+
totalHeads = len(heads)
154+
totalTails = sum([len(elem) for elem in tails])
155+
tail_per_head[rel] = totalTails / totalHeads
156+
157+
for rel in tailDict:
158+
tails = tailDict[rel].keys()
159+
heads = tailDict[rel].values()
160+
totalTails = len(tails)
161+
totalHeads = sum([len(elem) for elem in heads])
162+
head_per_tail[rel] = totalHeads / totalTails
163+
164+
connectedTailDict = {}
165+
for rel in headDict:
166+
if rel not in connectedTailDict:
167+
connectedTailDict[rel] = set()
168+
for head in headDict[rel]:
169+
connectedTailDict[rel] = connectedTailDict[rel].union(headDict[rel][head])
170+
171+
connectedHeadDict = {}
172+
for rel in tailDict:
173+
if rel not in connectedHeadDict:
174+
connectedHeadDict[rel] = set()
175+
for tail in tailDict[rel]:
176+
connectedHeadDict[rel] = connectedHeadDict[rel].union(tailDict[rel][tail])
177+
178+
print(tail_per_head)
179+
print(head_per_tail)
180+
181+
listTripleHead = [(triple.h, triple.t, triple.r) for triple in tripleHead]
182+
listTripleTail = [(triple.h, triple.t, triple.r) for triple in tripleTail]
183+
listTripleList = [(triple.h, triple.t, triple.r) for triple in tripleList]
184+
with open(os.path.join(datasetPath, 'head_tail_proportion.pkl'), 'wb') as fw:
185+
pickle.dump(tail_per_head, fw)
186+
pickle.dump(head_per_tail, fw)
187+
188+
with open(os.path.join(datasetPath, 'head_tail_connection.pkl'), 'wb') as fw:
189+
pickle.dump(connectedTailDict, fw)
190+
pickle.dump(connectedHeadDict, fw)
191+
192+
def getRel(triple):
193+
return triple.r
194+
195+
def getAnythingTotal(inPath, fileName):
196+
with open(os.path.join(inPath, fileName), 'r') as fr:
197+
for line in fr:
198+
return int(line)
199+
200+
def loadTriple(inPath, fileName):
201+
with open(os.path.join(inPath, fileName), 'r') as fr:
202+
i = 0
203+
tripleList = []
204+
for line in fr:
205+
if i == 0:
206+
tripleTotal = int(line)
207+
i += 1
208+
else:
209+
line_split = line.split()
210+
head = int(line_split[0])
211+
tail = int(line_split[1])
212+
rel = int(line_split[2])
213+
tripleList.append(Triple(head, tail, rel))
214+
215+
tripleDict = {}
216+
for triple in tripleList:
217+
tripleDict[(triple.h, triple.t, triple.r)] = True
218+
219+
return tripleTotal, tripleList, tripleDict
220+
221+
def which_loss_type(num):
222+
if num == 0:
223+
return loss.marginLoss
224+
elif num == 1:
225+
return loss.EMLoss
226+
elif num == 2:
227+
return loss.WGANLoss
228+
elif num == 3:
229+
return nn.MSELoss

0 commit comments

Comments
 (0)