Skip to content

Commit

Permalink
working 1st version
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmah committed Feb 9, 2016
1 parent 2176d21 commit 60b3686
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
14 changes: 10 additions & 4 deletions theano/sandbox/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,16 @@ def perform(self, node, ins, outs):
(pvals, unis, n_samples) = ins
(z,) = outs

if n_samples > pvals.shape[1]:
raise ValueError("Cannot sample without replacement n samples bigger "
"than the size of the distribution.")

if unis.shape[0] != pvals.shape[0] * n_samples:
raise ValueError("unis.shape[0] != pvals.shape[0] * n_samples",
unis.shape[0], pvals.shape[0], n_samples)
if z[0] is None or numpy.any(z[0].shape != [pvals.shape[0], n_samples]):
z[0] = numpy.zeros((pvals.shape[0], n_samples), dtype=node.outputs[0].dtype)

if z[0] is None or not numpy.all(z[0].shape == [pvals.shape[0], n_samples]):
z[0] = -1 * numpy.ones((pvals.shape[0], n_samples), dtype=node.outputs[0].dtype)

nb_multi = pvals.shape[0]
nb_outcomes = pvals.shape[1]
Expand All @@ -255,9 +260,10 @@ def perform(self, node, ins, outs):
cummul += pvals[n, m]
if (cummul > unis_n):
z[0][n, c] = m
# set to zero so that it's not selected again
# set to zero and re-normalize so that it's not selected again
pvals[n, m] = 0.

pvals[n] /= pvals[n].sum()
break

class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
"""
Expand Down
8 changes: 3 additions & 5 deletions theano/sandbox/rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,8 +1366,9 @@ def multinomial(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
def weighted_selection(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None):
"""
Sample `n` times (`n` needs to be in [1, m], where m is pvals.shape[1], default 1)
*WITHOUT replacement* from a multinomial distribution defined by probabilities pvals.
Sample `n` times *WITHOUT replacement* from a multinomial distribution
defined by probabilities pvals. `n` needs to be in [1, m], where m is the number of
elements to select from, i.e. m == pvals.shape[1]. By default n = 1.
Example : WRITEME
Expand All @@ -1387,9 +1388,6 @@ def weighted_selection(self, size=None, n=1, pvals=None, ndim=None, dtype='int64
raise TypeError("You have to specify pvals")
pvals = as_tensor_variable(pvals)

if n > pvals.shape[1]:
raise ValueError("Cannot sample without replacement n samples bigger "
"than the size of the distribution.")
if size is not None:
raise ValueError("Provided a size argument to "
"MRG_RandomStreams.weighted_selection, which does not use "
Expand Down

0 comments on commit 60b3686

Please sign in to comment.