Source code for zfit.models.functions

#  Copyright (c) 2025 zfit

from __future__ import annotations

import typing
from collections.abc import Callable, Iterable

if typing.TYPE_CHECKING:
    import zfit  # noqa: F401

import tensorflow as tf

import zfit.z.numpy as znp
from zfit._interfaces import ZfitFunc

from ..core.basefunc import BaseFuncV1
from ..core.basemodel import SimpleModelSubclassMixin
from ..core.space import supports
from ..models.basefunctor import FunctorMixin
from ..util import ztyping
from ..util.container import convert_to_container


[docs] class SimpleFuncV1(BaseFuncV1): def __init__( self, obs: ztyping.ObsTypeInput, func: Callable, name: str = "Function", **params, ): """Create a simple function out of of `func` with the observables `obs` depending on `parameters`. Args: func: obs: name: **params: The parameters as keyword arguments. E.g. `mu=Parameter(...)` """ super().__init__(name=name, obs=obs, params=params) self._value_func = self._check_input_x_function(func) def _func(self, x): try: return self._value_func(x) except TypeError: # self requested, TODO maybe check signature? return self._value_func(self, x)
class BaseFunctorFuncV1(FunctorMixin, BaseFuncV1): def __init__(self, funcs, name="BaseFunctorFunc", params=None, **kwargs): funcs = convert_to_container(funcs) if params is None: params = {} # for func in funcs: # params.update(func.params) self.funcs = funcs super().__init__(name=name, models=self.funcs, params=params, **kwargs) self._models = self.funcs
[docs] class SumFunc(BaseFunctorFuncV1): def __init__( self, funcs: Iterable[ZfitFunc], obs: ztyping.ObsTypeInput = None, name: str = "SumFunc", **kwargs, ): super().__init__(funcs=funcs, obs=obs, name=name, **kwargs) def _func(self, x): # sum_funcs = tf.add_n([func.value(x) for func in self.funcs]) funcs = [func.func(x) for func in self.funcs] return tf.math.accumulate_n(funcs) @supports() def _analytic_integrate(self, limits, norm): # below may raises AnalyticIntegralNotImplementedError, that's fine. We don't wanna catch that. integrals = [func.analytic_integrate(limits=limits, norm=norm) for func in self.funcs] return tf.math.accumulate_n(integrals)
[docs] class ProdFunc(BaseFunctorFuncV1): def __init__( self, funcs: Iterable[ZfitFunc], obs: ztyping.ObsTypeInput = None, name: str = "SumFunc", **kwargs, ): super().__init__(funcs=funcs, obs=obs, name=name, **kwargs) def _func(self, x): funcs = [func.func(x) for func in self.funcs] return znp.prod(funcs, axis=0)
class ZFuncV1(SimpleModelSubclassMixin, BaseFuncV1): def __init__(self, obs: ztyping.ObsTypeInput, name: str = "ZFunc", **params): super().__init__(obs=obs, name=name, **params) def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._check_simple_model_subclass()