"""MCMC samplers for Bayesian inference in zfit."""
# Copyright (c) 2025 zfit
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from .. import z
from .._interfaces import ZfitParameter
from ..z import numpy as znp
from .base_sampler import BaseMCMCSampler
if TYPE_CHECKING:
import emcee
import numpy.typing as npt
from zfit._bayesian.posterior import PosteriorSamples
from .._interfaces import ZfitLoss
[docs]
class EmceeSampler(BaseMCMCSampler):
"""MCMC sampler using emcee (https://emcee.readthedocs.io).
EmceeSampler is an ensemble sampler that uses multiple 'walkers' to explore
the posterior distribution. It's particularly good for problems with
strongly correlated parameters and doesn't require gradients of the
log-probability function.
This sampler requires all parameters to have priors defined, as it uses
the product of likelihood and prior for sampling.
Examples:
Basic usage with default settings:
>>> sampler = zfit.mcmc.EmceeSampler(nwalkers=32)
>>> result = sampler.sample(loss=nll, params=params, n_samples=1000)
With custom moves and settings:
>>> import emcee
>>> custom_moves = [(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)]
>>> sampler = zfit.mcmc.EmceeSampler(nwalkers=50,moves=custom_moves,verbosity=8)
>>> result = sampler.sample(loss=nll, params=params,
... n_samples=2000, n_warmup=500)
"""
def __init__(
self,
nwalkers: int | None = None,
*,
n_samples: int | None = None,
n_warmup: int | None = None,
moves: list[tuple[emcee.moves.Move, float]] | None = None,
backend: emcee.backends.Backend | None = None,
# pool: object | None = None, # not possible?
name: str = "EmceeSampler",
verbosity: int | None = None,
):
"""Initialize an EmceeSampler.
Args:
nwalkers: Number of walkers to use. If None, will use
max(2 * n_dims, 5) where n_dims is the number of parameters.
Must be at least twice the number of dimensions.
n_warmup: Default value for number of samples for warmup. The number of warmup points that will be
discarded.
n_samples: Default value for number of samples. The number of points to sample.
moves: The proposal moves to use. Can be a single move
or a list of (move, weight) tuples. If None, uses emcee's default
StretchMove. See emcee documentation for available moves.
backend: Backend to store the chain
state and samples. Useful for checkpointing long runs. If None,
samples are stored in memory only.
name: Name of the sampler for identification.
verbosity: Verbosity level:
- 0-6: No progress bars
- 7: Print sampling phases
- 8+: Show progress bars during sampling
Raises:
ImportError: If emcee is not installed.
Note:
The number of walkers should be at least 2 * n_params for good performance.
Larger numbers of walkers can help with difficult posteriors but increase
computational cost linearly.
"""
try:
import emcee # noqa: PLC0415, F401
except ImportError as error:
msg = "emcee is required for EmceeSampler. Install with 'pip install emcee'."
raise ImportError(msg) from error
super().__init__(name=name, verbosity=verbosity, n_samples=n_samples, n_warmup=n_warmup)
self.nwalkers = nwalkers
self.moves = moves
self.backend = backend
def _sample(
self,
loss: ZfitLoss,
params: list[ZfitParameter],
n_samples: int,
n_warmup: int,
init: PosteriorSamples | None,
) -> PosteriorSamples:
"""Implementation of emcee sampling.
Note:
- The total number of samples in the result is n_samples * nwalkers
- Sampling time scales linearly with nwalkers, n_samples, and n_warmup
"""
import emcee # noqa: PLC0415
import zfit # noqa: PLC0415
# Import here to avoid circular imports
from zfit._bayesian.posterior import PosteriorSamples # noqa: PLC0415
n_dims = len(params)
if (nwalkers := self.nwalkers) is None:
nwalkers = max(2 * n_dims, 5)
# @zfit.z.function
def calculate_priors(x):
# Calculate log prior
return znp.sum([param.prior.log_pdf(x[i]) for i, param in enumerate(params)])
# Define the log probability function
def log_prob(x):
x = znp.asarray(x)
return log_prob_jit(x)
@z.function(wraps="tensor")
def log_prob_jit(x):
# with zfit.param.set_values(params, x):
# Calculate log likelihood (negative of loss)
import zfit # noqa: PLC0415
zfit.param.assign_values_jit(params, x)
log_likelihood = -loss.value()
log_prior = calculate_priors(x)
return log_likelihood + log_prior
# Initialize walkers
# Track if we're using emcee state continuation
using_emcee_state = False
emcee_state = None
pos = None # Initialize pos variable
if init is not None:
# Initialize from previous PosteriorSamples
if not isinstance(init, PosteriorSamples):
msg = f"init must be a PosteriorSamples instance, not {type(init)}"
raise TypeError(msg)
# Check parameter compatibility
init_param_names = set(init.param_names)
current_param_names = {p.name for p in params}
if init_param_names != current_param_names:
msg = (
f"Parameter names don't match. "
f"Previous: {sorted(init_param_names)}, "
f"Current: {sorted(current_param_names)}"
)
raise ValueError(msg)
# Check if the previous run was also from an emcee sampler
if hasattr(init, "info") and init.info.get("type") == "emcee":
# Try to use the stored emcee state for improved continuation
stored_state = init.info.get("state")
if stored_state is not None:
self._print("Found previous emcee state - attempting optimized continuation", level=7)
# Check if state is compatible (has coords attribute and right shape)
try:
if hasattr(stored_state, "coords") and stored_state.coords is not None:
state_nwalkers = stored_state.coords.shape[0]
state_ndims = stored_state.coords.shape[1]
if state_ndims == n_dims:
if state_nwalkers == nwalkers:
# Perfect match - we can use the state directly
pos = stored_state.coords
# Check if parameter order has changed
reorder_indices = []
for param in params:
init_idx = init._position_by_name[param.name]
reorder_indices.append(init_idx)
if reorder_indices != list(range(n_dims)):
# Need to reorder columns from init's order to current order
pos = pos[:, reorder_indices]
# We can't use the full emcee state when reordering
using_emcee_state = False
self._print(
"Parameter order changed, reordering positions but not using full state",
level=7,
)
else:
# Parameter order unchanged, can use full state
emcee_state = stored_state
using_emcee_state = True
self._print(
f"Using exact emcee state continuation with {nwalkers} walkers", level=7
)
else:
# Different number of walkers - need to adapt
self._print(
f"Adapting emcee state: {state_nwalkers} -> {nwalkers} walkers", level=7
)
pos = self._adapt_walker_positions(stored_state.coords, nwalkers, n_dims)
# Check if parameter order has changed
reorder_indices = []
for param in params:
init_idx = init._position_by_name[param.name]
reorder_indices.append(init_idx)
if reorder_indices != list(range(n_dims)):
# Need to reorder columns from init's order to current order
pos = pos[:, reorder_indices]
self._print(
"Parameter order changed, reordered positions after adaptation", level=7
)
# Note: We lose the log_prob and random_state when adapting walkers
else:
self._print(f"State dimension mismatch: {state_ndims} != {n_dims}", level=5)
# Fall through to position extraction
except Exception as e:
self._print(f"Could not use stored emcee state: {e}", level=5)
# Fall through to position extraction
# If we couldn't use the emcee state, extract positions
if not using_emcee_state and pos is None:
# Get the last positions from previous sampling
# Extract positions for the last 'nwalkers' samples
if hasattr(init, "raw_result") and init.raw_result is not None:
# If we have access to the raw emcee sampler, use its last positions
try:
# Get the last state from the sampler
chain = init.raw_result.get_chain()
# chain has shape (n_steps, n_walkers, n_params)
last_positions = chain[-1, :, :] # Last step, all walkers
# If number of walkers differs, we need to resample
prev_nwalkers = last_positions.shape[0]
if prev_nwalkers != nwalkers:
pos = self._adapt_walker_positions(last_positions, nwalkers, n_dims)
else:
pos = last_positions
# Ensure parameter ordering matches
# Create mapping from init's parameter order to current parameter order
reorder_indices = []
for param in params:
init_idx = init._position_by_name[param.name]
reorder_indices.append(init_idx)
if reorder_indices != list(range(n_dims)):
# Need to reorder columns from init's order to current order
pos = pos[:, reorder_indices]
except Exception as e:
self._print(f"Could not extract positions from raw sampler: {e}", level=5)
# Fall back to using the flat samples
pos = self._extract_positions_from_samples(init, params, nwalkers, n_dims)
else:
# Use flat samples to reconstruct positions
pos = self._extract_positions_from_samples(init, params, nwalkers, n_dims)
if using_emcee_state:
self._print(f"Continuing from previous emcee state with {nwalkers} walkers", level=7)
else:
self._print(f"Initialized {nwalkers} walkers from previous posterior samples", level=7)
else:
# Initialize around current parameter values, respecting bounds
initial_positions = np.array([param.value() for param in params])
pos = np.zeros((nwalkers, n_dims))
for i, param in enumerate(params):
# Use small perturbations around current value, clipped to bounds
center = initial_positions[i]
scale = param.stepsize
# Generate perturbations
perturbations = center + scale * np.random.randn(nwalkers)
# Clip to parameter bounds if they exist
if param.has_limits:
if param.lower is not None:
perturbations = np.maximum(perturbations, param.lower + 1e-8)
if param.upper is not None:
perturbations = np.minimum(perturbations, param.upper - 1e-8)
pos[:, i] = perturbations
# Set up sampler
sampler = emcee.EnsembleSampler(
nwalkers,
n_dims,
log_prob,
moves=self.moves,
backend=self.backend,
vectorize=False,
)
oldvals = np.array(params)
with zfit.param.set_values(params, oldvals):
# Run burn-in
if n_warmup > 0:
self._print(f"Running burn-in phase with {n_warmup} steps...", level=7)
if using_emcee_state:
# Use the full state object which includes log_prob and random state
self._print("Starting from previous emcee state", level=7)
state = sampler.run_mcmc(emcee_state, n_warmup, progress=self.verbosity >= 8)
else:
state = sampler.run_mcmc(pos, n_warmup, progress=self.verbosity >= 8)
sampler.reset()
elif using_emcee_state:
state = emcee_state
self._print("Skipping burn-in, continuing from emcee state", level=7)
else:
state = pos
self._print("Skipping burn-in", level=7)
# Run production
self._print(f"Running production phase with {n_samples} steps...", level=7)
state = sampler.run_mcmc(state, n_samples, progress=self.verbosity >= 8)
# Create result object
samples = sampler.get_chain(flat=True)
return PosteriorSamples(
info={"type": "emcee", "state": state, "sampler": sampler},
samples=samples,
params=params,
loss=loss,
sampler=self,
n_warmup=n_warmup,
n_samples=n_samples,
raw_result=sampler,
)
def _adapt_walker_positions(
self,
positions: npt.NDArray[np.float64],
nwalkers: int,
n_dims: int,
) -> npt.NDArray[np.float64]:
"""Adapt walker positions when number of walkers changes.
Args:
positions: Array of shape (prev_nwalkers, n_dims) with walker positions.
nwalkers: Target number of walkers.
n_dims: Number of dimensions.
Returns:
Array of shape (nwalkers, n_dims) with adapted positions.
"""
prev_nwalkers = positions.shape[0]
if nwalkers > prev_nwalkers:
# Need more walkers: randomly duplicate some
indices = np.random.choice(prev_nwalkers, nwalkers, replace=True)
new_positions = positions[indices]
# Add small noise to duplicated walkers to break symmetry
noise_mask = np.zeros(nwalkers, dtype=bool)
unique_indices, counts = np.unique(indices, return_counts=True)
for idx, count in zip(unique_indices, counts, strict=False):
if count > 1:
duplicates = np.where(indices == idx)[0][1:] # Skip first occurrence
noise_mask[duplicates] = True
# Add scaled noise based on the spread of positions
if np.sum(noise_mask) > 0:
pos_std = np.std(positions, axis=0)
noise_scale = 1e-4 * np.maximum(pos_std, 1e-8) # Avoid zero scale
new_positions[noise_mask] += noise_scale * np.random.randn(np.sum(noise_mask), n_dims)
else:
# Need fewer walkers: randomly select subset
indices = np.random.choice(prev_nwalkers, nwalkers, replace=False)
new_positions = positions[indices]
return new_positions
def _extract_positions_from_samples(
self,
init: PosteriorSamples,
params: list[ZfitParameter],
nwalkers: int,
n_dims: int,
) -> npt.NDArray[np.float64]:
"""Extract walker positions from flat posterior samples.
Args:
init: PosteriorSamples instance.
params: Current parameter list.
nwalkers: Number of walkers needed.
n_dims: Number of dimensions.
Returns:
Array of shape (nwalkers, n_dims) with initial positions.
"""
# Get the last nwalkers samples (or resample if needed)
total_samples = init.samples.shape[0]
if total_samples >= nwalkers:
# Use the last nwalkers samples
last_samples = init.samples[-nwalkers:, :]
else:
# Not enough samples, need to resample with replacement
self._print(f"Resampling from {total_samples} samples to get {nwalkers} walkers", level=7)
indices = np.random.choice(total_samples, nwalkers, replace=True)
last_samples = init.samples[indices, :]
# Add small noise to duplicated samples
unique_indices, counts = np.unique(indices, return_counts=True)
for idx, count in zip(unique_indices, counts, strict=False):
if count > 1:
duplicates = np.where(indices == idx)[0][1:] # Skip first occurrence
last_samples[duplicates] += 1e-4 * np.random.randn(len(duplicates), n_dims)
# Ensure parameter ordering matches
# Create mapping from init's parameter order to current parameter order
reorder_indices = []
for param in params:
init_idx = init._position_by_name[param.name]
reorder_indices.append(init_idx)
if reorder_indices != list(range(n_dims)):
# Need to reorder columns from init's order to current order
last_samples = last_samples[:, reorder_indices]
return np.asarray(last_samples)