Source code for zfit.models.morphing

#  Copyright (c) 2024 zfit

from __future__ import annotations

from collections.abc import Mapping, Iterable

import tensorflow as tf
from uhi.typing.plottable import PlottableHistogram

import zfit.z.numpy as znp
from zfit import z
from zfit.core.binnedpdf import BaseBinnedPDFV1
from ..core import parameter
from ..core.interfaces import ZfitBinnedPDF
from ..util import ztyping
from ..util.exception import SpecificFunctionNotImplemented
from ..z.interpolate_spline import interpolate_spline


@z.function(wraps="tensor", keepalive=True)
def spline_interpolator(alpha, alphas, densities):
    alphas = alphas[None, :, None]
    shape = tf.shape(densities[0])
    densities_flat = [znp.reshape(density, [-1]) for density in densities]
    densities_flat = znp.stack(densities_flat, axis=0)
    alpha_shaped = znp.reshape(alpha, [1, -1, 1])

    y_flat = interpolate_spline(
        train_points=alphas,
        train_values=densities_flat[None, ...],
        query_points=alpha_shaped,
        order=2,
    )
    y_flat = y_flat[0, 0]
    y = tf.reshape(y_flat, shape)
    return y


[docs] class SplineMorphingPDF(BaseBinnedPDFV1): _morphing_interpolator = staticmethod(spline_interpolator) def __init__( self, alpha: ztyping.ParamTypeInput, hists: ( Mapping[float | int, Iterable[ZfitBinnedPDF]] | list[ZfitBinnedPDF] | tuple[ZfitBinnedPDF] ), extended: ztyping.ExtendedInputType = None, norm: ztyping.NormInputType = None, ): """Morphing a set of histograms with a spline interpolation. Args: alpha: Parameter for the spline interpolation. hists: A mapping of alpha values to histograms. This allows for arbitrary interpolation points. If a list or tuple of exactly three PDFs is given, this corresponds to the histograms at alhpa equal to -1, 0 and 1 respectively. extended: |@doc:pdf.init.extended| The overall yield of the PDF. If this is parameter-like, it will be used as the yield, the expected number of events, and the PDF will be extended. An extended PDF has additional functionality, such as the ``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended| norm: |@doc:pdf.init.norm| Normalization of the PDF. By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm| """ if isinstance(hists, (list, tuple)): if len(hists) != 3: raise ValueError( "If hists is a list, it is assumed to correspond to an alpha of -1, 0 and 1." f" hists is {hists} and has length {len(hists)}." ) else: hists = { float(i - 1): hist for i, hist in enumerate(hists) } # mapping to -1, 0, 1 hists_clean = {} for a, hist in hists.items(): if isinstance(hist, PlottableHistogram): from zfit.models.histogram import HistogramPDF hist = HistogramPDF(hist) if isinstance(hist, ZfitBinnedPDF): hists[a] = hist else: raise TypeError( f"hist {hist} is not a ZfitBinnedPDF or a UHI histogram." ) self.hists = hists self.alpha = alpha obs = list(hists.values())[0].space all_extended = all(hist.is_extended for hist in hists.values()) if extended is None: # TODO: yields? extended = all_extended self._automatically_extended = None if extended is True: # create the yield automatically self._automatically_extended = True if not all_extended: raise ValueError( "If extended is True, all PDFs must be extended to create the yield automatically." ) alphas = znp.array(list(self.hists.keys()), dtype=znp.float64) def interpolated_yield(params): alpha = params["alpha"] densities = tuple( params[f"{i}"] for i in range(len(params) - 1) ) # params has n hist entries + 1 alpha entry return spline_interpolator( alpha=alpha, alphas=alphas, densities=densities ) number = parameter.get_auto_number() yields = {f"{i}": hist.get_yield() for i, hist in enumerate(hists.values())} yields["alpha"] = alpha new_yield = parameter.ComposedParameter( f"AUTOGEN_{number}_interpolated_yield", interpolated_yield, params=yields, ) extended = new_yield elif extended is not False: self._automatically_extended = False super().__init__( obs=obs, extended=extended, norm=norm, params={"alpha": alpha}, name="LinearMorphing", ) def _counts(self, x, norm): if not self._automatically_extended: raise SpecificFunctionNotImplemented densities = [hist.counts(x, norm=norm) for hist in self.hists.values()] alphas = znp.array(list(self.hists.keys()), dtype=znp.float64) alpha = self.params["alpha"].value() y = self._morphing_interpolator(alpha, alphas, densities) return y def _rel_counts(self, x, norm): densities = [hist.rel_counts(x, norm=norm) for hist in self.hists.values()] alphas = znp.array(list(self.hists.keys()), dtype=znp.float64) alpha = self.params["alpha"].value() y = self._morphing_interpolator(alpha, alphas, densities) return y def _ext_pdf(self, x, norm): if not self._automatically_extended: raise SpecificFunctionNotImplemented densities = [hist.ext_pdf(x, norm=norm) for hist in self.hists.values()] alphas = znp.array(list(self.hists.keys()), dtype=znp.float64) alpha = self.params["alpha"].value() y = self._morphing_interpolator(alpha, alphas, densities) return y def _pdf(self, x, norm): densities = [hist.pdf(x, norm=norm) for hist in self.hists.values()] alphas = znp.array(list(self.hists.keys()), dtype=znp.float64) alpha = self.params["alpha"].value() y = self._morphing_interpolator(alpha, alphas, densities) return y