Source code for zfit.z.wrapping_tf

#  Copyright (c) 2019 zfit

import functools
from typing import Any

from tensorflow.compat.v1 import DType
import tensorflow as tf



from .tools import _auto_upcast
from . import zextension
from ..settings import ztypes


[docs]def log(x, name=None): x = _auto_upcast(x) return _auto_upcast(tf.math.log(x=x, name=name))
[docs]def exp(x, name=None): return _auto_upcast(tf.exp(x=x, name=name))
@functools.wraps(tf.convert_to_tensor) def convert_to_tensor(value, dtype=ztypes.float, name=None, preferred_dtype=ztypes.float): return tf.convert_to_tensor(value=value, dtype=dtype, name=name, dtype_hint=preferred_dtype)
[docs]def random_normal(shape, mean=0.0, stddev=1.0, dtype=ztypes.float, seed=None, name=None): return tf.random.normal(shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed, name=name)
[docs]def random_uniform(shape, minval=0, maxval=None, dtype=ztypes.float, seed=None, name=None): return tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed, name=name)
[docs]def random_poisson(lam: Any, shape: Any, dtype: DType = ztypes.float, seed: Any = None, name: Any = None): return tf.random.poisson(lam=lam, shape=shape, dtype=dtype, seed=seed, name=name)
[docs]def square(x, name=None): return _auto_upcast(tf.square(x, name))
[docs]def sqrt(x, name=None): return _auto_upcast(tf.sqrt(x, name=name))
[docs]def pow(x, y, name=None): return _auto_upcast(tf.pow(x, y, name=name))
[docs]def complex(real, imag, name=None): real = _auto_upcast(real) imag = _auto_upcast(imag) return _auto_upcast(tf.complex(real=real, imag=imag, name=name))
[docs]def check_numerics(tensor: Any, message: Any, name: Any = None): """Check whether a tensor is finite and not NaN. Extends TF by accepting complex types as well. Args: tensor (:py:class:~`tensorflow.python.framework.ops.Tensor`): message (str): name (Union[None, None, None]): Returns: tensorflow.python.framework.ops.Tensor: """ if tensor.dtype in (tf.complex64, tf.complex128): real_check = tf.debugging.check_numerics(tensor=tf.math.real(tensor), message=message, name=name) imag_check = tf.debugging.check_numerics(tensor=tf.math.imag(tensor), message=message, name=name) check_op = tf.group(real_check, imag_check) else: check_op = tf.debugging.check_numerics(tensor=tensor, message=message, name=name) return check_op
# # @functools.wraps(tf.reduce_sum) # def reduce_sum(*args, **kwargs): # return tf.reduce_sum(*args, **kwargs) reduce_sum = tf.reduce_sum reduce_prod = tf.reduce_prod