Source code for zfit.z.random
# Copyright (c) 2019 zfit
from typing import Union, Iterable, Sized
import tensorflow_probability as tfp
import tensorflow as tf
from .zextension import tf_function as function
from .wrapping_tf import convert_to_tensor
from ..util.container import convert_to_container
__all__ = ["counts_multinomial"]
[docs]def counts_multinomial(total_count: Union[int, tf.Tensor], probs: Iterable[Union[float, tf.Tensor]] = None,
logits: Iterable[Union[float, tf.Tensor]] = None, dtype=tf.int32) -> tf.Tensor:
"""Get the number of counts for different classes with given probs/logits.
Args:
total_count (int): The total number of draws.
probs: Length k (number of classes) object where the k-1th entry contains the probability to
get a single draw from the class k. Have to be from [0, 1] and sum up to 1.
logits: Same as probs but from [-inf, inf] (will be transformet to [0, 1])
Returns:
:py:class.`tf.Tensor`: shape (k,) tensor containing the number of draws.
"""
total_count = tf.convert_to_tensor(total_count)
probs = tf.convert_to_tensor(probs) if probs is not None else probs
logits = tf.convert_to_tensor(logits) if logits is not None else logits
control_deps = []
if probs is not None:
# if not isinstance(probs, (tf.Tensor, tf.Variable)):
# probs = convert_to_container(probs)
# if len(probs) < 2:
# raise ValueError("`probs` has to have length 2 at least.")
# probs = tf.convert_to_tensor(value=probs)
probs = tf.cast(probs, tf.float32)
# control_deps.append(probs)
# probs_logits_shape = tf.shape(probs)
elif logits is not None:
# if not isinstance(logits, (tf.Tensor, tf.Variable)):
# logits = convert_to_container(logits)
# if len(logits) < 2:
# raise ValueError("`logits` has to have length 2 at least.")
# logits = tf.convert_to_tensor(value=logits, dtype=None)
logits = tf.cast(logits, tf.float32)
# control_deps.append(logits)
# probs_logits_shape = tf.shape(logits)
else:
raise ValueError("Exactly one of `probs` or`logits` have to be specified")
# if not isinstance(total_count, tf.Variable):
# total_count = convert_to_tensor(total_count, dtype=None)
total_count = tf.cast(total_count, dtype=tf.float32)
control_deps.append(total_count)
# needed since otherwise shape of sample will be (1, n_probs)
# total_count = tf.broadcast_to(total_count, shape=probs_logits_shape)
@function
def wrapped_func(control_deps, dtype, logits, probs, total_count):
with tf.control_dependencies(control_deps):
dist = tfp.distributions.Multinomial(total_count=total_count, probs=probs, logits=logits)
counts = dist.sample()
counts = tf.cast(counts, dtype=dtype)
return counts
return wrapped_func(control_deps, dtype, logits, probs, total_count)