Source code for zfit.models.binned_functor
# Copyright (c) 2023 zfit
from __future__ import annotations
from collections.abc import Iterable
from .basefunctor import FunctorMixin, _preprocess_init_sum
from .. import z
from ..core.binnedpdf import BaseBinnedPDFV1
from ..core.interfaces import ZfitPDF
from ..core.space import supports
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.deprecation import deprecated_norm_range
from ..util.exception import NormNotImplemented
from ..z import numpy as znp
class BaseBinnedFunctorPDF(FunctorMixin, BaseBinnedPDFV1):
"""Base class for binned functors."""
def __init__(self, models, obs, **kwargs):
super().__init__(models, obs, **kwargs)
self.pdfs = self.models
[docs]
class BinnedSumPDF(BaseBinnedFunctorPDF):
def __init__(
self,
pdfs: Iterable[ZfitPDF],
fracs: ztyping.ParamTypeInput | None = None,
obs: ztyping.ObsTypeInput = None,
name: str = "BinnedSumPDF",
):
self._fracs = None
pdfs = convert_to_container(pdfs)
self.pdfs = pdfs
(
all_extended,
fracs_cleaned,
param_fracs,
params,
sum_yields,
frac_param_created,
) = _preprocess_init_sum(fracs, obs, pdfs)
del frac_param_created # currently actually not used
self._fracs = param_fracs
self._original_fracs = fracs_cleaned
extended = sum_yields if all_extended else None
super().__init__(
models=pdfs, obs=obs, params=params, name=name, extended=extended
)
# def _unnormalized_pdf(self, x):
# models = self.models
# prob = tf.reduce_sum([model._unnormalized_pdf(x) for model in models], axis=0)
# return prob
@supports(norm=True)
def _pdf(self, x, norm):
equal_norm_ranges = len(set([pdf.norm for pdf in self.pdfs] + [norm])) == 1
if norm and not equal_norm_ranges:
raise NormNotImplemented
pdfs = self.pdfs
fracs = self.params.values()
probs = []
for pdf, frac in zip(pdfs, fracs):
probs.append(pdf.pdf(x) * frac)
prob = znp.sum(probs, axis=0)
return z.convert_to_tensor(prob)
@deprecated_norm_range
def _ext_pdf(self, x, norm, *, norm_range=None):
equal_norm_ranges = len(set([pdf.norm for pdf in self.pdfs] + [norm])) == 1
if norm and not equal_norm_ranges:
raise NormNotImplemented
prob = znp.sum([model.ext_pdf(x) for model in self.models], axis=0)
return z.convert_to_tensor(prob)
def _counts(self, x, norm=None):
equal_norm_ranges = len(set([pdf.norm for pdf in self.pdfs] + [norm])) == 1
if norm and not equal_norm_ranges:
raise NormNotImplemented
prob = znp.sum([model.counts(x) for model in self.models], axis=0)
return prob
def _rel_counts(self, x, norm=None):
equal_norm_ranges = len(set([pdf.norm for pdf in self.pdfs] + [norm])) == 1
if norm and not equal_norm_ranges:
raise NormNotImplemented
fracs = self.params.values()
prob = znp.sum(
[model.rel_counts(x) * frac for model, frac in zip(self.models, fracs)],
axis=0,
)
return prob