Distributed Variables#

At times when you need to perform a computation using large input arrays, you may want to perform that computation in multiple processes, where each process operates on some subset of the input values. This may be done purely for performance reasons, or it may be necessary because the entire input will not fit in the memory of a single machine. In any case, this can be accomplished in OpenMDAO by declaring those inputs and outputs as distributed. By definition, a distributed variable is an input or output where each process contains only a part of the whole variable. Distributed variables are declared by setting the optional “distributed” argument to True when adding the variable to a component. A component that has at least one distributed variable can also be called a distributed component.

Any variable that is not distributed is called a non-distributed variable. When the model is run under MPI, every process contains a copy of the entire non-distributed variable. We also call these duplicated variables.

We’ve already seen that by using src_indices, we can connect an input to only a subset of an output variable. By giving different values for src_indices in each MPI process, we can distribute computations on a distributed output across the processes. All of the scenarios that involve connecting distributed and non-distributed variables are detailed in Connections involving distributed variables.

Note

This feature requires MPI, and may not be able to be run on Colab or Binder.

Example: Simple Component with Distributed Input and Output#

The following example shows how to create a simple component, SimpleDistrib, that takes a distributed variable as an input and computes a distributed output. The calculation is divided across the available processes, but the details of that division are not contained in the component. In fact, the input is sized based on its connected source using the “shape_by_conn” argument.

%%px 

import numpy as np

import openmdao.api as om


class SimpleDistrib(om.ExplicitComponent):

    def setup(self):

        # Distributed Input
        self.add_input('in_dist', shape_by_conn=True, distributed=True)

        # Distributed Output
        self.add_output('out_dist', copy_shape='in_dist', distributed=True)

    def compute(self, inputs, outputs):
        x = inputs['in_dist']

        # "Computationally Intensive" operation that we wish to parallelize.
        f_x = x**2 - 2.0*x + 4.0

        outputs['out_dist'] = f_x

In the next part of the example, we take the SimpleDistrib component, place it into a model, and run it. Suppose the vector of data we want to process has 7 elements. We have 4 processors available for computation, so if we distribute them as evenly as we can, 3 procs can handle 2 elements each, and the 4th processor can pick up the last one. OpenMDAO’s utilities includes the evenly_distrib_idxs function which computes the sizes and offsets for all ranks. The sizes are used to determine how much of the array to allocate on any specific rank. The offsets are used to figure out where the local portion of the array starts, and in this example, is used to set the initial value properly. In this case, the initial value for the full distributed input “in_dist” is a vector of 7 values between 3.0 and 9.0, and each processor has a 1 or 2 element piece of it.

%%px

from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI

size = 7

if MPI:
    comm = MPI.COMM_WORLD
    rank = comm.rank
    sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
    # When running without MPI, the entire variable is on one proc.
    rank = 0
    sizes = {rank : size}
    offsets = {rank : 0}

prob = om.Problem()
model = prob.model

# Create a distributed source for the distributed input.
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)

model.add_subsystem("indep", ivc)
model.add_subsystem("D1", SimpleDistrib())

model.connect('indep.x_dist', 'D1.in_dist')

prob.setup()

# Set initial values of distributed variable.
x_dist_init = 3.0 + np.arange(size)[offsets[rank]:offsets[rank] + sizes[rank]]
prob.set_val('indep.x_dist', x_dist_init)

prob.run_model()

# Values on each rank.
for var in ['indep.x_dist', 'D1.out_dist']:
    print(var, prob.get_val(var))
    
# Full gathered values.
for var in ['indep.x_dist', 'D1.out_dist']:
    print(var, prob.get_val(var, get_remote=True))
print('')
[stdout:2] indep.x_dist [7. 8.]
D1.out_dist [39. 52.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]
[stdout:3] indep.x_dist [9.]
D1.out_dist [67.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]
[stdout:1] indep.x_dist [5. 6.]
D1.out_dist [19. 28.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]
[stdout:0] indep.x_dist [3. 4.]
D1.out_dist [ 7. 12.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]

Note that we created a connection source ‘x_dist’ that passes its value to ‘D1.in_dist’. OpenMDAO requires a source for non-constant inputs, and usually creates one automatically as an output of a component referred to as an ‘Auto-IVC’. However, the automatic creation is not supported for distributed variables. We must manually create an IndepVarComp and connect it to our input.

When using distributed variables, OpenMDAO can’t always size the component inputs based on the shape of the connected source. In this example, the component determines its own split using evenly_distrib_idxs. This requires that the component know the full vector size, which is passed in via the option ‘vec_size’.

%%px

import numpy as np

import openmdao.api as om
from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI

class SimpleDistrib(om.ExplicitComponent):

    def initialize(self):
        self.options.declare('vec_size', types=int, default=1,
                             desc="Total size of vector.")

    def setup(self):
        comm = self.comm
        rank = comm.rank

        size = self.options['vec_size']
        sizes, _ = evenly_distrib_idxs(comm.size, size)
        mysize = sizes[rank]

        # Distributed Input
        self.add_input('in_dist', np.ones(mysize, float), distributed=True)

        # Distributed Output
        self.add_output('out_dist', np.ones(mysize, float), distributed=True)

    def compute(self, inputs, outputs):
        x = inputs['in_dist']

        # "Computationally Intensive" operation that we wish to parallelize.
        f_x = x**2 - 2.0*x + 4.0

        outputs['out_dist'] = f_x


size = 7

if MPI:
    comm = MPI.COMM_WORLD
    rank = comm.rank
    sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
    # When running without MPI, the entire variable is on one proc.
    rank = 0
    sizes = {rank : size}
    offsets = {rank : 0}

prob = om.Problem()
model = prob.model

# Create a distributed source for the distributed input.
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)

model.add_subsystem("indep", ivc)
model.add_subsystem("D1", SimpleDistrib(vec_size=size))

model.connect('indep.x_dist', 'D1.in_dist')

prob.setup()

# Set initial values of distributed variable.
x_dist_init = 3.0 + np.arange(size)[offsets[rank]:offsets[rank] + sizes[rank]]
prob.set_val('indep.x_dist', x_dist_init)

prob.run_model()

# Values on each rank.
for var in ['indep.x_dist', 'D1.out_dist']:
    print(var, prob.get_val(var))

# Full gathered values.
for var in ['indep.x_dist', 'D1.out_dist']:
    print(var, prob.get_val(var, get_remote=True))
print('')
[stdout:2] indep.x_dist [7. 8.]
D1.out_dist [39. 52.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]
[stdout:1] indep.x_dist [5. 6.]
D1.out_dist [19. 28.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]
[stdout:0] indep.x_dist [3. 4.]
D1.out_dist [ 7. 12.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]
[stdout:3] indep.x_dist [9.]
D1.out_dist [67.]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
D1.out_dist [ 7. 12. 19. 28. 39. 52. 67.]

Example: Distributed I/O and a Non-Distributed Input#

OpenMDAO supports both non-distributed and distributed I/O on the same component, so in this example, we expand the problem to include a non-distributed input. In this case, the non-distributed input also has a vector width of 7, but those values will be the same on each processor. This non-distributed input is included in the computation by taking the vector sum and adding it to the distributed output.

%%px 

import numpy as np

import openmdao.api as om
from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI


class MixedDistrib1(om.ExplicitComponent):

    def setup(self):

        # Distributed Input
        self.add_input('in_dist', shape_by_conn=True, distributed=True)

        # Non-Distributed Input
        self.add_input('in_nd', shape_by_conn=True)

        # Distributed Output
        self.add_output('out_dist', copy_shape='in_dist', distributed=True)

    def compute(self, inputs, outputs):
        Id = inputs['in_dist']
        Ind = inputs['in_nd']

        # "Computationally Intensive" operation that we wish to parallelize.
        f_Id = Id**2 - 2.0*Id + 4.0

        # This operation is repeated on all procs.
        f_Ind = Ind ** 0.5
        
        outputs['out_dist'] = f_Id + np.sum(f_Ind)
        
size = 7

if MPI:
    comm = MPI.COMM_WORLD
    rank = comm.rank
    sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
    # When running without MPI, the entire variable is on one proc.
    rank = 0
    sizes = {rank : size}
    offsets = {rank : 0}

prob = om.Problem()
model = prob.model

# Create a distributed source for the distributed input.
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)
ivc.add_output('x_nd', np.zeros(size))

model.add_subsystem("indep", ivc)
model.add_subsystem("D1", MixedDistrib1())

model.connect('indep.x_dist', 'D1.in_dist')
model.connect('indep.x_nd', 'D1.in_nd')

prob.setup()

# Set initial values of distributed variable.
x_dist_init = 3.0 + np.arange(size)[offsets[rank]:offsets[rank] + sizes[rank]]
prob.set_val('indep.x_dist', x_dist_init)

# Set initial values of non-distributed variable.
x_nd_init = 1.0 + 2.0*np.arange(size)
prob.set_val('indep.x_nd', x_nd_init)

prob.run_model()

# Values on each rank.
for var in ['indep.x_dist', 'indep.x_nd', 'D1.out_dist']:
    print(var, prob.get_val(var))
    
# Full gathered values.
for var in ['indep.x_dist', 'indep.x_nd', 'D1.out_dist']:
    print(var, prob.get_val(var, get_remote=True))
print('')
[stdout:1] indep.x_dist [5. 6.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [36.53604616 45.53604616]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [24.53604616 29.53604616 36.53604616 45.53604616 56.53604616 69.53604616
 84.53604616]
[stdout:2] indep.x_dist [7. 8.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [56.53604616 69.53604616]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [24.53604616 29.53604616 36.53604616 45.53604616 56.53604616 69.53604616
 84.53604616]
[stdout:3] indep.x_dist [9.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [84.53604616]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [24.53604616 29.53604616 36.53604616 45.53604616 56.53604616 69.53604616
 84.53604616]
[stdout:0] indep.x_dist [3. 4.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [24.53604616 29.53604616]
indep.x_dist [3. 4. 5. 6. 7. 8. 9.]
indep.x_nd [ 1.  3.  5.  7.  9. 11. 13.]
D1.out_dist [24.53604616 29.53604616 36.53604616 45.53604616 56.53604616 69.53604616
 84.53604616]

Example: Distributed I/O and a Non-Distributed Output#

You can also create a component with a non-distributed output and distributed outputs and inputs. This situation tends to be more tricky and usually requires you to perform some MPI operations in your component’s compute method. If the non-distributed output is only a function of the non-distributed inputs, then you can handle that variable just like you do on any other component. However, this example extends the previous component to include a non-distributed output that is a function of both the non-distributed and distributed inputs. In this case, it’s a function of the sum of the square root of each element in the full distributed vector. Since the data is not all on any local processor, we use an MPI operation, in this case Allreduce, to make a summation across the distributed vector, and gather the answer back to each processor. The MPI operation and your implementation will vary, but consider this to be a general example.

Note

In this example, we introduce a new component called an IndepVarComp. If you used OpenMDAO prior to version 3.2, then you are familiar with this component. It is used to define an independent variable.

You usually do not have to define these because OpenMDAO defines and uses them automatically for all unconnected inputs in your model. This automatically-created IndepVarComp is called an Auto-IVC.

However, when we define a distributed input, it sometimes isn’t possible to determine the full size of the corresponding independent variable, and the IndepVarComp cannot be created automatically. So, for unconnected inputs on a distributed component, you must manually create one, as we did in this example.

Derivatives with Distributed Variables#

In the following examples, we show how to add analytic derivatives to the distributed examples given above. In most cases it is straighforward, but there is a special case when a non-distributed output depends on a distributed input or vice versa:

When you have a distributed output depending on a non-distributed input, or a non-distributed output depending on a distributed input, the matrix-free format is required, meaning that you will need to define a compute_jacvec_product method if your component is an ExplicitComponent, or apply_linear if your component is an ImplicitComponent. The reasons for this are a bit subtle, but it boils down to OpenMDAO not knowing what kind of distributed operations are being done in the compute and having no way to manage those details when you propagate things in reverse.

An important thing to keep in mind when deciding whether or not you need to some sort of reduce operation when computing derivatives using the matrix-free API with mixed distributed/non-distributed derivatives is that derivatives of any non-distributed variable in that case must be the same across all ranks.

Derivatives: Distributed I/O and a Non-Distributed Input#

In this example, we have a distributed input, a distributed output, and a non-distributed input. Because we have a distributed output, ‘out_dist’ depending on a non-distributed input ‘in_nd’, we have to use the matrix-free API and define a compute_jacvec_product method for our component. It is described in the feature document for ExplicitComponent. Note that for this component we only have to do an allreduce when computing the jacobian vector product in reverse mode. That’s because in reverse mode we have to sum up the contribution of the entire distributed output ‘out_dist’ to the non-distributed input ‘in_nd’ so that the derivatives for ‘in_nd’ will be the same across all ranks. We verify that the derivatives are correct using check_totals with complex step since our component is complex-safe.

%%px 

import numpy as np

import openmdao.api as om
from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI


class MixedDistrib1(om.ExplicitComponent):

    def setup(self):

        # Distributed Input
        self.add_input('in_dist', shape_by_conn=True, distributed=True)

        # Non-Distributed Input
        self.add_input('in_nd', shape_by_conn=True)

        # Distributed Output
        self.add_output('out_dist', copy_shape='in_dist', distributed=True)

    def compute(self, inputs, outputs):
        Id = inputs['in_dist']
        Ind = inputs['in_nd']

        # "Computationally Intensive" operation that we wish to parallelize.
        f_Id = Id**2 - 2.0*Id + 4.0

        # This operation is repeated on all procs.
        f_Ind = Ind ** 0.5

        outputs['out_dist'] = f_Id + np.sum(f_Ind)

    def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
        Id = inputs['in_dist']
        Ind = inputs['in_nd']

        df_dId = 2.0 * Id - 2.0
        df_dInd = 0.5 / Ind ** 0.5

        nId = len(Id)
        nInd = len(Ind)

        if mode == 'fwd':
            if 'out_dist' in d_outputs:
                if 'in_nd' in d_inputs:
                    d_outputs['out_dist'] += np.tile(df_dInd, nId).reshape((nId, nInd)).dot(d_inputs['in_nd'])
                if 'in_dist' in d_inputs:
                    d_outputs['out_dist'] += df_dId * d_inputs['in_dist']

        else:  # rev
            if 'out_dist' in d_outputs:
                if 'in_nd' in d_inputs:
                    d_inputs['in_nd'] += self.comm.allreduce(np.tile(df_dInd, nId).reshape((nId, nInd)).T.dot(d_outputs['out_dist']))

                if 'in_dist' in d_inputs:
                    d_inputs['in_dist'] += df_dId * d_outputs['out_dist']


size = 7

if MPI:
    comm = MPI.COMM_WORLD
    rank = comm.rank
    sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
    # When running without MPI, the entire variable is on one proc.
    rank = 0
    sizes = {rank : size}
    offsets = {rank : 0}

prob = om.Problem()
model = prob.model

# Create a distributed source for the distributed input.
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)
ivc.add_output('x_nd', np.zeros(size))

model.add_subsystem("indep", ivc)
model.add_subsystem("D1", MixedDistrib1())

model.connect('indep.x_dist', 'D1.in_dist')
model.connect('indep.x_nd', 'D1.in_nd')

model.add_design_var('indep.x_nd')
model.add_design_var('indep.x_dist')
model.add_objective('D1.out_dist')

prob.setup(force_alloc_complex=True)

# Set initial values of distributed variable.
x_dist_init = 3.0 + np.arange(size)[offsets[rank]:offsets[rank] + sizes[rank]]
prob.set_val('indep.x_dist', x_dist_init)

# Set initial values of non-distributed variable.
x_nd_init = 1.0 + 2.0*np.arange(size)
prob.set_val('indep.x_nd', x_nd_init)

prob.run_model()

if rank > 0:
    prob.check_totals(method='cs', out_stream=None)
else:
    prob.check_totals(method='cs')
[stderr:1] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib1>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.
[stderr:2] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib1>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.
[stderr:3] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib1>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.
[stdout:0] -----------------
Total Derivatives
-----------------

  Full Model: 'D1.out_dist' wrt 'indep.x_dist'
     Reverse Magnitude: 2.849561e+01
          Fd Magnitude: 2.849561e+01 (cs:None)

    Absolute Error (Jrev - Jfd) : 0.000000e+00

    Relative Error (Jrev - Jfd) / Jfd : 0.000000e+00

    MPI Rank 0

    Raw Reverse Derivative (Jrev)
    [[ 4.  0.  0.  0.  0.  0.  0.]
     [ 0.  6.  0.  0.  0.  0.  0.]
     [ 0.  0.  8.  0.  0.  0.  0.]
     [ 0.  0.  0. 10.  0.  0.  0.]
     [ 0.  0.  0.  0. 12.  0.  0.]
     [ 0.  0.  0.  0.  0. 14.  0.]
     [ 0.  0.  0.  0.  0.  0. 16.]]

    Raw CS Derivative (Jfd)
    [[ 4.  0.  0.  0.  0.  0.  0.]
     [ 0.  6.  0.  0.  0.  0.  0.]
     [ 0.  0.  8.  0.  0.  0.  0.]
     [ 0.  0.  0. 10.  0.  0.  0.]
     [ 0.  0.  0.  0. 12.  0.  0.]
     [ 0.  0.  0.  0.  0. 14.  0.]
     [ 0.  0.  0.  0.  0.  0. 16.]]

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  Full Model: 'D1.out_dist' wrt 'indep.x_nd'
     Reverse Magnitude: 1.849725e+00
          Fd Magnitude: 1.849725e+00 (cs:None)

    Absolute Error (Jrev - Jfd) : 1.642042e-16

    Relative Error (Jrev - Jfd) / Jfd : 8.877220e-17

    MPI Rank 0

    Raw Reverse Derivative (Jrev)
    [[0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]]

    Raw CS Derivative (Jfd)
    [[0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]]

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
[stderr:0] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib1>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.

Derivatives: Mixed Distributed and Nondistributed I/O#

The following example shows how to implement derivatives on the earlier MixedDistrib2 component, which is similar to MixDistrib1 but adds a non-distributed output, ‘out_nd’ to the mix. This component uses the allreduce operation to combine contributions to its non-distributed input in_nd from its distributed output out_dist when in reverse mode. In forward mode it uses allreduce to combine the contributions to its non-distributed output out_nd from its distributed input in_dist.

%%px


import numpy as np

import openmdao.api as om
from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI


class MixedDistrib2(om.ExplicitComponent):

    def setup(self):

        # Distributed Input
        self.add_input('in_dist', shape_by_conn=True, distributed=True)

        # Non-Distributed Input
        self.add_input('in_nd', shape_by_conn=True)

        # Distributed Output
        self.add_output('out_dist', copy_shape='in_dist', distributed=True)

        # Non-Distributed Output
        self.add_output('out_nd', copy_shape='in_nd')

    def compute(self, inputs, outputs):
        Id = inputs['in_dist']
        Ind = inputs['in_nd']

        # "Computationally Intensive" operation that we wish to parallelize.
        f_Id = Id**2 - 2.0*Id + 4.0

        # These operations are repeated on all procs.
        f_Ind = Ind ** 0.5
        g_Ind = Ind**2 + 3.0*Ind - 5.0

        # Compute square root of our portion of the distributed input.
        g_Id = Id ** 0.5

        # Distributed output
        outputs['out_dist'] = f_Id + np.sum(f_Ind)

        # Non-Distributed output
        if MPI and comm.size > 1:

            # We need to gather the summed values to compute the total sum over all procs.
            local_sum = np.array(np.sum(g_Id))
            total_sum = local_sum.copy()
            self.comm.Allreduce(local_sum, total_sum, op=MPI.SUM)
            outputs['out_nd'] = g_Ind + total_sum
        else:
            # Recommended to make sure your code can run without MPI too, for testing.
            outputs['out_nd'] = g_Ind + np.sum(g_Id)

    def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
        Id = inputs['in_dist']
        Ind = inputs['in_nd']

        df_dId = 2.0 * Id - 2.0
        df_dInd = 0.5 / Ind ** 0.5
        dg_dId = 0.5 / Id ** 0.5
        dg_dInd = 2.0 * Ind + 3.0

        nId = len(Id)
        nInd = len(Ind)

        if mode == 'fwd':
            if 'out_dist' in d_outputs:
                if 'in_dist' in d_inputs:
                    d_outputs['out_dist'] += df_dId * d_inputs['in_dist']
                if 'in_nd' in d_inputs:
                    d_outputs['out_dist'] += np.tile(df_dInd, nId).reshape((nId, nInd)).dot(d_inputs['in_nd'])
            if 'out_nd' in d_outputs:
                if 'in_dist' in d_inputs:
                    d_outputs['out_nd'] += self.comm.allreduce(np.tile(dg_dId, nInd).reshape((nInd, nId)).dot(d_inputs['in_dist']))
                if 'in_nd' in d_inputs:
                    d_outputs['out_nd'] += dg_dInd * d_inputs['in_nd']

        else:  # rev
            if 'out_dist' in d_outputs:
                if 'in_dist' in d_inputs:
                    d_inputs['in_dist'] += df_dId * d_outputs['out_dist']
                if 'in_nd' in d_inputs:
                    d_inputs['in_nd'] += self.comm.allreduce(np.tile(df_dInd, nId).reshape((nId, nInd)).T.dot(d_outputs['out_dist']))

            if 'out_nd' in d_outputs:
                if 'in_dist' in d_inputs:
                    d_inputs['in_dist'] += np.tile(dg_dId, nInd).reshape((nInd, nId)).T.dot(d_outputs['out_nd'])
                if 'in_nd' in d_inputs:
                    d_inputs['in_nd'] += dg_dInd * d_outputs['out_nd']


size = 7

if MPI:
    comm = MPI.COMM_WORLD
    rank = comm.rank
    sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
    # When running without MPI, the entire variable is on one proc.
    rank = 0
    sizes = {rank : size}
    offsets = {rank : 0}

prob = om.Problem()
model = prob.model

# Create a distributed source for the distributed input.
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)
ivc.add_output('x_nd', np.zeros(size))

model.add_subsystem("indep", ivc)
model.add_subsystem("D1", MixedDistrib2())

model.connect('indep.x_dist', 'D1.in_dist')
model.connect('indep.x_nd', 'D1.in_nd')

model.add_design_var('indep.x_nd')
model.add_design_var('indep.x_dist')
model.add_constraint('D1.out_dist', lower=0.0)
model.add_constraint('D1.out_nd', lower=0.0)

prob.setup(force_alloc_complex=True)

# Set initial values of distributed variable.
x_dist_init = 3.0 + np.arange(size)[offsets[rank]:offsets[rank] + sizes[rank]]
prob.set_val('indep.x_dist', x_dist_init)

# Set initial values of non-distributed variable.
x_nd_init = 1.0 + 2.0*np.arange(size)
prob.set_val('indep.x_nd', x_nd_init)

prob.run_model()

if rank > 0:
    prob.check_totals(method='cs', out_stream=None)
else:
    prob.check_totals(method='cs')
[stderr:3] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib2>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.
[stderr:1] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib2>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.
[stderr:2] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib2>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.
[stdout:0] -----------------
Total Derivatives
-----------------

  Full Model: 'D1.out_dist' wrt 'indep.x_dist'
     Forward Magnitude: 2.849561e+01
          Fd Magnitude: 2.849561e+01 (cs:None)

    Absolute Error (Jfor - Jfd) : 0.000000e+00

    Relative Error (Jfor - Jfd) / Jfd : 0.000000e+00

    MPI Rank 0

    Raw Forward Derivative (Jfor)
    [[ 4. -0. -0. -0. -0. -0. -0.]
     [-0.  6. -0. -0. -0. -0. -0.]
     [-0. -0.  8. -0. -0. -0. -0.]
     [-0. -0. -0. 10. -0. -0. -0.]
     [-0. -0. -0. -0. 12. -0. -0.]
     [-0. -0. -0. -0. -0. 14. -0.]
     [-0. -0. -0. -0. -0. -0. 16.]]

    Raw CS Derivative (Jfd)
    [[ 4.  0.  0.  0.  0.  0.  0.]
     [ 0.  6.  0.  0.  0.  0.  0.]
     [ 0.  0.  8.  0.  0.  0.  0.]
     [ 0.  0.  0. 10.  0.  0.  0.]
     [ 0.  0.  0.  0. 12.  0.  0.]
     [ 0.  0.  0.  0.  0. 14.  0.]
     [ 0.  0.  0.  0.  0.  0. 16.]]

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  Full Model: 'D1.out_dist' wrt 'indep.x_nd'
     Forward Magnitude: 1.849725e+00
          Fd Magnitude: 1.849725e+00 (cs:None)

    Absolute Error (Jfor - Jfd) : 1.642042e-16

    Relative Error (Jfor - Jfd) / Jfd : 8.877220e-17

    MPI Rank 0

    Raw Forward Derivative (Jfor)
    [[0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]]

    Raw CS Derivative (Jfd)
    [[0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]
     [0.5        0.28867513 0.2236068  0.18898224 0.16666667 0.15075567 0.13867505]]

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  Full Model: 'D1.out_nd' wrt 'indep.x_dist'
     Forward Magnitude: 1.525023e+00
          Fd Magnitude: 1.525023e+00 (cs:None)

    Absolute Error (Jfor - Jfd) : 1.798767e-16

    Relative Error (Jfor - Jfd) / Jfd : 1.179502e-16

    MPI Rank 0

    Raw Forward Derivative (Jfor)
    [[0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]]

    Raw CS Derivative (Jfd)
    [[0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]
     [0.28867513 0.25       0.2236068  0.20412415 0.18898224 0.1767767  0.16666667]]

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  Full Model: 'D1.out_nd' wrt 'indep.x_nd'
     Forward Magnitude: 4.970915e+01
          Fd Magnitude: 4.970915e+01 (cs:None)

    Absolute Error (Jfor - Jfd) : 3.552714e-15

    Relative Error (Jfor - Jfd) / Jfd : 7.147001e-17

    MPI Rank 0

    Raw Forward Derivative (Jfor)
    [[ 5. -0. -0. -0. -0. -0. -0.]
     [-0.  9. -0. -0. -0. -0. -0.]
     [-0. -0. 13. -0. -0. -0. -0.]
     [-0. -0. -0. 17. -0. -0. -0.]
     [-0. -0. -0. -0. 21. -0. -0.]
     [-0. -0. -0. -0. -0. 25. -0.]
     [-0. -0. -0. -0. -0. -0. 29.]]

    Raw CS Derivative (Jfd)
    [[ 5.  0.  0.  0.  0.  0.  0.]
     [ 0.  9.  0.  0.  0.  0.  0.]
     [ 0.  0. 13.  0.  0.  0.  0.]
     [ 0.  0.  0. 17.  0.  0.  0.]
     [ 0.  0.  0.  0. 21.  0.  0.]
     [ 0.  0.  0.  0.  0. 25.  0.]
     [ 0.  0.  0.  0.  0.  0. 29.]]

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
[stderr:0] /usr/share/miniconda/envs/test/lib/python3.11/site-packages/openmdao/core/component.py:146: OMDeprecationWarning:'D1' <class MixedDistrib2>: It appears this component mixes distributed/non-distributed inputs and outputs, so it may break starting with OpenMDAO 3.25, where the convention used when passing data between distributed and non-distributed inputs and outputs within a matrix free component will change. See https://github.com/OpenMDAO/POEMs/blob/master/POEM_075.md for details.