Source code for zfit.models.basefunctor

#  Copyright (c) 2019 zfit

import abc
from typing import List, Union, Tuple

from ..core.basemodel import BaseModel
from ..core.dimension import get_same_obs, combine_spaces
from ..core.interfaces import ZfitFunctorMixin, ZfitModel
from ..core.limits import Space
from ..util.container import convert_to_container
from ..util.exception import NormRangeNotSpecifiedError, LimitsIncompatibleError


[docs]class FunctorMixin(ZfitFunctorMixin, BaseModel): def __init__(self, models, obs, **kwargs): models = convert_to_container(models, container=list) obs = self._check_extract_input_obs(obs=obs, models=models) super().__init__(obs=obs, **kwargs) self._model_obs = tuple(model.obs for model in models) def _infer_obs_from_daughters(self): obs = set(self._model_obs) if len(obs) == 1: return obs.pop() else: return False def _check_extract_input_obs(self, obs, models): # combine spaces and limits try: models_space = combine_spaces([model.space for model in models]) except LimitsIncompatibleError: # then only add obs extracted_obs = _extract_common_obs(obs=tuple(model.obs for model in models)) models_space = Space(obs=extracted_obs) if obs is None: obs = models_space else: if isinstance(obs, Space): obs_str = obs.obs else: obs_str = convert_to_container(value=obs, container=tuple) # if not frozenset(obs_str) == frozenset(models_space.obs): # not needed, example projection # raise ValueError("The given obs do not coincide with the obs from the daughter models.") return obs def _get_dependents(self): dependents = super()._get_dependents() # get the own parameter dependents model_dependents = self._extract_dependents(self.get_models()) return dependents.union(model_dependents) @property def models(self) -> List[ZfitModel]: """Return the models of this `Functor`. Can be `pdfs` or `funcs`.""" return self._models @property def _model_same_obs(self): return get_same_obs(self._model_obs) @property @abc.abstractmethod def _models(self) -> List[ZfitModel]: raise NotImplementedError
[docs] def get_models(self, names=None) -> List[ZfitModel]: if names is None: models = list(self.models) else: raise ValueError("name not supported currently.") # models = [self.models[name] for name in names] return models
def _check_input_norm_range_default(self, norm_range, caller_name="", none_is_error=True): if norm_range is None: try: norm_range = self.norm_range except AttributeError: raise NormRangeNotSpecifiedError("The normalization range is `None`, no default norm_range is set") return self._check_input_norm_range(norm_range=norm_range, caller_name=caller_name, none_is_error=none_is_error)
def _extract_common_obs(obs: Tuple[Union[Tuple[str], Space]]) -> Tuple[str]: obs_iter = [space.obs if isinstance(space, Space) else space for space in obs] unique_obs = [] for obs in obs_iter: for o in obs: if o not in unique_obs: unique_obs.append(o) return tuple(unique_obs)