Source code for zfit.util.graph
# Copyright (c) 2019 zfit
from typing import List
import tensorflow as tf
# TODO(Mayou36): make not recursive
[docs]def all_parents(op, current_obs=None):
if current_obs is None:
current_obs = set()
ops = set(input_.op for input_ in op.inputs if input_.op not in current_obs)
current_obs = current_obs.union(ops)
return ops.union(*(all_parents(op, current_obs=current_obs) for op in ops))
[docs]def get_dependents_auto(tensor: tf.Tensor, candidates: List[tf.Tensor]) -> List[tf.Tensor]:
"""Return the nodes in `candidates` that `tensor` depends on.
Args:
tensor ():
candidates ():
"""
try:
dependent_ops = all_parents(tensor.op)
except RuntimeError as error:
raise ValueError("Tensor too deeply nested, recursion limit exceeded. In the future,"
"implementation will be different and any dependents can be found."
"Currently, specify dependents explicitly if needed."
"Orignal Error: {}".format(error))
dependent_candidates = [cand for cand in candidates if cand.op in dependent_ops]
return dependent_candidates