zextension¶
-
zfit.ztf.zextension.
constant
(value, dtype=tf.float64, shape=None, name='Const', verify_shape=None)[source]¶
-
zfit.ztf.zextension.
nth_pow
(x, n, name=None)[source]¶ Calculate the nth power of the complex Tensor x.
Parameters:
-
zfit.ztf.zextension.
safe_where
(condition: tensorflow.python.framework.ops.Tensor, func: Callable, safe_func: Callable, values: tensorflow.python.framework.ops.Tensor, value_safer: Callable = <function add_dispatch_support.<locals>.wrapper>) → tensorflow.python.framework.ops.Tensor[source]¶ Like
tf.where()
but fixes gradient NaN if func produces NaN with certain values.Parameters: - condition (
tf.Tensor
) – Same argument as totf.where()
, a booleantf.Tensor
- func (Callable) – Function taking values as argument and returning the tensor _in case
condition is True_. Equivalent x of
tf.where()
but as function. - safe_func (Callable) – Function taking values as argument and returning the tensor
_in case the condition is False_, Equivalent y of
tf.where()
but as function. - values (
tf.Tensor
) – Values to be evaluated either by func or safe_func depending on condition. - value_safer (Callable) – Function taking values as arguments and returns “safe” values that won’t cause troubles when given to`func` or by taking the gradient with respect to func(value_safer(values)).
Returns: Return type: tf.Tensor
- condition (
-
zfit.ztf.zextension.
unstack_x
(value: Any, num: Any = None, axis: int = -1, always_list: bool = False, name: str = 'unstack_x')[source]¶ Unstack a Data object and return a list of (or a single) tensors in the right order.
Parameters: Returns: Return type: Union[List[tensorflow.python.framework.ops.Tensor], tensorflow.python.framework.ops.Tensor, None]