Source code for zfit.util.graph
# Copyright (c) 2019 zfit
from typing import List
import tensorflow as tf
# TODO(Mayou36): make not recursive
[docs]def all_parents(op, current_obs=None):
if current_obs is None:
current_obs = set()
ops = set(input_.op for input_ in op.inputs if input_.op not in current_obs)
current_obs = current_obs.union(ops)
return ops.union(*(all_parents(op, current_obs=current_obs) for op in ops))
[docs]def get_dependents_auto(tensor: tf.Tensor, candidates: List[tf.Tensor]) -> List[tf.Tensor]:
"""Return the nodes in `candidates` that `tensor` depends on.
Args:
tensor ():
candidates ():
"""
try:
dependent_ops = all_parents(tensor.op)
except RuntimeError as error:
raise ValueError("Tensor too deeply nested, recursion limit exceeded. In the future,"
"implementation will be different and any dependents can be found."
"Currently, specify dependents explicitly if needed."
"Orignal Error: {}".format(error))
dependent_candidates = [cand for cand in candidates if cand.op in dependent_ops]
return dependent_candidates
if __name__ == '__main__':
a = tf.compat.v1.distributions.Normal(1., 3.).sample() * 5.
var1 = tf.compat.v1.get_variable('a1', 1.)
var2 = tf.compat.v1.get_variable('a2', 2.)
var3 = tf.compat.v1.get_variable('a3', 3.)
b = tf.constant(3.) + 4 * var1
c = 5. * b
d = c + b * var2
e = c * 3.
print(get_dependents_auto(e, [b, c, d, var1, var2, var3]))
print(get_dependents_auto(e, [var1, var2, var3]))