Source code for zfit.minimizers.strategy

#  Copyright (c) 2025 zfit

from __future__ import annotations

import abc
import typing
from abc import abstractmethod
from collections.abc import Mapping

import numpy as np

from zfit._interfaces import ZfitLoss, ZfitParameter

from ..util import ztyping
from .fitresult import FitResult

if typing.TYPE_CHECKING:
    import zfit  # noqa: F401


class FailMinimizeNaN(Exception):
    pass


class ZfitStrategy(abc.ABC):
    @abstractmethod
    def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:
        raise NotImplementedError

    @abstractmethod
    def callback(
        self,
        value: float | None,
        gradient: np.ndarray | None,
        hessian: np.ndarray | None,
        params: list[ZfitParameter],
        loss: ZfitLoss,
    ) -> tuple[float, np.ndarray, np.ndarray]:
        raise NotImplementedError


class BaseStrategy(ZfitStrategy):
    def __init__(self) -> None:
        self.fit_result = None
        self.error = None
        super().__init__()

    def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:  # noqa: ARG002
        raise FailMinimizeNaN()

    def callback(self, value, gradient, hessian, params, loss):
        del params, loss  # unused
        return value, gradient, hessian

    def __str__(self) -> str:
        return repr(self.__class__)[:-2].split(".")[-1]


class ToyStrategyFail(BaseStrategy):
    def __init__(self) -> None:
        super().__init__()
        self.fit_result = FitResult(
            params={},
            edm=None,
            fminopt=None,
            status=None,
            converged=False,
            info={},
            valid=False,
            message="NaN produced, ToyStrategy fails",
            niter=None,
            loss=None,
            minimizer=None,
            criterion=None,
        )

    def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:
        del values  # unused
        param_vals = np.asarray(params)
        param_vals = dict(zip(params, param_vals, strict=True))
        self.fit_result = FitResult(
            params=param_vals,
            edm=None,
            fminopt=None,
            status=9,
            converged=False,
            info={},
            loss=loss,
            valid=False,
            message="Failed on too manf NaNs",
            niter=None,
            criterion=None,
            minimizer=None,
        )
        raise FailMinimizeNaN()


def make_pushback_strategy(
    nan_penalty: float | int = 100,
    nan_tol: int = 30,
    base: object | ZfitStrategy = BaseStrategy,
):
    class PushbackStrategy(base):
        def __init__(self):
            """Pushback by adding `nan_penalty * counter` to the loss if NaNs are encountered.

            The counter indicates how many NaNs occurred in a row. The `nan_tol` is the upper limit, if this is
            exceeded, the fallback will be used and an error is raised.

            Args:
                nan_penalty: Value to add to the previous loss in order to penalize the step taken.
                nan_tol: If the number of NaNs encountered in a row exceeds this number, the fallback is used.
            """
            super().__init__()
            self.nan_penalty = nan_penalty
            self.nan_tol = nan_tol

        def minimize_nan(self, loss: ZfitLoss, params: ztyping.ParamTypeInput, values: Mapping | None = None) -> float:
            assert "nan_counter" in values, "'nan_counter' not in values, minimizer not correctly implemented"
            nan_counter = values["nan_counter"]
            if nan_counter < self.nan_tol:
                last_loss = values.get("old_loss")
                last_grad = values.get("old_grad")
                if last_grad is not None:
                    last_grad = -last_grad
                if last_loss is not None:
                    loss_evaluated = last_loss + self.nan_penalty * nan_counter
                else:
                    loss_evaluated = values.get("loss")
                if isinstance(loss_evaluated, str):
                    msg = "Loss starts already with NaN, cannot minimize."
                    raise RuntimeError(msg)
                return loss_evaluated, last_grad
            else:
                super().minimize_nan(loss=loss, params=params, values=values)
                return None

    return PushbackStrategy


PushbackStrategy = make_pushback_strategy()


[docs] class DefaultToyStrategy(PushbackStrategy, ToyStrategyFail): """Same as :py:class:`PushbackStrategy`, but does not raise an error on full failure, instead return an invalid FitResult. This can be useful for toy studies, where multiple fits are done and a failure should simply be counted as a failure instead of rising an error. """