Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
This document represents the TFP standards for all members of the
tfp.distributions
module. These standards only apply to the
tfp.distributions
module and is intentionally a "high bar" meant to protect
users of the module. The "TensorFlow Distributions" whitepaper is an exposition on the standards
described here (see Section 3). These standards have been in effect since
16-aug-2016.
You are encouraged to subclass tfp.distributions.Distribution
and invited to
disregard any/all of these standards. We especially recommend you ignore point
#4.
All TFP code (including tfp.distributions
) follows the TFP Style Guide.
-
A
Distribution
subclass must implement asample
function. Thesample
function shall generate random outcomes which in expectation correspond to other properties of theDistribution
. Other properties must assume the semantics implied by the output ofsample
. -
A
Distribution
must implement alog_prob
function. This function is the log probability mass function or probability density function, depending on the underlying measure. -
All member functions must be efficiently computable. To us, efficiently computable means: computable in (at most) expected polynomial time.
-
All member functions (except sample) must be deterministic, reproducible, and platform invariant. The means no stochastic approximations (even seeded). For example, it is acceptable to implement an efficiently computable analytical upper bound on entropy but it is not acceptable to implement a Monte Carlo estimate of entropy, even if that Monte Carlo estimate uses a seeded RNG for reproducibility. It also acceptable to implement an approximation of statistic (e.g., approximation of LogitNormal mean), as long as this can be computed non-stochastically.
-
Implementing other distribution properties is highly encouraged iff they are mathematically well-defined and efficiently computable. For example
mean_log_det
is a meaningful property for a Wishart but not for a (scalar) Normal. -
A
Distribution
's inputs/outputs areTensor
s with "Distribution
shape semantics." For example, functions likeprob
andcdf
accept aTensor
with "Distribution
shape semantics" whereassample
returns aTensor
with "Distribution shape semantics."- Exception:
tfd.JointDistribution
takes/returns alist
-like ofTensor
s.
- Exception:
-
ADistribution
'sevent_ndims
must be known statically. For example,Wishart
hasevent_ndims=2
,MultivariateNormalDiag
hasevent_ndims=1
,Normal
hasevent_ndims=0
,Categorical
hasevent_ndims=0
andOneHotCategorical
hasevent_ndims=1
. Theevent_shape
need not be known statically, i.e., this might only be known at runtime (in Eager mode) or at graph execution time (in graph mode). Often aDistribution
'sevent_ndims
will be self-evident from the class name itself.- Redacted 07-nov-2019.
-
All
Distribution
s are implicitly or explicitly conditioned on global or local (per-sample) parameters. ADistribution
infers dtype and batch/event shape from its global parameters. For example, a scalarDistribution
'sevent_shape
is implicitly inferrable (event_shape=[]
) thus always known statically; the same is not necessarily true of aMultivariateNormalDiag
. -
When possible, a
Distribution
must support Numpy-like broadcasting for all arguments. When broadcasting is not possible, arguments must be validated. -
All possible effort is made to validate arguments prior to graph execution. Any validation requiring graph execution must be gated by a Boolean, global parameter.
-
Distribution
parameters are descriptive English (i.e., not Greek letters) and draw upon a relatively small, shared lexicon. Examples include:loc
,scale
,concentration
,rate
,probs
,logits
,df
. When forced to choose between mathematical purity and conveying intuitive meaning, prefer the latter but provide extensive documentation. -
TFP
Distribution
s guarantee that input arguments are not manipulated by__init__
except to convert non-TF-derived inputs totf.Tensor
s.Among other things, this contract implies:
-
tf.Variable
-derived__init__
arguments are not read ("concretized") until some computation is requested, e.g. a member function is called. We call this "maximally deferred read" idea: "tf.Variable
safety." -
No additional computation result is stored in lieu of
__init__
arguments except to convert them totf.Tensor
s (if they aren't already). For reasons made clear below, we call this "non manipulation" idea: "tf.GradientTape
safety".
The above contract ensures several desirable features of
Distribution
s:-
Evaluations of mutable arguments (including assertions) are re-run any time the underlying values could possibly change. Example:
loc = tf.constant(0.) scale = tf.Variable(1.) d = tfp.distributions.Normal(loc, scale, validate_args=True) d.log_prob(0.) # ==> -0.918938 d.scale.assign(-1.) d.log_prob(0.) # ==> InvalidArgumentError: Argument `scale` must be positive.
-
Gradients of public
Distribution
methods with respect to__init__
arguments are valid regardless of theDistribution
being created inside or outside thetf.GradientTape
. Example:loc = tf.constant(0.) scale = tf.Variable(1.) d = tfp.distributions.Normal(loc, scale, validate_args=True) with tf.GradientTape() as tape: tape.watch(loc) # `tape.watch(scale)` is not required since `tf.GradientTape` # automatically watches `tf.Variable` dependencies (by default). x = -d.log_prob(1.) grad = tape.gradient(x, [loc, scale, d.loc, d.scale]) assert all([g is not None for g in grad])
Note that both of these properties would be lost if
__init__
memoized any derived computation in lieu of the originalTensor
-convertible arguments. -
-
Distribution
and subclasses'@property
methods shall never execute TF ops. For example, in graph execution regime this implies calling@property
will never mutate the graph.
In this section we list items which have historically been presumed true of
tfp.distributions
but are not official requirements.
-
Mutable state is not explicitly disallowed. However, it is highly discouraged as it makes reasoning about the object more challenging (for both API owner and user). As of 11-nov-2019, no
tfp.distributions
member has its own mutable state although all distributions do mutate the global random number generate state on access tosample
. -
Subclasses are free to override public base class members. I.e., you don't have to follow the "public calls private" pattern. (However, as of 11-nov-2019, there has not yet been a reason to deviate from this pattern.)