Source code for gordo.server.model_io

# -*- coding: utf-8 -*-

import logging

import numpy as np

from sklearn.pipeline import Pipeline

"""
Collection of functions to work with the input and output of a scikit.base.BaseEstimator
"""

logger = logging.getLogger(__name__)


[docs]def get_model_output(model: Pipeline, X: np.ndarray) -> np.ndarray: """ Get the raw output from the current model given X. Will try to `predict` and then `transform`, raising an error if both fail. Parameters ---------- X: np.ndarray 2d array of sample(s) Returns ------- np.ndarray The raw output of the model in numpy array form. """ try: return model.predict(X) # type: ignore # Model may only be a transformer except AttributeError: try: return model.transform(X) # type: ignore except Exception as exc: logger.error(f"Failed to predict or transform; error: {exc}") raise