# Copyright (c) 2020 zfit
from typing import Dict, Union, Callable, Iterable
import tensorflow as tf
from ..core.basefunc import BaseFunc
from ..core.basemodel import SimpleModelSubclassMixin
from ..core.dependents import _extract_dependencies
from ..core.interfaces import ZfitModel, ZfitFunc
from ..core.space import supports
from ..models.basefunctor import FunctorMixin
from ..util import ztyping
from ..util.container import convert_to_container
[docs]class SimpleFunc(BaseFunc):
def __init__(self, obs: ztyping.ObsTypeInput, func: Callable, name: str = "Function", **params):
"""Create a simple function out of of `func` with the observables `obs` depending on `parameters`.
Args:
func (function):
obs (Union[str, Tuple[str]]):
name (str):
**params (): The parameters as keyword arguments. E.g. `mu=Parameter(...)`
"""
super().__init__(name=name, obs=obs, params=params)
self._value_func = self._check_input_x_function(func)
def _func(self, x):
try:
return self._value_func(x)
except TypeError: # self requested, TODO maybe check signature?
return self._value_func(self, x)
[docs]class BaseFunctorFunc(FunctorMixin, BaseFunc):
def __init__(self, funcs, name="BaseFunctorFunc", params=None, **kwargs):
funcs = convert_to_container(funcs)
if params is None:
params = {}
# for func in funcs:
# params.update(func.params)
self.funcs = funcs
super().__init__(name=name, models=self.funcs, params=params, **kwargs)
def _get_dependencies(self): # TODO: change recursive to `only_floating`?
dependents = super()._get_dependencies() # get the own parameter dependents
func_dependents = _extract_dependencies(self.funcs) # flatten
return dependents.union(func_dependents)
@property
def _models(self) -> Dict[Union[float, int, str], ZfitModel]:
return self.funcs
[docs]class SumFunc(BaseFunctorFunc):
def __init__(self, funcs: Iterable[ZfitFunc], obs: ztyping.ObsTypeInput = None, name: str = "SumFunc", **kwargs):
super().__init__(funcs=funcs, obs=obs, name=name, **kwargs)
def _func(self, x):
# sum_funcs = tf.add_n([func.value(x) for func in self.funcs])
funcs = [func.func(x) for func in self.funcs]
sum_funcs = tf.math.accumulate_n(funcs)
return sum_funcs
@supports()
def _analytic_integrate(self, limits, norm_range):
# below may raises AnalyticIntegralNotImplementedError, that's fine. We don't wanna catch that.
integrals = [func.analytic_integrate(limits=limits, norm_range=norm_range) for func in self.funcs]
return tf.math.accumulate_n(integrals)
[docs]class ProdFunc(BaseFunctorFunc):
def __init__(self, funcs: Iterable[ZfitFunc], obs: ztyping.ObsTypeInput = None, name: str = "SumFunc", **kwargs):
super().__init__(funcs=funcs, obs=obs, name=name, **kwargs)
def _func(self, x):
funcs = [func.func(x) for func in self.funcs]
product = tf.reduce_prod(funcs, axis=0)
return product
[docs]class ZFunc(SimpleModelSubclassMixin, BaseFunc):
def __init__(self, obs: ztyping.ObsTypeInput, name: str = "ZFunc", **params):
super().__init__(obs=obs, name=name, **params)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._check_simple_model_subclass()