Source code for smt_optim.surrogate_models.base

from abc import ABC, abstractmethod

import numpy as np


[docs] class Surrogate(ABC): """ Abstract class for surrogate models. """ def __init__(self): pass
[docs] @abstractmethod def train(self, xt: list[np.ndarray], yt: list[np.ndarray], **kwargs) -> None: raise Exception("train() method not implemented.")
[docs] @abstractmethod def predict_values(self, x_pred: np.ndarray) -> np.ndarray: raise Exception("predict_value() method not implemented.")
[docs] @abstractmethod def predict_variances(self, x_pred: np.ndarray) -> np.ndarray: raise Exception("predict_variance() method not implemented.")