"""
Define the LinearRHSChecker class.
LinearRHSChecker manages caching of solutions and right-hand sides for linear solves.
"""
from collections import deque
import atexit
import numpy as np
from math import isclose
from openmdao.utils.array_utils import allclose, allzero
from openmdao.utils.om_warnings import issue_warning, SolverWarning
from openmdao.visualization.tables.table_builder import generate_table
_cache_stats = {}
def _print_stats():
"""
Print out cache statistics at the end of the run.
"""
if _cache_stats:
headers = ['System', 'Eq Hits', 'Neg Hits', 'Parallel Hits', 'Zero Hits', 'Misses',
'Resets']
for prob_name, dct in _cache_stats.items():
rows = []
for syspath, stats in dct.items():
rows.append([syspath, stats['eqhits'], stats['neghits'], stats['parhits'],
stats['zerohits'], stats['misses'], stats['resets']])
print(f"\nCache Statistics for Problem '{prob_name}':")
generate_table(rows, tablefmt='simple_grid', headers=headers).display()
[docs]
class LinearRHSChecker(object):
"""
Class that manages caching of linear solutions.
Parameters
----------
system : System
The system that owns the solver that owns this LinearRHSChecker.
max_cache_entries : int
Maximum number of solutions to cache. Defaults to 3.
check_zero : bool
If True, check if the RHS vector is zero. Defaults to False.
rtol : float
Relative tolerance for equivalence checks. Defaults to 3e-16.
atol : float
Absolute tolerance for equivalence checks. Defaults to 3e-16.
collect_stats : bool
If True, collect cache statistics. Defaults to False.
verbose : bool
If True, print out whenever a cache hit occurs. Defaults to False.
Attributes
----------
_caches : list
List of cached solutions.
_ncompute_totals : int
Total number of compute_totals calls. Used to determine when to
reset the cache.
_check_zero : bool
If True, check if the RHS vector is zero.
_rtol : float
Relative tolerance for equivalence check.
_atol : float
Absolute tolerance for equivalence check.
_stats : dict or None
Dictionary to store cache statistics.
_verbose : bool
If True, print out whenever a cache hit occurs.
_solver_msginfo : str
The message info for the solver that owns this LinearRHSChecker.
"""
options = ('check_zero', 'rtol', 'atol', 'max_cache_entries', 'collect_stats',
'auto', 'verbose')
[docs]
def __init__(self, system, max_cache_entries=3, check_zero=False, rtol=3e-16, atol=3e-16,
collect_stats=False, verbose=False):
"""
Initialize the LinearRHSChecker.
"""
global _cache_stats
self._caches = deque(maxlen=max_cache_entries)
self._ncompute_totals = system._problem_meta['ncompute_totals']
self._check_zero = check_zero
self._rtol = rtol
self._atol = atol
# print out cache stats at the end of the run
if collect_stats:
self._stats = {
'eqhits': 0, 'neghits': 0, 'parhits': 0, 'zerohits': 0, 'misses': 0, 'resets': 0
}
prob_name = system._problem_meta['name']
if not _cache_stats:
atexit.register(_print_stats)
if prob_name not in _cache_stats:
_cache_stats[prob_name] = {}
_cache_stats[prob_name][system.pathname] = self._stats
else:
self._stats = None
self._verbose = verbose
self._solver_msginfo = system.linear_solver.msginfo
[docs]
@staticmethod
def check_options(system, options):
"""
Check the options dictionary for the presence of LinearRHSChecker options.
Parameters
----------
system : System
The system that owns the solver that owns this LinearRHSChecker.
options : dict
The options dictionary.
"""
invalid = set(options).difference(LinearRHSChecker.options)
if invalid:
if len(invalid) == 1:
invalid = f" '{invalid.pop()}'"
else:
invalid = f"s {sorted(invalid)}"
raise ValueError(f"{system.linear_solver.msginfo}: unrecognized 'rhs_checking' "
f"option{invalid}. Valid options are {LinearRHSChecker.options}.")
[docs]
@staticmethod
def create(system, opts):
"""
Conditionally create a LinearRHSChecker instance.
Parameters
----------
system : System
The system that owns the solver that owns this LinearRHSChecker.
opts : dict or bool
Options for the LinearRHSChecker. If True, the LinearRHSChecker will be created
with default options. If a dict, the values will override the defaults.
Returns
-------
LinearRHSChecker or None
A LinearRHSChecker instance if it was created, None otherwise.
"""
redundant_adj = system.pathname in system._relevance.get_redundant_adjoint_systems()
if isinstance(opts, dict):
LinearRHSChecker.check_options(system, opts)
if opts.get('auto', False):
opts = opts.copy()
opts.pop('auto')
if redundant_adj:
print(f"Using automated rhs checking for '{system.linear_solver.msginfo}' "
"because it has redundant adjoint solves and 'auto' was set in the "
"'rhs_checking' options.")
else:
return None
elif not opts:
if redundant_adj:
print(f"\n'rhs_checking' is disabled for '{system.linear_solver.msginfo}'"
" but that solver has redundant adjoint solves. If it is "
"expensive to compute derivatives for this solver, turning on "
"'rhs_checking' may improve performance.\n")
return None
else:
opts = dict(max_cache_entries=3, check_zero=False, rtol=3e-16, atol=3e-16,
collect_stats=False, verbose=False)
if redundant_adj:
return LinearRHSChecker(system, **opts)
else:
if opts.get('max_cache_entries', 3) > 0:
issue_warning(f"{system.linear_solver.msginfo}: 'rhs_checking' is active "
"but no redundant adjoint dependencies were found, so caching"
" has been disabled.", category=SolverWarning)
if opts.get('check_zero', False):
opts['max_cache_entries'] = 0
return LinearRHSChecker(system, **opts)
[docs]
def clear(self):
"""
Clear the cache.
"""
self._caches.clear()
[docs]
def add_solution(self, rhs, solution, copy):
"""
Add a solution to the cache.
Parameters
----------
rhs : ndarray
The RHS vector.
solution : ndarray
The solution vector.
copy : bool
If True, make a copy of the RHS and solution vectors before storing them.
"""
if self._caches.maxlen > 0:
if copy:
rhs = rhs.copy()
solution = solution.copy()
self._caches.append((rhs, solution))
[docs]
def get_solution(self, rhs_arr, system):
"""
Return a cached solution if the RHS vector matches a cached vector.
Also indicates if the RHS vector is zero.
Parameters
----------
rhs_arr : ndarray
The RHS vector.
system : System
The system that owns the solver that owns this LinearRHSChecker.
Returns
-------
ndarray or None
The cached solution if the RHS vector matches a cached vector, or None if no match
is found.
bool
True if the rhs array is zero.
"""
if system.under_complex_step:
return None, False
if self._check_zero:
if system.comm.size > 1:
# check if the whole distributed array is zero
if system.comm.allreduce(int(allzero(rhs_arr))) == system.comm.size:
if self._stats is not None:
self._stats['zerohits'] += 1
if self._verbose:
print(f"{self._solver_msginfo}: Skipping linear solve. RHS is zero.")
return None, True
elif allzero(rhs_arr):
if self._stats is not None:
self._stats['zerohits'] += 1
if self._verbose:
print(f"{self._solver_msginfo}: Skipping linear solve. RHS is zero.")
return None, True
if self._caches.maxlen == 0:
return None, False
# if there is no intersection between the current seed vars and the responses that cause
# redundant adjoint solves, then we don't need to check the cache.
seed_vars = system._problem_meta['seed_vars']
try:
redundant = system._relevance.get_redundant_adjoint_systems()[system.pathname]
except KeyError:
return None, False
if seed_vars is None or not redundant.intersection(seed_vars):
return None, False
sol_array = None
if self._ncompute_totals != system._problem_meta['ncompute_totals']:
# reset the cache if we've run compute_totals since the last time we used the cache
self.clear()
self._ncompute_totals = system._problem_meta['ncompute_totals']
if self._stats is not None:
self._stats['resets'] += 1
for i in range(len(self._caches) - 1, -1, -1):
rhs_cache, sol_cache = self._caches[i]
# Check if the RHS vector is the same as a cached vector. This part is not necessary,
# but is less expensive than checking if two vectors are parallel.
if allclose(rhs_arr, rhs_cache, rtol=self._rtol, atol=self._atol):
sol_array = sol_cache
if self._stats is not None:
self._stats['eqhits'] += 1
if self._verbose:
print(f"{self._solver_msginfo}: Skipping linear solve. RHS matches previous "
"solution.")
break
# Check if the RHS vector is equal to -1 * cached vector.
if allclose(rhs_arr, -rhs_cache, rtol=self._rtol, atol=self._atol):
sol_array = -sol_cache
if self._stats is not None:
self._stats['neghits'] += 1
if self._verbose:
print(f"{self._solver_msginfo}: Skipping linear solve. RHS matches negative of "
"previous solution.")
break
# Check if the RHS vector and a cached vector are parallel
dot_product = np.dot(rhs_arr, rhs_cache)
rhs_norm = np.linalg.norm(rhs_arr)
rhs_cache_norm = np.linalg.norm(rhs_cache)
if isclose(abs(dot_product), rhs_norm * rhs_cache_norm,
rel_tol=self._rtol, abs_tol=self._atol):
# two vectors are parallel, thus we can use the cache.
if rhs_cache_norm > 0.0:
scaler = dot_product / rhs_cache_norm**2
sol_array = sol_cache * scaler
if self._stats is not None:
self._stats['parhits'] += 1
if self._verbose:
print(f"{self._solver_msginfo}: Skipping linear solve. RHS is parallel to "
f"previous solution. (scaler={scaler})")
break
matched_cache = int(sol_array is not None)
if system.comm.size > 1:
# only match if the entire distributed array matches the cache
if system.comm.allreduce(matched_cache) != system.comm.size:
matched_cache = 0
if not matched_cache and self._stats is not None:
self._stats['misses'] += 1
return sol_array if matched_cache else None, False