Source code for openmdao.test_suite.scripts.jaxprof

"""
This script is used to compare performance of JAX sparsity comp vs. plain sparsity comp.

Cmd line args control the following:

color: use coloring
prof: use profiler
rev: use reverse mode
check: check partials
jax: use JAX sparsity comp
sparse: use sparse partials
fd: use finite difference partials
show: show sparsity

"""

import time
import sys

import jax.numpy as jnp
import numpy as np

import openmdao.api as om
from openmdao.devtools.debug import profiling
from openmdao.utils.assert_utils import assert_check_partials
from openmdao.test_suite.components.sparsity_comp import SparsityComp, JaxSparsityComp
from openmdao.utils.array_utils import rand_sparsity
from openmdao.utils.general_utils import do_nothing_context

args = sys.argv[1:]

use_coloring = 'color' in args
use_prof = 'prof' in args
rev = 'rev' in args
check = 'check' in args
use_jax = 'jax' in args
show = 'show' in args
use_fd = 'fd' in args
use_sparse = 'sparse' in args

if rev:
    nrows = 100
    ncols = 1000
else:
    nrows = 1000
    ncols = 100


[docs]def main(): rng = np.random.default_rng(66) p = om.Problem() klass = JaxSparsityComp if use_jax else SparsityComp sparsity = rand_sparsity((nrows, ncols), 0.01, rng=rng) if not use_sparse: sparsity = sparsity.toarray() comp = p.model.add_subsystem('comp', klass(sparsity=sparsity)) if use_coloring: comp.declare_coloring(show_summary=True, show_sparsity=show) if use_fd: comp.options['derivs_method'] = 'fd' print("Performance for args: ", args) t0 = time.perf_counter() p.setup() setup_time = time.perf_counter() - t0 print(f'setup time: {setup_time}') t0 = time.perf_counter() p.run_model() run_time = time.perf_counter() - t0 print(f'run_model time: {run_time}') if check: t0 = time.perf_counter() assert_check_partials(comp.check_partials(method='fd', show_only_incorrect=True)) check_time = time.perf_counter() - t0 print(f'check_partials time: {check_time}') if use_prof: profname = 'color' if use_coloring else 'nocolor' if use_jax: profname = 'jax_' + profname if rev: profname = profname + '_rev' if use_fd: profname = profname + '_fd' if not use_sparse: profname = profname + '_dense' ctx = profiling(profname + '.prof') else: ctx = do_nothing_context() reps = 1000 t0 = time.perf_counter() with ctx: for i in range(reps): comp._linearize() # force coloring to be computed t1 = time.perf_counter() print(f'linearize time: {t1 - t0} for {reps} reps')
if __name__ == '__main__': main()