Source code for zfit.models.functions

#  Copyright (c) 2024 zfit

from __future__ import annotations

from collections.abc import Callable, Iterable

import tensorflow as tf

import zfit.z.numpy as znp

from ..core.basefunc import BaseFuncV1
from ..core.basemodel import SimpleModelSubclassMixin
from ..core.dependents import _extract_dependencies
from ..core.interfaces import ZfitFunc
from ..core.space import supports
from ..models.basefunctor import FunctorMixin
from ..util import ztyping
from ..util.container import convert_to_container


[docs] class SimpleFuncV1(BaseFuncV1): def __init__( self, obs: ztyping.ObsTypeInput, func: Callable, name: str = "Function", **params, ): """Create a simple function out of of `func` with the observables `obs` depending on `parameters`. Args: func: obs: name: **params: The parameters as keyword arguments. E.g. `mu=Parameter(...)` """ super().__init__(name=name, obs=obs, params=params) self._value_func = self._check_input_x_function(func) def _func(self, x): try: return self._value_func(x) except TypeError: # self requested, TODO maybe check signature? return self._value_func(self, x)
class BaseFunctorFuncV1(FunctorMixin, BaseFuncV1): def __init__(self, funcs, name="BaseFunctorFunc", params=None, **kwargs): funcs = convert_to_container(funcs) if params is None: params = {} # for func in funcs: # params.update(func.params) self.funcs = funcs super().__init__(name=name, models=self.funcs, params=params, **kwargs) self._models = self.funcs def _get_dependencies(self): # TODO: change recursive to `only_floating`? dependents = super()._get_dependencies() # get the own parameter dependents func_dependents = _extract_dependencies(self.funcs) # flatten return dependents.union(func_dependents)
[docs] class SumFunc(BaseFunctorFuncV1): def __init__( self, funcs: Iterable[ZfitFunc], obs: ztyping.ObsTypeInput = None, name: str = "SumFunc", **kwargs, ): super().__init__(funcs=funcs, obs=obs, name=name, **kwargs) def _func(self, x): # sum_funcs = tf.add_n([func.value(x) for func in self.funcs]) funcs = [func.func(x) for func in self.funcs] return tf.math.accumulate_n(funcs) @supports() def _analytic_integrate(self, limits, norm): # below may raises AnalyticIntegralNotImplementedError, that's fine. We don't wanna catch that. integrals = [func.analytic_integrate(limits=limits, norm=norm) for func in self.funcs] return tf.math.accumulate_n(integrals)
[docs] class ProdFunc(BaseFunctorFuncV1): def __init__( self, funcs: Iterable[ZfitFunc], obs: ztyping.ObsTypeInput = None, name: str = "SumFunc", **kwargs, ): super().__init__(funcs=funcs, obs=obs, name=name, **kwargs) def _func(self, x): funcs = [func.func(x) for func in self.funcs] return znp.prod(funcs, axis=0)
class ZFuncV1(SimpleModelSubclassMixin, BaseFuncV1): def __init__(self, obs: ztyping.ObsTypeInput, name: str = "ZFunc", **params): super().__init__(obs=obs, name=name, **params) def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._check_simple_model_subclass()