himalaya.kernel_ridge.solve_weighted_kernel_ridge_neumann_series

himalaya.kernel_ridge.solve_weighted_kernel_ridge_neumann_series(Ks, Y, deltas, alpha=1.0, fit_intercept=False, max_iter=10, factor=0.0001, n_targets_batch=None, tol=None, random_state=None, debug=False)[source]

Solve weighted kernel ridge regression using Neumann series.

Solve the kernel ridge regression:

w* = argmin_w ||K @ w - Y||^2 + alpha (w.T @ K @ w)

where the kernel K is a weighted sum of multiple kernels:

K = sum_i exp(deltas[i]) Ks[i]

The Neumann series approximate the invert of K as K^-1 = sum_j (Id - K)^j. It is a poor approximation, so this solver should NOT be used to solve ridge. It is however useful during hyper-parameter gradient descent, as we do not need a good precision of the results, but merely the direction of the gradient.

See [Lorraine, Vicol, & Duvenaud (2019). Optimizing Millions of Hyperparameters by Implicit Differentiation. arXiv:1911.02590].

Parameters
Ksarray of shape (n_kernels, n_samples, n_samples)

Input kernels.

Ytorch.Tensor of shape (n_samples, n_targets)

Target data.

deltasarray of shape (n_kernels, ) or (n_kernels, n_targets)

Kernel weights.

alphafloat, or array of shape (n_targets, )

Regularization parameter.

fit_interceptboolean

Whether to fit an intercept. If False, Ks should be centered (see KernelCenterer), and Y must be zero-mean over samples.

max_iterint

Number of terms in the Neumann series.

factorfloat, or array of shape (n_targets, )

Factor used to allow convergence of the series. We actually invert (factor * K) instead of K, then multiply the result by factor.

n_targets_batchint or None

Size of the batch for over targets during cross-validation. Used for memory reasons. If None, uses all n_targets at once.

tolNone

Not used.

random_stateint, or None

Random generator seed. Not used.

debugbool

If True, check some intermediate computations.

Returns
dual_weightsarray of shape (n_samples, n_targets)

Kernel ridge coefficients.

interceptarray of shape (n_targets,)

Intercept. Only returned when fit_intercept is True.