Model descriptions¶
This package implements a number of models.
Ridge¶
Let \(X\in \mathbb{R}^{n\times p}\) be a feature matrix with \(n\) samples and \(p\) features, \(y\in \mathbb{R}^n\) a target vector, and \(\alpha > 0\) a fixed regularization hyperparameter. Ridge regression [1] defines the weight vector \(b^*\in \mathbb{R}^p\) as:
The equation has a closed-form solution \(b^* = M y\), where \(M = (X^\top X + \alpha I_p)^{-1}X^\top \in \mathbb{R}^{p \times n}\).
This model is implemented in
Ridge(scikit-learn-compatible estimator)solve_ridge_svd()(function)
KernelRidge¶
By the Woodbury matrix identity, \(b^*\) can be written as \(b^* = X^\top(XX^\top + \alpha I_n)^{-1}y\), or \(b^* = X^\top w^*\) for some \(w^*\in \mathbb{R}^n\). Noting the linear kernel \(K = X X^\top \in \mathbb{R}^{n\times n}\), this leads to the equivalent formulation:
This model can be extended to arbitrary positive semidefinite kernels \(K\), leading to the more general kernel ridge regression [2].
This model is implemented in
KernelRidge(scikit-learn-compatible estimator)solve_kernel_ridge_eigenvalues()(function)solve_kernel_ridge_gradient_descent()(function)solve_kernel_ridge_conjugate_gradient()(function)
RidgeCV and KernelRidgeCV¶
In practice, because the ridge regression and kernel ridge regression hyperparameter \(\alpha\) is unknown, it is typically selected through a grid-search with cross-validation. In cross-validation, we split the data set into a training set \((X_{train}, y_{train})\) and a validation set \((X_{val}, y_{val})\). Then, we train the model on the training set, and evaluate the generalization performance on the validation set. We perform this process for multiple hyperparameter candidates \(\alpha\), typically defined over a grid of log-spaced values. Finally, we keep the candidate leading to the best generalization performance, as measured by the validation loss, averaged over all cross-validation splits.
These models are implemented in
RidgeCV(scikit-learn-compatible estimator)solve_ridge_cv_svd()(function)KernelRidgeCV(scikit-learn-compatible estimator)solve_kernel_ridge_cv_eigenvalues()(function)
GroupRidgeCV / BandedRidgeCV¶
In some applications, features are naturally grouped into groups (or feature spaces). To adapt the regularization level to each feature space, ridge regression can be extended to group-regularized ridge regression (also known as banded ridge regression [3]). In this model, a separate hyperparameter is optimized for each feature space:
This is equivalent to solving a ridge regression:
where the feature space \(X_i\) is scaled by a group scaling \(Z_i = e^{\delta_i} X_i\). The hyperparameters \(\delta_i = - \log(\alpha_i)\) are then learned over cross-validation [4].
This model is implemented in
GroupRidgeCV(scikit-learn-compatible estimator)solve_group_ridge_random_search()(function)
See also multiple-kernel ridge regression, which is equivalent to group-regularization ridge regression when using one linear kernel per group of features:
MultipleKernelRidgeCV(scikit-learn-compatible estimator)solve_multiple_kernel_ridge_random_search()(function)solve_multiple_kernel_ridge_hyper_gradient()(function)
Note
“Group ridge regression” is also sometimes called “Banded ridge regression”.
WeightedKernelRidge¶
To extend kernel ridge to group-regularization, we can compute the kernel as a weighted sum of multiple kernels, \(K = \sum_{i=1}^m e^{\delta_i} K_i\). Then, we can use \(K_i = X_i X_i^\top\) for different groups of features \(X_i\). The model becomes:
This model is called weighted kernel ridge regression. The log-kernel-weights \(\delta_i\) are here fixed. When all the targets use the same log-kernel-weights, a single weighted kernel can be precomputed and used in a kernel ridge regression. However, when the log-kernel-weights are different for each target, the kernel sum cannot be precomputed, and the model requires some specific algorithms to be fit.
This model is implemented in
WeightedKernelRidge(scikit-learn-compatible estimator)solve_weighted_kernel_ridge_neumann_series()(function)
MultipleKernelRidgeCV¶
In weighted kernel ridge regression, when the log-kernel-weights \(\delta_i\) are unknown, we can learn them over cross-validation. This model is called multiple-kernel ridge regression. When the kernels are defined by \(K_i = X_i X_i^\top\) for different groups of features \(X_i\), multiple-kernel ridge regression is equivalent with group-ridge regression (aka banded ridge regression).
This model is implemented in
MultipleKernelRidgeCV(scikit-learn-compatible estimator)solve_multiple_kernel_ridge_hyper_gradient()(function)solve_multiple_kernel_ridge_random_search()(function)
Model flowchart¶
The following flowchart can be used as a guide to select the right estimator.
graph TD;
A(How many feature space ?)
O(Data size ?)
M(Data size ?)
OR(Hyperparameters ?)
OK(Hyperparameters ?)
MR(Hyperparameters ?)
MK(Hyperparameters ?)
A-- one-->O;
A--multiple-->M;
O--more samples-->OR;
O--more features-->OK;
M--more samples-->MR;
M--more features-->MK;
OK--known-->OKH[KernelRidge];
OK--unknown-->OKCV[KernelRidgeCV];
OR--known-->ORH[Ridge];
OR--unknown-->ORCV[RidgeCV];
MK--known-->MKH[WeightedKernelRidge];
MK--unknown-->MKCV[MultipleKernelRidgeCV];
MR--unknown-->MRCV[BandedRidgeCV];
MR--known-->MKH;
classDef fork fill:#FFDC97
class A,O,M,OR,OK,MR,MK fork;
classDef leaf fill:#ABBBE1
class ORH,OKH,MRH,MKH leaf;
class ORCV,OKCV,MRCV,MKCV leaf;
click ORH "https://gallantlab.github.io/himalaya/_generated/himalaya.ridge.Ridge.html"
click ORCV "https://gallantlab.github.io/himalaya/_generated/himalaya.ridge.RidgeCV.html"
click MRCV "https://gallantlab.github.io/himalaya/_generated/himalaya.ridge.BandedRidgeCV.html"
click OKH "https://gallantlab.github.io/himalaya/_generated/himalaya.kernel_ridge.KernelRidge.html"
click OKCV "https://gallantlab.github.io/himalaya/_generated/himalaya.kernel_ridge.KernelRidgeCV.html"
click MKH "https://gallantlab.github.io/himalaya/_generated/himalaya.kernel_ridge.WeightedKernelRidge.html"
click MKCV "https://gallantlab.github.io/himalaya/_generated/himalaya.kernel_ridge.MultipleKernelRidgeCV.html"