Skip to content

Commit

Permalink
Update bijectors to lower-bound STS scale parameters away from zero.
Browse files Browse the repository at this point in the history
This may help avoid numeric issues when fitting (nearly)-constant series.

PiperOrigin-RevId: 396407949
  • Loading branch information
davmre authored and tensorflower-gardener committed Sep 13, 2021
1 parent 4b675da commit bd917e0
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 20 deletions.
9 changes: 9 additions & 0 deletions tensorflow_probability/python/sts/components/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -73,6 +74,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -98,6 +100,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -123,6 +126,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -148,6 +152,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand Down Expand Up @@ -175,6 +180,7 @@ py_library(
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:docstring_util",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -201,6 +207,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -226,6 +233,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand All @@ -251,6 +259,7 @@ py_library(
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/sts/internal",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def __init__(self,
coefficient_constraining_bijector),
Parameter('level_scale', level_scale_prior,
tfb.Chain([tfb.Scale(scale=observed_stddev),
tfb.Softplus()]))
tfb.Softplus(low=dtype_util.eps(dtype))]))
],
latent_size=order,
init_parameters=init_parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def __init__(self,
parameters=[
Parameter('drift_scale', drift_scale_prior,
tfb.Chain([tfb.Scale(scale=observed_stddev),
tfb.Softplus()]))
tfb.Softplus(low=dtype_util.eps(dtype))]))
],
latent_size=num_features,
init_parameters=init_parameters,
Expand Down
16 changes: 9 additions & 7 deletions tensorflow_probability/python/sts/components/local_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,17 @@ def __init__(self,
init_parameters = dict(locals())
with tf.name_scope(name or 'LocalLevel') as name:

dtype = dtype_util.common_dtype([level_scale_prior, initial_level_prior])

if observed_time_series is not None:
_, observed_stddev, observed_initial = (
sts_util.empirical_statistics(observed_time_series))
else:
observed_stddev, observed_initial = (tf.convert_to_tensor(
value=1., dtype=dtype), tf.convert_to_tensor(
value=0., dtype=dtype))
observed_stddev, observed_initial = 1., 0.

dtype = dtype_util.common_dtype([level_scale_prior,
initial_level_prior,
observed_stddev,
observed_initial],
dtype_hint=tf.float32)

# Heuristic default priors. Overriding these may dramatically
# change inference performance and results.
Expand All @@ -303,7 +305,7 @@ def __init__(self,
name='level_scale_prior')
if initial_level_prior is None:
self._initial_state_prior = tfd.MultivariateNormalDiag(
loc=observed_initial[..., tf.newaxis],
loc=tf.convert_to_tensor(observed_initial)[..., tf.newaxis],
scale_diag=(
tf.abs(observed_initial) + observed_stddev)[..., tf.newaxis],
name='initial_level_prior')
Expand All @@ -316,7 +318,7 @@ def __init__(self,
parameters=[
Parameter('level_scale', level_scale_prior,
tfb.Chain([tfb.Scale(scale=observed_stddev),
tfb.Softplus()])),
tfb.Softplus(low=dtype_util.eps(dtype))])),
],
latent_size=1,
init_parameters=init_parameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_probability.python import distributions as tfd
from tensorflow_probability.python.distributions import linear_gaussian_ssm
from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers

Expand Down Expand Up @@ -377,7 +378,7 @@ def __init__(self,
initial_slope_prior = tfd.Normal(
loc=0., scale=observed_stddev, name='initial_slope_prior')

tf.debugging.assert_same_float_dtype([
dtype = dtype_util.common_dtype([
level_scale_prior, slope_scale_prior, initial_level_prior,
initial_slope_prior
])
Expand All @@ -393,7 +394,7 @@ def __init__(self,
], axis=-1))

scaled_softplus = tfb.Chain([tfb.Scale(scale=observed_stddev),
tfb.Softplus()])
tfb.Softplus(low=dtype_util.eps(dtype))])
super(LocalLinearTrend, self).__init__(
parameters=[
Parameter('level_scale', level_scale_prior, scaled_softplus),
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_probability/python/sts/components/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_probability.python import distributions as tfd

from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.sts.structural_time_series import Parameter
from tensorflow_probability.python.sts.structural_time_series import StructuralTimeSeries

Expand Down Expand Up @@ -446,22 +447,22 @@ def __init__(self,
prior=tfd.InverseGamma(
0.5 * ones_like_weights_batch,
0.5 * ones_like_weights_batch),
bijector=tfb.Softplus()),
bijector=tfb.Softplus(low=dtype_util.eps(dtype))),
Parameter('global_scale_noncentered',
prior=tfd.HalfNormal(
scale=ones_like_weights_batch),
bijector=tfb.Softplus()),
bijector=tfb.Softplus(low=dtype_util.eps(dtype))),
Parameter('local_scale_variances',
prior=tfd.Independent(tfd.InverseGamma(
0.5 * ones_like_weights,
0.5 * ones_like_weights),
reinterpreted_batch_ndims=1),
bijector=tfb.Softplus()),
bijector=tfb.Softplus(low=dtype_util.eps(dtype))),
Parameter('local_scales_noncentered',
prior=tfd.Independent(tfd.HalfNormal(
scale=ones_like_weights),
reinterpreted_batch_ndims=1),
bijector=tfb.Softplus()),
bijector=tfb.Softplus(low=dtype_util.eps(dtype))),
Parameter('weights_noncentered',
prior=tfd.Independent(tfd.Normal(
loc=tf.zeros_like(ones_like_weights),
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_probability/python/sts/components/seasonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.internal import docstring_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.sts.internal import util as sts_util
from tensorflow_probability.python.sts.structural_time_series import Parameter
from tensorflow_probability.python.sts.structural_time_series import StructuralTimeSeries
Expand Down Expand Up @@ -829,7 +830,7 @@ def __init__(self,
loc=observed_initial,
scale=tf.abs(observed_initial) + observed_stddev)

dtype = tf.debugging.assert_same_float_dtype(
dtype = dtype_util.common_dtype(
[drift_scale_prior, initial_effect_prior])

if isinstance(initial_effect_prior, tfd.Normal):
Expand Down Expand Up @@ -868,7 +869,7 @@ def __init__(self,
parameters.append(Parameter(
'drift_scale', drift_scale_prior,
tfb.Chain([tfb.Scale(scale=observed_stddev),
tfb.Softplus()])))
tfb.Softplus(low=dtype_util.eps(dtype))])))
self._allow_drift = allow_drift

super(Seasonal, self).__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python import distributions as tfd
from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.internal import dtype_util

from tensorflow_probability.python.sts.internal import util as sts_util
from tensorflow_probability.python.sts.structural_time_series import Parameter
Expand Down Expand Up @@ -397,6 +398,12 @@ def __init__(self,
if initial_slope_prior is None:
initial_slope_prior = tfd.Normal(loc=0., scale=observed_stddev)

dtype = dtype_util.common_dtype([level_scale_prior,
slope_scale_prior,
autoregressive_coef_prior,
initial_level_prior,
initial_slope_prior])

self._initial_state_prior = tfd.MultivariateNormalDiag(
loc=tf.stack(
[initial_level_prior.mean(),
Expand All @@ -418,7 +425,8 @@ def __init__(self,
autoregressive_coef_bijector = tfb.Identity() # unconstrained

stddev_preconditioner = tfb.Scale(scale=observed_stddev)
scaled_softplus = tfb.Chain([stddev_preconditioner, tfb.Softplus()])
scaled_softplus = tfb.Chain([stddev_preconditioner,
tfb.Softplus(low=dtype_util.eps(dtype))])
super(SemiLocalLinearTrend, self).__init__(
parameters=[
Parameter('level_scale', level_scale_prior, scaled_softplus),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def __init__(self,
initial_state_prior = tfd.MultivariateNormalDiag(
scale_diag=initial_state_scale * ones)

dtype = dtype_util.common_dtype([drift_scale_prior, initial_state_prior])

self._initial_state_prior = initial_state_prior
self._period = period
self._frequency_multipliers = frequency_multipliers
Expand All @@ -430,7 +432,7 @@ def __init__(self,
parameters.append(Parameter(
'drift_scale', drift_scale_prior,
tfb.Chain([tfb.Scale(scale=observed_stddev),
tfb.Softplus()])))
tfb.Softplus(low=dtype_util.eps(dtype))])))
self._allow_drift = allow_drift

super(SmoothSeasonal, self).__init__(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/sts/components/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def __init__(self,
observation_noise_scale_prior,
tfb.Chain([
tfb.Scale(scale=observed_stddev),
tfb.Softplus()]))]
tfb.Softplus(low=dtype_util.eps(dtype))]))]
for component in components:
for parameter in component.parameters:
parameters.append(Parameter(
Expand Down

0 comments on commit bd917e0

Please sign in to comment.