Source code for openmdao.jax.ks
"""
jax implementations of the Kreisselmeier-Steinhauser for the min and max values in an array.
"""
try:
import jax
from jax import jit
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
except (ImportError, ModuleNotFoundError):
jax = None
from openmdao.utils.jax_utils import jit_stub as jit
CITATIONS = """
@conference {Martins:2005:SOU,
title = {On Structural Optimization Using Constraint Aggregation},
booktitle = {Proceedings of the 6th World Congress on Structural and Multidisciplinary
Optimization},
year = {2005},
month = {May},
address = {Rio de Janeiro, Brazil},
author = {Joaquim R. R. A. Martins and Nicholas M. K. Poon}
}
"""
[docs]@jit
def ks_max(x, rho=100.0):
"""
Compute a differentiable maximum value in an array.
Given some array of values `x`, compute a differentiable, _conservative_ maximum using the
Kreisselmeier-Steinhauser function.
Parameters
----------
x : ndarray
Array of values.
rho : float
Aggregation Factor. Larger values of rho more closely match the true maximum value.
Returns
-------
float
A conservative approximation to the minimum value in x.
"""
x_max = jnp.max(x)
x_diff = x - x_max
exponents = jnp.exp(rho * x_diff)
summation = jnp.sum(exponents)
return x_max + 1.0 / rho * jnp.log(summation)
@jit
def ks_min(x, rho=100.0):
"""
Compute a differentiable minimum value in an array.
Given some array of values `x`, compute a differentiable,
_conservative_ minimum using the Kreisselmeier-Steinhauser function.
Parameters
----------
x : ndarray
Array of values.
rho : float
Aggregation Factor. Larger values of rho more closely match the true minimum value.
Returns
-------
float
A conservative approximation to the minimum value in x.
"""
x_min = jnp.min(x)
x_diff = x_min - x
exponents = jnp.exp(rho * x_diff)
summation = jnp.sum(exponents)
return x_min - 1.0 / rho * jnp.log(summation)