jax_utils.py

jax_utils.py#

Utilities for the use of jax in combination with OpenMDAO.

class openmdao.utils.jax_utils.CompJaxifyBase(comp, funcname, verbose=False)[source]

Bases: NodeTransformer

An ast.NodeTransformer that transforms a function definition to jax compatible form.

So original func becomes compute_primal(self, arg1, arg2, …).

If the component has discrete inputs, they will be passed individually into compute_primal before the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.

Parameters:
compComponent

The Component whose function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs and outputs.

funcnamestr

The name of the function to be transformed.

verbosebool

If True, the transformed function will be printed to stdout.

Attributes:
_compweakref.ref

A weak reference to the Component whose function is being transformed.

_funcnamestr

The name of the function being transformed.

compute_primalfunction

The compute_primal function created from the original function.

_orig_argslist

The original argument names of the original function.

_new_astast node

The new ast node created from the original function.

get_self_staticsfunction

A function that returns the static args for the Component as a single tuple.

__init__(comp, funcname, verbose=False)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

get_class_src()[source]

Return the source code of the class containing the transformed function.

Returns:
str

The source code of the class containing the transformed function.

get_compute_primal_src()[source]

Return the source code of the transformed function.

Returns:
str

The source code of the transformed function.

visit(node)

Visit a node.

visit_Assign(node)[source]

Translate an Assign node into an Assign node with the subscript replaced with the name.

Parameters:
nodeast.Assign

The Assign node being visited.

Returns:
ast.Any

The transformed node.

visit_Attribute(node)[source]

Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.

Parameters:
nodeast.Attribute

The Attribute node being visited.

Returns:
ast.Any

The transformed node.

visit_Constant(node)
visit_FunctionDef(node)[source]

Transform the compute function definition.

The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.

Parameters:
nodeast.FunctionDef

The FunctionDef node being visited.

Returns:
ast.FunctionDef

The transformed node.

visit_Subscript(node)[source]

Translate a Subscript node into a Name node with the name of the subscript variable.

Parameters:
nodeast.Subscript

The Subscript node being visited.

Returns:
ast.Any

The transformed node.

class openmdao.utils.jax_utils.ExplicitCompJaxify(comp, verbose=False)[source]

Bases: CompJaxifyBase

An ast.NodeTransformer that transforms a compute function definition to jax compatible form.

So compute(self, inputs, outputs) becomes compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. The new function will return a tuple of the output values in the order they are stored in outputs.

If the component has discrete inputs, they will be passed individually into compute_primal after the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.

Parameters:
compExplicitComponent

The Component whose compute function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs and outputs.

verbosebool

If True, the transformed function will be printed to stdout.

__init__(comp, verbose=False)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

get_class_src()

Return the source code of the class containing the transformed function.

Returns:
str

The source code of the class containing the transformed function.

get_compute_primal_src()

Return the source code of the transformed function.

Returns:
str

The source code of the transformed function.

visit(node)

Visit a node.

visit_Assign(node)

Translate an Assign node into an Assign node with the subscript replaced with the name.

Parameters:
nodeast.Assign

The Assign node being visited.

Returns:
ast.Any

The transformed node.

visit_Attribute(node)

Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.

Parameters:
nodeast.Attribute

The Attribute node being visited.

Returns:
ast.Any

The transformed node.

visit_Constant(node)
visit_FunctionDef(node)

Transform the compute function definition.

The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.

Parameters:
nodeast.FunctionDef

The FunctionDef node being visited.

Returns:
ast.FunctionDef

The transformed node.

visit_Subscript(node)

Translate a Subscript node into a Name node with the name of the subscript variable.

Parameters:
nodeast.Subscript

The Subscript node being visited.

Returns:
ast.Any

The transformed node.

class openmdao.utils.jax_utils.ImplicitCompJaxify(comp, verbose=False)[source]

Bases: CompJaxifyBase

A NodeTransformer that transforms an apply_nonlinear function definition to jax compatible form.

So apply_nonlinear(self, inputs, outputs, residuals) becomes compute_primal(self, arg1, arg2, …) where args are the input and output values in the order they are stored in their respective Vectors. The new function will return a tuple of the residual values in the order they are stored in the residuals Vector.

If the component has discrete inputs, they will be passed individually into compute_primal after the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.

Parameters:
compImplicitComponent

The Component whose apply_nonlinear function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs, outputs, and residuals.

verbosebool

If True, the transformed function will be printed to stdout.

__init__(comp, verbose=False)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

get_class_src()

Return the source code of the class containing the transformed function.

Returns:
str

The source code of the class containing the transformed function.

get_compute_primal_src()

Return the source code of the transformed function.

Returns:
str

The source code of the transformed function.

visit(node)

Visit a node.

visit_Assign(node)

Translate an Assign node into an Assign node with the subscript replaced with the name.

Parameters:
nodeast.Assign

The Assign node being visited.

Returns:
ast.Any

The transformed node.

visit_Attribute(node)

Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.

Parameters:
nodeast.Attribute

The Attribute node being visited.

Returns:
ast.Any

The transformed node.

visit_Constant(node)
visit_FunctionDef(node)

Transform the compute function definition.

The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.

Parameters:
nodeast.FunctionDef

The FunctionDef node being visited.

Returns:
ast.FunctionDef

The transformed node.

visit_Subscript(node)

Translate a Subscript node into a Name node with the name of the subscript variable.

Parameters:
nodeast.Subscript

The Subscript node being visited.

Returns:
ast.Any

The transformed node.

class openmdao.utils.jax_utils.ReturnChecker(method)[source]

Bases: NodeVisitor

An ast.NodeVisitor that determines if a method returns a tuple or not.

Parameters:
methodmethod

The method to be analyzed.

Attributes:
_returnslist

The list of boolean values indicating whether or not the method returns a tuple. One entry for each return statement in the method.

__init__(method)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

returns_tuple()[source]

Return whether or not the method returns a tuple.

Returns:
bool

True if the method returns a tuple, False otherwise.

visit(node)

Visit a node.

visit_Constant(node)
visit_Return(node)[source]

Visit a Return node.

Parameters:
nodeASTnode

The return node being visited.

class openmdao.utils.jax_utils.SelfAttrFinder(method)[source]

Bases: NodeVisitor

An ast.NodeVisitor that collects all attribute names that are accessed on self.

Parameters:
methodmethod

The method to be analyzed.

Attributes:
_attrsset

The set of attribute names accessed on self.

_funcsset

The set of method names accessed on self.

_dctsdict

The set of attribute names accessed on self that are subscripted.

__init__(method)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

visit(node)

Visit a node.

visit_Attribute(node)[source]

Visit an Attribute node.

If the attribute is accessed on self, add the attribute name to the set of attributes.

Parameters:
nodeast.Attribute

The Attribute node being visited.

visit_Call(node)[source]

Visit a Call node.

If the function is accessed on self, add the function name to the set of functions.

Parameters:
nodeast.Call

The Call node being visited.

visit_Constant(node)
visit_Subscript(node)[source]

Visit a Subscript node.

If the subscript is accessed on self, add the attribute name to the set of attributes.

Parameters:
nodeast.Subscript

The Subscript node being visited.

openmdao.utils.jax_utils.apply_linear(inst, inputs, outputs, d_inputs, d_outputs, d_residuals, mode)[source]

Compute jac-vector product. The model is assumed to be in an unscaled state.

If mode is:

‘fwd’: (d_inputs, d_outputs) |-> d_residuals

‘rev’: d_residuals |-> (d_inputs, d_outputs)

Parameters:
instImplicitComponent

The component instance.

inputsVector

Unscaled, dimensional input variables read via inputs[key].

outputsVector

Unscaled, dimensional output variables read via outputs[key].

d_inputsVector

See inputs; product must be computed only if var_name in d_inputs.

d_outputsVector

See outputs; product must be computed only if var_name in d_outputs.

d_residualsVector

See outputs.

modestr

Either ‘fwd’ or ‘rev’.

openmdao.utils.jax_utils.benchmark_component(comp_class, methods=(None, 'cs', 'jax'), initial_vals=None, repeats=2, mode='auto', table_format='simple_grid', **kwargs)[source]

Benchmark the performance of a Component using different methods for computing derivatives.

Parameters:
comp_classclass

The class of the Component to be benchmarked.

methodstuple of str

The methods to be benchmarked. Options are ‘cs’, ‘jax’, and None.

initial_valsdict or None

Initial values for the input variables.

repeatsint

The number of times to run compute/compute_partials.

modestr

The preferred derivative direction for the Problem.

table_formatstr or None

If not None, the format of the table to be displayed.

**kwargsdict

Additional keyword arguments to be passed to the Component.

Returns:
dict

A dictionary containing the benchmark results.

openmdao.utils.jax_utils.compute_jacvec_product(inst, inputs, d_inputs, d_outputs, mode, discrete_inputs=None)[source]

Compute jac-vector product. The model is assumed to be in an unscaled state.

If mode is:

‘fwd’: d_inputs |-> d_outputs

‘rev’: d_outputs |-> d_inputs

Parameters:
instImplicitComponent

The component instance.

inputsVector

Unscaled, dimensional input variables read via inputs[key].

d_inputsVector

See inputs; product must be computed only if var_name in d_inputs.

d_outputsVector

See outputs; product must be computed only if var_name in d_outputs.

modestr

Either ‘fwd’ or ‘rev’.

discrete_inputsdict or None

If not None, dict containing discrete input values.

openmdao.utils.jax_utils.compute_partials(inst, inputs, partials, discrete_inputs=None)[source]

Compute sub-jacobian parts. The model is assumed to be in an unscaled state.

Parameters:
instImplicitComponent

The component instance.

inputsVector

Unscaled, dimensional input variables read via inputs[key].

partialsJacobian

Sub-jac components written to partials[output_name, input_name]..

discrete_inputsdict or None

If not None, dict containing discrete input values.

openmdao.utils.jax_utils.dump_jaxpr(closed_jaxpr)[source]

Print out the contents of a Jaxpr.

Parameters:
closed_jaxprjax.core.ClosedJaxpr

The Jaxpr to be examined.

openmdao.utils.jax_utils.get_self_static_attrs(method)[source]

Get the set of attribute names accessed on self in the given method.

Parameters:
methodmethod

The method to be analyzed.

Returns:
set

The set of attribute names accessed on self.

dict

The set of attribute names accessed on self that are subscripted with a string.

openmdao.utils.jax_utils.jax_deriv_shape(derivs)[source]

Get the shape of the derivatives from a jax derivative calculation.

Parameters:
derivstuple

The tuple of derivatives.

Returns:
list

The shape of the derivatives.

openmdao.utils.jax_utils.jit_stub(f, *args, **kwargs)[source]

Provide a dummy jit decorator for use if jax is not available.

Parameters:
fCallable

The function or method to be wrapped.

*argslist

Positional arguments.

**kwargsdict

Keyword arguments.

Returns:
Callable

The decorated function.

openmdao.utils.jax_utils.linearize(inst, inputs, outputs, partials, discrete_inputs=None, discrete_outputs=None)[source]

Compute sub-jacobian parts and any applicable matrix factorizations.

The model is assumed to be in an unscaled state.

Parameters:
instImplicitComponent

The component instance.

inputsVector

Unscaled, dimensional input variables read via inputs[key].

outputsVector

Unscaled, dimensional output variables read via outputs[key].

partialspartial Jacobian

Sub-jac components written to jacobian[output_name, input_name].

discrete_inputsdict or None

If not None, dict containing discrete input values.

discrete_outputsdict or None

If not None, dict containing discrete output values.

openmdao.utils.jax_utils.register_jax_component(comp_class)[source]

Provide a class decorator that registers the given class as a pytree_node.

This allows jax to use jit compilation on the methods of this class if they reference attributes of the class itself, such as self.options.

Note that this decorator is not necessary if the given class does not reference self in any methods to which jax.jit is applied.

Parameters:
comp_classclass

The decorated class.

Returns:
object

The same class given as an argument.

Raises:
NotImplementedError

If this class does not define the _tree_flatten and _tree_unflatten` methods.

RuntimeError

If jax is not available.

openmdao.utils.jax_utils.to_compute_primal(inst, outfile='stdout', verbose=False)[source]

Convert the given Component’s compute method to a compute_primal method that works with jax.

Parameters:
instComponent

The Component to be converted.

outfilestr

The name of the file to write the converted class to. Defaults to ‘stdout’.

verbosebool

If True, print status information.