Source code for openmdao.recorders.sqlite_reader

"""
Definition of the SqliteCaseReader.
"""
from __future__ import print_function, absolute_import

import sqlite3
from collections import OrderedDict

from six import PY2, PY3, iteritems, string_types
from six.moves import range

import numpy as np

from openmdao.recorders.base_case_reader import BaseCaseReader
from openmdao.recorders.case import Case, PromAbsDict

from openmdao.utils.general_utils import simple_warning
from openmdao.utils.record_util import check_valid_sqlite3_db, get_source_system

from openmdao.recorders.sqlite_recorder import format_version

if PY2:
    import cPickle as pickle
    from openmdao.utils.general_utils import json_loads_byteified as json_loads
elif PY3:
    import pickle
    from json import loads as json_loads


[docs]class SqliteCaseReader(BaseCaseReader): """ A CaseReader specific to files created with SqliteRecorder. Attributes ---------- problem_metadata : dict Metadata about the problem, including the system hierachy and connections. solver_metadata : dict The solver options for each solver in the recorded model. system_metadata : dict Metadata about each system in the recorded model, including options and scaling factors. _format_version : int The version of the format assumed when loading the file. _solver_metadata : dict Metadata for all the solvers in the model, including their type and options _filename : str The path to the filename containing the recorded data. _abs2meta : dict Dictionary mapping variables to their metadata _abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. _prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. _output2meta : dict Dictionary mapping output variables to their metadata _input2meta : dict Dictionary mapping input variables to their metadata _driver_cases : DriverCases Helper object for accessing cases from the driver_iterations table. _deriv_cases : DerivCases Helper object for accessing cases from the driver_derivatives table. _system_cases : SystemCases Helper object for accessing cases from the system_iterations table. _solver_cases : SolverCases Helper object for accessing cases from the solver_iterations table. _problem_cases : ProblemCases Helper object for accessing cases from the problem_cases table. _global_iterations : list List of iteration cases and the table and row in which they are found. """
[docs] def __init__(self, filename, pre_load=False): """ Initialize. Parameters ---------- filename : str The path to the filename containing the recorded data. pre_load : bool If True, load all the data into memory during initialization. """ super(SqliteCaseReader, self).__init__(filename, pre_load) check_valid_sqlite3_db(filename) # initialize private attributes self._filename = filename self._abs2prom = None self._prom2abs = None self._abs2meta = None self._output2meta = None self._input2meta = None self._global_iterations = None # collect metadata from database with sqlite3.connect(filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() # collect data from the metadata table. this includes: # format_version # VOI metadata, which is added to problem_metadata # var name maps and metadata for all vars, which are saved as private attributes self._collect_metadata(cur) # collect data from the driver_metadata table. this includes: # model viewer data, which is added to problem_metadata self._collect_driver_metadata(cur) # collect data from the system_metadata table. this includes: # component metadata and scaling factors for each system, # which is added to system_metadata self._collect_system_metadata(cur) # collect data from the solver_metadata table. this includes: # solver class and options for each solver, which is saved as an attribute self._collect_solver_metadata(cur) # get the global iterations table, and save it as an attribute self._global_iterations = self._get_global_iterations(cur) con.close() # create maps to facilitate accessing variable metadata using absolute or promoted name self._output2meta = PromAbsDict(self._abs2meta, self._prom2abs, self._abs2prom, 1) self._input2meta = PromAbsDict(self._abs2meta, self._prom2abs, self._abs2prom, 0) # create helper objects for accessing cases from the three iteration tables and # the problem cases table var_info = self.problem_metadata['variables'] self._driver_cases = DriverCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, var_info) self._system_cases = SystemCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, var_info) self._solver_cases = SolverCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, var_info) if self._format_version >= 2: self._problem_cases = ProblemCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, var_info) # if requested, load all the iteration data into memory if pre_load: self._load_cases()
def _collect_metadata(self, cur): """ Load data from the metadata table. Populates the `format_version` attribute and the `variables` data in the `problem_metadata` attribute of this CaseReader. Also saves the variable name maps and variable metadata to private attributes. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute('select * from metadata') row = cur.fetchone() # get format_version self._format_version = version = row['format_version'] if version not in range(1, format_version + 1): raise ValueError('SQliteCaseReader encountered an unhandled ' 'format version: {0}'.format(self._format_version)) # add metadata for VOIs (des vars, objective, constraints) to problem metadata if version >= 4: self.problem_metadata['variables'] = json_loads(row['var_settings']) else: self.problem_metadata['variables'] = None # get variable name maps and metadata for all variables if version >= 3: self._abs2prom = json_loads(row['abs2prom']) self._prom2abs = json_loads(row['prom2abs']) self._abs2meta = json_loads(row['abs2meta']) # need to convert bounds to numpy arrays for name, meta in iteritems(self._abs2meta): if 'lower' in meta and meta['lower'] is not None: meta['lower'] = np.resize(np.array(meta['lower']), meta['shape']) if 'upper' in meta and meta['upper'] is not None: meta['upper'] = np.resize(np.array(meta['upper']), meta['shape']) elif version in (1, 2): abs2prom = row['abs2prom'] prom2abs = row['prom2abs'] abs2meta = row['abs2meta'] if PY2: try: self._abs2prom = pickle.loads(str(abs2prom)) self._prom2abs = pickle.loads(str(prom2abs)) self._abs2meta = pickle.loads(str(abs2meta)) except ValueError as err: if err.message.startswith('unsupported pickle protocol'): raise ValueError("This data appears to have been recorded with " "Python 3 and cannot be read with Python 2 " "(%s)." % err.message) else: raise err if PY3: try: self._abs2prom = pickle.loads(abs2prom) self._prom2abs = pickle.loads(prom2abs) self._abs2meta = pickle.loads(abs2meta) except TypeError: # Reading in a python 2 pickle recorded pre-OpenMDAO 2.4. self._abs2prom = pickle.loads(abs2prom.encode()) self._prom2abs = pickle.loads(prom2abs.encode()) self._abs2meta = pickle.loads(abs2meta.encode()) def _collect_driver_metadata(self, cur): """ Load data from the driver_metadata table. Populates the `problem_metadata` attribute of this CaseReader. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute("SELECT model_viewer_data FROM driver_metadata") row = cur.fetchone() if row is not None: if self._format_version >= 3: driver_metadata = json_loads(row[0]) elif self._format_version in (1, 2): if PY2: driver_metadata = pickle.loads(str(row[0])) if PY3: driver_metadata = pickle.loads(row[0]) self.problem_metadata.update(driver_metadata) def _collect_system_metadata(self, cur): """ Load data from the system table. Populates the `system_metadata` attribute of this CaseReader. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute("SELECT id, scaling_factors, component_metadata FROM system_metadata") for row in cur: id = row[0] self.system_metadata[id] = {} if PY2: self.system_metadata[id]['scaling_factors'] = pickle.loads(str(row[1])) self.system_metadata[id]['component_options'] = pickle.loads(str(row[2])) if PY3: self.system_metadata[id]['scaling_factors'] = pickle.loads(row[1]) self.system_metadata[id]['component_options'] = pickle.loads(row[2]) def _collect_solver_metadata(self, cur): """ Load data from the solver_metadata table. Populates the `solver_metadata` attribute of this CaseReader. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute("SELECT id, solver_options, solver_class FROM solver_metadata") for row in cur: id = row[0] if PY2: solver_options = pickle.loads(str(row[1])) if PY3: solver_options = pickle.loads(row[1]) solver_class = row[2] self.solver_metadata[id] = { 'solver_options': solver_options, 'solver_class': solver_class, } def _get_global_iterations(self, cur): """ Get the global iterations table. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. Returns ------- list List of global iterations and the table and row where the associated case is found. """ cur.execute('select * from global_iterations') return cur.fetchall() def _load_cases(self): """ Load all driver, solver, and system cases into memory. """ self._driver_cases._load_cases() self._solver_cases._load_cases() self._system_cases._load_cases() if self._format_version >= 2: self._problem_cases._load_cases()
[docs] def list_sources(self): """ List of all the different recording sources for which there is recorded data. Returns ------- list One or more of: `problem`, `driver`, `<system hierarchy location>`, `<solver hierarchy location>` """ sources = [] if self._driver_cases.count() > 0: sources.extend(self._driver_cases.list_sources()) if self._solver_cases.count() > 0: sources.extend(self._solver_cases.list_sources()) if self._system_cases.count() > 0: sources.extend(self._system_cases.list_sources()) if self._format_version >= 2 and self._problem_cases.count() > 0: sources.extend(self._problem_cases.list_sources()) return sources
[docs] def list_source_vars(self, source): """ List of all inputs and outputs recorded by the specified source. Parameters ---------- source : {'problem', 'driver', <system hierarchy location>, <solver hierarchy location>} Identifies the source for which to return information. Returns ------- dict {'inputs':[list of keys], 'outputs':[list of keys]}. Does not recurse. """ dct = { 'inputs': [], 'outputs': [], } case = None if source == 'problem': if self._problem_cases.count() > 0: case = self._problem_cases.get_case(0) elif source == 'driver': if self._driver_cases.count() > 0: case = self._driver_cases.get_case(0) elif source in self._system_cases.list_sources(): source_cases = self._system_cases.list_cases(source) case = self._system_cases.get_case(source_cases[0]) elif source in self._solver_cases.list_sources(): source_cases = self._solver_cases.list_cases(source) case = self._solver_cases.get_case(source_cases[0]) else: raise RuntimeError('Source not found: %s' % source) if case is None: raise RuntimeError('No cases recorded for %s' % source) if case.inputs: dct['inputs'] = list(case.inputs) if case.outputs: dct['outputs'] = list(case.outputs) return dct
[docs] def list_cases(self, source=None, recurse=True, flat=True): """ Iterate over Driver, Solver and System cases in order. Parameters ---------- source : {'problem', 'driver', <system hierarchy location>, <solver hierarchy location>, case name} If not None, only cases originating from the specified source or case are returned. recurse : bool, optional If True, will enable iterating over all successors in case hierarchy. flat : bool, optional If False and there are child cases, then a nested ordered dictionary is returned rather than an iterator. Returns ------- iterator or dict An iterator or a nested dictionary of identified cases. """ # if source was not specified, return all cases if source is None: if flat: source = '' else: if self._driver_cases.count() > 0: source = 'driver' elif 'root' in self._system_cases.list_sources(): source = 'root' else: # if there are no driver or model cases, then we need # another starting point to build the nested dict. raise RuntimeError("A nested dictionary of all cases was requested, but " "neither the driver or the model was recorded. Please " "specify another source (system or solver) for the cases " "you want to see.") if not isinstance(source, string_types): raise TypeError("Source parameter must be a string, %s is type %s." % (source, type(source).__name__)) if not source: return self._list_cases_recurse_flat() elif source == 'problem': if self._format_version >= 2: return self._problem_cases.list_cases() else: raise RuntimeError('No problem cases recorded (data format = %d).' % self._format_version) else: # figure out which table has cases from the source if source == 'driver': case_table = self._driver_cases elif source in self._system_cases.list_sources(): case_table = self._system_cases elif source in self._solver_cases.list_sources(): case_table = self._solver_cases else: case_table = None if case_table is not None: if not recurse: # return list of cases from the source alone return case_table.list_cases(source) elif flat: # return list of cases from the source plus child cases cases = [] source_cases = case_table.get_cases(source) for case in source_cases: cases += self._list_cases_recurse_flat(case.name) return cases else: # return nested dict of cases from the source and child cases cases = OrderedDict() source_cases = case_table.get_cases(source) for case in source_cases: cases.update(self._list_cases_recurse_nested(case.name)) return cases elif '|' in source: # source is a coordinate if recurse: if flat: return self._list_cases_recurse_flat(source) else: return self._list_cases_recurse_nested(source) else: raise RuntimeError('Source not found: %s' % source)
def _list_cases_recurse_flat(self, coord=None): """ Iterate recursively over Driver, Solver and System cases in order. Parameters ---------- coord : an iteration coordinate Identifies the parent of the cases to return. Returns ------- dict A nested dictionary of identified cases. """ solver_cases = self._solver_cases.list_cases() system_cases = self._system_cases.list_cases() driver_cases = self._driver_cases.list_cases() global_iters = self._global_iterations if not coord: # will return all cases coord = '' parent_case_counter = len(global_iters) elif coord in driver_cases: parent_case_counter = self._driver_cases.get_case(coord).counter elif coord in system_cases: parent_case_counter = self._system_cases.get_case(coord).counter elif coord in solver_cases: parent_case_counter = self._solver_cases.get_case(coord).counter else: raise RuntimeError('Case not found for coordinate:', coord) cases = [] # return all cases in the global iteration table that precede the given case # and whose coordinate is prefixed by the given coordinate for i in range(0, parent_case_counter): global_iter = global_iters[i] table, row = global_iter[1], global_iter[2] if table == 'solver': case_coord = solver_cases[row - 1] elif table == 'system': case_coord = system_cases[row - 1] elif table == 'driver': case_coord = driver_cases[row - 1] else: raise RuntimeError('Unexpected table name in global iterations:', table) if case_coord.startswith(coord): cases.append(case_coord) return cases def _list_cases_recurse_nested(self, coord=None): """ Iterate recursively over Driver, Solver and System cases in order. Parameters ---------- coord : an iteration coordinate Identifies the parent of the cases to return. Returns ------- dict A nested dictionary of identified cases. """ solver_cases = self._solver_cases.list_cases() system_cases = self._system_cases.list_cases() driver_cases = self._driver_cases.list_cases() global_iters = self._global_iterations if coord in driver_cases: parent_case = self._driver_cases.get_case(coord) elif coord in system_cases: parent_case = self._system_cases.get_case(coord) elif coord in solver_cases: parent_case = self._solver_cases.get_case(coord) else: raise RuntimeError('Case not found for coordinate:', coord) cases = OrderedDict() children = OrderedDict() cases[parent_case.name] = children # return all cases in the global iteration table that precede the given case # and whose coordinate is prefixed by the given coordinate for i in range(0, parent_case.counter - 1): global_iter = global_iters[i] table, row = global_iter[1], global_iter[2] if table == 'solver': case_coord = solver_cases[row - 1] if case_coord.startswith(coord): parent_coord = '|'.join(case_coord.split('|')[:-2]) if parent_coord == coord: children.update(self._list_cases_recurse_nested(case_coord)) elif table == 'system': case_coord = system_cases[row - 1] if case_coord.startswith(coord): parent_coord = '|'.join(case_coord.split('|')[:-2]) if parent_coord == coord: children.update(self._list_cases_recurse_nested(case_coord)) return cases
[docs] def get_cases(self, source=None, recurse=True, flat=True): """ Iterate over the cases. Parameters ---------- source : {'problem', 'driver', component pathname, solver pathname, case_name} Identifies which cases to return. recurse : bool, optional If True, will enable iterating over all successors in case hierarchy flat : bool, optional If False and there are child cases, then a nested ordered dictionary is returned rather than an iterator. Returns ------- list or dict The cases identified by source """ case_ids = self.list_cases(source, recurse, flat) if isinstance(case_ids, list): return [self.get_case(case_id) for case_id in case_ids] else: return self._get_cases_nested(case_ids, OrderedDict())
def _get_cases_nested(self, case_ids, cases): """ Populate a nested dictionary of cases matching the provided dictionary of case IDs. Parameters ---------- case_ids : OrderedDict The nested dictionary of case IDs. cases : OrderedDict The nested dictionary of cases. Returns ------- OrderedDict The nested dictionary of cases with cases added from case_ids. """ for case_id in case_ids: case = self.get_case(case_id) children = case_ids[case_id] if len(children.keys()) > 0: cases[case] = self._get_cases_nested(children, OrderedDict()) else: cases[case] = OrderedDict() return cases
[docs] def get_case(self, case_id, recurse=False): """ Get case identified by case_id. Parameters ---------- case_id : str or int The unique identifier of the case to return or an index into all cases. recurse : bool, optional If True, will return all successors to the case as well. Returns ------- dict The case identified by case_id """ if isinstance(case_id, int): # it's a global index rather than a coordinate global_iters = self._global_iterations if case_id > len(global_iters) - 1: raise IndexError("Invalid index into available cases:", case_id) global_iter = global_iters[case_id] table, row = global_iter[1], global_iter[2] if table == 'solver': solver_cases = self._solver_cases.list_cases() case_id = solver_cases[row - 1] elif table == 'system': system_cases = self._system_cases.list_cases() case_id = system_cases[row - 1] elif table == 'driver': driver_cases = self._driver_cases.list_cases() case_id = driver_cases[row - 1] if recurse: return self.get_cases(case_id, recurse=True) tables = [self._driver_cases, self._system_cases, self._solver_cases] if self._format_version >= 2: tables.append(self._problem_cases) for table in tables: case = table.get_case(case_id) if case: return case raise RuntimeError('Case not found:', case_id)
[docs]class CaseTable(object): """ Base class for wrapping case tables in a recording database. Attributes ---------- _filename : str The name of the recording file from which to instantiate the case reader. _format_version : int The version of the format assumed when loading the file. _table_name : str The name of the table in the database. _index_name : str The name of the case index column in the table. _global_iterations : list List of iteration cases and the table and row in which they are found. _abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. _abs2meta : dict Dictionary mapping absolute variable names to variable metadata. _prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. _var_info : dict Dictionary with information about variables (scaling, indices, execution order). _sources : list List of sources of cases in the table. _keys : list List of keys of cases in the table. _cases : dict Dictionary mapping keys to cases that have already been loaded. _global_iterations : list List of iteration cases and the table and row in which they are found. """
[docs] def __init__(self, fname, ver, table, index, giter, prom2abs, abs2prom, abs2meta, var_info): """ Initialize. Parameters ---------- fname : str The name of the recording file from which to instantiate the case reader. ver : int The version of the format assumed when loading the file. table : str The name of the table in the database. index : str The name of the case index column in the table. giter : list of tuple The global iterations table. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """ self._filename = fname self._format_version = ver self._table_name = table self._index_name = index self._global_iterations = giter self._prom2abs = prom2abs self._abs2prom = abs2prom self._abs2meta = abs2meta self._var_info = var_info # cached keys/cases self._sources = None self._keys = None self._cases = {}
[docs] def count(self): """ Get the number of cases recorded in the table. Returns ------- int The number of cases recorded in the table. """ with sqlite3.connect(self._filename) as con: cur = con.cursor() cur.execute("SELECT count(*) FROM %s" % self._table_name) rows = cur.fetchall() con.close() return rows[0][0]
[docs] def list_cases(self, source=None): """ Get list of case IDs for cases in the table. Parameters ---------- source : str, optional A source of cases or the iteration coordinate of a case. If not None, only cases originating from the specified source or case are returned. Returns ------- list The cases from the table from the specified source or parent case. """ if not self._keys: with sqlite3.connect(self._filename) as con: cur = con.cursor() cur.execute("SELECT %s FROM %s ORDER BY id ASC" % (self._index_name, self._table_name)) rows = cur.fetchall() con.close() # cache case list for future use self._keys = [row[0] for row in rows] if not source: # return all cases return self._keys elif '|' in source: # source is a coordinate return [key for key in self._keys if key.startswith(source)] else: # source is a system or solver return [key for key in self._keys if self._get_source(key) == source]
[docs] def get_cases(self, source=None, recurse=False, flat=False): """ Get list of case names for cases in the table. Parameters ---------- source : str, optional If not None, only cases that have the specified source will be returned recurse : bool, optional If True, will enable iterating over all successors in case hierarchy flat : bool, optional If False and there are child cases, then a nested ordered dictionary is returned rather than an iterator. Returns ------- list or dict The cases from the table that have the specified source. """ if self._keys is None: self.list_cases() if not source: # return all cases return [self.get_case(key) for key in self._keys] elif '|' in source: # source is a coordinate if recurse and not flat: cases = OrderedDict() for key in self._keys: if len(key) > len(source) and key.startswith(source): cases[key] = self.get_cases(key, recurse, flat) return cases else: return list([self.get_case(key) for key in self._keys if key.startswith(source)]) else: # source is a system or solver if recurse: if flat: # return all cases under the source system source_sys = source.replace('.nonlinear_solver', '') return list([self.get_case(key) for key in self._keys if get_source_system(key).startswith(source_sys)]) else: cases = OrderedDict() for key in self._keys: case_source = self._get_source(key) if case_source == source: cases[key] = self.get_cases(key, recurse, flat) return cases else: return [self.get_case(key) for key in self._keys if self._get_source(key) == source]
[docs] def get_case(self, case_id, cache=False): """ Get a case from the database. Parameters ---------- case_id : str or int The string-identifier of the case to be retrieved or the index of the case. cache : bool If True, case will be cached for faster access by key. Returns ------- Case The specified case from the table. """ # check to see if we've already cached this case if isinstance(case_id, int): case_id = self._get_iteration_coordinate(case_id) # if we've already cached this case, return the cached instance if case_id in self._cases: return self._cases[case_id] # we don't have it, so fetch it with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() cur.execute("SELECT * FROM %s WHERE %s='%s'" % (self._table_name, self._index_name, case_id)) row = cur.fetchone() con.close() # if found, extract the data and optionally cache the Case if row is not None: if self._format_version >= 5: source = self._get_row_source(row['id']) # check for situations where parsing the iter coord doesn't work correctly iter_source = self._get_source(row[self._index_name]) if iter_source != source: simple_warning('Mismatched source for %d: %s = %s vs %s' % (row['id'], row[self._index_name], iter_source, source)) else: source = self._get_source(row[self._index_name]) case = Case(source, row, self._prom2abs, self._abs2prom, self._abs2meta, self._var_info, self._format_version) # cache it if requested if cache: self._cases[case_id] = case return case else: return None
def _get_iteration_coordinate(self, case_idx): """ Return the iteration coordinate for the indexed case (handles negative indices, etc.). Parameters ---------- case_idx : int The case number that we want the iteration coordinate for. Returns ------- iteration_coordinate : str The iteration coordinate. """ # if keys have not been cached yet, get them now if self._keys is None: self.list_cases() return self._keys[case_idx]
[docs] def cases(self, cache=False): """ Iterate over all cases, optionally caching them into memory. Parameters ---------- cache : bool If True, cases will be cached for faster access by key. """ with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() cur.execute("SELECT * FROM %s ORDER BY id ASC" % self._table_name) # rows = cur.fetchall() for row in cur: case_id = row[self._index_name] source = self._get_source(case_id) case = Case(source, row, self._prom2abs, self._abs2prom, self._abs2meta, self._var_info, self._format_version) if cache: self._cases[case_id] = case yield case con.close()
def _load_cases(self): """ Load all cases into memory. """ for case in self.cases(cache=True): pass
[docs] def list_sources(self): """ Get the list of sources that recorded data in this table. Returns ------- list List of sources. """ if self._sources is None: if self._format_version >= 5: table = self._table_name.split('_')[0] # remove "_iterations" from table name sources = set() for global_iter in self._global_iterations: record_type, source = global_iter[1], global_iter[3] if record_type == table: if not source.startswith('root'): sources.add('root.' + source) else: sources.add(source) self._sources = sources else: self._sources = set([self._get_source(case) for case in self.list_cases()]) return self._sources
def _get_source(self, iteration_coordinate): """ Get the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str The source of the iteration. """ return get_source_system(iteration_coordinate) def _get_row_source(self, row_id): """ Get the source of the case at the specified row of this table. Parameters ---------- row_id : int The row_id of the case in the table. Returns ------- str The source of the case. """ table = self._table_name.split('_')[0] # remove "_iterations" from table name for global_iter in self._global_iterations: record_type, row, source = global_iter[1], global_iter[2], global_iter[3] if record_type == table and row == row_id: return source return None def _get_first(self, source): """ Get the first case from the specified source. Parameters ---------- source : str The source. Returns ------- Case The first case from the specified source. """ for case in self.cases(): if case.source == source: return case return None
[docs]class DriverCases(CaseTable): """ Cases specific to the entries that might be recorded in a Driver iteration. """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, var_info): """ Initialize. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """ super(DriverCases, self).__init__(filename, format_version, 'driver_iterations', 'iteration_coordinate', giter, prom2abs, abs2prom, abs2meta, var_info) self._var_info = var_info
[docs] def cases(self, cache=False): """ Iterate over all cases, optionally caching them into memory. Override base class to add derivatives from the derivatives table. Parameters ---------- cache : bool If True, cases will be cached for faster access by key. """ with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() cur.execute("SELECT * FROM %s ORDER BY id ASC" % self._table_name) rows = cur.fetchall() for row in rows: if self._format_version > 1: # fetch associated derivative data, if available cur.execute("SELECT * FROM driver_derivatives WHERE iteration_coordinate='%s'" % row['iteration_coordinate']) derivs_row = cur.fetchone() if derivs_row: # convert row to a regular dict and add jacobian row = dict(zip(row.keys(), row)) row['jacobian'] = derivs_row['derivatives'] case = Case('driver', row, self._prom2abs, self._abs2prom, self._abs2meta, self._var_info, self._format_version) if cache: self._cases[case.name] = case yield case con.close()
[docs] def get_case(self, case_id, cache=False): """ Get a case from the database. Parameters ---------- case_id : int or str The integer index or string-identifier of the case to be retrieved. cache : bool If True, cache the case so it does not have to be fetched on next access. Returns ------- Case The specified case from the driver_iterations and driver_derivatives tables. """ # check to see if we've already cached this case if isinstance(case_id, int): case_id = self._get_iteration_coordinate(case_id) # return cached case if present, else fetch it if case_id in self._cases: return self._cases[case_id] # Get an unscaled case if does not already exist in _cases with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() # fetch driver iteration data cur.execute("SELECT * FROM driver_iterations WHERE " "iteration_coordinate=:iteration_coordinate", {"iteration_coordinate": case_id}) row = cur.fetchone() # fetch associated derivative data, if available if row and self._format_version > 1: cur.execute("SELECT * FROM driver_derivatives WHERE " "iteration_coordinate=:iteration_coordinate", {"iteration_coordinate": case_id}) derivs_row = cur.fetchone() if derivs_row: # convert row to a regular dict and add jacobian row = dict(zip(row.keys(), row)) row['jacobian'] = derivs_row['derivatives'] con.close() # if found, create Case object (and cache it if requested) else return None if row: case = Case('driver', row, self._prom2abs, self._abs2prom, self._abs2meta, self._var_info, self._format_version) if cache: self._cases[case_id] = case return case else: return None
[docs] def list_sources(self): """ Get the list of sources that recorded data in this table (just the driver). Returns ------- list List of sources. """ return ['driver']
def _get_source(self, iteration_coordinate): """ Get the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str 'driver' (all cases in this table come from the driver). """ return 'driver' def _get_row_source(self, row_id): """ Get the source of the case at the specified row of this table. Parameters ---------- row_id : int The row_id of the case in the table. Returns ------- str The source of the case. """ return 'driver'
[docs]class SystemCases(CaseTable): """ Cases specific to the entries that might be recorded in a System iteration. """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, var_info): """ Initialize. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """ super(SystemCases, self).__init__(filename, format_version, 'system_iterations', 'iteration_coordinate', giter, prom2abs, abs2prom, abs2meta, var_info)
[docs]class SolverCases(CaseTable): """ Cases specific to the entries that might be recorded in a Solver iteration. """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, var_info): """ Initialize. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """ super(SolverCases, self).__init__(filename, format_version, 'solver_iterations', 'iteration_coordinate', giter, prom2abs, abs2prom, abs2meta, var_info)
def _get_source(self, iteration_coordinate): """ Get pathname of solver that is the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str The pathname of the solver that is the source of the iteration. """ source_system = get_source_system(iteration_coordinate) system_solve = source_system.split('.')[-1] + '._solve_nonlinear' system_coord_len = iteration_coordinate.index(system_solve) + len(system_solve) system_coord_nodes = len(iteration_coordinate[:system_coord_len].split('|')) + 1 iter_coord_nodes = len(iteration_coordinate.split('|')) if iter_coord_nodes == system_coord_nodes + 2: return source_system + '.nonlinear_solver' elif iter_coord_nodes == system_coord_nodes + 4: return source_system + '.nonlinear_solver.linesearch' else: raise RuntimeError("Can't parse solver iteration coordinate: %s" % iteration_coordinate)
[docs]class ProblemCases(CaseTable): """ Cases specific to the entries that might be recorded in a Driver iteration. """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, var_info): """ Initialize. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """ super(ProblemCases, self).__init__(filename, format_version, 'problem_cases', 'case_name', giter, prom2abs, abs2prom, abs2meta, var_info)
[docs] def list_sources(self): """ Get the list of sources that recorded data in this table (just the problem). Returns ------- list List of sources. """ return ['problem']
def _get_source(self, iteration_coordinate): """ Get the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str 'problem' (all cases in this table come from the problem). """ return 'problem' def _get_row_source(self, row_id): """ Get the source of the case at the specified row of this table. Parameters ---------- row_id : int The row_id of the case in the table. Returns ------- str The source of the case. """ return 'problem'