Source code for openmdao.lib.drivers.caseiterdriver

"""
.. _`caseiterdriver.py`:

"""

import logging
import os.path
import Queue
import sys
import thread
import threading
import traceback

from openmdao.main.datatypes.api import Bool, Dict, Enum, Int, Slot

from openmdao.main.api import Driver
from openmdao.main.exceptions import RunStopped, TracedError, traceback_str
from openmdao.main.expreval import ExprEvaluator
from openmdao.main.interfaces import ICaseIterator, ICaseRecorder, ICaseFilter
from openmdao.main.rbac import get_credentials, set_credentials
from openmdao.main.resource import ResourceAllocationManager as RAM
from openmdao.main.resource import LocalAllocator
from openmdao.util.filexfer import filexfer

from openmdao.util.decorators import add_delegate
from openmdao.main.hasparameters import HasParameters

from openmdao.lib.casehandlers.api import ListCaseRecorder

_EMPTY     = 'empty'
_LOADING   = 'loading'
_EXECUTING = 'executing'

class _ServerError(Exception):
    """ Raised when a server thread has problems. """
    pass


[docs]class CaseIterDriverBase(Driver): """ A base class for Drivers that run sets of cases in a manner similar to the ROSE framework. Concurrent evaluation is supported, with the various evaluations executed across servers obtained from the :class:`ResourceAllocationManager`. """ sequential = Bool(True, iotype='in', desc='If True, evaluate cases sequentially.') reload_model = Bool(True, iotype='in', desc='If True, reload the model between executions.') error_policy = Enum(values=('ABORT', 'RETRY'), iotype='in', desc='If ABORT, any error stops the evaluation of the' ' whole set of cases.') max_retries = Int(1, low=0, iotype='in', desc='Maximum number of times to retry a failed case.') extra_resources = Dict(iotype='in', desc='Extra resource requirements (unusual).') ignore_egg_requirements = Bool(False, iotype='in', desc='If True, no distribution or orphan' ' requirements will be included in the' ' generated egg.') def __init__(self, *args, **kwargs): super(CaseIterDriverBase, self).__init__(*args, **kwargs) self._iter = None # Set to None when iterator is empty. self._seqno = 0 # Used to set itername for case. self._replicants = 0 self._abort_exc = None # Set if error_policy == ABORT. self._egg_file = None self._egg_required_distributions = None self._egg_orphan_modules = None self._reply_q = None # Replies from server threads. self._server_lock = None # Lock for server data. # Various per-server data keyed by server name. self._servers = {} self._top_levels = {} self._server_info = {} self._queues = {} self._in_use = {} self._server_states = {} self._server_cases = {} self._exceptions = {} self._load_failures = {} self._todo = [] # Cases grabbed during server startup. self._rerun = [] # Cases that failed and should be retried. self._generation = 0 # Used to keep worker names unique.
[docs] def execute(self): """ Runs all cases and records results in `recorder`. Uses :meth:`setup` and :meth:`resume` with default arguments. """ self.setup() self.resume()
[docs] def resume(self, remove_egg=True): """ Resume execution. remove_egg: bool If True, then the egg file created for concurrent evaluation is removed at the end of the run. Re-using the egg file can eliminate a lot of startup overhead. """ self._stop = False self._abort_exc = None if self._iter is None: self.raise_exception('Run already complete', RuntimeError) try: if self.sequential: self._logger.info('Start sequential evaluation.') while self._iter is not None: if self._stop: break try: self.step() except StopIteration: break else: self._logger.info('Start concurrent evaluation.') self._start() finally: self._cleanup(remove_egg) if self._stop: if self._abort_exc is None: self.raise_exception('Run stopped', RunStopped) else: self.raise_exception('Run aborted: %s' % traceback_str(self._abort_exc), RuntimeError)
[docs] def step(self): """ Evaluate the next case. """ self._stop = False self._abort_exc = None if self._iter is None: self.setup() try: case = self._iter.next() except StopIteration: if not self._rerun: self._iter = None self._seqno = 0 raise self._seqno += 1 self._todo.append((case, self._seqno)) self._server_cases[None] = None self._server_states[None] = _EMPTY while self._server_ready(None, stepping=True): pass
[docs] def stop(self): """ Stop evaluating cases. """ # Necessary to avoid default driver handling of stop signal. self._stop = True
[docs] def setup(self, replicate=True): """ Setup to begin new run. replicate: bool If True, then replicate the model and save to an egg file first (for concurrent evaluation). """ self._cleanup(remove_egg=replicate) if not self.sequential: if replicate or self._egg_file is None: # Save model to egg. # Must do this before creating any locks or queues. self._replicants += 1 version = 'replicant.%d' % (self._replicants) # If only local host will be used, we can skip determining # distributions required by the egg. allocators = RAM.list_allocators() need_reqs = False if not self.ignore_egg_requirements: for allocator in allocators: if not isinstance(allocator, LocalAllocator): need_reqs = True break driver = self.parent.driver self.parent.add('driver', Driver()) # this driver will execute the workflow once self.parent.driver.workflow = self.workflow try: #egg_info = self.model.save_to_egg(self.model.name, version) # FIXME: what name should we give to the egg? egg_info = self.parent.save_to_egg(self.name, version, need_requirements=need_reqs) finally: self.parent.driver = driver self._egg_file = egg_info[0] self._egg_required_distributions = egg_info[1] self._egg_orphan_modules = [name for name, path in egg_info[2]] self._iter = self.get_case_iterator() self._seqno = 0
[docs] def get_case_iterator(self): """Returns a new iterator over the Case set.""" raise NotImplementedError('get_case_iterator')
def _start(self): """ Start evaluating cases concurrently. """ # Need credentials in case we're using a PublicKey server. credentials = get_credentials() # Determine maximum number of servers available. resources = { 'required_distributions':self._egg_required_distributions, 'orphan_modules':self._egg_orphan_modules, 'python_version':sys.version[:3]} if self.extra_resources: resources.update(self.extra_resources) max_servers = RAM.max_servers(resources) self._logger.debug('max_servers %d', max_servers) if max_servers <= 0: msg = 'No servers supporting required resources %s' % resources self.raise_exception(msg, RuntimeError) # Kick off initial wave of cases. self._server_lock = threading.Lock() self._reply_q = Queue.Queue() self._generation += 1 n_servers = 0 while n_servers < max_servers: if not self._more_to_go(): break # Get next case. Limits servers started if max_servers > cases. try: case = self._iter.next() except StopIteration: if not self._rerun: self._iter = None self._seqno = 0 break self._seqno += 1 self._todo.append((case, self._seqno)) # Start server worker thread. n_servers += 1 name = '%s_%d_%d' % (self.name, self._generation, n_servers) self._logger.debug('starting worker for %r', name) self._servers[name] = None self._in_use[name] = True self._server_cases[name] = None self._server_states[name] = _EMPTY self._load_failures[name] = 0 server_thread = threading.Thread(target=self._service_loop, args=(name, resources, credentials, self._reply_q)) server_thread.daemon = True try: server_thread.start() except thread.error: self._logger.warning('worker thread startup failed for %r', name) self._in_use[name] = False break if sys.platform != 'win32': # Process any pending events. while self._busy(): try: name, result, exc = self._reply_q.get(True, 0.01) except Queue.Empty: break # Timeout. else: # Difficult to force startup failure. if self._servers[name] is None: #pragma nocover self._logger.debug('server startup failed for %r', name) self._in_use[name] = False else: self._in_use[name] = self._server_ready(name) if sys.platform == 'win32': #pragma no cover # Don't start server processing until all servers are started, # otherwise we have egg removal issues. for name in self._in_use.keys(): name, result, exc = self._reply_q.get() if self._servers[name] is None: self._logger.debug('server startup failed for %r', name) self._in_use[name] = False # Kick-off started servers. for name in self._in_use.keys(): if self._in_use[name]: self._in_use[name] = self._server_ready(name) # Continue until no servers are busy. while self._busy(): if self._more_to_go(): timeout = None else: # Don't wait indefinitely for a server we don't need. # This has happened with a server that got 'lost' # in RAM.allocate() timeout = 60 try: name, result, exc = self._reply_q.get(timeout=timeout) # Hard to force worker to hang, which is handled here. except Queue.Empty: #pragma no cover msgs = [] for name, in_use in self._in_use.items(): if in_use: try: server = self._servers[name] info = self._server_info[name] except KeyError: msgs.append('%r: no startup reply' % name) self._in_use[name] = False else: state = self._server_states[name] if state not in (_LOADING, _EXECUTING): msgs.append('%r: %r %s %s' % (name, self._servers[name], state, self._server_info[name])) self._in_use[name] = False if msgs: self._logger.error('Timeout waiting with nothing left to do:') for msg in msgs: self._logger.error(' %s', msg) else: self._in_use[name] = self._server_ready(name) # Shut-down (started) servers. self._logger.debug('Shut-down (started) servers') for queue in self._queues.values(): queue.put(None) for i in range(len(self._queues)): try: name, status, exc = self._reply_q.get(True, 60) # Hard to force worker to hang, which is handled here. except Queue.Empty: #pragma no cover pass else: if name in self._queues: # 'Stale' worker can reply *late*. del self._queues[name] # Hard to force worker to hang, which is handled here. for name in self._queues.keys(): #pragma no cover self._logger.warning('Timeout waiting for %r to shut-down.', name) def _busy(self): """ Return True while at least one server is in use. """ return any(self._in_use.values()) def _cleanup(self, remove_egg=True): """ Cleanup internal state, and egg file if necessary. Note: this happens unconditionally, so it will cause issues for workers which haven't shut down by now. """ self._reply_q = None self._server_lock = None self._servers = {} self._top_levels = {} self._server_info = {} self._queues = {} self._in_use = {} self._server_states = {} self._server_cases = {} self._exceptions = {} self._load_failures = {} self._todo = [] self._rerun = [] if self._egg_file and os.path.exists(self._egg_file): os.remove(self._egg_file) self._egg_file = None def _server_ready(self, server, stepping=False): """ Responds to asynchronous callbacks during :meth:`execute` to run cases retrieved from `self._iter`. Results are processed by `recorder`. If `stepping`, then we don't grab any new cases. Returns True if this server is still in use. """ state = self._server_states[server] self._logger.debug('server %r state %s', server, state) in_use = True if state == _EMPTY: if server is None or server in self._queues: if self._more_to_go(stepping): self._logger.debug(' load_model') self._load_model(server) self._server_states[server] = _LOADING else: self._logger.debug(' no more cases') in_use = False # Difficult to force startup failure. else: #pragma nocover in_use = False # Never started. elif state == _LOADING: exc = self._model_status(server) if exc is None: in_use = self._start_next_case(server, stepping) else: self._logger.debug(' exception while loading: %r', exc) if self.error_policy == 'ABORT': if self._abort_exc is None: self._abort_exc = exc self._stop = True self._server_states[server] = _EMPTY in_use = False else: self._load_failures[server] += 1 if self._load_failures[server] < 3: in_use = self._start_processing(server, stepping) else: self._logger.debug(' too many load failures') self._server_states[server] = _EMPTY in_use = False elif state == _EXECUTING: case, seqno = self._server_cases[server] self._server_cases[server] = None exc = self._model_status(server) if exc is None: # Grab the data from the model. scope = self.parent if server is None else self._top_levels[server] try: case.update_outputs(scope) except Exception as exc: msg = 'Exception getting case outputs: %s' % exc self._logger.debug(' %s', msg) case.msg = '%s: %s' % (self.get_pathname(), msg) else: self._logger.debug(' exception while executing: %r', exc) case.msg = str(exc) if case.msg is not None and self.error_policy == 'ABORT': if self._abort_exc is None: self._abort_exc = exc self._stop = True # Record the data. self._record_case(case, seqno) # Set up for next case. in_use = self._start_processing(server, stepping, reload=True) # Just being defensive, should never happen. else: #pragma no cover msg = 'unexpected state %r for server %r' % (state, server) self._logger.error(msg) if self.error_policy == 'ABORT': if self._abort_exc is None: self._abort_exc = RuntimeError(msg) self._stop = True in_use = False return in_use def _more_to_go(self, stepping=False): """ Return True if there's more work to do. """ if self._stop: return False if self._todo or self._rerun: return True if not stepping and self._iter is not None: return True return False def _start_processing(self, server, stepping, reload=False): """ If there's something to do, start processing by either loading the model, or going straight to running it. """ if self._more_to_go(stepping): if reload: if self.reload_model: self._logger.debug(' reload') self._load_model(server) self._server_states[server] = _LOADING in_use = True else: in_use = self._start_next_case(server) else: self._logger.debug(' load') self._load_model(server) self._server_states[server] = _LOADING in_use = True else: self._logger.debug(' no more cases') self._server_states[server] = _EMPTY in_use = False return in_use def _start_next_case(self, server, stepping=False): """ Look for the next case and start it. """ if self._todo: self._logger.debug(' run startup case') case, seqno = self._todo.pop(0) in_use = self._run_case(case, seqno, server) elif self._rerun: self._logger.debug(' rerun case') case, seqno = self._rerun.pop(0) in_use = self._run_case(case, seqno, server, rerun=True) elif self._iter is None: self._logger.debug(' no more cases') in_use = False elif stepping: in_use = False else: try: case = self._iter.next() except StopIteration: self._logger.debug(' no more cases') self._iter = None self._seqno = 0 in_use = False else: self._logger.debug(' run next case') self._seqno += 1 in_use = self._run_case(case, self._seqno, server) return in_use def _run_case(self, case, seqno, server, rerun=False): """ Setup and start a case. Returns True if started. """ if not rerun: if not case.max_retries: case.max_retries = self.max_retries case.retries = 0 case.msg = None case.parent_uuid = self._case_id # Additional user-requested variables # These must be added here so that the outputs are in the cases # before they are in the server list. for printvar in self.printvars: if '*' in printvar: printvars = self._get_all_varpaths(printvar) else: printvars = [printvar] for var in printvars: val = ExprEvaluator(var, scope=self.parent).evaluate() case.add_output(var, val) try: for event in self.get_events(): try: self._model_set(server, event, None, True) except Exception as exc: msg = 'Exception setting %r: %s' % (event, exc) self._logger.debug(' %s', msg) self.raise_exception(msg, _ServerError) try: scope = self.parent if server is None else self._top_levels[server] case.apply_inputs(scope) except Exception as exc: msg = 'Exception setting case inputs: %s' % exc self._logger.debug(' %s', msg) self.raise_exception(msg, _ServerError) self._server_cases[server] = (case, seqno) self._model_execute(server) self._server_states[server] = _EXECUTING except _ServerError as exc: case.msg = str(exc) self._record_case(case, seqno) return self._start_processing(server, stepping=False) else: return True def _record_case(self, case, seqno): """ If successful, record the case. Otherwise possibly retry. """ if case.msg and case.retries < case.max_retries: case.msg = None case.retries += 1 self._rerun.append((case, seqno)) else: for recorder in self.recorders: recorder.record(case) def _service_loop(self, name, resource_desc, credentials, reply_q): """ Each server has an associated thread executing this. """ set_credentials(credentials) server, server_info = RAM.allocate(resource_desc) # Just being defensive, this should never happen. if server is None: #pragma no cover self._logger.error('Server allocation for %r failed :-(', name) reply_q.put((name, False, None)) return else: # Clear egg re-use indicator. server_info['egg_file'] = None self._logger.debug('%r using %r', name, server_info['name']) if self._logger.level == logging.NOTSET: # By default avoid lots of protocol messages. server.set_log_level(logging.DEBUG) else: server.set_log_level(self._logger.level) request_q = Queue.Queue() try: with self._server_lock: self._servers[name] = server self._server_info[name] = server_info self._queues[name] = request_q reply_q.put((name, True, None)) # ACK startup. while True: request = request_q.get() if request is None: break try: result = request[0](request[1]) except Exception as req_exc: self._logger.error('%r: %s caused %r', name, request[0], req_exc) result = None else: req_exc = None reply_q.put((name, result, req_exc)) except Exception as exc: # pragma no cover # This can easily happen if we take a long time to allocate and # we get 'cleaned-up' before we get started. if self._server_lock is not None: self._logger.error('%r: %r', name, exc) finally: self._logger.debug('%r releasing server', name) RAM.release(server) reply_q.put((name, True, None)) # ACK shutdown. def _load_model(self, server): """ Load a model into a server. """ self._exceptions[server] = None if server is not None: self._queues[server].put((self._remote_load_model, server)) def _remote_load_model(self, server): """ Load model into remote server. """ egg_file = self._server_info[server].get('egg_file', None) if egg_file is None or egg_file is not self._egg_file: # Only transfer if changed. try: filexfer(None, self._egg_file, self._servers[server], self._egg_file, 'b') # Difficult to force model file transfer error. except Exception as exc: #pragma nocover self._logger.error('server %r filexfer of %r failed: %r', server, self._egg_file, exc) self._top_levels[server] = None self._exceptions[server] = TracedError(exc, traceback.format_exc()) return else: self._server_info[server]['egg_file'] = self._egg_file try: tlo = self._servers[server].load_model(self._egg_file) # Difficult to force load error. except Exception as exc: #pragma nocover self._logger.error('server.load_model of %r failed: %r', self._egg_file, exc) self._top_levels[server] = None self._exceptions[server] = TracedError(exc, traceback.format_exc()) else: self._top_levels[server] = tlo def _model_set(self, server, name, index, value): """ Set value in server's model. """ if server is None: self.parent.set(name, value, index) else: self._top_levels[server].set(name, value, index) def _model_execute(self, server): """ Execute model in server. """ self._exceptions[server] = None if server is None: try: self.workflow.run(case_id=self._server_cases[server][0].uuid) except Exception as exc: self._exceptions[server] = TracedError(exc, traceback.format_exc()) self._logger.critical('Caught exception: %r' % exc) else: self._queues[server].put((self._remote_model_execute, server)) def _remote_model_execute(self, server): """ Execute model in remote server. """ case, seqno = self._server_cases[server] try: self._top_levels[server].set_itername(self.get_itername(), seqno) self._top_levels[server].run(case_id=case.uuid) except Exception as exc: self._exceptions[server] = TracedError(exc, traceback.format_exc()) self._logger.error('Caught exception from server %r, PID %d on %s: %r', self._server_info[server]['name'], self._server_info[server]['pid'], self._server_info[server]['host'], exc) def _model_status(self, server): """ Return execute status from model. """ return self._exceptions[server]
[docs]class CaseIteratorDriver(CaseIterDriverBase): """ Run a set of cases provided by an :class:`ICaseIterator`. Concurrent evaluation is supported, with the various evaluations executed across servers obtained from the :class:`ResourceAllocationManager`. """ iterator = Slot(ICaseIterator, iotype='in', desc='Iterator supplying Cases to evaluate.') evaluated = Slot(ICaseIterator, iotype='out', desc='Iterator supplying evaluated Cases.') filter = Slot(ICaseFilter, iotype='in', desc='Filter used to select cases to evaluate.')
[docs] def get_case_iterator(self): """Returns a new iterator over the Case set.""" if self.iterator is not None: if self.filter is None: return iter(self.iterator) else: return self._select_cases() else: self.raise_exception("iterator has not been set", ValueError)
def _select_cases(self): """ Select cases to be evaluated. """ for i, case in enumerate(iter(self.iterator)): if self.filter.select(i, case): yield case
[docs] def execute(self): """ Evaluate cases from `iterator` and place in `evaluated`. """ self.evaluated = None self.recorders.append(ListCaseRecorder()) try: super(CaseIteratorDriver, self).execute() finally: self.evaluated = self.recorders.pop().get_iterator()
OpenMDAO Home