# Copyright (c) 2024 zfit
from __future__ import annotations
import abc
from abc import abstractmethod
from collections.abc import Mapping
import numpy as np
from ..core.interfaces import ZfitLoss, ZfitParameter
from ..util import ztyping
from .fitresult import FitResult
class FailMinimizeNaN(Exception):
pass
class ZfitStrategy(abc.ABC):
@abstractmethod
def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:
raise NotImplementedError
@abstractmethod
def callback(
self,
value: float | None,
gradient: np.ndarray | None,
hessian: np.ndarray | None,
params: list[ZfitParameter],
loss: ZfitLoss,
) -> tuple[float, np.ndarray, np.ndarray]:
raise NotImplementedError
class BaseStrategy(ZfitStrategy):
def __init__(self) -> None:
self.fit_result = None
self.error = None
super().__init__()
def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float: # noqa: ARG002
raise FailMinimizeNaN()
def callback(self, value, gradient, hessian, params, loss):
del params, loss # unused
return value, gradient, hessian
def __str__(self) -> str:
return repr(self.__class__)[:-2].split(".")[-1]
class ToyStrategyFail(BaseStrategy):
def __init__(self) -> None:
super().__init__()
self.fit_result = FitResult(
params={},
edm=None,
fminopt=None,
status=None,
converged=False,
info={},
valid=False,
message="NaN produced, ToyStrategy fails",
niter=None,
loss=None,
minimizer=None,
criterion=None,
)
def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:
del values # unused
param_vals = np.asarray(params)
param_vals = dict(zip(params, param_vals))
self.fit_result = FitResult(
params=param_vals,
edm=None,
fminopt=None,
status=9,
converged=False,
info={},
loss=loss,
valid=False,
message="Failed on too manf NaNs",
niter=None,
criterion=None,
minimizer=None,
)
raise FailMinimizeNaN()
def make_pushback_strategy(
nan_penalty: float | int = 100,
nan_tol: int = 30,
base: object | ZfitStrategy = BaseStrategy,
):
class PushbackStrategy(base):
def __init__(self):
"""Pushback by adding `nan_penalty * counter` to the loss if NaNs are encountered.
The counter indicates how many NaNs occurred in a row. The `nan_tol` is the upper limit, if this is
exceeded, the fallback will be used and an error is raised.
Args:
nan_penalty: Value to add to the previous loss in order to penalize the step taken.
nan_tol: If the number of NaNs encountered in a row exceeds this number, the fallback is used.
"""
super().__init__()
self.nan_penalty = nan_penalty
self.nan_tol = nan_tol
def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:
assert "nan_counter" in values, "'nan_counter' not in values, minimizer not correctly implemented"
nan_counter = values["nan_counter"]
if nan_counter < self.nan_tol:
last_loss = values.get("old_loss")
last_grad = values.get("old_grad")
if last_grad is not None:
last_grad = -last_grad
if last_loss is not None:
loss_evaluated = last_loss + self.nan_penalty * nan_counter
else:
loss_evaluated = values.get("loss")
if isinstance(loss_evaluated, str):
msg = "Loss starts already with NaN, cannot minimize."
raise RuntimeError(msg)
return loss_evaluated, last_grad
else:
super().minimize_nan(loss=loss, params=params, values=values)
return None
return PushbackStrategy
PushbackStrategy = make_pushback_strategy()
[docs]
class DefaultToyStrategy(PushbackStrategy, ToyStrategyFail):
"""Same as :py:class:`PushbackStrategy`, but does not raise an error on full failure, instead return an invalid
FitResult.
This can be useful for toy studies, where multiple fits are done and a failure should simply be counted as a failure
instead of rising an error.
"""