Multiple-kernel ridge path between two kernels

This example demonstrates the path of all possible ratios of kernel weights between two kernels, in a multiple kernel ridge regression model. Over the path of ratios, the kernels are weighted by the kernel weights, then summed, and a joint model is fit on the obtained kernel. The explained variance on a test set is then computed, and decomposed over both kernels.

from functools import partial

import numpy as np
import matplotlib.pyplot as plt

from himalaya.backend import set_backend
from himalaya.kernel_ridge import MultipleKernelRidgeCV
from himalaya.kernel_ridge import Kernelizer
from himalaya.kernel_ridge import ColumnKernelizer
from himalaya.progress_bar import bar
from himalaya.utils import generate_multikernel_dataset

from sklearn.pipeline import make_pipeline
from sklearn import set_config
set_config(display='diagram')

In this example, we use the cupy backend.

backend = set_backend("cupy", on_error="warn")
/home/runner/work/himalaya/himalaya/himalaya/backend/_utils.py:55: UserWarning: Setting backend to cupy failed: Cupy not installed..Falling back to numpy backend.
  warnings.warn(f"Setting backend to {backend} failed: {str(error)}."

Generate a random dataset

  • X_train : array of shape (n_samples_train, n_features)

  • X_test : array of shape (n_samples_test, n_features)

  • Y_train : array of shape (n_samples_train, n_targets)

  • Y_test : array of shape (n_samples_test, n_targets)

n_targets = 50
kernel_weights = np.tile(np.array([0.6, 0.4])[None], (n_targets, 1))

(X_train, X_test, Y_train, Y_test,
 kernel_weights, n_features_list) = generate_multikernel_dataset(
     n_kernels=2, n_targets=n_targets, n_samples_train=600,
     n_samples_test=300, random_state=42, noise=0.31,
     kernel_weights=kernel_weights)

feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))]

Create a MultipleKernelRidgeCV model, see plot_mkr_sklearn_api.py for more details.

# Find the start and end of each feature space X in Xs.
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [
    slice(start, end)
    for start, end in zip(start_and_end[:-1], start_and_end[1:])
]

# Create a different ``Kernelizer`` for each feature space.
kernelizers = [(name, Kernelizer(), slice_)
               for name, slice_ in zip(feature_names, slices)]
column_kernelizer = ColumnKernelizer(kernelizers)

# Create a MultipleKernelRidgeCV model.
solver_params = dict(alphas=np.logspace(-5, 5, 41), progress_bar=False)
model = MultipleKernelRidgeCV(kernels="precomputed", solver="random_search",
                              solver_params=solver_params,
                              random_state=42)
pipe = make_pipeline(column_kernelizer, model)
pipe
Pipeline(steps=[('columnkernelizer',
                 ColumnKernelizer(transformers=[('Feature space 0',
                                                 Kernelizer(),
                                                 slice(np.int64(0), np.int64(1000), None)),
                                                ('Feature space 1',
                                                 Kernelizer(),
                                                 slice(np.int64(1000), np.int64(2000), None))])),
                ('multiplekernelridgecv',
                 MultipleKernelRidgeCV(kernels='precomputed', random_state=42,
                                       solver_params={'alphas': array([1.00000000e-05, 1.7782...
       1.00000000e-01, 1.77827941e-01, 3.16227766e-01, 5.62341325e-01,
       1.00000000e+00, 1.77827941e+00, 3.16227766e+00, 5.62341325e+00,
       1.00000000e+01, 1.77827941e+01, 3.16227766e+01, 5.62341325e+01,
       1.00000000e+02, 1.77827941e+02, 3.16227766e+02, 5.62341325e+02,
       1.00000000e+03, 1.77827941e+03, 3.16227766e+03, 5.62341325e+03,
       1.00000000e+04, 1.77827941e+04, 3.16227766e+04, 5.62341325e+04,
       1.00000000e+05]),
                                                      'progress_bar': False}))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Then, we manually perform a hyperparameter grid search for the kernel weights.

# Make the score method use `split=True` by default.
model.score = partial(model.score, split=True)

# Define the hyperparameter grid search.
ratios = np.logspace(-4, 4, 41)
candidates = np.array([1 - ratios / (1 + ratios), ratios / (1 + ratios)]).T

# Loop over hyperparameter candidates
split_r2_scores = []
for candidate in bar(candidates, "Hyperparameter candidates"):
    # test one hyperparameter candidate at a time
    pipe[-1].solver_params["n_iter"] = candidate[None]
    pipe.fit(X_train, Y_train)

    # split the R2 score between both kernels
    scores = pipe.score(X_test, Y_test)
    split_r2_scores.append(backend.to_numpy(scores))

# average scores over targets for plotting
split_r2_scores_avg = np.array(split_r2_scores).mean(axis=2)
[                              ] 0% | 0.00 sec | Hyperparameter candidates |
[                              ] 2% | 1.62 sec | Hyperparameter candidates | 0.62 it/s, ETA: 00:01:04
[.                             ] 5% | 2.96 sec | Hyperparameter candidates | 0.68 it/s, ETA: 00:00:57
[..                            ] 7% | 4.49 sec | Hyperparameter candidates | 0.67 it/s, ETA: 00:00:56
[..                            ] 10% | 6.03 sec | Hyperparameter candidates | 0.66 it/s, ETA: 00:00:55
[...                           ] 12% | 7.54 sec | Hyperparameter candidates | 0.66 it/s, ETA: 00:00:54
[....                          ] 15% | 8.98 sec | Hyperparameter candidates | 0.67 it/s, ETA: 00:00:52
[.....                         ] 17% | 10.30 sec | Hyperparameter candidates | 0.68 it/s, ETA: 00:00:50
[.....                         ] 20% | 11.58 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:47
[......                        ] 22% | 12.97 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:46
[.......                       ] 24% | 14.24 sec | Hyperparameter candidates | 0.70 it/s, ETA: 00:00:44
[........                      ] 27% | 15.77 sec | Hyperparameter candidates | 0.70 it/s, ETA: 00:00:43
[........                      ] 29% | 17.22 sec | Hyperparameter candidates | 0.70 it/s, ETA: 00:00:41
[.........                     ] 32% | 18.65 sec | Hyperparameter candidates | 0.70 it/s, ETA: 00:00:40
[..........                    ] 34% | 20.17 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:38
[..........                    ] 37% | 21.60 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:37
[...........                   ] 39% | 23.08 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:36
[............                  ] 41% | 24.45 sec | Hyperparameter candidates | 0.70 it/s, ETA: 00:00:34
[.............                 ] 44% | 25.83 sec | Hyperparameter candidates | 0.70 it/s, ETA: 00:00:33
[.............                 ] 46% | 27.45 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:31
[..............                ] 49% | 28.96 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:30
[...............               ] 51% | 30.57 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:29
[................              ] 54% | 32.03 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:27
[................              ] 56% | 33.60 sec | Hyperparameter candidates | 0.68 it/s, ETA: 00:00:26
[.................             ] 59% | 35.16 sec | Hyperparameter candidates | 0.68 it/s, ETA: 00:00:24
[..................            ] 61% | 36.49 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:23
[...................           ] 63% | 37.91 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:21
[...................           ] 66% | 39.38 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:20
[....................          ] 68% | 40.89 sec | Hyperparameter candidates | 0.68 it/s, ETA: 00:00:18
[.....................         ] 71% | 42.27 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:17
[.....................         ] 73% | 43.72 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:16
[......................        ] 76% | 45.14 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:14
[.......................       ] 78% | 46.63 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:13
[........................      ] 80% | 47.92 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:11
[........................      ] 83% | 49.27 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:10
[.........................     ] 85% | 50.87 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:08
[..........................    ] 88% | 52.24 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:07
[...........................   ] 90% | 53.71 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:05
[...........................   ] 93% | 55.18 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:04
[............................  ] 95% | 56.76 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:02
[............................. ] 98% | 58.22 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:01
[..............................] 100% | 59.76 sec | Hyperparameter candidates | 0.69 it/s, ETA: 00:00:00

Plot the variance decomposition for all the hyperparameter ratios.

For a ratio of 1e-3, feature space 0 is almost not used. For a ratio of 1e3, feature space 1 is almost not used. The best ratio is here around 1, because the feature spaces are used with similar scales in the simulated dataset.

fig, ax = plt.subplots(figsize=(5, 4))
accumulator = np.zeros_like(ratios)
for split in split_r2_scores_avg.T:
    ax.fill_between(ratios, accumulator, accumulator + split, alpha=0.7)
    accumulator += split

ax.set(xscale='log')
ax.set(xlabel=r"Ratio of kernel weight ($\gamma_A / \gamma_B$)")
ax.set(ylabel=r"$R^2$ score (test set)")
ax.set(title=r"$R^2$ score decomposition")
ax.legend(feature_names, loc="upper left")
ax.grid()
fig.tight_layout()
plt.show()
$R^2$ score decomposition

Total running time of the script: (1 minutes 0.190 seconds)

Gallery generated by Sphinx-Gallery