Source code for smt_optim.utils.plot_2d

import numpy as np
import matplotlib.pyplot as plt

from typing import Callable

[docs] def get_plot2d_data(func: Callable, bounds: np.ndarray, num_points: int = 101) -> tuple: X = np.linspace(bounds[0, 0], bounds[0, 1], num_points) Y = np.linspace(bounds[1, 0], bounds[1, 1], num_points) XX, YY = np.meshgrid(X, Y) data = np.vstack((XX.ravel(), YY.ravel())).T z = np.empty(data.shape[0]) for i in range(data.shape[0]): z[i] = func(data[i, :]) Z = z.reshape(XX.shape) return XX, YY, Z