Source code for zfit.core.baseobject
"""Baseclass for most objects appearing in zfit."""
# Copyright (c) 2019 zfit
import abc
from collections import OrderedDict
import itertools
from typing import List, Set
import tensorflow as tf
from ordered_set import OrderedSet
import zfit
from ..util.cache import Cachable
from ..util import ztyping
from .interfaces import ZfitObject, ZfitNumeric, ZfitDependentsMixin
from ..util.container import convert_to_container, DotDict
_COPY_DOCSTRING = """Creates a copy of the {zfit_type}.
Note: the copy {zfit_type} may continue to depend on the original
initialization arguments.
Args:
name (str):
**overwrite_parameters: String/value dictionary of initialization
arguments to override with new value.
Returns:
{zfit_type}: A new instance of `type(self)` initialized from the union
of self.parameters and override_parameters_kwargs, i.e.,
`dict(self.parameters, **overwrite_params)`.
"""
[docs]class BaseObject(ZfitObject):
def __init__(self, name, **kwargs):
assert not kwargs, "kwargs not empty, the following arguments are not captured: {}".format(kwargs)
super().__init__()
self._name = name # TODO: uniquify name?
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._repr = DotDict() # TODO: make repr more sophisticated
# cls._repr.zfit_type = cls
# cls.copy.__doc__ = _COPY_DOCSTRING.format(zfit_type=cls.__name__)
@property
def name(self) -> str:
"""The name of the object."""
return self._name
[docs] def copy(self, deep: bool = False, name: str = None, **overwrite_params) -> "ZfitObject":
new_object = self._copy(deep=deep, name=name, overwrite_params=overwrite_params)
return new_object
def _copy(self, deep, name, overwrite_params): # TODO(Mayou36) L: representation?
if deep:
raise NotImplementedError("Unfortunately, this feature is not implemented.")
if name is None:
name = self.name + "_copy" # TODO: improve name mangling
# params = self.parameters.copy()
raise RuntimeError("This copy should not be used.")
# params.update(overwrite_params)
# new_object = type(self)(name=name, **params)
# return new_object
def __eq__(self, other: object) -> bool:
if not isinstance(self, type(other)):
return False
for key, own_element in self._repr.items():
if not own_element == other._repr.get(key): # TODO: make repr better
return False
return True
def __hash__(self):
return object.__hash__(self)
[docs]class BaseDependentsMixin(ZfitDependentsMixin):
@abc.abstractmethod
def _get_dependents(self) -> ztyping.DependentsType:
raise NotImplementedError
[docs] def get_dependents(self, only_floating: bool = True) -> ztyping.DependentsType:
"""Return a set of all independent :py:class:`~zfit.Parameter` that this object depends on.
Args:
only_floating (bool): If `True`, only return floating :py:class:`~zfit.Parameter`
"""
dependents = self._get_dependents()
if only_floating:
dependents = OrderedSet(filter(lambda p: p.floating, dependents))
return dependents
@staticmethod
def _extract_dependents(zfit_objects: List[ZfitObject]) -> ztyping.DependentsType:
"""Calls the :py:meth:`~BaseDependentsMixin.get_dependents` method on every object and returns a combined set.
Args:
zfit_objects ():
Returns:
set(zfit.Parameter): A set of independent Parameters
"""
zfit_objects = convert_to_container(zfit_objects)
dependents = (obj.get_dependents(only_floating=False) for obj in zfit_objects)
dependents_set = OrderedSet(itertools.chain.from_iterable(dependents)) # flatten
return dependents_set
[docs]class BaseNumeric(Cachable, BaseDependentsMixin, ZfitNumeric, BaseObject):
def __init__(self, name, dtype, params, **kwargs):
super().__init__(name=name, **kwargs)
from zfit.core.parameter import convert_to_parameter
self._dtype = dtype
params = params or OrderedDict()
params = OrderedDict(sorted((n, convert_to_parameter(p)) for n, p in params.items()))
self.add_cache_dependents(params.values())
# parameters = OrderedDict(sorted(parameters)) # to always have a consistent order
self._params = params
self._repr.params = self.params
@property
def dtype(self) -> tf.DType:
"""The dtype of the object"""
return self._dtype
@property
def params(self) -> ztyping.ParametersType:
return self._params
[docs] def get_params(self, only_floating: bool = False, names: ztyping.ParamsNameOpt = None) -> List["ZfitParameter"]:
"""Return the parameters. If it is empty, automatically return all floating variables.
Args:
only_floating (): If True, return only the floating parameters.
names (): The names of the parameters to return.
Returns:
list(`ZfitParameters`):
"""
if isinstance(names, str):
names = (names,)
if names is not None:
missing_names = set(names).difference(self.params.keys())
if missing_names:
raise KeyError("The following names are not valid parameter names")
params = [self.params[name] for name in names]
else:
params = list(self.params.values())
if only_floating:
params = self._filter_floating_params(params=params)
return params
@staticmethod
def _filter_floating_params(params):
params = [param for param in params if param.floating]
return params