Parallel Coloring for Multipoint or Fan-Out Problems

In many models, there is an opportunity to parallelize across multiple points (e.g. multiple load cases for a structural optimization, multiple flight conditions for an aerodynamic optimization). Executing the nonlinear solve for this model in parallel offers a large potential speed-up, but when computing total derivatives, achieving that same parallel performance may require the use of OpenMDAO’s parallel coloring algorithm.

Note

Parallel coloring is appropriate when you have some inexpensive serial data path in your model, before the parallel points. For more details on when a model calls for parallel coloring see the Theory Manual entry on the fan-out model structures.

Parallel coloring is specified via the parallel_deriv_color argument to the add_constraint() method. The color specified can be any hashable object (e.g. string, int, tuple). Two constraints, pointing to variables from different components on different processors, given the same parallel_deriv_color argument will be solved for in parallel with each other.

Usage Example

Here is a toy problem that runs on two processors showing how to use this feature

Class definitions for a simple problem

import openmdao.api as om


class SumComp(om.ExplicitComponent):
    def __init__(self, size):
        super().__init__()
        self.size = size

    def setup(self):
        self.add_input('x', val=np.zeros(self.size))
        self.add_output('y', val=0.0)

        self.declare_partials(of='*', wrt='*')

    def compute(self, inputs, outputs):
        outputs['y'] = np.sum(inputs['x'])

    def compute_partials(self, inputs, partials):
        partials['y', 'x'] = np.ones(inputs['x'].size)
class SlowComp(om.ExplicitComponent):
    """
    Component with a delay that multiplies the input by a multiplier.
    """

    def __init__(self, delay=1.0, size=3, mult=2.0):
        super().__init__()
        self.delay = delay
        self.size = size
        self.mult = mult

    def setup(self):
        self.add_input('x', val=0.0)
        self.add_output('y', val=np.zeros(self.size))

        self.declare_partials(of='*', wrt='*')

    def compute(self, inputs, outputs):
        outputs['y'] = inputs['x'] * self.mult

    def compute_partials(self, inputs, partials):
        partials['y', 'x'] = self.mult

    def _apply_linear(self, jac, vec_names, rel_systems, mode, scope_out=None, scope_in=None):
        time.sleep(self.delay)
        super()._apply_linear(jac, vec_names, rel_systems, mode, scope_out, scope_in)
class PartialDependGroup(om.Group):
    def setup(self):
        size = 4

        Comp1 = self.add_subsystem('Comp1', SumComp(size))
        pargroup = self.add_subsystem('ParallelGroup1', om.ParallelGroup())

        self.set_input_defaults('Comp1.x', val=np.arange(size, dtype=float)+1.0)

        self.linear_solver = om.LinearBlockGS()
        self.linear_solver.options['iprint'] = -1
        pargroup.linear_solver = om.LinearBlockGS()
        pargroup.linear_solver.options['iprint'] = -1

        delay = .1
        Con1 = pargroup.add_subsystem('Con1', SlowComp(delay=delay, size=2, mult=2.0))
        Con2 = pargroup.add_subsystem('Con2', SlowComp(delay=delay, size=2, mult=-3.0))

        self.connect('Comp1.y', 'ParallelGroup1.Con1.x')
        self.connect('Comp1.y', 'ParallelGroup1.Con2.x')

        color = 'parcon'
        self.add_design_var('Comp1.x')
        self.add_constraint('ParallelGroup1.Con1.y', lower=0.0, parallel_deriv_color=color)
        self.add_constraint('ParallelGroup1.Con2.y', upper=0.0, parallel_deriv_color=color)

Run script

%%px

import numpy as np

import openmdao.api as om
from openmdao.core.tests.test_parallel_derivatives import PartialDependGroup

size = 4

of = ['ParallelGroup1.Con1.y', 'ParallelGroup1.Con2.y']
wrt = ['Comp1.x']

p = om.Problem(model=PartialDependGroup())
p.setup(mode='rev')
p.run_model()

J = p.compute_totals(of, wrt, return_format='dict')

print(J['ParallelGroup1.Con1.y']['Comp1.x'][0])
print(J['ParallelGroup1.Con2.y']['Comp1.x'][0])
[stdout:0] 
[2. 2. 2. 2.]
[-3. -3. -3. -3.]
[stdout:1] 
[2. 2. 2. 2.]
[-3. -3. -3. -3.]
[stdout:2] 
[2. 2. 2. 2.]
[-3. -3. -3. -3.]
[stdout:3] 
[2. 2. 2. 2.]
[-3. -3. -3. -3.]