Source code for zfit.core.space

#  Copyright (c) 2024 zfit

from __future__ import annotations

from typing import TYPE_CHECKING, Union

import pandas as pd
from numpy import ndarray
from pandas import DataFrame
from tensorflow import Tensor

from .serialmixin import SerializableMixin

if TYPE_CHECKING:
    import zfit

import functools
import inspect
import itertools
import warnings
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping
from contextlib import suppress

import numpy as np
import tensorflow as tf

import zfit
import zfit.z.numpy as znp

from .. import z
from .._variables.axis import Binnings, RegularBinning, histaxes_to_binning
from ..settings import ztypes
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.deprecation import deprecated, deprecated_args, deprecated_norm_range
from ..util.exception import (
    AxesIncompatibleError,
    AxesNotSpecifiedError,
    BreakingAPIChangeError,
    CannotConvertToNumpyError,
    CoordinatesIncompatibleError,
    CoordinatesUnderdefinedError,
    IllegalInGraphModeError,
    IntentionAmbiguousError,
    InvalidLimitSubspaceError,
    LimitsIncompatibleError,
    LimitsNotSpecifiedError,
    LimitsUnderdefinedError,
    MultipleLimitsNotImplemented,
    NormNotImplemented,
    NumberOfEventsIncompatibleError,
    ObsIncompatibleError,
    ObsNotSpecifiedError,
    OverdefinedError,
    ShapeIncompatibleError,
    SpaceIncompatibleError,
)
from .baseobject import BaseObject
from .coordinates import (
    Coordinates,
    _convert_obs_to_str,
    convert_to_axes,
    convert_to_obs_str,
)
from .dimension import common_axes, common_obs, limits_overlap
from .interfaces import (
    ZfitData,
    ZfitLimit,
    ZfitOrderableDimensional,
    ZfitPDF,
    ZfitSpace,
)


class LimitRangeDefinition:
    pass


# Singleton
class Any(LimitRangeDefinition):
    _singleton_instance = None

    def __new__(cls, *_, **__):
        instance = cls._singleton_instance
        if instance is None:
            instance = super().__new__(cls)
            cls._singleton_instance = instance

        return instance

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls._singleton_instance = None  # each subclass is a singleton of "itself"

    def __repr__(self):
        return "<Any>"

    def __lt__(self, other):
        return True

    def __le__(self, other):
        return True

    def __ge__(self, other):
        return True

    def __gt__(self, other):
        return True


class AnyLower(Any):
    def __repr__(self):
        return "<Any Lower Limit>"


class AnyUpper(Any):
    def __repr__(self):
        return "<Any Upper Limit>"


ANY = Any()
ANY_LOWER = AnyLower()
ANY_UPPER = AnyUpper()


class V1Space:
    def __init__(self, space: ZfitSpace):
        self.space = space

    @property
    def lower(self):
        return znp.atleast_1d(self.space.lower[0])

    @property
    def upper(self):
        return znp.atleast_1d(self.space.upper[0])

    @property
    def limits(self):
        return self.lower, self.upper

    @property
    def volume(self):
        return znp.atleast_1d(self.space._legacy_area())

    def area(self):
        return znp.atleast_1d(self.space.area())  # intentional, we want the deprecated one to trigger


class V0Space:
    def __init__(self, space: ZfitSpace):
        self.space = space

    @property
    def lower(self):
        return self.space.lower

    @property
    def upper(self):
        return self.space.upper

    @property
    def limits(self):
        return self.space.limits

    @property
    def volume(self):
        return self.space._legacy_area()

    def area(self):
        self.space.area()  # intentional, we want the deprecated one to trigger


class VectorizeLimits:
    def __init__(self, space):
        self.space = space

    @property
    def lower(self):
        return self.space.v1.lower[None, :]

    @property
    def upper(self):
        return self.space.v1.upper[None, :]

    @property
    def limits(self):
        return znp.stack([self.lower, self.upper], axis=0)

    @property
    def volume(self):
        return self.space.v1.volume[None, :]


# TODO(warning): set a changeable warning system in zfit
def fail_not_rect(func):
    def wrapped_func(*args, **kwargs):
        self = args[0]
        if self.has_limits and not self.has_rect_limits:
            msg = (
                f"Cannot call {func} as the space {self} has functional,"
                f" not rectangular limits. Use `rect_*` functions to obtain the"
                f" rectangular limits/area or `inside`/`filter` to test if values are"
                f" inside of the space."
            )
            raise RuntimeError(msg)

        return func(*args, **kwargs)

    return wrapped_func


@z.function(wraps="tensor", keepalive=True)
def calculate_rect_area(rect_limits):
    lower, upper = rect_limits
    diff = upper - lower
    return z.unstable.reduce_prod(diff, axis=-1)


@z.function(wraps="tensor", keepalive=True)
def inside_rect_limits(x, rect_limits):
    if (ndims := x.get_shape().ndims) is not None and ndims <= 1:
        msg = (
            "x has ndims <= 1, which is most probably not wanted. The default shape for array-like"
            " structures is (nevents, n_obs)."
        )
        raise ValueError(msg)
    lower, upper = z.unstack_x(rect_limits, axis=0)
    lower = z.convert_to_tensor(lower)
    upper = z.convert_to_tensor(upper)
    below_upper = znp.all(znp.less_equal(x, upper), axis=-1)  # if all obs inside
    above_lower = znp.all(znp.greater_equal(x, lower), axis=-1)
    return znp.logical_and(above_lower, below_upper)


@z.function(wraps="tensor", keepalive=True)
def filter_rect_limits(x, rect_limits, axis=None):
    return tf.boolean_mask(tensor=x, mask=inside_rect_limits(x, rect_limits=rect_limits, axis=axis))


def convert_to_tensor_or_numpy(obj, dtype=ztypes.float):
    if contains_tensor(obj):
        return znp.asarray(obj, dtype=dtype)
    else:
        with suppress(AttributeError):
            dtype = dtype.as_numpy_dtype
        return np.array(obj, dtype=dtype)


def _sanitize_x_input(x, n_obs):
    if isinstance(x, ZfitData):
        x = x.value()
    x = z.convert_to_tensor(x)
    if not x.shape.ndims > 1 and n_obs > 1:
        msg = (
            "x has ndims <= 1, which is most probably not wanted. The default shape for array-like"
            " structures is (nevents, n_obs)."
        )
        raise ValueError(msg)
    if x.shape.ndims <= 1 and n_obs == 1:
        x = tf.broadcast_to(x, (1, 1)) if x.shape.ndims == 0 else znp.expand_dims(x, axis=-1)
    if tf.get_static_value(x.shape[-1]) != n_obs:
        msg = (
            f"n_obs ({n_obs}) and the last dim of x (shape: {x.shape}) do not agree. Assuming x has shape (..., n_obs)"
        )
        raise ShapeIncompatibleError(msg)
    return x


def is_range_definition(limit):
    if isinstance(limit, (LimitRangeDefinition, ZfitSpace)):
        return True
    elif (isinstance(limit, np.ndarray) and limit.dtype != object) or tf.is_tensor(limit):
        return False
    try:
        return any(is_range_definition(lim) for lim in limit)
    except TypeError:
        return False  # not iterable and was not a LimitRangeDefinition in the beginning


# @tfp.experimental.auto_composite_tensor()
class Limit(
    ZfitLimit,
    # tfp.experimental.AutoCompositeTensor
):
    _experimental_allow_vectors = False

    def __init__(
        self,
        limit_fn: ztyping.LimitsFuncTypeInput = None,
        rect_limits: ztyping.LimitsTypeInput = None,
        n_obs: int | None = None,
    ):
        """Specify a limit with rectangular limits (and possiblty an arbitrary function).

        Args:
            limit_fn: Function that works as ``inside``: return true if a point is inside of the limits.
                The function should take one tensor-like argument with shape (..., n_obs) and should return
                a shape without the last dimension.
            rect_limits: Rectangular limits, a tuple of tensor-like objects with shape (typically) (1, n_obs) or similar
                such as only a tuple/list of values that will be interpreted as the last dimension. They should cover an
                area that includes ``limit_fn`` fully.
            n_obs: dimensionality of the Limits, the last dimension.
        """
        super().__init__()
        (
            limit_fn,
            rect_limits,
            n_obs,
            is_rect,
            sublimits,
        ) = self._check_convert_input_limits(limit_fn=limit_fn, rect_limits=rect_limits, n_obs=n_obs)
        self._limit_fn = limit_fn
        self._rect_limits = rect_limits
        self._n_obs = n_obs
        self._is_rect = is_rect
        self._sublimits = sublimits

    def _check_convert_input_limits(self, limit_fn, rect_limits, n_obs):
        if isinstance(limit_fn, ZfitLimit):
            if not isinstance(limit_fn, Limit):  # because of the limit_fn, that is private. Maybe use `inside` instead?
                msg = "If limits_fn is an instance of ZfitLimit, it has to be an instance of Limit (currently)"
                raise TypeError(msg)
            if rect_limits is not None or n_obs != limit_fn.n_obs:
                msg = "limits_fn is a ZfitLimit. rect_limits and n_obs must not be specified" "(or n_obs coincide)."
                raise OverdefinedError(msg)
            limit = limit_fn

            limit_fn = limit.limit_fn
            rect_limits = limit.rect_limits
            n_obs = limit.n_obs
        limits_are_rect = True

        # if the limits are False or None, we can take a shortcut and don't need to do any preprocessing
        return_limits_short = False
        if limit_fn is False:
            if rect_limits in (False, None):
                limits_short = False
                return_limits_short = True
        elif limit_fn is None:
            if rect_limits is False:
                limits_short = False
                return_limits_short = True
            elif rect_limits is None:
                limits_short = None
                return_limits_short = True
            else:  # start from limits are anything, rect is None
                limit_fn = rect_limits
                rect_limits = None
        if return_limits_short:
            sublimits = [type(self)(limit_fn=limits_short, n_obs=1) for _ in range(n_obs)] if n_obs > 1 else (self,)
            return limits_short, limits_short, n_obs, limits_short, sublimits

        if not callable(limit_fn):  # limits_fn is actually rect_limits
            rect_limits = limit_fn
            limit_fn = None

        else:
            limits_are_rect = False
            if rect_limits in (None, False):
                msg = "Limits given as a function need also rect_limits, cannot be None or False"
                raise ValueError(msg)
        try:
            lower, upper = rect_limits
        except TypeError as err:
            msg = "The outermost shape of `rect_limits` has to be 2 to represent (lower, upper)."
            raise TypeError(msg) from err

        lower = self._sanitize_rect_limit(lower)
        upper = self._sanitize_rect_limit(upper)

        # vectors means more than one n_events, in the first dim
        if not self._experimental_allow_vectors:
            lower_nevents = tf.get_static_value(lower.shape[0])
            upper_nevents = tf.get_static_value(upper.shape[0])
            if lower_nevents != 1 or upper_nevents != 1:
                msg = (
                    "Vectors (limits with n_events != 1) are not allowed. Experimental"
                    " flag (_experimental_allow_vectors) can be switched on if desired."
                    " This happened most likely due to the new Space limits layout:"
                    " To create multiple limits, use the addition operator of simple spaces."
                )
                raise LimitsIncompatibleError(msg)

        lower_nobs = tf.get_static_value(lower.shape[-1])

        if lower_nobs != (upper_nobs := tf.get_static_value(upper.shape[-1])):
            msg = f"Last dimension of lower ({lower_nobs}) and upper ({upper_nobs}) have to coincide."
            raise ShapeIncompatibleError(msg)
        if n_obs is not None and lower_nobs != n_obs:
            msg = f"Inferred last dimension ({lower_nobs}) does not coincide with " f"given n_obs ({n_obs})"
            raise ShapeIncompatibleError(msg)

        if not any(is_range_definition(limit) for limit in (lower, upper)):
            tf.assert_greater(
                upper,
                lower,
                message="All upper limits have to be larger than the lower limits and are"
                " given as (lower, upper). Maybe (upper, lower) was entered?",
            )

        n_obs = lower_nobs  # in case it was None
        rect_limits = (lower, upper)

        # It can be that there is a function that depends on multiple dimensions, e.g. if we have
        # a `limit_fn` and n_obs > 1. But if we have only rectangular limits, we can split them up
        # which allows later (the Space) to better combine and get subspaces

        # create sublimits to iterate if possible
        sublimits = []
        if limits_are_rect and n_obs > 1:
            for i in range(n_obs):
                low = z.unstable.gather(lower, (i,), axis=-1)
                up = z.unstable.gather(upper, (i,), axis=-1)
                sublimits.append(type(self)(rect_limits=(low, up), n_obs=1))
        else:
            sublimits.append(self)

        sublimits = tuple(sublimits)

        return limit_fn, rect_limits, n_obs, limits_are_rect, sublimits

    @staticmethod
    def _sanitize_rect_limit(limit) -> ztyping.RectLowerReturnType:
        """Sanitize the input limit and return if it is numerical or not.

        Args:
            limit:

        Returns:
        """
        dtype = object if is_range_definition(limit) else ztypes.float  # as the above ANY
        limit = convert_to_tensor_or_numpy(limit, dtype=dtype)
        if len(limit.shape) == 0:
            limit = z.unstable.broadcast_to(limit, shape=(1, 1))
        if len(limit.shape) == 1:
            limit = z.unstable.expand_dims(limit, axis=0)
        return limit

    @property
    def has_rect_limits(self) -> bool:
        """If the limits are rectangular."""
        return self.has_limits and self._is_rect

    @property
    def limits(self) -> ztyping.LimitsReturnType:
        """Return the limits as a tuple of the limit function and the rectangular limits.

        Returns:
            A tuple of the limit function and the rectangular limits.
        """
        return self.rect_limits

    @property  # todo: remove, legacy object
    def rect_limits(self) -> ztyping.RectLimitsReturnType:
        """Return the rectangular limits as ``np.ndarray/tf.Tensor`` if they are set and not false.

            The rectangular limits can be used for sampling. They do not in general represent the limits
            of the object as a functional limit can be set and to check if something is inside the limits,
            the method :py:meth:`~Limit.inside` should be used.

            In order to test if the limits are False or None, it is recommended to use the appropriate methods
            ``limits_are_false`` and ``limits_are_set``.

        Returns:
            The lower and upper limits.
        Raises:
            LimitsNotSpecifiedError: If there are not limits set or they are False.
        """
        if not self.has_limits:
            msg = "Limits are False or not set, cannot return the rectangular limits."
            raise LimitsNotSpecifiedError(msg)
        return self._rect_limits

    @property
    def _rect_limits_tf(self) -> ztyping.RectLimitsTFReturnType:
        rect_limits = self._rect_limits
        if rect_limits in (None, False):
            return rect_limits
        lower = z.convert_to_tensor(rect_limits[0])
        upper = z.convert_to_tensor(rect_limits[1])
        return z.convert_to_tensor((lower, upper))

    @property
    def rect_limits_np(self) -> ztyping.RectLimitsNPReturnType:
        """Return the rectangular limits as ``np.ndarray``. Raise error if not possible.

        Rectangular limits are returned as numpy arrays which can be useful when doing checks that do not
        need to be involved in the computation later on as they allow direct interaction with Python as
        compared to ``tf.Tensor`` inside a graph function.


        Returns:
            A tuple of two ``np.ndarray`` with shape (1, n_obs) typically. The last
                dimension is always ``n_obs``, the first can be vectorized. This allows unstacking
                with `z.unstack_x()` as can be done with data.

        Raises:
            CannotConvertToNumpyError: In case the conversion fails.
        """
        lower, upper = self._rect_limits

        lower = z.unstable._try_convert_numpy(lower)
        upper = z.unstable._try_convert_numpy(upper)
        return lower, upper

    @property
    def rect_lower(self) -> ztyping.RectLowerReturnType:
        """The lower, rectangular limits, equivalent to `rect_limits[0] with shape (..., n_obs)`

        Returns:
            The lower, rectangular limits as `np.ndarray` or `tf.Tensor`
        """
        return self.rect_limits[0]

    @property
    def rect_upper(self) -> ztyping.RectUpperReturnType:
        """The upper, rectangular limits, equivalent to `rect_limits[1]` with shape (..., n_obs)

        Returns:
            The lower, rectangular limits as `np.ndarray` or `tf.Tensor`
        """
        return self.rect_limits[1]

    def rect_area(self) -> float | np.ndarray | znp.array:
        """Calculate the total rectangular area of all the limits and axes.

        Useful, for example, for MC integration.
        """
        return calculate_rect_area(rect_limits=self._rect_limits_tf)

    def inside(self, x: ztyping.XTypeInput, guarantee_limits: bool = False) -> ztyping.XTypeReturnNoData:
        """Test if `x` is inside the limits.

        This function should be used to test if values are inside the limits. If the given x is already inside
        the rectangular limits, e.g. because it was sampled from within them

        Args:
            x: Values to be checked whether they are inside of the limits. The shape is expected to have the last
                dimension equal to n_obs.
            guarantee_limits: Guarantee that the values are already inside the rectangular limits.

        Returns:
            Return a boolean tensor-like object with the same shape as the input `x` except of the
                last dimension removed.
        """
        x = _sanitize_x_input(x, n_obs=self.n_obs)
        if not self.has_limits:
            msg = "Cannot call `inside` without limits defined."
            raise LimitsNotSpecifiedError(msg)
        if guarantee_limits and self.has_rect_limits:
            return tf.broadcast_to(True, x.shape)
        else:
            return self._inside(x, guarantee_limits)

    def _inside(self, x, guarantee_limits):
        del guarantee_limits
        if self.has_rect_limits:
            return inside_rect_limits(x, rect_limits=self._rect_limits_tf)
        else:
            return self._limit_fn(x)

    def filter(
        self,
        x: ztyping.XTypeInput,
        guarantee_limits: bool = False,
        axis: int | None = None,
    ) -> ztyping.XTypeReturnNoData:
        """Filter `x` by removing the elements along `axis` that are not inside the limits.

        This is similar to `tf.boolean_mask`.

        Args:
            x: Values to be checked whether they are inside of the limits. If not, the corresonding element (in the
                specified `axis`) is removed. The shape is expected to have the last dimension equal to n_obs.
            guarantee_limits: Guarantee that the values are already inside the rectangular limits.
            axis: The axis to remove the elements from. Defaults to 0.

        Returns:
            Return an object with the same shape as `x` except that along `axis` elements have been
                removed.
        """

        if not self.has_limits:
            msg = "Cannot call `filter` without limits defined."
            raise LimitsNotSpecifiedError(msg)
        x = _sanitize_x_input(x, n_obs=self.n_obs)

        # shortcut, everything already inside
        if guarantee_limits and self.has_rect_limits:
            return x

        return self._filter(x, guarantee_limits, axis=axis)

    def _filter(self, x, guarantee_limits, axis):
        return tf.boolean_mask(tensor=x, mask=self.inside(x, guarantee_limits=guarantee_limits), axis=axis)

    @property
    def limit_fn(self):
        return self._limit_fn

    @property
    def rect_limits_are_tensors(self) -> bool:
        """Return True if the rectangular limits are tensors.

        If a limit with tensors is evaluated inside a graph context, comparison operations will fail.

        Returns:
            If the rectangular limits are tensors.
        """
        try:
            _ = self.rect_limits_np
        except CannotConvertToNumpyError:
            return True
        else:
            return False

    @property
    def limits_are_set(self) -> bool:
        """If the limits have never explicitly been set to a limit or to False.

        Returns:
        """
        return self._rect_limits is not None

    @property
    def limits_are_false(self) -> bool:
        """If the limits have been set to False, so the object on purpose does not contain limits.

        Returns:
        """
        return self._rect_limits is False

    @property
    def has_limits(self) -> bool:
        """If there are limits set and they are not false.

        Returns:
        """
        return not (self.limits_are_false or (not self.limits_are_set))

    @property
    def n_obs(self) -> int:
        """Dimensionality, the number of observables, of the limits. Equals to the last axis in rectangular limits.

        Returns:
            Dimensionality of the limits.
        """
        return self._n_obs

    @property
    def n_events(self) -> int | None:
        """

        Returns:
            Return the number of events, the dimension of the first shape. If this is > 1 or None,
                it's vectorized.
        """
        if not self.has_limits:
            return 1
        return self.rect_lower.shape[0]

    def equal(self, other: object, allow_graph: bool = True) -> znp.array:
        """Compare the limits on equality. For ANY objects, this also returns true.

        If called inside a graph context *and* the limits are tensors, this will return a symbolic `tf.Tensor`.

        Args:
            other: Any other object to compare with
            allow_graph: If False and the function returns a symbolic tensor, raise IllegalInGraphModeError instead.

        Returns:
            A ``znp.array`` with the result of the comparison.
         Raises:
             IllegalInGraphModeError: if `allow_graph`
        """
        if not isinstance(other, ZfitLimit):
            return np.array(False)
        return equal_limits(self, other, allow_graph=allow_graph)

    def __eq__(self, other: object) -> bool:
        """Compares two Limits for equality without graph mode allowed.

        Returns:

        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, ZfitLimit):
            return NotImplemented
        return self.equal(other, allow_graph=False)

    def less_equal(self, other: object, allow_graph: bool = True) -> znp.array:
        """Set-like comparison for compatibility. If an object is less_equal to another, the limits are combatible.

        This can be used to determine whether a fitting range specification can handle another limit.

        If called inside a graph context *and* the limits are tensors, this will return a symbolic `tf.Tensor`.


        Args:
            other: Any other object to compare with
            allow_graph: If False and the function returns a symbolic tensor, raise IllegalInGraphModeError instead.


        Returns:
            Result of the comparison
        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, ZfitLimit):
            return np.array(False)
        return less_equal_limits(self, other, allow_graph=allow_graph)

    def __le__(self, other: object) -> bool:
        """Set-like comparison for compatibility. If an object is less_equal to another, the limits are combatible.

        This can be used to determine whether a fitting range specification can handle another limit.

        Returns:
            Result of the comparison
        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, ZfitLimit):
            return NotImplemented
        return self.less_equal(other, allow_graph=False)

    def get_sublimits(self) -> Iterable[ZfitLimit]:
        """Splits itself into multiple sublimits with smaller n_obs.

        If this is not possible, if the limits are not rectangular, just returns itself.

        Returns:
            The sublimits if it was able to split.
        """
        return self._sublimits

    def __hash__(self) -> int:
        objects = (
            self._limit_fn,
            self.n_obs,
        )  # not rect limits, not hashable and unprecise
        return hash(tuple(objects))

    def __repr__(self) -> str:
        class_name = str(self.__class__).split(".")[-1].split("'")[0]
        if not self.limits_are_set:
            limits = None
        elif self.limits_are_false:
            limits = False
        elif self.n_obs < 5 and not self.n_events > 1:
            limits = self.v1.limits
        else:
            limits = "rectangular"

        return f"<zfit {class_name} rect_limits={limits}, limit_fn={not self.has_rect_limits}>"


def rect_limits_are_any(limit: ZfitLimit) -> bool:
    """True if all limits in limit are ANY objects."""
    if limit.rect_limits_are_tensors:
        return False
    return bool(all(isinstance(ele, Any) for lim in limit.rect_limits_np for ele in lim.flatten()))


def less_equal_limits(limit1: Limit, limit2: Limit, allow_graph=True) -> znp.array:
    if rect_limits_are_any(limit1) or rect_limits_are_any(limit2):
        return np.array(True)

    try:
        lower1, upper1 = limit1.rect_limits_np
        lower2, upper2 = limit2.rect_limits_np
    except CannotConvertToNumpyError as error:
        if not allow_graph:
            msg = (
                "Cannot use equality in graph mode, e.g. inside a `tf.function` decorated "
                "function. To retrieve a symbolic Tensor, use `.equal(..., allow_graph=True)`"
            )
            raise IllegalInGraphModeError(msg) from error

        lower1, upper1 = limit1.rect_limits
        lower2, upper2 = limit2.rect_limits

    lower_le = z.unstable.reduce_all(z.unstable.less_equal(lower1, lower2), axis=-1)
    upper_le = z.unstable.reduce_all(z.unstable.less_equal(upper1, upper2), axis=-1)
    rect_limits_le = z.unstable.logical_and(lower_le, upper_le)
    # if both are functional, they have to coincide
    if not (limit1.has_rect_limits or limit2.has_rect_limits):
        funcs_equal = limit1.limit_fn == limit2.limit_fn

    # if one is functional, one is rect: the bigger one can be rect
    elif not limit1.has_rect_limits and limit2.has_rect_limits:
        funcs_equal = np.array(True)
    else:
        funcs_equal = limit1.limit_fn == limit2.limit_fn
    return z.unstable.logical_and(rect_limits_le, funcs_equal)


def equal_limits(limit1: Limit, limit2: Limit, allow_graph=True) -> bool:
    # if both are functional, we just need to compare their functions; the rect limits are "irrelevant"
    if not (limit1.has_rect_limits or limit2.has_rect_limits):
        return np.array(limit1.limit_fn == limit2.limit_fn)

    # if one is functional, one is rect: they are not the same
    elif limit1.has_rect_limits ^ limit2.has_rect_limits:
        return np.array(False)

    try:
        lower, upper = limit1.rect_limits_np
        lower_other, upper_other = limit2.rect_limits_np
    except CannotConvertToNumpyError as error:
        # todo: why does it fail without False?
        if not allow_graph:
            msg = (
                "Cannot use equality in graph mode, e.g. inside a `tf.function` decorated "
                "function. To retrieve a symbolic Tensor, use `.equal(..., allow_graph=True)`"
            )
            raise IllegalInGraphModeError(msg) from error

        lower, upper = limit1.rect_limits
        lower_other, upper_other = limit2.rect_limits

    # TODO add tols
    lower_limits_equal = z.unstable.reduce_all(z.unstable.allclose_anyaware(lower, lower_other))
    upper_limits_equal = z.unstable.reduce_all(z.unstable.allclose_anyaware(upper, upper_other))
    rect_limits_equal = z.unstable.logical_and(lower_limits_equal, upper_limits_equal)
    funcs_equal = limit1.limit_fn == limit2.limit_fn
    return z.unstable.logical_and(rect_limits_equal, funcs_equal)


def deprecate_multispace(func):
    msg = (
        "The functionality of using a MultiSpace, that is, a space with multiple limits,"
        " has been deprecated in favor of using the `TruncatedPDF`, or generally, a list of spaces, if it is supported."
        "The `MultiSpace` will be removed in the future, as it is too complicated with too little usage."
        "If you think a use-case is missing or have any questions, please open an issue on GitHub: https://github.com/zfit/zfit/issues/new"
    )

    return deprecated(None, msg)(func)


warned_new_limits = False


def warn_new_limits(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        global warned_new_limits
        if not warned_new_limits:
            msg = (
                "The attribute `limits`, `lower` and `upper` will have a changed behavior with the future 1.0 release."
                "They will have one dimension less, making them more intuitive: `limits` will return a tuple of one dimensional arrays,"
                " each dimension corresponding to one observable. `lower` and `upper` will return a one dimensional array for each observable."
                "Furthermore"
                "To stay compatible and not experience breaking changes, use `Space.v1.limits`, `Space.v1.lower` and `Space.v1.upper` for the future behavior, "
                "or use `Space.v0.limits`, `Space.v0.lower` and `Space.v0.upper` to keep the current behavior (discouraged)."
            )
            warnings.warn(msg, FutureWarning, stacklevel=2)
            warned_new_limits = True
        return func(*args, **kwargs)

    return wrapper


class BaseSpace(ZfitSpace, BaseObject):
    def __init__(self, obs, axes, name, **kwargs):
        super().__init__(name, **kwargs)
        coords = Coordinates(obs, axes)
        self.coords = coords

    @property
    def is_binned(self):
        return self.binning is not None

    def inside(self, x: ztyping.XTypeInput, guarantee_limits: bool = False) -> ztyping.XTypeReturn:
        """Test if `x` is inside the limits.

        This function should be used to test if values are inside the limits. If the given x is already inside
        the rectangular limits, e.g. because it was sampled from within them

        Args:
            x: Values to be checked whether they are inside of the limits. The shape is expected to have the last
                dimension equal to n_obs.
            guarantee_limits: Guarantee that the values are already inside the rectangular limits.

        Returns:
            Return a boolean tensor-like object with the same shape as the input `x` except of the
                last dimension removed.
        """
        x = _sanitize_x_input(x, n_obs=self.n_obs)
        if self.has_rect_limits and guarantee_limits:
            return tf.broadcast_to(True, x.shape)
        return self._inside(x, guarantee_limits)

    @deprecated(
        None, "Use the `volume` attribute instead (don't call it, i.e. convert `space.area()` to `space.volume`."
    )
    def area(self) -> float | znp.array:
        return self._legacy_area()

    @property
    def volume(self) -> float | znp.array:
        return self._legacy_area()

    @abstractmethod
    def _inside(self, x, guarantee_limits):
        raise NotImplementedError

    def filter(
        self,
        x: ztyping.XTypeInput,
        guarantee_limits: bool = False,
        axis: int | None = None,
    ) -> ndarray | Tensor | DataFrame | Any:
        """Filter `x` by removing the elements along `axis` that are not inside the limits.

        This is similar to `tf.boolean_mask`.

        Args:
            x: Values to be checked whether they are inside of the limits. If not, the corresonding element (in the
                specified `axis`) is removed. The shape is expected to have the last dimension equal to n_obs.
            guarantee_limits: Guarantee that the values are already inside the rectangular limits.
            axis: The axis to remove the elements from. Defaults to 0.

        Returns:
            Return an object with the same shape as `x` except that along `axis` elements have been
                removed.
        """
        if axis is not None:
            msg = "Axis is not yet implemented."
            raise ValueError(msg)
        if self.has_rect_limits and guarantee_limits:
            return x
        if isinstance(x, pd.DataFrame):
            exprs = []
            if missing_obs := set(self.obs) - set(x.columns):  # todo: expose the expression?
                msg = f"Dataframe is missing the following observables: {missing_obs}"
                raise ValueError(msg)
            for ob, lower, upper in zip(self.obs, self.v1.lower, self.v1.upper):
                expr = f"({lower} <= {ob} <= {upper})"
                exprs.append(expr)
            expr = " & ".join(exprs)
            return x.query(expr)
        return self._filter(x, guarantee_limits)

    def _filter(self, x, guarantee_limits):
        if isinstance(x, ZfitData):
            x = x.value()
        return tf.boolean_mask(tensor=x, mask=self.inside(x, guarantee_limits=guarantee_limits))

    @property
    def n_obs(self) -> int:
        """Return the number of observables/axes.

        Returns:
        Returns:
            int >= 1
        """
        return self.coords.n_obs

    @property
    def obs(self) -> ztyping.ObsTypeReturn:
        """The observables ("axes with str")the space is defined in.

        Returns:
        """
        return self.coords.obs

    @property
    def axes(self) -> ztyping.AxesTypeReturn:
        """The axes ("obs with int") the space is defined in.

        Returns:
        """
        return self.coords.axes

    @property
    @deprecated(
        None,
        "Multiple limits won't be supported anymore in the future. For alternatives, see the announcement: https://github.com/zfit/zfit/discussions/533",
    )
    def n_limits(self) -> int:
        return self._depr_n_limits()

    @property
    def _depr_n_limits(self):
        return len(tuple(self))

    def __iter__(self) -> Iterable[ZfitSpace]:
        yield self

    def get_reorder_indices(self, obs: ztyping.ObsTypeInput = None, axes: ztyping.AxesTypeInput = None) -> tuple[int]:
        """Indices that would order the instances obs as `obs` respectively the instances axes as `axes`.

        Args:
            obs: Observables that the instances obs should be ordered to. Does not reorder, but just
                return the indices that could be used to reorder.
            axes: Axes that the instances obs should be ordered to. Does not reorder, but just
                return the indices that could be used to reorder.

        Returns:
            New indices that would reorder the instances obs to be obs respectively axes.

        Raises:
            CoordinatesUnderdefinedError: If neither `obs` nor `axes` is given
        """
        return self.coords.get_reorder_indices(obs=obs, axes=axes)

    # TODO: remove, in coords
    def _check_convert_input_axes(
        self, axes: ztyping.AxesTypeInput, allow_none: bool = False
    ) -> ztyping.AxesTypeReturn:
        if axes is None:
            if allow_none:
                return None
            else:
                msg = "TODO: Cannot be None"
                raise AxesNotSpecifiedError(msg)
        # TODO(Mayou36): extend like _check_obs?
        return axes.axes if isinstance(axes, ZfitSpace) else convert_to_container(value=axes, container=tuple)

    # TODO: remove, in coords
    def _check_convert_input_obs(self, obs: ztyping.ObsTypeInput, allow_none: bool = False) -> ztyping.ObsTypeReturn:
        """Input check: Convert `NOT_SPECIFIED` to None or check if obs are all strings.

        Args:
            obs:

        Returns:
        """
        if obs is None:
            if allow_none:
                return None
            else:
                msg = "TODO: Cannot be None"
                raise ObsNotSpecifiedError(msg)

        if isinstance(obs, ZfitSpace):
            obs = obs.obs
        else:
            obs = convert_to_container(obs, container=tuple)
            obs_not_str = tuple(o for o in obs if not isinstance(o, str))
            if obs_not_str:
                msg = f"The following observables are not strings: {obs_not_str}"
                raise ValueError(msg)
        return obs

    def _check_coords_allowed(
        self,
        obs: ztyping.ObsTypeInput = None,
        axes: ztyping.AxesTypeInput = None,
        allow_superset: bool = True,
        allow_subset: bool = True,
    ):
        to_check = []
        if obs is not None and self.obs is not None:
            to_check.append(obs, self.obs)
        if axes is not None and self.axes is not None:
            to_check.append(axes, self.axes)

        for coord, self_coord in to_check:
            coord = frozenset(coord)
            self_coord = frozenset(self_coord)
            if coord != self_coord:
                if not allow_superset and coord.issuperset(self_coord):
                    msg = f"Superset is not allowed, but {coord} is a superset" f" of {self_coord}"
                    raise CoordinatesIncompatibleError(msg)

                if not allow_subset and coord.issubset(self_coord):
                    msg = f"subset is not allowed, but {coord} is a subset" f" of {self_coord}"
                    raise CoordinatesIncompatibleError(msg)

    def __repr__(self):
        class_name = str(self.__class__).split(".")[-1].split("'")[0]
        if not self.limits_are_set:
            limits = None
        elif self.limits_are_false:
            limits = False
        elif self.has_rect_limits:
            limits = self.rect_limits if self.n_obs < 3 and not self.n_events > 1 else "rectangular"
        else:
            limits = "functional"
        return f"<zfit {class_name} obs={self.obs}, axes={self.axes}, limits={limits}, binned={self.is_binned}>"

    @deprecated(
        None,
        "Multiple limits won't be supported anymore in the future. For alternatives, see the announcement: https://github.com/zfit/zfit/discussions/533",
    )
    def __add__(self, other):
        if not isinstance(other, ZfitSpace):
            msg = f"Cannot add a {type(self)} and a {type(other)}"
            raise TypeError(msg)
        return add_spaces(self, other)

    def get_sublimits(self):
        limits = self.extract_limits()
        return list(limits.values())

    @deprecate_multispace
    def add(self, *other: ztyping.SpaceOrSpacesTypeInput):
        """Add the limits of the spaces. Only works for the same obs.

        In case the observables are different, the order of the first space is taken.

        Args:
            other:

        Returns:
            :py:class:`~zfit.Space`:
        """
        # other = convert_to_container(other, container=list)
        return add_spaces(self, *other)

    def combine(self, *other: ztyping.SpaceOrSpacesTypeInput) -> ZfitSpace:
        """Combine spaces with different obs (but consistent limits).

        Args:
            other:

        Returns:
            :py:class:`~zfit.Space`:
        """
        # other = convert_to_container(other, container=list)
        return combine_spaces(self, *other)

    def __mul__(self, other):
        return self.combine(other)

    def __ge__(self, other):
        return NotImplemented

    def equal(self, other: object, allow_graph: bool) -> znp.array:
        """Compare the limits on equality. For ANY objects, this also returns true.

        If called inside a graph context *and* the limits are tensors, this will return a symbolic `tf.Tensor`.

        Returns:
            Result of the comparison
        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, ZfitSpace):
            return False
        return equal_space(self, other, allow_graph=allow_graph)

    def __eq__(self, other: object) -> bool:
        """Compares two Limits for equality without graph mode allowed.

        Returns:

        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, ZfitSpace):
            return NotImplemented
        return self.equal(other=other, allow_graph=False)

    def less_equal(self, other, allow_graph):
        """Set-like comparison for compatibility. If an object is less_equal to another, the limits are combatible.

        This can be used to determine whether a fitting range specification can handle another limit.

        If called inside a graph context *and* the limits are tensors, this will return a symbolic `tf.Tensor`.


        Args:
            other: Any other object to compare with
            allow_graph: If False and the function returns a symbolic tensor, raise IllegalInGraphModeError instead.

        Returns:
            Result of the comparison
        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, ZfitSpace):
            return False
        return less_equal_space(self, other, allow_graph=allow_graph)

    def __le__(self, other: object) -> bool:
        """Set-like comparison for compatibility. If an object is less_equal to another, the limits are combatible.

        This can be used to determine whether a fitting range specification can handle another limit.

        Returns:
            Result of the comparison
        Raises:
             IllegalInGraphModeError: it the comparison happens with tensors in a graph context.
        """
        if not isinstance(other, type(self)):
            return NotImplemented
        return self.less_equal(other, allow_graph=False)

    def __hash__(self):
        limits_frozen = tuple(((key, tuple(ldict.items())) for key, ldict in self._limits_dict.items()))
        return hash((limits_frozen, hash(self.coords), hash(self.binning)))

    def reorder_x(self, x, x_obs, x_axes, func_obs, func_axes):
        return self.coords.reorder_x(x, x_obs=x_obs, x_axes=x_axes, func_obs=func_obs, func_axes=func_axes)

    @deprecated(
        None,
        "Multiple limits won't be supported anymore in the future. For alternatives, see the announcement: https://github.com/zfit/zfit/discussions/533",
    )
    def __len__(self):
        if not self:
            return 0
        else:
            return sum(1 for _ in self)

    def __bool__(self):
        return self.has_limits


# @tfp.experimental.auto_composite_tensor()
def _legacy_get_arguments_space(obs, args, limits, binning, axes, rect_limits, lower, upper):  # noqa: ARG001
    """Legacy function to get the arguments of a `Space`, i.e. making the transition to allow "lower" and "upper"
    smooth."""

    if len(args) > 2:  # catch legacy behavior
        msg = "The API for `Space` has changed and takes now lower, upper as the first two arguments. Cannot take more than 3 positional arguments"
        raise BreakingAPIChangeError(msg)

    if len(args) == 2:
        if all(isinstance(lim, (int, float, Any)) for lim in args):
            limits = args
        else:
            limits, binning = args
    elif len(args) == 1:
        limits = args[0]
    elif lower is not None and upper is not None:
        limits = [lower, upper]
        # elif lower is None and upper is None and limits is None and rect_limits is None:
        #     raise BreakingAPIChangeError(
        #         "The API for `Space` has changed and takes now lower, upper. Could not deduce the intention, as only one of lower, upper was given")

    return limits, binning


[docs] class Space( BaseSpace, SerializableMixin, # tfp.experimental.AutoCompositeTensor ): AUTO_FILL = object() ANY = ANY ANY_LOWER = ANY_LOWER # TODO: needed? or move everything inside? ANY_UPPER = ANY_UPPER @deprecated_args(None, "Use `lower` and `upper` instead", ("rect_limits", "limits")) def __init__( self, obs: ztyping.ObsTypeInput | None = None, *args, limits: ztyping.LimitsTypeInput | None = None, binning: ztyping.BinningTypeInput = None, axes=None, rect_limits=None, name: str | None = None, label: Union[str, Iterable[str]] | None = None, lower: ztyping.LimitsTypeInputV1 | None = None, upper: ztyping.LimitsTypeInputV1 | None = None, ): """Define a space with the name (`obs`) of the axes (and it's number) and possibly it's limits. A space can be thought of as coordinates, possibly with the definition of a range (limits). For most use-cases, it is sufficient to specify a `Space` via observables; simple string identifiers. They can be multidimensional. Observables are like the columns of a spreadsheet/dataframe, and are therefore needed for any object that does numerical operations or holds data in order to match the right axes. On object creation, the observables are assigned using a `Space`. This is often used as the default space of an object and can be used as the default `norm_range`, sampling limits etc. Axes are the same concept as observables, but numbers, indexes, and are used *inside* an object. There, axes 0 corresponds to the 0th data column we get (which corresponds to a certain observable). Args: obs: |@doc:space.init.obs| Observable of the space. Serves as the "variable". |@docend:space.init.obs| lower, upper: |@doc:space.init.lowerupper| Lower and upper limits of the space, respectively. Each of them should be a scalar-like object. |@docend:space.init.lowerupper| limits: |@doc:space.init.limits| A tuple-like object of the limits of the space. These are the lower and upper limits. |@docend:space.init.limits| binning: |@doc:space.init.binning| Binning of the space. Currently, only regular and variable binning *with a name* is supported. If an integer or a list of integers is given with lengths equal to the number of observables, it is interpreted as the number of bins and a regular binning is automatically created using the limits as the start and end points. |@docend:space.init.binning| name: |@doc:space.init.name| Name of the space. Maybe has implications on the serialization and deserialization of the space. For a human-readable name, use the label. |@docend:space.init.name| Raises TypeError: If the axes in the binning do not have a name. ObsIncompatibleError: If the obs do not agree with the name of the binning. ShapeIncompatibleError: If the shape of the limits or the binnings do not match the shape of the obs. """ self.v1 = V1Space(self) self.v0 = V0Space(self) self.vec = VectorizeLimits(self) limits, binning = _legacy_get_arguments_space(obs, args, limits, binning, axes, rect_limits, lower, upper) # limits = [lower, upper] # temporary if binning is None: binning = False if name is None: name = "Space" if binning is not False: integer_autobinning = isinstance(binning, int) or ( isinstance(binning, (list, tuple)) and all(isinstance(b, int) for b in binning) ) if not integer_autobinning: if not isinstance(binning, Binnings): binning = convert_to_container(binning) if binning is not None: binning = histaxes_to_binning(binning) if not all(binning.name): msg = f"Axes must have a name. Missing: {[axis for axis in binning if not hasattr(axis, 'name')]}" raise TypeError(msg) if obs is None and axes is None: obs = [axis.name for axis in binning] super().__init__(obs=obs, axes=axes, name=name) label = convert_to_container(label, container=tuple) if label is not None and len(label) != self.n_obs: msg = f"Number of labels ({len(label)}) does not match the number of observables ({self.n_obs})" raise ValueError(msg) if self.obs is not None: # if None, the obs are not given, only axes -> no labels if label is None: label = self.obs elif len(label) != self.n_obs: msg = f"Number of labels ({label}) does not match the number of observables ({self.obs})" raise ValueError(msg) label = dict(zip(self.obs, label)) self._labels = label if binning is not False and not isinstance(binning, int) and limits is None and rect_limits is None: limits = [[], []] for axis in binning: limits[0].append(axis.edges[0]) limits[1].append(axis.edges[-1]) limits_dict = self._check_convert_input_limits( limit=limits, rect_limits=rect_limits, obs=self.obs, axes=self.axes, n_obs=self.n_obs, ) self._limits_dict = limits_dict if binning is not False and isinstance(binning, int): binning = [binning] if binning is not False and integer_autobinning: if len(binning) != self.n_obs: msg = ( f"Wrong number ({len(binning)}) of integers given for regular binning" f" ({binning}) with {self.n_obs} observables ({self.obs})." f" Numbers have to match the number of observables." ) raise ShapeIncompatibleError(msg) regular_binnings = [] for i, nbins in enumerate(binning): if nbins < 1: msg = "If binning is an integer, it must be > 0" raise ValueError(msg) lower = self.lower[0][i] upper = self.upper[0][i] regular_binnings.append(RegularBinning(bins=nbins, start=lower, stop=upper, name=self.obs[i])) binning = Binnings(regular_binnings) if binning is not False: bining_names = set(binning.name) obs = set(self.obs) wrong_names = bining_names - obs if wrong_names: msg = f"Binning names ({wrong_names}) do not match observables ({obs}), {wrong_names} not in space." raise ObsIncompatibleError(msg) missing_obs = obs - bining_names if missing_obs: msg = f"Binning names ({missing_obs}) do not match observables ({obs}), missing {missing_obs}." raise ObsIncompatibleError(msg) binning = Binnings([binning[ob] for ob in self.obs]) self._binning = None if binning is False else binning @property def labels(self): if (obs := self.obs) is not None: return tuple(self._labels[ob] for ob in obs) return None # we have axis -> no labels @property def label(self): if self.n_obs > 1: msg = f"{self} has more than one observable, use `labels` instead." raise ValueError(msg) return self.labels[0] @property def binning(self): return self._binning # if binning_out is not None: # binning_out = @property def is_binned(self): return self.binning is not None def _check_convert_input_limits( self, limit: ztyping.LowerTypeInput | ztyping.UpperTypeInput, rect_limits, obs, axes, n_obs, ) -> ztyping.LowerTypeReturn | ztyping.UpperTypeReturn: """Check and sanitize the input limits as well as the rectangular limits. Args: limit: Returns: Limits dictionary containing the observables and/or the axes as a key matching `ZfitLimits` objects. """ limits_dict = defaultdict(dict) input_limits = limit if isinstance(input_limits, Space): space = input_limits # get the subset of obs/axes, then drop the other coord if not given if obs and not axes: space = space.with_obs(obs, allow_subset=True, allow_superset=True) space = space.with_axes(None) elif not obs and axes: space = space.with_axes(axes, allow_subset=True, allow_superset=True) space = space.with_obs(None) elif obs and axes: coords = Coordinates(obs=obs, axes=axes) space = space.with_coords(coords, allow_superset=True, allow_subset=True) input_limits = space.get_limits() obs = space.obs axes = space.axes if not isinstance(input_limits, dict) and not isinstance(rect_limits, dict): # if not input_limits and rect_limits: # input_limits = rect_limits limit = Limit(limit_fn=limit, rect_limits=rect_limits, n_obs=n_obs) i_old = 0 for lim in limit.get_sublimits(): # split into smaller ones if possible i = i_old + lim.n_obs if obs is not None: limits_dict["obs"][obs[i_old:i]] = lim if axes is not None: limits_dict["axes"][axes[i_old:i]] = lim i_old = i input_limits = limits_dict if isinstance(input_limits, dict): input_limits = input_limits.copy() elif isinstance(rect_limits, dict): input_limits = rect_limits.copy() if "axes" not in input_limits and "obs" not in input_limits: msg = "Probably internal error: wrong format of limits_dict" raise ValueError(msg) # check if obs is in the limits dict. If not, copy it from the axes if obs: if "obs" in input_limits: obs_limit_dict = input_limits["obs"] obs_limit_dict = {ob: lim for ob, lim in obs_limit_dict.items() if ob[0] in obs} else: obs_limit_dict = {} for axes_lim, lim in input_limits["axes"].items(): obs_coords = tuple(obs[axes.index(ax)] for ax in axes_lim) if isinstance(lim, ZfitOrderableDimensional): lim = lim.with_coords(self.space) obs_limit_dict[obs_coords] = lim limits_dict["obs"] = obs_limit_dict if axes: if "axes" in input_limits: axes_limit_dict = input_limits["axes"] axes_limit_dict = {axis: lim for axis, lim in axes_limit_dict.items() if axis[0] in axes} else: axes_limit_dict = {} for obs_lim, lim in input_limits["obs"].items(): axes_coords = tuple(axes[obs.index(ob)] for ob in obs_lim) if isinstance(lim, ZfitOrderableDimensional): lim = lim.with_coords(self.space) axes_limit_dict[axes_coords] = lim limits_dict["axes"] = axes_limit_dict if not axes and "axes" in limits_dict: limits_dict.pop("axes") if not obs and "obs" in limits_dict: limits_dict.pop("obs") return limits_dict def get_limits( self, obs: ztyping.ObsTypeInput = None, axes: ztyping.AxesTypeInput = None ) -> ztyping.LimitsDictWithCoords | ztyping.LimitsDictNoCoords: return_dict = {} both_none_or_true = obs is None and axes is None or obs is True and axes is True if obs is True and axes is None or both_none_or_true: if not self.obs: if obs is True: msg = "Obs are not defined for this instance, no limits set for obs." raise ObsIncompatibleError(msg) else: return_dict["obs"] = self._limits_dict["obs"].copy() if axes is True and obs is None or both_none_or_true: if not self.axes: if axes is True: msg = "Axes are not defined for this instance, no limits set for axes." raise AxesIncompatibleError(msg) else: return_dict["axes"] = self._limits_dict["axes"].copy() else: if obs: return_dict["obs"] = extract_limits_from_dict(self._limits_dict, obs=obs) if axes: return_dict["axes"] = extract_limits_from_dict(self._limits_dict, axes=axes) return return_dict @property @fail_not_rect def limits(self) -> ztyping.LimitsTypeReturn: """Return the limits. Returns: """ return self.rect_limits @property @deprecated( None, "Use for compatibility the v1.limits, v1.lower, v1.upper, as described here: https://github.com/zfit/zfit/discussions/533", ) def rect_limits(self) -> ztyping.RectLimitsReturnType: """Return the rectangular limits as `np.ndarray``tf.Tensor` if they are set and not false. The rectangular limits can be used for sampling. They do not in general represent the limits of the object as a functional limit can be set and to check if something is inside the limits, the method :py:meth:`~Limit.inside` should be used. In order to test if the limits are False or None, it is recommended to use the appropriate methods `limits_are_false` and `limits_are_set`. Returns: The lower and upper limits. Raises: LimitsNotSpecifiedError: If there are not limits set or they are False. """ return self._depr_rect_limits @property def _depr_rect_limits(self): if not self.has_limits: msg = "Limits are False or not set, cannot return the rectangular limits." raise LimitsNotSpecifiedError(msg) lower_ordered, upper_ordered = self._rect_limits_z() return lower_ordered, upper_ordered @property def _rect_limits_tf(self) -> ztyping.LimitsTypeReturn: """Return the limits as `tf.Tensor`. Returns: """ if not self.has_limits: msg = "Limits are False or not set, cannot return the rectangular limits." raise LimitsNotSpecifiedError(msg) lower_ordered, upper_ordered = self._rect_limits_z() return znp.asarray(lower_ordered), znp.asarray(upper_ordered) @property @deprecated( None, "Use for compatibility the v1.limits, v1.lower, v1.upper, as described here: https://github.com/zfit/zfit/discussions/533", ) def rect_limits_np(self) -> ztyping.RectLimitsNPReturnType: """Return the rectangular limits as `np.ndarray`. Raise error if not possible. Rectangular limits are returned as numpy arrays which can be useful when doing checks that do not need to be involved in the computation later on as they allow direct interaction with Python as compared to `tf.Tensor` inside a graph function. In order to test if the limits are False or None, it is recommended to use the appropriate methods `limits_are_false` and `limits_are_set`. Returns: A tuple of two `np.ndarray` with shape (1, n_obs) typically. The last dimension is always `n_obs`, the first can be vectorized. This allows unstacking with `z.unstack_x()` as can be done with data. Raises: CannotConvertToNumpyError: In case the conversion fails. LimitsNotSpecifiedError: If the limits are not set """ lower, upper = self._rect_limits_z() lower = z.unstable._try_convert_numpy(lower) upper = z.unstable._try_convert_numpy(upper) return lower, upper @property @deprecated( None, "Use for compatibility the v1.limits, v1.lower, v1.upper, as described here: https://github.com/zfit/zfit/discussions/533", ) def rect_upper(self) -> ztyping.UpperTypeReturn: """The upper, rectangular limits, equivalent to `rect_limits[1]` with shape (..., n_obs) Returns: The upper, rectangular limits as `np.ndarray` or `tf.Tensor` Raises: LimitsNotSpecifiedError: If the limits are not set or are false """ return self.v0.limits[1] @property @deprecated( None, "Use for compatibility the v1.limits, v1.lower, v1.upper, as described here: https://github.com/zfit/zfit/discussions/533", ) def rect_lower(self) -> ztyping.RectLowerReturnType: """The lower, rectangular limits, equivalent to `rect_limits[0]` with shape (..., n_obs) Returns: The lower, rectangular limits as `np.ndarray` or `tf.Tensor` Raises: LimitsNotSpecifiedError: If the limits are not set or are false """ return self.v0.limits[0] def _rect_limits_z(self): limits_coords = [] rect_lower_unordered = [] rect_upper_unordered = [] obs_in_use = self.obs is not None limits_dict = self._limits_dict["obs" if obs_in_use else "axes"] for ( coord_limit, limit, ) in limits_dict.items(): # TODO: maybe refactor with extract limits? limits_coords.extend(coord_limit) lower, upper = limit.rect_limits # to get the numpy or tensor rect_lower_unordered.append(lower) rect_upper_unordered.append(upper) reorder_kwargs = {"x_obs" if obs_in_use else "x_axes": limits_coords} # stack the limits and reorder them according to the own coords lower_stacked = z.unstable.concat(rect_lower_unordered, axis=-1) lower_ordered = self.reorder_x(lower_stacked, **reorder_kwargs) upper_stacked = z.unstable.concat(rect_upper_unordered, axis=-1) upper_ordered = self.reorder_x(upper_stacked, **reorder_kwargs) return lower_ordered, upper_ordered
[docs] @deprecated( None, "Use for compatibility the v1.limits, v1.lower, v1.upper, as described here: https://github.com/zfit/zfit/discussions/533", ) def rect_area(self) -> float | np.ndarray | znp.array: """Calculate the total rectangular area of all the limits and axes. Useful, for example, for MC integration. """ return self._legacy_area()
@property def rect_limits_are_tensors(self) -> bool: """Return True if the rectangular limits are tensors. If a limit with tensors is evaluated inside a graph context, comparison operations will fail. Returns: If the rectangular limits are tensors. """ try: _ = self.rect_limits_np except CannotConvertToNumpyError: return True else: return False @property def has_rect_limits(self) -> bool: """If there are limits and whether they are rectangular.""" limits_dict = self._limits_dict.get("obs") if self.obs is not None else self._limits_dict.get("axes") if not limits_dict: return False rect_limits = [limit.has_rect_limits for limit in limits_dict.values()] all_rect_limits = all(rect_limits) return all_rect_limits and len(rect_limits) > 0 @property def limits_are_false(self) -> bool: """If the limits have been set to False, so the object on purpose does not contain limits. Returns: True if limits is False """ return all(limit.limits_are_false for limit in self._limits_dict["obs" if self.obs else "axes"].values()) @property def has_limits(self) -> bool: """Whether there are limits set and they are not false. Returns: """ return self.limits_are_set and not self.limits_are_false @property def limits_are_set(self): return all( limit.limits_are_set for limit in self._limits_dict["obs" if self.obs else "axes"].values() if limit is not self ) @property def n_events(self) -> int | None: """Return the number of events, the dimension of the first shape. Returns: Number of events, the dimension of the first shape. If this is > 1 or None, it's vectorized. """ if not self.has_limits: return 1 return self.rect_lower.shape[0] @property @fail_not_rect def lower(self) -> ztyping.LowerTypeReturn: """Return the lower limits. Returns: """ return self.rect_lower @property @fail_not_rect def upper(self) -> ztyping.UpperTypeReturn: """Return the upper limits. Returns: """ return self.rect_upper @property @deprecated( None, "Multiple limits won't be supported anymore in the future. For alternatives, see the announcement:" "https://github.com/zfit/zfit/discussions/533", ) def n_limits(self) -> int: """The number of different limits. Returns: int >= 1 """ return len(tuple(self)) # def with_limits( # self, # *args, # limits: ztyping.LimitsTypeInput = None, # rect_limits: ztyping.RectLimitsInputType | None = None, # lower: ztyping.LimitsTypeInputV1 | None = None, # upper: ztyping.LimitsTypeInputV1 | None = None, # name: str | None = None, # ) -> ZfitSpace: # """Return a copy of the space with the new `limits` (and the new `name`). # # Args: # limits: Limits to use. Can be rectangular, a function (requires to also specify `rect_limits` # or an instance of ZfitLimit. # rect_limits: Rectangular limits that will be assigned with the instance # name: Human readable name # # Returns: # Copy of the current object with the new limits. # """ # if len(args) > 2: # msg = "Cannot take more than 3 positional arguments" # raise ValueError(msg) # if len(args) == 2: # limits, rect_limits = args # elif len(args) == 1: # limits = args[0] # if limits is not None and rect_limits is not None and all(isinstance(lim, (int, float, Any)) for lim in (limits, rect_limits)): # if lower is not None or upper is not None: # msg = "Cannot set both limits and lower, upper." # raise ValueError(msg) # lower, upper = limits, rect_limits # # limits = [lower, upper] # rect_limits = limits # # return type(self)( # obs=self.coords, # limits=limits, # rect_limits=rect_limits, # binning=self.binning, # name=name, # )
[docs] def with_limits( self, limits: ztyping.LimitsTypeInput = None, rect_limits: ztyping.RectLimitsInputType | None = None, name: str | None = None, ) -> ZfitSpace: """Return a copy of the space with the new `limits` (and the new `name`). Args: limits: Limits to use. Can be rectangular, a function (requires to also specify `rect_limits` or an instance of ZfitLimit. rect_limits: Rectangular limits that will be assigned with the instance name: Human readable name Returns: Copy of the current object with the new limits. """ return type(self)( obs=self.coords, limits=limits, rect_limits=rect_limits, binning=self.binning, name=name, label=self.labels )
[docs] def reorder_x( self, x: tf.Tensor | np.ndarray, *, x_obs: ztyping.ObsTypeInput = None, x_axes: ztyping.AxesTypeInput = None, func_obs: ztyping.ObsTypeInput = None, func_axes: ztyping.AxesTypeInput = None, ) -> ztyping.XTypeReturnNoData: """Reorder x in the last dimension either according to its own obs or assuming a function ordered with func_obs. There are two obs or axes around: the one associated with this Coordinate object and the one associated with x. If x_obs or x_axes is given, then this is assumed to be the obs resp. the axes of x and x will be reordered according to `self.obs` resp. `self.axes`. If func_obs resp. func_axes is given, then x is assumed to have `self.obs` resp. `self.axes` and will be reordered to align with a function ordered with `func_obs` resp. `func_axes`. Switching `func_obs` for `x_obs` resp. `func_axes` for `x_axes` inverts the reordering of x. Args: x: Tensor to be reordered, last dimension should be n_obs resp. n_axes x_obs: Observables associated with x. If both, x_obs and x_axes are given, this has precedency over the latter. x_axes: Axes associated with x. func_obs: Observables associated with a function that x will be given to. Reorders x accordingly and assumes self.obs to be the obs of x. If both, `func_obs` and `func_axes` are given, this has precedency over the latter. func_axes: Axe associated with a function that x will be given to. Reorders x accordingly and assumes self.axes to be the axes of x. Returns: The reordered array-like object """ return self.coords.reorder_x(x=x, x_obs=x_obs, x_axes=x_axes, func_obs=func_obs, func_axes=func_axes)
[docs] def with_obs( self, obs: ztyping.ObsTypeInput | None, allow_superset: bool = True, allow_subset: bool = True, ) -> ZfitSpace: """Create a new Space that has `obs`; sorted by or set or dropped. The behavior is as follows: * obs are already set: * input obs are None: the observables will be dropped. If no axes are set, an error will be raised, as no coordinates will be assigned to this instance anymore. * input obs are not None: the instance will be sorted by the incoming obs. If axes or other objects have an associated order (e.g. data, limits,...), they will be reordered as well. If a strict subset is given (and allow_subset is True), only a subset will be returned. This can be used to take a subspace of limits, data etc. If a strict superset is given (and allow_superset is True), the obs will be sorted accordingly as if the obs not contained in the instances obs were not in the input obs. * obs are not set: * if the input obs are None, the same object is returned. * if the input obs are not None, they will be set as-is and now correspond to the already existing axes in the object. Args: obs: Observables to sort/associate this instance with allow_superset: if False and a strict superset of the own observables is given, an error is raised. allow_subset:if False and a strict subset of the own observables is given, an error is raised. Returns: A copy of the object with the new ordering/observables Raises: CoordinatesUnderdefinedError: if obs is None and the instance does not have axes ObsIncompatibleError: if `obs` is a superset and allow_superset is False or a subset and allow_allow_subset is False """ if obs is None: # drop obs, check if there are axes if self.obs is None: return self if self.axes is None: msg = "Cannot remove obs (using None) for a Space without axes" raise AxesIncompatibleError(msg) new_limits = self._limits_dict.copy() new_space = self.copy(obs=obs, limits=new_limits) else: obs = _convert_obs_to_str(obs) coords = self.coords.with_obs(obs, allow_superset=allow_superset, allow_subset=allow_subset) binning = self.binning if binning is not None: binning = [binning[ob] for ob in obs if ob in self.obs] if (newlabels := self.labels) is not None: newlabels = tuple(self._labels[ob] for ob in coords.obs) # use just the obs that are available new_space = type(self)(coords, limits=self._limits_dict, binning=binning, label=newlabels) return new_space
[docs] def with_axes( self, axes: ztyping.AxesTypeInput | None, allow_superset: bool = True, allow_subset: bool = True, ) -> ZfitSpace: """Create a new instance that has `axes`; sorted by or set or dropped. The behavior is as follows: * axes are already set: * input axes are None: the axes will be dropped. If no observables are set, an error will be raised, as no coordinates will be assigned to this instance anymore. * input axes are not None: the instance will be sorted by the incoming axes. If obs or other objects have an associated order (e.g. data, limits,...), they will be reordered as well. If a strict subset is given (and allow_subset is True), only a subset will be returned. This can be used to retrieve a subspace of limits, data etc. If a strict superset is given (and allow_superset is True), the axes will be sorted accordingly as if the axes not contained in the instances axes were not present in the input axes. * axes are not set: * if the input axes are None, the same object is returned. * if the input axes are not None, they will be set as-is and now correspond to the already existing obs in the object. Args: axes: Axes to sort/associate this instance with allow_superset: if False and a strict superset of the own axeservables is given, an error is raised. allow_subset:if False and a strict subset of the own axeservables is given, an error is raised. Returns: A copy of the object with the new ordering/axes Raises: CoordinatesUnderdefinedError: if obs is None and the instance does not have axes AxesIncompatibleError: if `axes` is a superset and allow_superset is False or a subset and allow_allow_subset is False """ if axes is None: # drop axes if self.axes is None: return self if self.obs is None: msg = "Cannot remove axes (using None) for a Space without obs" raise ObsIncompatibleError(msg) new_limits = self._limits_dict.copy() if (newlabels := self.labels) is not None and (obs := self.obs) is not None: newlabels = tuple(self._labels[ob] for ob in obs) new_space = self.copy(axes=axes, limits=new_limits, label=newlabels) else: axes = convert_to_axes(axes) if self.axes is None: if len(axes) != len(self.obs): msg = f"Trying to set axes {axes} to object with obs {self.obs}" raise AxesIncompatibleError(msg) if (newlabels := self.labels) is not None and (obs := self.obs) is not None: newlabels = tuple(self._labels[ob] for ob in obs) new_space = self.copy(axes=axes, limits=self._limits_dict, label=newlabels) else: coords = self.coords.with_axes(axes=axes, allow_superset=allow_superset, allow_subset=allow_subset) if (newlabels := self.labels) is not None and (obs := coords.obs) is not None: newlabels = tuple(self._labels[ob] for ob in obs) new_space = type(self)(coords, limits=self._limits_dict, binning=self.binning, label=newlabels) return new_space
[docs] def with_coords( self, coords: ZfitOrderableDimensional, allow_superset: bool = True, allow_subset: bool = True, ) -> Space: """Create a new :py:class:`~zfit.Space` with reordered observables and/or axes. The behavior is that _at least one coordinate (obs or axes) has to be set in both instances (the space itself or in `coords`). If both match, observables is taken as the defining coordinate. The space is sorted according to the defining coordinate and the other coordinate is sorted as well. If either the space did not have the "weaker coordinate" (e.g. both have observables, but only coords has axes), then the resulting Space will have both. If both have both coordinates, obs and axes, and sorting for obs results in non-matchin axes results in axes being dropped. Args: coords: An instance of :py:class:`Coordinates` allow_superset: If `False` and a strict superset is given, an error is raised allow_subset: If `False` and a strict subset is given, an error is raised Returns: :py:class:`~zfit.Space`: Raises: CoordinatesUnderdefinedError: if neither both obs or axes are specified. CoordinatesIncompatibleError: if `coords` is a superset and allow_superset is False or a subset and allow_allow_subset is False """ if self.obs is not None and coords.obs is not None: new_space_obs = self.with_obs(coords.obs, allow_superset=allow_superset, allow_subset=allow_subset) if coords.axes is not None: # use this axes: first drop the other one if new_space_obs.axes is not None: new_space_obs = new_space_obs.with_axes(None) # filter in case there are super/subsets coords_axes = coords.with_obs( new_space_obs.obs, allow_superset=allow_superset, allow_subset=allow_subset, ) new_space_obs = new_space_obs.with_axes(coords_axes.axes) # are the same or self.axes is None new_space = new_space_obs elif self.axes is not None and coords.axes is not None: new_space_axes = self.with_axes(coords.axes, allow_superset=allow_superset, allow_subset=allow_subset) if coords.obs is not None: # filter in case there are super/subsets coords_obs = coords.with_axes( new_space_axes.axes, allow_superset=allow_superset, allow_subset=allow_subset, ) new_space_axes = new_space_axes.with_obs(coords_obs.obs) new_space = new_space_axes else: msg = f"Neither the axes nor the obs are specified in both objects" f" {self} and {coords}" raise CoordinatesUnderdefinedError(msg) return new_space
[docs] def with_autofill_axes(self, overwrite: bool = False) -> zfit.Space: """Overwrite the axes of the current object with axes corresponding to range(len(n_obs)). This effectively fills with (0, 1, 2,...) and can be used mostly when an object enters a PDF or similar. `overwrite` allows to remove the axis first in case there are already some set. .. code-block:: object.obs -> ('x', 'z', 'y') object.axes -> None object.with_autofill_axes() object.obs -> ('x', 'z', 'y') object.axes -> (0, 1, 2) Args: overwrite: If axes are already set, replace the axes with the autofilled ones. If axes is already set and `overwrite` is False, raise an error. Returns: The object with the new axes Raises: AxesIncompatibleError: if the axes are already set and `overwrite` is False. """ new_coords = self.coords.with_autofill_axes(overwrite=overwrite) # new_space = self.with_coords(new_coords) return self.copy(axes=new_coords.axes) if self.axes is None or overwrite else self
[docs] def get_subspace( self, obs: ztyping.ObsTypeInput = None, axes: ztyping.AxesTypeInput = None, name: str | None = None, ) -> Space: """Create a :py:class:`~zfit.Space` consisting of only a subset of the `obs`/`axes` (only one allowed). Args: obs: Observables of the subspace to return. axes: Axes of the subspace to return. name: Human readable names Returns: A space containing only a subspace (and sublimits etc.) """ del name # should be label if obs is not None and axes is not None: msg = "Cannot specify `obs` *and* `axes` to get subspace." raise ValueError(msg) if axes is None and obs is None: msg = "Either `obs` or `axes` has to be specified and not None" raise ValueError(msg) # try to use observables to get index obs = self._check_convert_input_obs(obs=obs, allow_none=True) axes = self._check_convert_input_axes(axes=axes, allow_none=True) if obs is not None: limits_dict = self.get_limits(obs=obs) new_coords = self.coords.with_obs(obs, allow_subset=True, allow_superset=True) else: limits_dict = self.get_limits(axes=axes) new_coords = self.coords.with_axes(axes=axes, allow_subset=True, allow_superset=True) if (newlabels := self.labels) is not None and (obs := new_coords.obs) is not None: newlabels = tuple(self._labels[ob] for ob in obs) return type(self)(obs=new_coords, limits=limits_dict, label=newlabels)
@fail_not_rect def _legacy_area(self) -> float: """Return the total area of all the limits and axes. Useful, for example, for MC integration. """ return calculate_rect_area(rect_limits=self._rect_limits_tf) def with_binning(self, binning): return self.copy(binning=binning) # Operators
[docs] def copy(self, **overwrite_kwargs) -> zfit.Space: """Create a new :py:class:`~zfit.Space` using the current attributes and overwriting with `overwrite_overwrite_kwargs`. Args: name: The new name. If not given, the new instance will be named the same as the current one. **overwrite_kwargs: Returns: :py:class:`~zfit.Space` """ kwargs = { "name": self.name, "limits": self._limits_dict, "binning": self.binning, "axes": self.axes, "obs": self.obs, "label": self.labels, } kwargs.update(overwrite_kwargs) if set(overwrite_kwargs) - set(kwargs): msg = f"Not usable keys in `overwrite_kwargs`: {set(overwrite_kwargs) - set(kwargs)}" raise KeyError(msg) kwargs.get("binning") return self.__class__(**kwargs)
def _inside(self, x, guarantee_limits): del guarantee_limits # not used xs_inside = [] obs_in_use = self.obs is not None limits_dict = self._limits_dict["obs" if obs_in_use else "axes"] for coords, limit in limits_dict.items(): reorder_kwargs = {"func_obs" if obs_in_use else "func_axes": coords} x_sub = self.reorder_x(x, **reorder_kwargs) x_inside = limit.inside(x_sub) xs_inside.append(x_inside) return znp.all(xs_inside, axis=0) @property # TODO(1.0): deprecate limits, see also https://github.com/zfit/zfit/discussions/533 # @deprecated(date=None, instructions="use lower, upper, limits, area; for transition, see also the announcement: https://github.com/zfit/zfit/discussions/533") @fail_not_rect def limit1d(self) -> tuple[float, float]: """Simplified limits getter for 1 obs, 1 limit only: return the tuple(lower, upper). Returns: So :code:`lower, upper = space.limit1d` for a simple, 1 obs limit. Raises: RuntimeError: if the conditions (n_obs or n_limits) are not satisfied. """ if self.n_obs > 1: msg = f"Cannot call `limit1d, as `Space` has more than one observables: {self.n_obs}" raise RuntimeError(msg) if self._depr_n_limits > 1: msg = f"Cannot call `limit1d, as `Space` has several limits: {self._depr_n_limits}" raise RuntimeError(msg) lower, upper = self.rect_limits return lower[0][0], upper[0][0]
[docs] @classmethod @deprecated( date=None, instructions="Use directly the class to create a Space. E.g. zfit.Space(axes=(0, 1), ...)", ) def from_axes( cls, axes: ztyping.AxesTypeInput, # noqa: ARG003 limits: ztyping.LimitsTypeInput | None = None, # noqa: ARG003 rect_limits=None, # noqa: ARG003 name: str | None = None, # noqa: ARG003 ) -> zfit.Space: """Create a space from `axes` instead of from `obs`. Args: rect_limits: axes: limits: name: Returns: :py:class:`~zfit.Space` """ msg = "from_axes is not needed anymore, create a Space directly." raise BreakingAPIChangeError(msg)
def extract_limits_from_dict(limits_dict, obs=None, axes=None): if (obs is None) and (axes is None): msg = "Need to specify at least one, obs or axes." raise ValueError(msg) if (obs is not None) and (axes is not None): axes = None # obs has precedency if obs is None: obs_in_use = False coords_to_extract = axes else: obs_in_use = True coords_to_extract = obs coords_to_extract = convert_to_container(coords_to_extract) coords_to_extract = set(coords_to_extract) limits_to_eval = {} limit_dict = limits_dict["obs" if obs_in_use else "axes"].items() keys_sorted = sorted(limit_dict, key=lambda x: len(x[0]), reverse=True) for key_coords, limit in keys_sorted: coord_intersec = frozenset(key_coords).intersection(coords_to_extract) if not coord_intersec: # this limit does not contain any requested obs continue if coord_intersec == frozenset(key_coords): if isinstance(limit, ZfitOrderableDimensional): # drop coordinates if given limit = limit.with_axes(None) if obs_in_use else limit.with_obs(None) limits_to_eval[key_coords] = limit else: coord_limit = [coord for coord in key_coords if coord in coord_intersec] kwargs = {"obs" if obs_in_use else "axes": coord_limit} try: sublimit = limit.get_subspace(**kwargs) except InvalidLimitSubspaceError: msg = f"Cannot extract {coord_intersec} from limit {limit}." raise InvalidLimitSubspaceError(msg) from None sublimit_coord = limit.obs if obs_in_use else limit.axes if isinstance(sublimit, ZfitOrderableDimensional): # drop coordinates if given sublimit = sublimit.with_axes(None) if obs_in_use else sublimit.with_obs(None) limits_to_eval[sublimit_coord] = sublimit coords_to_extract -= coord_intersec return limits_to_eval def add_spaces(*spaces: Iterable[ZfitSpace], name=None): """Add two spaces and merge their limits if possible or return False. Args: spaces: Returns: Union[None, :py:class:`~zfit.Space`, bool]: Raises: LimitsIncompatibleError: if limits of the `spaces` cannot be merged because they overlap """ # spaces = convert_to_container(spaces) if not all(isinstance(space, ZfitSpace) for space in spaces): msg = f"Can only add type ZfitSpace, not {spaces}" raise TypeError(msg) return MultiSpace(spaces, name=name) def get_coord(space, obs_in_use=True): if obs_in_use: return space.obs else: return space.axes
[docs] def combine_spaces(*spaces: Iterable[Space]): """Combine spaces with different `obs` and `limits` to one `space`. Checks if the limits in each obs coincide *exactly*. If this is not the case, the combination is not unambiguous and `False` is returned Args: spaces: Returns: Returns False if the limits don't coincide in one or more obs. Otherwise return the :py:class:`~zfit.Space` with all obs from `spaces` sorted by the order of `spaces` and with the combined limits. Raises: ValueError: if only one space is given LimitsIncompatibleError: If the limits of one or more spaces (or within a space) overlap LimitsNotSpecifiedError: If the limits for one or more obs but not all are None or False. """ spaces = convert_to_container(spaces, container=tuple) # if len(spaces) <= 1: # return spaces # raise ValueError("Need at least two spaces to test limit consistency.") # TODO: allow? usecase? common_obs_ordered = common_obs(spaces=spaces) common_axes_ordered = common_axes(spaces=spaces) all_spaces_binned = all(space.is_binned for space in spaces) all_spaces_unbinned = not any(space.is_binned for space in spaces) if not (all_spaces_binned or all_spaces_unbinned): msg = ( f"Some spaces are binned {[s for s in spaces if s.is_binned]}" f" while others are not {[s for s in spaces if not s.is_binned]}. Cannot mix." ) raise ValueError(msg) all_labels_map = {} for space in spaces: if (lab := space._labels) is not None: all_labels_map.update(lab) if all_spaces_binned: binnings_ordererd = [] for ob in common_obs_ordered: for space in spaces: if ob in space.obs: binning = space.binning[ob] if binning not in binnings_ordererd: binnings_ordererd.append(binning) binning = Binnings(binnings_ordererd) else: binning = None using_obs = bool(common_obs_ordered) common_coords_ordered = common_obs_ordered if using_obs else common_axes_ordered # sort the spaces if using_obs: spaces = tuple(space.with_obs(common_obs_ordered, allow_superset=True) for space in spaces) all_coords = [space.obs for space in spaces] elif common_axes_ordered: spaces = tuple(space.with_axes(common_axes_ordered, allow_superset=True) for space in spaces) all_coords = [space.axes for space in spaces] else: msg = "Neither `obs` nor `axes` exist in all spaces." raise CoordinatesUnderdefinedError(msg) all_limits_false = all(space.limits_are_false for space in spaces) all_limits_not_set = all(not space.limits_are_set for space in spaces) has_limits = [space.has_limits for space in spaces] if all_limits_false: limits = False elif all_limits_not_set: limits = None elif not all(has_limits): msg = "Limits either have to be set, not set, or False for all spaces to be combined." raise LimitsNotSpecifiedError(msg) else: space_combinations = tuple(itertools.product(*spaces)) if len(space_combinations) > 1: # there are MultiSpaces in there all_combinations = [] for spa in space_combinations: with suppress(LimitsIncompatibleError): all_combinations.append(combine_spaces(*spa)) if not all_combinations: msg = f"The limits of {spaces} are all not compatible to be combined." raise LimitsIncompatibleError(msg) # filter, as can be False: non-overlapping limits e.g. if we have two MultiSpace [space for space in all_combinations if space is not False] return MultiSpace( spaces=all_combinations, obs=common_obs_ordered if common_obs_ordered else None, axes=common_axes_ordered if common_axes_ordered else None, ) # TODO: spaces that have multidim limits? limits_dict = {} non_unique_coords = set() unique_coords = set() for coord in common_coords_ordered: if sum(coord in coords for coords in all_coords) > 1: non_unique_coords.add(coord) else: unique_coords.add(coord) for coord in common_coords_ordered: if coord in unique_coords: space = next(space for space in spaces if coord in get_coord(space, using_obs)) space = space.get_subspace( obs=unique_coords if using_obs else None, axes=None if using_obs else unique_coords, ) limits_dict.update(space.get_limits()["obs" if using_obs else "axes"]) for coord in get_coord(space, using_obs): unique_coords.remove(coord) elif coord in non_unique_coords: non_unique_spaces = [space for space in spaces if coord in get_coord(space, using_obs)] common_coords_non_unique = list( set.intersection(*(set(get_coord(space, using_obs)) for space in non_unique_spaces)) ) # do the below to check if we can take the subspace non_unique_subspaces = [ space.get_subspace( obs=common_coords_non_unique if using_obs else None, axes=None if using_obs else common_coords_non_unique, ) for space in non_unique_spaces ] # TODO compare limits any_non_equal = any(non_unique_subspaces[0] != space for space in non_unique_subspaces[1:]) if any_non_equal: msg = ( f"Limits in coord {common_coords_non_unique} do not match for spaces" f" {non_unique_subspaces}" ) raise LimitsIncompatibleError(msg) non_unique_subspace = non_unique_subspaces[0] limits_dict.update(non_unique_subspace.get_limits()["obs" if using_obs else "axes"]) for coord in get_coord(non_unique_subspace, using_obs): non_unique_coords.remove(coord) else: pass # fine, since it is already satisfied by a space # # for coord in common_coords_ordered: # space_with_coord = [space for space in spaces if coord in get_coord(space, using_obs)] # assert not any(isinstance(space, MultiSpace) for space in space_with_coord), "bug, should be caught before." # assert space_with_coord, "empty, cannot be. This is a bug." # limits_coord = [] # for space in space_with_coord: # if type(space) == Space: # has to be the exact type, we use an implementation detail here # limits_coord.append(space.get_limits(obs=coord if using_obs else None, # axes=coord if not using_obs else None)) # else: # raise WorkInProgressError # # limits_coord.append(space.with_obs(obs=coord) if using_obs else space.with_axes(axes=coord)) # any_non_equal = any([limits_coord[0][(coord,)] != limit[(coord,)] for limit in limits_coord[1:]]) # if any_non_equal: # raise LimitsIncompatibleError(f"Limits in coord {coord} do not match for spaces {limits_coord}") # limits_dict.update(limits_coord[0]) limits = {"obs" if using_obs else "axes": limits_dict} # all_lower = [] # all_upper = [] # # # create the lower and upper limits with all obs replacing missing dims with None # # With this, all limits have the same length # # TODO? # # if limits_overlap(spaces=spaces, allow_exact_match=True): # # raise LimitsIncompatibleError("Limits overlap") # # for space in flatten_spaces(spaces): # if space.limits is None: # continue # lowers, uppers = space.limits # lower = [tuple(low[space.obs.index(ob)] for low in lowers) if ob in space.obs else None for ob in all_obs] # upper = [tuple(up[space.obs.index(ob)] for up in uppers) if ob in space.obs else None for ob in all_obs] # all_lower.append(lower) # all_upper.append(upper) # # def check_extract_limits(limits_spaces): # new_limits = [] # # if not limits_spaces: # return None # for index, obs in enumerate(all_obs): # current_limit = None # for limit in limits_spaces: # lim = limit[index] # # if lim is not None: # if current_limit is None: # current_limit = lim # elif not np.allclose(current_limit, lim): # return False # else: # if current_limit is None: # raise LimitsNotSpecifiedError("Limits in obs {} are not specified".format(obs)) # new_limits.append(current_limit) # # n_limits = int(np.prod(tuple(len(lim) for lim in new_limits))) # new_limits_comb = [[] for _ in range(n_limits)] # for limit in new_limits: # for lim in limit: # for i in range(int(n_limits / len(limit))): # new_limits_comb[i].append(lim) # # new_limits = tuple(tuple(limit) for limit in new_limits_comb) # return new_limits # new_lower = check_extract_limits(all_lower) # new_upper = check_extract_limits(all_upper) # assert not (new_lower is None) ^ (new_upper is None), "Bug, please report issue. either both are defined or None." # if new_lower is None: # limits = None # elif new_lower is False: # return False # else: # limits = (new_lower, new_upper) return Space( obs=common_obs_ordered if using_obs else None, axes=None if using_obs else common_axes_ordered, binning=binning, limits=limits, label=tuple(all_labels_map[ob] for ob in common_coords_ordered) if using_obs else None, )
def less_equal_space(space1, space2, allow_graph=True): return compare_multispace( space1=space1, space2=space2, comparator=lambda limit1, limit2: limit1.less_equal(limit2, allow_graph=allow_graph), ) def equal_space(space1, space2, allow_graph=True): return compare_multispace( space1=space1, space2=space2, comparator=lambda limit1, limit2: limit1.equal(limit2, allow_graph=allow_graph), ) def compare_multispace(space1: ZfitSpace, space2: ZfitSpace, comparator: Callable): """Compare multiple spaces if they have the same obs, axes, and, if a comparator is given, limits. It is automatically checked if the limits are set resp. are False Args: space1: space2: comparator: Returns: """ axes_not_none = space1.axes is not None and space2.axes is not None obs_not_none = space1.obs is not None and space2.obs is not None if not (axes_not_none or obs_not_none): # if both are None return False if obs_not_none: if set(space1.obs) != set(space2.obs): return False elif axes_not_none and set(space1.axes) != set(space2.axes): # axes only matter if there are no obs return False if space1.binning != space2.binning: return False # check limits if not space1.limits_are_set: return bool(not space2.limits_are_set) elif space1.limits_are_false: return bool(space2.limits_are_false) return compare_limits_multispace(space1, space2, comparator=comparator) def compare_limits_multispace(space1: ZfitSpace, space2: ZfitSpace, comparator: Callable) -> bool: if len(space1) != len(space2): return False if not (space1.has_limits and space2.has_limits): return False space2_reordered = space2.with_coords(space1) comparison = [] for space11 in space1: compare_spaces2 = [] for space22 in space2_reordered: compare_spaces2.append( compare_limits_coords_dict(space11.get_limits(), space22.get_limits(), comparator=comparator) ) comparison.append(compare_spaces2) comparison = convert_to_tensor_or_numpy(comparison, dtype=tf.bool) space1_matches = z.unstable.reduce_any(comparison, axis=1) # reduce over axis containing space2, has to match with # at least one space2. space2_matches = z.unstable.reduce_any(comparison, axis=0) all_space1_match = z.unstable.reduce_all(space1_matches, axis=0) all_space2_match = z.unstable.reduce_all(space2_matches, axis=0) return z.unstable.logical_and(all_space1_match, all_space2_match) def compare_limits_coords_dict( limits1: Mapping[str, Mapping[Iterable, ZfitLimit]], limits2: Mapping[str, Mapping[Iterable, ZfitLimit]], comparator: Callable, require_all_coord_types: bool = False, ) -> bool: if limits1.keys() != limits2.keys() and require_all_coord_types: return False equal = [] for coord_type, limit1_dict in limits1.items(): limit2_dict = limits2.get(coord_type) if limit2_dict is None: continue equal.append(compare_limits_dict(limit1_dict, limit2_dict, comparator=comparator)) return z.unstable.reduce_all(equal) def compare_limits_dict(dict1: Mapping, dict2: Mapping, comparator: Callable) -> bool: comparison = [] limits2_to_check = dict2.copy() for coord, limit1 in dict1.items(): for limit2cord, limit2 in limits2_to_check.items(): if set(limit2cord) == set(coord): limit2 = limits2_to_check.pop(limit2cord) comparison.append(comparator(limit1, limit2)) break else: # no break, nothing matched return False return z.unstable.reduce_all(comparison) def flatten_spaces(spaces): return tuple(s for space in spaces for s in space) class MultiSpace(BaseSpace): def __new__( cls, spaces: Iterable[ZfitSpace], obs: ztyping.ObsTypeInput = None, binning: ztyping.BinningTypeInput = None, axes: ztyping.AxesTypeInput = None, name: str | None = None, # noqa: ARG003 ) -> Space | MultiSpace: if binning is not None: msg = ( "Binning not yet implemented for MultiSpace, won't ever be. Use the new truncated, multidim PDF instead" ) raise RuntimeError(msg) spaces, obs, axes = cls._check_convert_input_spaces_obs_axes(spaces, obs, axes) if len(spaces) == 1: return spaces[0] space = super().__new__(cls) # for the __init__ below, see there space._tmp_store_spaces_obs_axes = spaces, obs, axes return space @deprecated( date=None, instructions="MultiSpace is deprecated, see announcement here: https://github.com/zfit/zfit/discussions/533" "If you're use-case is not covered by this or you have questions, please open an issue " " at https://github.com/zfit/zfit/issues/new?assignees=mayou36&labels=discussion&projects=&template=behaviorunderdiscussion.md&title=%5BMultiSpace%20deprecation%5D", ) def __init__(self, spaces: Iterable[ZfitSpace], obs=None, axes=None, name: str | None = None) -> None: # Since __new__ returns an instance of MultiSpace, __init__ is invoked. We don't want to reprocess # the input arguments here, so we store them above in the dummy attribute. del spaces, obs, axes # not needed, we take the already preprocessed. spaces, obs, axes = self._tmp_store_spaces_obs_axes del self._tmp_store_spaces_obs_axes if name is None: name = "MultiSpace" super().__init__(obs, axes, name) self._labels = {ob: ob for ob in self.obs} if self.obs is not None else None self.spaces = spaces self.v0 = V0Space(self) @property def label(self): return None @property def labels(self): return self.obs if self.obs is not None else None @staticmethod def _initialize_space(space, spaces, obs, axes): space._obs = obs space._axes = axes space.spaces = spaces return space @staticmethod def _check_convert_input_spaces_obs_axes(spaces, obs, axes): # TODO: do something with axes spaces = flatten_spaces(spaces) all_have_obs = all(space.obs is not None for space in spaces) all_have_axes = all(space.axes is not None for space in spaces) all_binnings_compatible = len({space.binning for space in spaces}) == 1 n_events = [space.n_events in (spaces[0].n_events, None) for space in spaces] all_nevents_compatible = all(n_events) if not all_binnings_compatible: msg = "Binnings not compatible, maybe this needs to be better care taken." raise ValueError(msg) if not all_nevents_compatible: msg = "The number of events of the spaces do not coincide" raise NumberOfEventsIncompatibleError(msg) if all_have_axes: axes = spaces[0].axes if axes is None else convert_to_axes(axes) if all_have_obs: obs = spaces[0].obs if obs is None else convert_to_obs_str(obs) spaces = [space.with_obs(obs, allow_subset=False, allow_superset=False) for space in spaces] if not (all_have_axes and all(space.axes == axes for space in spaces)): # obs coincide, axes don't -> drop spaces = [space.with_axes(None) for space in spaces] elif all_have_axes: if all(space.obs is None for space in spaces): spaces = [space.with_axes(axes, allow_superset=False, allow_subset=False) for space in spaces] if not (all_have_obs and all(space.obs == obs for space in spaces)): # axes coincide, obs don't -> drop spaces = [space.with_obs(None) for space in spaces] else: msg = "Spaces do not have consistent obs and/or axes." raise SpaceIncompatibleError(msg) if all(space.has_limits for space in spaces): # check overlap, reduce common limits pass elif not any(space.has_limits for space in spaces): spaces = [spaces[0]] # if all are None, then nothing to add else: # some have limits, some don't -> does not really make sense (or just drop the ones without limits?) msg = ( "Some spaces have limits, other don't. This behavior may change in the future " "to allow spaces with None to be simply ignored.\n" "If you prefer this behavior, please open an issue on github." ) raise LimitsIncompatibleError(msg) spaces = tuple(spaces) return spaces, obs, axes @property def binning(self): return self.spaces[0].binning # noinspection PyPropertyDefinition @property @fail_not_rect def limits(self) -> None: self._raise_limits_not_implemented() # noinspection PyPropertyDefinition @property def rect_limits(self): self._raise_limits_not_implemented() # noinspection PyPropertyDefinition @property @fail_not_rect def lower(self) -> None: self._raise_limits_not_implemented() # noinspection PyPropertyDefinition @property @fail_not_rect def upper(self) -> None: self._raise_limits_not_implemented() # noinspection PyPropertyDefinition @property def rect_limits_np(self): self._raise_limits_not_implemented() # noinspection PyPropertyDefinition @property def rect_lower(self): self._raise_limits_not_implemented() # noinspection PyPropertyDefinition @property def rect_upper(self): self._raise_limits_not_implemented() def rect_area(self) -> float | np.ndarray | tf.Tensor: """Calculate the total rectangular area of all the limits and axes. Useful, for example, for MC integration. """ return z.reduce_sum([space._legacy_area() for space in self], axis=0) @property def rect_limits_are_tensors(self) -> bool: """Return True if the rectangular limits are tensors. If a limit with tensors is evaluated inside a graph context, comparison operations will fail. Returns: If the rectangular limits are tensors. """ return all(space.limits_are_tensors for space in self) @property def has_rect_limits(self) -> bool: """If there are limits and whether they are rectangular.""" return all(space.has_rect_limits for space in self.spaces) # noinspection PyPropertyDefinition @property def limits_are_false(self) -> bool: """If the limits have been set to False, so the object on purpose does not contain limits. Returns: True if limits is False """ return all(space.limits_are_false for space in self.spaces) # noinspection PyPropertyDefinition @property def has_limits(self) -> bool: """Whether there are limits set and they are not false. Returns: """ try: return self.limits_are_set and not self.limits_are_false except MultipleLimitsNotImplemented: return True @property def limits_are_set(self) -> bool: """If the limits have been set to a limit or are False. Returns: Whether the limits have been set or not. """ return all(space.limits_are_set for space in self.spaces) @property def n_events(self) -> int | None: """Shape of the first dimension, usually reflects the number of events. Returns: Return the number of events, the dimension of the first shape. If this is > 1 or None, it's vectorized. """ # get the first numeric n_events. Is None if a Tensor and not specified yet. n_events_first = [space.n_events for space in self if space.n_events is not None] return None if not n_events_first else n_events_first[0] def with_limits( self, limits: ztyping.LimitsTypeInput = None, rect_limits: ztyping.RectLimitsInputType | None = None, name: str | None = None, ) -> ZfitSpace: """Return a copy of the space with the new `limits` (and the new `name`). Args: limits: Limits to use. Can be rectangular, a function (requires to also specify `rect_limits` or an instance of ZfitLimit. rect_limits: Rectangular limits that will be assigned with the instance name: Human readable name Returns: Copy of the current object with the new limits. """ return self.copy( spaces=[space.with_limits(limits=limits, rect_limits=rect_limits) for space in self], name=name, ) @fail_not_rect def _legacy_area(self) -> float: return self.rect_area() def with_obs( self, obs: ztyping.ObsTypeInput | None, allow_superset: bool = True, allow_subset: bool = True, ) -> MultiSpace: """Create a new Space that has `obs`; sorted by or set or dropped. The behavior is as follows: * obs are already set: * input obs are None: the observables will be dropped. If no axes are set, an error will be raised, as no coordinates will be assigned to this instance anymore. * input obs are not None: the instance will be sorted by the incoming obs. If axes or other objects have an associated order (e.g. data, limits,...), they will be reordered as well. If a strict subset is given (and allow_subset is True), only a subset will be returned. This can be used to take a subspace of limits, data etc. If a strict superset is given (and allow_superset is True), the obs will be sorted accordingly as if the obs not contained in the instances obs were not in the input obs. * obs are not set: * if the input obs are None, the same object is returned. * if the input obs are not None, they will be set as-is and now correspond to the already existing axes in the object. Args: obs: Observables to sort/associate this instance with allow_superset: if False and a strict superset of the own observables is given, an error is raised. allow_subset:if False and a strict subset of the own observables is given, an error is raised. Returns: A copy of the object with the new ordering/observables Raises: CoordinatesUnderdefinedError: if obs is None and the instance does not have axes ObsIncompatibleError: if `obs` is a superset and allow_superset is False or a subset and allow_allow_subset is False """ spaces = [ space.with_obs(obs, allow_superset=allow_superset, allow_subset=allow_subset) for space in self.spaces ] coords = self.coords.with_obs(obs, allow_subset=allow_subset, allow_superset=allow_superset) return self.copy(spaces=spaces, obs=coords.obs, axes=coords.axes) def with_axes( self, axes: ztyping.AxesTypeInput | None, allow_superset: bool = True, allow_subset: bool = True, ) -> MultiSpace: """Create a new instance that has `axes`; sorted by or set or dropped. The behavior is as follows: * axes are already set: * input axes are None: the axes will be dropped. If no observables are set, an error will be raised, as no coordinates will be assigned to this instance anymore. * input axes are not None: the instance will be sorted by the incoming axes. If obs or other objects have an associated order (e.g. data, limits,...), they will be reordered as well. If a strict subset is given (and allow_subset is True), only a subset will be returned. This can be used to retrieve a subspace of limits, data etc. If a strict superset is given (and allow_superset is True), the axes will be sorted accordingly as if the axes not contained in the instances axes were not present in the input axes. * axes are not set: * if the input axes are None, the same object is returned. * if the input axes are not None, they will be set as-is and now correspond to the already existing obs in the object. Args: axes: Axes to sort/associate this instance with allow_superset: if False and a strict superset of the own axeservables is given, an error is raised. allow_subset:if False and a strict subset of the own axeservables is given, an error is raised. Returns: A copy of the object with the new ordering/axes Raises: CoordinatesUnderdefinedError: if obs is None and the instance does not have axes AxesIncompatibleError: if `axes` is a superset and allow_superset is False or a subset and allow_allow_subset is False """ spaces = [ space.with_axes(axes, allow_superset=allow_superset, allow_subset=allow_subset) for space in self.spaces ] coords = self.coords.with_axes(axes, allow_subset=allow_subset, allow_superset=allow_superset) return self.copy(spaces=spaces, obs=coords.obs, axes=coords.axes) def with_coords( self, coords: ZfitOrderableDimensional, allow_superset: bool = True, allow_subset: bool = True, ) -> MultiSpace: """Create a new :py:class:`~zfit.Space` with reordered observables and/or axes. The behavior is that _at least one coordinate (obs or axes) has to be set in both instances (the space itself or in `coords`). If both match, observables is taken as the defining coordinate. The space is sorted according to the defining coordinate and the other coordinate is sorted as well. If either the space did not have the "weaker coordinate" (e.g. both have observables, but only coords has axes), then the resulting Space will have both. If both have both coordinates, obs and axes, and sorting for obs results in non-matchin axes results in axes being dropped. Args: coords: An instance of :py:class:`Coordinates` allow_superset: If false and a strict superset is given, an error is raised allow_subset: If false and a strict subset is given, an error is raised Returns: :py:class:`~zfit.Space`: Raises: CoordinatesUnderdefinedError: if neither both obs or axes are specified. CoordinatesIncompatibleError: if `coords` is a superset and allow_superset is False or a subset and allow_allow_subset is False """ new_spaces = [ space.with_coords(coords, allow_superset=allow_superset, allow_subset=allow_subset) for space in self ] return type(self)(spaces=new_spaces) def with_autofill_axes(self, overwrite: bool = False) -> MultiSpace: """Overwrite the axes of the current object with axes corresponding to range(len(n_obs)). This effectively fills with (0, 1, 2,...) and can be used mostly when an object enters a PDF or similar. `overwrite` allows to remove the axis first in case there are already some set. .. code-block:: object.obs -> ('x', 'z', 'y') object.axes -> None object.with_autofill_axes() object.obs -> ('x', 'z', 'y') object.axes -> (0, 1, 2) Args: overwrite: If axes are already set, replace the axes with the autofilled ones. If axes is already set and `overwrite` is False, raise an error. Returns: The object with the new axes Raises: AxesIncompatibleError: if the axes are already set and `overwrite` is False. """ spaces = [space.with_autofill_axes(overwrite=overwrite) for space in self] return self.copy(spaces=spaces) def get_subspace( self, obs: ztyping.ObsTypeInput = None, axes: ztyping.AxesTypeInput = None, name: str | None = None, ) -> MultiSpace: """Create a :py:class:`~zfit.Space` consisting of only a subset of the `obs`/`axes` (only one allowed). Args: obs: Observables of the subspace to return. axes: Axes of the subspace to return. name: Human readable names Returns: A space containing only a subspace (and sublimits etc.) """ spaces = [space.get_subspace(obs=obs, axes=axes) for space in self.spaces] return self.copy(spaces=spaces, name=name) def copy(self, *, deep: bool = False, name: str | None = None, **overwrite_params) -> MultiSpace: assert not deep, "deep not explicitly implemented, should not be needed for immutable objects" kwargs = { "spaces": tuple(self), "obs": self.obs, "axes": self.axes, } kwargs.update(overwrite_params) kwargs["name"] = self.name if name is None else name return type(self)(**kwargs) def _raise_limits_not_implemented(self): msg = ( "Limits/lower/upper not implemented for MultiSpace. This error is either caught" " automatically as part of the codes logic or the MultiLimit case should" " be considered. To do that, simply iterate through the MultiSpace, which returns" " a simple space. Iterating through a Spaces also works" "for simple spaces." ) raise MultipleLimitsNotImplemented(msg) def _inside(self, x, guarantee_limits): inside_limits = [space.inside(x, guarantee_limits=guarantee_limits) for space in self] return znp.any(inside_limits, axis=0) # has to be inside one limit def __iter__(self) -> ZfitSpace: yield from self.spaces def __repr__(self): class_name = str(self.__class__).split(".")[-1].split("'")[0] if not self.limits_are_set: limits = None elif self.limits_are_false: limits = False elif self.has_rect_limits: if self.n_obs < 3 and not self.n_events > 1 and self._depr_n_limits <= 3: limits = [lim.v0.limits for lim in self] else: limits = "rectangular" else: limits = "functional" return f"<zfit {class_name} obs={self.obs}, axes={self.axes}, limits={limits}>" def __eq__(self, other): # in principle, two disjoint regions could have coinciding lower and upper limits equaling to an actually larger # space. However, we ignore this case and say that l1 to u1 and l2 to u2 is never the same as l1 to u2, # even if u1 == l2 (and they mathematically coincide). if not isinstance(other, MultiSpace): warnings.warn("Multispace limits compare never equal to Space.", stacklevel=2) return equal_space(self, other, allow_graph=False) def __le__(self, other): # in principle, two disjoint regions could have coinciding lower and upper limits equaling to an actually larger # space. However, we ignore this case and say that l1 to u1 and l2 to u2 is never the same as l1 to u2, # even if u1 == l2 (and they mathematically coincide). if not isinstance(other, MultiSpace): warnings.warn("Multispace limits compare never equal to Space.", stacklevel=2) return less_equal_space(self, other, allow_graph=False) def __hash__(self): return hash(self.spaces) def convert_to_space( obs: ztyping.ObsTypeInput | None = None, axes: ztyping.AxesTypeInput | None = None, limits: ztyping.LimitsTypeInput | None = None, *, overwrite_limits: bool = False, one_dim_limits_only: bool = True, simple_limits_only: bool = True, ) -> None | ZfitSpace | bool: """Convert *limits* to a :py:class:`~zfit.Space` object if not already None or False. Args: obs: limits: axes: overwrite_limits: If `obs` or `axes` is a :py:class:`~zfit.Space` _and_ `limits` are given, return an instance of :py:class:`~zfit.Space` with the new limits. If the flag is `False`, the `limits` argument will be ignored if one_dim_limits_only: simple_limits_only: Returns: Union[:py:class:`~zfit.Space`, False, None]: Raises: OverdefinedError: if `obs` or `axes` is a :py:class:`~zfit.Space` and `axes` respectively `obs` is not `None`. """ space = None # Test if already `Space` and handle if isinstance(obs, ZfitSpace): if axes is not None: msg = "if `obs` is a `Space`, `axes` cannot be defined." raise OverdefinedError(msg) space = obs elif isinstance(axes, ZfitSpace): if obs is not None: msg = "if `axes` is a `Space`, `obs` cannot be defined." raise OverdefinedError(msg) space = axes elif isinstance(limits, ZfitSpace): return limits if space is not None: # set the limits if given if limits is not None and (overwrite_limits or not space.limits_are_set): if isinstance(limits, ZfitSpace): # figure out if compatible if limits is `Space` if not ( limits.obs == space.obs or (limits.axes == space.axes and limits.obs is None and space.obs is None) ): msg = "`obs`/`axes` is a `Space` as well as the `limits`, but the " "obs/axes of them do not match" raise IntentionAmbiguousError(msg) limits = False if limits.limits_are_false else limits.limits space = space.with_limits(limits=limits) return space # space is None again if not (obs is None and axes is None): # check if limits are allowed space = Space(obs=obs, axes=axes, limits=limits) # create and test if valid if one_dim_limits_only and space.n_obs > 1 and space.has_limits: msg = "Limits more sophisticated than 1-dim cannot be auto-created from tuples. Use `Space` instead." raise LimitsUnderdefinedError(msg) if simple_limits_only and space.has_limits and space._depr_n_limits > 1: msg = "Limits with multiple limits cannot be auto-created" " from tuples. Use `Space` instead." raise LimitsUnderdefinedError(msg) return space def check_norm(supports=None): if supports is None: supports = False supports = convert_to_container(supports, convert_none=True) def no_norm_range(func): """Decorator: Catch the 'norm' kwargs. If not None, raise `NormNotImplemented`. """ parameters = inspect.signature(func).parameters keys = list(parameters.keys()) norm_range_index = None norm_index = None if "norm_range" in keys: norm_range_index = keys.index("norm_range") if "norm" in keys: norm_index = keys.index("norm") @functools.wraps(func) def new_func(*args, **kwargs): self = args[0] if len(args) > 0 else None norm_range = kwargs.get("norm_range") norm = kwargs.get("norm") norm_is_arg = False if norm_range is not None: norm = norm_range elif norm is not None: pass else: if norm_range_index is not None: norm_is_arg = len(args) > norm_range_index if norm_is_arg: norm = args[norm_range_index] if norm_index is not None: norm_is_arg = len(args) > norm_index if norm_is_arg: norm = args[norm_index] args = list(args) # if not norm_is_arg: # TODO: remove? why is this here? # if 'norm_range' in kwargs: # kwargs['norm_range'] = False # if 'norm' in kwargs: # kwargs['norm'] = False # assume it's not supported. Switch if we find that it is supported. norm_not_supported = supports[0] is not True if isinstance(norm, ZfitSpace): if "space" in supports and isinstance(self, ZfitPDF) and self.space == norm: norm_not_supported = False if "norm" in supports and isinstance(self, ZfitPDF) and self.norm == norm: norm_not_supported = False if norm_not_supported: norm_not_supported = not norm.limits_are_false if norm.limits_are_false: if not norm_is_arg: # TODO: remove? why is this here? if "norm_range" in kwargs: kwargs["norm_range"] = False if "norm" in kwargs: kwargs["norm"] = False elif norm_range_index is not None: args[norm_range_index] = False elif norm_index is not None: args[norm_index] = False elif norm_not_supported: norm_not_supported = not (norm is None or norm is False) if norm_not_supported: raise NormNotImplemented() try: return func(*args, **kwargs) except TypeError as error: if "got an unexpected keyword argument 'norm_range'" in str(error): kwargs.pop("norm_range") elif "got an unexpected keyword argument 'norm'" in str(error): kwargs.pop("norm") else: raise return func(*args, **kwargs) return new_func return no_norm_range def param_args_supported(func): """Decorator: Catch the 'params' kwargs if not supported.""" @functools.wraps(func) def new_func(*args, **kwargs): self = args[0] if len(args) > 0 else None if "params" not in kwargs and self is not None: params = {key: znp.asarray(val) for key, val in self.params.items()} kwargs["params"] = params try: return func(*args, **kwargs) except TypeError as error: if "got an unexpected keyword argument 'params'" in str(error): kwargs.pop("params") else: raise return func(*args, **kwargs) return new_func def no_multiple_limits(func): """Decorator: Catch the 'limits' kwargs. If it contains multiple limits, raise MultipleLimitsNotImplementedError. """ parameters = inspect.signature(func).parameters if "limits" in (keys := list(parameters.keys())): limits_index = keys.index("limits") else: return func # no limits as parameters -> no problem @functools.wraps(func) def new_func(*args, **kwargs): limits_is_arg = len(args) > limits_index limits = args[limits_index] if limits_is_arg else kwargs["limits"] if limits._depr_n_limits > 1: raise MultipleLimitsNotImplemented return func(*args, **kwargs) return new_func @deprecated_norm_range def supports( *, norm: bool | str | Iterable[str] | None = None, multiple_limits: bool | None = None, ) -> Callable: """Decorator: Add (mandatory for some methods) on a method to control what it can handle. If any of the flags is set to False, it will check the arguments and, in case they match a flag (say if a *norm* is passed while the *norm* flag is set to `False`), it will raise a corresponding exception (in this example a `NormRangeNotImplementedError`) that will be catched by an earlier function that knows how to handle things. Args: norm: If False, no norm_range argument will be passed through resp. will be `None`. Other options include `'space'` or `'norm'`, which will check if the norm is equal to the space or norm of the PDF. If they are, it is assumed to be supported. multiple_limits: If False, only simple limits are to be expected and no iteration is therefore required. """ if norm is None: norm = False if multiple_limits is None: multiple_limits = False decorator_stack = [] if not multiple_limits: decorator_stack.append(no_multiple_limits) if norm is not None: # check True. Could also be a str decorator_stack.append(check_norm(norm)) decorator_stack.append(param_args_supported) def create_deco_stack(func): for decorator in reversed(decorator_stack): func = decorator(func) func.__wrapped__ = supports return func return create_deco_stack def contains_tensor(objects): tensor_found = tf.is_tensor(objects) with suppress(TypeError): for obj in objects: if tensor_found: break tensor_found += contains_tensor(obj) return tensor_found def shape_np_tf(objects): return tuple(tf.convert_to_tensor(objects).shape.as_list()) if contains_tensor(objects) else np.shape(objects) def limits_consistent(spaces: Iterable[zfit.Space]): """Check if space limits are the *exact* same in each obs they are defined and therefore are compatible. In this case, if a space has several limits, e.g. from -1 to 1 and from 2 to 3 (all in the same observable), to be consistent with this limits, other limits have to have (in this obs) also the limits from -1 to 1 and from 2 to 3. Only having the limit -1 to 1 _or_ 2 to 3 is considered _not_ consistent. This function is useful to check if several spaces with *different* observables can be _combined_. Args: spaces: Returns: """ try: _ = combine_spaces(*spaces) except LimitsIncompatibleError: return False else: return True def add_spaces_old(spaces: Iterable[zfit.Space]): """Add two spaces and merge their limits if possible or return False. Args: spaces: Returns: Union[None, :py:class:`~zfit.Space`, bool]: Raises: LimitsIncompatibleError: if limits of the `spaces` cannot be merged because they overlap """ spaces = convert_to_container(spaces) if not all(isinstance(space, ZfitSpace) for space in spaces): msg = "Cannot only add type ZfitSpace" raise TypeError(msg) if len(spaces) <= 1: msg = "Need at least two spaces to be added." raise ValueError(msg) # TODO: allow? usecase? obs = frozenset(frozenset(space.obs) for space in spaces) if len(obs) != 1: return False obs1 = spaces[0].obs spaces = [space.with_obs(obs=obs1) if space.obs != obs1 else space for space in spaces] if limits_overlap(spaces=spaces, allow_exact_match=True): msg = "Limits of spaces overlap, cannot merge spaces." raise LimitsIncompatibleError(msg) lowers = [] uppers = [] for space in spaces: if not space.limits_are_set: continue for lower, upper in space: for other_lower, other_upper in zip(lowers, uppers): lower_same = np.allclose(lower, other_lower) upper_same = np.allclose(upper, other_upper) assert not lower_same ^ upper_same, "Bug, please report as issue. limits_overlap did not catch right." if lower_same and upper_same: break else: lowers.append(lower) uppers.append(upper) lowers = tuple(lowers) uppers = tuple(uppers) limits = None if len(lowers) == 0 else (lowers, uppers) return zfit.Space(obs=spaces[0].obs, limits=limits)