# Copyright (c) 2024 zfit
from __future__ import annotations
import abc
import collections
from collections.abc import Callable
from typing import Iterable, Literal, Mapping, Optional
import numpy as np
import pydantic.v1 as pydantic
import tensorflow as tf
import tensorflow_probability as tfp
from ordered_set import OrderedSet
import zfit.z.numpy as znp
from .. import z
from ..serialization.serializer import BaseRepr, Serializer
from ..settings import ztypes
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.deprecation import deprecated_args
from ..util.exception import ShapeIncompatibleError
from .baseobject import BaseNumeric
from .interfaces import ZfitConstraint, ZfitParameter
from .serialmixin import SerializableMixin
tfd = tfp.distributions
class BaseConstraintRepr(BaseRepr):
_implementation = None
_owndict = pydantic.PrivateAttr(default_factory=dict)
hs3_type: Literal["BaseConstraint"] = pydantic.Field("BaseConstraint", alias="type")
class BaseConstraint(ZfitConstraint, BaseNumeric):
def __init__(
self,
params: dict[str, ZfitParameter] | None = None,
name: str = "BaseConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints.
Args:
dtype: the dtype of the constraint
name: the name of the constraint
params: A dictionary with the internal name of the
parameter and the parameters itself the constrains depends on
"""
super().__init__(name=name, dtype=dtype, params=params, **kwargs)
def value(self):
return self._value()
@abc.abstractmethod
def _value(self):
raise NotImplementedError
# TODO: improve arbitrary constraints, should we allow only functions that have a `params` argument?
[docs]
class SimpleConstraint(BaseConstraint):
def __init__(
self,
func: Callable,
params: (Mapping[str, ztyping.ParameterType] | Iterable[ztyping.ParameterType] | ztyping.ParameterType | None),
*,
name: str | None = None,
):
"""Constraint from a (function returning a) Tensor.
Args:
func: Callable that constructs the constraint and returns a tensor. For the expected signature,
see below in ``params``.
params: The parameters of the loss. If given as a list, the parameters are named "param_{i}"
and the function does not take any arguments. If given as a dict, the function expects
the parameter as the first argument (``params``).
"""
if name is None:
name = "SimpleConstraint"
self._simple_func = func
self._func_params = None
if isinstance(params, collections.abc.Mapping):
self._func_params = params
params = list(params.values())
self._simple_func_dependents = convert_to_container(params, container=OrderedSet)
params = convert_to_container(params, container=list)
params = {f"param_{i}": p for i, p in enumerate(params)} if self._func_params is None else self._func_params
super().__init__(name=name, params=params)
def _value(self):
if self._func_params is None:
return self._simple_func()
else:
return self._simple_func(self._func_params)
class ProbabilityConstraint(BaseConstraint):
def __init__(
self,
observation: ztyping.NumericalScalarType | ZfitParameter,
params: dict[str, ZfitParameter] | None = None,
name: str = "ProbabilityConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints using a probability density function.
Args:
dtype: the dtype of the constraint
name: the name of the constraint
params: The parameters to constraint
observation: Observed values of the parameter
to constraint obtained from auxiliary measurements.
"""
# TODO: proper handling of input params, arrays. ArrayParam?
if isinstance(params, collections.abc.Mapping):
params_dict = params
params = [p for name, p in params.items() if name.startswith("param_")]
else:
params = convert_to_container(params, ignore=np.ndarray, container=tuple)
params_dict = {f"param_{i}": p for i, p in enumerate(params)}
super().__init__(name=name, dtype=dtype, params=params_dict, **kwargs)
observation = convert_to_container(observation, tuple, ignore=np.ndarray)
if len(observation) != len(params):
msg = (
"observation and params have to be the same length. Currently"
f"observation: {len(observation)}, params: {len(params)}"
)
raise ShapeIncompatibleError(msg)
self._observation = observation # TODO: needed below? Why?
# for obs, p in zip(observation, params):
# obs = convert_to_parameter(obs, f"{p.name}_obs", prefer_constant=False)
# obs.floating = False
# self._observation.append(obs)
self._ordered_params = params
@property
def observation(self):
"""Return the observed values of the parameters constrained."""
return self._observation
def value(self):
return self._value()
@abc.abstractmethod
def _value(self):
raise NotImplementedError
def sample(self, n):
"""Sample ``n`` points from the probability density function for the observed value of the parameters.
Args:
n: The number of samples to be generated.
Returns:
"""
sample = self._sample(n=n)
return {p: sample[:, i] for i, p in enumerate(self._ordered_params)}
@abc.abstractmethod
def _sample(self, n):
raise NotImplementedError
@property
def _params_array(self):
return znp.asarray(self._ordered_params)
class TFProbabilityConstraint(ProbabilityConstraint):
def __init__(
self,
observation: ztyping.NumericalScalarType | ZfitParameter,
params: dict[str, ZfitParameter],
distribution: tfd.Distribution,
dist_params,
dist_kwargs=None,
name: str = "DistributionConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints using a probability density function from ``tensorflow_probability``.
Args:
distribution: The probability density function
used to constraint the parameters
"""
super().__init__(observation=observation, params=params, name=name, dtype=dtype, **kwargs)
self._distribution = distribution
self.dist_params = dist_params
self.dist_kwargs = dist_kwargs if dist_kwargs is not None else {}
@property
def distribution(self):
params = self.dist_params
if callable(params):
params = params(self.observation)
kwargs = self.dist_kwargs
if callable(kwargs):
kwargs = kwargs()
params = {k: znp.asarray(v, ztypes.float) for k, v in params.items()}
return self._distribution(**params, **kwargs, name=f"{self.name}_tfp")
def _value(self):
array = znp.asarray(self._params_array, ztypes.float)
value = -self.distribution.log_prob(array)
return tf.reduce_sum(value)
def _sample(self, n):
return self.distribution.sample(n)
def _preprocess_gaussian_constr_sigma_var(cov, sigma, legacy_uncertainty):
if sigma is not None:
if legacy_uncertainty:
msg = "Either `sigma` or `uncertainty` can be given, not both. Use `sigma`. `uncertainty` is deprecated."
raise ValueError(msg)
if cov is not None:
msg = "Either `sigma` or `cov` can be given, not both."
raise ValueError(msg)
if any(isinstance(s, ZfitParameter) for s in convert_to_container(sigma)):
msg = "sigma has to be a scalar or a 1D tensor, not a ZfitParameter (if this feature is needed, please open an issue on github with zfit."
raise ValueError(msg)
sigma = znp.asarray(sigma, ztypes.float)
sigma = znp.atleast_1d(sigma)
if (ndims := sigma.shape.ndims) == 2:
msg = f"sigma has to be a scalar or a 1D tensor, not a {ndims}D tensor. Use `cov` instead."
raise ValueError(msg)
if ndims < 2:
cov = znp.diag(znp.square(sigma))
else:
msg = f"sigma has to be a scalar, a 1D tensor or a 2D tensor, not {ndims}D."
raise ValueError(msg)
elif cov is not None:
if any(isinstance(c, ZfitParameter) for c in convert_to_container(cov)):
msg = "cov has to be a scalar, a 1D tensor or a 2D tensor, not a ZfitParameter (if this feature is needed, please open an issue on github with zfit."
raise ValueError(msg)
if legacy_uncertainty:
msg = "Either `cov` or `uncertainty` can be given, not both. Use `cov`. `uncertainty` is deprecated."
raise ValueError(msg)
cov = znp.atleast_1d(znp.asarray(cov, ztypes.float))
if cov.shape.ndims == 1:
cov = znp.diag(cov)
sigma = znp.sqrt(znp.diag(cov))
else: # legacy 3
sigma = -999
cov = -999
# end legacy 3
return sigma, cov
[docs]
class GaussianConstraint(TFProbabilityConstraint, SerializableMixin):
@deprecated_args(None, "Use `sigma` or `cov` instead.", "uncertainty")
def __init__(
self,
params: ztyping.ParamTypeInput,
observation: ztyping.NumericalScalarType,
*,
uncertainty: ztyping.NumericalScalarType = None,
sigma: ztyping.NumericalScalarType = None,
cov: ztyping.NumericalScalarType = None,
):
r"""Gaussian constraints on a list of parameters to some observed values with uncertainties.
A Gaussian constraint is defined as the likelihood of ``params`` given the ``observations`` and ``sigma`` or ``cov``
from a different measurement.
.. math::
\text{constraint} = \text{Gauss}(\text{observation}; \text{params}, \text{uncertainty})
Args:
params: The parameters to constraint; corresponds to x in the Gaussian
distribution.
observation: observed values of the parameter; corresponds to mu
in the Gaussian distribution.
sigma: Typically the uncertainties of the observed values. Can either be a single value,
a list of values, an array or a tensor. Must be broadcastable to the shape of the parameters.
Either `sigma` or `cov` can be given, not both.
``sigma`` is the square root of the diagonal of the covariance matrix.
cov: The covariance matrix of the observed values. Can either be a single value,
a list of values, an array or a tensor that are either 1 or 2 dimensional. If 1D, it is interpreted
as the diagonal of the covariance matrix.
Either ``sigma`` or ``cov`` can be given, not both.
``cov`` is a 2D matrix with the shape `(n, n)` where `n` is the number of parameters and ``sigma``
squared on the diagonal.
Raises:
ShapeIncompatibleError: If params, mu and sigma have incompatible shapes.
"""
observation = convert_to_container(observation, tuple, ignore=np.ndarray)
params = convert_to_container(params, tuple, ignore=np.ndarray)
params_tuple_legacy = params
# legacy start 1
if legacy_uncertainty := uncertainty is not None:
uncertainty = convert_to_container(uncertainty, tuple, ignore=np.ndarray)
if isinstance(uncertainty[0], (np.ndarray, tf.Tensor)) and len(uncertainty) == 1:
uncertainty = tuple(uncertainty[0])
def create_covariance_legacy(mu, sigma):
mu = z.convert_to_tensor(mu)
sigma = znp.asarray(
sigma
) # otherwise TF complains that the shape got changed from [2] to [2, 2] (if we have a tuple of two arrays)
sigma = z.convert_to_tensor(sigma)
params_tensor = z.convert_to_tensor(params_tuple_legacy)
if sigma.shape.ndims > 1:
covariance = sigma
elif sigma.shape.ndims == 1:
covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
else:
sigma = znp.reshape(sigma, [1])
covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
if not params_tensor.shape[0] == mu.shape[0] == covariance.shape[0] == covariance.shape[1]:
msg = (
f"params_tensor, observation and uncertainty have to have the"
" same length. Currently"
f"param: {params_tensor.shape[0]}, mu: {mu.shape[0]}, "
f"covariance (from uncertainty): {covariance.shape[0:2]}"
)
raise ShapeIncompatibleError(msg)
return covariance
# legacy end 1
original_init = {
"observation": observation,
"params": params,
"uncertainty": uncertainty,
"sigma": sigma,
"cov": cov,
}
sigma, cov = _preprocess_gaussian_constr_sigma_var(cov, sigma, legacy_uncertainty)
self.__cov = cov
self.__sigma = sigma
distribution = tfd.MultivariateNormalTriL
def dist_params(observation, *, self=self):
return {"loc": observation, "scale_tril": tf.linalg.cholesky(self.covariance)}
dist_kwargs = {"validate_args": True}
params = {f"param_{i}": p for i, p in enumerate(params)}
super().__init__(
name="GaussianConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
if legacy_uncertainty:
self._covariance = lambda: create_covariance_legacy(self.observation, uncertainty)
else:
self._covariance = lambda cov: znp.asarray(cov, ztypes.float)
self._legacy_uncertainty = legacy_uncertainty
@property
def covariance(self):
"""Return the covariance matrix of the observed values of the parameters constrained."""
# legacy start 2
if self._legacy_uncertainty:
return self._covariance()
# legacy end 2
return self._covariance(cov=self.__cov)
class GaussianConstraintRepr(BaseConstraintRepr):
_implementation = GaussianConstraint
hs3_type: Literal["GaussianConstraint"] = pydantic.Field("GaussianConstraint", alias="type")
params: list[Serializer.types.ParamInputTypeDiscriminated]
observation: list[Serializer.types.ParamInputTypeDiscriminated]
uncertainty: Optional[list[Serializer.types.ParamInputTypeDiscriminated]]
sigma: Optional[list[Serializer.types.ParamInputTypeDiscriminated]]
cov: Optional[list[Serializer.types.ParamInputTypeDiscriminated]]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation", "uncertainty", "sigma", "cov")
def validate_params(cls, v):
return v.tolist() if isinstance(v, np.ndarray) else convert_to_container(v, list)
[docs]
class PoissonConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(self, params: ztyping.ParamTypeInput, observation: ztyping.NumericalScalarType):
r"""Poisson constraints on a list of parameters to some observed values.
Constraints parameters that can be counts (i.e. from a histogram) or, more generally, are
Poisson distributed. This is often used in the case of histogram templates which are obtained
from simulation and have a poisson uncertainty due to limited statistics.
.. math::
\text{constraint} = \text{Poisson}(\text{observation}; \text{params})
Args:
params: The parameters to constraint; corresponds to the mu in the Poisson
distribution.
observation: observed values of the parameter; corresponds to lambda
in the Poisson distribution.
Raises:
ShapeIncompatibleError: If params and observation have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
original_init = {"observation": observation, "params": params}
distribution = tfd.Poisson
dist_params = {"rate": observation}
dist_kwargs = {"validate_args": False}
super().__init__(
name="PoissonConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
class PoissonConstraintRepr(BaseConstraintRepr):
_implementation = PoissonConstraint
hs3_type: Literal["PoissonConstraint"] = pydantic.Field("PoissonConstraint", alias="type")
params: list[Serializer.types.ParamInputTypeDiscriminated]
observation: list[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation")
def validate_params(cls, v):
return v.tolist() if isinstance(v, np.ndarray) else convert_to_container(v, list)
[docs]
class LogNormalConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(
self,
params: ztyping.ParamTypeInput,
observation: ztyping.NumericalScalarType,
uncertainty: ztyping.NumericalScalarType,
):
r"""Log-normal constraints on a list of parameters to some observed values.
Constraints parameters that can be counts (i.e. from a histogram) or, more generally, are
LogNormal distributed. This is often used in the case of histogram templates which are obtained
from simulation and have a log-normal uncertainty due to a multiplicative uncertainty.
.. math::
\text{constraint} = \text{LogNormal}(\text{observation}; \text{params})
Args:
params: The parameters to constraint; corresponds to the mu in the Poisson
distribution.
observation: observed values of the parameter; corresponds to lambda
in the Poisson distribution.
uncertainty: uncertainty of the observed values of the parameter; corresponds to sigma
in the Poisson distribution.
Raises:
ShapeIncompatibleError: If params, mu and sigma have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
uncertainty = convert_to_container(uncertainty, tuple)
original_init = {
"observation": observation,
"params": params,
"uncertainty": uncertainty,
}
distribution = tfd.LogNormal
def dist_params(observation):
return {"loc": observation, "scale": uncertainty}
dist_kwargs = {"validate_args": False}
super().__init__(
name="LogNormalConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
class LogNormalConstraintRepr(BaseConstraintRepr):
_implementation = LogNormalConstraint
hs3_type: Literal["LogNormalConstraint"] = pydantic.Field("LogNormalConstraint", alias="type")
params: list[Serializer.types.ParamInputTypeDiscriminated]
observation: list[Serializer.types.ParamInputTypeDiscriminated]
uncertainty: list[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation", "uncertainty")
def validate_params(cls, v):
return v.tolist() if isinstance(v, np.ndarray) else convert_to_container(v, list)