Source code for openmdao.recorders.sqlite_reader

"""
Definition of the SqliteCaseReader.
"""
import pathlib
import sqlite3
from collections import OrderedDict

import sys
import numpy as np
import io

from openmdao.recorders.base_case_reader import BaseCaseReader
from openmdao.recorders.case import Case
from openmdao.core.constants import _DEFAULT_OUT_STREAM
from openmdao.utils.variable_table import write_source_table
from openmdao.utils.record_util import check_valid_sqlite3_db, get_source_system
from openmdao.utils.om_warnings import issue_warning, CaseRecorderWarning

from openmdao.recorders.sqlite_recorder import format_version, META_KEY_SEP

from openmdao.utils.notebook_utils import notebook, display, HTML
from openmdao.visualization.tables.table_builder import generate_table

import pickle
import zlib
import re
from json import loads as json_loads
from io import TextIOBase


[docs]class UnknownType: """ A class used by _RestrictedUnpickler. Used to indicate the unpickler can't generate an instance of a class whose class definition is not available Parameters ---------- *args : list Positional args. **kwargs : dict Keyword args. """
[docs] def __init__(*args, **kwargs): """ Construct an object representing an unknown type. Parameters ---------- *args : list Positional args. **kwargs : dict Keyword args. Returns ------- object The returned UnknownType object """ pass
# don't allow these functions when unpickling _unsafe = ( ('builtins', 'eval'), ('builtins', 'exec'), ('posix', 'system'), ('nt', 'system'), ) class _RestrictedUnpicklerForCaseReader(pickle.Unpickler): def __init__(self, file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): super().__init__(file, fix_imports=fix_imports, encoding=encoding, errors=errors, buffers=buffers) self.error_strings = '' # Used to document which classes are not available def find_class(self, module, name): # Disallow some unsafe function calls during unpickling. if (module, name) in _unsafe: if self.error_strings: self.error_strings += ', ' self.error_strings += f"Error unpickling global, '{module}.{name}' is forbidden" return UnknownType try: return super().find_class(module, name) except ModuleNotFoundError as e: if self.error_strings: self.error_strings += ', ' self.error_strings += str(e) # Returning this acts as a kind of flag to indicate that the unpickler can't # generate instances of classes whose class definition is not available return UnknownType def loads_and_return_errors(self): unpickled_contents = self.load() return unpickled_contents, self.error_strings def _safer_unpickle(s, desc, compressed=False): """ Unpickle input and issue a warning for any errors. Analogous to pickle.loads(). """ if compressed: i = io.BytesIO(zlib.decompress(s)) else: i = io.BytesIO(s) # returns a tuple of the value and also error strings data, errs = _RestrictedUnpicklerForCaseReader(i).loads_and_return_errors() if errs: issue_warning(f"While reading {desc} from case recorder, the " f"following errors occurred: {errs}", category=RuntimeWarning) return data
[docs]class SqliteCaseReader(BaseCaseReader): """ A CaseReader specific to files created with SqliteRecorder. Parameters ---------- filename : str or pathlib.Path The path to the filename containing the recorded data. pre_load : bool If True, load all the data into memory during initialization. metadata_filename : str The path to the filename containing the recorded metadata, if separate. Attributes ---------- problem_metadata : dict Metadata about the problem, including the system hierachy and connections. solver_metadata : dict The solver options for each solver in the recorded model. _system_options : dict Metadata about each system in the recorded model, including options and scaling factors. _format_version : int The version of the format assumed when loading the file. _filename : str The path to the filename containing the recorded data. _abs2meta : dict Dictionary mapping variables to their metadata _abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. _prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. _conns : dict Dictionary of all model connections. _driver_cases : DriverCases Helper object for accessing cases from the driver_iterations table. _system_cases : SystemCases Helper object for accessing cases from the system_iterations table. _solver_cases : SolverCases Helper object for accessing cases from the solver_iterations table. _problem_cases : ProblemCases Helper object for accessing cases from the problem_cases table. _global_iterations : list List of iteration cases and the table and row in which they are found. """
[docs] def __init__(self, filename, pre_load=False, metadata_filename=None): """Initialize.""" super().__init__(filename, pre_load) check_valid_sqlite3_db(filename) if metadata_filename: check_valid_sqlite3_db(metadata_filename) # initialize private attributes self._filename = pathlib.Path(filename) self._abs2prom = None self._prom2abs = None self._abs2meta = None self._conns = None self._global_iterations = None filename = str(filename) with sqlite3.connect(filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() # get the global iterations table, and save it as an attribute self._global_iterations = self._get_global_iterations(cur) # If separate metadata not specified, check the current db # to make sure it's there if metadata_filename is None: cur.execute("SELECT count(name) FROM sqlite_master " "WHERE type='table' AND name='metadata'") # If not, take a guess at the filename: if cur.fetchone()[0] == 0: metadata_filename = re.sub(r'^(.*)_(\d+)', r'\1_meta', filename) check_valid_sqlite3_db(metadata_filename) else: metadata_filename = filename con.close() # collect metadata from database with sqlite3.connect(metadata_filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() # collect data from the metadata table. this includes: # format_version # openmdao_version # VOI metadata, which is added to problem_metadata # var name maps and metadata for all vars, which are saved as private attributes self._collect_metadata(cur) # collect data from the driver_metadata table. this includes: # model viewer data, which is added to problem_metadata self._collect_driver_metadata(cur) # collect data from the system_metadata table. this includes: # component metadata and scaling factors for each system, # which is added to _system_options self._collect_system_metadata(cur) # collect data from the solver_metadata table. this includes: # solver class and options for each solver, which is saved as an attribute self._collect_solver_metadata(cur) con.close() # create helper objects for accessing cases from the three iteration tables and # the problem cases table var_info = self.problem_metadata['variables'] self._driver_cases = DriverCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, var_info) self._system_cases = SystemCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, var_info) self._solver_cases = SolverCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, var_info) if self._format_version >= 2: self._problem_cases = ProblemCases(filename, self._format_version, self._global_iterations, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, var_info) # if requested, load all the iteration data into memory if pre_load: self._load_cases()
def _collect_metadata(self, cur): """ Load data from the metadata table. Populates the `format_version` attribute and the `variables` data in the `problem_metadata` attribute of this CaseReader. Also saves the variable name maps and variable metadata to private attributes. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute('select * from metadata') row = cur.fetchone() self._format_version = version = row['format_version'] if version >= 13: self._openmdao_version = row['openmdao_version'] if version not in range(1, format_version + 1): raise ValueError('SqliteCaseReader encountered an unhandled ' 'format version: {0}'.format(self._format_version)) if version >= 11: # Auto-IVC if version >= 14: self._conns = json_loads(zlib.decompress(row['conns']).decode('ascii')) else: self._conns = json_loads(row['conns']) # add metadata for VOIs (des vars, objective, constraints) to problem metadata if version >= 4: if version >= 14: self.problem_metadata['variables'] = \ json_loads(zlib.decompress(row['var_settings']).decode('ascii')) else: self.problem_metadata['variables'] = json_loads(row['var_settings']) else: self.problem_metadata['variables'] = None # get variable name maps and metadata for all variables if version >= 3: if version >= 14: self._abs2prom = json_loads(zlib.decompress(row['abs2prom']).decode('ascii')) self._prom2abs = json_loads(zlib.decompress(row['prom2abs']).decode('ascii')) self._abs2meta = json_loads(zlib.decompress(row['abs2meta']).decode('ascii')) else: self._abs2prom = json_loads(row['abs2prom']) self._prom2abs = json_loads(row['prom2abs']) self._abs2meta = json_loads(row['abs2meta']) # need to convert bounds to numpy arrays for meta in self._abs2meta.values(): if 'lower' in meta and meta['lower'] is not None: meta['lower'] = np.resize(np.array(meta['lower']), meta['shape']) if 'upper' in meta and meta['upper'] is not None: meta['upper'] = np.resize(np.array(meta['upper']), meta['shape']) elif version in (1, 2): abs2prom = row['abs2prom'] prom2abs = row['prom2abs'] abs2meta = row['abs2meta'] try: self._abs2prom = _safer_unpickle(abs2prom, 'abs2prom dictionary') self._prom2abs = _safer_unpickle(prom2abs, 'prom2abs dictionary') self._abs2meta = _safer_unpickle(abs2meta, 'abs2meta dictionary') except TypeError: # Reading in a python 2 pickle recorded pre-OpenMDAO 2.4. self._abs2prom = _safer_unpickle(abs2prom.encode(), 'abs2prom dictionary') self._prom2abs = _safer_unpickle(prom2abs.encode(), 'prom2abs dictionary') self._abs2meta = _safer_unpickle(abs2meta.encode(), 'abs2meta dictionary') self.problem_metadata['abs2prom'] = self._abs2prom def _collect_driver_metadata(self, cur): """ Load data from the driver_metadata table. Populates the `problem_metadata` attribute of this CaseReader. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute("SELECT model_viewer_data FROM driver_metadata") row = cur.fetchone() if row is not None: if self._format_version >= 3: driver_metadata = json_loads(row[0]) elif self._format_version in (1, 2): driver_metadata = _safer_unpickle(row[0], 'driver metadata') self.problem_metadata.update(driver_metadata) def _collect_system_metadata(self, cur): """ Load data from the system table. Populates the `_system_options` attribute of this CaseReader. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute("SELECT id, scaling_factors, component_metadata FROM system_metadata") for row in cur: id = row[0] opt = self._system_options[id] = {} cmp = self._format_version >= 14 opt['scaling_factors'] = _safer_unpickle(row[1], f'{id} scaling factors', cmp) opt['component_options'] = _safer_unpickle(row[2], f'{id} component options', cmp) def _collect_solver_metadata(self, cur): """ Load data from the solver_metadata table. Populates the `solver_metadata` attribute of this CaseReader. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. """ cur.execute("SELECT id, solver_options, solver_class FROM solver_metadata") for row in cur: id = row[0] cmp = self._format_version >= 14 self.solver_metadata[id] = { 'solver_options': _safer_unpickle(row[1], f'{id} solver options', cmp), 'solver_class': row[2] } def _get_global_iterations(self, cur): """ Get the global iterations table. Parameters ---------- cur : sqlite3.Cursor Database cursor to use for reading the data. Returns ------- list List of global iterations and the table and row where the associated case is found. """ cur.execute('select * from global_iterations') return cur.fetchall() def _load_cases(self): """ Load all driver, solver, and system cases into memory. """ self._driver_cases._load_cases() self._solver_cases._load_cases() self._system_cases._load_cases() if self._format_version >= 2: self._problem_cases._load_cases()
[docs] def list_sources(self, out_stream=_DEFAULT_OUT_STREAM): """ List of all the different recording sources for which there is recorded data. Parameters ---------- out_stream : file-like object Where to send human readable output. Default is sys.stdout. Set to None to suppress. Returns ------- list One or more of: `problem`, `driver`, `<system hierarchy location>`, `<solver hierarchy location>` """ sources = [] if self._driver_cases.count() > 0: sources.extend(self._driver_cases.list_sources()) if self._solver_cases.count() > 0: sources.extend(self._solver_cases.list_sources()) if self._system_cases.count() > 0: sources.extend(self._system_cases.list_sources()) if self._format_version >= 2 and self._problem_cases.count() > 0: sources.extend(self._problem_cases.list_sources()) if out_stream: if notebook and out_stream is _DEFAULT_OUT_STREAM: display(HTML(str(generate_table([[s] for s in sources], headers=['Sources'], tablefmt='html')))) else: if out_stream is _DEFAULT_OUT_STREAM: out_stream = sys.stdout elif not isinstance(out_stream, TextIOBase): raise TypeError("Invalid output stream specified for 'out_stream'.") for source in sources: out_stream.write('{}\n'.format(source)) return sources
[docs] def list_source_vars(self, source, out_stream=_DEFAULT_OUT_STREAM): """ List of all inputs and outputs recorded by the specified source. Parameters ---------- source : {'problem', 'driver', <system hierarchy location>, <solver hierarchy location>} Identifies the source for which to return information. out_stream : file-like object Where to send human readable output. Default is sys.stdout. Set to None to suppress. Returns ------- dict {'inputs':[key list], 'outputs':[key list], 'residuals':[key list]}. No recurse. """ dct = { 'inputs': [], 'outputs': [], 'residuals': [], } case = None if source == 'problem': if self._problem_cases.count() > 0: case = self._problem_cases.get_case(0) elif source == 'driver': if self._driver_cases.count() > 0: case = self._driver_cases.get_case(0) elif source in self._system_cases.list_sources(): source_cases = self._system_cases.list_cases(source) case = self._system_cases.get_case(source_cases[0]) elif source in self._solver_cases.list_sources(): source_cases = self._solver_cases.list_cases(source) case = self._solver_cases.get_case(source_cases[0]) else: raise RuntimeError('Source not found: %s' % source) if case is None: raise RuntimeError('No cases recorded for %s' % source) if case.inputs: dct['inputs'] = list(case.inputs) if case.outputs: dct['outputs'] = list(case.outputs) if case.residuals: dct['residuals'] = list(case.residuals) if out_stream: write_source_table(dct, out_stream) return dct
[docs] def systems(self, tree=None, path=None, paths=[]): """ List pathnames of systems in the system hierarchy. Parameters ---------- tree : dict Nested dictionary of system information. path : str or None Pathname of root system (None for the root model). paths : list List to which pathnames are appended. Returns ------- list List of pathnames of systems. """ if tree is None: tree = self.problem_metadata['tree'] path = '.'.join([path, tree['name']]) if path else tree['name'] paths.append(path) if 'children' in tree: for child in tree['children']: if child['type'] == 'subsystem': self.systems(child, path, paths) return paths
[docs] def list_model_options(self, run_number=0, system=None, out_stream=_DEFAULT_OUT_STREAM): """ List model options for the specified run. Parameters ---------- run_number : int Run_driver or run_model iteration to inspect. system : str or None Pathname of system (None for all systems). out_stream : file-like object Where to send human readable output. Default is sys.stdout. Set to None to suppress. Returns ------- dict {system: {key: val}}. """ dct = {} if not self._system_options: issue_warning("System options not recorded.", category=CaseRecorderWarning) return dct if out_stream is _DEFAULT_OUT_STREAM: out_stream = sys.stdout num_header = None # need to handle edge case for v11 recording if self._format_version < 12: SEP = '_' else: SEP = META_KEY_SEP for key in self._system_options: if key.find(SEP) > 0: name, num = key.rsplit(SEP, 1) else: name = key num = 0 if (system is None or system == name) and (run_number == int(num)): if out_stream: if num_header != num: out_stream.write(f"Run Number: {num}\n") num_header = num out_stream.write(f" Subsystem : {name}\n") dct[name] = {} comp_options = self._system_options[key]['component_options'] for opt, val in comp_options.items(): dct[name][opt] = val if out_stream: out_stream.write(f" {opt}: {val}\n") return dct
[docs] def list_solver_options(self, run_number=0, solver=None, out_stream=_DEFAULT_OUT_STREAM): """ List solver options for the specified run. Parameters ---------- run_number : int Run_driver or run_model iteration to inspect. solver : str or None Pathname of solver (None for all solvers). out_stream : file-like object Where to send human readable output. Default is sys.stdout. Set to None to suppress. Returns ------- dict {solver: {key: val}}. """ dct = {} if not self.solver_metadata: issue_warning("Solver options not recorded.", category=CaseRecorderWarning) return dct if out_stream is _DEFAULT_OUT_STREAM: out_stream = sys.stdout num_header = None for key in self.solver_metadata: if key.find(META_KEY_SEP) > 0: name, num = key.rsplit(META_KEY_SEP, 1) else: name = key num = 0 if (solver is None or solver == name) and (run_number == int(num)): if out_stream: if num_header != num: out_stream.write(f"Run Number: {num}\n") num_header = num out_stream.write(f" Solver: {name}\n") dct[name] = {} comp_options = self.solver_metadata[key]['solver_options'] for opt, val in comp_options.items(): dct[name][opt] = val if out_stream: out_stream.write(f" {opt}: {val}\n") return dct
[docs] def list_cases(self, source=None, recurse=True, flat=True, out_stream=_DEFAULT_OUT_STREAM): """ Iterate over Driver, Solver and System cases in order. Parameters ---------- source : 'problem', 'driver', component pathname, solver pathname, case_name If not None, only cases originating from the specified source or case are returned. recurse : bool, optional If True, will enable iterating over all successors in case hierarchy. flat : bool, optional If False and there are child cases, then a nested ordered dictionary is returned rather than an iterator. out_stream : file-like object Where to send human readable output. Default is sys.stdout. Set to None to suppress. Returns ------- iterator or dict An iterator or a nested dictionary of identified cases. """ # if source was not specified, return all cases if source is None: if flat: source = '' else: if self._driver_cases.count() > 0: source = 'driver' elif 'root' in self._system_cases.list_sources(): source = 'root' else: # if there are no driver or model cases, then we need # another starting point to build the nested dict. raise RuntimeError("A nested dictionary of all cases was requested, but " "neither the driver or the model was recorded. Please " "specify another source (system or solver) for the cases " "you want to see.") if not isinstance(source, str): raise TypeError("Source parameter must be a string, %s is type %s." % (source, type(source).__name__)) if not source: cases = self._list_cases_recurse_flat(out_stream=None) elif source == 'problem': if self._format_version >= 2: cases = self._problem_cases.list_cases() else: raise RuntimeError('No problem cases recorded (data format = %d).' % self._format_version) else: # figure out which table has cases from the source if source == 'driver': case_table = self._driver_cases elif source in self._system_cases.list_sources(): case_table = self._system_cases elif source in self._solver_cases.list_sources(): case_table = self._solver_cases else: case_table = None if case_table is not None: if not recurse: # return list of cases from the source alone cases = case_table.list_cases(source) elif flat: # return list of cases from the source plus child cases cases = [] source_cases = case_table.get_cases(source) for case in source_cases: cases += self._list_cases_recurse_flat(case.name, out_stream=None) else: # return nested dict of cases from the source and child cases cases = OrderedDict() source_cases = case_table.get_cases(source) for case in source_cases: cases.update(self._list_cases_recurse_nested(case.name)) return cases elif '|' in source: # source is a coordinate if recurse: if flat: cases = self._list_cases_recurse_flat(source, out_stream=None) else: return self._list_cases_recurse_nested(source) else: raise RuntimeError('Source not found: %s' % source) if out_stream: if not source: for source, subcases in self.source_cases_table.items(): if subcases: write_source_table({source: subcases}, out_stream) del self.source_cases_table else: write_source_table({source: cases}, out_stream) return cases
def _list_cases_recurse_flat(self, coord=None, out_stream=_DEFAULT_OUT_STREAM): """ Iterate recursively over Driver, Solver and System cases in order. Parameters ---------- coord : an iteration coordinate Identifies the parent of the cases to return. out_stream : file-like object Where to send human readable output. Default is sys.stdout. Set to None to suppress. Returns ------- dict A nested dictionary of identified cases. """ solver_cases = self._solver_cases.list_cases() system_cases = self._system_cases.list_cases() driver_cases = self._driver_cases.list_cases() if self._format_version >= 2: problem_cases = self._problem_cases.list_cases() global_iters = self._global_iterations if not coord: # will return all cases coord = '' parent_case_counter = len(global_iters) elif coord in driver_cases: parent_case_counter = self._driver_cases.get_case(coord).counter elif coord in system_cases: parent_case_counter = self._system_cases.get_case(coord).counter elif coord in solver_cases: parent_case_counter = self._solver_cases.get_case(coord).counter elif coord in problem_cases: parent_case_counter = self._problem_cases.get_case(coord).counter else: raise RuntimeError('Case not found for coordinate:', coord) cases = [] self.source_cases_table = {'solver': [], 'system': [], 'driver': [], 'problem': []} source_cases = [] # return all cases in the global iteration table that precede the given case # and whose coordinate is prefixed by the given coordinate current_table = None current_cases = [] for i in range(0, parent_case_counter): global_iter = global_iters[i] table, row = global_iter[1], global_iter[2] if table == 'solver': case_coord = solver_cases[row - 1] elif table == 'system': case_coord = system_cases[row - 1] elif table == 'driver': case_coord = driver_cases[row - 1] elif table == 'problem': case_coord = problem_cases[row - 1] else: raise RuntimeError('Unexpected table name in global iterations:', table) if case_coord.startswith(coord): cases.append(case_coord) self.source_cases_table[table].append(case_coord) if out_stream: if not current_cases: current_table = table current_cases = {table: [case_coord]} elif table == current_table: current_cases[table].append(case_coord) else: source_cases.append(current_cases) current_table = table current_cases = {table: [case_coord]} if out_stream: if current_cases: source_cases.append(current_cases) write_source_table(source_cases, out_stream) return cases def _list_cases_recurse_nested(self, coord=None): """ Iterate recursively over Driver, Solver and System cases in order. Parameters ---------- coord : an iteration coordinate Identifies the parent of the cases to return. Returns ------- dict A nested dictionary of identified cases. """ solver_cases = self._solver_cases.list_cases() system_cases = self._system_cases.list_cases() driver_cases = self._driver_cases.list_cases() global_iters = self._global_iterations if coord in driver_cases: parent_case = self._driver_cases.get_case(coord) elif coord in system_cases: parent_case = self._system_cases.get_case(coord) elif coord in solver_cases: parent_case = self._solver_cases.get_case(coord) else: raise RuntimeError('Case not found for coordinate:', coord) cases = OrderedDict() children = OrderedDict() cases[parent_case.name] = children # return all cases in the global iteration table that precede the given case # and whose coordinate is prefixed by the given coordinate for i in range(0, parent_case.counter - 1): global_iter = global_iters[i] table, row = global_iter[1], global_iter[2] if table == 'solver': case_coord = solver_cases[row - 1] if case_coord.startswith(coord): parent_coord = '|'.join(case_coord.split('|')[:-2]) if parent_coord == coord: children.update(self._list_cases_recurse_nested(case_coord)) elif table == 'system': case_coord = system_cases[row - 1] if case_coord.startswith(coord): parent_coord = '|'.join(case_coord.split('|')[:-2]) if parent_coord == coord: children.update(self._list_cases_recurse_nested(case_coord)) return cases
[docs] def get_cases(self, source=None, recurse=True, flat=True): """ Iterate over the cases. Parameters ---------- source : 'problem', 'driver', component pathname, solver pathname, case_name Identifies which cases to return. recurse : bool, optional If True, will enable iterating over all successors in case hierarchy. flat : bool, optional If False and there are child cases, then a nested ordered dictionary is returned rather than an iterator. Returns ------- list or dict The cases identified by source. """ case_ids = self.list_cases(source, recurse, flat, out_stream=None) if isinstance(case_ids, list): return [self.get_case(case_id) for case_id in case_ids] else: return self._get_cases_nested(case_ids, OrderedDict())
def _get_cases_nested(self, case_ids, cases): """ Populate a nested dictionary of cases matching the provided dictionary of case IDs. Parameters ---------- case_ids : OrderedDict The nested dictionary of case IDs. cases : OrderedDict The nested dictionary of cases. Returns ------- OrderedDict The nested dictionary of cases with cases added from case_ids. """ for case_id in case_ids: case = self.get_case(case_id) children = case_ids[case_id] if len(children.keys()) > 0: cases[case] = self._get_cases_nested(children, OrderedDict()) else: cases[case] = OrderedDict() return cases
[docs] def get_case(self, case_id, recurse=False): """ Get case identified by case_id. Parameters ---------- case_id : str or int The unique identifier of the case to return or an index into all cases. recurse : bool, optional If True, will return all successors to the case as well. Returns ------- dict The case identified by case_id. """ if isinstance(case_id, int): # it's a global index rather than a coordinate global_iters = self._global_iterations if case_id > len(global_iters) - 1: raise IndexError("Invalid index into available cases:", case_id) global_iter = global_iters[case_id] table, row = global_iter[1], global_iter[2] if table == 'solver': solver_cases = self._solver_cases.list_cases() case_id = solver_cases[row - 1] elif table == 'system': system_cases = self._system_cases.list_cases() case_id = system_cases[row - 1] elif table == 'driver': driver_cases = self._driver_cases.list_cases() case_id = driver_cases[row - 1] if recurse: return self.get_cases(case_id, recurse=True) tables = [self._driver_cases, self._system_cases, self._solver_cases] if self._format_version >= 2: tables.append(self._problem_cases) for table in tables: case = table.get_case(case_id) if case: return case raise RuntimeError('Case not found:', case_id)
[docs]class CaseTable(object): """ Base class for wrapping case tables in a recording database. Parameters ---------- fname : str The name of the recording file from which to instantiate the case reader. ver : int The version of the format assumed when loading the file. table : str The name of the table in the database. index : str The name of the case index column in the table. giter : list of tuple The global iterations table. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. conns : dict Dictionary of all model connections. var_info : dict Dictionary with information about variables (scaling, indices, execution order). Attributes ---------- _filename : str The name of the recording file from which to instantiate the case reader. _format_version : int The version of the format assumed when loading the file. _table_name : str The name of the table in the database. _index_name : str The name of the case index column in the table. _global_iterations : list List of iteration cases and the table and row in which they are found. _abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. _abs2meta : dict Dictionary mapping absolute variable names to variable metadata. _prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. _conns : dict Dictionary of all model connections. _var_info : dict Dictionary with information about variables (scaling, indices, execution order). _sources : list List of sources of cases in the table. _keys : list List of keys of cases in the table. _cases : dict Dictionary mapping keys to cases that have already been loaded. _global_iterations : list List of iteration cases and the table and row in which they are found. """
[docs] def __init__(self, fname, ver, table, index, giter, prom2abs, abs2prom, abs2meta, conns, var_info): """ Initialize. """ self._filename = fname self._format_version = ver self._table_name = table self._index_name = index self._global_iterations = giter self._prom2abs = prom2abs self._abs2prom = abs2prom self._abs2meta = abs2meta self._conns = conns self._var_info = var_info # cached keys/cases self._sources = None self._keys = None self._cases = {}
[docs] def count(self): """ Get the number of cases recorded in the table. Returns ------- int The number of cases recorded in the table. """ with sqlite3.connect(self._filename) as con: cur = con.cursor() cur.execute(f"SELECT count(*) FROM {self._table_name}") # nosec: trusted input rows = cur.fetchall() con.close() return rows[0][0]
[docs] def list_cases(self, source=None): """ Get list of case IDs for cases in the table. Parameters ---------- source : str, optional A source of cases or the iteration coordinate of a case. If not None, only cases originating from the specified source or case are returned. Returns ------- list The cases from the table from the specified source or parent case. """ if not self._keys: with sqlite3.connect(self._filename) as con: cur = con.cursor() cur.execute(f"SELECT {self._index_name} FROM {self._table_name}" " ORDER BY id ASC") # nosec trusted input rows = cur.fetchall() con.close() # cache case list for future use self._keys = [row[0] for row in rows] if not source: # return all cases return self._keys elif '|' in source: # source is a coordinate return [key for key in self._keys if key.startswith(source)] else: # source is a system or solver return [key for key in self._keys if self._get_source(key) == source]
[docs] def get_cases(self, source=None, recurse=False, flat=False): """ Get list of case names for cases in the table. Parameters ---------- source : str, optional If not None, only cases that have the specified source will be returned. recurse : bool, optional If True, will enable iterating over all successors in case hierarchy. flat : bool, optional If False and there are child cases, then a nested ordered dictionary is returned rather than an iterator. Returns ------- list or dict The cases from the table that have the specified source. """ if self._keys is None: self.list_cases() if not source: # return all cases return [self.get_case(key) for key in self._keys] elif '|' in source: # source is a coordinate if recurse and not flat: cases = OrderedDict() for key in self._keys: if len(key) > len(source) and key.startswith(source): cases[key] = self.get_cases(key, recurse, flat) return cases else: return list([self.get_case(key) for key in self._keys if key.startswith(source)]) else: # source is a system or solver if recurse: if flat: # return all cases under the source system source_sys = source.replace('.nonlinear_solver', '') return list([self.get_case(key) for key in self._keys if get_source_system(key).startswith(source_sys)]) else: cases = OrderedDict() for key in self._keys: case_source = self._get_source(key) if case_source == source: cases[key] = self.get_cases(key, recurse, flat) return cases else: return [self.get_case(key) for key in self._keys if self._get_source(key) == source]
[docs] def get_case(self, case_id, cache=False): """ Get a case from the database. Parameters ---------- case_id : str or int The string-identifier of the case to be retrieved or the index of the case. cache : bool If True, case will be cached for faster access by key. Returns ------- Case The specified case from the table. """ # check to see if we've already cached this case if isinstance(case_id, int): case_id = self._get_iteration_coordinate(case_id) # if we've already cached this case, return the cached instance if case_id in self._cases: return self._cases[case_id] # we don't have it, so fetch it with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() cur.execute(f"SELECT * FROM {self._table_name} " # nosec: trusted input f"WHERE {self._index_name}=?", (case_id, )) row = cur.fetchone() con.close() # if found, extract the data and optionally cache the Case if row is not None: if self._format_version >= 5: source = self._get_row_source(row['id']) # check for situations where parsing the iter coord doesn't work correctly iter_source = self._get_source(row[self._index_name]) if iter_source != source: msg = f'Mismatched source for {row["id"]}: {row[self._index_name]} = ' \ f'{iter_source} vs {source}' issue_warning(msg, category=CaseRecorderWarning) else: source = self._get_source(row[self._index_name]) case = Case(source, row, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, self._var_info, self._format_version) # cache it if requested if cache: self._cases[case_id] = case return case else: return None
def _get_iteration_coordinate(self, case_idx): """ Return the iteration coordinate for the indexed case (handles negative indices, etc.). Parameters ---------- case_idx : int The case number that we want the iteration coordinate for. Returns ------- iteration_coordinate : str The iteration coordinate. """ # if keys have not been cached yet, get them now if self._keys is None: self.list_cases() return self._keys[case_idx]
[docs] def cases(self, cache=False): """ Iterate over all cases, optionally caching them into memory. Parameters ---------- cache : bool If True, cases will be cached for faster access by key. Yields ------ case """ with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() cur.execute(f"SELECT * FROM {self._table_name} ORDER BY id ASC") # nosec: trusted input # rows = cur.fetchall() for row in cur: case_id = row[self._index_name] source = self._get_source(case_id) case = Case(source, row, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, self._var_info, self._format_version) if cache: self._cases[case_id] = case yield case con.close()
def _load_cases(self): """ Load all cases into memory. """ for case in self.cases(cache=True): pass
[docs] def list_sources(self): """ Get the list of sources that recorded data in this table. Returns ------- list List of sources. """ if self._sources is None: if self._format_version >= 5: table = self._table_name.split('_')[0] # remove "_iterations" from table name sources = set() for global_iter in self._global_iterations: record_type, source = global_iter[1], global_iter[3] if record_type == table: if not source.startswith('root'): sources.add('root.' + source) else: sources.add(source) self._sources = sources else: self._sources = set([self._get_source(case) for case in self.list_cases()]) return self._sources
def _get_source(self, iteration_coordinate): """ Get the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str The source of the iteration. """ return get_source_system(iteration_coordinate) def _get_row_source(self, row_id): """ Get the source of the case at the specified row of this table. Parameters ---------- row_id : int The row_id of the case in the table. Returns ------- str The source of the case. """ table = self._table_name.partition('_')[0] # remove "_iterations" from table name for global_iter in self._global_iterations: record_type, row, source = global_iter[1], global_iter[2], global_iter[3] if record_type == table and row == row_id: return source return None def _get_first(self, source): """ Get the first case from the specified source. Parameters ---------- source : str The source. Returns ------- Case The first case from the specified source. """ for case in self.cases(): if case.source == source: return case return None
[docs]class DriverCases(CaseTable): """ Cases specific to the entries that might be recorded in a Driver iteration. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. conns : dict Dictionary of all model connections. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, conns, var_info): """ Initialize. """ super().__init__(filename, format_version, 'driver_iterations', 'iteration_coordinate', giter, prom2abs, abs2prom, abs2meta, conns, var_info)
[docs] def cases(self, cache=False): """ Iterate over all cases, optionally caching them into memory. Override base class to add derivatives from the derivatives table. Parameters ---------- cache : bool If True, cases will be cached for faster access by key. Yields ------ case """ with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() cur.execute(f"SELECT * FROM {self._table_name} ORDER BY id ASC") # nosec: trusted input rows = cur.fetchall() for row in rows: if self._format_version > 1: # fetch associated derivative data, if available cur.execute("SELECT * FROM driver_derivatives WHERE iteration_coordinate=?", (row['iteration_coordinate'], )) derivs_row = cur.fetchone() if derivs_row: # convert row to a regular dict and add jacobian row = dict(zip(row.keys(), row)) row['jacobian'] = derivs_row['derivatives'] case = Case('driver', row, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, self._var_info, self._format_version) if cache: self._cases[case.name] = case yield case con.close()
[docs] def get_case(self, case_id, cache=False): """ Get a case from the database. Parameters ---------- case_id : int or str The integer index or string-identifier of the case to be retrieved. cache : bool If True, cache the case so it does not have to be fetched on next access. Returns ------- Case The specified case from the driver_iterations and driver_derivatives tables. """ # check to see if we've already cached this case if isinstance(case_id, int): case_id = self._get_iteration_coordinate(case_id) # return cached case if present, else fetch it if case_id in self._cases: return self._cases[case_id] # Get an unscaled case if does not already exist in _cases with sqlite3.connect(self._filename) as con: con.row_factory = sqlite3.Row cur = con.cursor() # fetch driver iteration data cur.execute("SELECT * FROM driver_iterations WHERE " "iteration_coordinate=:iteration_coordinate", {"iteration_coordinate": case_id}) row = cur.fetchone() # fetch associated derivative data, if available if row and self._format_version > 1: cur.execute("SELECT * FROM driver_derivatives WHERE " "iteration_coordinate=:iteration_coordinate", {"iteration_coordinate": case_id}) derivs_row = cur.fetchone() if derivs_row: # convert row to a regular dict and add jacobian row = dict(zip(row.keys(), row)) row['jacobian'] = derivs_row['derivatives'] con.close() # if found, create Case object (and cache it if requested) else return None if row: case = Case('driver', row, self._prom2abs, self._abs2prom, self._abs2meta, self._conns, self._var_info, self._format_version) if cache: self._cases[case_id] = case return case else: return None
[docs] def list_sources(self): """ Get the list of sources that recorded data in this table (just the driver). Returns ------- list List of sources. """ return ['driver']
def _get_source(self, iteration_coordinate): """ Get the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str 'driver' (all cases in this table come from the driver). """ return 'driver' def _get_row_source(self, row_id): """ Get the source of the case at the specified row of this table. Parameters ---------- row_id : int The row_id of the case in the table. Returns ------- str The source of the case. """ return 'driver'
[docs]class SystemCases(CaseTable): """ Cases specific to the entries that might be recorded in a System iteration. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. conns : dict Dictionary of all model connections. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, conns, var_info): """ Initialize. """ super().__init__(filename, format_version, 'system_iterations', 'iteration_coordinate', giter, prom2abs, abs2prom, abs2meta, conns, var_info)
[docs]class SolverCases(CaseTable): """ Cases specific to the entries that might be recorded in a Solver iteration. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. conns : dict Dictionary of all model connections. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, conns, var_info): """ Initialize. """ super().__init__(filename, format_version, 'solver_iterations', 'iteration_coordinate', giter, prom2abs, abs2prom, abs2meta, conns, var_info)
def _get_source(self, iteration_coordinate): """ Get pathname of solver that is the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str The pathname of the solver that is the source of the iteration. """ source_system = get_source_system(iteration_coordinate) system_solve = source_system.split('.')[-1] + '._solve_nonlinear' system_coord_len = iteration_coordinate.index(system_solve) + len(system_solve) system_coord_nodes = len(iteration_coordinate[:system_coord_len].split('|')) + 1 num_coord_nodes = iteration_coordinate.count('|') + 1 if num_coord_nodes == system_coord_nodes + 2: return source_system + '.nonlinear_solver' elif num_coord_nodes == system_coord_nodes + 4: return source_system + '.nonlinear_solver.linesearch' else: raise RuntimeError("Can't parse solver iteration coordinate: %s" % iteration_coordinate)
[docs]class ProblemCases(CaseTable): """ Cases specific to the entries that might be recorded in a Driver iteration. Parameters ---------- filename : str The name of the recording file from which to instantiate the case reader. format_version : int The version of the format assumed when loading the file. giter : list of tuple The global iterations table. prom2abs : {'input': dict, 'output': dict} Dictionary mapping promoted names to absolute names. abs2prom : {'input': dict, 'output': dict} Dictionary mapping absolute names to promoted names. abs2meta : dict Dictionary mapping absolute variable names to variable metadata. conns : dict Dictionary of all model connections. var_info : dict Dictionary with information about variables (scaling, indices, execution order). """
[docs] def __init__(self, filename, format_version, giter, prom2abs, abs2prom, abs2meta, conns, var_info): """ Initialize. """ super().__init__(filename, format_version, 'problem_cases', 'case_name', giter, prom2abs, abs2prom, abs2meta, conns, var_info)
[docs] def list_sources(self): """ Get the list of sources that recorded data in this table (just the problem). Returns ------- list List of sources. """ return ['problem']
def _get_source(self, iteration_coordinate): """ Get the source of the iteration. Parameters ---------- iteration_coordinate : str The full unique identifier for this iteration. Returns ------- str 'problem' (all cases in this table come from the problem). """ return 'problem' def _get_row_source(self, row_id): """ Get the source of the case at the specified row of this table. Parameters ---------- row_id : int The row_id of the case in the table. Returns ------- str The source of the case. """ return 'problem'