Note
Click here to download the full example code
Multiple-kernel ridge¶
This example demonstrates how to solve multiple kernel ridge regression. It uses random search and cross validation to select optimal hyperparameters.
import numpy as np
import matplotlib.pyplot as plt
from himalaya.backend import set_backend
from himalaya.kernel_ridge import solve_multiple_kernel_ridge_random_search
from himalaya.kernel_ridge import predict_and_score_weighted_kernel_ridge
from himalaya.utils import generate_multikernel_dataset
from himalaya.scoring import r2_score_split
from himalaya.viz import plot_alphas_diagnostic
In this example, we use the cupy
backend, and fit the model on GPU.
backend = set_backend("cupy", on_error="warn")
Generate a random dataset¶
X_train : array of shape (n_samples_train, n_features)
X_test : array of shape (n_samples_test, n_features)
Y_train : array of shape (n_samples_train, n_targets)
Y_test : array of shape (n_samples_test, n_targets)
n_kernels = 3
n_targets = 50
kernel_weights = np.tile(np.array([0.5, 0.3, 0.2])[None], (n_targets, 1))
(X_train, X_test, Y_train, Y_test,
kernel_weights, n_features_list) = generate_multikernel_dataset(
n_kernels=n_kernels, n_targets=n_targets, n_samples_train=600,
n_samples_test=300, kernel_weights=kernel_weights, random_state=42)
feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))]
# Find the start and end of each feature space X in Xs
start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
slices = [
slice(start, end)
for start, end in zip(start_and_end[:-1], start_and_end[1:])
]
Xs_train = [X_train[:, slic] for slic in slices]
Xs_test = [X_test[:, slic] for slic in slices]
Precompute the linear kernels¶
We also cast them to float32.
Ks_train = backend.stack([X_train @ X_train.T for X_train in Xs_train])
Ks_train = backend.asarray(Ks_train, dtype=backend.float32)
Y_train = backend.asarray(Y_train, dtype=backend.float32)
Ks_test = backend.stack(
[X_test @ X_train.T for X_train, X_test in zip(Xs_train, Xs_test)])
Ks_test = backend.asarray(Ks_test, dtype=backend.float32)
Y_test = backend.asarray(Y_test, dtype=backend.float32)
Run the solver, using random search¶
This method should work fine for
small number of kernels (< 20). The larger the number of kenels, the larger
we need to sample the hyperparameter space (i.e. increasing n_iter
).
Here we use 100 iterations to have a reasonably fast example (~40 sec). To have a better convergence, we probably need more iterations. Note that there is currently no stopping criterion in this method.
n_iter = 100
Grid of regularization parameters.
alphas = np.logspace(-10, 10, 21)
Batch parameters are used to reduce the necessary GPU memory. A larger value will be a bit faster, but the solver might crash if it runs out of memory. Optimal values depend on the size of your dataset.
n_targets_batch = 1000
n_alphas_batch = 20
If return_weights == "dual"
, the solver will use more memory.
To mitigate this, you can reduce n_targets_batch
in the refit
using `n_targets_batch_refit
.
If you don’t need the dual weights, use return_weights = None
.
return_weights = 'dual'
n_targets_batch_refit = 200
Run the solver. For each iteration, it will:
sample kernel weights from a Dirichlet distribution
fit (n_splits * n_alphas * n_targets) ridge models
compute the scores on the validation set of each split
average the scores over splits
take the maximum over alphas
(only if you ask for the ridge weights) refit using the best alphas per target and the entire dataset
return for each target the log kernel weights leading to the best CV score (and the best weights if necessary)
results = solve_multiple_kernel_ridge_random_search(
Ks=Ks_train,
Y=Y_train,
n_iter=n_iter,
alphas=alphas,
n_targets_batch=n_targets_batch,
return_weights=return_weights,
n_alphas_batch=n_alphas_batch,
n_targets_batch_refit=n_targets_batch_refit,
jitter_alphas=True,
)
Out:
[ ] 0% | 0.00 sec | 100 random sampling with cv |
[ ] 1% | 0.23 sec | 100 random sampling with cv |
[ ] 2% | 0.43 sec | 100 random sampling with cv |
[. ] 3% | 0.57 sec | 100 random sampling with cv |
[. ] 4% | 0.71 sec | 100 random sampling with cv |
[.. ] 5% | 0.84 sec | 100 random sampling with cv |
[.. ] 6% | 1.01 sec | 100 random sampling with cv |
[.. ] 7% | 1.18 sec | 100 random sampling with cv |
[... ] 8% | 1.31 sec | 100 random sampling with cv |
[... ] 9% | 1.48 sec | 100 random sampling with cv |
[.... ] 10% | 1.65 sec | 100 random sampling with cv |
[.... ] 11% | 1.82 sec | 100 random sampling with cv |
[.... ] 12% | 1.99 sec | 100 random sampling with cv |
[..... ] 13% | 2.13 sec | 100 random sampling with cv |
[..... ] 14% | 2.26 sec | 100 random sampling with cv |
[...... ] 15% | 2.39 sec | 100 random sampling with cv |
[...... ] 16% | 2.52 sec | 100 random sampling with cv |
[...... ] 17% | 2.69 sec | 100 random sampling with cv |
[....... ] 18% | 2.83 sec | 100 random sampling with cv |
[....... ] 19% | 2.96 sec | 100 random sampling with cv |
[........ ] 20% | 3.09 sec | 100 random sampling with cv |
[........ ] 21% | 3.26 sec | 100 random sampling with cv |
[........ ] 22% | 3.43 sec | 100 random sampling with cv |
[......... ] 23% | 3.56 sec | 100 random sampling with cv |
[......... ] 24% | 3.70 sec | 100 random sampling with cv |
[.......... ] 25% | 3.83 sec | 100 random sampling with cv |
[.......... ] 26% | 3.96 sec | 100 random sampling with cv |
[.......... ] 27% | 4.10 sec | 100 random sampling with cv |
[........... ] 28% | 4.23 sec | 100 random sampling with cv |
[........... ] 29% | 4.36 sec | 100 random sampling with cv |
[............ ] 30% | 4.49 sec | 100 random sampling with cv |
[............ ] 31% | 4.63 sec | 100 random sampling with cv |
[............ ] 32% | 4.76 sec | 100 random sampling with cv |
[............. ] 33% | 4.89 sec | 100 random sampling with cv |
[............. ] 34% | 5.02 sec | 100 random sampling with cv |
[.............. ] 35% | 5.16 sec | 100 random sampling with cv |
[.............. ] 36% | 5.33 sec | 100 random sampling with cv |
[.............. ] 37% | 5.46 sec | 100 random sampling with cv |
[............... ] 38% | 5.59 sec | 100 random sampling with cv |
[............... ] 39% | 5.72 sec | 100 random sampling with cv |
[................ ] 40% | 5.86 sec | 100 random sampling with cv |
[................ ] 41% | 5.99 sec | 100 random sampling with cv |
[................ ] 42% | 6.12 sec | 100 random sampling with cv |
[................. ] 43% | 6.25 sec | 100 random sampling with cv |
[................. ] 44% | 6.42 sec | 100 random sampling with cv |
[.................. ] 45% | 6.56 sec | 100 random sampling with cv |
[.................. ] 46% | 6.69 sec | 100 random sampling with cv |
[.................. ] 47% | 6.82 sec | 100 random sampling with cv |
[................... ] 48% | 6.99 sec | 100 random sampling with cv |
[................... ] 49% | 7.12 sec | 100 random sampling with cv |
[.................... ] 50% | 7.26 sec | 100 random sampling with cv |
[.................... ] 51% | 7.39 sec | 100 random sampling with cv |
[.................... ] 52% | 7.52 sec | 100 random sampling with cv |
[..................... ] 53% | 7.66 sec | 100 random sampling with cv |
[..................... ] 54% | 7.83 sec | 100 random sampling with cv |
[...................... ] 55% | 7.96 sec | 100 random sampling with cv |
[...................... ] 56% | 8.09 sec | 100 random sampling with cv |
[...................... ] 57% | 8.23 sec | 100 random sampling with cv |
[....................... ] 58% | 8.36 sec | 100 random sampling with cv |
[....................... ] 59% | 8.49 sec | 100 random sampling with cv |
[........................ ] 60% | 8.67 sec | 100 random sampling with cv |
[........................ ] 61% | 8.80 sec | 100 random sampling with cv |
[........................ ] 62% | 8.97 sec | 100 random sampling with cv |
[......................... ] 63% | 9.10 sec | 100 random sampling with cv |
[......................... ] 64% | 9.24 sec | 100 random sampling with cv |
[.......................... ] 65% | 9.37 sec | 100 random sampling with cv |
[.......................... ] 66% | 9.50 sec | 100 random sampling with cv |
[.......................... ] 67% | 9.64 sec | 100 random sampling with cv |
[........................... ] 68% | 9.77 sec | 100 random sampling with cv |
[........................... ] 69% | 9.90 sec | 100 random sampling with cv |
[............................ ] 70% | 10.04 sec | 100 random sampling with cv |
[............................ ] 71% | 10.17 sec | 100 random sampling with cv |
[............................ ] 72% | 10.30 sec | 100 random sampling with cv |
[............................. ] 73% | 10.44 sec | 100 random sampling with cv |
[............................. ] 74% | 10.61 sec | 100 random sampling with cv |
[.............................. ] 75% | 10.74 sec | 100 random sampling with cv |
[.............................. ] 76% | 10.92 sec | 100 random sampling with cv |
[.............................. ] 77% | 11.05 sec | 100 random sampling with cv |
[............................... ] 78% | 11.18 sec | 100 random sampling with cv |
[............................... ] 79% | 11.32 sec | 100 random sampling with cv |
[................................ ] 80% | 11.45 sec | 100 random sampling with cv |
[................................ ] 81% | 11.59 sec | 100 random sampling with cv |
[................................ ] 82% | 11.72 sec | 100 random sampling with cv |
[................................. ] 83% | 11.89 sec | 100 random sampling with cv |
[................................. ] 84% | 12.02 sec | 100 random sampling with cv |
[.................................. ] 85% | 12.16 sec | 100 random sampling with cv |
[.................................. ] 86% | 12.29 sec | 100 random sampling with cv |
[.................................. ] 87% | 12.42 sec | 100 random sampling with cv |
[................................... ] 88% | 12.56 sec | 100 random sampling with cv |
[................................... ] 89% | 12.69 sec | 100 random sampling with cv |
[.................................... ] 90% | 12.86 sec | 100 random sampling with cv |
[.................................... ] 91% | 12.99 sec | 100 random sampling with cv |
[.................................... ] 92% | 13.13 sec | 100 random sampling with cv |
[..................................... ] 93% | 13.26 sec | 100 random sampling with cv |
[..................................... ] 94% | 13.39 sec | 100 random sampling with cv |
[...................................... ] 95% | 13.53 sec | 100 random sampling with cv |
[...................................... ] 96% | 13.66 sec | 100 random sampling with cv |
[...................................... ] 97% | 13.79 sec | 100 random sampling with cv |
[....................................... ] 98% | 13.93 sec | 100 random sampling with cv |
[....................................... ] 99% | 14.06 sec | 100 random sampling with cv |
[........................................] 100% | 14.20 sec | 100 random sampling with cv |
As we used the cupy
backend, the results are cupy
arrays, which are
on GPU. Here, we cast the results back to CPU, and to numpy
arrays.
deltas = backend.to_numpy(results[0])
dual_weights = backend.to_numpy(results[1])
cv_scores = backend.to_numpy(results[2])
Plot the convergence curve¶
cv_scores
gives the scores for each sampled kernel weights.
The convergence curve is thus the current maximum for each target.
current_max = np.maximum.accumulate(cv_scores, axis=0)
mean_current_max = np.mean(current_max, axis=1)
x_array = np.arange(1, len(mean_current_max) + 1)
plt.plot(x_array, mean_current_max, '-o')
plt.grid("on")
plt.xlabel("Number of kernel weights sampled")
plt.ylabel("L2 negative loss (higher is better)")
plt.title("Convergence curve, averaged over targets")
plt.tight_layout()
plt.show()
Plot the optimal alphas selected by the solver¶
This plot is helpful to refine the alpha grid if the range is too small or too large.
best_alphas = 1. / np.sum(np.exp(deltas), axis=0)
plot_alphas_diagnostic(best_alphas, alphas)
plt.title("Best alphas selected by cross-validation")
plt.show()
Compute the predictions on the test set¶
(requires the dual weights)
split = False
scores = predict_and_score_weighted_kernel_ridge(
Ks_test, dual_weights, deltas, Y_test, split=split,
n_targets_batch=n_targets_batch, score_func=r2_score_split)
scores = backend.to_numpy(scores)
plt.hist(scores, np.linspace(0, 1, 50))
plt.xlabel(r"$R^2$ generalization score")
plt.title("Histogram over targets")
plt.show()
Compute the split predictions on the test set¶
(requires the dual weights)
Here we apply the dual weights on each kernel separately
(exp(deltas[i]) * kernel[i]
), and we compute the R2 scores
(corrected for correlations) of each prediction.
split = True
scores_split = predict_and_score_weighted_kernel_ridge(
Ks_test, dual_weights, deltas, Y_test, split=split,
n_targets_batch=n_targets_batch, score_func=r2_score_split)
scores_split = backend.to_numpy(scores_split)
for kk, score in enumerate(scores_split):
plt.hist(score, np.linspace(0, np.max(scores_split), 50), alpha=0.7,
label="kernel %d" % kk)
plt.title(r"Histogram of $R^2$ generalization score split between kernels")
plt.legend()
plt.show()
Total running time of the script: ( 0 minutes 14.633 seconds)