Source code for zfit.util.plotter

#  Copyright (c) 2025 zfit
from __future__ import annotations

import typing
from collections.abc import Callable, Mapping

import numpy as np

from zfit._interfaces import ZfitBinnedData, ZfitData, ZfitPDF, ZfitUnbinnedData

from ..core.space import convert_to_space
from . import ztyping
from .checks import RuntimeDependency
from .warnings import warn_experimental_feature

if typing.TYPE_CHECKING:
    import zfit  # noqa: F401

try:
    import matplotlib.pyplot as plt
except ImportError as error:
    plt = RuntimeDependency("plt", error_msg=str(error))

try:
    import mplhep
except ImportError as error:
    mplhep = RuntimeDependency("mplhep", error_msg=str(error))


[docs] def plot_sumpdf_components_pdfV1( model, *, plotfunc: Callable | None = None, scale=1, ax=None, linestyle=None, plotkwargs: Mapping[str, object] | None = None, extended: bool | None = None, ): """Plot the components of a sum pdf. Args: model: A zfit SumPDF. plotfunc: A plotting function that takes the `ax` to plot on x, y, and additional arguments. scale: An overall scale factor to apply to the components. ax: A matplotlib Axes object to plot on. linestyle: A linestyle to use for the components. Default is "--". plotkwargs: Additional keyword arguments to pass to the plotting function. extended: If True, plot extended components. If None, uses the model's extended state. """ import zfit # noqa: PLC0415 if not isinstance(model, zfit.pdf.SumPDF): msg = f"model must be a ZfitPDF, not a {type(model)}. Model is {model}." raise ValueError(msg) if linestyle is None: linestyle = "--" if plotkwargs is None: plotkwargs = {} if extended is None: extended = model.is_extended plotfunc = plot_model_pdf if plotfunc is None else plotfunc # Check if the SumPDF is automatically extended is_auto_extended = hasattr(model, "_automatically_extended") and model._automatically_extended # For automatically extended SumPDFs, we need to handle components differently if extended and is_auto_extended: for mod in model.pdfs: plotfunc(mod, scale=scale, ax=ax, linestyle=linestyle, extended=True, plotkwargs=plotkwargs) else: if extended: scale *= model.get_yield() # For non-extended or manually extended SumPDFs, use fractions # Force components to be non-extended and scale by fractions for mod, frac in zip(model.pdfs, model.params.values(), strict=True): plotfunc(mod, scale=frac * scale, ax=ax, linestyle=linestyle, extended=False, plotkwargs=plotkwargs) return ax
[docs] def plot_model_pdf( model: ZfitPDF, *, plotfunc: Callable | None = None, extended: bool | None = None, obs: ztyping.ObsTypeInput = None, scale: float | int | None = None, ax: plt.Axes | None = None, num: int | None = None, full: bool | None = None, linestyle=None, plotkwargs=None, ): """Plot the 1 dimensional density of a model, possibly scaled by the yield if extended. Args: model: An unbinned ZfitPDF. plotfunc: A plotting function that takes the ``ax`` to plot on, and x, y, and additional arguments. Default is ``ax.plot``. extended: If True, plot the extended pdf. If False, plot the pdf. obs: The observable to plot the pdf for. If None, the model's space is used. scale: An overall scale factor to apply to the pdf. ax: A matplotlib Axes object to plot on. num: The number of points to evaluate the pdf at. Default is 300. full: If True, set the x and y labels and the legend. Default is True. linestyle: A linestyle to use for the pdf. plotkwargs: Additional keyword arguments to pass to the plotting function. Returns: """ import zfit.z.numpy as znp # noqa: PLC0415 if not isinstance(model, ZfitPDF): msg = f"model must be a ZfitPDF, not a {type(model)}. Model is {model}." raise TypeError(msg) if scale is None: scale = 1 if num is None: num = 300 if full is None: full = True if plotkwargs is None: plotkwargs = {} if obs is None: obs = model.space if not obs.has_limits: msg = "Observables must have limits to be plotted. Either provide the limits with `obs` or use a model that has limits." raise ValueError(msg) else: obs = convert_to_space(obs) if not obs.has_limits: obs = model.space.with_obs(obs) if not obs.has_limits: msg = "Observables must have limits to be plotted. Either provide the limits with `obs` or use a model that has limits." raise ValueError(msg) if obs.n_obs != 1: msg = "obs must be 1D to be plotted." raise ValueError(msg) if model.space.n_obs != 1: if obs is None: msg = "1D space must be provided for multi-dimensional models to provide a 1D projection." raise ValueError(msg) model = model.create_projection_pdf(obs=obs, label=model.label) lower, upper = obs.v1.limits x = znp.linspace(lower, upper, num=num) y = model.ext_pdf(x) if extended else model.pdf(x) y *= scale if ax is None: ax = plt.gca() elif not isinstance(ax, plt.Axes): msg = "ax must be a matplotlib Axes object" raise ValueError(msg) plotfunc = ax.plot if plotfunc is None else plotfunc if "label" not in plotkwargs and full: plotkwargs["label"] = model.label plotfunc(x, y, linestyle=linestyle, **plotkwargs) if full: ax.set_xlabel(obs.label) ylabel = "Probability density" if not extended else "Extended probability density" ax.set_ylabel(ylabel) plt.legend() return ax
def assert_initialized(func): def wrapper(self, *args, **kwargs): if self.pdf is None: msg = "PDFPlotter is not initialized with a PDF." raise ValueError(msg) return func(self, *args, **kwargs) return wrapper class ZfitPDFPlotter: @warn_experimental_feature @assert_initialized def plotpdf( self, data: ZfitData | None = None, *, depth: int | None = None, density: bool | None = None, plotfunc: Callable | None = None, extended: bool | None = None, obs: ztyping.ObsTypeInput = None, scale: float | int | None = None, ax: plt.Axes | None = None, num: int | None = None, full: bool | None = None, linestyle=None, plotkwargs: Mapping[str, object] | None = None, histplotkwargs: Mapping[str, object] | None = None, ): """Plot the 1 dimensional density of the PDF, possibly scaled by the yield if extended. This is the main plotting method for PDFs in zfit. It provides a quick way to visualize the probability density function. Examples: Basic usage:: # Plot a simple PDF pdf.plot.plotpdf() # Plot extended PDF (scaled by yield) pdf.plot.plotpdf(extended=True) # Custom styling pdf.plot.plotpdf(color='red', linestyle='--', label='My PDF') For composite PDFs like SumPDF:: # Plot the sum sumpdf.plot.plotpdf() # Plot components sumpdf.plot.comp.plotpdf(linestyle='--') Args: data: An optional `ZfitData` object to plot alongside the PDF. If provided, the PDF will be scaled to match the data's normalization (either count or density). If `data` is unbinned, it will be binned automatically for plotting with 50 bins. To provide a custom binning, use ``pdf.plot(data=data.to_binned(...), ...)``. depth: The depth to plot if the PDF is made up of a Sum pdf with components. density: If True, the data will be plotted as a density histogram. If False, it will be plotted. as a count histogram. If `data` is provided and `extended` is True, `density` defaults to False. If `data` is provided and `extended` is False, `density` defaults to True. extended: If True, plot the extended pdf (multiplied by the yield). If False, plot the normalized pdf. If None, uses the PDF's `is_extended` property. obs: The observable to plot the pdf for. If None, the model's space is used. scale: An overall scale factor to apply to the pdf. Useful for plotting multiple PDFs with different normalizations. ax: A matplotlib Axes object to plot on. If None, uses the current axes (plt.gca()). num: The number of points to evaluate the pdf at. Default is 300. full: If True, set the x and y labels and the legend. Default is True. linestyle: A linestyle to use for the pdf (e.g., '-', '--', '-.', ':'). plotfunc: A plotting function that takes the `ax` to plot on, and x, y, and additional arguments. Default is `ax.plot`. plotkwargs: Additional keyword arguments to pass to the plotting function (e.g., color, alpha, linewidth, label). histplotkwargs: Additional keyword arguments to pass to `mplhep.histplot` when plotting data. Returns: matplotlib.axes.Axes: The matplotlib Axes object used for plotting. See Also: zfit.plot.plot_model_pdf: The underlying plotting function. SumPDF.plot.comp.plotpdf: For plotting components of composite PDFs. """ extended = self._preprocess_args_extended(extended) if depth is None: depth = 1 if scale is None: scale = 1 if data is None: if density is not None: msg = "Density argument is only supported when data is provided." raise ValueError(msg) if histplotkwargs is not None: msg = "histplotkwargs argument is only supported when data is provided." raise ValueError(msg) else: if density is None: density = not extended normalize = not extended ax, newscale = self._plot_scale_data( data, density=density, normalize=normalize, ax=ax, histplotkwargs=histplotkwargs ) scale *= newscale return self._plotpdf( depth=depth, plotfunc=plotfunc, extended=extended, obs=obs, scale=scale, ax=ax, num=num, full=full, linestyle=linestyle, plotkwargs=plotkwargs, ) def _plotpdf(self, **kwargs): raise NotImplementedError @property def comp(self): return None def __call__(self, data=None, **kwargs): return self.plotpdf(data=data, **kwargs) def _plot_scale_data( self, data: ZfitData, density=None, normalize=None, ax=None, histplotkwargs=None ) -> (plt.Axes, float): """Plots the scaled data. This method plots the given data with optional density and normalization. Args ---------- data: The ZfitBinnedData object to be plotted. density: If True, the plot will show the density of the data. Defaults to False. normalize: If True, the plot will normalize the data. Defaults to True. ax: The matplotlib axes object to plot on. If None, a new axes object will be created. Defaults to None. Notes ----- The plot will be scaled based on the provided normalization. The density of the data will be displayed if the density parameter is set to True. """ import zfit.z.numpy as znp # noqa: PLC0415 if histplotkwargs is None: histplotkwargs = {} if density is None: density = False if normalize is None: normalize = True if ax is None: ax = plt.gca() elif not isinstance(ax, plt.Axes): msg = "ax must be a matplotlib Axes object" raise ValueError(msg) if not isinstance(data, ZfitBinnedData) and isinstance(data, ZfitUnbinnedData): data = data.to_binned(50) values = data.values() binwidths = np.prod(data.binning.widths, axis=0) edges = data.binning.edges errors = None if (variances := data.variances()) is not None: errors = variances**0.5 scale = 1 nvals = None if density or normalize: nvals = znp.sum(values) if density: values /= binwidths if errors is not None: errors /= binwidths else: scale *= np.mean(binwidths) # converting the PDF density to counts if normalize: values /= nvals if errors is not None: errors /= nvals # plot values mplhep.histplot((values, edges), yerr=errors, **histplotkwargs, label=data.label, ax=ax) return ax, scale def _preprocess_args_extended(self, extended): if extended is None: extended = self.pdf.is_extended if extended and not self.pdf.is_extended: msg = "Provided extended as argument for plotting, but pdf is not extended." raise ValueError(msg) return extended class PDFPlotter(ZfitPDFPlotter): def __init__( self, pdf: ZfitPDF | None, pdfplotter: Callable | None = None, componentplotter: ZfitPDFPlotter = None, defaults: Mapping[str, object] | None = None, ): self.defaults = {} if defaults is None else defaults self.pdf = pdf if pdfplotter is not None and not callable(pdfplotter): msg = f"pdfplotter must be a callable, is {type(pdfplotter)}." raise TypeError(msg) self._pdfplotter = plot_model_pdf if pdfplotter is None else pdfplotter if componentplotter is not None and not isinstance(componentplotter, ZfitPDFPlotter): msg = f"componentplotter must be a ZfitPDFPlotter, is {type(componentplotter)}." raise TypeError(msg) self._componentplotter = componentplotter def _plotpdf(self, depth: int | None = None, **kwargs): if depth is None: depth = 1 kwargs |= self.defaults ax = plot_model_pdf(self.pdf, **kwargs) _ = kwargs.pop("ax", None) if kwargs.get("linestyle") is None: kwargs["linestyle"] = ":" if depth and self.comp is not None: return self.comp(depth=depth - 1, ax=ax, **kwargs) return ax @property @assert_initialized def comp(self): return self._componentplotter class SumCompPlotter(ZfitPDFPlotter): def __init__( self, pdf: ZfitPDF | None, *args, **kwargs, ): if pdf is not None and not isinstance(pdf, ZfitPDF): msg = f"pdf must be a ZfitPDF, is {type(pdf)}." raise TypeError(msg) self.pdf = pdf super().__init__(*args, **kwargs) def _plotpdf(self, data=None, *, depth: int | None = None, **kwargs): # noqa: ARG002 import zfit # noqa: PLC0415 if not isinstance(pdf := self.pdf, zfit.pdf.SumPDF): # we can relax this later with duck typing msg = f"pdf must be a SumPDF, is {type(pdf)}." raise TypeError(msg) assert isinstance(pdf, zfit.pdf.SumPDF), "pdf must be a SumPDF" scale = kwargs.pop("scale", None) if scale is None: scale = 1 if depth is None: depth = 0 if depth < 0: ax = kwargs.get("ax") if ax is None: msg = ( "ax is None. Either there is an issue with the depth argument or an internal error. " "Make sure `depth` is at least 0, if that's the case, please open a bug report " "with zfit." ) raise RuntimeError(msg) return ax if kwargs.pop("extended", False): scale *= pdf.get_yield() kwargs["extended"] = False # we manually scale the components, this should always hold assert len(pdf.pdfs) > 0, "INTERNAL ERROR: pdfs cannot be empty" values = pdf.params.values() assert len(values) > 0, "INTERNAL ERROR: values cannot be empty" for mod, frac in zip(pdf.pdfs, values, strict=True): ax = mod.plot.plotpdf(scale=frac * scale, **kwargs, depth=depth - 1) return ax