-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathalias.py
90 lines (72 loc) · 2.25 KB
/
alias.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
class AliasSampler(object):
"""
Alias sampler for arange topics.
p: store probability `np.ndarray`, such as `[0.1, 0.5, 0.4]`.
"""
def __init__(self, p: np.ndarray):
self._build_table(p)
def sample(self):
u, k = np.modf(np.random.rand() * self._K)
k = int(k)
if u < self.v[k]:
return k
else:
return self.a[k]
def _build_table(self, p: np.ndarray):
self._K = len(p)
p /= np.sum(p)
self.a = np.zeros(self._K, dtype=np.uint32)
self.v = np.array(self._K * p)
L, S = [], []
for k, vk in enumerate(self.v):
if 1. <= vk:
L.append(k)
else:
S.append(k)
while len(L) > 0 and len(S) > 0:
l = L.pop()
s = S.pop()
self.a[s] = l
self.v[l] -= (1. - self.v[s])
if 1. > self.v[l]:
S.append(l)
else:
L.append(l)
class SparseAliasSampler(AliasSampler):
"""
alias class for not arange topics (sparse)
p: store probability `np.ndarray` such as [0.1, 0.5, 0.4].
topics: np.array or list such as `[100, 20, 1000]`
"""
def __init__(self, p: np.ndarray, topics: np.ndarray):
assert len(p) == len(topics), "the length of `p` and `topics` should be same. "
self._build_table(p, topics)
def sample(self):
u, k = np.modf(np.random.rand() * self._K)
k = int(k)
if u < self.v[k]:
return self.topics[k]
else:
return self.a[k]
def _build_table(self, p: np.ndarray, topics: np.ndarray):
self._K = len(p)
p /= np.sum(p)
self.a = np.zeros(self._K, dtype=np.uint32)
self.v = np.array(self._K * p)
self.topics = topics
L, S = [], []
for k, vk in enumerate(self.v):
if 1. <= vk:
L.append(k)
else:
S.append(k)
while len(L) > 0 and len(S) > 0:
l = L.pop()
s = S.pop()
self.a[s] = self.topics[l]
self.v[l] -= (1. - self.v[s])
if 1. > self.v[l]:
S.append(l)
else:
L.append(l)