# Copyright (c) 2024 zfit
from __future__ import annotations
import abc
import collections
from collections.abc import Callable
from typing import Iterable, Literal, Mapping, Optional
import numpy as np
import pydantic
import tensorflow as tf
import tensorflow_probability as tfp
from ordered_set import OrderedSet
import zfit.z.numpy as znp
from .. import z
from ..serialization.serializer import BaseRepr, Serializer
from ..settings import ztypes
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.deprecation import deprecated_args
from ..util.exception import ShapeIncompatibleError
from .baseobject import BaseNumeric
from .dependents import _extract_dependencies
from .interfaces import ZfitConstraint, ZfitParameter
from .serialmixin import SerializableMixin
tfd = tfp.distributions
class BaseConstraintRepr(BaseRepr):
_implementation = None
_owndict = pydantic.PrivateAttr(default_factory=dict)
hs3_type: Literal["BaseConstraint"] = pydantic.Field("BaseConstraint", alias="type")
class BaseConstraint(ZfitConstraint, BaseNumeric):
def __init__(
self,
params: dict[str, ZfitParameter] | None = None,
name: str = "BaseConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints.
Args:
dtype: the dtype of the constraint
name: the name of the constraint
params: A dictionary with the internal name of the
parameter and the parameters itself the constrains depends on
"""
super().__init__(name=name, dtype=dtype, params=params, **kwargs)
def value(self):
return self._value()
@abc.abstractmethod
def _value(self):
raise NotImplementedError
def _get_dependencies(self) -> ztyping.DependentsType:
return _extract_dependencies(self.get_params(floating=None))
# TODO: improve arbitrary constraints, should we allow only functions that have a `params` argument?
[docs]
class SimpleConstraint(BaseConstraint):
def __init__(
self,
func: Callable,
params: (Mapping[str, ztyping.ParameterType] | Iterable[ztyping.ParameterType] | ztyping.ParameterType | None),
*,
name: str | None = None,
):
"""Constraint from a (function returning a) Tensor.
Args:
func: Callable that constructs the constraint and returns a tensor. For the expected signature,
see below in ``params``.
params: The parameters of the loss. If given as a list, the parameters are named "param_{i}"
and the function does not take any arguments. If given as a dict, the function expects
the parameter as the first argument (``params``).
"""
if name is None:
name = "SimpleConstraint"
self._simple_func = func
self._func_params = None
if isinstance(params, collections.abc.Mapping):
self._func_params = params
params = list(params.values())
self._simple_func_dependents = convert_to_container(params, container=OrderedSet)
params = convert_to_container(params, container=list)
params = {f"param_{i}": p for i, p in enumerate(params)} if self._func_params is None else self._func_params
super().__init__(name=name, params=params)
def _value(self):
if self._func_params is None:
return self._simple_func()
else:
return self._simple_func(self._func_params)
class ProbabilityConstraint(BaseConstraint):
def __init__(
self,
observation: ztyping.NumericalScalarType | ZfitParameter,
params: dict[str, ZfitParameter] | None = None,
name: str = "ProbabilityConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints using a probability density function.
Args:
dtype: the dtype of the constraint
name: the name of the constraint
params: The parameters to constraint
observation: Observed values of the parameter
to constraint obtained from auxiliary measurements.
"""
# TODO: proper handling of input params, arrays. ArrayParam?
if isinstance(params, collections.abc.Mapping):
params_dict = params
params = [p for name, p in params.items() if name.startswith("param_")]
else:
params = convert_to_container(params, ignore=np.ndarray, container=tuple)
params_dict = {f"param_{i}": p for i, p in enumerate(params)}
super().__init__(name=name, dtype=dtype, params=params_dict, **kwargs)
observation = convert_to_container(observation, tuple, ignore=np.ndarray)
if len(observation) != len(params):
msg = (
"observation and params have to be the same length. Currently"
f"observation: {len(observation)}, params: {len(params)}"
)
raise ShapeIncompatibleError(msg)
self._observation = observation # TODO: needed below? Why?
# for obs, p in zip(observation, params):
# obs = convert_to_parameter(obs, f"{p.name}_obs", prefer_constant=False)
# obs.floating = False
# self._observation.append(obs)
self._ordered_params = params
@property
def observation(self):
"""Return the observed values of the parameters constrained."""
return self._observation
def value(self):
return self._value()
@abc.abstractmethod
def _value(self):
raise NotImplementedError
def _get_dependencies(self) -> ztyping.DependentsType:
return _extract_dependencies(self.get_params())
def sample(self, n):
"""Sample ``n`` points from the probability density function for the observed value of the parameters.
Args:
n: The number of samples to be generated.
Returns:
"""
sample = self._sample(n=n)
return {p: sample[:, i] for i, p in enumerate(self._ordered_params)}
@abc.abstractmethod
def _sample(self, n):
raise NotImplementedError
@property
def _params_array(self):
return znp.asarray(self._ordered_params)
class TFProbabilityConstraint(ProbabilityConstraint):
def __init__(
self,
observation: ztyping.NumericalScalarType | ZfitParameter,
params: dict[str, ZfitParameter],
distribution: tfd.Distribution,
dist_params,
dist_kwargs=None,
name: str = "DistributionConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints using a probability density function from ``tensorflow_probability``.
Args:
distribution: The probability density function
used to constraint the parameters
"""
super().__init__(observation=observation, params=params, name=name, dtype=dtype, **kwargs)
self._distribution = distribution
self.dist_params = dist_params
self.dist_kwargs = dist_kwargs if dist_kwargs is not None else {}
@property
def distribution(self):
params = self.dist_params
if callable(params):
params = params(self.observation)
kwargs = self.dist_kwargs
if callable(kwargs):
kwargs = kwargs()
params = {k: znp.asarray(v, ztypes.float) for k, v in params.items()}
return self._distribution(**params, **kwargs, name=f"{self.name}_tfp")
def _value(self):
array = znp.asarray(self._params_array, ztypes.float)
value = -self.distribution.log_prob(array)
return tf.reduce_sum(value)
def _sample(self, n):
return self.distribution.sample(n)
def _preprocess_gaussian_constr_sigma_var(cov, sigma, legacy_uncertainty):
if sigma is not None:
if legacy_uncertainty:
msg = "Either `sigma` or `uncertainty` can be given, not both. Use `sigma`. `uncertainty` is deprecated."
raise ValueError(msg)
if cov is not None:
msg = "Either `sigma` or `cov` can be given, not both."
raise ValueError(msg)
if any(isinstance(s, ZfitParameter) for s in convert_to_container(sigma)):
msg = "sigma has to be a scalar or a 1D tensor, not a ZfitParameter (if this feature is needed, please open an issue on github with zfit."
raise ValueError(msg)
sigma = znp.asarray(sigma, ztypes.float)
sigma = znp.atleast_1d(sigma)
if (ndims := sigma.shape.ndims) == 2:
msg = f"sigma has to be a scalar or a 1D tensor, not a {ndims}D tensor. Use `cov` instead."
raise ValueError(msg)
if ndims < 2:
cov = znp.diag(znp.square(sigma))
else:
msg = f"sigma has to be a scalar, a 1D tensor or a 2D tensor, not {ndims}D."
raise ValueError(msg)
elif cov is not None:
if any(isinstance(c, ZfitParameter) for c in convert_to_container(cov)):
msg = "cov has to be a scalar, a 1D tensor or a 2D tensor, not a ZfitParameter (if this feature is needed, please open an issue on github with zfit."
raise ValueError(msg)
if legacy_uncertainty:
msg = "Either `cov` or `uncertainty` can be given, not both. Use `cov`. `uncertainty` is deprecated."
raise ValueError(msg)
cov = znp.atleast_1d(znp.asarray(cov, ztypes.float))
if cov.shape.ndims == 1:
cov = znp.diag(cov)
sigma = znp.sqrt(znp.diag(cov))
else: # legacy 3
sigma = -999
cov = -999
# end legacy 3
return sigma, cov
[docs]
class GaussianConstraint(TFProbabilityConstraint, SerializableMixin):
@deprecated_args(None, "Use `sigma` or `cov` instead.", "uncertainty")
def __init__(
self,
params: ztyping.ParamTypeInput,
observation: ztyping.NumericalScalarType,
*,
uncertainty: ztyping.NumericalScalarType = None,
sigma: ztyping.NumericalScalarType = None,
cov: ztyping.NumericalScalarType = None,
):
r"""Gaussian constraints on a list of parameters to some observed values with uncertainties.
A Gaussian constraint is defined as the likelihood of ``params`` given the ``observations`` and ``sigma`` or ``cov``
from a different measurement.
.. math::
\text{constraint} = \text{Gauss}(\text{observation}; \text{params}, \text{uncertainty})
Args:
params: The parameters to constraint; corresponds to x in the Gaussian
distribution.
observation: observed values of the parameter; corresponds to mu
in the Gaussian distribution.
sigma: Typically the uncertainties of the observed values. Can either be a single value,
a list of values, an array or a tensor. Must be broadcastable to the shape of the parameters.
Either `sigma` or `cov` can be given, not both.
``sigma`` is the square root of the diagonal of the covariance matrix.
cov: The covariance matrix of the observed values. Can either be a single value,
a list of values, an array or a tensor that are either 1 or 2 dimensional. If 1D, it is interpreted
as the diagonal of the covariance matrix.
Either ``sigma`` or ``cov`` can be given, not both.
``cov`` is a 2D matrix with the shape `(n, n)` where `n` is the number of parameters and ``sigma``
squared on the diagonal.
Raises:
ShapeIncompatibleError: If params, mu and sigma have incompatible shapes.
"""
observation = convert_to_container(observation, tuple, ignore=np.ndarray)
params = convert_to_container(params, tuple, ignore=np.ndarray)
params_tuple_legacy = params
# legacy start 1
if legacy_uncertainty := uncertainty is not None:
uncertainty = convert_to_container(uncertainty, tuple, ignore=np.ndarray)
if isinstance(uncertainty[0], (np.ndarray, tf.Tensor)) and len(uncertainty) == 1:
uncertainty = tuple(uncertainty[0])
def create_covariance_legacy(mu, sigma):
mu = z.convert_to_tensor(mu)
sigma = znp.asarray(
sigma
) # otherwise TF complains that the shape got changed from [2] to [2, 2] (if we have a tuple of two arrays)
sigma = z.convert_to_tensor(sigma)
params_tensor = z.convert_to_tensor(params_tuple_legacy)
if sigma.shape.ndims > 1:
covariance = sigma
elif sigma.shape.ndims == 1:
covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
else:
sigma = znp.reshape(sigma, [1])
covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
if not params_tensor.shape[0] == mu.shape[0] == covariance.shape[0] == covariance.shape[1]:
msg = (
f"params_tensor, observation and uncertainty have to have the"
" same length. Currently"
f"param: {params_tensor.shape[0]}, mu: {mu.shape[0]}, "
f"covariance (from uncertainty): {covariance.shape[0:2]}"
)
raise ShapeIncompatibleError(msg)
return covariance
# legacy end 1
original_init = {
"observation": observation,
"params": params,
"uncertainty": uncertainty,
"sigma": sigma,
"cov": cov,
}
sigma, cov = _preprocess_gaussian_constr_sigma_var(cov, sigma, legacy_uncertainty)
self.__cov = cov
self.__sigma = sigma
distribution = tfd.MultivariateNormalTriL
def dist_params(observation, *, self=self):
return {"loc": observation, "scale_tril": tf.linalg.cholesky(self.covariance)}
dist_kwargs = {"validate_args": True}
params = {f"param_{i}": p for i, p in enumerate(params)}
super().__init__(
name="GaussianConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
if legacy_uncertainty:
self._covariance = lambda: create_covariance_legacy(self.observation, uncertainty)
else:
self._covariance = lambda cov: znp.asarray(cov, ztypes.float)
self._legacy_uncertainty = legacy_uncertainty
@property
def covariance(self):
"""Return the covariance matrix of the observed values of the parameters constrained."""
# legacy start 2
if self._legacy_uncertainty:
return self._covariance()
# legacy end 2
return self._covariance(cov=self.__cov)
class GaussianConstraintRepr(BaseConstraintRepr):
_implementation = GaussianConstraint
hs3_type: Literal["GaussianConstraint"] = pydantic.Field("GaussianConstraint", alias="type")
params: list[Serializer.types.ParamInputTypeDiscriminated]
observation: list[Serializer.types.ParamInputTypeDiscriminated]
uncertainty: Optional[list[Serializer.types.ParamInputTypeDiscriminated]]
sigma: Optional[list[Serializer.types.ParamInputTypeDiscriminated]]
cov: Optional[list[Serializer.types.ParamInputTypeDiscriminated]]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation", "uncertainty", "sigma", "cov")
def validate_params(cls, v):
return v.tolist() if isinstance(v, np.ndarray) else convert_to_container(v, list)
[docs]
class PoissonConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(self, params: ztyping.ParamTypeInput, observation: ztyping.NumericalScalarType):
r"""Poisson constraints on a list of parameters to some observed values.
Constraints parameters that can be counts (i.e. from a histogram) or, more generally, are
Poisson distributed. This is often used in the case of histogram templates which are obtained
from simulation and have a poisson uncertainty due to limited statistics.
.. math::
\text{constraint} = \text{Poisson}(\text{observation}; \text{params})
Args:
params: The parameters to constraint; corresponds to the mu in the Poisson
distribution.
observation: observed values of the parameter; corresponds to lambda
in the Poisson distribution.
Raises:
ShapeIncompatibleError: If params and observation have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
original_init = {"observation": observation, "params": params}
distribution = tfd.Poisson
dist_params = {"rate": observation}
dist_kwargs = {"validate_args": False}
super().__init__(
name="PoissonConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
class PoissonConstraintRepr(BaseConstraintRepr):
_implementation = PoissonConstraint
hs3_type: Literal["PoissonConstraint"] = pydantic.Field("PoissonConstraint", alias="type")
params: list[Serializer.types.ParamInputTypeDiscriminated]
observation: list[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation")
def validate_params(cls, v):
return v.tolist() if isinstance(v, np.ndarray) else convert_to_container(v, list)
[docs]
class LogNormalConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(
self,
params: ztyping.ParamTypeInput,
observation: ztyping.NumericalScalarType,
uncertainty: ztyping.NumericalScalarType,
):
r"""Log-normal constraints on a list of parameters to some observed values.
Constraints parameters that can be counts (i.e. from a histogram) or, more generally, are
LogNormal distributed. This is often used in the case of histogram templates which are obtained
from simulation and have a log-normal uncertainty due to a multiplicative uncertainty.
.. math::
\text{constraint} = \text{LogNormal}(\text{observation}; \text{params})
Args:
params: The parameters to constraint; corresponds to the mu in the Poisson
distribution.
observation: observed values of the parameter; corresponds to lambda
in the Poisson distribution.
uncertainty: uncertainty of the observed values of the parameter; corresponds to sigma
in the Poisson distribution.
Raises:
ShapeIncompatibleError: If params, mu and sigma have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
uncertainty = convert_to_container(uncertainty, tuple)
original_init = {
"observation": observation,
"params": params,
"uncertainty": uncertainty,
}
distribution = tfd.LogNormal
def dist_params(observation):
return {"loc": observation, "scale": uncertainty}
dist_kwargs = {"validate_args": False}
super().__init__(
name="LogNormalConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
class LogNormalConstraintRepr(BaseConstraintRepr):
_implementation = LogNormalConstraint
hs3_type: Literal["LogNormalConstraint"] = pydantic.Field("LogNormalConstraint", alias="type")
params: list[Serializer.types.ParamInputTypeDiscriminated]
observation: list[Serializer.types.ParamInputTypeDiscriminated]
uncertainty: list[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation", "uncertainty")
def validate_params(cls, v):
return v.tolist() if isinstance(v, np.ndarray) else convert_to_container(v, list)