Kernel ridge with cross-validation

This example demonstrates how to solve kernel ridge regression with a cross-validation of the regularization parameter, using himalaya’s estimator KernelRidgeCV.

Create a random dataset

import numpy as np
np.random.seed(0)
n_samples, n_features, n_targets = 10, 20, 4
X = np.random.randn(n_samples, n_features)
Y = np.random.randn(n_samples, n_targets)

Limit of GridSearchCV

In scikit-learn, one can use GridSearchCV to optimize hyperparameters over cross-validation.

import sklearn.model_selection
import sklearn.kernel_ridge

estimator = sklearn.kernel_ridge.KernelRidge(kernel="linear")
gscv = sklearn.model_selection.GridSearchCV(
    estimator=estimator,
    param_grid=dict(alpha=np.logspace(-2, 2, 5)),
)
gscv.fit(X, Y)
GridSearchCV(estimator=KernelRidge(),
             param_grid={'alpha': array([1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02])})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


However, since GridSearchCV optimizes the average score over all targets, it returns a single value for alpha.

gscv.best_params_
{'alpha': np.float64(100.0)}

KernelRidgeCV

To optimize each target independently, himalaya implements KernelRidgeCV, which supports any cross-validation scheme compatible with scikit-learn.

import himalaya.kernel_ridge

model = himalaya.kernel_ridge.KernelRidgeCV(kernel="linear",
                                            alphas=np.logspace(-2, 2, 5))
model.fit(X, Y)
KernelRidgeCV(alphas=array([1.e-02, 1.e-01, 1.e+00, 1.e+01, 1.e+02]))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


KernelRidgeCV returns a separate best alpha per target.

model.best_alphas_
array([1.e+00, 1.e-02, 1.e+02, 1.e+02])

Total running time of the script: (0 minutes 0.047 seconds)

Gallery generated by Sphinx-Gallery