.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_auto_examples/multiple_kernel_ridge/plot_mkr_0_random_search.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__auto_examples_multiple_kernel_ridge_plot_mkr_0_random_search.py: Multiple-kernel ridge ===================== This example demonstrates how to solve multiple kernel ridge regression. It uses random search and cross validation to select optimal hyperparameters. .. GENERATED FROM PYTHON SOURCE LINES 7-18 .. code-block:: default import numpy as np import matplotlib.pyplot as plt from himalaya.backend import set_backend from himalaya.kernel_ridge import solve_multiple_kernel_ridge_random_search from himalaya.kernel_ridge import predict_and_score_weighted_kernel_ridge from himalaya.utils import generate_multikernel_dataset from himalaya.scoring import r2_score_split from himalaya.viz import plot_alphas_diagnostic .. GENERATED FROM PYTHON SOURCE LINES 20-21 In this example, we use the ``cupy`` backend, and fit the model on GPU. .. GENERATED FROM PYTHON SOURCE LINES 21-24 .. code-block:: default backend = set_backend("cupy", on_error="warn") .. GENERATED FROM PYTHON SOURCE LINES 25-32 Generate a random dataset ------------------------- - X_train : array of shape (n_samples_train, n_features) - X_test : array of shape (n_samples_test, n_features) - Y_train : array of shape (n_samples_train, n_targets) - Y_test : array of shape (n_samples_test, n_targets) .. GENERATED FROM PYTHON SOURCE LINES 32-53 .. code-block:: default n_kernels = 3 n_targets = 50 kernel_weights = np.tile(np.array([0.5, 0.3, 0.2])[None], (n_targets, 1)) (X_train, X_test, Y_train, Y_test, kernel_weights, n_features_list) = generate_multikernel_dataset( n_kernels=n_kernels, n_targets=n_targets, n_samples_train=600, n_samples_test=300, kernel_weights=kernel_weights, random_state=42) feature_names = [f"Feature space {ii}" for ii in range(len(n_features_list))] # Find the start and end of each feature space X in Xs start_and_end = np.concatenate([[0], np.cumsum(n_features_list)]) slices = [ slice(start, end) for start, end in zip(start_and_end[:-1], start_and_end[1:]) ] Xs_train = [X_train[:, slic] for slic in slices] Xs_test = [X_test[:, slic] for slic in slices] .. GENERATED FROM PYTHON SOURCE LINES 54-57 Precompute the linear kernels ----------------------------- We also cast them to float32. .. GENERATED FROM PYTHON SOURCE LINES 57-67 .. code-block:: default Ks_train = backend.stack([X_train @ X_train.T for X_train in Xs_train]) Ks_train = backend.asarray(Ks_train, dtype=backend.float32) Y_train = backend.asarray(Y_train, dtype=backend.float32) Ks_test = backend.stack( [X_test @ X_train.T for X_train, X_test in zip(Xs_train, Xs_test)]) Ks_test = backend.asarray(Ks_test, dtype=backend.float32) Y_test = backend.asarray(Y_test, dtype=backend.float32) .. GENERATED FROM PYTHON SOURCE LINES 68-73 Run the solver, using random search ----------------------------------- This method should work fine for small number of kernels (< 20). The larger the number of kenels, the larger we need to sample the hyperparameter space (i.e. increasing ``n_iter``). .. GENERATED FROM PYTHON SOURCE LINES 75-78 Here we use 100 iterations to have a reasonably fast example (~40 sec). To have a better convergence, we probably need more iterations. Note that there is currently no stopping criterion in this method. .. GENERATED FROM PYTHON SOURCE LINES 78-80 .. code-block:: default n_iter = 100 .. GENERATED FROM PYTHON SOURCE LINES 81-82 Grid of regularization parameters. .. GENERATED FROM PYTHON SOURCE LINES 82-84 .. code-block:: default alphas = np.logspace(-10, 10, 21) .. GENERATED FROM PYTHON SOURCE LINES 85-88 Batch parameters are used to reduce the necessary GPU memory. A larger value will be a bit faster, but the solver might crash if it runs out of memory. Optimal values depend on the size of your dataset. .. GENERATED FROM PYTHON SOURCE LINES 88-91 .. code-block:: default n_targets_batch = 1000 n_alphas_batch = 20 .. GENERATED FROM PYTHON SOURCE LINES 92-96 If ``return_weights == "dual"``, the solver will use more memory. To mitigate this, you can reduce ``n_targets_batch`` in the refit using ```n_targets_batch_refit``. If you don't need the dual weights, use ``return_weights = None``. .. GENERATED FROM PYTHON SOURCE LINES 96-99 .. code-block:: default return_weights = 'dual' n_targets_batch_refit = 200 .. GENERATED FROM PYTHON SOURCE LINES 100-111 Run the solver. For each iteration, it will: - sample kernel weights from a Dirichlet distribution - fit (n_splits * n_alphas * n_targets) ridge models - compute the scores on the validation set of each split - average the scores over splits - take the maximum over alphas - (only if you ask for the ridge weights) refit using the best alphas per target and the entire dataset - return for each target the log kernel weights leading to the best CV score (and the best weights if necessary) .. GENERATED FROM PYTHON SOURCE LINES 111-123 .. code-block:: default results = solve_multiple_kernel_ridge_random_search( Ks=Ks_train, Y=Y_train, n_iter=n_iter, alphas=alphas, n_targets_batch=n_targets_batch, return_weights=return_weights, n_alphas_batch=n_alphas_batch, n_targets_batch_refit=n_targets_batch_refit, jitter_alphas=True, ) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none [ ] 0% | 0.00 sec | 100 random sampling with cv | [ ] 1% | 0.23 sec | 100 random sampling with cv | [ ] 2% | 0.43 sec | 100 random sampling with cv | [. ] 3% | 0.57 sec | 100 random sampling with cv | [. ] 4% | 0.71 sec | 100 random sampling with cv | [.. ] 5% | 0.84 sec | 100 random sampling with cv | [.. ] 6% | 1.01 sec | 100 random sampling with cv | [.. ] 7% | 1.18 sec | 100 random sampling with cv | [... ] 8% | 1.31 sec | 100 random sampling with cv | [... ] 9% | 1.48 sec | 100 random sampling with cv | [.... ] 10% | 1.65 sec | 100 random sampling with cv | [.... ] 11% | 1.82 sec | 100 random sampling with cv | [.... ] 12% | 1.99 sec | 100 random sampling with cv | [..... ] 13% | 2.13 sec | 100 random sampling with cv | [..... ] 14% | 2.26 sec | 100 random sampling with cv | [...... ] 15% | 2.39 sec | 100 random sampling with cv | [...... ] 16% | 2.52 sec | 100 random sampling with cv | [...... ] 17% | 2.69 sec | 100 random sampling with cv | [....... ] 18% | 2.83 sec | 100 random sampling with cv | [....... ] 19% | 2.96 sec | 100 random sampling with cv | [........ ] 20% | 3.09 sec | 100 random sampling with cv | [........ ] 21% | 3.26 sec | 100 random sampling with cv | [........ ] 22% | 3.43 sec | 100 random sampling with cv | [......... ] 23% | 3.56 sec | 100 random sampling with cv | [......... ] 24% | 3.70 sec | 100 random sampling with cv | [.......... ] 25% | 3.83 sec | 100 random sampling with cv | [.......... ] 26% | 3.96 sec | 100 random sampling with cv | [.......... ] 27% | 4.10 sec | 100 random sampling with cv | [........... ] 28% | 4.23 sec | 100 random sampling with cv | [........... ] 29% | 4.36 sec | 100 random sampling with cv | [............ ] 30% | 4.49 sec | 100 random sampling with cv | [............ ] 31% | 4.63 sec | 100 random sampling with cv | [............ ] 32% | 4.76 sec | 100 random sampling with cv | [............. ] 33% | 4.89 sec | 100 random sampling with cv | [............. ] 34% | 5.02 sec | 100 random sampling with cv | [.............. ] 35% | 5.16 sec | 100 random sampling with cv | [.............. ] 36% | 5.33 sec | 100 random sampling with cv | [.............. ] 37% | 5.46 sec | 100 random sampling with cv | [............... ] 38% | 5.59 sec | 100 random sampling with cv | [............... ] 39% | 5.72 sec | 100 random sampling with cv | [................ ] 40% | 5.86 sec | 100 random sampling with cv | [................ ] 41% | 5.99 sec | 100 random sampling with cv | [................ ] 42% | 6.12 sec | 100 random sampling with cv | [................. ] 43% | 6.25 sec | 100 random sampling with cv | [................. ] 44% | 6.42 sec | 100 random sampling with cv | [.................. ] 45% | 6.56 sec | 100 random sampling with cv | [.................. ] 46% | 6.69 sec | 100 random sampling with cv | [.................. ] 47% | 6.82 sec | 100 random sampling with cv | [................... ] 48% | 6.99 sec | 100 random sampling with cv | [................... ] 49% | 7.12 sec | 100 random sampling with cv | [.................... ] 50% | 7.26 sec | 100 random sampling with cv | [.................... ] 51% | 7.39 sec | 100 random sampling with cv | [.................... ] 52% | 7.52 sec | 100 random sampling with cv | [..................... ] 53% | 7.66 sec | 100 random sampling with cv | [..................... ] 54% | 7.83 sec | 100 random sampling with cv | [...................... ] 55% | 7.96 sec | 100 random sampling with cv | [...................... ] 56% | 8.09 sec | 100 random sampling with cv | [...................... ] 57% | 8.23 sec | 100 random sampling with cv | [....................... ] 58% | 8.36 sec | 100 random sampling with cv | [....................... ] 59% | 8.49 sec | 100 random sampling with cv | [........................ ] 60% | 8.67 sec | 100 random sampling with cv | [........................ ] 61% | 8.80 sec | 100 random sampling with cv | [........................ ] 62% | 8.97 sec | 100 random sampling with cv | [......................... ] 63% | 9.10 sec | 100 random sampling with cv | [......................... ] 64% | 9.24 sec | 100 random sampling with cv | [.......................... ] 65% | 9.37 sec | 100 random sampling with cv | [.......................... ] 66% | 9.50 sec | 100 random sampling with cv | [.......................... ] 67% | 9.64 sec | 100 random sampling with cv | [........................... ] 68% | 9.77 sec | 100 random sampling with cv | [........................... ] 69% | 9.90 sec | 100 random sampling with cv | [............................ ] 70% | 10.04 sec | 100 random sampling with cv | [............................ ] 71% | 10.17 sec | 100 random sampling with cv | [............................ ] 72% | 10.30 sec | 100 random sampling with cv | [............................. ] 73% | 10.44 sec | 100 random sampling with cv | [............................. ] 74% | 10.61 sec | 100 random sampling with cv | [.............................. ] 75% | 10.74 sec | 100 random sampling with cv | [.............................. ] 76% | 10.92 sec | 100 random sampling with cv | [.............................. ] 77% | 11.05 sec | 100 random sampling with cv | [............................... ] 78% | 11.18 sec | 100 random sampling with cv | [............................... ] 79% | 11.32 sec | 100 random sampling with cv | [................................ ] 80% | 11.45 sec | 100 random sampling with cv | [................................ ] 81% | 11.59 sec | 100 random sampling with cv | [................................ ] 82% | 11.72 sec | 100 random sampling with cv | [................................. ] 83% | 11.89 sec | 100 random sampling with cv | [................................. ] 84% | 12.02 sec | 100 random sampling with cv | [.................................. ] 85% | 12.16 sec | 100 random sampling with cv | [.................................. ] 86% | 12.29 sec | 100 random sampling with cv | [.................................. ] 87% | 12.42 sec | 100 random sampling with cv | [................................... ] 88% | 12.56 sec | 100 random sampling with cv | [................................... ] 89% | 12.69 sec | 100 random sampling with cv | [.................................... ] 90% | 12.86 sec | 100 random sampling with cv | [.................................... ] 91% | 12.99 sec | 100 random sampling with cv | [.................................... ] 92% | 13.13 sec | 100 random sampling with cv | [..................................... ] 93% | 13.26 sec | 100 random sampling with cv | [..................................... ] 94% | 13.39 sec | 100 random sampling with cv | [...................................... ] 95% | 13.53 sec | 100 random sampling with cv | [...................................... ] 96% | 13.66 sec | 100 random sampling with cv | [...................................... ] 97% | 13.79 sec | 100 random sampling with cv | [....................................... ] 98% | 13.93 sec | 100 random sampling with cv | [....................................... ] 99% | 14.06 sec | 100 random sampling with cv | [........................................] 100% | 14.20 sec | 100 random sampling with cv | .. GENERATED FROM PYTHON SOURCE LINES 124-126 As we used the ``cupy`` backend, the results are ``cupy`` arrays, which are on GPU. Here, we cast the results back to CPU, and to ``numpy`` arrays. .. GENERATED FROM PYTHON SOURCE LINES 126-130 .. code-block:: default deltas = backend.to_numpy(results[0]) dual_weights = backend.to_numpy(results[1]) cv_scores = backend.to_numpy(results[2]) .. GENERATED FROM PYTHON SOURCE LINES 131-136 Plot the convergence curve -------------------------- ``cv_scores`` gives the scores for each sampled kernel weights. The convergence curve is thus the current maximum for each target. .. GENERATED FROM PYTHON SOURCE LINES 136-148 .. code-block:: default current_max = np.maximum.accumulate(cv_scores, axis=0) mean_current_max = np.mean(current_max, axis=1) x_array = np.arange(1, len(mean_current_max) + 1) plt.plot(x_array, mean_current_max, '-o') plt.grid("on") plt.xlabel("Number of kernel weights sampled") plt.ylabel("L2 negative loss (higher is better)") plt.title("Convergence curve, averaged over targets") plt.tight_layout() plt.show() .. image:: /_auto_examples/multiple_kernel_ridge/images/sphx_glr_plot_mkr_0_random_search_001.png :alt: Convergence curve, averaged over targets :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 149-154 Plot the optimal alphas selected by the solver ---------------------------------------------- This plot is helpful to refine the alpha grid if the range is too small or too large. .. GENERATED FROM PYTHON SOURCE LINES 154-160 .. code-block:: default best_alphas = 1. / np.sum(np.exp(deltas), axis=0) plot_alphas_diagnostic(best_alphas, alphas) plt.title("Best alphas selected by cross-validation") plt.show() .. image:: /_auto_examples/multiple_kernel_ridge/images/sphx_glr_plot_mkr_0_random_search_002.png :alt: Best alphas selected by cross-validation :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 161-164 Compute the predictions on the test set --------------------------------------- (requires the dual weights) .. GENERATED FROM PYTHON SOURCE LINES 164-176 .. code-block:: default split = False scores = predict_and_score_weighted_kernel_ridge( Ks_test, dual_weights, deltas, Y_test, split=split, n_targets_batch=n_targets_batch, score_func=r2_score_split) scores = backend.to_numpy(scores) plt.hist(scores, np.linspace(0, 1, 50)) plt.xlabel(r"$R^2$ generalization score") plt.title("Histogram over targets") plt.show() .. image:: /_auto_examples/multiple_kernel_ridge/images/sphx_glr_plot_mkr_0_random_search_003.png :alt: Histogram over targets :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 177-184 Compute the split predictions on the test set --------------------------------------------- (requires the dual weights) Here we apply the dual weights on each kernel separately (``exp(deltas[i]) * kernel[i]``), and we compute the R\ :sup:`2` scores (corrected for correlations) of each prediction. .. GENERATED FROM PYTHON SOURCE LINES 184-197 .. code-block:: default split = True scores_split = predict_and_score_weighted_kernel_ridge( Ks_test, dual_weights, deltas, Y_test, split=split, n_targets_batch=n_targets_batch, score_func=r2_score_split) scores_split = backend.to_numpy(scores_split) for kk, score in enumerate(scores_split): plt.hist(score, np.linspace(0, np.max(scores_split), 50), alpha=0.7, label="kernel %d" % kk) plt.title(r"Histogram of $R^2$ generalization score split between kernels") plt.legend() plt.show() .. image:: /_auto_examples/multiple_kernel_ridge/images/sphx_glr_plot_mkr_0_random_search_004.png :alt: Histogram of $R^2$ generalization score split between kernels :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 14.633 seconds) .. _sphx_glr_download__auto_examples_multiple_kernel_ridge_plot_mkr_0_random_search.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_mkr_0_random_search.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_mkr_0_random_search.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_