Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cast x to logits.dtype in onehot_categorical.
Otherwise, dist.log_prob(dist.sample()) has an error in reduce_logsumexp inside _assert_valid_sample when validate_args=True (TypeError: Value passed to parameter 'x' has DataType int32 not in list of allowed values: bfloat16, float16, float32, float64). PiperOrigin-RevId: 239992007
- Loading branch information