Note
Go to the end to download the full example code.
Multiple-kernel ridge with scikit-learn API¶
This example demonstrates how to solve multiple kernel ridge regression, using scikit-learn API.
import numpy as np
import matplotlib.pyplot as plt
from himalaya.backend import set_backend
from himalaya.kernel_ridge import KernelRidgeCV
from himalaya.kernel_ridge import MultipleKernelRidgeCV
from himalaya.kernel_ridge import Kernelizer
from himalaya.kernel_ridge import ColumnKernelizer
from himalaya.utils import generate_multikernel_dataset
from sklearn.pipeline import make_pipeline
from sklearn import set_config
set_config(display='diagram')
In this example, we use the torch_cuda backend.
Torch can perform computations both on CPU and GPU. To use CPU, use the “torch” backend, to use GPU, use the “torch_cuda” backend.
backend = set_backend("torch_cuda", on_error="warn")
/home/runner/work/himalaya/himalaya/himalaya/backend/_utils.py:55: UserWarning: Setting backend to torch_cuda failed: PyTorch with CUDA is not available..Falling back to numpy backend.
warnings.warn(f"Setting backend to {backend} failed: {str(error)}."
Generate a random dataset¶
X_train : array of shape (n_samples_train, n_features)
X_test : array of shape (n_samples_test, n_features)
Y_train : array of shape (n_samples_train, n_targets)
Y_test : array of shape (n_samples_test, n_targets)
(X_train, X_test, Y_train, Y_test, kernel_weights,
n_features_list) = generate_multikernel_dataset(n_kernels=3, n_targets=50,
n_samples_train=600,
n_samples_test=300,
random_state=42)
feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))]
We could precompute the kernels by hand on Xs_train, as done in
plot_mkr_random_search.py. Instead, here we use the ColumnKernelizer
to make a scikit-learn Pipeline.
# Find the start and end of each feature space X in Xs
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [
slice(start, end)
for start, end in zip(start_and_end[:-1], start_and_end[1:])
]
Create a different Kernelizer for each feature space. Here we use a
linear kernel for all feature spaces, but ColumnKernelizer accepts any
Kernelizer, or scikit-learn Pipeline ending with a
Kernelizer.
kernelizers = [(name, Kernelizer(), slice_)
for name, slice_ in zip(feature_names, slices)]
column_kernelizer = ColumnKernelizer(kernelizers)
# Note that ``ColumnKernelizer`` has a parameter ``n_jobs`` to parallelize each
# kernelizer, yet such parallelism does not work with GPU arrays.
Define the model¶
The class takes a number of common parameters during initialization, such as kernels or solver. Since the solver parameters might be different depending on the solver, they can be passed in the solver_params parameter.
Here we use the “random_search” solver. We can check its specific parameters in the function docstring:
solver_function = MultipleKernelRidgeCV.ALL_SOLVERS["random_search"]
print("Docstring of the function %s:" % solver_function.__name__)
print(solver_function.__doc__)
Docstring of the function solve_multiple_kernel_ridge_random_search:
Solve multiple kernel ridge regression using random search.
Parameters
----------
Ks : array of shape (n_kernels, n_samples, n_samples)
Input kernels.
Y : array of shape (n_samples, n_targets)
Target data.
n_iter : int, or array of shape (n_iter, n_kernels)
Number of kernel weights combination to search.
If an array is given, the solver uses it as the list of kernel weights
to try, instead of sampling from a Dirichlet distribution. Examples:
- `n_iter=np.eye(n_kernels)` implement a winner-take-all strategy
over kernels.
- `n_iter=np.ones((1, n_kernels))/n_kernels` solves a (standard)
kernel ridge regression.
concentration : float, or list of float
Concentration parameters of the Dirichlet distribution.
If a list, iteratively cycle through the list.
Not used if n_iter is an array.
alphas : float or array of shape (n_alphas, )
Range of ridge regularization parameter.
score_func : callable
Function used to compute the score of predictions versus Y.
cv : int or scikit-learn splitter
Cross-validation splitter. If an int, KFold is used.
fit_intercept : boolean
Whether to fit an intercept. If False, Ks should be centered
(see KernelCenterer), and Y must be zero-mean over samples.
Only available if return_weights == 'dual'.
return_weights : None, 'primal', or 'dual'
Whether to refit on the entire dataset and return the weights.
Xs : array of shape (n_kernels, n_samples, n_features) or None
Necessary if return_weights == 'primal'.
local_alpha : bool
If True, alphas are selected per target, else shared over all targets.
jitter_alphas : bool
If True, alphas range is slightly jittered for each gamma.
random_state : int, or None
Random generator seed. Use an int for deterministic search.
n_targets_batch : int or None
Size of the batch for over targets during cross-validation.
Used for memory reasons. If None, uses all n_targets at once.
n_targets_batch_refit : int or None
Size of the batch for over targets during refit.
Used for memory reasons. If None, uses all n_targets at once.
n_alphas_batch : int or None
Size of the batch for over alphas. Used for memory reasons.
If None, uses all n_alphas at once.
progress_bar : bool
If True, display a progress bar over gammas.
Ks_in_cpu : bool
If True, keep Ks in CPU memory to limit GPU memory (slower).
This feature is not available through the scikit-learn API.
conservative : bool
If True, when selecting the hyperparameter alpha, take the largest one
that is less than one standard deviation away from the best.
If False, take the best.
Y_in_cpu : bool
If True, keep the target values ``Y`` in CPU memory (slower).
diagonalize_method : str in {"eigh", "svd"}
Method used to diagonalize the kernel.
return_alphas : bool
If True, return the best alpha value for each target.
Returns
-------
deltas : array of shape (n_kernels, n_targets)
Best log kernel weights for each target.
refit_weights : array or None
Refit regression weights on the entire dataset, using selected best
hyperparameters. Refit weights are always stored on CPU memory.
If return_weights == 'primal', shape is (n_features, n_targets),
if return_weights == 'dual', shape is (n_samples, n_targets),
else, None.
cv_scores : array of shape (n_iter, n_targets)
Cross-validation scores per iteration, averaged over splits, for the
best alpha. Cross-validation scores will always be on CPU memory.
best_alphas : array of shape (n_targets, )
Best alpha value per target. Only returned if return_alphas is True.
intercept : array of shape (n_targets,)
Intercept. Only returned when fit_intercept is True.
We use 100 iterations to have a reasonably fast example (~40 sec). To have a better convergence, we probably need more iterations. Note that there is currently no stopping criterion in this method.
n_iter = 100
Grid of regularization parameters.
alphas = np.logspace(-10, 10, 41)
Batch parameters are used to reduce the necessary GPU memory. A larger value will be a bit faster, but the solver might crash if it runs out of memory. Optimal values depend on the size of your dataset.
n_targets_batch = 1000
n_alphas_batch = 20
n_targets_batch_refit = 200
solver_params = dict(n_iter=n_iter, alphas=alphas,
n_targets_batch=n_targets_batch,
n_alphas_batch=n_alphas_batch,
n_targets_batch_refit=n_targets_batch_refit,
jitter_alphas=True)
model = MultipleKernelRidgeCV(kernels="precomputed", solver="random_search",
solver_params=solver_params)
Define and fit the pipeline
pipe = make_pipeline(column_kernelizer, model)
pipe.fit(X_train, Y_train)
[ ] 0% | 0.00 sec | 100 random sampling with cv |
[ ] 1% | 1.62 sec | 100 random sampling with cv | 0.62 it/s, ETA: 00:02:40
[ ] 2% | 3.20 sec | 100 random sampling with cv | 0.63 it/s, ETA: 00:02:36
[ ] 3% | 4.50 sec | 100 random sampling with cv | 0.67 it/s, ETA: 00:02:25
[. ] 4% | 5.68 sec | 100 random sampling with cv | 0.70 it/s, ETA: 00:02:16
[. ] 5% | 7.05 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:13
[. ] 6% | 8.51 sec | 100 random sampling with cv | 0.70 it/s, ETA: 00:02:13
[.. ] 7% | 9.79 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:10
[.. ] 8% | 11.29 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:09
[.. ] 9% | 12.71 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:08
[... ] 10% | 14.15 sec | 100 random sampling with cv | 0.71 it/s, ETA: 00:02:07
[... ] 11% | 15.27 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:02:03
[... ] 12% | 16.63 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:02:01
[... ] 13% | 17.89 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:59
[.... ] 14% | 19.33 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:01:58
[.... ] 15% | 20.74 sec | 100 random sampling with cv | 0.72 it/s, ETA: 00:01:57
[.... ] 16% | 21.82 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:54
[..... ] 17% | 23.29 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:53
[..... ] 18% | 24.67 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:52
[..... ] 19% | 25.97 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:50
[...... ] 20% | 27.47 sec | 100 random sampling with cv | 0.73 it/s, ETA: 00:01:49
[...... ] 21% | 28.52 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:47
[...... ] 22% | 29.82 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:45
[...... ] 23% | 31.19 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:44
[....... ] 24% | 32.29 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:42
[....... ] 25% | 33.56 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:40
[....... ] 26% | 35.03 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:39
[........ ] 27% | 36.34 sec | 100 random sampling with cv | 0.74 it/s, ETA: 00:01:38
[........ ] 28% | 37.46 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:36
[........ ] 29% | 38.68 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:34
[......... ] 30% | 39.90 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:33
[......... ] 31% | 41.23 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:31
[......... ] 32% | 42.40 sec | 100 random sampling with cv | 0.75 it/s, ETA: 00:01:30
[......... ] 33% | 43.59 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:28
[.......... ] 34% | 44.97 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:27
[.......... ] 35% | 46.23 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:25
[.......... ] 36% | 47.41 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:24
[........... ] 37% | 48.69 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:22
[........... ] 38% | 49.81 sec | 100 random sampling with cv | 0.76 it/s, ETA: 00:01:21
[........... ] 39% | 50.75 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:19
[............ ] 40% | 52.08 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:18
[............ ] 41% | 53.43 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:16
[............ ] 42% | 54.39 sec | 100 random sampling with cv | 0.77 it/s, ETA: 00:01:15
[............ ] 43% | 55.24 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:13
[............. ] 44% | 56.39 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:11
[............. ] 45% | 57.55 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:10
[............. ] 46% | 58.71 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:08
[.............. ] 47% | 59.89 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:07
[.............. ] 48% | 61.36 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:06
[.............. ] 49% | 62.71 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:05
[............... ] 50% | 63.75 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:03
[............... ] 51% | 65.04 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:02
[............... ] 52% | 66.36 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:01
[............... ] 53% | 67.74 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:01:00
[................ ] 54% | 69.27 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:59
[................ ] 55% | 70.60 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:57
[................ ] 56% | 71.94 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:56
[................. ] 57% | 73.05 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:55
[................. ] 58% | 74.43 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:53
[................. ] 59% | 75.53 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:52
[.................. ] 60% | 76.74 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:51
[.................. ] 61% | 77.95 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:49
[.................. ] 62% | 79.25 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:48
[.................. ] 63% | 80.38 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:47
[................... ] 64% | 81.63 sec | 100 random sampling with cv | 0.78 it/s, ETA: 00:00:45
[................... ] 65% | 82.72 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:44
[................... ] 66% | 83.82 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:43
[.................... ] 67% | 85.14 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:41
[.................... ] 68% | 86.29 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:40
[.................... ] 69% | 87.70 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:39
[..................... ] 70% | 89.10 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:38
[..................... ] 71% | 90.29 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:36
[..................... ] 72% | 91.67 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:35
[..................... ] 73% | 92.79 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:34
[...................... ] 74% | 94.12 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:33
[...................... ] 75% | 95.46 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:31
[...................... ] 76% | 96.75 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:30
[....................... ] 77% | 97.70 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:29
[....................... ] 78% | 98.84 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:27
[....................... ] 79% | 99.95 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:26
[........................ ] 80% | 101.13 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:25
[........................ ] 81% | 102.36 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:24
[........................ ] 82% | 103.39 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:22
[........................ ] 83% | 104.49 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:21
[......................... ] 84% | 105.79 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:20
[......................... ] 85% | 106.99 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:18
[......................... ] 86% | 108.19 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:17
[.......................... ] 87% | 109.52 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:16
[.......................... ] 88% | 110.66 sec | 100 random sampling with cv | 0.80 it/s, ETA: 00:00:15
[.......................... ] 89% | 111.89 sec | 100 random sampling with cv | 0.80 it/s, ETA: 00:00:13
[........................... ] 90% | 113.27 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:12
[........................... ] 91% | 114.54 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:11
[........................... ] 92% | 116.05 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:10
[........................... ] 93% | 117.35 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:08
[............................ ] 94% | 118.79 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:07
[............................ ] 95% | 119.97 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:06
[............................ ] 96% | 121.32 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:05
[............................. ] 97% | 122.36 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:03
[............................. ] 98% | 123.53 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:02
[............................. ] 99% | 124.83 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:01
[..............................] 100% | 126.07 sec | 100 random sampling with cv | 0.79 it/s, ETA: 00:00:00
Plot the convergence curve¶
# ``cv_scores`` gives the scores for each sampled kernel weights.
# The convergence curve is thus the current maximum for each target.
cv_scores = backend.to_numpy(pipe[1].cv_scores_)
current_max = np.maximum.accumulate(cv_scores, axis=0)
mean_current_max = np.mean(current_max, axis=1)
x_array = np.arange(1, len(mean_current_max) + 1)
plt.plot(x_array, mean_current_max, '-o')
plt.grid("on")
plt.xlabel("Number of kernel weights sampled")
plt.ylabel("L2 negative loss (higher is better)")
plt.title("Convergence curve, averaged over targets")
plt.tight_layout()
plt.show()

Compare to KernelRidgeCV¶
Compare to a baseline KernelRidgeCV model with all the concatenated
features. Comparison is performed using the prediction scores on the test
set.
Fit the baseline model KernelRidgeCV
baseline = KernelRidgeCV(kernel="linear", alphas=alphas)
baseline.fit(X_train, Y_train)
Compute scores of both models
scores = pipe.score(X_test, Y_test)
scores = backend.to_numpy(scores)
scores_baseline = baseline.score(X_test, Y_test)
scores_baseline = backend.to_numpy(scores_baseline)
Plot histograms
bins = np.linspace(0, max(scores_baseline.max(), scores.max()), 50)
plt.hist(scores_baseline, bins, alpha=0.7, label="KernelRidgeCV")
plt.hist(scores, bins, alpha=0.7, label="MultipleKernelRidgeCV")
plt.xlabel(r"$R^2$ generalization score")
plt.title("Histogram over targets")
plt.legend()
plt.show()

Total running time of the script: (2 minutes 9.284 seconds)