Source code for zfit.models.dist_tfp

"""A rich selection of analytically implemented Distributions (models) are available in `TensorFlow Probability.

<https://github.com/tensorflow/probability>`_. While their API is slightly different from the zfit models, it is similar
enough to be easily wrapped.

Therefore a convenient wrapper as well as a lot of implementations are provided.
"""
#  Copyright (c) 2022 zfit

from collections import OrderedDict

import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_probability.python.distributions as tfd

from zfit import z
from zfit.util.exception import (
    AnalyticSamplingNotImplemented,
)
from ..core.basepdf import BasePDF
from ..core.parameter import convert_to_parameter
from ..core.space import Space, supports
from ..settings import ztypes
from ..util import ztyping


# TODO: improve? while loop over `.sample`? Maybe as a fallback if not implemented?


def tfd_analytic_sample(n: int, dist: tfd.Distribution, limits: ztyping.ObsTypeInput):
    """Sample analytically with a `tfd.Distribution` within the limits. No preprocessing.

    Args:
        n: Number of samples to get
        dist: Distribution to sample from
        limits: Limits to sample from within

    Returns:
        The sampled data with the number of samples and the number of observables.
    """
    lower_bound, upper_bound = limits.rect_limits
    lower_prob_lim = dist.cdf(lower_bound)
    upper_prob_lim = dist.cdf(upper_bound)

    shape = (n, 1)
    prob_sample = z.random.uniform(
        shape=shape, minval=lower_prob_lim, maxval=upper_prob_lim
    )
    prob_sample.set_shape((None, 1))
    try:
        sample = dist.quantile(prob_sample)
    except NotImplementedError:
        raise AnalyticSamplingNotImplemented
    sample.set_shape((None, limits.n_obs))
    return sample


[docs]class WrapDistribution(BasePDF): # TODO: extend functionality of wrapper, like icdf """Baseclass to wrap tensorflow-probability distributions automatically.""" def __init__( self, distribution, dist_params, obs, params=None, dist_kwargs=None, dtype=ztypes.float, name=None, **kwargs ): # Check if subclass of distribution? if dist_kwargs is None: dist_kwargs = {} if dist_params is None: dist_params = {} name = name or distribution.name if params is None: params = OrderedDict((k, p) for k, p in dist_params.items()) else: params = OrderedDict( (k, convert_to_parameter(p)) for k, p in params.items() ) super().__init__(obs=obs, dtype=dtype, name=name, params=params, **kwargs) self._distribution = distribution self.dist_params = dist_params self.dist_kwargs = dist_kwargs self._inverse_analytic_integral = [] @property def distribution(self): params = self.dist_params if callable(params): params = params() kwargs = self.dist_kwargs if callable(kwargs): kwargs = kwargs() return self._distribution(**params, **kwargs, name=self.name + "_tfp") def _unnormalized_pdf(self, x: "zfit.Data"): value = z.unstack_x(x) # TODO: use this? change shaping below? return self.distribution.prob(value=value, name="unnormalized_pdf") # TODO: register integral? @supports() def _analytic_integrate(self, limits, norm): lower, upper = limits._rect_limits_tf lower = z.unstack_x(lower) upper = z.unstack_x(upper) tf.debugging.assert_all_finite( (lower, upper), "Are infinite limits needed? Causes troubles with NaNs" ) return self.distribution.cdf(upper) - self.distribution.cdf(lower) def _analytic_sample(self, n, limits: Space): return tfd_analytic_sample(n=n, dist=self.distribution, limits=limits)
# class KernelDensityTFP(WrapDistribution): # # def __init__(self, loc: ztyping.ParamTypeInput, scale: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, # kernel: tfp.distributions.Distribution = tfp.distributions.Normal, # weights: Union[None, np.ndarray, tf.Tensor] = None, name: str = "KernelDensity"): # """Kernel Density Estimation of loc and either a broadcasted or a per-loc scale with a Distribution as kernel. # # Args: # loc: 1-D Tensor-like. The positions of the `kernel`. Determines how many kernels will be created. # scale: Broadcastable to the batch and event shape of the distribution. A scalar will simply broadcast # to `loc` for a 1-D distribution. # obs: Observables # kernel: Distribution that is used as kernel # weights: Weights of each `loc`, can be None or Tensor-like with shape compatible with loc # name: Name of the PDF # """ # if not isinstance(kernel, # tfp.distributions.Distribution) and False: # HACK remove False, why does test not work? # raise TypeError("Currently, only tfp distributions are supported as kernels. Please open an issue if this " # "is too restrictive.") # # if isinstance(loc, ZfitData): # if loc.weights is not None: # if weights is not None: # raise OverdefinedError("Cannot specify weights and use a `ZfitData` with weights.") # else: # weights = loc.weights # # if weights is None: # weights = tf.ones_like(loc, dtype=tf.float64) # self._weights_loc = weights # self._weights_sum = z.reduce_sum(weights) # self._latent_loc = loc # params = {"scale": scale} # dist_params = {"loc": loc, "scale": scale} # super().__init__(distribution=kernel, dist_params=dist_params, obs=obs, params=params, dtype=ztypes.float, # name=name) # # def _unnormalized_pdf(self, x: "zfit.Data", norm_range=False): # value = znp.expand_dims(x.value(), -2) # new_shape = znp.concatenate([tf.shape(value)[:2], [tf.shape(self._latent_loc)[0], 4]], axis=0) # value = tf.broadcast_to(value, new_shape) # probs = self.distribution.prob(value=value, name="unnormalized_pdf") # # weights = znp.expand_dims(self._weights_loc, axis=-1) # weights = self._weights_loc # probs = z.reduce_sum(probs * weights, axis=-1) / self._weights_sum # return probs # # @supports() # def _analytic_integrate(self, limits, norm_range): # lower, upper = limits.limits # if np.all(-np.array(lower) == np.array(upper)) and np.all(np.array(upper) == np.infty): # return z.reduce_sum(self._weights_loc) # tfp distributions are normalized to 1 # lower = z.to_real(lower[0], dtype=self.dtype) # # lower = tf.broadcast_to(lower, shape=(tf.shape(self._latent_loc)[0], limits.n_obs,)) # remove # upper = z.to_real(upper[0], dtype=self.dtype) # integral = self.distribution.cdf(upper) - self.distribution.cdf(lower) # integral = z.reduce_sum(integral * self._weights_loc, axis=-1) / self._weights_sum # return integral # TODO: generalize for VectorSpaces
[docs]class Gauss(WrapDistribution): _N_OBS = 1 def __init__( self, mu: ztyping.ParamTypeInput, sigma: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, name: str = "Gauss", ): """Gaussian or Normal distribution with a mean (mu) and a standartdeviation (sigma). The gaussian shape is defined as .. math:: f(x \\mid \\mu, \\sigma^2) = e^{ -\\frac{(x - \\mu)^{2}}{2\\sigma^2} } with the normalization over [-inf, inf] of .. math:: \\frac{1}{\\sqrt{2\\pi\\sigma^2} } The normalization changes for different normalization ranges Args: mu: Mean of the gaussian dist sigma: Standard deviation or spread of the gaussian obs: Observables and normalization range the pdf is defined in name: Name of the pdf """ mu, sigma = self._check_input_params(mu, sigma) params = OrderedDict((("mu", mu), ("sigma", sigma))) dist_params = lambda: dict(loc=mu.value(), scale=sigma.value()) distribution = tfp.distributions.Normal super().__init__( distribution=distribution, dist_params=dist_params, obs=obs, params=params, name=name, )
class ExponentialTFP(WrapDistribution): _N_OBS = 1 def __init__( self, tau: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, name: str = "Exponential", ): (tau,) = self._check_input_params(tau) params = OrderedDict((("tau", tau),)) dist_params = dict(rate=tau) distribution = tfp.distributions.Exponential super().__init__( distribution=distribution, dist_params=dist_params, obs=obs, params=params, name=name, )
[docs]class Uniform(WrapDistribution): _N_OBS = 1 def __init__( self, low: ztyping.ParamTypeInput, high: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, name: str = "Uniform", ): """Uniform distribution which is constant between `low`, `high` and zero outside. Args: low: Below this value, the pdf is zero. high: Above this value, the pdf is zero. obs: Observables and normalization range the pdf is defined in name: Name of the pdf """ low, high = self._check_input_params(low, high) params = OrderedDict((("low", low), ("high", high))) dist_params = lambda: dict(low=low.value(), high=high.value()) distribution = tfp.distributions.Uniform super().__init__( distribution=distribution, dist_params=dist_params, obs=obs, params=params, name=name, )
[docs]class TruncatedGauss(WrapDistribution): _N_OBS = 1 def __init__( self, mu: ztyping.ParamTypeInput, sigma: ztyping.ParamTypeInput, low: ztyping.ParamTypeInput, high: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, name: str = "TruncatedGauss", ): """Gaussian distribution that is 0 outside of `low`, `high`. Equivalent to the product of Gauss and Uniform. Args: mu: Mean of the gaussian dist sigma: Standard deviation or spread of the gaussian low: Below this value, the pdf is zero. high: Above this value, the pdf is zero. obs: Observables and normalization range the pdf is defined in name: Name of the pdf """ mu, sigma, low, high = self._check_input_params(mu, sigma, low, high) params = OrderedDict( (("mu", mu), ("sigma", sigma), ("low", low), ("high", high)) ) distribution = tfp.distributions.TruncatedNormal dist_params = lambda: dict( loc=mu.value(), scale=sigma.value(), low=low.value(), high=high.value() ) super().__init__( distribution=distribution, dist_params=dist_params, obs=obs, params=params, name=name, )
[docs]class Cauchy(WrapDistribution): _N_OBS = 1 def __init__( self, m: ztyping.ParamTypeInput, gamma: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, name: str = "Cauchy", ): r"""Non-relativistic Breit-Wigner (Cauchy) PDF representing the energy distribution of a decaying particle. The (unnormalized) shape of the non-relativistic Breit-Wigner is given by .. math:: \frac{1}{\gamma \left[1 + \left(\frac{x - m}{\gamma}\right)^2\right]} with :math:`m` the mean and :math:`\gamma` the width of the distribution. Args: m: Invariant mass of the unstable particle. gamma: Width of the shape. obs: Observables and normalization range the pdf is defined in name: Name of the PDF """ m, gamma = self._check_input_params(m, gamma) params = OrderedDict((("m", m), ("gamma", gamma))) distribution = tfp.distributions.Cauchy dist_params = lambda: dict(loc=m.value(), scale=gamma.value()) super().__init__( distribution=distribution, dist_params=dist_params, obs=obs, params=params, name=name, )
[docs]class Poisson(WrapDistribution): _N_OBS = 1 def __init__( self, lamb: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, name: str = "Poisson", ): """Poisson distribution, parametrized with an event rate parameter (lamb). The probability mass function of the Poisson distribution is given by .. math:: f(x, \\lambda) = \\frac{\\lambda^{x}e^{-\\lambda}}{x!} Args: lamb: the event rate obs: Observables and normalization range the pdf is defined in name: Name of the PDF """ (lamb,) = self._check_input_params(lamb) params = OrderedDict((("lamb", lamb),)) dist_params = lambda: dict(rate=lamb.value()) distribution = tfp.distributions.Poisson super().__init__( distribution=distribution, dist_params=dist_params, obs=obs, params=params, name=name, )