Note
Click here to download the full example code
Multiple-kernel ridge with scikit-learn API¶
This example demonstrates how to solve multiple kernel ridge regression, using scikit-learn API.
import numpy as np
import matplotlib.pyplot as plt
from himalaya.backend import set_backend
from himalaya.kernel_ridge import KernelRidgeCV
from himalaya.kernel_ridge import MultipleKernelRidgeCV
from himalaya.kernel_ridge import Kernelizer
from himalaya.kernel_ridge import ColumnKernelizer
from himalaya.utils import generate_multikernel_dataset
from sklearn.pipeline import make_pipeline
from sklearn import set_config
set_config(display='diagram')
In this example, we use the torch_cuda
backend.
Torch can perform computations both on CPU and GPU. To use CPU, use the “torch” backend, to use GPU, use the “torch_cuda” backend.
backend = set_backend("torch_cuda", on_error="warn")
Generate a random dataset¶
X_train : array of shape (n_samples_train, n_features)
X_test : array of shape (n_samples_test, n_features)
Y_train : array of shape (n_samples_train, n_targets)
Y_test : array of shape (n_samples_test, n_targets)
(X_train, X_test, Y_train, Y_test, kernel_weights,
n_features_list) = generate_multikernel_dataset(n_kernels=3, n_targets=50,
n_samples_train=600,
n_samples_test=300,
random_state=42)
feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))]
We could precompute the kernels by hand on Xs_train
, as done in
plot_mkr_random_search.py
. Instead, here we use the ColumnKernelizer
to make a scikit-learn
Pipeline
.
# Find the start and end of each feature space X in Xs
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [
slice(start, end)
for start, end in zip(start_and_end[:-1], start_and_end[1:])
]
Create a different Kernelizer
for each feature space. Here we use a
linear kernel for all feature spaces, but ColumnKernelizer
accepts any
Kernelizer
, or scikit-learn
Pipeline
ending with a
Kernelizer
.
kernelizers = [(name, Kernelizer(), slice_)
for name, slice_ in zip(feature_names, slices)]
column_kernelizer = ColumnKernelizer(kernelizers)
# Note that ``ColumnKernelizer`` has a parameter ``n_jobs`` to parallelize each
# kernelizer, yet such parallelism does not work with GPU arrays.
Define the model¶
The class takes a number of common parameters during initialization, such as kernels or solver. Since the solver parameters might be different depending on the solver, they can be passed in the solver_params parameter.
Here we use the “random_search” solver. We can check its specific parameters in the function docstring:
solver_function = MultipleKernelRidgeCV.ALL_SOLVERS["random_search"]
print("Docstring of the function %s:" % solver_function.__name__)
print(solver_function.__doc__)
Out:
Docstring of the function solve_multiple_kernel_ridge_random_search:
Solve multiple kernel ridge regression using random search.
Parameters
----------
Ks : array of shape (n_kernels, n_samples, n_samples)
Input kernels.
Y : array of shape (n_samples, n_targets)
Target data.
n_iter : int, or array of shape (n_iter, n_kernels)
Number of kernel weights combination to search.
If an array is given, the solver uses it as the list of kernel weights
to try, instead of sampling from a Dirichlet distribution. Examples:
- `n_iter=np.eye(n_kernels)` implement a winner-take-all strategy
over kernels.
- `n_iter=np.ones((1, n_kernels))/n_kernels` solves a (standard)
kernel ridge regression.
concentration : float, or list of float
Concentration parameters of the Dirichlet distribution.
If a list, iteratively cycle through the list.
Not used if n_iter is an array.
alphas : float or array of shape (n_alphas, )
Range of ridge regularization parameter.
score_func : callable
Function used to compute the score of predictions versus Y.
cv : int or scikit-learn splitter
Cross-validation splitter. If an int, KFold is used.
fit_intercept : boolean
Whether to fit an intercept. If False, Ks should be centered
(see KernelCenterer), and Y must be zero-mean over samples.
Only available if return_weights == 'dual'.
return_weights : None, 'primal', or 'dual'
Whether to refit on the entire dataset and return the weights.
Xs : array of shape (n_kernels, n_samples, n_features) or None
Necessary if return_weights == 'primal'.
local_alpha : bool
If True, alphas are selected per target, else shared over all targets.
jitter_alphas : bool
If True, alphas range is slightly jittered for each gamma.
random_state : int, or None
Random generator seed. Use an int for deterministic search.
n_targets_batch : int or None
Size of the batch for over targets during cross-validation.
Used for memory reasons. If None, uses all n_targets at once.
n_targets_batch_refit : int or None
Size of the batch for over targets during refit.
Used for memory reasons. If None, uses all n_targets at once.
n_alphas_batch : int or None
Size of the batch for over alphas. Used for memory reasons.
If None, uses all n_alphas at once.
progress_bar : bool
If True, display a progress bar over gammas.
Ks_in_cpu : bool
If True, keep Ks in CPU memory to limit GPU memory (slower).
This feature is not available through the scikit-learn API.
conservative : bool
If True, when selecting the hyperparameter alpha, take the largest one
that is less than one standard deviation away from the best.
If False, take the best.
Y_in_cpu : bool
If True, keep the target values ``Y`` in CPU memory (slower).
diagonalize_method : str in {"eigh", "svd"}
Method used to diagonalize the kernel.
return_alphas : bool
If True, return the best alpha value for each target.
Returns
-------
deltas : array of shape (n_kernels, n_targets)
Best log kernel weights for each target.
refit_weights : array or None
Refit regression weights on the entire dataset, using selected best
hyperparameters. Refit weights are always stored on CPU memory.
If return_weights == 'primal', shape is (n_features, n_targets),
if return_weights == 'dual', shape is (n_samples, n_targets),
else, None.
cv_scores : array of shape (n_iter, n_targets)
Cross-validation scores per iteration, averaged over splits, for the
best alpha. Cross-validation scores will always be on CPU memory.
best_alphas : array of shape (n_targets, )
Best alpha value per target. Only returned if return_alphas is True.
intercept : array of shape (n_targets,)
Intercept. Only returned when fit_intercept is True.
We use 100 iterations to have a reasonably fast example (~40 sec). To have a better convergence, we probably need more iterations. Note that there is currently no stopping criterion in this method.
n_iter = 100
Grid of regularization parameters.
alphas = np.logspace(-10, 10, 41)
Batch parameters are used to reduce the necessary GPU memory. A larger value will be a bit faster, but the solver might crash if it runs out of memory. Optimal values depend on the size of your dataset.
n_targets_batch = 1000
n_alphas_batch = 20
n_targets_batch_refit = 200
solver_params = dict(n_iter=n_iter, alphas=alphas,
n_targets_batch=n_targets_batch,
n_alphas_batch=n_alphas_batch,
n_targets_batch_refit=n_targets_batch_refit,
jitter_alphas=True)
model = MultipleKernelRidgeCV(kernels="precomputed", solver="random_search",
solver_params=solver_params)
Define and fit the pipeline
pipe = make_pipeline(column_kernelizer, model)
pipe.fit(X_train, Y_train)
Out:
[ ] 0% | 0.00 sec | 100 random sampling with cv |
[ ] 1% | 0.11 sec | 100 random sampling with cv |
[ ] 2% | 0.23 sec | 100 random sampling with cv |
[. ] 3% | 0.34 sec | 100 random sampling with cv |
[. ] 4% | 0.45 sec | 100 random sampling with cv |
[.. ] 5% | 0.56 sec | 100 random sampling with cv |
[.. ] 6% | 0.68 sec | 100 random sampling with cv |
[.. ] 7% | 0.79 sec | 100 random sampling with cv |
[... ] 8% | 0.90 sec | 100 random sampling with cv |
[... ] 9% | 1.02 sec | 100 random sampling with cv |
[.... ] 10% | 1.13 sec | 100 random sampling with cv |
[.... ] 11% | 1.24 sec | 100 random sampling with cv |
[.... ] 12% | 1.35 sec | 100 random sampling with cv |
[..... ] 13% | 1.47 sec | 100 random sampling with cv |
[..... ] 14% | 1.58 sec | 100 random sampling with cv |
[...... ] 15% | 1.69 sec | 100 random sampling with cv |
[...... ] 16% | 1.81 sec | 100 random sampling with cv |
[...... ] 17% | 1.92 sec | 100 random sampling with cv |
[....... ] 18% | 2.03 sec | 100 random sampling with cv |
[....... ] 19% | 2.14 sec | 100 random sampling with cv |
[........ ] 20% | 2.26 sec | 100 random sampling with cv |
[........ ] 21% | 2.37 sec | 100 random sampling with cv |
[........ ] 22% | 2.45 sec | 100 random sampling with cv |
[......... ] 23% | 2.56 sec | 100 random sampling with cv |
[......... ] 24% | 2.65 sec | 100 random sampling with cv |
[.......... ] 25% | 2.73 sec | 100 random sampling with cv |
[.......... ] 26% | 2.81 sec | 100 random sampling with cv |
[.......... ] 27% | 2.93 sec | 100 random sampling with cv |
[........... ] 28% | 3.04 sec | 100 random sampling with cv |
[........... ] 29% | 3.15 sec | 100 random sampling with cv |
[............ ] 30% | 3.26 sec | 100 random sampling with cv |
[............ ] 31% | 3.38 sec | 100 random sampling with cv |
[............ ] 32% | 3.49 sec | 100 random sampling with cv |
[............. ] 33% | 3.60 sec | 100 random sampling with cv |
[............. ] 34% | 3.72 sec | 100 random sampling with cv |
[.............. ] 35% | 3.83 sec | 100 random sampling with cv |
[.............. ] 36% | 3.94 sec | 100 random sampling with cv |
[.............. ] 37% | 4.06 sec | 100 random sampling with cv |
[............... ] 38% | 4.17 sec | 100 random sampling with cv |
[............... ] 39% | 4.25 sec | 100 random sampling with cv |
[................ ] 40% | 4.36 sec | 100 random sampling with cv |
[................ ] 41% | 4.48 sec | 100 random sampling with cv |
[................ ] 42% | 4.56 sec | 100 random sampling with cv |
[................. ] 43% | 4.67 sec | 100 random sampling with cv |
[................. ] 44% | 4.76 sec | 100 random sampling with cv |
[.................. ] 45% | 4.87 sec | 100 random sampling with cv |
[.................. ] 46% | 4.98 sec | 100 random sampling with cv |
[.................. ] 47% | 5.09 sec | 100 random sampling with cv |
[................... ] 48% | 5.21 sec | 100 random sampling with cv |
[................... ] 49% | 5.32 sec | 100 random sampling with cv |
[.................... ] 50% | 5.40 sec | 100 random sampling with cv |
[.................... ] 51% | 5.49 sec | 100 random sampling with cv |
[.................... ] 52% | 5.60 sec | 100 random sampling with cv |
[..................... ] 53% | 5.71 sec | 100 random sampling with cv |
[..................... ] 54% | 5.82 sec | 100 random sampling with cv |
[...................... ] 55% | 5.94 sec | 100 random sampling with cv |
[...................... ] 56% | 6.02 sec | 100 random sampling with cv |
[...................... ] 57% | 6.10 sec | 100 random sampling with cv |
[....................... ] 58% | 6.19 sec | 100 random sampling with cv |
[....................... ] 59% | 6.27 sec | 100 random sampling with cv |
[........................ ] 60% | 6.36 sec | 100 random sampling with cv |
[........................ ] 61% | 6.44 sec | 100 random sampling with cv |
[........................ ] 62% | 6.52 sec | 100 random sampling with cv |
[......................... ] 63% | 6.61 sec | 100 random sampling with cv |
[......................... ] 64% | 6.72 sec | 100 random sampling with cv |
[.......................... ] 65% | 6.83 sec | 100 random sampling with cv |
[.......................... ] 66% | 6.94 sec | 100 random sampling with cv |
[.......................... ] 67% | 7.06 sec | 100 random sampling with cv |
[........................... ] 68% | 7.14 sec | 100 random sampling with cv |
[........................... ] 69% | 7.26 sec | 100 random sampling with cv |
[............................ ] 70% | 7.34 sec | 100 random sampling with cv |
[............................ ] 71% | 7.42 sec | 100 random sampling with cv |
[............................ ] 72% | 7.51 sec | 100 random sampling with cv |
[............................. ] 73% | 7.59 sec | 100 random sampling with cv |
[............................. ] 74% | 7.67 sec | 100 random sampling with cv |
[.............................. ] 75% | 7.79 sec | 100 random sampling with cv |
[.............................. ] 76% | 7.90 sec | 100 random sampling with cv |
[.............................. ] 77% | 8.01 sec | 100 random sampling with cv |
[............................... ] 78% | 8.10 sec | 100 random sampling with cv |
[............................... ] 79% | 8.18 sec | 100 random sampling with cv |
[................................ ] 80% | 8.26 sec | 100 random sampling with cv |
[................................ ] 81% | 8.35 sec | 100 random sampling with cv |
[................................ ] 82% | 8.43 sec | 100 random sampling with cv |
[................................. ] 83% | 8.51 sec | 100 random sampling with cv |
[................................. ] 84% | 8.63 sec | 100 random sampling with cv |
[.................................. ] 85% | 8.74 sec | 100 random sampling with cv |
[.................................. ] 86% | 8.85 sec | 100 random sampling with cv |
[.................................. ] 87% | 8.97 sec | 100 random sampling with cv |
[................................... ] 88% | 9.05 sec | 100 random sampling with cv |
[................................... ] 89% | 9.17 sec | 100 random sampling with cv |
[.................................... ] 90% | 9.25 sec | 100 random sampling with cv |
[.................................... ] 91% | 9.33 sec | 100 random sampling with cv |
[.................................... ] 92% | 9.42 sec | 100 random sampling with cv |
[..................................... ] 93% | 9.53 sec | 100 random sampling with cv |
[..................................... ] 94% | 9.61 sec | 100 random sampling with cv |
[...................................... ] 95% | 9.73 sec | 100 random sampling with cv |
[...................................... ] 96% | 9.84 sec | 100 random sampling with cv |
[...................................... ] 97% | 9.93 sec | 100 random sampling with cv |
[....................................... ] 98% | 10.01 sec | 100 random sampling with cv |
[....................................... ] 99% | 10.09 sec | 100 random sampling with cv |
[........................................] 100% | 10.18 sec | 100 random sampling with cv |
Plot the convergence curve¶
# ``cv_scores`` gives the scores for each sampled kernel weights.
# The convergence curve is thus the current maximum for each target.
cv_scores = backend.to_numpy(pipe[1].cv_scores_)
current_max = np.maximum.accumulate(cv_scores, axis=0)
mean_current_max = np.mean(current_max, axis=1)
x_array = np.arange(1, len(mean_current_max) + 1)
plt.plot(x_array, mean_current_max, '-o')
plt.grid("on")
plt.xlabel("Number of kernel weights sampled")
plt.ylabel("L2 negative loss (higher is better)")
plt.title("Convergence curve, averaged over targets")
plt.tight_layout()
plt.show()
Compare to KernelRidgeCV
¶
Compare to a baseline KernelRidgeCV
model with all the concatenated
features. Comparison is performed using the prediction scores on the test
set.
Fit the baseline model KernelRidgeCV
baseline = KernelRidgeCV(kernel="linear", alphas=alphas)
baseline.fit(X_train, Y_train)
Compute scores of both models
scores = pipe.score(X_test, Y_test)
scores = backend.to_numpy(scores)
scores_baseline = baseline.score(X_test, Y_test)
scores_baseline = backend.to_numpy(scores_baseline)
Plot histograms
bins = np.linspace(0, max(scores_baseline.max(), scores.max()), 50)
plt.hist(scores_baseline, bins, alpha=0.7, label="KernelRidgeCV")
plt.hist(scores, bins, alpha=0.7, label="MultipleKernelRidgeCV")
plt.xlabel(r"$R^2$ generalization score")
plt.title("Histogram over targets")
plt.legend()
plt.show()
Total running time of the script: ( 0 minutes 12.366 seconds)