# Copyright (c) 2025 zfit
from __future__ import annotations
import typing
from collections.abc import Callable, Iterable
if typing.TYPE_CHECKING:
import zfit # noqa: F401
import tensorflow as tf
import zfit.z.numpy as znp
from ..core.basefunc import BaseFuncV1
from ..core.basemodel import SimpleModelSubclassMixin
from ..core.interfaces import ZfitFunc
from ..core.space import supports
from ..models.basefunctor import FunctorMixin
from ..util import ztyping
from ..util.container import convert_to_container
[docs]
class SimpleFuncV1(BaseFuncV1):
def __init__(
self,
obs: ztyping.ObsTypeInput,
func: Callable,
name: str = "Function",
**params,
):
"""Create a simple function out of of `func` with the observables `obs` depending on `parameters`.
Args:
func:
obs:
name:
**params: The parameters as keyword arguments. E.g. `mu=Parameter(...)`
"""
super().__init__(name=name, obs=obs, params=params)
self._value_func = self._check_input_x_function(func)
def _func(self, x):
try:
return self._value_func(x)
except TypeError: # self requested, TODO maybe check signature?
return self._value_func(self, x)
class BaseFunctorFuncV1(FunctorMixin, BaseFuncV1):
def __init__(self, funcs, name="BaseFunctorFunc", params=None, **kwargs):
funcs = convert_to_container(funcs)
if params is None:
params = {}
# for func in funcs:
# params.update(func.params)
self.funcs = funcs
super().__init__(name=name, models=self.funcs, params=params, **kwargs)
self._models = self.funcs
[docs]
class SumFunc(BaseFunctorFuncV1):
def __init__(
self,
funcs: Iterable[ZfitFunc],
obs: ztyping.ObsTypeInput = None,
name: str = "SumFunc",
**kwargs,
):
super().__init__(funcs=funcs, obs=obs, name=name, **kwargs)
def _func(self, x):
# sum_funcs = tf.add_n([func.value(x) for func in self.funcs])
funcs = [func.func(x) for func in self.funcs]
return tf.math.accumulate_n(funcs)
@supports()
def _analytic_integrate(self, limits, norm):
# below may raises AnalyticIntegralNotImplementedError, that's fine. We don't wanna catch that.
integrals = [func.analytic_integrate(limits=limits, norm=norm) for func in self.funcs]
return tf.math.accumulate_n(integrals)
[docs]
class ProdFunc(BaseFunctorFuncV1):
def __init__(
self,
funcs: Iterable[ZfitFunc],
obs: ztyping.ObsTypeInput = None,
name: str = "SumFunc",
**kwargs,
):
super().__init__(funcs=funcs, obs=obs, name=name, **kwargs)
def _func(self, x):
funcs = [func.func(x) for func in self.funcs]
return znp.prod(funcs, axis=0)
class ZFuncV1(SimpleModelSubclassMixin, BaseFuncV1):
def __init__(self, obs: ztyping.ObsTypeInput, name: str = "ZFunc", **params):
super().__init__(obs=obs, name=name, **params)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._check_simple_model_subclass()