Source code for zfit.core.data

#  Copyright (c) 2024 zfit

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Literal, Optional, Union

import pydantic
import xxhash
from pydantic import Field
from tensorflow.python.util.deprecation import deprecated, deprecated_args

from ..serialization import SpaceRepr
from ..serialization.serializer import BaseRepr, to_orm_init
from .parameter import set_values
from .serialmixin import SerializableMixin, ZfitSerializable

if TYPE_CHECKING:
    import zfit

from collections import OrderedDict
from collections.abc import Callable, Mapping

import numpy as np
import pandas as pd
import tensorflow as tf
import uproot

import zfit
import zfit.z.numpy as znp

from .. import z
from ..settings import run, ztypes
from ..util import ztyping
from ..util.cache import GraphCachable, invalidate_graph
from ..util.container import convert_to_container
from ..util.exception import (
    ObsIncompatibleError,
    ShapeIncompatibleError,
    WorkInProgressError,
)
from ..util.temporary import TemporarilySet
from .baseobject import BaseObject
from .coordinates import convert_to_obs_str
from .dimension import BaseDimensional
from .interfaces import ZfitSpace, ZfitUnbinnedData
from .space import Space, convert_to_space


def convert_to_data(data, obs=None):
    if isinstance(data, ZfitUnbinnedData):
        return data
    elif isinstance(data, (tf.data.Dataset, LightDataset)):
        return Data(dataset=data)
    elif isinstance(data, pd.DataFrame):
        return Data.from_pandas(df=data, obs=obs)

    if obs is None:
        msg = f"If data is not a Data-like object, obs has to be specified. Data is {data} and obs is {obs}."
        raise ValueError(msg)
    if isinstance(data, (int, float)):
        data = znp.array([data])
    if isinstance(data, Iterable):
        data = znp.array(data)
    if isinstance(data, np.ndarray):
        return Data.from_numpy(obs=obs, array=data)
    if isinstance(data, (tf.Tensor, znp.ndarray, tf.Variable)):
        return Data.from_tensor(obs=obs, tensor=data)

    msg = f"Cannot convert {data} to a Data object."
    raise TypeError(msg)


# TODO: make cut only once, then remember
[docs] class Data( ZfitUnbinnedData, BaseDimensional, BaseObject, GraphCachable, SerializableMixin, ZfitSerializable, ): BATCH_SIZE = 1000000 # 1 mio def __init__( self, dataset: tf.data.Dataset | LightDataset, obs: ztyping.ObsTypeInput = None, name: str | None = None, weights=None, dtype: tf.DType = None, use_hash: bool | None = None, ): """Create a data holder from a ``dataset`` used to feed into ``models``. Args: dataset: A dataset storing the actual values obs: Observables where the data is defined in name: Name of the ``Data`` weights: Weights of the data dtype: |dtype_arg_descr| use_hash: Whether to use a hash for caching """ if use_hash is None: use_hash = run.hashing_data() self._use_hash = use_hash if name is None: name = "Data" if dtype is None: dtype = ztypes.float super().__init__(name=name) self._permutation_indices_data = None self._next_batch = None self._dtype = dtype self._nevents = None self._weights = None self._data_range = None self._set_space(obs) self._original_space = self.space self._data_range = self.space # TODO proper data cuts: currently set so that the cuts in all dims are applied self.dataset = dataset.batch(100_000_000) self._name = name self._set_weights(weights=weights) self._hashint = None self._update_hash() @property def nevents(self): nevents = self._nevents if nevents is None: nevents = self._get_nevents() return nevents @property def hashint(self) -> int | None: return self._hashint # TODO: which naming? nevents or n_events @property def _approx_nevents(self): return self.nevents @property def n_events(self): return self.nevents @property def has_weights(self): return self._weights is not None @property def dtype(self): return self._dtype def _set_space(self, obs: Space): obs = convert_to_space(obs) self._check_n_obs(space=obs) obs = obs.with_autofill_axes(overwrite=True) self._space = obs self._update_hash() @property def data_range(self): data_range = self._data_range if data_range is None: data_range = self.space return data_range @invalidate_graph def set_data_range(self, data_range): data_range = self._check_input_data_range(data_range=data_range) def setter(value): self._data_range = value self._update_hash() def getter(): return self._data_range return TemporarilySet(value=data_range, setter=setter, getter=getter) @property def weights(self): """Get the weights of the data.""" # TODO: refactor below more general, when to apply a cut? if self.data_range.has_limits and self.has_weights: raw_values = self._value_internal(obs=self.data_range.obs, filter=False) is_inside = self.data_range.inside(raw_values) weights = self._weights[is_inside] else: weights = self._weights return weights
[docs] @deprecated(None, "Do not set the weights on a data set, create a new one instead.") @invalidate_graph def set_weights(self, weights: ztyping.WeightsInputType): """Set (temporarily) the weights of the dataset. Args: weights: """ # weights = self._set_weights(weights) def setter(value): self._set_weights(value) def getter(): return self.weights return TemporarilySet(value=weights, getter=getter, setter=setter)
def _set_weights(self, weights): if weights is not None: weights = z.convert_to_tensor(weights) weights = z.to_real(weights) if weights.shape.ndims != 1: if weights.shape.ndims == 2 and weights.shape[1] == 1: weights = znp.reshape(weights, (-1,)) else: msg = "Weights have to be 1-Dim objects." raise ShapeIncompatibleError(msg) self._weights = weights self._update_hash() return weights @property def space(self) -> ZfitSpace: return self._space
[docs] @classmethod def from_pandas( cls, df: pd.DataFrame, obs: ztyping.ObsTypeInput = None, weights: ztyping.WeightsInputType | str = None, name: str | None = None, dtype: tf.DType = None, use_hash: bool | None = None, ): """Create a ``Data`` from a pandas DataFrame. If ``obs`` is ``None``, columns are used as obs. Args: df: pandas DataFrame that contains the data. If ``obs`` is ``None``, columns are used as obs. Can be a superset of obs. obs: obs to use for the data. obs have to be the columns in the data frame. If ``None``, columns are used as obs. weights: Weights of the data. Has to be 1-D and match the shape of the data (nevents) or a string that is a column in the dataframe. By default, looks for a column ``""``, i.e. an empty string. name: dtype: dtype of the data use_hash: If ``True``, a hash of the data is created and is used to identify it in caching. """ weights_requested = weights is not None if weights is None: weights = "" if obs is None: obs = list(df.columns) if isinstance(df, pd.Series): df = df.to_frame() obs = convert_to_space(obs) not_in_df = set(obs.obs) - set(df.columns) if not_in_df: msg = f"Observables {not_in_df} not in dataframe with columns {df.columns}" raise ValueError(msg) space = obs if isinstance(weights, str): # it's in the df if weights not in df.columns: if weights_requested: msg = f"Weights {weights} is a string and not in dataframe with columns {df.columns}" raise ValueError(msg) weights = None else: obs = [o for o in space.obs if o != weights] weights = df[weights] space = space.with_obs(obs=obs) not_in_df = set(space.obs) - set(df.columns) if not_in_df: msg = f"Observables {not_in_df} not in dataframe with columns {df.columns}" raise ValueError(msg) array = df[list(space.obs)].to_numpy() return Data.from_numpy( # *not* class, if subclass, keep constructor obs=space, array=array, weights=weights, name=name, dtype=dtype, use_hash=use_hash, )
[docs] @classmethod @deprecated_args(None, "Use obs instead.", "branches") @deprecated_args( None, "Use obs_alias instead and make sure to invert the logic! I.e. it's a mapping from" " the observable name to the actual branch name.", "branches_alias", ) def from_root( cls, path: str, treepath: str, obs: ZfitSpace = None, *, weights: ztyping.WeightsStrInputType = None, obs_alias: Mapping[str, str] | None = None, name: str | None = None, dtype: tf.DType = None, root_dir_options=None, use_hash: bool | None = None, # deprecated branches: list[str] | None = None, branches_alias: dict | None = None, ) -> Data: """Create a ``Data`` from a ROOT file. Arguments are passed to ``uproot``. The arguments are passed to uproot directly. Args: path: Path to the root file. treepath: Name of the tree in the root file. obs: Observables of the data. This will also be the columns of the data if not *obs_alias* is given. weights: Weights of the data. Has to be 1-D and match the shape of the data (nevents). Can be a column of the ROOT file by using a string corresponding to a column. obs_alias: A mapping from the ``obs`` (as keys) to the actual ``branches`` (as values) in the root file. This allows to have different ``observable`` names, independent of the branch name in the file. name: root_dir_options: Returns: ``zfit.Data``: A ``Data`` object containing the unbinned data. """ # begin deprecated legacy arguments if branches: obs = branches del branches if branches_alias is not None: if obs_alias is not None: msg = "Cannot use both `branches_alias` and `obs_alias`." raise ValueError(msg) obs_alias = {obs: branch for branch, obs in branches_alias.items()} del branches_alias # end legacy if root_dir_options is None: root_dir_options = {} if obs_alias is None and obs is None: msg = "Either branches or branches_alias has to be specified." raise ValueError(msg) if obs_alias is None: obs_alias = {} if obs is None: obs = list(obs_alias.values()) obs = convert_to_space(obs) branches = [obs_alias.get(branch, branch) for branch in obs.obs] weights_are_branch = isinstance(weights, str) def uproot_loader(): with uproot.open(path, **root_dir_options)[treepath] as root_tree: branches_with_weights = [*branches, weights] if weights_are_branch else branches branches_with_weights = tuple(branches_with_weights) data = root_tree.arrays(expressions=branches_with_weights, library="pd") data_np = data[branches].to_numpy() weights_np = data[weights] if weights_are_branch else None return data_np, weights_np data, weights_np = uproot_loader() if not weights_are_branch: weights_np = weights dataset = LightDataset.from_tensor(data) return Data( # *not* class, if subclass, keep constructor dataset=dataset, obs=obs, weights=weights_np, name=name, dtype=dtype, use_hash=use_hash, )
[docs] @classmethod def from_numpy( cls, obs: ztyping.ObsTypeInput, array: np.ndarray, weights: ztyping.WeightsInputType = None, name: str | None = None, dtype: tf.DType = None, use_hash=None, ): """Create ``Data`` from a ``np.array``. Args: obs: Observables of the data. They will be matched to the data in the same order. array: Numpy array containing the data. weights: Weights of the data. Has to be 1-D and match the shape of the data (nevents). name: Name of the data. dtype: dtype of the data. use_hash: If ``True``, a hash of the data is created and is used to identify it in caching. Returns: ``zfit.Data``: A ``Data`` object containing the unbinned data. """ if not isinstance(array, (np.ndarray)) and not (tf.is_tensor(array) and hasattr(array, "numpy")): msg = f"`array` has to be a `np.ndarray`. Is currently {type(array)}" raise TypeError(msg) if dtype is None: dtype = ztypes.float array = znp.asarray(array) tensor = tf.cast(array, dtype=dtype) return Data.from_tensor( # *not* class, if subclass, keep constructor obs=obs, tensor=tensor, weights=weights, name=name, dtype=dtype, use_hash=use_hash, )
[docs] @classmethod def from_tensor( cls, obs: ztyping.ObsTypeInput, tensor: tf.Tensor, weights: ztyping.WeightsInputType = None, name: str | None = None, dtype: tf.DType = None, use_hash=None, ) -> Data: """Create a ``Data`` from a ``tf.Tensor``. ``Value`` simply returns the tensor (in the right order). Args: obs: Observables of the data. They will be matched to the data in the same order. tensor: Tensor containing the data. weights: Weights of the data. Has to be 1-D and match the shape of the data (nevents). name: Name of the data. Returns: ``zfit.Data``: A ``Data`` object containing the unbinned data. """ if dtype is None: dtype = ztypes.float tensor = tf.cast(tensor, dtype=dtype) if len(tensor.shape) == 0: tensor = znp.expand_dims(tensor, -1) if len(tensor.shape) == 1: tensor = znp.expand_dims(tensor, -1) dataset = LightDataset.from_tensor(tensor) return Data( # *not* class, if subclass, keep constructor dataset=dataset, obs=obs, name=name, weights=weights, dtype=dtype, use_hash=use_hash, )
def _update_hash(self): if not run.executing_eagerly() or not self._use_hash: self._hashint = None else: try: hashval = xxhash.xxh128(np.asarray(self.value())) if self.has_weights: hashval.update(np.asarray(self.weights)) self._hashint = hashval.intdigest() except AttributeError: # if the dataset is not yet initialized; this is allowed self._hashint = None
[docs] def with_obs(self, obs): """Create a new ``Data`` with a subset of the data using the *obs*. Args: obs: Observables to return. Has to be a subset of the original observables. Returns: ``zfit.Data``: A new ``Data`` object containing the subset of the data. """ values = self.value(obs) return type(self).from_tensor(obs=self.space, tensor=values, weights=self.weights, name=self.name)
[docs] def to_pandas(self, obs: ztyping.ObsTypeInput = None, weightsname: str | None = None): """Create a ``pd.DataFrame`` from ``obs`` as columns and return it. Args: obs: The observables to use as columns. If ``None``, all observables are used. weightsname: The name of the weights column if the data has weights. If ``None``, defaults to ``""``, an empty string. Returns: ``pd.DataFrame``: A ``pd.DataFrame`` containing the data and the weights (if present). """ values = self.value(obs=obs) if obs is None: obs = self.obs obs_str = list(convert_to_obs_str(obs)) values = values.numpy() if self.has_weights: weights = self.weights.numpy() if weightsname is None: weightsname = "" values = np.concatenate((values, weights[:, None]), axis=1) obs_str = [*obs_str, weightsname] return pd.DataFrame(data=values, columns=obs_str)
[docs] def unstack_x(self, obs: ztyping.ObsTypeInput = None, always_list=None): """Return the unstacked data: a list of tensors or a single Tensor. Args: obs: Observables to return. If ``None``, all observables are returned. Can be a subset of the original always_list: If ``True``, always return a list, even if only one observable is requested. Returns: List(tf.Tensor) """ value = self.value(obs=obs) if len(value.shape) == 1: value = znp.expand_dims(value, -1) # to make sure we can unstack it again return z.unstack_x(value, always_list=always_list)
[docs] def value(self, obs: ztyping.ObsTypeInput = None): """Return the data as a numpy-like object in ``obs`` order. Args: obs: Observables to return. If ``None``, all observables are returned. Can be a subset of the original observables. If a string is given, a 1-D array is returned with shape (nevents,). If a list of strings or a ``zfit.Space`` is given, a 2-D array is returned with shape (nevents, nobs). Returns: """ out = znp.asarray(self._value_internal(obs=obs)) if isinstance(obs, str): out = znp.squeeze(out, axis=-1) return out
def numpy(self): return self.value().numpy() def _cut_data(self, value, obs=None): if self.data_range.has_limits: data_range = self.data_range.with_obs(obs=obs) value = data_range.filter(value) return value def _value_internal(self, obs: ztyping.ObsTypeInput = None, filter: bool = True): if obs is not None: obs = convert_to_obs_str(obs) # for raw_value in self.dataset: # value = self._check_convert_value(raw_value) value = self.dataset.value() if filter: value = self._cut_data(value, obs=self._original_space.obs) return self._sort_value(value=value, obs=obs) def _check_convert_value(self, value): # TODO(Mayou36): add conversion to right dimension? (n_events, n_obs)? # check if 1-D? if len(value.shape.as_list()) == 0: value = znp.expand_dims(value, -1) if len(value.shape.as_list()) == 1: value = znp.expand_dims(value, -1) # cast data to right type if value.dtype != self.dtype: value = tf.cast(value, dtype=self.dtype) return value def _sort_value(self, value, obs: tuple[str]): obs = convert_to_container(value=obs, container=tuple) # TODO CURRENT: deactivated below! perm_indices = self.space.axes if self.space.axes != tuple(range(value.shape[-1])) else False if obs: if not frozenset(obs) <= frozenset(self.obs): msg = ( f"The observable(s) {frozenset(obs) - frozenset(self.obs)} are not contained in the dataset. " f"Only the following are: {self.obs}" ) raise ValueError(msg) perm_indices = self.space.get_reorder_indices(obs=obs) if perm_indices: value = z.unstack_x(value, always_list=True) value = [value[i] for i in perm_indices] value = z.stack_x(value) return value # TODO(Mayou36): use Space to permute data? # TODO(Mayou36): raise error is not obs <= self.obs? @invalidate_graph def sort_by_axes(self, axes: ztyping.AxesTypeInput, allow_superset: bool = True): if not allow_superset and not frozenset(axes) <= frozenset(self.axes): msg = ( f"The observable(s) {frozenset(axes) - frozenset(self.axes)} are not contained in the dataset. " f"Only the following are: {self.axes}" ) raise ValueError(msg) space = self.space.with_axes(axes=axes, allow_subset=True) def setter(value): self._space = value def getter(): return self.space return TemporarilySet(value=space, setter=setter, getter=getter) # @invalidate_graph def sort_by_obs(self, obs: ztyping.ObsTypeInput, allow_superset: bool = False): if not allow_superset and not frozenset(obs) <= frozenset(self.obs): msg = ( f"The observable(s) {frozenset(obs) - frozenset(self.obs)} are not contained in the dataset. " f"Only the following are: {self.obs}" ) raise ValueError(msg) space = self.space.with_obs(obs=obs, allow_subset=True, allow_superset=allow_superset) def setter(value): self._space = value def getter(): return self.space return TemporarilySet(value=space, setter=setter, getter=getter) def _check_input_data_range(self, data_range): data_range = self._convert_sort_space(limits=data_range) if frozenset(self.data_range.obs) != frozenset(data_range.obs): msg = ( f"Data range has to cover the full observable space {self.data_range.obs}, not " f"only {data_range.obs}" ) raise ObsIncompatibleError(msg) return data_range # TODO(Mayou36): refactor with pdf or other range things? def _convert_sort_space( self, obs: ztyping.ObsTypeInput = None, axes: ztyping.AxesTypeInput = None, limits: ztyping.LimitsTypeInput = None, ) -> Space | None: """Convert the inputs (using eventually ``obs``, ``axes``) to :py:class:`~zfit.Space` and sort them according to own `obs`. Args: obs: axes: limits: Returns: """ if obs is None: # for simple limits to convert them obs = self.obs space = convert_to_space(obs=obs, axes=axes, limits=limits) if self.space is not None: space = space.with_coords(self.space, allow_subset=True) return space def _get_nevents(self): return tf.shape(input=self.value())[0] def __str__(self) -> str: return f"<zfit.Data: {self.name} obs={self.obs}>" def to_binned(self, space): from zfit._data.binneddatav1 import BinnedData return BinnedData.from_unbinned(space=space, data=self) def __getitem__(self, item): try: value = getitem_obs(self, item) except Exception as error: msg = ( f"Failed to retrieve {item} from data {self}. This can be changed behavior (since zfit 0.11): data can" f" no longer be accessed numpy-like but instead the 'obs' can be used, i.e. strings or spaces. This" f" resembles more closely the behavior of a pandas DataFrame." ) raise RuntimeError(msg) from error return value
# TODO(serialization): add to serializer class DataRepr(BaseRepr): _implementation = Data _owndict = pydantic.PrivateAttr(default_factory=dict) hs3_type: Literal["Data"] = Field("Data", alias="type") data: np.ndarray space: Union[SpaceRepr, list[SpaceRepr]] name: Optional[str] = None weights: Optional[np.ndarray] = None @pydantic.root_validator(pre=True) def extract_data(cls, values): if cls.orm_mode(values): values = dict(values) values["data"] = values["value"]() return values @pydantic.validator("space", pre=True) def flatten_spaces(cls, v): if cls.orm_mode(v): v = [v.get_subspace(o) for o in v.obs] return v @pydantic.validator("data", pre=True) def convert_data(cls, v): return np.asarray(v) @pydantic.validator("weights", pre=True) def convert_weights(cls, v): if v is not None: v = np.asarray(v) return v @to_orm_init def _to_orm(self, init): dataset = LightDataset(znp.asarray(init.pop("data"))) init["dataset"] = dataset init["obs"] = init.pop("space") spaces = init["obs"] space = spaces[0] for sp in spaces[1:]: space *= sp init["obs"] = space return super()._to_orm(init) def getitem_obs(self, item): if not isinstance(item, str): item = convert_to_obs_str(item) return self.value(item) class SampleData(Data): _cache_counting = 0 def __init__( self, dataset: tf.data.Dataset | LightDataset, obs: ztyping.ObsTypeInput = None, weights=None, name: str | None = None, dtype: tf.DType = ztypes.float, use_hash: bool | None = None, ): super().__init__( dataset, obs, name=name, weights=weights, dtype=dtype, use_hash=use_hash, ) @classmethod def get_cache_counting(cls): counting = cls._cache_counting cls._cache_counting += 1 return counting @classmethod def from_sample( # TODO(deprecate and remove? use normal data? cls, sample: tf.Tensor, obs: ztyping.ObsTypeInput, name: str | None = None, weights=None, use_hash: bool | None = None, ): return Data.from_tensor(tensor=sample, obs=obs, name=name, weights=weights, use_hash=use_hash) class Sampler(Data): _cache_counting = 0 def __init__( self, dataset: LightDataset, sample_func: Callable, sample_holder: tf.Variable, n: ztyping.NumericalScalarType | Callable, weights=None, fixed_params: dict[zfit.Parameter, ztyping.NumericalScalarType] | None = None, obs: ztyping.ObsTypeInput = None, name: str | None = None, dtype: tf.DType = ztypes.float, use_hash: bool | None = None, ): super().__init__( dataset=dataset, obs=obs, name=name, weights=weights, dtype=dtype, use_hash=use_hash, ) if fixed_params is None: fixed_params = OrderedDict() if isinstance(fixed_params, (list, tuple)): fixed_params = OrderedDict((param, param.numpy()) for param in fixed_params) # TODO: numpy -> read_value? self._initial_resampled = False self.fixed_params = fixed_params self.sample_holder = sample_holder self.sample_func = sample_func self.n = n self._n_holder = n self.resample() # to be used for precompilations etc @property def n_samples(self): return self._n_holder @property def _approx_nevents(self): nevents = super()._approx_nevents if nevents is None: nevents = self.n return nevents def _value_internal(self, obs: ztyping.ObsTypeInput = None, filter: bool = True): if not self._initial_resampled: msg = ( "No data generated yet. Use `resample()` to generate samples or directly use `model.sample()`" "for single-time sampling." ) raise RuntimeError(msg) return super()._value_internal(obs=obs, filter=filter) @property def hashint(self) -> int | None: return None # since the variable can be changed but this may stays static... and using 128 bits we can't have # a tf.Variable that keeps the int @classmethod def get_cache_counting(cls): counting = cls._cache_counting cls._cache_counting += 1 return counting @classmethod def from_sample( cls, sample_func: Callable, n: ztyping.NumericalScalarType, obs: ztyping.ObsTypeInput, fixed_params=None, name: str | None = None, weights=None, dtype=None, use_hash: bool | None = None, ): obs = convert_to_space(obs) if fixed_params is None: fixed_params = [] if dtype is None: dtype = ztypes.float sample_holder = tf.Variable( initial_value=sample_func(), dtype=dtype, trainable=False, shape=(None, obs.n_obs), name=f"sample_data_holder_{cls.get_cache_counting()}", ) dataset = LightDataset.from_tensor(sample_holder) return cls( dataset=dataset, sample_holder=sample_holder, sample_func=sample_func, fixed_params=fixed_params, n=n, obs=obs, name=name, weights=weights, use_hash=use_hash, ) def resample(self, param_values: Mapping | None = None, n: int | tf.Tensor = None): """Update the sample by newly sampling. This affects any object that used this data already. All params that are not in the attribute ``fixed_params`` will use their current value for the creation of the new sample. The value can also be overwritten for one sampling by providing a mapping with ``param_values`` from ``Parameter`` to the temporary ``value``. Args: param_values: a mapping from :py:class:`~zfit.Parameter` to a `value`. For the current sampling, `Parameter` will use the `value`. n: the number of samples to produce. If the `Sampler` was created with anything else then a numerical or tf.Tensor, this can't be used. """ if n is None: n = self.n temp_param_values = self.fixed_params.copy() if param_values is not None: temp_param_values.update(param_values) with set_values(list(temp_param_values.keys()), list(temp_param_values.values())): # if not (n and self._initial_resampled): # we want to load and make sure that it's initialized # # means it's handled inside the function # # TODO(Mayou36): check logic; what if new_samples loaded? get's overwritten by initializer # # fixed with self.n, needs cleanup # if not (isinstance(self.n_samples, str) or self.n_samples is None): # self.sess.run(self.n_samples.initializer) # if n: # if not isinstance(self.n_samples, tf.Variable): # raise RuntimeError("Cannot set a new `n` if not a Tensor-like object was given") # self.n_samples.assign(n) new_sample = self.sample_func(n) # self.sample_holder.assign(new_sample) self.sample_holder.assign(new_sample, read_value=False) self._initial_resampled = True self._update_hash() def __str__(self) -> str: return f"<Sampler: {self.name} obs={self.obs}>" # register_tensor_conversion(Data, name="Data", overload_operators=True) class LightDataset: def __init__(self, tensor): if not isinstance(tensor, (tf.Tensor, tf.Variable)): tensor = z.convert_to_tensor(tensor) self.tensor = tensor def batch(self, _): # ad-hoc just empty, mimicking tf.data.Dataset interface return self def __iter__(self): yield self.value() @classmethod def from_tensor(cls, tensor): return cls(tensor=tensor) def value(self): return self.tensor def sum_samples( sample1: ZfitUnbinnedData, sample2: ZfitUnbinnedData, obs: ZfitSpace, shuffle: bool = False, ): samples = [sample1, sample2] if obs is None: raise WorkInProgressError sample2 = sample2.value(obs=obs) if shuffle: sample2 = z.random.shuffle(sample2) sample1 = sample1.value(obs=obs) tensor = sample1 + sample2 if any(s.weights is not None for s in samples): msg = "Cannot combine weights currently" raise WorkInProgressError(msg) weights = None return SampleData.from_sample(sample=tensor, obs=obs, weights=weights)