高斯过程回归(GPR)在Python中的实现
2024.04.02 11:35浏览量:25简介:高斯过程回归(GPR)是一种强大的非参数贝叶斯方法,适用于回归问题。本文介绍了使用sklearn.gaussian_process模块在Python中实现GPR的步骤,包括模型训练、预测以及参数调整等。
文心大模型4.5及X1 正式发布
百度智能云千帆全面支持文心大模型4.5 API调用,文心大模型X1即将上线
引言
高斯过程回归(Gaussian Process Regression,简称GPR)是一种非参数贝叶斯方法,用于解决回归问题。它基于高斯过程(Gaussian Process,简称GP)的先验分布,通过观测数据来逐步修正这个先验分布,最终得到预测分布。GPR在机器学习和统计学中得到了广泛应用,特别是在处理小数据集、具有噪声或不确定性的数据时表现出色。
在Python中,我们可以使用sklearn.gaussian_process
模块来实现GPR。下面将介绍如何使用这个模块进行GPR的建模、预测和参数调整。
1. 导入必要的库
首先,我们需要导入所需的库。在这个例子中,我们将使用numpy
进行数值计算,matplotlib
进行可视化,以及sklearn.gaussian_process
进行高斯过程回归。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
2. 准备数据
为了演示GPR,我们需要一些训练数据。在这个例子中,我们将生成一个简单的正弦波数据。
# 生成训练数据
X_train = np.atleast_2d(np.linspace(0, 10, 1000)).T
y_train = np.sin(X_train).ravel()
# 可视化训练数据
plt.figure(figsize=(10, 5))
plt.plot(X_train, y_train, 'r.', markersize=10, label='Observations')
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.legend(loc='upper left')
plt.show()
3. 定义高斯过程模型
接下来,我们需要定义高斯过程模型。在sklearn.gaussian_process
中,我们使用GaussianProcessRegressor
类来实现GPR。这个类需要一个内核(kernel)来定义输入空间中的协方差结构。在这个例子中,我们使用径向基函数(Radial Basis Function,简称RBF)内核。
# 定义内核
kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))
# 创建高斯过程回归模型
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
在上面的代码中,C(1.0, (1e-3, 1e3))
表示常数项内核,RBF(10, (1e-2, 1e2))
表示径向基函数内核。这些内核参数可以通过交叉验证等方法进行调整。n_restarts_optimizer
参数指定了优化器在寻找最佳参数时重启的次数,这里设置为9。
4. 训练模型
现在我们可以使用训练数据来训练高斯过程回归模型了。在GaussianProcessRegressor
中,我们使用fit
方法来进行训练。
# 训练模型
gp.fit(X_train, y_train)
5. 进行预测
训练完成后,我们可以使用predict
方法进行预测。这个方法返回预测值以及预测值的标准差。
```python
定义测试数据
X_test = np.atleast_2d(np.linspace(0, 10, 100)).T
进行预测
y_pred, sigma = gp.predict(X_test, return_std=True)
可视化预测结果
plt.figure(figsize=(10, 5))
plt.plot(X_train, y_train, ‘r.’, markersize=10, label=’Observations’)
plt.plot(X_test, y_pred, ‘b-‘, label=’Prediction’)
plt.fill(np.concatenate([X_test, X_test[::-1]]),
np.concatenate([y_pred - 1.96 sigma,
(y_pred + 1.96 sigma)[::-1]]),
alpha=.2, fc=’b’, ec=’None’, label=’95% conf. interval’)
plt.xlabel(‘$x$’)
plt.ylabel(‘$y$’)
plt.legend(loc=’upper left’)

发表评论
登录后可评论,请前往 登录 或 注册