Multiple-kernel ridge with scikit-learn API

This example demonstrates how to solve multiple kernel ridge regression, using scikit-learn API.

import numpy as np
import matplotlib.pyplot as plt

from himalaya.backend import set_backend
from himalaya.kernel_ridge import KernelRidgeCV
from himalaya.kernel_ridge import MultipleKernelRidgeCV
from himalaya.kernel_ridge import Kernelizer
from himalaya.kernel_ridge import ColumnKernelizer
from himalaya.utils import generate_multikernel_dataset

from sklearn.pipeline import make_pipeline
from sklearn import set_config
set_config(display='diagram')

In this example, we use the torch_cuda backend.

Torch can perform computations both on CPU and GPU. To use CPU, use the “torch” backend, to use GPU, use the “torch_cuda” backend.

backend = set_backend("torch_cuda", on_error="warn")
/home/runner/work/himalaya/himalaya/himalaya/backend/_utils.py:55: UserWarning: Setting backend to torch_cuda failed: PyTorch with CUDA is not available..Falling back to numpy backend.
  warnings.warn(f"Setting backend to {backend} failed: {str(error)}."

Generate a random dataset

  • X_train : array of shape (n_samples_train, n_features)

  • X_test : array of shape (n_samples_test, n_features)

  • Y_train : array of shape (n_samples_train, n_targets)

  • Y_test : array of shape (n_samples_test, n_targets)

(X_train, X_test, Y_train, Y_test, kernel_weights,
 n_features_list) = generate_multikernel_dataset(n_kernels=3, n_targets=50,
                                                 n_samples_train=600,
                                                 n_samples_test=300,
                                                 random_state=42)

feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))]

We could precompute the kernels by hand on Xs_train, as done in plot_mkr_random_search.py. Instead, here we use the ColumnKernelizer to make a scikit-learn Pipeline.

# Find the start and end of each feature space X in Xs
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [
    slice(start, end)
    for start, end in zip(start_and_end[:-1], start_and_end[1:])
]

Create a different Kernelizer for each feature space. Here we use a linear kernel for all feature spaces, but ColumnKernelizer accepts any Kernelizer, or scikit-learn Pipeline ending with a Kernelizer.

kernelizers = [(name, Kernelizer(), slice_)
               for name, slice_ in zip(feature_names, slices)]
column_kernelizer = ColumnKernelizer(kernelizers)

# Note that ``ColumnKernelizer`` has a parameter ``n_jobs`` to parallelize each
# kernelizer, yet such parallelism does not work with GPU arrays.

Define the model

The class takes a number of common parameters during initialization, such as kernels or solver. Since the solver parameters might be different depending on the solver, they can be passed in the solver_params parameter.

Here we use the “random_search” solver. We can check its specific parameters in the function docstring:

solver_function = MultipleKernelRidgeCV.ALL_SOLVERS["random_search"]
print("Docstring of the function %s:" % solver_function.__name__)
print(solver_function.__doc__)
Docstring of the function solve_multiple_kernel_ridge_random_search:
Solve multiple kernel ridge regression using random search.

    Parameters
    ----------
    Ks : array of shape (n_kernels, n_samples, n_samples)
        Input kernels.
    Y : array of shape (n_samples, n_targets)
        Target data.
    n_iter : int, or array of shape (n_iter, n_kernels)
        Number of kernel weights combination to search.
        If an array is given, the solver uses it as the list of kernel weights
        to try, instead of sampling from a Dirichlet distribution. Examples:
          - `n_iter=np.eye(n_kernels)` implement a winner-take-all strategy
            over kernels.
          - `n_iter=np.ones((1, n_kernels))/n_kernels` solves a (standard)
            kernel ridge regression.
    concentration : float, or list of float
        Concentration parameters of the Dirichlet distribution.
        If a list, iteratively cycle through the list.
        Not used if n_iter is an array.
    alphas : float or array of shape (n_alphas, )
        Range of ridge regularization parameter.
    score_func : callable
        Function used to compute the score of predictions versus Y.
    cv : int or scikit-learn splitter
        Cross-validation splitter. If an int, KFold is used.
    fit_intercept : boolean
        Whether to fit an intercept. If False, Ks should be centered
        (see KernelCenterer), and Y must be zero-mean over samples.
        Only available if return_weights == 'dual'.
    return_weights : None, 'primal', or 'dual'
        Whether to refit on the entire dataset and return the weights.
    Xs : array of shape (n_kernels, n_samples, n_features) or None
        Necessary if return_weights == 'primal'.
    local_alpha : bool
        If True, alphas are selected per target, else shared over all targets.
    jitter_alphas : bool
        If True, alphas range is slightly jittered for each gamma.
    random_state : int, or None
        Random generator seed. Use an int for deterministic search.
    n_targets_batch : int or None
        Size of the batch for over targets during cross-validation.
        Used for memory reasons. If None, uses all n_targets at once.
    n_targets_batch_refit : int or None
        Size of the batch for over targets during refit.
        Used for memory reasons. If None, uses all n_targets at once.
    n_alphas_batch : int or None
        Size of the batch for over alphas. Used for memory reasons.
        If None, uses all n_alphas at once.
    progress_bar : bool
        If True, display a progress bar over gammas.
    Ks_in_cpu : bool
        If True, keep Ks in CPU memory to limit GPU memory (slower).
        This feature is not available through the scikit-learn API.
    conservative : bool
        If True, when selecting the hyperparameter alpha, take the largest one
        that is less than one standard deviation away from the best.
        If False, take the best.
    Y_in_cpu : bool
        If True, keep the target values ``Y`` in CPU memory (slower).
    diagonalize_method : str in {"eigh", "svd"}
        Method used to diagonalize the kernel.
    return_alphas : bool
        If True, return the best alpha value for each target.

    Returns
    -------
    deltas : array of shape (n_kernels, n_targets)
        Best log kernel weights for each target.
    refit_weights : array or None
        Refit regression weights on the entire dataset, using selected best
        hyperparameters. Refit weights are always stored on CPU memory.
        If return_weights == 'primal', shape is (n_features, n_targets),
        if return_weights == 'dual', shape is (n_samples, n_targets),
        else, None.
    cv_scores : array of shape (n_iter, n_targets)
        Cross-validation scores per iteration, averaged over splits, for the
        best alpha. Cross-validation scores will always be on CPU memory.
    best_alphas : array of shape (n_targets, )
        Best alpha value per target. Only returned if return_alphas is True.
    intercept : array of shape (n_targets,)
        Intercept. Only returned when fit_intercept is True.

We use 100 iterations to have a reasonably fast example (~40 sec). To have a better convergence, we probably need more iterations. Note that there is currently no stopping criterion in this method.

n_iter = 100

Grid of regularization parameters.

alphas = np.logspace(-10, 10, 41)

Batch parameters are used to reduce the necessary GPU memory. A larger value will be a bit faster, but the solver might crash if it runs out of memory. Optimal values depend on the size of your dataset.

n_targets_batch = 1000
n_alphas_batch = 20
n_targets_batch_refit = 200

solver_params = dict(n_iter=n_iter, alphas=alphas,
                     n_targets_batch=n_targets_batch,
                     n_alphas_batch=n_alphas_batch,
                     n_targets_batch_refit=n_targets_batch_refit,
                     jitter_alphas=True)

model = MultipleKernelRidgeCV(kernels="precomputed", solver="random_search",
                              solver_params=solver_params)

Define and fit the pipeline

pipe = make_pipeline(column_kernelizer, model)
pipe.fit(X_train, Y_train)
[                              ] 0% | 0.00 sec | 100 random sampling with cv |
[                              ] 1% | 1.62 sec | 100 random sampling with cv | 0.62 it/s, ETA: 00:02:40
[                              ] 2% | 3.20 sec | 100 random sampling with cv | 0.63 it/s, ETA: 00:02:36
[                              ] 3% | 4.50 sec | 100 random sampling with cv | 0.67 it/s, ETA: 00:02:25
[.                             ] 4% | 5.68 sec | 100 random sampling with cv | 0.70 it/s, ETA: 00:02:16
[.                             ] 5% | 7.05 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:13
[.                             ] 6% | 8.51 sec | 100 random sampling with cv | 0.70 it/s, ETA: 00:02:13
[..                            ] 7% | 9.79 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:10
[..                            ] 8% | 11.29 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:09
[..                            ] 9% | 12.71 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:08
[...                           ] 10% | 14.15 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:07
[...                           ] 11% | 15.27 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:02:03
[...                           ] 12% | 16.63 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:02:01
[...                           ] 13% | 17.89 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:59
[....                          ] 14% | 19.33 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:01:58
[....                          ] 15% | 20.74 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:01:57
[....                          ] 16% | 21.82 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:54
[.....                         ] 17% | 23.29 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:53
[.....                         ] 18% | 24.67 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:52
[.....                         ] 19% | 25.97 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:50
[......                        ] 20% | 27.47 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:49
[......                        ] 21% | 28.52 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:47
[......                        ] 22% | 29.82 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:45
[......                        ] 23% | 31.19 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:44
[.......                       ] 24% | 32.29 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:42
[.......                       ] 25% | 33.56 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:40
[.......                       ] 26% | 35.03 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:39
[........                      ] 27% | 36.34 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:38
[........                      ] 28% | 37.46 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:36
[........                      ] 29% | 38.68 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:34
[.........                     ] 30% | 39.90 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:33
[.........                     ] 31% | 41.23 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:31
[.........                     ] 32% | 42.40 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:30
[.........                     ] 33% | 43.59 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:28
[..........                    ] 34% | 44.97 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:27
[..........                    ] 35% | 46.23 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:25
[..........                    ] 36% | 47.41 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:24
[...........                   ] 37% | 48.69 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:22
[...........                   ] 38% | 49.81 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:21
[...........                   ] 39% | 50.75 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:19
[............                  ] 40% | 52.08 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:18
[............                  ] 41% | 53.43 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:16
[............                  ] 42% | 54.39 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:15
[............                  ] 43% | 55.24 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:13
[.............                 ] 44% | 56.39 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:11
[.............                 ] 45% | 57.55 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:10
[.............                 ] 46% | 58.71 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:08
[..............                ] 47% | 59.89 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:07
[..............                ] 48% | 61.36 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:06
[..............                ] 49% | 62.71 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:05
[...............               ] 50% | 63.75 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:03
[...............               ] 51% | 65.04 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:02
[...............               ] 52% | 66.36 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:01
[...............               ] 53% | 67.74 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:00
[................              ] 54% | 69.27 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:59
[................              ] 55% | 70.60 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:57
[................              ] 56% | 71.94 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:56
[.................             ] 57% | 73.05 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:55
[.................             ] 58% | 74.43 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:53
[.................             ] 59% | 75.53 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:52
[..................            ] 60% | 76.74 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:51
[..................            ] 61% | 77.95 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:49
[..................            ] 62% | 79.25 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:48
[..................            ] 63% | 80.38 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:47
[...................           ] 64% | 81.63 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:45
[...................           ] 65% | 82.72 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:44
[...................           ] 66% | 83.82 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:43
[....................          ] 67% | 85.14 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:41
[....................          ] 68% | 86.29 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:40
[....................          ] 69% | 87.70 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:39
[.....................         ] 70% | 89.10 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:38
[.....................         ] 71% | 90.29 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:36
[.....................         ] 72% | 91.67 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:35
[.....................         ] 73% | 92.79 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:34
[......................        ] 74% | 94.12 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:33
[......................        ] 75% | 95.46 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:31
[......................        ] 76% | 96.75 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:30
[.......................       ] 77% | 97.70 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:29
[.......................       ] 78% | 98.84 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:27
[.......................       ] 79% | 99.95 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:26
[........................      ] 80% | 101.13 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:25
[........................      ] 81% | 102.36 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:24
[........................      ] 82% | 103.39 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:22
[........................      ] 83% | 104.49 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:21
[.........................     ] 84% | 105.79 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:20
[.........................     ] 85% | 106.99 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:18
[.........................     ] 86% | 108.19 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:17
[..........................    ] 87% | 109.52 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:16
[..........................    ] 88% | 110.66 sec | 100 random sampling with cv | 0.80 it/s, ETA: 00:00:15
[..........................    ] 89% | 111.89 sec | 100 random sampling with cv | 0.80 it/s, ETA: 00:00:13
[...........................   ] 90% | 113.27 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:12
[...........................   ] 91% | 114.54 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:11
[...........................   ] 92% | 116.05 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:10
[...........................   ] 93% | 117.35 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:08
[............................  ] 94% | 118.79 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:07
[............................  ] 95% | 119.97 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:06
[............................  ] 96% | 121.32 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:05
[............................. ] 97% | 122.36 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:03
[............................. ] 98% | 123.53 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:02
[............................. ] 99% | 124.83 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:01
[..............................] 100% | 126.07 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:00
Pipeline(steps=[('columnkernelizer',
                 ColumnKernelizer(transformers=[('Feature space 0',
                                                 Kernelizer(),
                                                 slice(np.int64(0), np.int64(1000), None)),
                                                ('Feature space 1',
                                                 Kernelizer(),
                                                 slice(np.int64(1000), np.int64(2000), None)),
                                                ('Feature space 2',
                                                 Kernelizer(),
                                                 slice(np.int64(2000), np.int64(3000), None))])),
                ('multiplekernelridgecv',
                 MultipleKernelRidgeCV(kernels='precompu...
       1.00000000e+02, 3.16227766e+02, 1.00000000e+03, 3.16227766e+03,
       1.00000000e+04, 3.16227766e+04, 1.00000000e+05, 3.16227766e+05,
       1.00000000e+06, 3.16227766e+06, 1.00000000e+07, 3.16227766e+07,
       1.00000000e+08, 3.16227766e+08, 1.00000000e+09, 3.16227766e+09,
       1.00000000e+10]),
                                                      'jitter_alphas': True,
                                                      'n_alphas_batch': 20,
                                                      'n_iter': 100,
                                                      'n_targets_batch': 1000,
                                                      'n_targets_batch_refit': 200}))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Plot the convergence curve

# ``cv_scores`` gives the scores for each sampled kernel weights.
# The convergence curve is thus the current maximum for each target.
cv_scores = backend.to_numpy(pipe[1].cv_scores_)
current_max = np.maximum.accumulate(cv_scores, axis=0)
mean_current_max = np.mean(current_max, axis=1)

x_array = np.arange(1, len(mean_current_max) + 1)
plt.plot(x_array, mean_current_max, '-o')
plt.grid("on")
plt.xlabel("Number of kernel weights sampled")
plt.ylabel("L2 negative loss (higher is better)")
plt.title("Convergence curve, averaged over targets")
plt.tight_layout()
plt.show()
Convergence curve, averaged over targets

Compare to KernelRidgeCV

Compare to a baseline KernelRidgeCV model with all the concatenated features. Comparison is performed using the prediction scores on the test set.

Fit the baseline model KernelRidgeCV

baseline = KernelRidgeCV(kernel="linear", alphas=alphas)
baseline.fit(X_train, Y_train)
KernelRidgeCV(alphas=array([1.00000000e-10, 3.16227766e-10, 1.00000000e-09, 3.16227766e-09,
       1.00000000e-08, 3.16227766e-08, 1.00000000e-07, 3.16227766e-07,
       1.00000000e-06, 3.16227766e-06, 1.00000000e-05, 3.16227766e-05,
       1.00000000e-04, 3.16227766e-04, 1.00000000e-03, 3.16227766e-03,
       1.00000000e-02, 3.16227766e-02, 1.00000000e-01, 3.16227766e-01,
       1.00000000e+00, 3.16227766e+00, 1.00000000e+01, 3.16227766e+01,
       1.00000000e+02, 3.16227766e+02, 1.00000000e+03, 3.16227766e+03,
       1.00000000e+04, 3.16227766e+04, 1.00000000e+05, 3.16227766e+05,
       1.00000000e+06, 3.16227766e+06, 1.00000000e+07, 3.16227766e+07,
       1.00000000e+08, 3.16227766e+08, 1.00000000e+09, 3.16227766e+09,
       1.00000000e+10]))
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Compute scores of both models

scores = pipe.score(X_test, Y_test)
scores = backend.to_numpy(scores)

scores_baseline = baseline.score(X_test, Y_test)
scores_baseline = backend.to_numpy(scores_baseline)

Plot histograms

bins = np.linspace(0, max(scores_baseline.max(), scores.max()), 50)
plt.hist(scores_baseline, bins, alpha=0.7, label="KernelRidgeCV")
plt.hist(scores, bins, alpha=0.7, label="MultipleKernelRidgeCV")
plt.xlabel(r"$R^2$ generalization score")
plt.title("Histogram over targets")
plt.legend()
plt.show()
Histogram over targets

Total running time of the script: (2 minutes 9.284 seconds)

Gallery generated by Sphinx-Gallery