From bb8733895a813c4ddbbdac72a235b5d32d3d1c74 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 15 Jun 2024 00:08:38 +0530 Subject: [PATCH] Fix UserWarning --- .../plot_gpr_prior_posterior.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/gaussian_process/plot_gpr_prior_posterior.py b/examples/gaussian_process/plot_gpr_prior_posterior.py index 78d3c1a3ff71d..6de06fd0dea2e 100644 --- a/examples/gaussian_process/plot_gpr_prior_posterior.py +++ b/examples/gaussian_process/plot_gpr_prior_posterior.py @@ -196,23 +196,30 @@ def plot_gpr_samples(gpr_model, n_samples, ax): # %% # Dot-product kernel # .................. +from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ConstantKernel, DotProduct +from sklearn.preprocessing import StandardScaler +import matplotlib.pyplot as plt -kernel = ConstantKernel(0.1, (0.01, 10.0)) * ( - DotProduct(sigma_0=1.0, sigma_0_bounds=(0.1, 10.0)) ** 2 -) -gpr = GaussianProcessRegressor(kernel=kernel, random_state=0) +# Scale the data +scaler = StandardScaler() +X_train_scaled = scaler.fit_transform(X_train) -fig, axs = plt.subplots(nrows=2, sharex=True, sharey=True, figsize=(10, 8)) +# Define the kernel +kernel = ConstantKernel(0.1, (0.01, 10.0)) * (DotProduct(sigma_0=1.0, sigma_0_bounds=(0.1, 10.0)) ** 2) -# plot prior +# Increase the number of iterations +gpr = GaussianProcessRegressor(kernel=kernel, random_state=0, n_restarts_optimizer=10, optimizer='fmin_l_bfgs_b') + +# Plot prior +fig, axs = plt.subplots(nrows=2, sharex=True, sharey=True, figsize=(10, 8)) plot_gpr_samples(gpr, n_samples=n_samples, ax=axs[0]) axs[0].set_title("Samples from prior distribution") -# plot posterior -gpr.fit(X_train, y_train) +# Fit the model and plot posterior +gpr.fit(X_train_scaled, y_train) plot_gpr_samples(gpr, n_samples=n_samples, ax=axs[1]) -axs[1].scatter(X_train[:, 0], y_train, color="red", zorder=10, label="Observations") +axs[1].scatter(X_train_scaled[:, 0], y_train, color="red", zorder=10, label="Observations") axs[1].legend(bbox_to_anchor=(1.05, 1.5), loc="upper left") axs[1].set_title("Samples from posterior distribution")