Multiple-kernel ridge with scikit-learn API

This example demonstrates how to solve multiple kernel ridge regression, using scikit-learn API.

import numpy as np
import matplotlib.pyplot as plt

from himalaya.backend import set_backend
from himalaya.kernel_ridge import KernelRidgeCV
from himalaya.kernel_ridge import MultipleKernelRidgeCV
from himalaya.kernel_ridge import Kernelizer
from himalaya.kernel_ridge import ColumnKernelizer
from himalaya.utils import generate_multikernel_dataset

from sklearn.pipeline import make_pipeline
from sklearn import set_config
set_config(display='diagram')

In this example, we use the torch_cuda backend.

Torch can perform computations both on CPU and GPU. To use CPU, use the “torch” backend, to use GPU, use the “torch_cuda” backend.

backend = set_backend("torch_cuda", on_error="warn")

Generate a random dataset

  • X_train : array of shape (n_samples_train, n_features)

  • X_test : array of shape (n_samples_test, n_features)

  • Y_train : array of shape (n_samples_train, n_targets)

  • Y_test : array of shape (n_samples_test, n_targets)

(X_train, X_test, Y_train, Y_test, kernel_weights,
 n_features_list) = generate_multikernel_dataset(n_kernels=3, n_targets=50,
                                                 n_samples_train=600,
                                                 n_samples_test=300,
                                                 random_state=42)

feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))]

We could precompute the kernels by hand on Xs_train, as done in plot_mkr_random_search.py. Instead, here we use the ColumnKernelizer to make a scikit-learn Pipeline.

# Find the start and end of each feature space X in Xs
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [
    slice(start, end)
    for start, end in zip(start_and_end[:-1], start_and_end[1:])
]

Create a different Kernelizer for each feature space. Here we use a linear kernel for all feature spaces, but ColumnKernelizer accepts any Kernelizer, or scikit-learn Pipeline ending with a Kernelizer.

kernelizers = [(name, Kernelizer(), slice_)
               for name, slice_ in zip(feature_names, slices)]
column_kernelizer = ColumnKernelizer(kernelizers)

# Note that ``ColumnKernelizer`` has a parameter ``n_jobs`` to parallelize each
# kernelizer, yet such parallelism does not work with GPU arrays.

Define the model

The class takes a number of common parameters during initialization, such as kernels or solver. Since the solver parameters might be different depending on the solver, they can be passed in the solver_params parameter.

Here we use the “random_search” solver. We can check its specific parameters in the function docstring:

solver_function = MultipleKernelRidgeCV.ALL_SOLVERS["random_search"]
print("Docstring of the function %s:" % solver_function.__name__)
print(solver_function.__doc__)

Out:

Docstring of the function solve_multiple_kernel_ridge_random_search:
Solve multiple kernel ridge regression using random search.

    Parameters
    ----------
    Ks : array of shape (n_kernels, n_samples, n_samples)
        Input kernels.
    Y : array of shape (n_samples, n_targets)
        Target data.
    n_iter : int, or array of shape (n_iter, n_kernels)
        Number of kernel weights combination to search.
        If an array is given, the solver uses it as the list of kernel weights
        to try, instead of sampling from a Dirichlet distribution. Examples:
          - `n_iter=np.eye(n_kernels)` implement a winner-take-all strategy
            over kernels.
          - `n_iter=np.ones((1, n_kernels))/n_kernels` solves a (standard)
            kernel ridge regression.
    concentration : float, or list of float
        Concentration parameters of the Dirichlet distribution.
        If a list, iteratively cycle through the list.
        Not used if n_iter is an array.
    alphas : float or array of shape (n_alphas, )
        Range of ridge regularization parameter.
    score_func : callable
        Function used to compute the score of predictions versus Y.
    cv : int or scikit-learn splitter
        Cross-validation splitter. If an int, KFold is used.
    fit_intercept : boolean
        Whether to fit an intercept. If False, Ks should be centered
        (see KernelCenterer), and Y must be zero-mean over samples.
        Only available if return_weights == 'dual'.
    return_weights : None, 'primal', or 'dual'
        Whether to refit on the entire dataset and return the weights.
    Xs : array of shape (n_kernels, n_samples, n_features) or None
        Necessary if return_weights == 'primal'.
    local_alpha : bool
        If True, alphas are selected per target, else shared over all targets.
    jitter_alphas : bool
        If True, alphas range is slightly jittered for each gamma.
    random_state : int, or None
        Random generator seed. Use an int for deterministic search.
    n_targets_batch : int or None
        Size of the batch for over targets during cross-validation.
        Used for memory reasons. If None, uses all n_targets at once.
    n_targets_batch_refit : int or None
        Size of the batch for over targets during refit.
        Used for memory reasons. If None, uses all n_targets at once.
    n_alphas_batch : int or None
        Size of the batch for over alphas. Used for memory reasons.
        If None, uses all n_alphas at once.
    progress_bar : bool
        If True, display a progress bar over gammas.
    Ks_in_cpu : bool
        If True, keep Ks in CPU memory to limit GPU memory (slower).
        This feature is not available through the scikit-learn API.
    conservative : bool
        If True, when selecting the hyperparameter alpha, take the largest one
        that is less than one standard deviation away from the best.
        If False, take the best.
    Y_in_cpu : bool
        If True, keep the target values ``Y`` in CPU memory (slower).
    diagonalize_method : str in {"eigh", "svd"}
        Method used to diagonalize the kernel.
    return_alphas : bool
        If True, return the best alpha value for each target.

    Returns
    -------
    deltas : array of shape (n_kernels, n_targets)
        Best log kernel weights for each target.
    refit_weights : array or None
        Refit regression weights on the entire dataset, using selected best
        hyperparameters. Refit weights are always stored on CPU memory.
        If return_weights == 'primal', shape is (n_features, n_targets),
        if return_weights == 'dual', shape is (n_samples, n_targets),
        else, None.
    cv_scores : array of shape (n_iter, n_targets)
        Cross-validation scores per iteration, averaged over splits, for the
        best alpha. Cross-validation scores will always be on CPU memory.
    best_alphas : array of shape (n_targets, )
        Best alpha value per target. Only returned if return_alphas is True.
    intercept : array of shape (n_targets,)
        Intercept. Only returned when fit_intercept is True.

We use 100 iterations to have a reasonably fast example (~40 sec). To have a better convergence, we probably need more iterations. Note that there is currently no stopping criterion in this method.

n_iter = 100

Grid of regularization parameters.

alphas = np.logspace(-10, 10, 41)

Batch parameters are used to reduce the necessary GPU memory. A larger value will be a bit faster, but the solver might crash if it runs out of memory. Optimal values depend on the size of your dataset.

n_targets_batch = 1000
n_alphas_batch = 20
n_targets_batch_refit = 200

solver_params = dict(n_iter=n_iter, alphas=alphas,
                     n_targets_batch=n_targets_batch,
                     n_alphas_batch=n_alphas_batch,
                     n_targets_batch_refit=n_targets_batch_refit,
                     jitter_alphas=True)

model = MultipleKernelRidgeCV(kernels="precomputed", solver="random_search",
                              solver_params=solver_params)

Define and fit the pipeline

pipe = make_pipeline(column_kernelizer, model)
pipe.fit(X_train, Y_train)

Out:

[                                        ] 0% | 0.00 sec | 100 random sampling with cv |
[                                        ] 1% | 0.11 sec | 100 random sampling with cv |
[                                        ] 2% | 0.23 sec | 100 random sampling with cv |
[.                                       ] 3% | 0.34 sec | 100 random sampling with cv |
[.                                       ] 4% | 0.45 sec | 100 random sampling with cv |
[..                                      ] 5% | 0.56 sec | 100 random sampling with cv |
[..                                      ] 6% | 0.68 sec | 100 random sampling with cv |
[..                                      ] 7% | 0.79 sec | 100 random sampling with cv |
[...                                     ] 8% | 0.90 sec | 100 random sampling with cv |
[...                                     ] 9% | 1.02 sec | 100 random sampling with cv |
[....                                    ] 10% | 1.13 sec | 100 random sampling with cv |
[....                                    ] 11% | 1.24 sec | 100 random sampling with cv |
[....                                    ] 12% | 1.35 sec | 100 random sampling with cv |
[.....                                   ] 13% | 1.47 sec | 100 random sampling with cv |
[.....                                   ] 14% | 1.58 sec | 100 random sampling with cv |
[......                                  ] 15% | 1.69 sec | 100 random sampling with cv |
[......                                  ] 16% | 1.81 sec | 100 random sampling with cv |
[......                                  ] 17% | 1.92 sec | 100 random sampling with cv |
[.......                                 ] 18% | 2.03 sec | 100 random sampling with cv |
[.......                                 ] 19% | 2.14 sec | 100 random sampling with cv |
[........                                ] 20% | 2.26 sec | 100 random sampling with cv |
[........                                ] 21% | 2.37 sec | 100 random sampling with cv |
[........                                ] 22% | 2.45 sec | 100 random sampling with cv |
[.........                               ] 23% | 2.56 sec | 100 random sampling with cv |
[.........                               ] 24% | 2.65 sec | 100 random sampling with cv |
[..........                              ] 25% | 2.73 sec | 100 random sampling with cv |
[..........                              ] 26% | 2.81 sec | 100 random sampling with cv |
[..........                              ] 27% | 2.93 sec | 100 random sampling with cv |
[...........                             ] 28% | 3.04 sec | 100 random sampling with cv |
[...........                             ] 29% | 3.15 sec | 100 random sampling with cv |
[............                            ] 30% | 3.26 sec | 100 random sampling with cv |
[............                            ] 31% | 3.38 sec | 100 random sampling with cv |
[............                            ] 32% | 3.49 sec | 100 random sampling with cv |
[.............                           ] 33% | 3.60 sec | 100 random sampling with cv |
[.............                           ] 34% | 3.72 sec | 100 random sampling with cv |
[..............                          ] 35% | 3.83 sec | 100 random sampling with cv |
[..............                          ] 36% | 3.94 sec | 100 random sampling with cv |
[..............                          ] 37% | 4.06 sec | 100 random sampling with cv |
[...............                         ] 38% | 4.17 sec | 100 random sampling with cv |
[...............                         ] 39% | 4.25 sec | 100 random sampling with cv |
[................                        ] 40% | 4.36 sec | 100 random sampling with cv |
[................                        ] 41% | 4.48 sec | 100 random sampling with cv |
[................                        ] 42% | 4.56 sec | 100 random sampling with cv |
[.................                       ] 43% | 4.67 sec | 100 random sampling with cv |
[.................                       ] 44% | 4.76 sec | 100 random sampling with cv |
[..................                      ] 45% | 4.87 sec | 100 random sampling with cv |
[..................                      ] 46% | 4.98 sec | 100 random sampling with cv |
[..................                      ] 47% | 5.09 sec | 100 random sampling with cv |
[...................                     ] 48% | 5.21 sec | 100 random sampling with cv |
[...................                     ] 49% | 5.32 sec | 100 random sampling with cv |
[....................                    ] 50% | 5.40 sec | 100 random sampling with cv |
[....................                    ] 51% | 5.49 sec | 100 random sampling with cv |
[....................                    ] 52% | 5.60 sec | 100 random sampling with cv |
[.....................                   ] 53% | 5.71 sec | 100 random sampling with cv |
[.....................                   ] 54% | 5.82 sec | 100 random sampling with cv |
[......................                  ] 55% | 5.94 sec | 100 random sampling with cv |
[......................                  ] 56% | 6.02 sec | 100 random sampling with cv |
[......................                  ] 57% | 6.10 sec | 100 random sampling with cv |
[.......................                 ] 58% | 6.19 sec | 100 random sampling with cv |
[.......................                 ] 59% | 6.27 sec | 100 random sampling with cv |
[........................                ] 60% | 6.36 sec | 100 random sampling with cv |
[........................                ] 61% | 6.44 sec | 100 random sampling with cv |
[........................                ] 62% | 6.52 sec | 100 random sampling with cv |
[.........................               ] 63% | 6.61 sec | 100 random sampling with cv |
[.........................               ] 64% | 6.72 sec | 100 random sampling with cv |
[..........................              ] 65% | 6.83 sec | 100 random sampling with cv |
[..........................              ] 66% | 6.94 sec | 100 random sampling with cv |
[..........................              ] 67% | 7.06 sec | 100 random sampling with cv |
[...........................             ] 68% | 7.14 sec | 100 random sampling with cv |
[...........................             ] 69% | 7.26 sec | 100 random sampling with cv |
[............................            ] 70% | 7.34 sec | 100 random sampling with cv |
[............................            ] 71% | 7.42 sec | 100 random sampling with cv |
[............................            ] 72% | 7.51 sec | 100 random sampling with cv |
[.............................           ] 73% | 7.59 sec | 100 random sampling with cv |
[.............................           ] 74% | 7.67 sec | 100 random sampling with cv |
[..............................          ] 75% | 7.79 sec | 100 random sampling with cv |
[..............................          ] 76% | 7.90 sec | 100 random sampling with cv |
[..............................          ] 77% | 8.01 sec | 100 random sampling with cv |
[...............................         ] 78% | 8.10 sec | 100 random sampling with cv |
[...............................         ] 79% | 8.18 sec | 100 random sampling with cv |
[................................        ] 80% | 8.26 sec | 100 random sampling with cv |
[................................        ] 81% | 8.35 sec | 100 random sampling with cv |
[................................        ] 82% | 8.43 sec | 100 random sampling with cv |
[.................................       ] 83% | 8.51 sec | 100 random sampling with cv |
[.................................       ] 84% | 8.63 sec | 100 random sampling with cv |
[..................................      ] 85% | 8.74 sec | 100 random sampling with cv |
[..................................      ] 86% | 8.85 sec | 100 random sampling with cv |
[..................................      ] 87% | 8.97 sec | 100 random sampling with cv |
[...................................     ] 88% | 9.05 sec | 100 random sampling with cv |
[...................................     ] 89% | 9.17 sec | 100 random sampling with cv |
[....................................    ] 90% | 9.25 sec | 100 random sampling with cv |
[....................................    ] 91% | 9.33 sec | 100 random sampling with cv |
[....................................    ] 92% | 9.42 sec | 100 random sampling with cv |
[.....................................   ] 93% | 9.53 sec | 100 random sampling with cv |
[.....................................   ] 94% | 9.61 sec | 100 random sampling with cv |
[......................................  ] 95% | 9.73 sec | 100 random sampling with cv |
[......................................  ] 96% | 9.84 sec | 100 random sampling with cv |
[......................................  ] 97% | 9.93 sec | 100 random sampling with cv |
[....................................... ] 98% | 10.01 sec | 100 random sampling with cv |
[....................................... ] 99% | 10.09 sec | 100 random sampling with cv |
[........................................] 100% | 10.18 sec | 100 random sampling with cv |
Pipeline(steps=[('columnkernelizer',
                 ColumnKernelizer(transformers=[('Feature space 0',
                                                 Kernelizer(),
                                                 slice(0, 1000, None)),
                                                ('Feature space 1',
                                                 Kernelizer(),
                                                 slice(1000, 2000, None)),
                                                ('Feature space 2',
                                                 Kernelizer(),
                                                 slice(2000, 3000, None))])),
                ('multiplekernelridgecv',
                 MultipleKernelRidgeCV(kernels='precomputed',
                                       solver_params={'alphas': array([1.00000000e-10, 3.1622776...
       1.00000000e+02, 3.16227766e+02, 1.00000000e+03, 3.16227766e+03,
       1.00000000e+04, 3.16227766e+04, 1.00000000e+05, 3.16227766e+05,
       1.00000000e+06, 3.16227766e+06, 1.00000000e+07, 3.16227766e+07,
       1.00000000e+08, 3.16227766e+08, 1.00000000e+09, 3.16227766e+09,
       1.00000000e+10]),
                                                      'jitter_alphas': True,
                                                      'n_alphas_batch': 20,
                                                      'n_iter': 100,
                                                      'n_targets_batch': 1000,
                                                      'n_targets_batch_refit': 200}))])
Please rerun this cell to show the HTML repr or trust the notebook.


Plot the convergence curve

# ``cv_scores`` gives the scores for each sampled kernel weights.
# The convergence curve is thus the current maximum for each target.
cv_scores = backend.to_numpy(pipe[1].cv_scores_)
current_max = np.maximum.accumulate(cv_scores, axis=0)
mean_current_max = np.mean(current_max, axis=1)

x_array = np.arange(1, len(mean_current_max) + 1)
plt.plot(x_array, mean_current_max, '-o')
plt.grid("on")
plt.xlabel("Number of kernel weights sampled")
plt.ylabel("L2 negative loss (higher is better)")
plt.title("Convergence curve, averaged over targets")
plt.tight_layout()
plt.show()
Convergence curve, averaged over targets

Compare to KernelRidgeCV

Compare to a baseline KernelRidgeCV model with all the concatenated features. Comparison is performed using the prediction scores on the test set.

Fit the baseline model KernelRidgeCV

baseline = KernelRidgeCV(kernel="linear", alphas=alphas)
baseline.fit(X_train, Y_train)
KernelRidgeCV(alphas=array([1.00000000e-10, 3.16227766e-10, 1.00000000e-09, 3.16227766e-09,
       1.00000000e-08, 3.16227766e-08, 1.00000000e-07, 3.16227766e-07,
       1.00000000e-06, 3.16227766e-06, 1.00000000e-05, 3.16227766e-05,
       1.00000000e-04, 3.16227766e-04, 1.00000000e-03, 3.16227766e-03,
       1.00000000e-02, 3.16227766e-02, 1.00000000e-01, 3.16227766e-01,
       1.00000000e+00, 3.16227766e+00, 1.00000000e+01, 3.16227766e+01,
       1.00000000e+02, 3.16227766e+02, 1.00000000e+03, 3.16227766e+03,
       1.00000000e+04, 3.16227766e+04, 1.00000000e+05, 3.16227766e+05,
       1.00000000e+06, 3.16227766e+06, 1.00000000e+07, 3.16227766e+07,
       1.00000000e+08, 3.16227766e+08, 1.00000000e+09, 3.16227766e+09,
       1.00000000e+10]))
Please rerun this cell to show the HTML repr or trust the notebook.


Compute scores of both models

scores = pipe.score(X_test, Y_test)
scores = backend.to_numpy(scores)

scores_baseline = baseline.score(X_test, Y_test)
scores_baseline = backend.to_numpy(scores_baseline)

Plot histograms

bins = np.linspace(0, max(scores_baseline.max(), scores.max()), 50)
plt.hist(scores_baseline, bins, alpha=0.7, label="KernelRidgeCV")
plt.hist(scores, bins, alpha=0.7, label="MultipleKernelRidgeCV")
plt.xlabel(r"$R^2$ generalization score")
plt.title("Histogram over targets")
plt.legend()
plt.show()
Histogram over targets

Total running time of the script: ( 0 minutes 12.366 seconds)

Gallery generated by Sphinx-Gallery