Source code for zfit.models.physics

#  Copyright (c) 2024 zfit

from __future__ import annotations

import numpy as np
import pydantic
import tensorflow as tf

from typing import Literal

import zfit.z.numpy as znp
from zfit import z
from ..core.basepdf import BasePDF
from ..core.serialmixin import SerializableMixin
from ..core.space import ANY_LOWER, ANY_UPPER, Space
from ..serialization import SpaceRepr, Serializer
from ..serialization.pdfrepr import BasePDFRepr
from ..util import ztyping
from ..util.ztyping import ExtendedInputType, NormInputType


def _powerlaw(x, a, k):
    return a * znp.power(x, k)


@z.function(wraps="tensor", keepalive=True)
def crystalball_func(x, mu, sigma, alpha, n):
    t = (x - mu) / sigma * tf.sign(alpha)
    abs_alpha = znp.abs(alpha)
    a = znp.power((n / abs_alpha), n) * znp.exp(-0.5 * znp.square(alpha))
    b = (n / abs_alpha) - abs_alpha
    cond = tf.less(t, -abs_alpha)
    func = z.safe_where(
        cond,
        lambda t: _powerlaw(b - t, a, -n),
        lambda t: znp.exp(-0.5 * znp.square(t)),
        values=t,
        value_safer=lambda t: znp.ones_like(t) * (b - 2),
    )
    func = znp.maximum(func, znp.zeros_like(func))
    return func


@z.function(wraps="tensor", keepalive=True, stateless_args=False)
def double_crystalball_func(x, mu, sigma, alphal, nl, alphar, nr):
    cond = tf.less(x, mu)

    func = tf.where(
        cond,
        crystalball_func(x, mu, sigma, alphal, nl),
        crystalball_func(x, mu, sigma, -alphar, nr),
    )

    return func


# created with the help of TensorFlow autograph used on python code converted from ShapeCB of RooFit
def crystalball_integral(limits, params, model):
    mu = params["mu"]
    sigma = params["sigma"]
    alpha = params["alpha"]
    n = params["n"]

    lower, upper = limits._rect_limits_tf

    integral = crystalball_integral_func(mu, sigma, alpha, n, lower, upper)
    return integral


@z.function(wraps="tensor", keepalive=True)
# @tf.function  # BUG? TODO: problem with tf.function and input signature
def crystalball_integral_func(mu, sigma, alpha, n, lower, upper):
    sqrt_pi_over_two = np.sqrt(np.pi / 2)
    sqrt2 = np.sqrt(2)

    use_log = tf.less(znp.abs(n - 1.0), 1e-05)
    abs_sigma = znp.abs(sigma)
    abs_alpha = znp.abs(alpha)
    tmin = (lower - mu) / abs_sigma
    tmax = (upper - mu) / abs_sigma

    alpha_negative = tf.less(alpha, 0)
    # do not move on two lines, logic will fail...
    tmax, tmin = znp.where(alpha_negative, -tmin, tmax), znp.where(
        alpha_negative, -tmax, tmin
    )

    if_true_4 = (
        abs_sigma
        * sqrt_pi_over_two
        * (tf.math.erf(tmax / sqrt2) - tf.math.erf(tmin / sqrt2))
    )

    a = znp.power(n / abs_alpha, n) * znp.exp(-0.5 * tf.square(abs_alpha))
    b = n / abs_alpha - abs_alpha

    # gradients from tf.where can be NaN if the non-selected branch is NaN
    # https://github.com/tensorflow/tensorflow/issues/42889
    # solution is to provide save values for the non-selected branch to never make them become NaNs
    b_tmin = b - tmin
    safe_b_tmin_ones = znp.where(b_tmin > 0, b_tmin, znp.ones_like(b_tmin))
    b_tmax = b - tmax
    safe_b_tmax_ones = znp.where(b_tmax > 0, b_tmax, znp.ones_like(b_tmax))

    if_true_1 = a * abs_sigma * (znp.log(safe_b_tmin_ones) - znp.log(safe_b_tmax_ones))

    if_false_1 = (
        a
        * abs_sigma
        / (1.0 - n)
        * (
            1.0 / znp.power(safe_b_tmin_ones, n - 1.0)
            - 1.0 / znp.power(safe_b_tmax_ones, n - 1.0)
        )
    )

    if_true_3 = tf.where(use_log, if_true_1, if_false_1)

    if_true_2 = a * abs_sigma * (znp.log(safe_b_tmin_ones) - znp.log(n / abs_alpha))
    if_false_2 = (
        a
        * abs_sigma
        / (1.0 - n)
        * (
            1.0 / znp.power(safe_b_tmin_ones, n - 1.0)
            - 1.0 / znp.power(n / abs_alpha, n - 1.0)
        )
    )
    term1 = tf.where(use_log, if_true_2, if_false_2)
    term2 = (
        abs_sigma
        * sqrt_pi_over_two
        * (tf.math.erf(tmax / sqrt2) - tf.math.erf(-abs_alpha / sqrt2))
    )
    if_false_3 = term1 + term2

    if_false_4 = tf.where(tf.less_equal(tmax, -abs_alpha), if_true_3, if_false_3)

    # if_false_4()
    result = tf.where(tf.greater_equal(tmin, -abs_alpha), if_true_4, if_false_4)
    if not result.shape.rank == 0:
        result = tf.gather(result, 0, axis=-1)  # remove last dim, should vanish
    return result


def double_crystalball_mu_integral(limits, params, model):
    mu = params["mu"]
    sigma = params["sigma"]
    alphal = params["alphal"]
    nl = params["nl"]
    alphar = params["alphar"]
    nr = params["nr"]

    lower, upper = limits._rect_limits_tf
    lower = lower[:, 0]
    upper = upper[:, 0]

    return double_crystalball_mu_integral_func(
        mu=mu,
        sigma=sigma,
        alphal=alphal,
        nl=nl,
        alphar=alphar,
        nr=nr,
        lower=lower,
        upper=upper,
    )


@z.function(wraps="tensor", keepalive=True)
def double_crystalball_mu_integral_func(
    mu, sigma, alphal, nl, alphar, nr, lower, upper
):
    # mu_broadcast =
    upper_of_lowerint = znp.minimum(mu, upper)
    integral_left = crystalball_integral_func(
        mu=mu, sigma=sigma, alpha=alphal, n=nl, lower=lower, upper=upper_of_lowerint
    )
    left = tf.where(tf.less(mu, lower), znp.zeros_like(integral_left), integral_left)

    lower_of_upperint = znp.maximum(mu, lower)
    integral_right = crystalball_integral_func(
        mu=mu, sigma=sigma, alpha=-alphar, n=nr, lower=lower_of_upperint, upper=upper
    )
    right = tf.where(
        tf.greater(mu, upper), znp.zeros_like(integral_right), integral_right
    )

    integral = left + right
    return integral


[docs] class CrystalBall(BasePDF, SerializableMixin): _N_OBS = 1 def __init__( self, mu: ztyping.ParamTypeInput, sigma: ztyping.ParamTypeInput, alpha: ztyping.ParamTypeInput, n: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, *, extended: ExtendedInputType = None, norm: NormInputType = None, name: str = "CrystalBall", ): """Crystal Ball shaped PDF. A combination of a Gaussian with a powerlaw tail. The function is defined as follows: .. math:: f(x;\\mu, \\sigma, \\alpha, n) = \\begin{cases} \\exp(- \\frac{(x - \\mu)^2}{2 \\sigma^2}), & \\mbox{for}\\frac{x - \\mu}{\\sigma} \\geqslant -\\alpha \\newline A \\cdot (B - \\frac{x - \\mu}{\\sigma})^{-n}, & \\mbox{for }\\frac{x - \\mu}{\\sigma} < -\\alpha \\end{cases} with .. math:: A = \\left(\\frac{n}{\\left| \\alpha \\right|}\\right)^n \\cdot \\exp\\left(- \\frac {\\left|\\alpha \\right|^2}{2}\\right) B = \\frac{n}{\\left| \\alpha \\right|} - \\left| \\alpha \\right| Args: mu: The mean of the gaussian sigma: Standard deviation of the gaussian alpha: parameter where to switch from a gaussian to the powertail n: Exponent of the powertail obs: |@doc:pdf.init.obs| Observables of the model. This will be used as the default space of the PDF and, if not given explicitly, as the normalization range. The default space is used for example in the sample method: if no sampling limits are given, the default space is used. The observables are not equal to the domain as it does not restrict or truncate the model outside this range. |@docend:pdf.init.obs| extended: |@doc:pdf.init.extended| The overall yield of the PDF. If this is parameter-like, it will be used as the yield, the expected number of events, and the PDF will be extended. An extended PDF has additional functionality, such as the ``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended| norm: |@doc:pdf.init.norm| Normalization of the PDF. By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm| name: |@doc:pdf.init.name| Human-readable name or label of the PDF for better identification. Has no programmatical functional purpose as identification. |@docend:pdf.init.name| .. _CBShape: https://en.wikipedia.org/wiki/Crystal_Ball_function """ params = {"mu": mu, "sigma": sigma, "alpha": alpha, "n": n} super().__init__( obs=obs, name=name, params=params, extended=extended, norm=norm ) def _unnormalized_pdf(self, x): mu = self.params["mu"].value() sigma = self.params["sigma"].value() alpha = self.params["alpha"].value() n = self.params["n"].value() x = x.unstack_x() return crystalball_func(x=x, mu=mu, sigma=sigma, alpha=alpha, n=n)
class CrystalBallPDFRepr(BasePDFRepr): _implementation = CrystalBall hs3_type: Literal["CrystalBall"] = pydantic.Field("CrystalBall", alias="type") x: SpaceRepr mu: Serializer.types.ParamTypeDiscriminated sigma: Serializer.types.ParamTypeDiscriminated alpha: Serializer.types.ParamTypeDiscriminated n: Serializer.types.ParamTypeDiscriminated crystalball_integral_limits = Space( axes=(0,), limits=(((ANY_LOWER,),), ((ANY_UPPER,),)) ) CrystalBall.register_analytic_integral( func=crystalball_integral, limits=crystalball_integral_limits )
[docs] class DoubleCB(BasePDF, SerializableMixin): _N_OBS = 1 def __init__( self, mu: ztyping.ParamTypeInput, sigma: ztyping.ParamTypeInput, alphal: ztyping.ParamTypeInput, nl: ztyping.ParamTypeInput, alphar: ztyping.ParamTypeInput, nr: ztyping.ParamTypeInput, obs: ztyping.ObsTypeInput, *, extended: ExtendedInputType = None, norm: NormInputType = None, name: str = "DoubleCB", ): """Double-sided Crystal Ball shaped PDF. A combination of two CB using the **mu** (not a frac) on each side. The function is defined as follows: .. math:: f(x;\\mu, \\sigma, \\alpha_{L}, n_{L}, \\alpha_{R}, n_{R}) = \\begin{cases} A_{L} \\cdot (B_{L} - \\frac{x - \\mu}{\\sigma})^{-n}, & \\mbox{for }\\frac{x - \\mu}{\\sigma} < -\\alpha_{L} \\newline \\exp(- \\frac{(x - \\mu)^2}{2 \\sigma^2}), & -\\alpha_{L} \\leqslant \\mbox{for}\\frac{x - \\mu}{\\sigma} \\leqslant \\alpha_{R} \\newline A_{R} \\cdot (B_{R} + \\frac{x - \\mu}{\\sigma})^{-n}, & \\mbox{for }\\frac{x - \\mu}{\\sigma} > \\alpha_{R} \\end{cases} with .. math:: A_{L/R} = \\left(\\frac{n_{L/R}}{\\left| \\alpha_{L/R} \\right|}\\right)^n_{L/R} \\cdot \\exp\\left(- \\frac {\\left|\\alpha_{L/R} \\right|^2}{2}\\right) B_{L/R} = \\frac{n_{L/R}}{\\left| \\alpha_{L/R} \\right|} - \\left| \\alpha_{L/R} \\right| Args: mu: The mean of the gaussian sigma: Standard deviation of the gaussian alphal: parameter where to switch from a gaussian to the powertail on the left side nl: Exponent of the powertail on the left side alphar: parameter where to switch from a gaussian to the powertail on the right side nr: Exponent of the powertail on the right side obs: |@doc:pdf.init.obs| Observables of the model. This will be used as the default space of the PDF and, if not given explicitly, as the normalization range. The default space is used for example in the sample method: if no sampling limits are given, the default space is used. The observables are not equal to the domain as it does not restrict or truncate the model outside this range. |@docend:pdf.init.obs| extended: |@doc:pdf.init.extended| The overall yield of the PDF. If this is parameter-like, it will be used as the yield, the expected number of events, and the PDF will be extended. An extended PDF has additional functionality, such as the ``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended| norm: |@doc:pdf.init.norm| Normalization of the PDF. By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm| name: |@doc:pdf.init.name| Human-readable name or label of the PDF for better identification. Has no programmatical functional purpose as identification. |@docend:pdf.init.name| """ params = { "mu": mu, "sigma": sigma, "alphal": alphal, "nl": nl, "alphar": alphar, "nr": nr, } super().__init__( obs=obs, name=name, params=params, extended=extended, norm=norm ) def _unnormalized_pdf(self, x): mu = self.params["mu"].value() sigma = self.params["sigma"].value() alphal = self.params["alphal"].value() nl = self.params["nl"].value() alphar = self.params["alphar"].value() nr = self.params["nr"].value() x = x.unstack_x() return double_crystalball_func( x=x, mu=mu, sigma=sigma, alphal=alphal, nl=nl, alphar=alphar, nr=nr )
class DoubleCBPDFRepr(BasePDFRepr): _implementation = DoubleCB hs3_type: Literal["DoubleCB"] = pydantic.Field("DoubleCB", alias="type") x: SpaceRepr mu: Serializer.types.ParamTypeDiscriminated sigma: Serializer.types.ParamTypeDiscriminated alphal: Serializer.types.ParamTypeDiscriminated nl: Serializer.types.ParamTypeDiscriminated alphar: Serializer.types.ParamTypeDiscriminated nr: Serializer.types.ParamTypeDiscriminated DoubleCB.register_analytic_integral( func=double_crystalball_mu_integral, limits=crystalball_integral_limits )