# Copyright (c) 2024 zfit
from __future__ import annotations
import abc
import collections
from collections import OrderedDict
from collections.abc import Callable
from typing import Mapping, Iterable, List
import numpy as np
import pydantic
from .serialmixin import SerializableMixin
from ..serialization.serializer import BaseRepr, Serializer
from typing import Literal
import tensorflow as tf
import tensorflow_probability as tfp
from ordered_set import OrderedSet
import zfit.z.numpy as znp
from zfit import z
from .baseobject import BaseNumeric
from .dependents import _extract_dependencies
from .interfaces import ZfitConstraint, ZfitParameter
from ..settings import ztypes
from ..util import ztyping
from ..util.container import convert_to_container
from ..util.exception import ShapeIncompatibleError
tfd = tfp.distributions
# TODO(serialization): add to serializer
class BaseConstraintRepr(BaseRepr):
_implementation = None
_owndict = pydantic.PrivateAttr(default_factory=dict)
hs3_type: Literal["BaseConstraint"] = pydantic.Field("BaseConstraint", alias="type")
class BaseConstraint(ZfitConstraint, BaseNumeric):
def __init__(
self,
params: dict[str, ZfitParameter] = None,
name: str = "BaseConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints.
Args:
dtype: the dtype of the constraint
name: the name of the constraint
params: A dictionary with the internal name of the
parameter and the parameters itself the constrains depends on
"""
super().__init__(name=name, dtype=dtype, params=params, **kwargs)
def value(self):
return self._value()
@abc.abstractmethod
def _value(self):
raise NotImplementedError
def _get_dependencies(self) -> ztyping.DependentsType:
return _extract_dependencies(self.get_params(floating=None))
# TODO: improve arbitrary constraints, should we allow only functions that have a `params` argument?
[docs]
class SimpleConstraint(BaseConstraint):
def __init__(
self,
func: Callable,
params: (
Mapping[str, ztyping.ParameterType]
| Iterable[ztyping.ParameterType]
| ztyping.ParameterType
| None
),
*,
name: str = None,
):
"""Constraint from a (function returning a) Tensor.
Args:
func: Callable that constructs the constraint and returns a tensor. For the expected signature,
see below in ``params``.
params: The parameters of the loss. If given as a list, the parameters are named "param_{i}"
and the function does not take any arguments. If given as a dict, the function expects
the parameter as the first argument (``params``).
"""
if name is None:
name = "SimpleConstraint"
self._simple_func = func
self._func_params = None
if isinstance(params, collections.abc.Mapping):
self._func_params = params
params = list(params.values())
self._simple_func_dependents = convert_to_container(
params, container=OrderedSet
)
params = convert_to_container(params, container=list)
if self._func_params is None:
params = OrderedDict((f"param_{i}", p) for i, p in enumerate(params))
else:
params = self._func_params
super().__init__(name=name, params=params)
def _value(self):
if self._func_params is None:
return self._simple_func()
else:
return self._simple_func(self._func_params)
class ProbabilityConstraint(BaseConstraint):
def __init__(
self,
observation: ztyping.NumericalScalarType | ZfitParameter,
params: dict[str, ZfitParameter] = None,
name: str = "ProbabilityConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints using a probability density function.
Args:
dtype: the dtype of the constraint
name: the name of the constraint
params: The parameters to constraint
observation: Observed values of the parameter
to constraint obtained from auxiliary measurements.
"""
# TODO: proper handling of input params, arrays. ArrayParam?
params = convert_to_container(params)
params_dict = {f"param_{i}": p for i, p in enumerate(params)}
super().__init__(name=name, dtype=dtype, params=params_dict, **kwargs)
params = tuple(self.params.values())
observation = convert_to_container(observation, tuple)
if len(observation) != len(params):
raise ShapeIncompatibleError(
"observation and params have to be the same length. Currently"
f"observation: {len(observation)}, params: {len(params)}"
)
self._observation = observation # TODO: needed below? Why?
# for obs, p in zip(observation, params):
# obs = convert_to_parameter(obs, f"{p.name}_obs", prefer_constant=False)
# obs.floating = False
# self._observation.append(obs)
self._ordered_params = params
@property
def observation(self):
"""Return the observed values of the parameters constrained."""
return self._observation
def value(self):
return self._value()
@abc.abstractmethod
def _value(self):
raise NotImplementedError
def _get_dependencies(self) -> ztyping.DependentsType:
return _extract_dependencies(self.get_params())
def sample(self, n):
"""Sample ``n`` points from the probability density function for the observed value of the parameters.
Args:
n: The number of samples to be generated.
Returns:
"""
sample = self._sample(n=n)
return {p: sample[:, i] for i, p in enumerate(self._ordered_params)}
@abc.abstractmethod
def _sample(self, n):
raise NotImplementedError
@property
def _params_array(self):
return znp.asarray(self._ordered_params)
class TFProbabilityConstraint(ProbabilityConstraint):
def __init__(
self,
observation: ztyping.NumericalScalarType | ZfitParameter,
params: dict[str, ZfitParameter],
distribution: tfd.Distribution,
dist_params,
dist_kwargs=None,
name: str = "DistributionConstraint",
dtype=ztypes.float,
**kwargs,
):
"""Base class for constraints using a probability density function from ``tensorflow_probability``.
Args:
distribution: The probability density function
used to constraint the parameters
"""
super().__init__(
observation=observation, params=params, name=name, dtype=dtype, **kwargs
)
self._distribution = distribution
self.dist_params = dist_params
self.dist_kwargs = dist_kwargs if dist_kwargs is not None else {}
@property
def distribution(self):
params = self.dist_params
if callable(params):
params = params(self.observation)
kwargs = self.dist_kwargs
if callable(kwargs):
kwargs = kwargs()
params = {k: tf.cast(v, ztypes.float) for k, v in params.items()}
return self._distribution(**params, **kwargs, name=f"{self.name}_tfp")
def _value(self):
array = tf.cast(self._params_array, ztypes.float)
value = -self.distribution.log_prob(array)
return tf.reduce_sum(value)
def _sample(self, n):
return self.distribution.sample(n)
[docs]
class GaussianConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(
self,
params: ztyping.ParamTypeInput,
observation: ztyping.NumericalScalarType,
uncertainty: ztyping.NumericalScalarType,
):
r"""Gaussian constraints on a list of parameters to some observed values with uncertainties.
A Gaussian constraint is defined as the likelihood of ``params`` given the ``observations`` and ``uncertainty`` from
a different measurement.
.. math::
\text{constraint} = \text{Gauss}(\text{observation}; \text{params}, \text{uncertainty})
Args:
params: The parameters to constraint; corresponds to x in the Gaussian
distribution.
observation: observed values of the parameter; corresponds to mu
in the Gaussian distribution.
uncertainty: Uncertainties or covariance/error
matrix of the observed values. Can either be a single value, a list of values, an array or a tensor.
Corresponds to the sigma of the Gaussian distribution.
Raises:
ShapeIncompatibleError: If params, mu and sigma have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
uncertainty = convert_to_container(uncertainty, tuple)
if (
isinstance(uncertainty[0], (np.ndarray, tf.Tensor))
and len(uncertainty) == 1
):
uncertainty = tuple(uncertainty[0])
original_init = {
"observation": observation,
"params": params,
"uncertainty": uncertainty,
}
def create_covariance(mu, sigma):
mu = z.convert_to_tensor(mu)
sigma = znp.asarray(
sigma
) # otherwise TF complains that the shape got changed from [2] to [2, 2] (if we have a tuple of two arrays)
sigma = z.convert_to_tensor(sigma) # TODO (Mayou36): fix as above?
params_tensor = z.convert_to_tensor(params)
if sigma.shape.ndims > 1:
covariance = sigma # TODO: square as well?
elif sigma.shape.ndims == 1:
covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
else:
sigma = znp.reshape(sigma, [1])
covariance = tf.linalg.tensor_diag(z.pow(sigma, 2.0))
if (
not params_tensor.shape[0]
== mu.shape[0]
== covariance.shape[0]
== covariance.shape[1]
):
raise ShapeIncompatibleError(
f"params_tensor, observation and uncertainty have to have the"
" same length. Currently"
f"param: {params_tensor.shape[0]}, mu: {mu.shape[0]}, "
f"covariance (from uncertainty): {covariance.shape[0:2]}"
)
return covariance
distribution = tfd.MultivariateNormalTriL
covariance = create_covariance(observation, uncertainty)
dist_params = lambda observation: dict(
loc=observation,
scale_tril=tf.linalg.cholesky(covariance),
)
dist_kwargs = dict(validate_args=True)
super().__init__(
name="GaussianConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
self._covariance = lambda: create_covariance(self.observation, uncertainty)
@property
def covariance(self):
"""Return the covariance matrix of the observed values of the parameters constrained."""
return self._covariance()
class GaussianConstraintRepr(BaseConstraintRepr):
_implementation = GaussianConstraint
hs3_type: Literal["GaussianConstraint"] = pydantic.Field(
"GaussianConstraint", alias="type"
)
params: List[Serializer.types.ParamInputTypeDiscriminated]
observation: List[Serializer.types.ParamInputTypeDiscriminated]
uncertainty: List[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation", "uncertainty")
def validate_params(cls, v):
if isinstance(v, np.ndarray):
v = v.tolist()
else:
v = convert_to_container(v, list)
return v
[docs]
class PoissonConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(
self, params: ztyping.ParamTypeInput, observation: ztyping.NumericalScalarType
):
r"""Poisson constraints on a list of parameters to some observed values.
Constraints parameters that can be counts (i.e. from a histogram) or, more generally, are
Poisson distributed. This is often used in the case of histogram templates which are obtained
from simulation and have a poisson uncertainty due to limited statistics.
.. math::
\text{constraint} = \text{Poisson}(\text{observation}; \text{params})
Args:
params: The parameters to constraint; corresponds to the mu in the Poisson
distribution.
observation: observed values of the parameter; corresponds to lambda
in the Poisson distribution.
Raises:
ShapeIncompatibleError: If params and observation have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
original_init = {"observation": observation, "params": params}
distribution = tfd.Poisson
dist_params = dict(rate=observation)
dist_kwargs = dict(validate_args=False)
super().__init__(
name="PoissonConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
class PoissonConstraintRepr(BaseConstraintRepr):
_implementation = PoissonConstraint
hs3_type: Literal["PoissonConstraint"] = pydantic.Field(
"PoissonConstraint", alias="type"
)
params: List[Serializer.types.ParamInputTypeDiscriminated]
observation: List[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation")
def validate_params(cls, v):
if isinstance(v, np.ndarray):
v = v.tolist()
else:
v = convert_to_container(v, list)
return v
[docs]
class LogNormalConstraint(TFProbabilityConstraint, SerializableMixin):
def __init__(
self,
params: ztyping.ParamTypeInput,
observation: ztyping.NumericalScalarType,
uncertainty: ztyping.NumericalScalarType,
):
r"""Log-normal constraints on a list of parameters to some observed values.
Constraints parameters that can be counts (i.e. from a histogram) or, more generally, are
LogNormal distributed. This is often used in the case of histogram templates which are obtained
from simulation and have a log-normal uncertainty due to a multiplicative uncertainty.
.. math::
\text{constraint} = \text{LogNormal}(\text{observation}; \text{params})
Args:
params: The parameters to constraint; corresponds to the mu in the Poisson
distribution.
observation: observed values of the parameter; corresponds to lambda
in the Poisson distribution.
uncertainty: uncertainty of the observed values of the parameter; corresponds to sigma
in the Poisson distribution.
Raises:
ShapeIncompatibleError: If params, mu and sigma have incompatible shapes.
"""
observation = convert_to_container(observation, tuple)
params = convert_to_container(params, tuple)
uncertainty = convert_to_container(uncertainty, tuple)
original_init = {
"observation": observation,
"params": params,
"uncertainty": uncertainty,
}
distribution = tfd.LogNormal
dist_params = lambda observation: dict(loc=observation, scale=uncertainty)
dist_kwargs = dict(validate_args=False)
super().__init__(
name="LogNormalConstraint",
observation=observation,
params=params,
distribution=distribution,
dist_params=dist_params,
dist_kwargs=dist_kwargs,
)
self.hs3.original_init.update(original_init)
class LogNormalConstraintRepr(BaseConstraintRepr):
_implementation = LogNormalConstraint
hs3_type: Literal["LogNormalConstraint"] = pydantic.Field(
"LogNormalConstraint", alias="type"
)
params: List[Serializer.types.ParamInputTypeDiscriminated]
observation: List[Serializer.types.ParamInputTypeDiscriminated]
uncertainty: List[Serializer.types.ParamInputTypeDiscriminated]
@pydantic.root_validator(pre=True)
def get_init_args(cls, values):
if cls.orm_mode(values):
values = values["hs3"].original_init
return values
@pydantic.validator("params", "observation", "uncertainty")
def validate_params(cls, v):
if isinstance(v, np.ndarray):
v = v.tolist()
else:
v = convert_to_container(v, list)
return v