Skip to content

Commit

Permalink
Configurable pooling operator for SharedProdParameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
gvtulder committed Mar 29, 2013
1 parent 749e20c commit e76e87a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions morb/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def energy_term(self, vmap):


class SharedProdParameters(Parameters):
def __init__(self, rbm, units_list, dimensions, shared_dimensions, W, name=None, energy_multiplier=1):
def __init__(self, rbm, units_list, dimensions, shared_dimensions, W, name=None, energy_multiplier=1, pooling_operator=T.mean):
super(SharedProdParameters, self).__init__(rbm, units_list, name=name, energy_multiplier = energy_multiplier)
assert len(units_list) == 2
self.var = W
Expand All @@ -168,8 +168,10 @@ def __init__(self, rbm, units_list, dimensions, shared_dimensions, W, name=None,
self.hsd = shared_dimensions
self.hnd = self.hud - self.hsd

self.pooling_operator = pooling_operator

def from_hu(m, vmap):
return T.mean(m, axis=self._shared_axes(vmap))
return self.pooling_operator(m, axis=self._shared_axes(vmap))

def to_hu(m):
return T.shape_padright(m, self.hsd)
Expand All @@ -195,7 +197,7 @@ def weights_for(self, units):
return self.var.dimshuffle('x', 0, 'x', 'x', 1)

def energy_term(self, vmap):
return - T.sum(T.dot(vmap[self.vu], self.var) * T.mean(vmap[self.hu], axis=self._shared_axes(vmap)), axis=1)
return - T.sum(T.dot(vmap[self.vu], self.var) * self.pooling_operator(vmap[self.hu], axis=self._shared_axes(vmap)), axis=1)

t = tensordot(vmap[self.hu], self.var, axes=(range(1, self.hnd+1), range(0, self.hnd)))
axes = range(t.ndim - self.hsd, t.ndim)
Expand Down

0 comments on commit e76e87a

Please sign in to comment.