# Copyright (c) 2025 zfit
from __future__ import annotations
import typing
from collections.abc import Callable
from typing import Literal
import pydantic.v1 as pydantic
import tensorflow as tf
import zfit.z.numpy as znp
from .. import z
from ..core.serialmixin import SerializableMixin
from ..core.space import supports
from ..serialization import Serializer # noqa: F401
from ..settings import ztypes
from ..util import ztyping
from ..util.exception import AnalyticGradientNotAvailable
from .basefunctor import FunctorPDFRepr
from .functor import BaseFunctor
if typing.TYPE_CHECKING:
import zfit # noqa: F401
def get_value(cache: tf.Variable, flag: tf.Variable, func: Callable):
@tf.custom_gradient
def actual_func():
def autoset_func():
val = func()
# tf.print(val)
return cache.assign(val, read_value=True)
def use_cache():
return cache
val = tf.cond(flag, use_cache, autoset_func)
def grad_fn(dval, variables): # noqa: ARG001
msg = (
"The analytic gradient is not implemented for caching PDF. Use the numerical gradient instead."
"(either using zfit.run.set_autograd_mode(False) and/or by using the minimizer internal numerical gradient)"
)
raise AnalyticGradientNotAvailable(msg)
return val, grad_fn
return actual_func()
[docs]
class CachedPDF(BaseFunctor, SerializableMixin):
def __init__(
self,
pdf: ztyping.PDFInputType,
*,
epsilon: float | None = None,
extended: ztyping.ExtendedInputType = None,
norm: ztyping.NormInputType = None,
cache_tol=None,
name: str | None = None,
label: str | None = None,
):
"""Creates a PDF where ``pdf`` and ``integrate`` methods are cacheable.
.. note::
Analytic gradients are not supported for cached PDFs. Use numerical gradients instead.
The method stores the last calculated value of a function for a specific dataset and
returns it when the input data and the parameters are the same. This can be useful when
the pdf is called multiple times with the same data and parameters, for example in the
minimization process when a numerical gradient is used.
Args:
pdf: pdf which methods to be cached
cache_tol: accuracy of absolute tolerance comparing arguments (parameters, data) with cached values
extended: |@doc:pdf.init.extended| The overall yield of the PDF.
If this is parameter-like, it will be used as the yield,
the expected number of events, and the PDF will be extended.
An extended PDF has additional functionality, such as the
``ext_*`` methods and the ``counts`` (for binned PDFs). |@docend:pdf.init.extended|
norm: |@doc:pdf.init.norm| Normalization of the PDF.
By default, this is the same as the default space of the PDF. |@docend:pdf.init.norm|
name: |@doc:pdf.init.name| Name of the PDF.
Maybe has implications on the serialization and deserialization of the PDF.
For a human-readable name, use the label. |@docend:pdf.init.name|
label: |@doc:pdf.init.label| Human-readable name
or label of
the PDF for a better description, to be used with plots etc.
Has no programmatical functional purpose as identification. |@docend:pdf.init.label|
"""
obs = pdf.space
if cache_tol is not None:
epsilon = cache_tol
hs3_init = {
"obs": obs,
"extended": extended,
"norm": norm,
"epsilon": epsilon,
"name": name,
}
name = name or pdf.name
super().__init__(pdfs=pdf, obs=obs, name=name, extended=extended, norm=norm, label=label, autograd_params=[])
params = list(pdf.get_params(floating=None, is_yield=None, extract_independent=True))
param_cache = tf.Variable(
znp.zeros(shape=tf.shape(znp.stack(params)), dtype=ztypes.float),
trainable=False,
validate_shape=False,
dtype=tf.float64,
)
param_cache_int = tf.Variable(
znp.zeros(shape=tf.shape(znp.stack(params)), dtype=ztypes.float),
trainable=False,
validate_shape=False,
dtype=tf.float64,
)
self._cached_pdf_params = param_cache
self._cached_pdf_params_for_integration = param_cache_int
self._pdf_cache = None
self._cached_x = None
self._pdf_cache_valid = tf.Variable(initial_value=False, trainable=False)
self._cached_integral_limits = None
self._integral_cache = None
self._integral_cache_valid = tf.Variable(initial_value=False, trainable=False)
self._cache_tolerance = 1e-8 if epsilon is None else epsilon
self.hs3.original_init.update(hs3_init)
@supports(norm="space")
@z.function(autograph=True)
def _pdf(self, x, norm):
x = x.value()
xlen = tf.shape(x)[0]
if self._pdf_cache is None:
self._pdf_cache = tf.Variable(
-999.0
* znp.ones(
shape=xlen,
), # negative ones, to make sure these are unrealistic values
trainable=False,
validate_shape=False,
shape=tf.TensorShape([None]),
dtype=ztypes.float,
)
if self._cached_x is None:
self._cached_x = tf.Variable(
x + 19.0, # to make sure it's not the same
trainable=False,
validate_shape=False,
shape=tf.TensorShape([None, x.shape[1]]),
dtype=ztypes.float,
)
cachedxlen = tf.shape(self._cached_x)[0]
minlen = znp.min([xlen, cachedxlen])
xtrunc = x[:minlen]
xcachedtrunc = self._cached_x[:minlen]
# for debugging purposes, this fails a lot...
# tf.print(tf.shape(xtrunc))
# tf.print(tf.shape(xcachedtrunc))
# tf.print(xlen, cachedxlen, minlen)
with tf.control_dependencies([xtrunc, xcachedtrunc]): # required! Otherwise would use previous shape (bug?)
xtrunc_diff = xtrunc - xcachedtrunc
xtruncdiff_abs = znp.abs(xtrunc_diff)
xtrunc_lt = xtruncdiff_abs < self._cache_tolerance
x_same = tf.math.reduce_all(xtrunc_lt)
x_same = znp.logical_and(x_same, xlen == cachedxlen)
pdf_params = list(self.pdfs[0].get_params())
if hasparams := (paramlen := len(pdf_params)) > 0:
stacked_pdf_params = znp.stack(pdf_params)
cachedparamlen = tf.shape(self._cached_pdf_params)[0]
minparamlen = znp.min(
[
paramlen,
cachedparamlen,
]
)
stackedpdf_trunc = stacked_pdf_params[:minparamlen]
cachedpdf_trunc = self._cached_pdf_params[:minparamlen]
with tf.control_dependencies([stackedpdf_trunc, cachedpdf_trunc]):
params_same = tf.math.reduce_all(znp.abs(stackedpdf_trunc - cachedpdf_trunc) < self._cache_tolerance)
same_args = tf.math.logical_and(params_same, x_same)
same_args = znp.logical_and(same_args, paramlen == cachedparamlen)
else:
same_args = x_same
assign1 = self._pdf_cache_valid.assign(same_args, read_value=False)
def value_update_func():
if hasparams:
self._cached_pdf_params.assign(stacked_pdf_params, read_value=False)
self._cached_x.assign(x, read_value=False)
return self.pdfs[0].pdf(x, norm)
# tf.print(self._pdf_cache)
with tf.control_dependencies([assign1]):
return get_value(self._pdf_cache, self._pdf_cache_valid, value_update_func)
@supports(norm="space")
def _integrate(self, limits, norm, options=None):
if self._cached_integral_limits is None:
self._cached_integral_limits = tf.Variable(
tf.stack(limits.v1.limits) + 19.0, # to make sure it's not the same
trainable=False,
validate_shape=False,
dtype=ztypes.float,
)
if self._integral_cache is None:
self._integral_cache = tf.Variable(
znp.zeros(shape=tf.shape([1])),
trainable=False,
validate_shape=False,
dtype=ztypes.float,
)
stacked_integral_limits = tf.stack(limits.v1.limits)
limits_same = tf.math.reduce_all(
znp.abs(stacked_integral_limits - self._cached_integral_limits) < self._cache_tolerance
)
params = list(self.pdfs[0].get_params(floating=None))
if hasparams := len(params) > 0:
stacked_pdf_params = znp.stack(params)
params_same = znp.all(
znp.abs(stacked_pdf_params - self._cached_pdf_params_for_integration) < self._cache_tolerance
)
same_args = znp.logical_and(params_same, limits_same)
else:
same_args = limits_same
assign1 = self._integral_cache_valid.assign(same_args, read_value=False)
def value_update_func():
if hasparams:
self._cached_pdf_params_for_integration.assign(stacked_pdf_params, read_value=False)
self._cached_integral_limits.assign(stacked_integral_limits, read_value=False)
return self.pdfs[0].integrate(limits, norm, options=options)
with tf.control_dependencies([assign1]):
return get_value(self._integral_cache, self._integral_cache_valid, value_update_func)
class CachedPDFRepr(FunctorPDFRepr):
_implementation = CachedPDF
hs3_type: Literal["CachedPDF"] = pydantic.Field("CachedPDF", alias="type")
def _to_orm(self, init) -> CachedPDF:
init.pop("obs")
init["pdf"] = init.pop("pdfs")[0]
return super()._to_orm(init)