Source code for smt_optim.surrogate_models.base
from abc import ABC, abstractmethod
import numpy as np
[docs]
class Surrogate(ABC):
"""
Abstract class for surrogate models.
"""
def __init__(self):
pass
[docs]
@abstractmethod
def train(self, xt: list[np.ndarray], yt: list[np.ndarray], **kwargs) -> None:
raise Exception("train() method not implemented.")
[docs]
@abstractmethod
def predict_values(self, x_pred: np.ndarray) -> np.ndarray:
raise Exception("predict_value() method not implemented.")
[docs]
@abstractmethod
def predict_variances(self, x_pred: np.ndarray) -> np.ndarray:
raise Exception("predict_variance() method not implemented.")