"""A rich selection of analytically implemented Distributions (models) are available in `TensorFlow Probability.
<https://github.com/tensorflow/probability>`_. While their API is slightly different from the zfit models, it is similar
enough to be easily wrapped.
Therefore, a convenient wrapper as well as a lot of implementations are provided.
"""
# Copyright (c) 2023 zfit
from __future__ import annotations
from collections import OrderedDict
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_probability.python.distributions as tfd
from pydantic import Field
from typing import Literal
from zfit import z
from zfit.util.exception import (
AnalyticSamplingNotImplemented,
)
from ..core.basepdf import BasePDF
from ..core.interfaces import ZfitData
from ..core.parameter import convert_to_parameter
from ..core.serialmixin import SerializableMixin
from ..core.space import Space, supports
from ..serialization import SpaceRepr, Serializer
from ..serialization.pdfrepr import BasePDFRepr
from ..settings import ztypes
from ..util import ztyping
from ..util.deprecation import deprecated_args
from ..util.ztyping import ExtendedInputType, NormInputType
# TODO: improve? while loop over `.sample`? Maybe as a fallback if not implemented?
def tfd_analytic_sample(n: int, dist: tfd.Distribution, limits: ztyping.ObsTypeInput):
"""Sample analytically with a `tfd.Distribution` within the limits. No preprocessing.
Args:
n: Number of samples to get
dist: Distribution to sample from
limits: Limits to sample from within
Returns:
The sampled data with the number of samples and the number of observables.
"""
lower_bound, upper_bound = limits.rect_limits
lower_prob_lim = dist.cdf(lower_bound)
upper_prob_lim = dist.cdf(upper_bound)
shape = (n, 1)
prob_sample = z.random.uniform(
shape=shape, minval=lower_prob_lim, maxval=upper_prob_lim
)
prob_sample.set_shape((None, 1))
try:
sample = dist.quantile(prob_sample)
except NotImplementedError:
raise AnalyticSamplingNotImplemented
sample.set_shape((None, limits.n_obs))
return sample
[docs]
class WrapDistribution(BasePDF): # TODO: extend functionality of wrapper, like icdf
"""Baseclass to wrap tensorflow-probability distributions automatically."""
def __init__(
self,
distribution,
dist_params,
obs,
params=None,
dist_kwargs=None,
dtype=ztypes.float,
name=None,
**kwargs,
):
# Check if subclass of distribution?
if dist_kwargs is None:
dist_kwargs = {}
if dist_params is None:
dist_params = {}
name = name or distribution.name
if params is None:
params = OrderedDict((k, p) for k, p in dist_params.items())
else:
params = OrderedDict(
(k, convert_to_parameter(p)) for k, p in params.items()
)
super().__init__(obs=obs, dtype=dtype, name=name, params=params, **kwargs)
self._distribution = distribution
self.dist_params = dist_params
self.dist_kwargs = dist_kwargs
self._inverse_analytic_integral = []
@property
def distribution(self):
params = self.dist_params
if callable(params):
params = params()
kwargs = self.dist_kwargs
if callable(kwargs):
kwargs = kwargs()
return self._distribution(**params, **kwargs, name=self.name + "_tfp")
def _unnormalized_pdf(self, x: ZfitData):
value = z.unstack_x(x) # TODO: use this? change shaping below?
return self.distribution.prob(value=value, name="unnormalized_pdf")
# TODO: register integral?
@supports()
def _analytic_integrate(self, limits, norm):
lower, upper = limits._rect_limits_tf
lower = z.unstack_x(lower)
upper = z.unstack_x(upper)
tf.debugging.assert_all_finite(
(lower, upper), "Are infinite limits needed? Causes troubles with NaNs"
)
return self.distribution.cdf(upper) - self.distribution.cdf(lower)
def _analytic_sample(self, n, limits: Space):
return tfd_analytic_sample(n=n, dist=self.distribution, limits=limits)
# class KernelDensityTFP(WrapDistribution):
#
# def __init__(self, loc: ztyping.ParamTypeInput, scale: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput,
# kernel: tfp.distributions.Distribution = tfp.distributions.Normal,
# weights: Union[None, np.ndarray, tf.Tensor] = None, name: str = "KernelDensity"):
# """Kernel Density Estimation of loc and either a broadcasted or a per-loc scale with a Distribution as kernel.
#
# Args:
# loc: 1-D Tensor-like. The positions of the `kernel`. Determines how many kernels will be created.
# scale: Broadcastable to the batch and event shape of the distribution. A scalar will simply broadcast
# to `loc` for a 1-D distribution.
# obs: Observables
# kernel: Distribution that is used as kernel
# weights: Weights of each `loc`, can be None or Tensor-like with shape compatible with loc
# name: Name of the PDF
# """
# if not isinstance(kernel,
# tfp.distributions.Distribution) and False: # HACK remove False, why does test not work?
# raise TypeError("Currently, only tfp distributions are supported as kernels. Please open an issue if this "
# "is too restrictive.")
#
# if isinstance(loc, ZfitData):
# if loc.weights is not None:
# if weights is not None:
# raise OverdefinedError("Cannot specify weights and use a `ZfitData` with weights.")
# else:
# weights = loc.weights
#
# if weights is None:
# weights = tf.ones_like(loc, dtype=tf.float64)
# self._weights_loc = weights
# self._weights_sum = z.reduce_sum(weights)
# self._latent_loc = loc
# params = {"scale": scale}
# dist_params = {"loc": loc, "scale": scale}
# super().__init__(distribution=kernel, dist_params=dist_params, obs=obs, params=params, dtype=ztypes.float,
# name=name)
#
# def _unnormalized_pdf(self, x: "zfit.Data", norm_range=False):
# value = znp.expand_dims(x.value(), -2)
# new_shape = znp.concatenate([tf.shape(value)[:2], [tf.shape(self._latent_loc)[0], 4]], axis=0)
# value = tf.broadcast_to(value, new_shape)
# probs = self.distribution.prob(value=value, name="unnormalized_pdf")
# # weights = znp.expand_dims(self._weights_loc, axis=-1)
# weights = self._weights_loc
# probs = z.reduce_sum(probs * weights, axis=-1) / self._weights_sum
# return probs
#
# @supports()
# def _analytic_integrate(self, limits, norm_range):
# lower, upper = limits.limits
# if np.all(-np.array(lower) == np.array(upper)) and np.all(np.array(upper) == np.infty):
# return z.reduce_sum(self._weights_loc) # tfp distributions are normalized to 1
# lower = z.to_real(lower[0], dtype=self.dtype)
# # lower = tf.broadcast_to(lower, shape=(tf.shape(self._latent_loc)[0], limits.n_obs,)) # remove
# upper = z.to_real(upper[0], dtype=self.dtype)
# integral = self.distribution.cdf(upper) - self.distribution.cdf(lower)
# integral = z.reduce_sum(integral * self._weights_loc, axis=-1) / self._weights_sum
# return integral # TODO: generalize for VectorSpaces
[docs]
class Gauss(WrapDistribution, SerializableMixin):
_N_OBS = 1
def __init__(
self,
mu: ztyping.ParamTypeInput,
sigma: ztyping.ParamTypeInput,
obs: ztyping.ObsTypeInput,
extended: ExtendedInputType = None,
norm: NormInputType = None,
name: str = "Gauss",
):
"""Gaussian or Normal distribution with a mean (mu) and a standartdeviation (sigma).
The gaussian shape is defined as
.. math::
f(x \\mid \\mu, \\sigma^2) = e^{ -\\frac{(x - \\mu)^{2}}{2\\sigma^2} }
with the normalization over [-inf, inf] of
.. math::
\\frac{1}{\\sqrt{2\\pi\\sigma^2} }
The normalization changes for different normalization ranges
Args:
mu: Mean of the gaussian dist
sigma: Standard deviation or spread of the gaussian
obs: Observables and normalization range the pdf is defined in
extended: |@doc:pdf.init.extended| The overall yield of the PDF.
If this is parameter-like, it will be used as the yield,
the expected number of events, and the PDF will be extended.
An extended PDF has additional functionality, such as the
``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended|
norm: |@doc:pdf.init.norm| Normalization of the PDF.
By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm|
name: |@doc:model.init.name| Human-readable name
or label of
the PDF for better identification.
Has no programmatical functional purpose as identification. |@docend:model.init.name|
"""
mu, sigma = self._check_input_params(mu, sigma)
params = OrderedDict((("mu", mu), ("sigma", sigma)))
dist_params = lambda: dict(loc=mu.value(), scale=sigma.value())
distribution = tfp.distributions.Normal
super().__init__(
distribution=distribution,
dist_params=dist_params,
obs=obs,
params=params,
name=name,
extended=extended,
norm=norm,
)
class GaussPDFRepr(BasePDFRepr):
_implementation = Gauss
hs3_type: Literal["Gauss"] = Field("Gauss", alias="type")
x: SpaceRepr
mu: Serializer.types.ParamInputTypeDiscriminated
sigma: Serializer.types.ParamInputTypeDiscriminated
class ExponentialTFP(WrapDistribution):
_N_OBS = 1
def __init__(
self,
tau: ztyping.ParamTypeInput,
obs: ztyping.ObsTypeInput,
name: str = "Exponential",
):
(tau,) = self._check_input_params(tau)
params = OrderedDict((("tau", tau),))
dist_params = dict(rate=tau)
distribution = tfp.distributions.Exponential
super().__init__(
distribution=distribution,
dist_params=dist_params,
obs=obs,
params=params,
name=name,
)
[docs]
class TruncatedGauss(WrapDistribution):
_N_OBS = 1
def __init__(
self,
mu: ztyping.ParamTypeInput,
sigma: ztyping.ParamTypeInput,
low: ztyping.ParamTypeInput,
high: ztyping.ParamTypeInput,
obs: ztyping.ObsTypeInput,
*,
extended: ExtendedInputType = None,
norm: NormInputType = None,
name: str = "TruncatedGauss",
):
"""Gaussian distribution that is 0 outside of `low`, `high`. Equivalent to the product of Gauss and Uniform.
Args:
mu: Mean of the gaussian dist
sigma: Standard deviation or spread of the gaussian
low: Below this value, the pdf is zero.
high: Above this value, the pdf is zero.
obs: Observables and normalization range the pdf is defined in
extended: |@doc:pdf.init.extended| The overall yield of the PDF.
If this is parameter-like, it will be used as the yield,
the expected number of events, and the PDF will be extended.
An extended PDF has additional functionality, such as the
``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended|
norm: |@doc:pdf.init.norm| Normalization of the PDF.
By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm|
name: |@doc:model.init.name| Human-readable name
or label of
the PDF for better identification.
Has no programmatical functional purpose as identification. |@docend:model.init.name|
"""
mu, sigma, low, high = self._check_input_params(mu, sigma, low, high)
params = OrderedDict(
(("mu", mu), ("sigma", sigma), ("low", low), ("high", high))
)
distribution = tfp.distributions.TruncatedNormal
dist_params = lambda: dict(
loc=mu.value(), scale=sigma.value(), low=low.value(), high=high.value()
)
super().__init__(
distribution=distribution,
dist_params=dist_params,
obs=obs,
params=params,
name=name,
extended=extended,
norm=norm,
)
[docs]
class Cauchy(WrapDistribution, SerializableMixin):
_N_OBS = 1
def __init__(
self,
m: ztyping.ParamTypeInput,
gamma: ztyping.ParamTypeInput,
obs: ztyping.ObsTypeInput,
*,
extended: ExtendedInputType = None,
norm: NormInputType = None,
name: str = "Cauchy",
):
r"""Non-relativistic Breit-Wigner (Cauchy) PDF representing the energy distribution of a decaying particle.
The (unnormalized) shape of the non-relativistic Breit-Wigner is given by
.. math::
\frac{1}{\gamma \left[1 + \left(\frac{x - m}{\gamma}\right)^2\right]}
with :math:`m` the mean and :math:`\gamma` the width of the distribution.
Args:
m: Invariant mass of the unstable particle.
gamma: Width of the shape.
obs: Observables and normalization range the pdf is defined in
extended: |@doc:pdf.init.extended| The overall yield of the PDF.
If this is parameter-like, it will be used as the yield,
the expected number of events, and the PDF will be extended.
An extended PDF has additional functionality, such as the
``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended|
norm: |@doc:pdf.init.norm| Normalization of the PDF.
By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm|
name: |@doc:model.init.name| Human-readable name
or label of
the PDF for better identification.
Has no programmatical functional purpose as identification. |@docend:model.init.name|
"""
m, gamma = self._check_input_params(m, gamma)
params = OrderedDict((("m", m), ("gamma", gamma)))
distribution = tfp.distributions.Cauchy
dist_params = lambda: dict(loc=m.value(), scale=gamma.value())
super().__init__(
distribution=distribution,
dist_params=dist_params,
obs=obs,
params=params,
name=name,
extended=extended,
norm=norm,
)
class CauchyPDFRepr(BasePDFRepr):
_implementation = Cauchy
hs3_type: Literal["Cauchy"] = Field("Cauchy", alias="type")
x: SpaceRepr
m: Serializer.types.ParamTypeDiscriminated
gamma: Serializer.types.ParamTypeDiscriminated
[docs]
class Poisson(WrapDistribution, SerializableMixin):
_N_OBS = 1
@deprecated_args(None, "Use lam instead", "lamb")
def __init__(
self,
lam: ztyping.ParamTypeInput = None,
obs: ztyping.ObsTypeInput = None,
*,
extended: ExtendedInputType = None,
norm: NormInputType = None,
name: str = "Poisson",
lamb=None,
):
"""Poisson distribution, parametrized with an event rate parameter (lamb).
The probability mass function of the Poisson distribution is given by
.. math::
f(x, \\lambda) = \\frac{\\lambda^{x}e^{-\\lambda}}{x!}
Args:
lamb: the event rate
obs: Observables and normalization range the pdf is defined in
extended: |@doc:pdf.init.extended| The overall yield of the PDF.
If this is parameter-like, it will be used as the yield,
the expected number of events, and the PDF will be extended.
An extended PDF has additional functionality, such as the
``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended|
norm: |@doc:pdf.init.norm| Normalization of the PDF.
By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm|
name: Name of the PDF
"""
if lamb is not None:
lam = lamb
del lamb
(lam,) = self._check_input_params(lam)
params = {"lam": lam}
dist_params = lambda: dict(rate=lam.value())
distribution = tfp.distributions.Poisson
super().__init__(
distribution=distribution,
dist_params=dist_params,
obs=obs,
params=params,
name=name,
extended=extended,
norm=norm,
)
class PoissonPDFRepr(BasePDFRepr):
_implementation = Poisson
hs3_type: Literal["Poisson"] = Field("Poisson", alias="type")
x: SpaceRepr
lam: Serializer.types.ParamTypeDiscriminated