Menu
×
   ❮     
HTML CSS JAVASCRIPT SQL PYTHON JAVA PHP HOW TO W3.CSS C C++ C# BOOTSTRAP REACT MYSQL JQUERY EXCEL XML DJANGO NUMPY PANDAS NODEJS R TYPESCRIPT ANGULAR GIT POSTGRESQL MONGODB ASP AI GO KOTLIN SASS VUE DSA GEN AI SCIPY AWS CYBERSECURITY DATA SCIENCE
     ❯   

Python 教程

Python 主页 Python 简介 Python 入门 Python 语法 Python 注释 Python 变量 Python 数据类型 Python 数字 Python 类型转换 Python 字符串 Python 布尔值 Python 运算符 Python 列表 Python 元组 Python 集合 Python 字典 Python If...Else Python While 循环 Python For 循环 Python 函数 Python Lambda Python 数组 Python 类/对象 Python 继承 Python 迭代器 Python 多态 Python 范围 Python 模块 Python 日期 Python 数学 Python JSON Python 正则表达式 Python PIP Python Try...Except Python 用户输入 Python 字符串格式化

文件处理

Python 文件处理 Python 读取文件 Python 写入/创建文件 Python 删除文件

Python 模块

NumPy 教程 Pandas 教程 SciPy 教程 Django 教程

Python Matplotlib

Matplotlib 简介 Matplotlib 入门 Matplotlib Pyplot Matplotlib 绘图 Matplotlib 标记 Matplotlib 线 Matplotlib 标签 Matplotlib 网格 Matplotlib 子图 Matplotlib 散点图 Matplotlib 条形图 Matplotlib 直方图 Matplotlib 饼图

机器学习

入门 平均值 中位数 众数 标准差 百分位数 数据分布 正态数据分布 散点图 线性回归 多项式回归 多元回归 缩放 训练/测试 决策树 混淆矩阵 层次聚类 逻辑回归 网格搜索 分类数据 K-means 自助聚合 交叉验证 AUC - ROC 曲线 K 最近邻

Python MySQL

MySQL 入门 MySQL 创建数据库 MySQL 创建表 MySQL 插入 MySQL 选择 MySQL Where MySQL 排序 MySQL 删除 MySQL 删除表 MySQL 更新 MySQL 限制 MySQL 连接

Python MongoDB

MongoDB 入门 MongoDB 创建数据库 MongoDB 集合 MongoDB 插入 MongoDB 查找 MongoDB 查询 MongoDB 排序 MongoDB 删除 MongoDB 删除集合 MongoDB 更新 MongoDB 限制

Python 参考

Python 概述 Python 内置函数 Python 字符串方法 Python 列表方法 Python 字典方法 Python 元组方法 Python 集合方法 Python 文件方法 Python 关键字 Python 异常 Python 词汇表

模块参考

Random 模块 Requests 模块 Statistics 模块 Math 模块 cMath 模块

Python 如何

删除列表重复项 反转字符串 添加两个数字

Python 例子

Python 例子 Python 编译器 Python 练习 Python 测验 Python 服务器 Python 面试问答 Python 集训营 Python 证书

机器学习 - 交叉验证


在本页,W3schools.com 与 纽约数据科学学院 合作,为我们的学生提供数字培训内容。


交叉验证

在调整模型时,我们的目标是提高模型在未见过的数据上的整体性能。超参数调整可以显著提高测试集上的性能。然而,对测试集进行参数优化可能会导致信息泄露,从而导致模型在未见过的数据上的性能更差。为了纠正这一点,我们可以进行交叉验证。

为了更好地理解 CV,我们将对鸢尾花数据集执行不同的方法。首先让我们加载并分离数据。

from sklearn import datasets

X, y = datasets.load_iris(return_X_y=True)

交叉验证有许多方法,我们将从 k 折交叉验证开始。


K 折交叉验证

用于模型训练的训练数据被分成 k 个较小的子集,用于验证模型。然后在 k-1 个训练集折叠上训练模型。剩余的折叠作为验证集用于评估模型。

由于我们将尝试对不同种类的鸢尾花进行分类,因此我们需要导入一个分类器模型,在本练习中,我们将使用 DecisionTreeClassifier。我们还需要从 sklearn 导入 CV 模块。

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold, cross_val_score

数据加载完成后,我们可以创建并拟合一个模型进行评估。

clf = DecisionTreeClassifier(random_state=42)

现在让我们评估我们的模型,看看它在每个k折上的表现。

k_folds = KFold(n_splits = 5)

scores = cross_val_score(clf, X, y, cv = k_folds)

将所有折的得分平均起来查看CV的整体表现也是一个好习惯。

示例

运行k折交叉验证

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import KFold, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

k_folds = KFold(n_splits = 5)

scores = cross_val_score(clf, X, y, cv = k_folds)

print("交叉验证得分: ", scores)
print("平均CV得分: ", scores.mean())
print("用于平均的CV得分数量: ", len(scores))
运行示例 »

广告


分层k折交叉验证

在类不平衡的情况下,我们需要一种方法来解决训练集和验证集中的不平衡问题。为此,我们可以对目标类进行分层,这意味着两组将包含所有类的相等比例。

示例

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

sk_folds = StratifiedKFold(n_splits = 5)

scores = cross_val_score(clf, X, y, cv = sk_folds)

print("交叉验证得分: ", scores)
print("平均CV得分: ", scores.mean())
print("用于平均的CV得分数量: ", len(scores))
运行示例 »

虽然折数相同,但确保分层类后,基本k折的平均CV会增加。


留一法 (LOO)

与在训练数据集中选择拆分数量的k折交叉验证不同,留一法使用1个观察值进行验证,并使用n-1个观察值进行训练。这种方法是一种穷举技术。

示例

运行留一法交叉验证

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import LeaveOneOut, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

loo = LeaveOneOut()

scores = cross_val_score(clf, X, y, cv = loo)

print("交叉验证得分: ", scores)
print("平均CV得分: ", scores.mean())
print("用于平均的CV得分数量: ", len(scores))
运行示例 »

我们可以观察到,执行的交叉验证得分数量等于数据集中观察值的数量。在本例中,虹膜数据集中有150个观察值。

平均CV得分是94%。


留p法 (LPO)

留p法只是留一法的细微差异,我们可以选择验证集中使用的p的数量。

示例

运行留p法交叉验证

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import LeavePOut, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

lpo = LeavePOut(p=2)

scores = cross_val_score(clf, X, y, cv = lpo)

print("交叉验证得分: ", scores)
print("平均CV得分: ", scores.mean())
print("用于平均的CV得分数量: ", len(scores))
运行示例 »

正如我们所见,这是一种穷举方法,我们计算的得分比留一法要多得多,即使p=2,但它仍然获得了大致相同的平均CV得分。


随机拆分

KFold不同,ShuffleSplit会省略一部分数据,这些数据不会用于训练集或验证集。为此,我们必须决定训练集和测试集的大小,以及拆分数量。

示例

运行随机拆分交叉验证

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import ShuffleSplit, cross_val_score

X, y = datasets.load_iris(return_X_y=True)

clf = DecisionTreeClassifier(random_state=42)

ss = ShuffleSplit(train_size=0.6, test_size=0.3, n_splits = 5)

scores = cross_val_score(clf, X, y, cv = ss)

print("交叉验证得分: ", scores)
print("平均CV得分: ", scores.mean())
print("用于平均的CV得分数量: ", len(scores))
运行示例 »

结束说明

这些只是可以应用于模型的几种CV方法。还有很多其他的交叉验证类,大多数模型都有自己的类。查看sklearn的交叉验证以了解更多CV选项。


×

Contact Sales

If you want to use W3Schools services as an educational institution, team or enterprise, send us an e-mail:
[email protected]

Report Error

If you want to report an error, or if you want to make a suggestion, send us an e-mail:
[email protected]

W3Schools is optimized for learning and training. Examples might be simplified to improve reading and learning. Tutorials, references, and examples are constantly reviewed to avoid errors, but we cannot warrant full correctness of all content. While using W3Schools, you agree to have read and accepted our terms of use, cookie and privacy policy.

Copyright 1999-2024 by Refsnes Data. All Rights Reserved. W3Schools is Powered by W3.CSS.