Menu
×
   ❮     
HTML CSS JAVASCRIPT SQL PYTHON JAVA PHP HOW TO W3.CSS C C++ C# BOOTSTRAP REACT MYSQL JQUERY EXCEL XML DJANGO NUMPY PANDAS NODEJS R TYPESCRIPT ANGULAR GIT POSTGRESQL MONGODB ASP AI GO KOTLIN SASS VUE DSA GEN AI SCIPY AWS CYBERSECURITY DATA SCIENCE
     ❯   

Python 教程

Python 主页 Python 简介 Python 入门 Python 语法 Python 注释 Python 变量 Python 数据类型 Python 数字 Python 类型转换 Python 字符串 Python 布尔值 Python 运算符 Python 列表 Python 元组 Python 集合 Python 字典 Python If...Else Python While 循环 Python For 循环 Python 函数 Python Lambda Python 数组 Python 类/对象 Python 继承 Python 迭代器 Python 多态 Python 作用域 Python 模块 Python 日期 Python 数学 Python JSON Python 正则表达式 Python PIP Python Try...Except Python 用户输入 Python 字符串格式化

文件处理

Python 文件处理 Python 读取文件 Python 写入/创建文件 Python 删除文件

Python 模块

NumPy 教程 Pandas 教程 SciPy 教程 Django 教程

Python Matplotlib

Matplotlib 简介 Matplotlib 入门 Matplotlib Pyplot Matplotlib 绘图 Matplotlib 标记 Matplotlib 线 Matplotlib 标签 Matplotlib 网格 Matplotlib 子图 Matplotlib 散点图 Matplotlib 条形图 Matplotlib 直方图 Matplotlib 饼图

机器学习

入门 平均数 中位数 众数 标准差 百分位数 数据分布 正态数据分布 散点图 线性回归 多项式回归 多元回归 缩放 训练/测试 决策树 混淆矩阵 层次聚类 逻辑回归 网格搜索 分类数据 K-均值 自助聚合 交叉验证 AUC - ROC 曲线 K-近邻

Python MySQL

MySQL 入门 MySQL 创建数据库 MySQL 创建表 MySQL 插入 MySQL 查询 MySQL Where MySQL 排序 MySQL 删除 MySQL 删除表 MySQL 更新 MySQL 限制 MySQL 连接

Python MongoDB

MongoDB 入门 MongoDB 创建数据库 MongoDB 集合 MongoDB 插入 MongoDB 查找 MongoDB 查询 MongoDB 排序 MongoDB 删除 MongoDB 删除集合 MongoDB 更新 MongoDB 限制

Python 参考

Python 概述 Python 内置函数 Python 字符串方法 Python 列表方法 Python 字典方法 Python 元组方法 Python 集合方法 Python 文件方法 Python 关键字 Python 异常 Python 词汇表

模块参考

随机模块 请求模块 统计模块 数学模块 cMath 模块

Python 如何

移除列表重复项 反转字符串 添加两个数字

Python 示例

Python 示例 Python 编译器 Python 练习 Python 问答 Python 服务器 Python 面试问答 Python 集训营 Python 证书

机器学习 - 多项式回归


多项式回归

如果您的数据点明显不适合线性回归(一条穿过所有数据点的直线),那么多项式回归可能是理想的选择。

多项式回归与线性回归类似,利用变量 x 和 y 之间的关联关系,找到最佳方式在数据点上绘制一条线。


它是如何工作的?

Python 具备用于查找数据点之间的关联关系并绘制多项式回归线的方法。我们将向您展示如何使用这些方法,而不是深入讲解数学公式。

在以下示例中,我们记录了 18 辆汽车通过某收费站时的状况。

我们记录了汽车的速度以及通过收费站的时间(小时)。

x 轴表示一天中的小时,y 轴表示速度

示例

首先绘制散点图

import matplotlib.pyplot as plt

x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]

plt.scatter(x, y)
plt.show()

结果

运行示例 »

示例

导入 numpy matplotlib,然后绘制多项式回归线

import numpy
import matplotlib.pyplot as plt

x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

myline = numpy.linspace(1, 22, 100)

plt.scatter(x, y)
plt.plot(myline, mymodel(myline))
plt.show()

结果

运行示例 »

示例说明

导入所需的模块。

您可以在我们的 NumPy 教程 中了解 NumPy 模块。

您可以在我们的 SciPy 教程 中了解 SciPy 模块。

import numpy
import matplotlib.pyplot as plt

创建表示 x 轴和 y 轴值的数组

x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]

NumPy 有一个方法可以让我们创建多项式模型

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

然后指定直线的显示方式,我们从位置 1 开始,到位置 22 结束

myline = numpy.linspace(1, 22, 100)

绘制原始散点图

plt.scatter(x, y)

绘制多项式回归线

plt.plot(myline, mymodel(myline))

显示图表

plt.show()



R 平方

了解 x 轴和 y 轴值之间关系的好坏非常重要,如果它们之间没有关系,则无法使用多项式回归来预测任何事物。

关系是通过一个称为 r 平方值的指标来衡量的。

r 平方值的范围在 0 到 1 之间,其中 0 表示没有关系,1 表示 100% 相关。

Python 和 Sklearn 模块会为你计算这个值,你只需要提供 x 和 y 数组即可。

示例

我的数据在多项式回归中拟合得怎么样?

import numpy
from sklearn.metrics import r2_score

x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

print(r2_score(y, mymodel(x)))
自己试一试 »

注意: 结果 0.94 表明存在非常好的关系,我们可以在未来的预测中使用多项式回归。


预测未来值

现在我们可以使用收集的信息来预测未来的值。

示例:让我们尝试预测一辆汽车在 17:00 左右经过收费站时的速度。

为此,我们需要来自上述示例的相同 mymodel 数组。

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

示例

预测在 17:00 经过的汽车速度

import numpy
from sklearn.metrics import r2_score

x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

speed = mymodel(17)
print(speed)
运行示例 »

该示例预测速度为 88.87,我们也可以从图中读出。


拟合不好?

让我们创建一个示例,其中多项式回归不是预测未来值的最佳方法。

示例

这些 x 轴和 y 轴的值应该导致多项式回归拟合得很差。

import numpy
import matplotlib.pyplot as plt

x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

myline = numpy.linspace(2, 95, 100)

plt.scatter(x, y)
plt.plot(myline, mymodel(myline))
plt.show()

结果

运行示例 »

那么 r 平方值呢?

示例

你应该得到一个非常低的 r 平方值。

import numpy
from sklearn.metrics import r2_score

x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]

mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))

print(r2_score(y, mymodel(x)))
自己试一试 »

结果:0.00995 表示关系很差,告诉我们该数据集不适合多项式回归。


×

Contact Sales

If you want to use W3Schools services as an educational institution, team or enterprise, send us an e-mail:
[email protected]

Report Error

If you want to report an error, or if you want to make a suggestion, send us an e-mail:
[email protected]

W3Schools is optimized for learning and training. Examples might be simplified to improve reading and learning. Tutorials, references, and examples are constantly reviewed to avoid errors, but we cannot warrant full correctness of all content. While using W3Schools, you agree to have read and accepted our terms of use, cookie and privacy policy.

Copyright 1999-2024 by Refsnes Data. All Rights Reserved. W3Schools is Powered by W3.CSS.