"""
Utilities for the use of jax in combination with OpenMDAO.
"""
import sys
import os
import ast
import textwrap
import inspect
import weakref
from itertools import chain
from collections import defaultdict
import importlib
import numpy as np
from openmdao.visualization.tables.table_builder import generate_table
from openmdao.utils.code_utils import _get_long_name, remove_src_blocks, replace_src_block, \
get_partials_deps
from openmdao.utils.file_utils import get_module_path, _load_and_exec
[docs]
def jit_stub(f, *args, **kwargs):
"""
Provide a dummy jit decorator for use if jax is not available.
Parameters
----------
f : Callable
The function or method to be wrapped.
*args : list
Positional arguments.
**kwargs : dict
Keyword arguments.
Returns
-------
Callable
The decorated function.
"""
return f
try:
import jax
jax.config.update("jax_enable_x64", True) # jax by default uses 32 bit floats
import jax.numpy as jnp
from jax import jit, tree_util
except ImportError:
jax = None
jnp = np
jit = jit_stub
[docs]
def register_jax_component(comp_class):
"""
Provide a class decorator that registers the given class as a pytree_node.
This allows jax to use jit compilation on the methods of this class if they
reference attributes of the class itself, such as `self.options`.
Note that this decorator is not necessary if the given class does not reference
`self` in any methods to which `jax.jit` is applied.
Parameters
----------
comp_class : class
The decorated class.
Returns
-------
object
The same class given as an argument.
Raises
------
NotImplementedError
If this class does not define the `_tree_flatten` and _tree_unflatten` methods.
RuntimeError
If jax is not available.
"""
if jax is None:
raise RuntimeError("jax is not available. "
"Try 'pip install openmdao[jax]' with Python>=3.8.")
if not hasattr(comp_class, '_tree_flatten'):
raise NotImplementedError(f'class {comp_class} does not implement method _tree_flatten.'
f'\nCannot register {comp_class} as a jax jit-compatible '
f'component.')
if not hasattr(comp_class, '_tree_unflatten'):
raise NotImplementedError(f'class {comp_class} does not implement method _tree_unflatten.'
f'\nCannot register class {comp_class} as a jax jit-compatible '
f'component.')
jax.tree_util.register_pytree_node(comp_class,
comp_class._tree_flatten,
comp_class._tree_unflatten)
return comp_class
[docs]
def dump_jaxpr(closed_jaxpr):
"""
Print out the contents of a Jaxpr.
Parameters
----------
closed_jaxpr : jax.core.ClosedJaxpr
The Jaxpr to be examined.
"""
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("in_avals", closed_jaxpr.in_avals, closed_jaxpr.in_avals[0].dtype)
print("outvars:", jaxpr.outvars)
print("out_avals:", closed_jaxpr.out_avals)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
[docs]
class CompJaxifyBase(ast.NodeTransformer):
"""
An ast.NodeTransformer that transforms a function definition to jax compatible form.
So original func becomes compute_primal(self, arg1, arg2, ...).
If the component has discrete inputs, they will be passed individually into compute_primal
*before* the continuous inputs. If the component has discrete outputs, they will be assigned
to local variables of the same name within the function and set back into the discrete
outputs dict just prior to the return from the function.
Parameters
----------
comp : Component
The Component whose function is to be transformed. This NodeTransformer may only
be used after the Component has had its _setup_var_data method called, because that
determines the ordering of the inputs and outputs.
funcname : str
The name of the function to be transformed.
verbose : bool
If True, the transformed function will be printed to stdout.
Attributes
----------
_comp : weakref.ref
A weak reference to the Component whose function is being transformed.
_funcname : str
The name of the function being transformed.
compute_primal : function
The compute_primal function created from the original function.
_orig_args : list
The original argument names of the original function.
_new_ast : ast node
The new ast node created from the original function.
get_self_statics : function
A function that returns the static args for the Component as a single tuple.
"""
# these ops require static objects so their args should not be traced. Traced array ops should
# use jnp and static ones should use np.
_static_ops = {'reshape'}
_np_names = {'np', 'numpy'}
[docs]
def __init__(self, comp, funcname, verbose=False): # noqa
self._comp = weakref.ref(comp)
self._funcname = funcname
func = getattr(comp, funcname)
if 'jnp' not in func.__globals__:
func.__globals__['jnp'] = jnp
namespace = func.__globals__.copy()
static_attrs, static_dcts = get_self_static_attrs(func)
self_statics = ['_self_statics_'] if static_attrs or static_dcts else []
if self_statics:
self.get_self_statics = self._get_self_statics_func(static_attrs, static_dcts)
else:
self.get_self_statics = None
self._orig_args = list(inspect.signature(func).parameters)
node = self.visit(ast.parse(textwrap.dedent(inspect.getsource(func)), mode='exec'))
self._new_ast = ast.fix_missing_locations(node)
code = compile(self._new_ast, '<ast>', 'exec')
exec(code, namespace) # nosec
self.compute_primal = namespace['compute_primal']
if verbose:
print(f"\n{comp.pathname}:\n{self.get_compute_primal_src()}\n")
[docs]
def get_compute_primal_src(self):
"""
Return the source code of the transformed function.
Returns
-------
str
The source code of the transformed function.
"""
return ast.unparse(self._new_ast)
[docs]
def get_class_src(self):
"""
Return the source code of the class containing the transformed function.
Returns
-------
str
The source code of the class containing the transformed function.
"""
try:
class_src = textwrap.dedent(inspect.getsource(self._comp().__class__))
except Exception:
raise RuntimeError(f"Couldn't obtain class source for {self._comp().__class__}.")
compute_primal_src = textwrap.indent(textwrap.dedent(self.get_compute_primal_src()),
' ' * 4)
class_src = replace_src_block(class_src, self._funcname, compute_primal_src,
block_start_tok='def')
class_src = remove_src_blocks(class_src, self._get_del_methods(), block_start_tok='def')
return class_src.rstrip()
def _get_self_statics_func(self, static_attrs, static_dcts):
fsrc = ['def get_self_statics(self):']
tupargs = []
for attr in static_attrs:
tupargs.append(f"self.{attr}")
for name, entries in static_dcts:
for entry in entries:
tupargs.append(f"self.{name}['{entry}']")
if len(entries) == 1:
tupargs.append('') # so we'll get a trailing comma for a 1 item tuple
fsrc.append(f' return ({", ".join(tupargs)})')
fsrc = '\n'.join(fsrc)
namespace = getattr(self._comp(), self._funcname).__globals__.copy()
exec(fsrc, namespace) # nosec
return namespace['get_self_statics']
def _get_pre_body(self):
if not self._comp()._discrete_outputs:
return []
# add a statement to pull individual values out of the discrete outputs
elts = [ast.Name(id=name, ctx=ast.Store()) for name in self._comp()._discrete_outputs]
return [
ast.Assign(targets=[ast.Tuple(elts=elts, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(value=ast.Attribute(value=ast.Name(id='self',
ctx=ast.Load()),
attr='_discrete_outputs',
ctx=ast.Load()),
attr='values', ctx=ast.Load()),
args=[], keywords=[]))]
def _get_post_body(self):
if not self._comp()._discrete_outputs:
return []
# add a statement to set the values of self._discrete outputs
elts = [ast.Name(id=name, ctx=ast.Load()) for name in self._comp()._discrete_outputs]
args = [ast.Tuple(elts=elts, ctx=ast.Load())]
return [ast.Expr(value=ast.Call(func=ast.Attribute(
value=ast.Attribute(value=ast.Name(id='self', ctx=ast.Load()),
attr='_discrete_outputs', ctx=ast.Load()),
attr='set_vals', ctx=ast.Load()), args=args, keywords=[]))]
def _make_return(self):
val = ast.Tuple([ast.Name(id=n, ctx=ast.Load())
for n in self._get_compute_primal_returns()], ctx=ast.Load())
return ast.Return(val)
def _get_new_args(self):
new_args = [ast.arg('self', annotation=None)]
for arg_name in self._get_compute_primal_args():
new_args.append(ast.arg(arg=arg_name, annotation=None))
return ast.arguments(args=new_args, posonlyargs=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
[docs]
def visit_FunctionDef(self, node):
"""
Transform the compute function definition.
The function will be transformed from compute(self, inputs, outputs, ...) or
apply_nonlinear(self, ...) to compute_primal(self, arg1, arg2, ...) where args are the
input values in the order they are stored in inputs. All subscript accesses into the input
args will be replaced with the name of the key being accessed, e.g., inputs['foo'] becomes
foo. The new function will return a tuple of the output values in the order they are stored
in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they
will be handled similarly.
Parameters
----------
node : ast.FunctionDef
The FunctionDef node being visited.
Returns
-------
ast.FunctionDef
The transformed node.
"""
newbody = self._get_pre_body()
for statement in node.body:
newnode = self.visit(statement)
if newnode is not None:
newbody.append(newnode)
newbody.extend(self._get_post_body())
# add a return statement for the outputs
newbody.append(self._make_return())
newargs = self._get_new_args()
return ast.FunctionDef('compute_primal', newargs, newbody, node.decorator_list,
node.returns, node.type_comment)
[docs]
def visit_Subscript(self, node):
"""
Translate a Subscript node into a Name node with the name of the subscript variable.
Parameters
----------
node : ast.Subscript
The Subscript node being visited.
Returns
-------
ast.Any
The transformed node.
"""
# if we encounter a subscript of any of the input args, then replace arg['name'] or
# arg["name"] with name.
# NOTE: this will only work if the subscript is a string constant. If the subscript is a
# variable or some other expression, then we don't modify it.
if (isinstance(node.value, ast.Name) and node.value.id in self._orig_args and
isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str)):
return ast.copy_location(ast.Name(id=_fixname(node.slice.value), ctx=node.ctx), node)
return self.generic_visit(node)
[docs]
def visit_Attribute(self, node):
"""
Translate any non-static use of 'numpy' or 'np' to 'jnp'.
Parameters
----------
node : ast.Attribute
The Attribute node being visited.
Returns
-------
ast.Any
The transformed node.
"""
if isinstance(node.value, ast.Name) and node.value.id in self._np_names:
if node.attr not in self._static_ops:
return ast.copy_location(ast.Attribute(value=ast.Name(id='jnp', ctx=ast.Load()),
attr=node.attr, ctx=node.ctx), node)
return self.generic_visit(node)
[docs]
def visit_Assign(self, node):
"""
Translate an Assign node into an Assign node with the subscript replaced with the name.
Parameters
----------
node : ast.Assign
The Assign node being visited.
Returns
-------
ast.Any
The transformed node.
"""
if len(node.targets) == 1:
nodeval = self.visit(node.value)
tgt = node.targets[0]
if isinstance(tgt, ast.Name) and isinstance(nodeval, ast.Name):
if tgt.id == nodeval.id:
return None # get rid of any 'x = x' assignments after conversion
return self.generic_visit(node)
[docs]
class ExplicitCompJaxify(CompJaxifyBase):
"""
An ast.NodeTransformer that transforms a compute function definition to jax compatible form.
So compute(self, inputs, outputs) becomes compute_primal(self, arg1, arg2, ...) where args are
the input values in the order they are stored in inputs. The new function will return a tuple
of the output values in the order they are stored in outputs.
If the component has discrete inputs, they will be passed individually into compute_primal
*after* the continuous inputs. If the component has discrete outputs, they will be assigned
to local variables of the same name within the function and set back into the discrete
outputs dict just prior to the return from the function.
Parameters
----------
comp : ExplicitComponent
The Component whose compute function is to be transformed. This NodeTransformer may only
be used after the Component has had its _setup_var_data method called, because that
determines the ordering of the inputs and outputs.
verbose : bool
If True, the transformed function will be printed to stdout.
"""
[docs]
def __init__(self, comp, verbose=False): # noqa
super().__init__(comp, 'compute', verbose)
def _get_compute_primal_args(self):
# ensure that ordering of args and returns exactly matches the order of the inputs and
# outputs vectors.
return chain(self._comp()._var_rel_names['input'], self._comp()._discrete_inputs)
def _get_compute_primal_returns(self):
return chain(self._comp()._var_rel_names['output'], self._comp()._discrete_outputs)
def _get_del_methods(self):
return ['compute', 'compute_partials', 'compute_jacvec_product']
[docs]
class ImplicitCompJaxify(CompJaxifyBase):
"""
A NodeTransformer that transforms an apply_nonlinear function definition to jax compatible form.
So apply_nonlinear(self, inputs, outputs, residuals) becomes
compute_primal(self, arg1, arg2, ...) where args are
the input and output values in the order they are stored in their respective Vectors.
The new function will return a tuple of the residual values in the order they are stored in
the residuals Vector.
If the component has discrete inputs, they will be passed individually into compute_primal
*after* the continuous inputs. If the component has discrete outputs, they will be assigned
to local variables of the same name within the function and set back into the discrete
outputs dict just prior to the return from the function.
Parameters
----------
comp : ImplicitComponent
The Component whose apply_nonlinear function is to be transformed. This NodeTransformer
may only be used after the Component has had its _setup_var_data method called, because that
determines the ordering of the inputs, outputs, and residuals.
verbose : bool
If True, the transformed function will be printed to stdout.
"""
[docs]
def __init__(self, comp, verbose=False): # noqa
super().__init__(comp, 'apply_nonlinear', verbose)
def _get_compute_primal_args(self):
# ensure that ordering of args and returns exactly matches the order of the inputs,
# outputs, and residuals vectors.
return chain(self._comp()._var_rel_names['input'], self._comp()._var_rel_names['output'],
self._comp()._discrete_inputs)
def _get_compute_primal_returns(self):
return chain(self._comp()._var_rel_names['output'], self._comp()._discrete_outputs)
def _get_del_methods(self):
return ['apply_nonlinear', 'linearize', 'apply_linear']
[docs]
class SelfAttrFinder(ast.NodeVisitor):
"""
An ast.NodeVisitor that collects all attribute names that are accessed on `self`.
Parameters
----------
method : method
The method to be analyzed.
Attributes
----------
_attrs : set
The set of attribute names accessed on `self`.
_funcs : set
The set of method names accessed on `self`.
_dcts : dict
The set of attribute names accessed on `self` that are subscripted.
"""
# TODO: need to support intermediate variables, e.g., foo = self.options, x = foo['blah']
# TODO: need to support self.options[var], where var is an attr, not a string.
# TODO: even if we can't handle the above, at least detect and flag them and warn that
# auto-converter can't handle them.
[docs]
def __init__(self, method): # noqa
self._attrs = set()
self._funcs = set()
self._dcts = defaultdict(set)
self.visit(ast.parse(textwrap.dedent(inspect.getsource(method)), mode='exec'))
[docs]
def visit_Attribute(self, node):
"""
Visit an Attribute node.
If the attribute is accessed on `self`, add the attribute name to the set of attributes.
Parameters
----------
node : ast.Attribute
The Attribute node being visited.
"""
name = _get_long_name(node)
if name is None:
return
if name.startswith('self.'):
self._attrs.add(name.partition('.')[2])
[docs]
def visit_Subscript(self, node):
"""
Visit a Subscript node.
If the subscript is accessed on `self`, add the attribute name to the set of attributes.
Parameters
----------
node : ast.Subscript
The Subscript node being visited.
"""
name = _get_long_name(node.value)
if name is None:
return
if name.startswith('self.'):
if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str):
self._dcts[name.partition('.')[2]].add(node.slice.value)
else:
self._attrs.add(name.partition('.')[2])
self.visit(node.slice)
[docs]
def visit_Call(self, node):
"""
Visit a Call node.
If the function is accessed on `self`, add the function name to the set of functions.
Parameters
----------
node : ast.Call
The Call node being visited.
"""
name = _get_long_name(node.func)
if name is not None and name.startswith('self.'):
parts = name.split('.')
if len(parts) == 2:
self._funcs.add(parts[1])
else:
self._attrs.add('.'.join(parts[1:-1]))
for arg in node.args:
self.visit(arg)
[docs]
class ReturnChecker(ast.NodeVisitor):
"""
An ast.NodeVisitor that determines if a method returns a tuple or not.
Parameters
----------
method : method
The method to be analyzed.
Attributes
----------
_returns : list
The list of boolean values indicating whether or not the method returns a tuple. One
entry for each return statement in the method.
"""
[docs]
def __init__(self, method): # noqa
self._returns = []
self.visit(ast.parse(textwrap.dedent(inspect.getsource(method)), mode='exec'))
[docs]
def returns_tuple(self):
"""
Return whether or not the method returns a tuple.
Returns
-------
bool
True if the method returns a tuple, False otherwise.
"""
if self._returns:
ret = self._returns[0]
for r in self._returns[1:]:
if r != ret:
raise RuntimeError("ReturnChecker can't handle a method with multiple return "
"statements that return different types.")
return ret
return False
[docs]
def visit_Return(self, node):
"""
Visit a Return node.
Parameters
----------
node : ASTnode
The return node being visited.
"""
self._returns.append(isinstance(node.value, ast.Tuple))
[docs]
def get_self_static_attrs(method):
"""
Get the set of attribute names accessed on `self` in the given method.
Parameters
----------
method : method
The method to be analyzed.
Returns
-------
set
The set of attribute names accessed on `self`.
dict
The set of attribute names accessed on `self` that are subscripted with a string.
"""
saf = SelfAttrFinder(method)
static_attrs = sorted(saf._attrs)
static_dcts = [(name, sorted(eset)) for name, eset in sorted(saf._dcts.items(),
key=lambda x: x[0])]
return static_attrs, static_dcts
_invalid = frozenset((':', '(', ')', '[', ']', '{', '}', ' ', '-',
'+', '*', '/', '^', '%', '!', '<', '>', '='))
def _fixname(name):
"""
Convert (if necessary) the given name into a valid Python variable name.
Parameters
----------
name : str
The name to be fixed.
Returns
-------
str
The fixed name.
"""
intr = _invalid.intersection(name)
if intr:
for c in intr:
name = name.replace(c, '_')
return name
[docs]
def benchmark_component(comp_class, methods=(None, 'cs', 'jax'), initial_vals=None, repeats=2,
mode='auto', table_format='simple_grid', **kwargs):
"""
Benchmark the performance of a Component using different methods for computing derivatives.
Parameters
----------
comp_class : class
The class of the Component to be benchmarked.
methods : tuple of str
The methods to be benchmarked. Options are 'cs', 'jax', and None.
initial_vals : dict or None
Initial values for the input variables.
repeats : int
The number of times to run compute/compute_partials.
mode : str
The preferred derivative direction for the Problem.
table_format : str or None
If not None, the format of the table to be displayed.
**kwargs : dict
Additional keyword arguments to be passed to the Component.
Returns
-------
dict
A dictionary containing the benchmark results.
"""
import time
from openmdao.core.problem import Problem
from openmdao.devtools.memory import mem_usage
verbose = table_format is not None
results = []
for method in methods:
mem_start = mem_usage()
p = Problem()
comp = p.model.add_subsystem('comp', comp_class(**kwargs))
comp.options['derivs_method'] = method
if method in ('cs', 'fd'):
comp._has_approx = True
comp._get_approx_scheme(method)
if initial_vals:
for name, val in initial_vals.items():
p.model.set_val('comp.' + name, val)
p.setup(mode=mode, force_alloc_complex='cs' in methods)
p.run_model()
model_mem = mem_usage
if verbose:
print(f"\nModel memory usage: {model_mem} MB")
print(f"\nTiming {repeats} compute calls for {comp_class.__name__} using "
f"{method} method.")
start = time.perf_counter()
for n in range(repeats):
comp.compute(comp._inputs, comp._outputs)
if verbose:
print('.', end='', flush=True)
results.append([method, 'compute', n, time.perf_counter() - start, None])
diff_mem = mem_usage() - mem_start
results[-1][-1] = diff_mem
if verbose:
print(f"\n\nTiming {repeats} compute_partials calls for {comp_class.__name__} using "
f"{method} method.")
start = time.perf_counter()
for n in range(repeats):
p.model._linearize(None)
if verbose:
print('.', end='', flush=True)
results.append([method, 'compute_partials', n, time.perf_counter() - start, None])
diff_mem = mem_usage() - model_mem
results[-1][-1] = diff_mem
del p
if verbose:
print('\n')
headers = ['Method', 'Function', 'Iterations', 'Time (s)', 'Memory (MB)']
generate_table(results, tablefmt=table_format, headers=headers).display()
return results
[docs]
def jax_deriv_shape(derivs):
"""
Get the shape of the derivatives from a jax derivative calculation.
Parameters
----------
derivs : tuple
The tuple of derivatives.
Returns
-------
list
The shape of the derivatives.
"""
dims = []
if isinstance(derivs, jnp.ndarray):
dims.append(derivs.shape)
else: # tuple
for d in derivs:
if isinstance(d, jnp.ndarray):
dims.append(d.shape)
else:
dims.append(jax_deriv_shape(d))
return dims
if jax is None or bool(os.environ.get('JAX_DISABLE_JIT', '')):
def _jax_register_pytree_class(cls):
pass
else:
_registered_classes = set()
def _jax_register_pytree_class(cls):
"""
Register a class with jax so that it can be used with jax.jit.
Parameters
----------
cls : class
The class to be registered.
name : str
The name of the class.
bases : tuple
The base classes of the class.
attrs : dict
The attributes of the class.
"""
global _registered_classes
if cls not in _registered_classes:
# register with jax so we can flatten/unflatten self
tree_util.register_pytree_node(cls, cls._tree_flatten, cls._tree_unflatten)
_registered_classes.add(cls)
# we define compute_partials here instead of making this the base class version as we
# did with compute, because the existence of a compute_partials method that is not the
# base class method is used to determine if a given component computes its own partials.
[docs]
def compute_partials(inst, inputs, partials, discrete_inputs=None):
"""
Compute sub-jacobian parts. The model is assumed to be in an unscaled state.
Parameters
----------
inst : ImplicitComponent
The component instance.
inputs : Vector
Unscaled, dimensional input variables read via inputs[key].
partials : Jacobian
Sub-jac components written to partials[output_name, input_name]..
discrete_inputs : dict or None
If not None, dict containing discrete input values.
"""
deriv_vals = inst._get_jac_func()(*inst._get_compute_primal_invals(inputs,
inst._discrete_inputs))
nested_tup = isinstance(deriv_vals, tuple) and len(deriv_vals) > 0 and \
isinstance(deriv_vals[0], tuple)
nof = len(inst._var_rel_names['output'])
for ofidx, ofname in enumerate(inst._var_rel_names['output']):
ofmeta = inst._var_rel2meta[ofname]
for wrtidx, wrtname in enumerate(inst._var_rel_names['input']):
key = (ofname, wrtname)
if key not in partials:
# FIXME: this means that we computed a derivative that we didn't need
continue
wrtmeta = inst._var_rel2meta[wrtname]
dvals = deriv_vals
# if there's only one 'of' value, we only take the indexed value if the
# return value of compute_primal is single entry tuple. If a single array or
# scalar is returned, we don't apply the 'of' index.
if nof > 1 or nested_tup:
dvals = dvals[ofidx]
dvals = dvals[wrtidx].reshape(ofmeta['size'], wrtmeta['size'])
sjmeta = partials.get_metadata(key)
rows = sjmeta['rows']
if rows is None:
partials[ofname, wrtname] = dvals
else:
partials[ofname, wrtname] = dvals[rows, sjmeta['cols']]
[docs]
def compute_jacvec_product(inst, inputs, d_inputs, d_outputs, mode, discrete_inputs=None):
r"""
Compute jac-vector product. The model is assumed to be in an unscaled state.
If mode is:
'fwd': d_inputs \|-> d_outputs
'rev': d_outputs \|-> d_inputs
Parameters
----------
inst : ImplicitComponent
The component instance.
inputs : Vector
Unscaled, dimensional input variables read via inputs[key].
d_inputs : Vector
See inputs; product must be computed only if var_name in d_inputs.
d_outputs : Vector
See outputs; product must be computed only if var_name in d_outputs.
mode : str
Either 'fwd' or 'rev'.
discrete_inputs : dict or None
If not None, dict containing discrete input values.
"""
if mode == 'fwd':
dx = tuple(d_inputs.values())
full_invals = tuple(inst._get_compute_primal_invals(inputs, discrete_inputs))
x = full_invals[:len(dx)]
other = full_invals[len(dx):]
_, deriv_vals = jax.jvp(lambda *args: inst.compute_primal(*args, *other),
primals=x, tangents=dx)
d_outputs.set_vals(deriv_vals)
else:
inhash = ((inputs.get_hash(),) + tuple(inst._discrete_inputs.values()) +
inst.get_self_statics())
if inhash != inst._vjp_hash:
dx = tuple(d_inputs.values())
full_invals = tuple(inst._get_compute_primal_invals(inputs, discrete_inputs))
x = full_invals[:len(dx)]
other = full_invals[len(dx):]
# recompute vjp function if inputs have changed
_, inst._vjp_fun = jax.vjp(lambda *args: inst.compute_primal(*args, *other), *x)
inst._vjp_hash = inhash
if inst._compute_primal_returns_tuple:
deriv_vals = inst._vjp_fun(tuple(d_outputs.values()) +
tuple(inst._discrete_outputs.values()))
else:
deriv_vals = inst._vjp_fun(tuple(d_outputs.values())[0])
d_inputs.set_vals(deriv_vals)
# we define linearize here instead of making this the base class version as we
# did with apply_nonlinear, because the existence of a linearize method that is not the
# base class method is used to determine if a given component computes its own partials.
[docs]
def linearize(inst, inputs, outputs, partials, discrete_inputs=None, discrete_outputs=None):
"""
Compute sub-jacobian parts and any applicable matrix factorizations.
The model is assumed to be in an unscaled state.
Parameters
----------
inst : ImplicitComponent
The component instance.
inputs : Vector
Unscaled, dimensional input variables read via inputs[key].
outputs : Vector
Unscaled, dimensional output variables read via outputs[key].
partials : partial Jacobian
Sub-jac components written to jacobian[output_name, input_name].
discrete_inputs : dict or None
If not None, dict containing discrete input values.
discrete_outputs : dict or None
If not None, dict containing discrete output values.
"""
deriv_vals = inst._get_jac_func()(*inst._get_compute_primal_invals(inputs, outputs,
discrete_inputs))
nested_tup = isinstance(deriv_vals, tuple) and len(deriv_vals) > 0 and \
isinstance(deriv_vals[0], tuple)
nof = len(inst._var_rel_names['output'])
ofidx = len(inst._discrete_outputs) - 1
for ofname in inst._var_rel_names['output']:
ofidx += 1
ofmeta = inst._var_rel2meta[ofname]
for wrtidx, wrtname in enumerate(chain(inst._var_rel_names['input'],
inst._var_rel_names['output'])):
key = (ofname, wrtname)
if key not in partials:
# FIXME: this means that we computed a derivative that we didn't need
continue
wrtmeta = inst._var_rel2meta[wrtname]
dvals = deriv_vals
# if there's only one 'of' value, we only take the indexed value if the
# return value of compute_primal is single entry tuple. If a single array or
# scalar is returned, we don't apply the 'of' index.
if nof > 1 or nested_tup:
dvals = dvals[ofidx]
# print(ofidx, ofname, ofmeta['shape'], wrtidx, wrtname, wrtmeta['shape'],
# 'subjac_shape', dvals[wrtidx].shape)
dvals = dvals[wrtidx].reshape(ofmeta['size'], wrtmeta['size'])
sjmeta = partials.get_metadata(key)
rows = sjmeta['rows']
if rows is None:
partials[ofname, wrtname] = dvals
else:
partials[ofname, wrtname] = dvals[rows, sjmeta['cols']]
[docs]
def apply_linear(inst, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
r"""
Compute jac-vector product. The model is assumed to be in an unscaled state.
If mode is:
'fwd': (d_inputs, d_outputs) \|-> d_residuals
'rev': d_residuals \|-> (d_inputs, d_outputs)
Parameters
----------
inst : ImplicitComponent
The component instance.
inputs : Vector
Unscaled, dimensional input variables read via inputs[key].
outputs : Vector
Unscaled, dimensional output variables read via outputs[key].
d_inputs : Vector
See inputs; product must be computed only if var_name in d_inputs.
d_outputs : Vector
See outputs; product must be computed only if var_name in d_outputs.
d_residuals : Vector
See outputs.
mode : str
Either 'fwd' or 'rev'.
"""
if mode == 'fwd':
dx = tuple(chain(d_inputs.values(), d_outputs.values()))
full_invals = tuple(inst._get_compute_primal_invals(inputs, outputs, inst._discrete_inputs))
x = full_invals[:len(dx)]
other = full_invals[len(dx):]
_, deriv_vals = jax.jvp(lambda *args: inst.compute_primal(*args, *other),
primals=x, tangents=dx)
if isinstance(deriv_vals, tuple):
d_residuals.set_vals(deriv_vals)
else:
d_residuals.asarray()[:] = deriv_vals.flatten()
else:
inhash = (inputs.get_hash(), outputs.get_hash()) + tuple(inst._discrete_inputs.values())
if inhash != inst._vjp_hash:
# recompute vjp function only if inputs or outputs have changed
dx = tuple(chain(d_inputs.values(), d_outputs.values()))
full_invals = tuple(inst._get_compute_primal_invals(inputs, outputs,
inst._discrete_inputs))
x = full_invals[:len(dx)]
other = full_invals[len(dx):]
_, inst._vjp_fun = jax.vjp(lambda *args: inst.compute_primal(*args, *other), *x)
inst._vjp_hash = inhash
if inst._compute_primals_out_shape is None:
shape = jax.eval_shape(lambda *args: inst.compute_primal(*args, *other), *x)
if isinstance(shape, tuple):
shape = (tuple(s.shape for s in shape), True, len(inst._var_rel_names['input']))
else:
shape = (shape.shape, False, len(inst._var_rel_names['input']))
inst._compute_primals_out_shape = shape
shape, istup, ninputs = inst._compute_primals_out_shape
if istup:
deriv_vals = (inst._vjp_fun(tuple(d_residuals.values())))
else:
deriv_vals = inst._vjp_fun(tuple(d_residuals.values())[0])
d_inputs.set_vals(deriv_vals[:ninputs])
d_outputs.set_vals(deriv_vals[ninputs:])
def _to_compute_primal_setup_parser(parser):
"""
Set up the command line options for the 'openmdao call_tree' command line tool.
"""
parser.add_argument('file', nargs=1, help='Python file or module containing the class.')
parser.add_argument('-c', '--class', action='store', dest='klass',
help='Component class to be converted.')
parser.add_argument('-i', '--import', action='store', dest='imported',
help='Try to import the file as a module and convert the specified class.'
' This requires that the class be initializable with no arguments.')
parser.add_argument('-v', '--verbose', action='store_true', dest='verbose',
help='Print status information.')
parser.add_argument('-o', '--outfile', action='store', dest='outfile',
default='stdout', help='Output file. Defaults to stdout.')
def _to_compute_primal_exec(options, user_args):
"""
Process command line args and call convert on the specified class.
"""
from openmdao.core.component import Component
from openmdao.core.problem import Problem
import openmdao.utils.hooks as hooks
if not options.klass:
raise RuntimeError("Must specify a class to convert.")
if options.imported:
fname = options.file[0]
if fname.endswith('.py'):
fname = options.file[0]
if not os.path.exists(fname):
raise FileNotFoundError(f"File '{fname}' not found.")
modpath = get_module_path(fname)
if modpath is None:
modpath = fname
moddir = os.path.dirname(modpath)
sys.path = [moddir] + sys.path
modpath = os.path.basename(modpath)[:-3]
else:
modpath = options.file[0]
try:
mod = importlib.import_module(modpath)
except ImportError as err:
print(f"Can't import module '{modpath}': {err}")
return
for name, klass in inspect.getmembers(mod, inspect.isclass):
if name == options.klass:
if not issubclass(klass, Component):
print(f"Class '{options.klass}' is not a subclass of Component.")
return
# try to instantiate class with no args
try:
inst = klass()
except Exception as err:
print(f"Can't instantiate class '{options.klass}' with default args: {err}")
print("Try using --instance instead and specify the path to an instance.")
return
p = Problem()
p.model.add_subsystem('comp', inst)
p.setup()
to_compute_primal(inst, outfile=options.outfile)
break
else:
print(f"Class '{options.klass}' not found in module '{modpath}'")
return
else:
def _to_compute_primal(model):
found = False
classpath = options.klass.split('.')
cname = classpath[-1]
cmod = '.'.join(classpath[:-1])
npaths = len(classpath)
for s in model.system_iter(recurse=True, typ=Component):
for cls in inspect.getmro(type(s)):
if cls.__name__ == cname:
if npaths == 1 or cls.__module__ == cmod:
if options.verbose:
print(f"Converting class '{options.klass}' compute method to "
f"compute_primal method for instance '{s.pathname}'.")
to_compute_primal(s, outfile=options.outfile,
verbose=options.verbose)
found = True
break
if found:
break
else:
print(f"Class '{options.klass}' not found in the model.")
return
def _set_dyn_hook(prob):
# set the _to_compute_primal hook to be called right after _setup_var_data on the model
prob.model.pathname = ''
hooks._register_hook('_setup_var_data', class_name='Group', inst_id='',
post=_to_compute_primal, exit=True)
hooks._setup_hooks(prob.model)
# register the hook to be called right after setup on the problem
hooks._register_hook('setup', 'Problem', pre=_set_dyn_hook, ncalls=1)
_load_and_exec(options.file[0], user_args)
[docs]
def to_compute_primal(inst, outfile='stdout', verbose=False):
"""
Convert the given Component's compute method to a compute_primal method that works with jax.
Parameters
----------
inst : Component
The Component to be converted.
outfile : str
The name of the file to write the converted class to. Defaults to 'stdout'.
verbose : bool
If True, print status information.
"""
from openmdao.core.implicitcomponent import ImplicitComponent
from openmdao.core.explicitcomponent import ExplicitComponent
classname = type(inst).__name__
if verbose:
print(f"Converting class '{classname}' compute method to compute_primal method.")
print(f"Output will be written to '{outfile}'.")
if isinstance(inst, ImplicitComponent):
jaxer = ImplicitCompJaxify(inst)
elif isinstance(inst, ExplicitComponent):
jaxer = ExplicitCompJaxify(inst)
else:
print(f"'{classname}' is not an ImplicitComponent or ExplicitComponent.")
return
if outfile == 'stdout':
print(jaxer.get_class_src())
else:
with open(outfile, 'w') as f:
print(jaxer.get_class_src(), file=f)
if __name__ == '__main__':
import openmdao.api as om
def func(x, y): # noqa: D103
z = jnp.sin(x) * y
q = x * 1.5
zz = q + x * 1.5
return z, zz
print('partials are:\n', list(get_partials_deps(func, ('z', 'zz'))))
p = om.Problem()
comp = p.model.add_subsystem('comp', om.ExecComp('y = 2.0*x', x=np.ones(3), y=np.ones(3)))
comp.derivs_method = 'jax'
p.setup()
p.run_model()
print(p.compute_totals(of=['comp.y'], wrt=['comp.x']))