Source code for himalaya.viz
import numpy as np
[docs]def plot_alphas_diagnostic(best_alphas, alphas, ax=None):
"""Plot a diagnostic plot for the selected alphas during cross-validation.
To figure out whether to increase the range of alphas.
Parameters
----------
best_alphas : array of shape (n_targets, )
Alphas selected during cross-validation for each target.
alphas : array of shape (n_alphas)
Alphas used while fitting the model.
ax : None or figure axis
Returns
-------
ax : figure axis
"""
import matplotlib.pyplot as plt
alphas = np.sort(alphas)
n_alphas = len(alphas)
indices = np.argmin(np.abs(best_alphas[None] - alphas[:, None]), 0)
hist = np.bincount(indices, minlength=n_alphas)
if ax is None:
fig, ax = plt.subplots(1, 1)
log10alphas = np.log(alphas) / np.log(10)
ax.plot(log10alphas, hist, '.-', markersize=12)
ax.set_ylabel('Number of targets')
ax.set_xlabel('log10(alpha)')
ax.grid("on")
return ax