Source code for zfit.util.execution
# Copyright (c) 2019 zfit
import contextlib
import copy
import multiprocessing
import os
import sys
from typing import List
import warnings
import tensorflow as tf
import zfit
from .temporary import TemporarilySet
from .container import DotDict
[docs]class RunManager:
def __init__(self, n_cpu='auto'):
"""Handle the resources and runtime specific options. The `run` method is equivalent to `sess.run`"""
self.MAX_CHUNK_SIZE = sys.maxsize
self.sess = None
self._sess_kwargs = {}
self.chunking = DotDict()
self._cpu = []
self.numeric_checks = True
self.set_n_cpu(n_cpu=n_cpu)
# HACK
self._enable_parameter_autoconversion = True
# HACK END
# set default values
self.chunking.active = False # not yet implemented the chunking...
self.chunking.max_n_points = 1000000
[docs] def auto_initialize(self, variable: tf.Variable):
self(variable.initializer)
@property
def chunksize(self):
if self.chunking.active:
return self.chunking.max_n_points
else:
return self.MAX_CHUNK_SIZE
@property
def n_cpu(self):
return len(self._cpu)
[docs] def set_n_cpu(self, n_cpu='auto'):
if n_cpu == 'auto':
try:
cpu = sorted(os.sched_getaffinity(0))
except AttributeError:
cpu = range(multiprocessing.cpu_count())
warnings.warn("Not running on Linux. Determining available cpus for thread can fail"
"and be overestimated. Workaround (only if too many cpus are used):"
"`zfit.run.set_n_cpu(your_cpu_number)`")
elif isinstance(n_cpu, int):
cpu = range(n_cpu)
self._cpu = ['dummy_cpu{}'.format(i) for i in cpu]
[docs] @contextlib.contextmanager
def aquire_cpu(self, max_cpu: int = -1) -> List[str]:
if isinstance(max_cpu, int):
if max_cpu < 0:
max_cpu = max((self.n_cpu + 1 + max_cpu, 0)) # -1 means all
if max_cpu == 0:
cpu = []
else:
n_cpu = min((max_cpu, self.n_cpu))
cpu = self._cpu[-n_cpu:]
self._cpu = self._cpu[:-n_cpu]
yield cpu
self._cpu.extend(cpu)
def __call__(self, *args, **kwargs):
return self.sess.run(*args, **kwargs)
# def close(self):
# """Closes the current session."""
[docs] def reset(self):
if self._sess is not None:
self.sess.close()
tf.compat.v1.reset_default_graph()
[docs] def create_session(self, *args, close_current=True, reset_graph=False, **kwargs):
"""Create a new session (or replace the current one). Arguments will overwrite the already set arguments.
Args:
close_current (bool): Closes the current open session before replacement. Has no effect if
no session was created before.
reset_graph (bool): Resets the current (default) graph before creating a new :py:class:`tf.compat.v1.Session`.
*args ():
**kwargs ():
Returns:
:py:class:`tf.compat.v1.Session`
"""
sess_kwargs = copy.deepcopy(self._sess_kwargs)
sess_kwargs.update(kwargs)
if close_current and self._sess is not None:
self.sess.close()
if reset_graph:
tf.compat.v1.reset_default_graph()
from zfit.core.parameter import ZfitParameterMixin
ZfitParameterMixin._existing_names = set() # TODO(Mayou36): better hook for reset?
self.sess = tf.compat.v1.Session(*args, **sess_kwargs)
from ..settings import ztypes
tf.compat.v1.get_variable_scope().set_use_resource(True)
tf.compat.v1.get_variable_scope().set_dtype(ztypes.float)
return self.sess
@property
def sess(self):
if self._sess is None:
self.create_session()
return self._sess
@sess.setter
def sess(self, value):
self._sess = value
[docs]class SessionHolderMixin:
def __init__(self, *args, **kwargs):
"""Creates a `self.sess` attribute, a setter `set_sess` (with a fallback to the zfit default session)."""
super().__init__(*args, **kwargs)
self._sess = None
[docs] def set_sess(self, sess: tf.compat.v1.Session):
"""Set the session (temporarily) for this instance. If None, the auto-created default is taken.
Args:
sess (tf.compat.v1.Session):
"""
if not isinstance(sess, tf.compat.v1.Session):
raise TypeError("`sess` has to be a TensorFlow Session but is {}".format(sess))
def getter():
return self._sess # use private attribute! self.sess creates default session
def setter(value):
self.sess = value
return TemporarilySet(value=sess, setter=setter, getter=getter)
@property
def sess(self):
sess = self._sess
if sess is None:
sess = zfit.run.sess
return sess
@sess.setter
def sess(self, sess: tf.compat.v1.Session):
self._sess = sess