"""Posterior class for Bayesian inference results in zfit."""
# Copyright (c) 2025 zfit
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import zfit
from zfit.util.container import convert_to_container, is_container
if TYPE_CHECKING:
from collections.abc import Iterable
import arviz as az
import numpy.typing as npt
import pandas as pd
from zfit.mcmc import MCMCSampler
from .._interfaces import ZfitLoss, ZfitParameter
from .priors import KDE
[docs]
class PosteriorSamples:
def __init__(
self,
samples: npt.NDArray[np.float64],
params: Iterable[ZfitParameter],
loss: ZfitLoss,
sampler: MCMCSampler,
n_warmup: int,
n_samples: int,
raw_result: object | None = None,
info: dict | None = None,
):
"""Posterior samples from MCMC bayesian inference.
Args:
samples: Array of shape (n_samples, n_params) containing MCMC samples.
params: List of ZfitParameter objects.
loss: The ZfitLoss that was sampled.
sampler: The ZfitSampler that generated the samples.
n_warmup: Number of warmup/burn-in steps.
n_samples: Number of posterior samples per walker.
raw_result: Raw result from the sampler.
info: Additional information dictionary.
"""
if not isinstance(n_warmup, (int, np.integer)) or n_warmup < 0:
msg = f"n_warmup must be a non-negative integer, got {n_warmup}"
raise ValueError(msg)
if not isinstance(n_samples, (int, np.integer)) or n_samples <= 0:
msg = f"n_samples must be a positive integer, got {n_samples}"
raise ValueError(msg)
if info is None:
info = {}
self.samples = np.asarray(samples)
# Validate sample shape
if self.samples.ndim != 2:
msg = f"samples must be a 2D array, got shape {self.samples.shape}"
raise ValueError(msg)
if len(self.samples) == 0:
msg = "samples cannot be empty"
raise ValueError(msg)
self._params = convert_to_container(params)
# Validate params
if not self._params:
msg = "params cannot be empty"
raise ValueError(msg)
# Check samples and params consistency
if self.samples.shape[1] != len(self._params):
msg = f"Number of parameters in samples ({self.samples.shape[1]}) does not match number of parameters ({len(self._params)})"
raise ValueError(msg)
self._loss = loss
self._sampler = sampler
self.n_warmup = n_warmup
self.n_samples = n_samples
self.raw_result = raw_result
self.info = info
# Create parameter mappings
# Map both name->param and param->name for efficient lookup
self._param_by_name = {param.name: param for param in self._params}
self._name_by_param = {param: param.name for param in self._params}
# Create position mapping for internal numpy operations
self._position_by_name = {param.name: i for i, param in enumerate(self._params)}
# Compute convergence diagnostics
self._compute_convergence_diagnostics()
# Core statistical methods
[docs]
def mean(
self, params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None
) -> float | npt.NDArray[np.float64]:
"""Posterior mean(s).
Args:
params: Parameter name, object, index, or list thereof. If None, return all means.
Returns:
Mean value(s).
- Single parameter: returns float.
- Collection of parameters: returns array.
"""
# Validate that we have samples
if len(self.samples) == 0:
msg = "Cannot compute mean of empty samples"
raise ValueError(msg)
# If params is None, use all parameters (treat as collection)
if params is None:
params = [param.name for param in self._params]
was_container = True
else:
# Check if original input was a container before conversion
was_container = is_container(params)
indices = self._get_param_positions(params)
# Validate indices
if not indices:
msg = "No valid parameters specified for mean calculation"
raise ValueError(msg)
# Select samples for the requested parameters
samples_np = np.asarray(self.samples)
selected_samples = samples_np[:, indices]
means = np.mean(selected_samples, axis=0)
# Single param not in container -> scalar, collection -> array
if not was_container:
return float(means[0])
return means
[docs]
def symerr(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
*,
sigma: float | None = None,
) -> float | npt.NDArray[np.float64]:
"""Symmetric error (standard deviation) of posterior samples.
Args:
params: Parameter name, object, or index. If None, return all errors.
sigma: Number of standard deviations. Default is 1. For example,
sigma=1 returns 1 standard deviation,
sigma=2 returns 2 standard deviations.
Returns:
Symmetric error(s) as float or array.
"""
if sigma is None:
sigma = 1
# Convert to float
try:
sigma = float(sigma)
except (TypeError, ValueError) as error:
msg = f"sigma must be convertible to float, got {sigma}"
raise TypeError(msg) from error
if sigma <= 0:
msg = f"sigma must be positive, got {sigma}"
raise ValueError(msg)
if sigma > 20:
msg = f"A sigma value of {sigma} is larger than 20. This is not a realistic value and most likely a bug."
raise ValueError(msg)
return sigma * self.std(params=params)
[docs]
def std(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
) -> float | npt.NDArray[np.float64]:
"""Standard deviation of posterior samples
Args:
params: Parameter name, object, index, or list thereof. If None, return all stds.
Returns:
Standard deviation(s).
- Single parameter: returns float.
- Collection of parameters: returns array.
Examples:
>>> result.std() # All parameters
array([0.102, 0.234])
>>> result.std(['mu', 'sigma']) # Multiple parameters
array([0.102, 0.234])
>>> result.std('mu') # Single parameter
0.102
>>> result.std(['mu']) # Single parameter in list
array([0.102])
"""
# Validate that we have samples
if len(self.samples) == 0:
msg = "Cannot compute standard deviation of empty samples"
raise ValueError(msg)
# If params is None, use all parameters (treat as collection)
if params is None:
params = [param.name for param in self._params]
was_container = True
else:
# Check if original input was a container before conversion
was_container = is_container(params)
indices = self._get_param_positions(params)
# Validate indices
if not indices:
msg = "No valid parameters specified for std calculation"
raise ValueError(msg)
# Select samples for the requested parameters
samples_np = np.asarray(self.samples)
selected_samples = samples_np[:, indices]
stds = np.std(selected_samples, axis=0)
# Single param not in container -> scalar, collection -> array
if not was_container:
return float(stds[0])
return stds
[docs]
def credible_interval(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
*,
alpha: float | None = None,
sigma: float | None = None,
) -> tuple[float | npt.NDArray[np.float64], float | npt.NDArray[np.float64]]:
"""Equal-tailed credible interval(s).
Args:
params: Parameter name, object, index, or list thereof. If None, return all intervals.
alpha: Significance level. Default is 0.05 for 95% interval.
sigma: Number of standard deviations (e.g., 1 for ~68%, 2 for ~95%). Overrides alpha if given.
Returns:
Tuple (lower, upper).
- Single parameter: returns tuple of floats.
- Collection of parameters: returns tuple of arrays.
"""
import scipy.stats # noqa: PLC0415
# Validate inputs
if sigma is not None and alpha is not None:
msg = "Cannot specify both sigma and alpha. Choose one."
raise ValueError(msg)
if sigma is not None:
# Convert to float
try:
sigma = float(sigma)
except (TypeError, ValueError) as error:
msg = f"sigma must be convertible to float, got {sigma}"
raise TypeError(msg) from error
if sigma <= 0:
msg = f"sigma must be positive, got {sigma}"
raise ValueError(msg)
# Convert sigma to two-tailed alpha using normal distribution
alpha = 2 * (1 - scipy.stats.norm.cdf(sigma))
elif alpha is None:
alpha = 0.05
else:
# Convert to float
try:
alpha = float(alpha)
except (TypeError, ValueError) as error:
msg = f"alpha must be convertible to float, got {alpha}"
raise TypeError(msg) from error
if not 0 < alpha < 1:
msg = f"alpha must be between 0 and 1, got {alpha}"
raise ValueError(msg)
# Vectorized percentile calculation
lower_percentile = 100 * alpha / 2
upper_percentile = 100 * (1 - alpha / 2)
# If params is None, use all parameters (treat as collection)
if params is None:
params = [param.name for param in self._params]
was_container = True
else:
# Check if original input was a container before conversion
was_container = is_container(params)
# Handle single parameter or list of parameters
indices = self._get_param_positions(params)
if not indices:
msg = "No parameters provided for credible interval calculation"
raise ValueError(msg)
# Extract samples for selected parameters
samples_np = np.asarray(self.samples)
selected_samples = samples_np[:, indices]
# Calculate percentiles
lowers = np.percentile(selected_samples, lower_percentile, axis=0)
uppers = np.percentile(selected_samples, upper_percentile, axis=0)
# Single param not in container -> scalar tuple, collection -> array tuple
if not was_container:
return float(lowers[0]), float(uppers[0])
return lowers, uppers
[docs]
def get_samples(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
) -> npt.NDArray[np.float64]:
"""Get posterior samples.
Args:
params: Parameter name, object, index, or list thereof. If None, return all samples.
Returns:
Array of samples.
- Single parameter: returns 1D array.
- Collection of parameters: returns 2D array with shape (n_samples, n_params).
"""
# Validate that we have samples
if len(self.samples) == 0:
msg = "Cannot get samples from empty posterior"
raise ValueError(msg)
# If params is None, use all parameters (treat as collection)
if params is None:
params = [param.name for param in self._params]
was_container = True
else:
# Check if original input was a container before conversion
was_container = is_container(params)
indices = self._get_param_positions(params)
# Validate indices
if not indices:
msg = "No valid parameters specified"
raise ValueError(msg)
# Convert to numpy for indexing
samples_np = np.asarray(self.samples)
# Single param not in container -> 1D, collection -> 2D
if not was_container:
return samples_np[:, indices[0]]
return samples_np[:, indices]
[docs]
def as_prior(self, param: str | ZfitParameter | int) -> KDE:
"""Get posterior samples as a KDE prior for hierarchical modeling.
Args:
param: Parameter name, object, or index to get posterior for.
Returns:
KDE prior created from posterior samples.
"""
# Validate param is not None
if param is None:
msg = "param cannot be None. Must specify a single parameter."
raise ValueError(msg)
# Validate param is a single parameter (not a list)
if is_container(param) and len(param) > 1:
msg = "as_prior() only supports single parameter, not multiple parameters"
raise ValueError(msg)
samples = self.get_samples(param)
# Import here to avoid circular imports
from .priors import KDE # noqa: PLC0415
param_name = param if isinstance(param, str) else param.name if hasattr(param, "name") else f"param_{param}"
return KDE(samples, name=f"{param_name}_posterior_prior")
# ArviZ integration
[docs]
def to_arviz(self) -> az.InferenceData:
"""Convert to ArviZ InferenceData format.
Returns:
ArviZ InferenceData object for advanced analysis.
"""
try:
import arviz as az # noqa: PLC0415
except ImportError as error:
msg = "ArviZ is required for to_arviz(). Install with 'pip install arviz'."
raise ImportError(msg) from error
# Get samples as numpy array
samples_np = self.samples
# Determine chain and draw dimensions
total_samples = len(samples_np)
# Try to infer nwalkers from sampler
if hasattr(self._sampler, "nwalkers") and self._sampler.nwalkers is not None:
nwalkers = self._sampler.nwalkers
ndraws = total_samples // nwalkers
else:
# Default to single chain
nwalkers = 1
ndraws = total_samples
# Reshape samples for ArviZ (chain, draw, parameter)
if nwalkers > 1:
samples_reshaped = np.reshape(samples_np, (nwalkers, ndraws, -1))
else:
samples_reshaped = samples_np[np.newaxis, :, :] # Add chain dimension
# Use az.from_dict for simpler conversion
return az.from_dict(
{param.name: samples_reshaped[:, :, i] for i, param in enumerate(self._params)},
coords={"chain": range(nwalkers), "draw": range(ndraws)},
)
# Parameter management
[docs]
def update_params(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
*,
what: str | None = None,
) -> PosteriorSamples:
"""Set all parameters to their posterior mean values."""
if what is None:
what = "mean"
# Validate 'what' parameter
valid_options = ["mean"] # Can be extended in future
if what not in valid_options:
msg = f"Invalid 'what' option: {what}. Valid options are: {valid_options}"
raise ValueError(msg)
if what == "mean":
return self._set_params_to_mean(params)
msg = "This should never be reached, internal error."
raise AssertionError(msg)
def _set_params_to_mean(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
) -> PosteriorSamples:
"""Set parameters to their posterior mean values."""
if params is None:
params = self._params
# Convert params to container and get indices
params = convert_to_container(params)
indices = self._get_param_positions(params)
means = self.mean(params)
# Get actual parameter objects from indices
param_objects = [self._params[idx] for idx in indices]
# Set parameter values
if len(param_objects) == 1:
param_objects[0].set_value(float(means))
else:
for param_obj, mean_val in zip(param_objects, means, strict=False):
param_obj.set_value(float(mean_val))
return self # todo: improve, refactor?
[docs]
def __enter__(self):
"""Context manager: set parameters to posterior means."""
self._old_values = [param.value() for param in self._params]
self.update_params()
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager: restore original parameter values."""
from zfit.core.parameter import set_values # noqa: PLC0415
set_values(self._params, self._old_values)
# Essential properties
@property
def params(self) -> list[ZfitParameter]:
"""Parameters used in the sampling."""
return self._params
@property
def param_names(self) -> list[str]:
"""Names of the parameters used in the sampling."""
return [param.name for param in self._params]
@property
def sampler(self) -> MCMCSampler:
"""Sampler used to generate samples."""
return self._sampler
@property
def loss(self) -> ZfitLoss:
"""Loss function that was sampled."""
return self._loss
@property
def valid(self) -> bool:
"""Whether the MCMC results are valid (no NaN/inf values)."""
return self._valid
@property
def converged(self) -> bool:
"""Whether the MCMC chains have converged based on diagnostics.
Convergence is determined by:
- R-hat < 1.1 for all parameters (Gelman-Rubin statistic)
- Effective sample size > 100 for all parameters
- No NaN or infinite values
"""
return self._converged
[docs]
def covariance(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter] | None = None,
) -> npt.NDArray[np.float64]:
"""Covariance matrix from posterior samples.
Args:
params: Parameters to include. If None, use all parameters.
Returns:
Covariance matrix as numpy array.
- Single parameter: returns scalar (variance).
- Collection of parameters: returns matrix.
"""
# Validate that we have samples
if len(self.samples) == 0:
msg = "Cannot compute covariance of empty samples"
raise ValueError(msg)
# Need at least 2 samples for covariance
if len(self.samples) < 2:
msg = f"Need at least 2 samples to compute covariance, got {len(self.samples)}"
raise ValueError(msg)
# If params is None, use all parameters (treat as collection)
if params is None:
params = [param.name for param in self._params]
was_container = True
else:
# Check if original input was a container before conversion
was_container = is_container(params)
indices = self._get_param_positions(params)
# Validate indices
if not indices:
msg = "No valid parameters specified for covariance calculation"
raise ValueError(msg)
# Select columns for specified parameters
samples_np = np.asarray(self.samples)
selected_samples = samples_np[:, indices]
# Single param not in container -> scalar variance, collection -> matrix
if len(indices) == 1 and not was_container:
variance = np.var(selected_samples, axis=0, ddof=1)
return float(variance[0]) if variance.ndim > 0 else float(variance)
# For collection of parameters, return covariance matrix
cov_matrix = np.cov(selected_samples, rowvar=False)
return np.atleast_2d(cov_matrix)
# Utility methods
[docs]
def summary(self, round_to: int | None = None) -> pd.DataFrame:
"""Summary statistics using ArviZ when available.
Args:
round_to: Number of decimals to round to. If None, no rounding.
Returns:
ArviZ summary DataFrame.
"""
import arviz as az # noqa: PLC0415
idata = self.to_arviz()
# Use ArviZ summary for comprehensive statistics
return az.summary(idata, round_to=round_to)
def _get_param_positions(
self,
params: str | ZfitParameter | Iterable[str | ZfitParameter],
) -> list[int]:
"""Get parameter positions in the samples array from names or objects.
Args:
params: Single parameter or collection of parameters.
Can be parameter name(s) or object(s).
Returns:
List of parameter positions in the samples array.
"""
# Check for invalid types before conversion
if isinstance(params, dict):
msg = f"Invalid parameter type: {type(params)}"
raise TypeError(msg)
# Convert single parameter to list
params = convert_to_container(params)
positions = []
for param in params:
if isinstance(param, str):
if param not in self._param_by_name:
msg = f"Parameter '{param}' not found"
raise ValueError(msg)
positions.append(self._position_by_name[param])
elif isinstance(param, zfit.Parameter):
if param not in self._name_by_param:
msg = f"Parameter {param} not found in posterior samples"
raise ValueError(msg)
positions.append(self._position_by_name[param.name])
else:
msg = (
f"Invalid parameter type: {type(param)}. Expected string (parameter name) or ZfitParameter object."
)
raise TypeError(msg)
return positions
def __repr__(self) -> str:
return f"PosteriorSamples(n_samples={len(self.samples)}, params={[param.name for param in self._params]})"
[docs]
def __str__(self) -> str:
"""Nice string representation of posterior results."""
import colored # noqa: PLC0415
from colorama import Style # noqa: PLC0415
from tabulate import tabulate # noqa: PLC0415
# Header
string = Style.BRIGHT + "PosteriorSamples" + Style.NORMAL + f" from\n{self.loss} \nwith\n{self.sampler}\n\n"
# Convergence summary table
def color_on_bool(value, on_true=None, on_false=None):
"""Color boolean values.
Args:
value: Boolean value to color
on_true: Color for True values. Defaults to green background.
on_false: Color for False values. Defaults to red background.
"""
if on_true is None:
on_true = colored.bg("green")
if on_false is None:
on_false = colored.bg("red")
if on_false is False:
on_false = ""
text = "True" if value else "False"
color = on_true if value else on_false
return f"{color}{text}{Style.RESET_ALL}"
# Main diagnostics table
rhat_str = "N/A (single chain)"
if self._rhat is not None:
max_rhat = np.max(self._rhat)
rhat_str = f"{max_rhat:.4f}"
if max_rhat > 1.1:
rhat_str = colored.fg("red") + rhat_str + Style.RESET_ALL
ess_str = "N/A"
if self._ess is not None:
min_ess = np.min(self._ess)
ess_str = f"{min_ess:.0f}"
if min_ess < 100:
ess_str = colored.fg("red") + ess_str + Style.RESET_ALL
string += tabulate(
[
[
color_on_bool(self.valid),
color_on_bool(self.converged, on_true=colored.bg("green"), on_false=colored.bg("yellow")),
rhat_str,
ess_str,
f"{len(self.samples):>13} | {self.n_warmup:>6} | {self.n_samples:>10}",
]
],
[
"valid",
"converged",
"max R̂",
"min ESS",
"total samples | warmup | per walker",
],
tablefmt="fancy_grid",
disable_numparse=True,
colalign=["center", "center", "center", "center", "right"],
)
# Parameters table
string += "\n\n" + Style.BRIGHT + "Parameters\n" + Style.NORMAL
param_data = []
# First pass: collect all values to determine optimal formatting widths
means = self.mean()
stds = self.std()
lower, upper = self.credible_interval(alpha=0.05)
# Determine the width needed for credible intervals
all_ci_values = []
for i in range(len(self._params)):
ci_lower = lower[i] if hasattr(lower, "__len__") else lower
ci_upper = upper[i] if hasattr(upper, "__len__") else upper
all_ci_values.extend([ci_lower, ci_upper])
max_abs_ci = max(abs(val) for val in all_ci_values)
if max_abs_ci >= 1000:
ci_width = 10 # For large numbers like n_sig, n_bkg
elif max_abs_ci >= 10:
ci_width = 8 # For moderate numbers
else:
ci_width = 7 # For small numbers like mu, sigma
# Second pass: format data with proper alignment
for i, param in enumerate(self._params):
param_name = param.name
mean_val = means[i]
std_val = stds[i]
ci_lower = lower[i] if hasattr(lower, "__len__") else lower
ci_upper = upper[i] if hasattr(upper, "__len__") else upper
# R-hat and ESS for this parameter
rhat_param = f"{self._rhat[i]:.3f}" if self._rhat is not None else "N/A"
ess_param = f"{self._ess[i]:.0f}" if self._ess is not None else "N/A"
# Format credible interval with proper alignment
ci_formatted = f"[{ci_lower:>{ci_width}.4f}, {ci_upper:>{ci_width}.4f}]"
param_data.append(
[
param_name,
f"{mean_val:>8.4f}",
f"± {std_val:.4f}",
ci_formatted,
rhat_param,
ess_param,
]
)
string += tabulate(
param_data,
headers=["parameter", "mean", "std", "95% CI", "R̂", "ESS"],
tablefmt="simple",
floatfmt=".4f",
colalign=["left", "right", "right", "right", "right", "right"],
)
# Additional info
if hasattr(self._sampler, "nwalkers"):
string += f"\n\nSampler: {self._sampler.__class__.__name__} with {self._sampler.nwalkers} walkers"
return string
def _repr_pretty_(self, p, cycle):
"""IPython/Jupyter pretty display."""
if cycle:
p.text(self.__repr__())
return
p.text(self.__str__())
def _compute_convergence_diagnostics(self):
"""Compute MCMC convergence diagnostics."""
# Check for valid samples (no NaN/inf)
self._valid = not (np.any(np.isnan(self.samples)) or np.any(np.isinf(self.samples)))
if not self._valid:
self._converged = False
self._rhat = None
self._ess = None
return
# Try to use ArviZ for better diagnostics if available
import arviz as az # noqa: PLC0415
idata = self.to_arviz()
# Compute R-hat and ESS for all parameters at once
rhat_data = az.rhat(idata)
ess_data = az.ess(idata)
# Extract values for each parameter
self._rhat = np.array([rhat_data[param.name].values for param in self._params])
self._ess = np.array([ess_data[param.name].values for param in self._params])
# Check convergence criteria
rhat_converged = bool(np.all(self._rhat < 1.1))
ess_converged = bool(np.all(self._ess > 100))
self._converged = rhat_converged and ess_converged
@property
def rhat(self) -> npt.NDArray[np.float64] | None:
"""Gelman-Rubin R-hat convergence diagnostic.
Values < 1.1 indicate good convergence.
Only available when multiple chains are used.
"""
return self._rhat
@property
def ess(self) -> npt.NDArray[np.float64] | None:
"""Effective sample size for each parameter.
Accounts for autocorrelation in MCMC chains.
Higher values indicate more independent samples.
"""
return self._ess
[docs]
def convergence_summary(self) -> dict:
"""Summary of convergence diagnostics."""
import arviz as az # noqa: PLC0415
idata = self.to_arviz()
# Use ArviZ for comprehensive diagnostics
return {
"valid": self._valid,
"converged": self._converged,
"rhat": az.rhat(idata).to_dict(),
"ess_bulk": az.ess(idata, method="bulk").to_dict(),
"ess_tail": az.ess(idata, method="tail").to_dict(),
"mcse_mean": az.mcse(idata, method="mean").to_dict(),
"mcse_sd": az.mcse(idata, method="sd").to_dict(),
}
[docs]
def diagnostics(self) -> dict:
"""Comprehensive diagnostics report.
Returns:
Dictionary with all available diagnostics.
"""
import arviz as az # noqa: PLC0415 # noqa
idata = self.to_arviz()
return {
"valid": self.valid,
"converged": self.converged,
"summary": az.summary(idata),
"rhat": az.rhat(idata),
"ess_bulk": az.ess(idata, method="bulk"),
"ess_tail": az.ess(idata, method="tail"),
"mcse_mean": az.mcse(idata, method="mean"),
"mcse_sd": az.mcse(idata, method="sd"),
"loo": None, # Placeholder for future LOO-CV integration
}