Skip to content

Commit

Permalink
completed more modular approach
Browse files Browse the repository at this point in the history
  • Loading branch information
andhus committed Jul 31, 2017
1 parent 36a0ea4 commit b148b3e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
11 changes: 8 additions & 3 deletions examples/distribution/mog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from keras.engine import Model
from keras.layers import Lambda, TimeDistributed

from extkeras.layers.distribution import MoGParams1D, neg_log_mog_pdf_1d
from extkeras.layers.distribution import (
MixtureOfGaussian1D,
DistributionOutputLayer
)

n_timesteps = 3
n_features = 5
Expand All @@ -15,10 +18,12 @@
features = Input((n_timesteps, n_features), name='features')
target = Input((n_timesteps, 1), name='target')

params_layer = TimeDistributed(MoGParams1D(components=n_components))
distribution = MixtureOfGaussian1D(n_components=n_components)

params_layer = TimeDistributed(DistributionOutputLayer(distribution))
params = params_layer(features)

loss_layer = Lambda(lambda y_true_pred: neg_log_mog_pdf_1d(*y_true_pred))
loss_layer = Lambda(lambda y_true_pred: distribution.loss(*y_true_pred))
loss = loss_layer([target, params])

params_and_loss_model = Model(inputs=[target, features], outputs=[params, loss])
Expand Down
83 changes: 41 additions & 42 deletions extkeras/layers/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,39 @@ class DistributionBase(object):

@abc.abstractmethod
def activation(self, x):
"""apply activation function on input tensor"""
pass

@abc.abstractmethod
def loss(self, ):
def loss(self, y_true, y_pred):
"""Implementation of standard loss for this this distribution normally
-log(pdf(y_true)) where pdf is parameterized by y_pred
"""
pass

@abc.abstractproperty
def n_params(self):
pass


class MixtureParamsActivationBase(FunctionalBlock):
class MixtureDistributionBase(DistributionBase):
__metaclass__ = abc.ABCMeta

param_type_names = ('mixture_weight',)
n_param_types = len(param_type_names)
n_param_types = None

def __init__(self):
self.mixture_weight_activation = softmax
def __init__(self, n_components):
self.n_components = n_components

@property
def mixture_weight_activation(self):
return softmax

@classmethod
def split_params(cls, x):
"""Splits input tensor into the """
def split_param_types(cls, x):
"""Splits input tensor into the different param types.
Assumes same number of parameters for each param type...
"""
# TODO use n_components instead?
dim = x.shape[-1].value
if not dim % cls.n_param_types == 0:
raise ValueError(
Expand All @@ -65,40 +79,34 @@ def split_params(cls, x):
)
components = dim // cls.n_param_types
param_types = [
x[..., i*components:(i+1)*components] for i in range(cls.n_param_types)
x[..., i*components:(i+1)*components]
for i in range(cls.n_param_types)
]

return param_types

@abc.abstractmethod
def __call__(self, x):
"""apply activation function on input tensor"""
pass

@classmethod
def loss(cls, y_true, y_pred):
"""Implementation of standard loss for this this distribution normally
-log(pdf(y_true)) where pdf is parameterized by y_pred
"""
raise NotImplementedError('')
@property
def n_params(self):
return self.n_param_types * self.n_components


class MoGParams1DActivation(MixtureParamsActivationBase):
class MixtureOfGaussian1D(MixtureDistributionBase):

param_type_names = ('mixture_weight', 'mu', 'sigma')
n_param_types = len(param_type_names)

def __init__(
self,
n_components,
mu_activation=None,
sigma_activation=None,
):
super(MoGParams1DActivation, self).__init__()
super(MixtureOfGaussian1D, self).__init__(n_components)
self.mu_activation = mu_activation or (lambda x: x)
self.sigma_activation = sigma_activation or ScaledExponential()

def __call__(self, x):
_mixture_weights, _mu, _sigma = self.split_params(x)
def activation(self, x):
_mixture_weights, _mu, _sigma = self.split_param_types(x)
mixture_weights = self.mixture_weight_activation(_mixture_weights)
mu = self.mu_activation(_mu)
sigma = self.sigma_activation(_sigma)
Expand All @@ -109,7 +117,7 @@ def __call__(self, x):
def loss(cls, y_true, y_pred):
"""TODO document and check dims of inputs
"""
mixture_weights, mu, sigma, = cls.split_params(y_pred)
mixture_weights, mu, sigma, = cls.split_param_types(y_pred)
norm = 1. / (np.sqrt(2. * np.pi) * sigma)
exponent = -(
K.square(y_true - mu) / (2. * K.square(sigma)) -
Expand All @@ -119,22 +127,13 @@ def loss(cls, y_true, y_pred):
return -K.logsumexp(exponent, axis=-1)


neg_log_mog_pdf_1d = MoGParams1DActivation.loss
class DistributionOutputLayer(Dense):


class MoGParams1D(Dense):

def __init__(
self,
components,
mu_activation=None,
sigma_activation=None
):
super(MoGParams1D, self).__init__(
units=3 * components,
activation=MoGParams1DActivation(
mu_activation=mu_activation,
sigma_activation=sigma_activation,
)
def __init__(self, distribution, **kwargs):
self.distribution = distribution
if 'units' in kwargs or 'activation' in kwargs:
raise ValueError('') # TODO
super(DistributionOutputLayer, self).__init__(
units=distribution.n_params,
activation=distribution.activation
)
self.n_components = components

0 comments on commit b148b3e

Please sign in to comment.