"""
Functors are functions that take typically one or more other PDF. Prominent examples are a sum, convolution etc.
A FunctorBase class is provided to make handling the models easier.
Their implementation is often non-trivial.
"""
# Copyright (c) 2019 zfit
from collections import OrderedDict
import itertools
from typing import Union, List, Optional
import tensorflow as tf
import numpy as np
from zfit import ztf
from ..core.interfaces import ZfitPDF, ZfitModel, ZfitSpace
from ..core.limits import no_norm_range, supports
from ..core.basepdf import BasePDF
from ..core.parameter import Parameter, convert_to_parameter
from ..models.basefunctor import FunctorMixin
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.exception import (ExtendedPDFError, AlreadyExtendedPDFError, AxesNotUnambiguousError,
LimitsOverdefinedError,
ModelIncompatibleError, )
from ..util.temporary import TemporarilySet
from ..settings import ztypes, run
[docs]class BaseFunctor(FunctorMixin, BasePDF):
def __init__(self, pdfs, name="BaseFunctor", **kwargs):
self.pdfs = convert_to_container(pdfs)
super().__init__(models=self.pdfs, name=name, **kwargs)
self._set_norm_range_from_daugthers()
self._component_norm_range_holder = None
def _get_component_norm_range(self):
return self._component_norm_range_holder
def _set_component_norm_range(self, norm_range: ztyping.LimitsTypeInput):
norm_range = self._check_input_norm_range(norm_range=norm_range)
if norm_range.limits in (False, None):
if self._get_component_norm_range() is None:
raise RuntimeError("Cannot use `False` as `norm_range` without previously setting the "
"`component_norm_range`.")
def setter(value):
self._component_norm_range_holder = value
return TemporarilySet(value=norm_range, setter=setter, getter=self._get_component_norm_range)
def _set_norm_range_from_daugthers(self):
norm_range = super().norm_range
if norm_range.limits is None:
norm_range_candidat = self._infer_norm_range_from_daughters()
# if norm_range_candidat is False:
# raise LimitsOverdefinedError("Daughter pdfs do not agree on a `norm_range` and no `norm_range`"
# "has been explicitly set.")
if isinstance(norm_range_candidat, ZfitSpace): # TODO(Mayou36, #77): different obs?
norm_range = norm_range_candidat
self._norm_range = norm_range
def _infer_norm_range_from_daughters(self):
norm_ranges = set(model.norm_range for model in self.models)
obs = set(norm_range.obs for norm_range in norm_ranges)
if len(norm_ranges) == 1:
return norm_ranges.pop()
elif len(obs) > 1: # TODO(Mayou36, #77): different obs?
return None
else:
return False
def _single_hook_unnormalized_pdf(self, x, component_norm_range, name):
if component_norm_range.limits is not None:
with self._set_component_norm_range(norm_range=component_norm_range):
return super()._single_hook_unnormalized_pdf(x, component_norm_range, name)
else:
return super()._single_hook_unnormalized_pdf(x, component_norm_range, name)
def _single_hook_integrate(self, limits, norm_range, name='_hook_integrate'):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_integrate(limits, norm_range, name)
def _single_hook_analytic_integrate(self, limits, norm_range, name="_hook_analytic_integrate"):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_analytic_integrate(limits, norm_range, name)
def _single_hook_numeric_integrate(self, limits, norm_range, name='_hook_numeric_integrate'):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_numeric_integrate(limits, norm_range, name)
def _single_hook_partial_integrate(self, x, limits, norm_range, name='_hook_partial_integrate'):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_partial_integrate(x, limits, norm_range, name)
def _single_hook_partial_analytic_integrate(self, x, limits, norm_range, name='_hook_partial_analytic_integrate'):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_partial_analytic_integrate(x, limits, norm_range, name)
def _single_hook_partial_numeric_integrate(self, x, limits, norm_range, name='_hook_partial_numeric_integrate'):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_partial_numeric_integrate(x, limits, norm_range, name)
def _single_hook_normalization(self, limits, name="_hook_normalization"):
with self._set_component_norm_range(norm_range=limits):
return super()._single_hook_normalization(limits, name)
def _single_hook_pdf(self, x, norm_range, name="_hook_pdf"):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_pdf(x, norm_range, name)
def _single_hook_log_pdf(self, x, norm_range, name):
with self._set_component_norm_range(norm_range=norm_range):
return super()._single_hook_log_pdf(x, norm_range, name)
def _single_hook_sample(self, n, limits, name):
with self._set_component_norm_range(norm_range=limits):
return super()._single_hook_sample(n, limits, name)
@property
def pdfs_extended(self):
return [pdf.is_extended for pdf in self.pdfs]
@property
def _models(self) -> List[ZfitModel]:
return self.pdfs
[docs]class SumPDF(BaseFunctor):
def __init__(self, pdfs: List[ZfitPDF], fracs: Optional[ztyping.ParamTypeInput] = None,
obs: ztyping.ObsTypeInput = None,
name: str = "SumPDF"):
"""Create the sum of the `pdfs` with `fracs` as coefficients.
Args:
pdfs (pdf): The pdfs to add.
fracs (iterable): coefficients for the linear combination of the pdfs. If pdfs are
extended, this throws an error.
- len(frac) == len(basic) - 1 results in the interpretation of a non-extended pdf.
The last coefficient will equal to 1 - sum(frac)
- len(frac) == len(pdf) each pdf in `pdfs` will become an extended pdf with the
given yield.
name (str):
"""
# Check user input, improve TODO
self._fracs = None
set_yield_at_end = False
pdfs = convert_to_container(pdfs)
self.pdfs = pdfs
if len(pdfs) < 2:
raise ValueError("Cannot build a sum of a single pdf")
if fracs is not None:
fracs = convert_to_container(fracs)
fracs = [convert_to_parameter(frac) for frac in fracs]
# check if all extended
extended_pdfs = self.pdfs_extended
implicit = None
extended = None
if all(extended_pdfs):
implicit = True
extended = True
# all extended except one -> fraction
elif sum(extended_pdfs) == len(extended_pdfs) - 1:
implicit = True
extended = False
# no pdf is extended -> using `fracs`
elif not any(extended_pdfs) and fracs is not None:
# make extended
if len(fracs) == len(pdfs):
implicit = False
extended = True
elif len(fracs) == len(pdfs) - 1:
implicit = False
extended = False
# catch if args don't fit known case
value_error = implicit is None or extended is None
if (implicit and fracs is not None) or value_error:
raise ModelIncompatibleError("Wrong arguments. Either"
"\n a) `pdfs` are not extended and `fracs` is given with length pdfs "
"(-> pdfs get extended) or pdfs - 1 (fractions)"
"\n b) all or all except 1 `pdfs` are extended and fracs is None.")
# create fracs if one is not extended
if not extended and implicit:
fracs = []
not_extended_position = None
new_pdfs = []
for i, pdf in enumerate(pdfs):
if pdf.is_extended:
fracs.append(pdf.get_yield())
pdf = pdf.copy()
pdf._set_yield_inplace(None) # make non-extended
else:
fracs.append(tf.constant(0., dtype=ztypes.float))
not_extended_position = i
new_pdfs.append(pdf)
pdfs = new_pdfs
remaining_frac = tf.constant(1., dtype=ztypes.float) - tf.add_n(fracs)
if run.numeric_checks:
assert_op = tf.Assert(tf.greater_equal(remaining_frac, tf.constant(0., dtype=ztypes.float)),
data=[remaining_frac]) # check fractions
deps = [assert_op]
else:
deps = []
with tf.control_dependencies(deps):
# TODO(Mayou36): always last position?
fracs[not_extended_position] = tf.identity(remaining_frac)
implicit = False # now it's explicit
elif not extended and not implicit:
remaining_frac = tf.constant(1., dtype=ztypes.float) - tf.add_n(fracs)
if run.numeric_checks:
assert_op = tf.Assert(tf.greater_equal(remaining_frac, tf.constant(0., dtype=ztypes.float)),
data=[remaining_frac]) # check fractions
deps = [assert_op]
else:
deps = []
with tf.control_dependencies(deps):
fracs.append(tf.identity(remaining_frac))
# make extended
elif extended and not implicit:
yields = fracs
pdfs = [pdf.create_extended(yield_) for pdf, yield_ in zip(pdfs, yields)]
implicit = True
elif extended and implicit:
yields = [pdf.get_yield() for pdf in pdfs]
if extended:
# TODO(Mayou36): convert to correct dtype
sum_yields = tf.reduce_sum(
input_tensor=[tf.convert_to_tensor(value=y, dtype_hint=ztypes.float) for y in yields])
yield_fracs = [yield_ / sum_yields for yield_ in yields]
self.fracs = yield_fracs
# self.fracs = yield_fracs
set_yield_at_end = True
self._maybe_extended_fracs = [tf.constant(1, dtype=ztypes.float)] * len(self.pdfs)
else:
self._maybe_extended_fracs = fracs
self.pdfs = pdfs
params = OrderedDict()
for i, frac in enumerate(self._maybe_extended_fracs):
params['frac_{}'.format(i)] = frac
super().__init__(pdfs=pdfs, obs=obs, params=params, name=name)
if set_yield_at_end:
self._set_yield_inplace(sum_yields)
@property
def fracs(self):
fracs = self._fracs
if fracs is None:
fracs = self._maybe_extended_fracs
return fracs
@fracs.setter
def fracs(self, value):
self._fracs = value
def _apply_yield(self, value: float, norm_range: ztyping.LimitsType, log: bool):
if all(self.pdfs_extended):
return value
else:
return super()._apply_yield(value=value, norm_range=norm_range, log=log)
def _unnormalized_pdf(self, x):
norm_range = self._get_component_norm_range()
return self._pdf(x=x, norm_range=norm_range)
# raise NotImplementedError
# pdfs = self.pdfs
# fracs = self.fracs
# func = tf.accumulate_n(
# [scale * pdf.unnormalized_pdf(x) for pdf, scale in zip(pdfs, fracs)])
# return func
def _pdf(self, x, norm_range):
pdfs = self.pdfs
fracs = self.fracs
prob = tf.add_n([pdf.pdf(x, norm_range=norm_range) * frac for pdf, frac in zip(pdfs, fracs)])
# prob = tf.accumulate_n([pdf.pdf(x, norm_range=norm_range) * scale for pdf, scale in zip(pdfs, fracs)])
return prob
def _set_yield(self, value: Union[Parameter, None]):
# TODO: what happens now with the daughters?
if all(self.pdfs_extended) and self.is_extended and value is not None: # to be able to set the yield in the
# beginning
raise AlreadyExtendedPDFError("Cannot set the yield of a PDF with extended daughters.")
elif all(self.pdfs_extended) and self.is_extended and value is None: # not extended anymore
reciprocal_yield = tf.math.reciprocal(self.get_yield())
self._maybe_extended_fracs = [reciprocal_yield] * len(self._maybe_extended_fracs)
else:
super()._set_yield(value=value)
@supports(norm_range=True, multiple_limits=True)
def _integrate(self, limits, norm_range):
pdfs = self.pdfs
fracs = self._maybe_extended_fracs
assert norm_range not in (None, False), "Bug, who requested an unnormalized integral?"
integrals = [pdf.integrate(limits=limits, norm_range=norm_range) for pdf in pdfs]
integrals = [integral * frac for integral, frac in zip(integrals, fracs)]
integral = tf.reduce_sum(input_tensor=integrals)
return integral
@supports(norm_range=True, multiple_limits=True)
def _analytic_integrate(self, limits, norm_range):
pdfs = self.pdfs
fracs = self._maybe_extended_fracs
assert norm_range not in (None, False), "Bug, who requested an unnormalized integral?"
try:
integrals = [pdf.analytic_integrate(limits=limits, norm_range=norm_range) for pdf in pdfs]
except NotImplementedError as original_error:
raise NotImplementedError("analytic_integrate of pdf {name} is not implemented in this"
" SumPDF, as at least one sub-pdf does not implement it."
"Original message:\n{error}".format(name=self.name,
error=original_error))
integrals = [integral * frac for integral, frac in zip(integrals, fracs)]
integral = tf.reduce_sum(input_tensor=integrals)
return integral
@supports(norm_range=True, multiple_limits=True)
def _partial_integrate(self, x, limits, norm_range):
raise RuntimeError("Currently not available, cleanup with yields expected.")
# @supports()
# def _partial_analytic_integrate(self, x, limits, norm_range):
# pdfs = self.pdfs
# frac = self.fracs
# try:
# partial_integral = [pdf.analytic_integrate(limits=limits, norm_range=norm_range) for pdf in pdfs]
# except NotImplementedError as original_error:
# raise NotImplementedError("partial_analytic_integrate of pdf {name} is not implemented in this"
# " SumPDF, as at least one sub-pdf does not implement it."
# "Original message:\n{error}".format(name=self.name,
# error=original_error))
# partial_integral = tf.stack([partial_integral * s for pdf, s in zip(partial_integral, frac)])
# partial_integral = tf.reduce_sum(partial_integral, axis=0)
# return partial_integral
[docs]class ProductPDF(BaseFunctor): # TODO: unfinished
def __init__(self, pdfs: List[ZfitPDF], obs: ztyping.ObsTypeInput = None, name="ProductPDF"):
super().__init__(pdfs=pdfs, obs=obs, name=name)
def _unnormalized_pdf(self, x: ztyping.XType):
norm_range = self._get_component_norm_range()
return np.prod([pdf.unnormalized_pdf(x, component_norm_range=norm_range.get_subspace(obs=pdf.obs))
for pdf in self.pdfs])
def _pdf(self, x, norm_range):
if all(not dep for dep in self._model_same_obs):
probs = [pdf.pdf(x=x, norm_range=norm_range.get_subspace(obs=pdf.obs)) for pdf in self.pdfs]
return tf.reduce_prod(input_tensor=probs, axis=0)
else:
raise NotImplementedError