Menu
×
   ❮     
HTML CSS JAVASCRIPT SQL PYTHON JAVA PHP HOW TO W3.CSS C C++ C# BOOTSTRAP REACT MYSQL JQUERY EXCEL XML DJANGO NUMPY PANDAS NODEJS R TYPESCRIPT ANGULAR GIT POSTGRESQL MONGODB ASP AI GO KOTLIN SASS VUE DSA GEN AI SCIPY AWS CYBERSECURITY DATA SCIENCE
     ❯   

机器学习

学习是循环

ML 模型通过循环数据多次来训练

在每次迭代中,权重值都会被调整。

当迭代无法降低成本时,训练完成。

训练我找到最佳拟合线

自己试试 »


梯度下降

梯度下降是一种解决 AI 问题的流行算法。

一个简单的线性回归模型可以用来演示梯度下降。

线性回归的目标是将一条线性图拟合到一组 (x,y) 点上。这可以用一个数学公式来解决。但是,一个机器学习算法也可以解决这个问题。

这就是上面示例所做的。

它从一个散点图和一个线性模型 (y = wx + b) 开始。

然后它训练模型找到一条拟合图的线。这通过改变线的权重(斜率)和偏差(截距)来完成。

下面是训练器对象的代码,它可以解决这个问题(以及许多其他问题)。


一个训练器对象

创建一个训练器对象,它可以在两个数组(xArr、yArr)中接收任意数量的 (x,y) 值。

将权重设置为零,并将偏差设置为 1。

必须设置学习常数 (learnc),并且必须定义成本变量

示例

function Trainer(xArray, yArray) {
  this.xArr = xArray;
  this.yArr = yArray;
  this.points = this.xArr.length;
  this.learnc = 0.00001;
  this.weight = 0;
  this.bias = 1;
  this.cost;


成本函数

解决回归问题的一种标准方法是使用“成本函数”,该函数衡量解决方案的优劣。

该函数使用模型中的权重和偏差(y = wx + b),并返回一个误差,该误差基于线拟合图的程度。

计算此误差的方法是循环遍历图中的所有 (x,y) 点,并将每个点的 y 值与直线的平方距离之和。

最常规的方法是将距离平方(以确保为正值)并使误差函数可微。

this.costError = function() {
  total = 0;
  for (let i = 0; i < this.points; i++) {
    total += (this.yArr[i] - (this.weight * this.xArr[i] + this.bias)) **2;
  }
  return total / this.points;
}

成本函数的另一个名称是误差函数

函数中使用的公式实际上是

Formula
  • E 是误差(成本)
  • N 是观测值的总数(点)
  • y 是每个观测值的价值(标签)
  • x 是每个观测值的价值(特征)
  • m 是斜率(权重)
  • b 是截距(偏差)
  • mx + b 是预测值
  • 1/N * N∑1 是平方均值

训练函数

我们现在将运行梯度下降。

梯度下降算法应该将成本函数朝最佳线移动。

每次迭代都应该将 m 和 b 更新到成本(误差)较低的直线。

为此,我们添加一个训练函数,该函数多次循环遍历所有数据

this.train = function(iter) {
  for (let i = 0; i < iter; i++) {
    this.updateWeights();
  }
  this.cost = this.costError();
}

更新权重函数

上面的训练函数应该在每次迭代中更新权重和偏差。

移动方向使用两个偏导数计算

this.updateWeights = function() {
  let wx;
  let w_deriv = 0;
  let b_deriv = 0;
  for (let i = 0; i < this.points; i++) {
    wx = this.yArr[i] - (this.weight * this.xArr[i] + this.bias);
    w_deriv += -2 * wx * this.xArr[i];
    b_deriv += -2 * wx;
  }
  this.weight -= (w_deriv / this.points) * this.learnc;
  this.bias -= (b_deriv / this.points) * this.learnc;
}

创建自己的库

库代码

function Trainer(xArray, yArray) {
  this.xArr = xArray;
  this.yArr = yArray;
  this.points = this.xArr.length;
  this.learnc = 0.00001;
  this.weight = 0;
  this.bias = 1;
  this.cost;

// 成本函数
this.costError = function() {
  total = 0;
  for (let i = 0; i < this.points; i++) {
    total += (this.yArr[i] - (this.weight * this.xArr[i] + this.bias)) **2;
  }
  return total / this.points;
}

// 训练函数
this.train = function(iter) {
  for (let i = 0; i < iter; i++) {
    this.updateWeights();
  }
  this.cost = this.costError();
}

// 更新权重函数
this.updateWeights = function() {
  let wx;
  let w_deriv = 0;
  let b_deriv = 0;
  for (let i = 0; i < this.points; i++) {
    wx = this.yArr[i] - (this.weight * this.xArr[i] + this.bias);
    w_deriv += -2 * wx * this.xArr[i];
    b_deriv += -2 * wx;
  }
  this.weight -= (w_deriv / this.points) * this.learnc;
  this.bias -= (b_deriv / this.points) * this.learnc;
}

} // 结束训练器对象

现在您可以在 HTML 中包含该库

<script src="myailib.js"></script>

自己试试 »


×

Contact Sales

If you want to use W3Schools services as an educational institution, team or enterprise, send us an e-mail:
[email protected]

Report Error

If you want to report an error, or if you want to make a suggestion, send us an e-mail:
[email protected]

W3Schools is optimized for learning and training. Examples might be simplified to improve reading and learning. Tutorials, references, and examples are constantly reviewed to avoid errors, but we cannot warrant full correctness of all content. While using W3Schools, you agree to have read and accepted our terms of use, cookie and privacy policy.

Copyright 1999-2024 by Refsnes Data. All Rights Reserved. W3Schools is Powered by W3.CSS.