Source code for openmdao.drivers.analysis_generator

"""
Provide generators for use with AnalysisDriver.

These generators are pythonic, lazy generators which, when provided with a dictionary
of variables and values to be tested, produce some set of sample values to be evaluated.

"""

from collections.abc import Iterator
import csv
import itertools


[docs]class AnalysisGenerator(Iterator): """ Provide a generator which provides case data for AnalysisDriver. Parameters ---------- var_dict : dict A dictionary whose keys are promoted paths of variables to be set, and whose values are the arguments to `set_val`. Attributes ---------- _iter : Iterator The underlying iterator for variable values. _run_count : int A running count of the samples obtained from the iterator. _var_dict : dict An internal copy of the var_dict used to create the generator. """
[docs] def __init__(self, var_dict): """ Instantiate the base class for AnalysisGenerators. Parameters ---------- var_dict : dict A dictionary mapping a variable name to values to be assumed, as well as optional units and indices specifications. """ super().__init__() self._run_count = 0 self._var_dict = var_dict self._iter = None self._setup()
def _setup(self): """ Reset the run counter and instantiate the internal Iterator. Subclasses of AnalysisGenerator should override this method to define self._iter. """ self._run_count = 0 def _get_sampled_vars(self): """ Return the set of variable names whose value are provided by this generator. """ return self._var_dict.keys() def __next__(self): """ Provide a dictionary of the next point to be analyzed. The key of each entry is the promoted path of var whose values are to be set. The associated value is the values to set (required), units (options), and indices (optional). Raises ------ StopIteration When all analysis var_dict have been exhausted. Returns ------- dict A dictionary containing the promoted paths of variables to be set by the AnalysisDriver """ d = {} vals = next(self._iter) for i, name in enumerate(self._var_dict.keys()): d[name] = {'val': vals[i], 'units': self._var_dict[name].get('units', None), 'indices': self._var_dict[name].get('indices', None)} self._run_count += 1 return d
[docs]class ZipGenerator(AnalysisGenerator): """ A generator which provides case data for AnalysisDriver by zipping values of each factor. Parameters ---------- var_dict : dict A dictionary which maps promoted path names of variables to be set in each itearation with their values to be assumed (required), units (optional), and indices (optional). """ def _setup(self): """ Set up the iterator which provides each case. Raises ------ ValueError Raised if the length of var_dict for each case are not all the same size. """ super()._setup() sampler = (c['val'] for c in self._var_dict.values()) _lens = {name: len(meta['val']) for name, meta in self._var_dict.items()} if len(set([_l for _l in _lens.values()])) != 1: raise ValueError('ZipGenerator requires that val ' f'for all var_dict have the same length:\n{_lens}') sampler = (c['val'] for c in self._var_dict.values()) self._iter = zip(*sampler)
[docs]class ProductGenerator(AnalysisGenerator): """ A generator which provides full-factorial case data for AnalysisDriver. Parameters ---------- var_dict : dict A dictionary which maps promoted path names of variables to be set in each itearation with their values to be assumed (required), units (optional), and indices (optional). """ def _setup(self): """ Set up the iterator which provides each case. Raises ------ ValueError Raised if the length of var_dict for each case are not all the same size. """ super()._setup() sampler = (c['val'] for c in self._var_dict.values()) self._iter = itertools.product(*sampler)
[docs]class CSVGenerator(AnalysisGenerator): """ A generator which provides cases for AnalysisDriver by pulling rows from a CSV file. Parameters ---------- filename : str The filename for the CSV file containing the variable data. has_units : bool If True, the second line of the CSV contains the units of each variable. has_indices : bool If True, the line after units (if present) contains the indices being set. Attributes ---------- _filename : str The filename of the CSV file providing the samples. _has_units : bool True if the CSV file contains a row of the units for each variable. _has_indices : bool True if the CSV file contains a row of indices being provided for each variable. If units are present, indices will be on the line following units. _csv_file : file The file object for the CSV file. _csv_reader : DictReader The reader object for the CSV file. _var_names : set of str The set of variable names provided by this CSVGenerator. _ret_val : dict The dict which is returned by each call to __next__. """
[docs] def __init__(self, filename, has_units=False, has_indices=False): """ Instantiate CSVGenerator. Parameters ---------- filename : str The filename for the CSV file containing the variable data. has_units : bool If True, the second line of the CSV contains the units of each variable. has_indices : bool If True, the line after units (if present) contains the indices being set. """ self._filename = filename self._has_units = has_units self._has_indices = has_indices self._csv_file = open(self._filename, 'r') self._csv_reader = csv.DictReader(self._csv_file) self._var_names = set(self._csv_reader.fieldnames) self._ret_val = {var: {'units': None, 'indices': None} for var in self._csv_reader.fieldnames} if self._has_units: var_units_dict = next(self._csv_reader) for var, units in var_units_dict.items(): self._ret_val[var]['units'] = None if not units else units if self._has_indices: var_idxs_dict = next(self._csv_reader) for var, idxs in var_idxs_dict.items(): idxs = eval(idxs, {'__builtins__': {}}) # nosec: scope limited self._ret_val[var]['indices'] = idxs
def _get_sampled_vars(self): return self._var_names def __next__(self): """ Provide the data from the next row of the csv file. """ try: var_val_dict = next(self._csv_reader) for var, val in var_val_dict.items(): self._ret_val[var]['val'] = val return self._ret_val except StopIteration: # Close the file and propagate the exception self._csv_file.close() raise def __del__(self): """ Ensure the file is closed if we don't exhaust the iterator. """ if self._csv_file and not self._csv_file.closed: self._csv_file.close()
[docs]class SequenceGenerator: """ A generator which provides samples from python lists or tuples. Internally this generator converts the list or tuple to a deque and then consumes it as it iterates over it. Parameters ---------- container : container A python container, excluding strings, bytes, or bytearray. Attributes ---------- _sampled_vars : list(str) A list of the variables in the model being sampled. _iter : Iterator The internal iterator over the users case data. Raises ------ StopIteration When given list or tuple is exhausted. """
[docs] def __init__(self, container): """ Instantiate a SequenceGenerator with the given container of samples. """ self._sampled_vars = [k for k in list(container)[0].keys()] self._iter = iter(container)
[docs] def __iter__(self): """ Provide the python iterator for this instance. """ return self
def __next__(self): """ Provide the next values for the variables in the generator. """ return next(self._iter) def _get_sampled_vars(self): """ Return the set of variable names whose value are provided by this generator. """ return self._sampled_vars