Source code for zfit.models.functor

"""
Functors are functions that take typically one or more other PDF. Prominent examples are a sum, convolution etc.

A FunctorBase class is provided to make handling the models easier.

Their implementation is often non-trivial.
"""
#  Copyright (c) 2019 zfit

from collections import OrderedDict
import itertools
from typing import Union, List, Optional

import tensorflow as tf

import numpy as np

from zfit import z
from ..core.interfaces import ZfitPDF, ZfitModel, ZfitSpace
from ..core.limits import no_norm_range, supports
from ..core.basepdf import BasePDF
from ..core.parameter import Parameter, convert_to_parameter, ComposedParameter
from ..models.basefunctor import FunctorMixin
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.exception import (ExtendedPDFError, AlreadyExtendedPDFError, AxesNotUnambiguousError,
                              LimitsOverdefinedError,
                              ModelIncompatibleError, )
from ..util.temporary import TemporarilySet
from ..settings import ztypes, run


[docs]class BaseFunctor(FunctorMixin, BasePDF): def __init__(self, pdfs, name="BaseFunctor", **kwargs): self.pdfs = convert_to_container(pdfs) super().__init__(models=self.pdfs, name=name, **kwargs) self._set_norm_range_from_daugthers() self._component_norm_range_holder = None def _get_component_norm_range(self): return self._component_norm_range_holder def _set_component_norm_range(self, norm_range: ztyping.LimitsTypeInput): norm_range = self._check_input_norm_range(norm_range=norm_range) if norm_range.limits in (False, None): if self._get_component_norm_range() is None: raise RuntimeError("Cannot use `False` as `norm_range` without previously setting the " "`component_norm_range`.") def setter(value): self._component_norm_range_holder = value return TemporarilySet(value=norm_range, setter=setter, getter=self._get_component_norm_range) def _set_norm_range_from_daugthers(self): norm_range = super().norm_range if norm_range.limits is None: norm_range_candidat = self._infer_norm_range_from_daughters() # if norm_range_candidat is False: # raise LimitsOverdefinedError("Daughter pdfs do not agree on a `norm_range` and no `norm_range`" # "has been explicitly set.") if isinstance(norm_range_candidat, ZfitSpace): # TODO(Mayou36, #77): different obs? norm_range = norm_range_candidat self._norm_range = norm_range def _infer_norm_range_from_daughters(self): norm_ranges = set(model.norm_range for model in self.models) obs = set(norm_range.obs for norm_range in norm_ranges) if len(norm_ranges) == 1: return norm_ranges.pop() elif len(obs) > 1: # TODO(Mayou36, #77): different obs? return None else: return False def _single_hook_unnormalized_pdf(self, x, component_norm_range, name): if component_norm_range.limits is not None: with self._set_component_norm_range(norm_range=component_norm_range): return super()._single_hook_unnormalized_pdf(x, component_norm_range, name) else: return super()._single_hook_unnormalized_pdf(x, component_norm_range, name) def _single_hook_integrate(self, limits, norm_range, name='_hook_integrate'): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_integrate(limits, norm_range, name) def _single_hook_analytic_integrate(self, limits, norm_range, name="_hook_analytic_integrate"): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_analytic_integrate(limits, norm_range, name) def _single_hook_numeric_integrate(self, limits, norm_range, name='_hook_numeric_integrate'): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_numeric_integrate(limits, norm_range, name) def _single_hook_partial_integrate(self, x, limits, norm_range, name='_hook_partial_integrate'): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_partial_integrate(x, limits, norm_range, name) def _single_hook_partial_analytic_integrate(self, x, limits, norm_range, name='_hook_partial_analytic_integrate'): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_partial_analytic_integrate(x, limits, norm_range, name) def _single_hook_partial_numeric_integrate(self, x, limits, norm_range, name='_hook_partial_numeric_integrate'): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_partial_numeric_integrate(x, limits, norm_range, name) def _single_hook_normalization(self, limits, name="_hook_normalization"): with self._set_component_norm_range(norm_range=limits): return super()._single_hook_normalization(limits, name) def _single_hook_pdf(self, x, norm_range, name="_hook_pdf"): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_pdf(x, norm_range, name) def _single_hook_log_pdf(self, x, norm_range, name): with self._set_component_norm_range(norm_range=norm_range): return super()._single_hook_log_pdf(x, norm_range, name) def _single_hook_sample(self, n, limits, name): with self._set_component_norm_range(norm_range=limits): return super()._single_hook_sample(n, limits, name) @property def pdfs_extended(self): return [pdf.is_extended for pdf in self.pdfs] @property def _models(self) -> List[ZfitModel]: return self.pdfs
[docs]class SumPDF(BaseFunctor): def __init__(self, pdfs: List[ZfitPDF], fracs: Optional[ztyping.ParamTypeInput] = None, obs: ztyping.ObsTypeInput = None, name: str = "SumPDF"): """Create the sum of the `pdfs` with `fracs` as coefficients. Args: pdfs (pdf): The pdfs to add. fracs (iterable): coefficients for the linear combination of the pdfs. If pdfs are extended, this throws an error. - len(frac) == len(basic) - 1 results in the interpretation of a non-extended pdf. The last coefficient will equal to 1 - sum(frac) - len(frac) == len(pdf) each pdf in `pdfs` will become an extended pdf with the given yield. name (str): """ # Check user input, improve TODO self._fracs = None set_yield_at_end = False pdfs = convert_to_container(pdfs) self.pdfs = pdfs if len(pdfs) < 2: raise ValueError("Cannot build a sum of a single pdf") if fracs is not None: fracs = convert_to_container(fracs) fracs = [convert_to_parameter(frac) for frac in fracs] # check if all extended extended_pdfs = self.pdfs_extended implicit = None extended = None if all(extended_pdfs): implicit = True extended = True # all extended except one -> fraction elif sum(extended_pdfs) == len(extended_pdfs) - 1: implicit = True extended = False # no pdf is extended -> using `fracs` elif not any(extended_pdfs) and fracs is not None: # make extended if len(fracs) == len(pdfs): implicit = False extended = True elif len(fracs) == len(pdfs) - 1: implicit = False extended = False # catch if args don't fit known case value_error = implicit is None or extended is None if (implicit and fracs is not None) or value_error: raise ModelIncompatibleError("Wrong arguments. Either" "\n a) `pdfs` are not extended and `fracs` is given with length pdfs " "(-> pdfs get extended) or pdfs - 1 (fractions)" "\n b) all or all except 1 `pdfs` are extended and fracs is None.") # create fracs if one is not extended if not extended and implicit: fracs = [] not_extended_position = None new_pdfs = [] for i, pdf in enumerate(pdfs): if pdf.is_extended: fracs.append(pdf.get_yield()) pdf = pdf.copy() pdf._set_yield_inplace(None) # make non-extended else: fracs.append(tf.constant(0., dtype=ztypes.float)) not_extended_position = i new_pdfs.append(pdf) pdfs = new_pdfs copied_fracs = fracs.copy() remaining_frac_func = lambda: tf.constant(1., dtype=ztypes.float) - tf.add_n(copied_fracs) remaining_frac = convert_to_parameter(remaining_frac_func, dependents=[convert_to_parameter(f) for f in copied_fracs]) if run.numeric_checks: assert_op = tf.Assert(tf.greater_equal(remaining_frac, tf.constant(0., dtype=ztypes.float)), data=[remaining_frac]) # check fractions deps = [assert_op] else: deps = [] fracs[not_extended_position] = remaining_frac implicit = False # now it's explicit elif not extended and not implicit: # remaining_frac_func = lambda: tf.constant(1., dtype=ztypes.float) - tf.add_n(fracs) copied_fracs = fracs.copy() def remaining_frac_func(): return tf.constant(1., dtype=ztypes.float) - tf.add_n(copied_fracs) remaining_frac = convert_to_parameter(remaining_frac_func, dependents=[convert_to_parameter(f) for f in copied_fracs]) if run.numeric_checks: assert_op = tf.Assert(tf.greater_equal(remaining_frac, tf.constant(0., dtype=ztypes.float)), data=[remaining_frac]) # check fractions deps = [assert_op] else: deps = [] fracs.append(remaining_frac) # make extended elif extended and not implicit: yields = fracs pdfs = [pdf.create_extended(yield_) for pdf, yield_ in zip(pdfs, yields)] implicit = True elif extended and implicit: yields = [pdf.get_yield() for pdf in pdfs] if extended: # TODO(Mayou36): convert to correct dtype def sum_yields_func(): return tf.reduce_sum( input_tensor=[tf.convert_to_tensor(value=y, dtype_hint=ztypes.float) for y in yields.copy()]) sum_yields = convert_to_parameter(sum_yields_func, dependents=yields) yield_fracs = [convert_to_parameter(lambda yield_=yield_: yield_ / sum_yields, dependents=yield_) for yield_ in yields] self.fracs = yield_fracs set_yield_at_end = True self._maybe_extended_fracs = [tf.constant(1, dtype=ztypes.float)] * len(self.pdfs) else: self._maybe_extended_fracs = fracs self.pdfs = pdfs params = OrderedDict() # TODO(Mayou36): this is not right. Where to create the params if extended? The correct fracs? for i, frac in enumerate(self._maybe_extended_fracs): params['frac_{}'.format(i)] = frac super().__init__(pdfs=pdfs, obs=obs, params=params, name=name) if set_yield_at_end: self._set_yield_inplace(sum_yields) @property def fracs(self): fracs = self._fracs if fracs is None: fracs = self._maybe_extended_fracs return fracs @fracs.setter def fracs(self, value): self._fracs = value def _apply_yield(self, value: float, norm_range: ztyping.LimitsType, log: bool): if all(self.pdfs_extended): return value else: return super()._apply_yield(value=value, norm_range=norm_range, log=log) def _unnormalized_pdf(self, x): norm_range = self._get_component_norm_range() return self._pdf(x=x, norm_range=norm_range) # raise NotImplementedError # pdfs = self.pdfs # fracs = self.fracs # func = tf.accumulate_n( # [scale * pdf.unnormalized_pdf(x) for pdf, scale in zip(pdfs, fracs)]) # return func def _pdf(self, x, norm_range): pdfs = self.pdfs fracs = self.fracs prob = tf.add_n([pdf.pdf(x, norm_range=norm_range) * frac for pdf, frac in zip(pdfs, fracs)]) return prob def _set_yield(self, value: Union[Parameter, None]): # TODO: what happens now with the daughters? if all(self.pdfs_extended) and self.is_extended and value is not None: # to be able to set the yield in the # beginning raise AlreadyExtendedPDFError("Cannot set the yield of a PDF with extended daughters.") elif all(self.pdfs_extended) and self.is_extended and value is None: # not extended anymore reciprocal_yield = convert_to_parameter(lambda: tf.math.reciprocal(self.get_yield()), dependents=self.get_yield()) self._maybe_extended_fracs = [reciprocal_yield] * len(self._maybe_extended_fracs) else: super()._set_yield(value=value) @supports(norm_range=True, multiple_limits=True) def _integrate(self, limits, norm_range): pdfs = self.pdfs fracs = self._maybe_extended_fracs assert norm_range not in (None, False), "Bug, who requested an unnormalized integral?" integrals = [pdf.integrate(limits=limits, norm_range=norm_range) for pdf in pdfs] integrals = [integral * frac for integral, frac in zip(integrals, fracs)] integral = tf.reduce_sum(input_tensor=integrals) return integral @supports(norm_range=True, multiple_limits=True) def _analytic_integrate(self, limits, norm_range): pdfs = self.pdfs fracs = self._maybe_extended_fracs assert norm_range not in (None, False), "Bug, who requested an unnormalized integral?" try: integrals = [pdf.analytic_integrate(limits=limits, norm_range=norm_range) for pdf in pdfs] except NotImplementedError as original_error: raise NotImplementedError("analytic_integrate of pdf {name} is not implemented in this" " SumPDF, as at least one sub-pdf does not implement it." "Original message:\n{error}".format(name=self.name, error=original_error)) integrals = [integral * frac for integral, frac in zip(integrals, fracs)] integral = tf.reduce_sum(input_tensor=integrals) return integral @supports(norm_range=True, multiple_limits=True) def _partial_integrate(self, x, limits, norm_range): raise RuntimeError("Currently not available, cleanup with yields expected.")
# @supports() # def _partial_analytic_integrate(self, x, limits, norm_range): # pdfs = self.pdfs # frac = self.fracs # try: # partial_integral = [pdf.analytic_integrate(limits=limits, norm_range=norm_range) for pdf in pdfs] # except NotImplementedError as original_error: # raise NotImplementedError("partial_analytic_integrate of pdf {name} is not implemented in this" # " SumPDF, as at least one sub-pdf does not implement it." # "Original message:\n{error}".format(name=self.name, # error=original_error)) # partial_integral = tf.stack([partial_integral * s for pdf, s in zip(partial_integral, frac)]) # partial_integral = tf.reduce_sum(partial_integral, axis=0) # return partial_integral
[docs]class ProductPDF(BaseFunctor): # TODO: unfinished def __init__(self, pdfs: List[ZfitPDF], obs: ztyping.ObsTypeInput = None, name="ProductPDF"): super().__init__(pdfs=pdfs, obs=obs, name=name) def _unnormalized_pdf(self, x: ztyping.XType): norm_range = self._get_component_norm_range() return tf.math.reduce_prod([pdf.unnormalized_pdf(x, component_norm_range=norm_range.get_subspace(obs=pdf.obs)) for pdf in self.pdfs], axis=0) def _pdf(self, x, norm_range): if all(not dep for dep in self._model_same_obs): probs = [pdf.pdf(x=x, norm_range=norm_range.get_subspace(obs=pdf.obs)) for pdf in self.pdfs] return tf.reduce_prod(input_tensor=probs, axis=0) else: raise NotImplementedError