Source code for openmdao.test_suite.components.sparsity_comp


import numpy as np
from scipy.sparse import coo_matrix, issparse

try:
    import jax
    jax.config.update("jax_enable_x64", True)  # jax by default uses 32 bit floats
    import jax.numpy as jnp
    from jax.experimental import sparse as jsparse
except ImportError:
    jax = None
    jnp = np

from openmdao.core.problem import Problem
from openmdao.components.jax_explicit_comp import JaxExplicitComponent
from openmdao.core.explicitcomponent import ExplicitComponent


[docs]class SparsityComp(ExplicitComponent):
[docs] def __init__(self, sparsity, **kwargs): super(SparsityComp, self).__init__(**kwargs) if isinstance(sparsity, np.ndarray): self.use_sparse = False self.sparsity = sparsity self.nzrows, self.nzcols = np.nonzero(self.sparsity) else: self.use_sparse = True self.sparsity = sparsity.tocoo() self.nzrows, self.nzcols = self.sparsity.row, self.sparsity.col
[docs] def setup(self): self.add_input('x', shape=self.sparsity.shape[1]) self.add_output('y', shape=self.sparsity.shape[0])
[docs] def setup_partials(self): if self.use_sparse: self.declare_partials('y', 'x', rows=self.nzrows, cols=self.nzcols) else: self.declare_partials('y', 'x')
[docs] def compute(self, inputs, outputs): outputs['y'] = self.sparsity @ inputs['x']
[docs] def compute_partials(self, inputs, partials): if self.use_sparse: partials['y', 'x'] = self.sparsity.data else: partials['y', 'x'] = self.sparsity # [self.nzrows, self.nzcols]
[docs]class JaxSparsityComp(JaxExplicitComponent): """ A simple component that multiplies a sparse matrix by an input vector. The sparsity structure is defined by the 'sparsity' argument, and the data values are just the (index + 1) of the nonzeros in the sparsity structure. This component is used to test coloring. The jacobian can be computed with and without coloring and the two results should be the same. This version does not declare partials so sparsity must be computed in order to determine rows/cols for the subjacobian declarations. Parameters ---------- sparsity : ndarray or coo_matrix Sparsity structure to be tested. Attributes ---------- sparsity : coo_matrix Dense or sparse version of the sparsity structure. """
[docs] def __init__(self, sparsity, declare_partials=False, **kwargs): super(JaxSparsityComp, self).__init__(**kwargs) self.declare = declare_partials if isinstance(sparsity, (np.ndarray, jnp.ndarray)): self.sparsity = jnp.array(sparsity) self.nzrows, self.nzcols = jnp.nonzero(self.sparsity) elif issparse(sparsity): sparsity = sparsity.tocoo() self.nzrows, self.nzcols = jnp.array(sparsity.row), jnp.array(sparsity.col) indices = jnp.array(list(zip(self.nzrows, self.nzcols)), dtype=jnp.int32) self.sparsity = jsparse.BCOO((jnp.array(sparsity.data), indices), shape=sparsity.shape) else: raise ValueError(f"{self.msginfo}: no support for sparse type of {type(sparsity)}")
[docs] def setup(self): self.add_input('x', shape=self.sparsity.shape[1]) self.add_output('y', shape=self.sparsity.shape[0])
[docs] def setup_partials(self): if self.declare: self.declare_partials('y', 'x', rows=self.nzrows, cols=self.nzcols)
[docs] def compute_primal(self, x): return self.sparsity @ x
if __name__ == '__main__': from openmdao.test_suite.comp_tester import ComponentTester sparsity = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) sparsity_coo = coo_matrix(sparsity) jax_sparsity = jnp.array(sparsity) ComponentTester(SparsityComp, (sparsity,)).run() ComponentTester(SparsityComp, (sparsity_coo,)).run() ComponentTester(JaxSparsityComp, (jax_sparsity,)).run() ComponentTester(JaxSparsityComp, (jax_sparsity,), {'declare_partials': True}).run()