jax_utils.py#

Utilities for the use of jax in combination with OpenMDAO.

class openmdao.utils.jax_utils.CompJaxifyBase(comp, funcname, verbose=False)[source]

Bases: NodeTransformer

An ast.NodeTransformer that transforms a function definition to jax compatible form.

So original func becomes compute_primal(self, arg1, arg2, …).

If the component has discrete inputs, they will be passed individually into compute_primal before the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.

Parameters:
compComponent

The Component whose function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs and outputs.

funcnamestr

The name of the function to be transformed.

verbosebool

If True, the transformed function will be printed to stdout.

Attributes:
_compweakref.ref

A weak reference to the Component whose function is being transformed.

_funcnamestr

The name of the function being transformed.

compute_primalfunction

The compute_primal function created from the original function.

_orig_argslist

The original argument names of the original function.

_new_astast node

The new ast node created from the original function.

get_self_staticsfunction

A function that returns the static args for the Component as a single tuple.

__init__(comp, funcname, verbose=False)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

get_class_src()[source]

Return the source code of the class containing the transformed function.

Returns:
str

The source code of the class containing the transformed function.

get_compute_primal_src()[source]

Return the source code of the transformed function.

Returns:
str

The source code of the transformed function.

visit(node)

Visit a node.

visit_Assign(node)[source]

Translate an Assign node into an Assign node with the subscript replaced with the name.

Parameters:
nodeast.Assign

The Assign node being visited.

Returns:
ast.Any

The transformed node.

visit_Attribute(node)[source]

Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.

Parameters:
nodeast.Attribute

The Attribute node being visited.

Returns:
ast.Any

The transformed node.

visit_Constant(node)
visit_FunctionDef(node)[source]

Transform the compute function definition.

The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.

Parameters:
nodeast.FunctionDef

The FunctionDef node being visited.

Returns:
ast.FunctionDef

The transformed node.

visit_Subscript(node)[source]

Translate a Subscript node into a Name node with the name of the subscript variable.

Parameters:
nodeast.Subscript

The Subscript node being visited.

Returns:
ast.Any

The transformed node.

class openmdao.utils.jax_utils.ExplicitCompJaxify(comp, verbose=False)[source]

Bases: CompJaxifyBase

An ast.NodeTransformer that transforms a compute function definition to jax compatible form.

So compute(self, inputs, outputs) becomes compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. The new function will return a tuple of the output values in the order they are stored in outputs.

If the component has discrete inputs, they will be passed individually into compute_primal after the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.

Parameters:
compExplicitComponent

The Component whose compute function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs and outputs.

verbosebool

If True, the transformed function will be printed to stdout.

__init__(comp, verbose=False)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

get_class_src()

Return the source code of the class containing the transformed function.

Returns:
str

The source code of the class containing the transformed function.

get_compute_primal_src()

Return the source code of the transformed function.

Returns:
str

The source code of the transformed function.

visit(node)

Visit a node.

visit_Assign(node)

Translate an Assign node into an Assign node with the subscript replaced with the name.

Parameters:
nodeast.Assign

The Assign node being visited.

Returns:
ast.Any

The transformed node.

visit_Attribute(node)

Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.

Parameters:
nodeast.Attribute

The Attribute node being visited.

Returns:
ast.Any

The transformed node.

visit_Constant(node)
visit_FunctionDef(node)

Transform the compute function definition.

The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.

Parameters:
nodeast.FunctionDef

The FunctionDef node being visited.

Returns:
ast.FunctionDef

The transformed node.

visit_Subscript(node)

Translate a Subscript node into a Name node with the name of the subscript variable.

Parameters:
nodeast.Subscript

The Subscript node being visited.

Returns:
ast.Any

The transformed node.

class openmdao.utils.jax_utils.ImplicitCompJaxify(comp, verbose=False)[source]

Bases: CompJaxifyBase

A NodeTransformer that transforms an apply_nonlinear function definition to jax compatible form.

So apply_nonlinear(self, inputs, outputs, residuals) becomes compute_primal(self, arg1, arg2, …) where args are the input and output values in the order they are stored in their respective Vectors. The new function will return a tuple of the residual values in the order they are stored in the residuals Vector.

If the component has discrete inputs, they will be passed individually into compute_primal after the continuous inputs. If the component has discrete outputs, they will be assigned to local variables of the same name within the function and set back into the discrete outputs dict just prior to the return from the function.

Parameters:
compImplicitComponent

The Component whose apply_nonlinear function is to be transformed. This NodeTransformer may only be used after the Component has had its _setup_var_data method called, because that determines the ordering of the inputs, outputs, and residuals.

verbosebool

If True, the transformed function will be printed to stdout.

__init__(comp, verbose=False)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

get_class_src()

Return the source code of the class containing the transformed function.

Returns:
str

The source code of the class containing the transformed function.

get_compute_primal_src()

Return the source code of the transformed function.

Returns:
str

The source code of the transformed function.

visit(node)

Visit a node.

visit_Assign(node)

Translate an Assign node into an Assign node with the subscript replaced with the name.

Parameters:
nodeast.Assign

The Assign node being visited.

Returns:
ast.Any

The transformed node.

visit_Attribute(node)

Translate any non-static use of ‘numpy’ or ‘np’ to ‘jnp’.

Parameters:
nodeast.Attribute

The Attribute node being visited.

Returns:
ast.Any

The transformed node.

visit_Constant(node)
visit_FunctionDef(node)

Transform the compute function definition.

The function will be transformed from compute(self, inputs, outputs, …) or apply_nonlinear(self, …) to compute_primal(self, arg1, arg2, …) where args are the input values in the order they are stored in inputs. All subscript accesses into the input args will be replaced with the name of the key being accessed, e.g., inputs[‘foo’] becomes foo. The new function will return a tuple of the output values in the order they are stored in outputs. If compute has the additional args discrete_inputs and discrete_outputs, they will be handled similarly.

Parameters:
nodeast.FunctionDef

The FunctionDef node being visited.

Returns:
ast.FunctionDef

The transformed node.

visit_Subscript(node)

Translate a Subscript node into a Name node with the name of the subscript variable.

Parameters:
nodeast.Subscript

The Subscript node being visited.

Returns:
ast.Any

The transformed node.

class openmdao.utils.jax_utils.ReturnChecker(method)[source]

Bases: NodeVisitor

An ast.NodeVisitor that determines if a method returns a tuple or not.

Parameters:
methodmethod

The method to be analyzed.

Attributes:
_returnslist

The list of boolean values indicating whether or not the method returns a tuple. One entry for each return statement in the method.

__init__(method)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

returns_tuple()[source]

Return whether or not the method returns a tuple.

Returns:
bool

True if the method returns a tuple, False otherwise.

visit(node)

Visit a node.

visit_Constant(node)
visit_Return(node)[source]

Visit a Return node.

Parameters:
nodeASTnode

The return node being visited.

class openmdao.utils.jax_utils.SelfAttrFinder(method)[source]

Bases: NodeVisitor

An ast.NodeVisitor that collects all attribute names that are accessed on self.

Parameters:
methodmethod

The method to be analyzed.

Attributes:
_attrsset

The set of attribute names accessed on self.

_funcsset

The set of method names accessed on self.

_dctsdict

The set of attribute names accessed on self that are subscripted.

__init__(method)[source]
generic_visit(node)

Called if no explicit visitor function exists for a node.

visit(node)

Visit a node.

visit_Attribute(node)[source]

Visit an Attribute node.

If the attribute is accessed on self, add the attribute name to the set of attributes.

Parameters:
nodeast.Attribute

The Attribute node being visited.

visit_Call(node)[source]

Visit a Call node.

If the function is accessed on self, add the function name to the set of functions.

Parameters:
nodeast.Call

The Call node being visited.

visit_Constant(node)
visit_Subscript(node)[source]

Visit a Subscript node.

If the subscript is accessed on self, add the attribute name to the set of attributes.

Parameters:
nodeast.Subscript

The Subscript node being visited.

openmdao.utils.jax_utils.benchmark_component(comp_class, methods=(None, 'cs', 'jax'), initial_vals=None, repeats=2, mode='auto', table_format='simple_grid', **kwargs)[source]

Benchmark the performance of a Component using different methods for computing derivatives.

Parameters:
comp_classclass

The class of the Component to be benchmarked.

methodstuple of str

The methods to be benchmarked. Options are ‘cs’, ‘jax’, and None.

initial_valsdict or None

Initial values for the input variables.

repeatsint

The number of times to run compute/compute_partials.

modestr

The preferred derivative direction for the Problem.

table_formatstr or None

If not None, the format of the table to be displayed.

**kwargsdict

Additional keyword arguments to be passed to the Component.

Returns:
dict

A dictionary containing the benchmark results.

openmdao.utils.jax_utils.dump_jaxpr(closed_jaxpr)[source]

Print out the contents of a Jaxpr.

Parameters:
closed_jaxprjax.core.ClosedJaxpr

The Jaxpr to be examined.

openmdao.utils.jax_utils.get_self_static_attrs(method)[source]

Get the set of attribute names accessed on self in the given method.

Parameters:
methodmethod

The method to be analyzed.

Returns:
set

The set of attribute names accessed on self.

dict

The set of attribute names accessed on self that are subscripted with a string.

openmdao.utils.jax_utils.get_vmap_tangents(vals, direction, fill=1.0, coloring=None)[source]

Return a tuple of tangents values for use with vmap.

The batching dimension is the last axis of each tangent.

Parameters:
valslist

List of function input or output values.

directionstr

The direction to compute the sparsity in. It must be ‘fwd’ or ‘rev’.

fillfloat

The value to fill nonzero entries in the tangent with.

coloringColoring or None

A Coloring object that contains coloring information including nonzero indices.

Returns:
tuple of ndarray or ndarray

The tangents values to be passed to vmap.

openmdao.utils.jax_utils.jit_stub(f, *args, **kwargs)[source]

Provide a dummy jit decorator for use if jax is not available.

Parameters:
fCallable

The function or method to be wrapped.

*argslist

Positional arguments.

**kwargsdict

Keyword arguments.

Returns:
Callable

The decorated function.

openmdao.utils.jax_utils.register_jax_component(comp_class)[source]

Provide a class decorator that registers the given class as a pytree_node.

This allows jax to use jit compilation on the methods of this class if they reference attributes of the class itself, such as self.options.

Note that this decorator is not necessary if the given class does not reference self in any methods to which jax.jit is applied.

Parameters:
comp_classclass

The decorated class.

Returns:
object

The same class given as an argument.

Raises:
NotImplementedError

If this class does not define the _tree_flatten and _tree_unflatten` methods.

RuntimeError

If jax is not available.

openmdao.utils.jax_utils.to_compute_primal(inst, outfile='stdout', verbose=False)[source]

Convert the given Component’s compute method to a compute_primal method that works with jax.

Parameters:
instComponent

The Component to be converted.

outfilestr

The name of the file to write the converted class to. Defaults to ‘stdout’.

verbosebool

If True, print status information.