himalaya.kernel_ridge.solve_weighted_kernel_ridge_gradient_descent¶
- himalaya.kernel_ridge.solve_weighted_kernel_ridge_gradient_descent(Ks, Y, deltas, alpha=1.0, fit_intercept=False, step_sizes=None, lipschitz_Ks=None, initial_dual_weights=None, max_iter=100, tol=0.001, double_K=False, random_state=None, debug=False, n_targets_batch=None)[source]¶
Solve weighted kernel ridge regression using gradient descent.
Solve the kernel ridge regression:
w* = argmin_w ||K @ w - Y||^2 + alpha (w.T @ K @ w)
where the kernel K is a weighted sum of multiple kernels:
K = sum_i exp(deltas[i]) Ks[i]
- Parameters
- Ksarray of shape (n_kernels, n_samples, n_samples)
Input kernels.
- Yarray of shape (n_samples, n_targets)
Target data.
- deltasarray of shape (n_kernels, ) or (n_kernels, n_targets)
Kernel weights.
- alphafloat, or array of shape (n_targets, )
Regularization parameter.
- fit_interceptboolean
Whether to fit an intercept. If False, Ks should be centered (see KernelCenterer), and Y must be zero-mean over samples.
- step_sizesfloat, or array of shape (n_targets), or None
Step sizes. If None, computes a step size based on the Lipschitz constants.
- lipschitz_Ksfloat, or array of shape (n_kernels), or None:
Lipschitz constant. Used only if step_sizes is None. If None, Lipschitz constants are estimated with power iteration on Ks.
- initial_dual_weightsarray of shape (n_samples, n_targets)
Initial kernel ridge coefficients.
- max_iterint
Maximum number of gradient step.
- tolfloat > 0 or None
Tolerance for the stopping criterion.
- double_Kbool
If True, multiply the gradient by the kernel to obtain the true gradients, which are less well conditionned.
- random_stateint, or None
Random generator seed. Use an int for deterministic search.
- debugbool
If True, check some intermediate computations.
- n_targets_batchint or None
Size of the batch for over targets during cross-validation. Used for memory reasons. If None, uses all n_targets at once.
- Returns
- dual_weightsarray of shape (n_samples, n_targets)
Kernel ridge coefficients.
- interceptarray of shape (n_targets,)
Intercept. Only returned when fit_intercept is True.