zextension

class zfit.z.zextension.FunctionWrapperRegistry(**kwargs_user)[source]

Bases: object

tf.function-like decorator with additional cache-invalidation functionality.

Parameters:**kwargs_user – arguments to tf.function
classmethod check_wrapped_functions_registered()[source]
registries = [<zfit.z.zextension.FunctionWrapperRegistry object>]
reset(**kwargs_user)[source]
wrapped_functions = []
zfit.z.zextension.abs_square(x)[source]
zfit.z.zextension.constant(value, dtype=tf.float64, shape=None, name='Const', verify_shape=None)[source]
zfit.z.zextension.convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None)[source]
zfit.z.zextension.nth_pow(x, n, name=None)[source]

Calculate the nth power of the complex Tensor x.

Parameters:
  • x (tf.Tensor, complex) –
  • n (int >= 0) – Power
  • name (str) – No effect, for API compatibility with tf.pow
zfit.z.zextension.run_no_nan(func, x)[source]
zfit.z.zextension.safe_where(condition: tensorflow.python.framework.ops.Tensor, func: Callable, safe_func: Callable, values: tensorflow.python.framework.ops.Tensor, value_safer: Callable = <function ones_like_v2>) → tensorflow.python.framework.ops.Tensor[source]

Like tf.where() but fixes gradient NaN if func produces NaN with certain values.

Parameters:
  • condition (tf.Tensor) – Same argument as to tf.where(), a boolean tf.Tensor
  • func (Callable) – Function taking values as argument and returning the tensor _in case condition is True_. Equivalent x of tf.where() but as function.
  • safe_func (Callable) – Function taking values as argument and returning the tensor _in case the condition is False_, Equivalent y of tf.where() but as function.
  • values (tf.Tensor) – Values to be evaluated either by func or safe_func depending on condition.
  • value_safer (Callable) – Function taking values as arguments and returns “safe” values that won’t cause troubles when given to`func` or by taking the gradient with respect to func(value_safer(values)).
Returns:

Return type:

tf.Tensor

zfit.z.zextension.stack_x(values, axis: int = -1, name: str = 'stack_x')[source]
zfit.z.zextension.to_complex(number, dtype=tf.complex128)[source]
zfit.z.zextension.to_real(x, dtype=tf.float64)[source]
zfit.z.zextension.unstack_x(value: Any, num: Any = None, axis: int = -1, always_list: bool = False, name: str = 'unstack_x')[source]

Unstack a Data object and return a list of (or a single) tensors in the right order.

Parameters:
  • () (value) –
  • num (Union[]) –
  • axis (int) –
  • always_list (bool) – If True, also return a list if only one element.
  • name (str) –
Returns:

Return type:

Union[List[tensorflow.python.framework.ops.Tensor], tensorflow.python.framework.ops.Tensor, None]