jax_utils.py#
Utilities for the use of jax in combination with OpenMDAO.
- class openmdao.utils.jax_utils.CompJaxifyBase(comp, funcname, verbose=False)[source]
Bases:
NodeTransformer
An ast.NodeTransformer that transforms a function definition to jax compatible form.
So original func becomes compute_primal(self, arg1, arg2, …).
If the component has discrete inputs, they will be passed individually into compute_primal before the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.
- Parameters:
- compComponent
The Component whose function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs and outputs.
- funcnamestr
The name of the function to be transformed.
- verbosebool
If True, the transformed function will be printed to stdout.
- Attributes:
- _compweakref.ref
A weak reference to the Component whose function is being transformed.
- _funcnamestr
The name of the function being transformed.
- compute_primalfunction
The compute_primal function created from the original function.
- _orig_argslist
The original argument names of the original function.
- _new_astast node
The new ast node created from the original function.
- get_self_staticsfunction
A function that returns the static args for the Component as a single tuple.
- __init__(comp, funcname, verbose=False)[source]
- generic_visit(node)
Called if no explicit visitor function exists for a node.
- get_class_src()[source]
Return the source code of the class containing the transformed function.
- Returns:
- str
The source code of the class containing the transformed function.
- get_compute_primal_src()[source]
Return the source code of the transformed function.
- Returns:
- str
The source code of the transformed function.
- visit(node)
Visit a node.
- visit_Assign(node)[source]
Translate an Assign node into an Assign node with the subscript replaced with the name.
- Parameters:
- nodeast.Assign
The Assign node being visited.
- Returns:
- ast.Any
The transformed node.
- visit_Attribute(node)[source]
Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.
- Parameters:
- nodeast.Attribute
The Attribute node being visited.
- Returns:
- ast.Any
The transformed node.
- visit_Constant(node)
- visit_FunctionDef(node)[source]
Transform the compute function definition.
The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.
- Parameters:
- nodeast.FunctionDef
The FunctionDef node being visited.
- Returns:
- ast.FunctionDef
The transformed node.
- visit_Subscript(node)[source]
Translate a Subscript node into a Name node with the name of the subscript variable.
- Parameters:
- nodeast.Subscript
The Subscript node being visited.
- Returns:
- ast.Any
The transformed node.
- class openmdao.utils.jax_utils.ExplicitCompJaxify(comp, verbose=False)[source]
Bases:
CompJaxifyBase
An ast.NodeTransformer that transforms a compute function definition to jax compatible form.
So compute(self, inputs, outputs) becomes compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. The new function will return a tuple of the output values in the order they are stored in outputs.
If the component has discrete inputs, they will be passed individually into compute_primal after the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.
- Parameters:
- compExplicitComponent
The Component whose compute function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs and outputs.
- verbosebool
If True, the transformed function will be printed to stdout.
- __init__(comp, verbose=False)[source]
- generic_visit(node)
Called if no explicit visitor function exists for a node.
- get_class_src()
Return the source code of the class containing the transformed function.
- Returns:
- str
The source code of the class containing the transformed function.
- get_compute_primal_src()
Return the source code of the transformed function.
- Returns:
- str
The source code of the transformed function.
- visit(node)
Visit a node.
- visit_Assign(node)
Translate an Assign node into an Assign node with the subscript replaced with the name.
- Parameters:
- nodeast.Assign
The Assign node being visited.
- Returns:
- ast.Any
The transformed node.
- visit_Attribute(node)
Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.
- Parameters:
- nodeast.Attribute
The Attribute node being visited.
- Returns:
- ast.Any
The transformed node.
- visit_Constant(node)
- visit_FunctionDef(node)
Transform the compute function definition.
The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.
- Parameters:
- nodeast.FunctionDef
The FunctionDef node being visited.
- Returns:
- ast.FunctionDef
The transformed node.
- visit_Subscript(node)
Translate a Subscript node into a Name node with the name of the subscript variable.
- Parameters:
- nodeast.Subscript
The Subscript node being visited.
- Returns:
- ast.Any
The transformed node.
- class openmdao.utils.jax_utils.ImplicitCompJaxify(comp, verbose=False)[source]
Bases:
CompJaxifyBase
A NodeTransformer that transforms an apply_nonlinear function definition to jax compatible form.
So apply_nonlinear(self, inputs, outputs, residuals) becomes compute_primal(self, arg1, arg2, …) where args are the input and output values in the order they are stored in their respective Vectors. The new function will return a tuple of the residual values in the order they are stored in the residuals Vector.
If the component has discrete inputs, they will be passed individually into compute_primal after the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.
- Parameters:
- compImplicitComponent
The Component whose apply_nonlinear function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs, outputs, and residuals.
- verbosebool
If True, the transformed function will be printed to stdout.
- __init__(comp, verbose=False)[source]
- generic_visit(node)
Called if no explicit visitor function exists for a node.
- get_class_src()
Return the source code of the class containing the transformed function.
- Returns:
- str
The source code of the class containing the transformed function.
- get_compute_primal_src()
Return the source code of the transformed function.
- Returns:
- str
The source code of the transformed function.
- visit(node)
Visit a node.
- visit_Assign(node)
Translate an Assign node into an Assign node with the subscript replaced with the name.
- Parameters:
- nodeast.Assign
The Assign node being visited.
- Returns:
- ast.Any
The transformed node.
- visit_Attribute(node)
Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.
- Parameters:
- nodeast.Attribute
The Attribute node being visited.
- Returns:
- ast.Any
The transformed node.
- visit_Constant(node)
- visit_FunctionDef(node)
Transform the compute function definition.
The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.
- Parameters:
- nodeast.FunctionDef
The FunctionDef node being visited.
- Returns:
- ast.FunctionDef
The transformed node.
- visit_Subscript(node)
Translate a Subscript node into a Name node with the name of the subscript variable.
- Parameters:
- nodeast.Subscript
The Subscript node being visited.
- Returns:
- ast.Any
The transformed node.
- class openmdao.utils.jax_utils.ReturnChecker(method)[source]
Bases:
NodeVisitor
An ast.NodeVisitor that determines if a method returns a tuple or not.
- Parameters:
- methodmethod
The method to be analyzed.
- Attributes:
- _returnslist
The list of boolean values indicating whether or not the method returns a tuple. One entry for each return statement in the method.
- __init__(method)[source]
- generic_visit(node)
Called if no explicit visitor function exists for a node.
- returns_tuple()[source]
Return whether or not the method returns a tuple.
- Returns:
- bool
True if the method returns a tuple, False otherwise.
- visit(node)
Visit a node.
- visit_Constant(node)
- visit_Return(node)[source]
Visit a Return node.
- Parameters:
- nodeASTnode
The return node being visited.
- class openmdao.utils.jax_utils.SelfAttrFinder(method)[source]
Bases:
NodeVisitor
An ast.NodeVisitor that collects all attribute names that are accessed on self.
- Parameters:
- methodmethod
The method to be analyzed.
- Attributes:
- _attrsset
The set of attribute names accessed on self.
- _funcsset
The set of method names accessed on self.
- _dctsdict
The set of attribute names accessed on self that are subscripted.
- __init__(method)[source]
- generic_visit(node)
Called if no explicit visitor function exists for a node.
- visit(node)
Visit a node.
- visit_Attribute(node)[source]
Visit an Attribute node.
If the attribute is accessed on self, add the attribute name to the set of attributes.
- Parameters:
- nodeast.Attribute
The Attribute node being visited.
- visit_Call(node)[source]
Visit a Call node.
If the function is accessed on self, add the function name to the set of functions.
- Parameters:
- nodeast.Call
The Call node being visited.
- visit_Constant(node)
- visit_Subscript(node)[source]
Visit a Subscript node.
If the subscript is accessed on self, add the attribute name to the set of attributes.
- Parameters:
- nodeast.Subscript
The Subscript node being visited.
- openmdao.utils.jax_utils.apply_linear(inst, inputs, outputs, d_inputs, d_outputs, d_residuals, mode)[source]
Compute jac-vector product. The model is assumed to be in an unscaled state.
- If mode is:
‘fwd’: (d_inputs, d_outputs) |-> d_residuals
‘rev’: d_residuals |-> (d_inputs, d_outputs)
- Parameters:
- instImplicitComponent
The component instance.
- inputsVector
Unscaled, dimensional input variables read via inputs[key].
- outputsVector
Unscaled, dimensional output variables read via outputs[key].
- d_inputsVector
See inputs; product must be computed only if var_name in d_inputs.
- d_outputsVector
See outputs; product must be computed only if var_name in d_outputs.
- d_residualsVector
See outputs.
- modestr
Either ‘fwd’ or ‘rev’.
- openmdao.utils.jax_utils.benchmark_component(comp_class, methods=(None, 'cs', 'jax'), initial_vals=None, repeats=2, mode='auto', table_format='simple_grid', **kwargs)[source]
Benchmark the performance of a Component using different methods for computing derivatives.
- Parameters:
- comp_classclass
The class of the Component to be benchmarked.
- methodstuple of str
The methods to be benchmarked. Options are ‘cs’, ‘jax’, and None.
- initial_valsdict or None
Initial values for the input variables.
- repeatsint
The number of times to run compute/compute_partials.
- modestr
The preferred derivative direction for the Problem.
- table_formatstr or None
If not None, the format of the table to be displayed.
- **kwargsdict
Additional keyword arguments to be passed to the Component.
- Returns:
- dict
A dictionary containing the benchmark results.
- openmdao.utils.jax_utils.compute_jacvec_product(inst, inputs, d_inputs, d_outputs, mode, discrete_inputs=None)[source]
Compute jac-vector product. The model is assumed to be in an unscaled state.
- If mode is:
‘fwd’: d_inputs |-> d_outputs
‘rev’: d_outputs |-> d_inputs
- Parameters:
- instImplicitComponent
The component instance.
- inputsVector
Unscaled, dimensional input variables read via inputs[key].
- d_inputsVector
See inputs; product must be computed only if var_name in d_inputs.
- d_outputsVector
See outputs; product must be computed only if var_name in d_outputs.
- modestr
Either ‘fwd’ or ‘rev’.
- discrete_inputsdict or None
If not None, dict containing discrete input values.
- openmdao.utils.jax_utils.compute_partials(inst, inputs, partials, discrete_inputs=None)[source]
Compute sub-jacobian parts. The model is assumed to be in an unscaled state.
- Parameters:
- instImplicitComponent
The component instance.
- inputsVector
Unscaled, dimensional input variables read via inputs[key].
- partialsJacobian
Sub-jac components written to partials[output_name, input_name]..
- discrete_inputsdict or None
If not None, dict containing discrete input values.
- openmdao.utils.jax_utils.dump_jaxpr(closed_jaxpr)[source]
Print out the contents of a Jaxpr.
- Parameters:
- closed_jaxprjax.core.ClosedJaxpr
The Jaxpr to be examined.
- openmdao.utils.jax_utils.get_self_static_attrs(method)[source]
Get the set of attribute names accessed on self in the given method.
- Parameters:
- methodmethod
The method to be analyzed.
- Returns:
- set
The set of attribute names accessed on self.
- dict
The set of attribute names accessed on self that are subscripted with a string.
- openmdao.utils.jax_utils.jax_deriv_shape(derivs)[source]
Get the shape of the derivatives from a jax derivative calculation.
- Parameters:
- derivstuple
The tuple of derivatives.
- Returns:
- list
The shape of the derivatives.
- openmdao.utils.jax_utils.jit_stub(f, *args, **kwargs)[source]
Provide a dummy jit decorator for use if jax is not available.
- Parameters:
- fCallable
The function or method to be wrapped.
- *argslist
Positional arguments.
- **kwargsdict
Keyword arguments.
- Returns:
- Callable
The decorated function.
- openmdao.utils.jax_utils.linearize(inst, inputs, outputs, partials, discrete_inputs=None, discrete_outputs=None)[source]
Compute sub-jacobian parts and any applicable matrix factorizations.
The model is assumed to be in an unscaled state.
- Parameters:
- instImplicitComponent
The component instance.
- inputsVector
Unscaled, dimensional input variables read via inputs[key].
- outputsVector
Unscaled, dimensional output variables read via outputs[key].
- partialspartial Jacobian
Sub-jac components written to jacobian[output_name, input_name].
- discrete_inputsdict or None
If not None, dict containing discrete input values.
- discrete_outputsdict or None
If not None, dict containing discrete output values.
- openmdao.utils.jax_utils.register_jax_component(comp_class)[source]
Provide a class decorator that registers the given class as a pytree_node.
This allows jax to use jit compilation on the methods of this class if they reference attributes of the class itself, such as self.options.
Note that this decorator is not necessary if the given class does not reference self in any methods to which jax.jit is applied.
- Parameters:
- comp_classclass
The decorated class.
- Returns:
- object
The same class given as an argument.
- Raises:
- NotImplementedError
If this class does not define the _tree_flatten and _tree_unflatten` methods.
- RuntimeError
If jax is not available.
- openmdao.utils.jax_utils.to_compute_primal(inst, outfile='stdout', verbose=False)[source]
Convert the given Component’s compute method to a compute_primal method that works with jax.
- Parameters:
- instComponent
The Component to be converted.
- outfilestr
The name of the file to write the converted class to. Defaults to ‘stdout’.
- verbosebool
If True, print status information.