Source code for openmdao.lib.casehandlers.dbcase

"""A CaseRecorder and CaseIterator that store the cases in a relational
DB (Python's sqlite.)
"""

import sys
import sqlite3
import uuid
from cPickle import dumps, loads, HIGHEST_PROTOCOL, UnpicklingError
from optparse import OptionParser

# pylint: disable-msg=E0611,F0401
from openmdao.main.interfaces import implements, ICaseRecorder, ICaseIterator
from openmdao.main.case import Case

_casetable_attrs = set(['id','uuid','parent','label','msg','retries', \
                        'model_id','timeEnter'])
_vartable_attrs = set(['var_id','name','case_id','sense','value'])

def _query_split(query):
    """Return a tuple of lhs, relation, rhs after splitting on 
    a list of allowed operators.
    """
    # FIXME: make this more robust
    for op in ['<>', '<=', '>=', '==', '!=', '<', '>', '=']:
        parts = query.split(op, 1)
        if len(parts) > 1:
            return (parts[0].strip(), op, parts[1].strip())
    else:
        raise ValueError("No allowable operator found in query '%s'" % query)


[docs]class DBCaseIterator(object): """Pulls Cases from a relational DB (sqlite). It doesn't support general sql queries, but it does allow for a series of boolean selectors, e.g., 'x<=y', that are ANDed together. """ implements(ICaseIterator) def __init__(self, dbfile=':memory:', selectors=None, connection=None): if connection is not None: self._dbfile = dbfile self._connection = connection else: self._connection = None self.dbfile = dbfile self.selectors = selectors self._connection.text_factory = sqlite3.OptimizedUnicode @property def dbfile(self): """The name of the database. This can be a filename or :memory: for an in-memory database. """ return self._dbfile @dbfile.setter
[docs] def dbfile(self, value): """Set the DB file and connect to it.""" self._dbfile = value if self._connection: self._connection.close() self._connection = sqlite3.connect(value)
def __iter__(self): return self._next_case() def _next_case(self): """ Generator which returns Cases one at a time. """ # figure out which selectors are for cases and which are for variables sql = ["SELECT * FROM cases"] if self.selectors is not None: for sel in self.selectors: rhs,rel,lhs = _query_split(sel) if rhs in _casetable_attrs: if len(sql) == 1: sql.append("WHERE %s%s%s" % (rhs,rel,lhs)) else: sql.append("AND %s%s%s" % (rhs,rel,lhs)) casecur = self._connection.cursor() casecur.execute(' '.join(sql)) sql = ['SELECT var_id,name,case_id,sense,value from casevars WHERE case_id=%s'] if self.selectors is not None: for sel in self.selectors: rhs,rel,lhs = _query_split(sel) if rhs in _vartable_attrs: sql.append("AND %s%s%s" % (rhs,rel,lhs)) combined = ' '.join(sql) varcur = self._connection.cursor() for cid,text_id,parent,label,msg,retries,model_id,timeEnter in casecur: varcur.execute(combined % cid) inputs = [] outputs = [] for var_id, vname, case_id, sense, value in varcur: if not isinstance(value, (float,int,str)): try: value = loads(str(value)) except UnpicklingError as err: raise UnpicklingError("can't unpickle value '%s' for case '%s' from database: %s" % (vname, cname, str(err))) if sense=='i': inputs.append((vname, value)) else: outputs.append((vname, value)) if len(inputs) > 0 or len(outputs) > 0: yield Case(inputs=inputs, outputs=outputs, retries=retries,msg=msg,label=label, case_uuid=text_id, parent_uuid=parent)
[docs] def get_attributes(self, io_only=True): """ We need a custom get_attributes because we aren't using Traits to manage our changeable settings. This is unfortunate, and should be changed to something that automates this somehow.""" attrs = {} attrs['type'] = type(self).__name__ variables = [] attr = {} attr['name'] = "dbfile" attr['type'] = type(self.dbfile).__name__ attr['value'] = str(self.dbfile) attr['connected'] = '' attr['desc'] = 'Name of the database file to be iterated. Default ' + \ 'is ":memory:", which reads the database from memory.' variables.append(attr) attr = {} attr['name'] = "selectors" attr['type'] = type(self.selectors).__name__ attr['value'] = str(self.selectors) attr['connected'] = '' attr['desc'] = 'String of additional SQL queries to apply to the case selection.' variables.append(attr) attrs["Inputs"] = variables return attrs
[docs]class DBCaseRecorder(object): """Records Cases to a relational DB (sqlite). Values other than floats, ints or strings are pickled and are opaque to SQL queries. """ implements(ICaseRecorder) def __init__(self, dbfile=':memory:', model_id='', append=False): self.dbfile = dbfile # this creates the connection self.model_id = model_id if append: exstr = 'if not exists' else: exstr = '' self._connection.execute(""" create table %s cases( id INTEGER PRIMARY KEY, uuid TEXT, parent TEXT, label TEXT, msg TEXT, retries INTEGER, model_id TEXT, timeEnter TEXT )""" % exstr) self._connection.execute(""" create table %s casevars( var_id INTEGER PRIMARY KEY, name TEXT, case_id INTEGER, sense TEXT, value BLOB )""" % exstr) @property def dbfile(self): """The name of the database. This can be a filename or :memory: for an in-memory database. """ return self._dbfile @dbfile.setter
[docs] def dbfile(self, value): """Set the DB file and connect to it.""" self._dbfile = value self._connection = sqlite3.connect(value) self._iter_conn = sqlite3.connect(value)
[docs] def record(self, case): """Record the given Case.""" if self._connection is None: raise RuntimeError('Attempt to record on closed recorder') cur = self._connection.cursor() cur.execute("""insert into cases(id,uuid,parent,label,msg,retries,model_id,timeEnter) values (?,?,?,?,?,?,?,DATETIME('NOW'))""", (None, case.uuid, case.parent_uuid, case.label, case.msg or '', case.retries, self.model_id)) case_id = cur.lastrowid # insert the inputs and outputs into the vars table. Pickle them if they're not one of the # built-in types int, float, or str. for name,value in case.items(iotype='in'): if isinstance(value, (float,int,str)): v = (None, name, case_id, 'i', value) else: v = (None, name, case_id, 'i', sqlite3.Binary(dumps(value,HIGHEST_PROTOCOL))) cur.execute("insert into casevars(var_id,name,case_id,sense,value) values(?,?,?,?,?)", v) for name,value in case.items(iotype='out'): if isinstance(value, (float,int,str)): v = (None, name, case_id, 'o', value) else: v = (None, name, case_id, 'o', sqlite3.Binary(dumps(value,HIGHEST_PROTOCOL))) cur.execute("insert into casevars(var_id,name,case_id,sense,value) values(?,?,?,?,?)", v) self._connection.commit()
[docs] def close(self): """Commit and close DB connection if not using ``:memory:``.""" if self._connection is not None and self._dbfile != ':memory:': self._connection.commit() self._connection.close() self._connection = None
[docs] def get_iterator(self): """Return a DBCaseIterator that points to our current DB.""" return DBCaseIterator(dbfile=self._dbfile, connection=self._connection)
[docs] def get_attributes(self, io_only=True): """ We need a custom get_attributes because we aren't using Traits to manage our changeable settings. This is unfortunate, and should be changed to something that automates this somehow.""" attrs = {} attrs['type'] = type(self).__name__ variables = [] attr = {} attr['name'] = "dbfile" attr['type'] = type(self.dbfile).__name__ attr['value'] = str(self.dbfile) attr['connected'] = '' attr['desc'] = 'Name of the database file to be recorded. Default ' + \ 'is ":memory:", which writes the database to memory.' variables.append(attr) attrs["Inputs"] = variables return attrs
""" Utility functions related to plotting data """
[docs]def list_db_vars(dbname): """ Return the set of the names of the variables found in the specified case DB file. dbname: str The name of the sqlite DB file. """ connection = sqlite3.connect(dbname) varcur = connection.cursor() varcur.execute("SELECT name from casevars") varnames = set([v for v in varcur]) return varnames
[docs]def case_db_to_dict(dbname, varnames, case_sql='', var_sql='', include_errors=False): """ Retrieve the values of specified variables from a sqlite DB containing Case data. Returns a dict containing a list of values for each entry, keyed on variable name. Only data from cases containing ALL of the specified variables will be returned so that all data values with the same index will correspond to the same case. dbname: str The name of the sqlite DB file. varnames: list[str] Iterator of names of variables to be retrieved. case_sql: str (optional) SQL syntax that will be placed in the WHERE clause for Case retrieval. var_sql: str (optional) SQL syntax that will be placed in the WHERE clause for variable retrieval. include_errors: bool (optional) [False] If True, include data from cases that reported an error. """ connection = sqlite3.connect(dbname) vardict = dict([(name,[]) for name in varnames]) sql = ["SELECT id FROM cases"] qlist = [] if case_sql: qlist.append(case_sql) if not include_errors: qlist.append("msg = ''") if qlist: sql.append("WHERE %s" % ' AND '.join(qlist)) casecur = connection.cursor() casecur.execute(' '.join(sql)) sql = ["SELECT name, value from casevars WHERE case_id=%s"] vars_added = False for i,name in enumerate(vardict.keys()): if i==0: sql.append("AND (") else: sql.append("OR") sql.append("name='%s'" % name) vars_added = True if vars_added: sql.append(")") if var_sql: sql.append(" AND %s" % var_sql) combined = ' '.join(sql) varcur = connection.cursor() for case_id in casecur: casedict = {} varcur.execute(combined % case_id) for vname, value in varcur: if not isinstance(value, (float,int,str)): try: value = loads(str(value)) except UnpicklingError as err: raise UnpicklingError("can't unpickle value '%s' from database: %s" % (vname, str(err))) casedict[vname] = value if len(casedict) != len(vardict): continue # case doesn't contain a complete set of specified vars, so skip it to avoid data mismatches for name, value in casedict.items(): vardict[name].append(value) return vardict
def _get_lines(dbname, xnames, ynames, case_sql=None, var_sql=None): """Return a list of lines which will be fed to the plot function.""" vardict = case_db_to_dict(dbname, xnames+ynames, case_sql, var_sql) lines = [] yvals = [] xvals = [] for i,name in enumerate(ynames): yvals.append(vardict[name]) if len(xnames) == 0: xvals.append(range(len(vardict[name]))) elif len(xnames) == 1: xvals.append(vardict[xnames[0]]) else: xvals.append(vardict[xnames[i]]) for xdata,ydata in zip(xvals, yvals): lines.append((xdata, ydata)) return lines
[docs]def displayXY(dbname, xnames, ynames, case_sql=None, var_sql=None, title='', grid=False, xlabel='', ylabel=''): """Display an XY plot using Case data from a sqlite DB. dbname: str Name of the database file. xnames: list[str] Names of X variables. ynames: list[str] Names of Y variables. case_sql: str (optional) SQL syntax that will be placed in the WHERE clause for Case retrieval. var_sql: str (optional) SQL syntax that will be placed in the WHERE clause for variable retrieval. title: str (optional) Plot title. grid: bool (optional) If True, a grid is drawn on the plot. xlabel: str (optional) X axis label. ylabel: str (optional) Y axis label. """ try: if 'matplotlib' not in sys.modules: import matplotlib if sys.platform == 'darwin': matplotlib.use('MacOSX') else: try: import wx except ImportError: matplotlib.use('TkAgg') else: matplotlib.use('WxAgg') import matplotlib.pyplot as plt except ImportError: print 'matplotlib not found' return fig = plt.figure() fig.add_subplot(111) for i,line in enumerate(_get_lines(dbname, xnames, ynames, case_sql, var_sql)): args = [] kwargs = {} args.append(line[0]) args.append(line[1]) kwargs['label'] = '%s' % ynames[i] plt.plot(*args, **kwargs) if grid: plt.grid(True) if xlabel: plt.xlabel(xlabel) if ylabel: plt.ylabel(ylabel) if title: plt.title(title) plt.legend() plt.show()
[docs]def cmdlineXYplot(): """Based on command line options, display an XY plot using data from a sqlite Case DB. """ parser = OptionParser() parser.add_option("-x", "", action="store", type="string", dest="xnames", help="names of x variables") parser.add_option("-y", "", action="store", type="string", dest="ynames", help="names of y variables") parser.add_option("-d", "--dbfile", action="store", type="string", dest="dbname", help="database filename") parser.add_option("-t", "--title", action="store", type="string", dest="title", help="plot title",) parser.add_option("--xlabel", action="store", type="string", dest="xlabel", help="x axis label") parser.add_option("--ylabel", action="store", type="string", dest="ylabel", help="y axis label") parser.add_option("-g", "--grid", action="store_true", dest="grid", help="makes grid visible") parser.add_option("--cases", action="store", type="string", dest="case_sql", help="sql syntax to select certain cases") parser.add_option("--vars", action="store", type="string", dest="var_sql", help="sql syntax to select certain vars") parser.add_option("-l", "--list", action="store_true", dest="listvars", help="lists names of variables found in the database") (options, args) = parser.parse_args(sys.argv[1:]) if options.listvars: print for name in sorted(list_db_vars(options.dbname)): print name print sys.exit(0) if len(args) > 0 or not options.ynames or not options.dbname: parser.print_help() sys.exit(-1) if options.xnames: xs = options.xnames.split(',') else: xs = [] ys = options.ynames.split(',') if len(xs) > 1 and len(xs) != len(ys): print "Number of x variables doesn't match number of y variables." sys.exit(-1) displayXY(options.dbname, xs, ys, options.case_sql, options.var_sql, title=options.title, grid=options.grid, xlabel=options.xlabel, ylabel=options.ylabel)
OpenMDAO Home