2022-10-10AI00
请注意,本文编写于 53 天前,最后修改于 53 天前,其中某些信息可能已经过时。

简单的梯度下降算法实现

二维梯度下降

之前使用的算法是一次性计算出结果,在样本数量比较少的情况下还好,在样本较多的情况下,机器的算力可能不足以快速的计算出结果,这时候就需要使用梯度下降算法,我们先从简单的二维开始,即拟合y=wx。

首先依旧是获取散点,绘制坐标系,预设一个w值为0.1

import dataset
import matplotlib.pyplot as plt
import numpy as np
xs, ys = dataset.get_beans(100)


print(xs)
print(ys)


plt.title("STF", fontsize=12)
plt.xlabel("B")
plt.ylabel("T")

plt.scatter(xs, ys)

w = 0.1
y_pre = w*xs
plt.plot(xs, y_pre)
plt.show()

这次写的是随机梯度下降算法,它的原理简单来说就是计算代价函数(w,e)在某一点的导数,用先前的w值减去学习率*斜率。因为在最低点右边时,斜率大于零,减去后向最低点靠拢,在左边时同理。学习率alpha的功能也是控制震荡幅度。这就是梯度下降。而随机指的是每次随机取样本中的一个数据验证拟合度,以避免在大量样本数时算力不足的情况

它的代码实现很简单,我在这里使用plt.clf()函数和plt.pause()函数相结合来实现动态图像,使拟合过程更加生动形象。

for _ in range(100):
	for i in range(100):
		x = xs[i]
		y = ys[i]
		k = 2*(x**2)*w + (-2*x*y)
		alpha = 0.05
		w = w - alpha*k
		plt.clf()
		plt.scatter(xs, ys)
		y_pre = w*xs
		plt.xlim(0,1)
		plt.ylim(0,1.2)
		plt.plot(xs, y_pre)
		plt.pause(0.01)#暂停0.01秒

图像.gif

这样就实现了一个二维的随机梯度下降

三维梯度下降

前面实现的二维梯度下降的算法比简单,但是不是所有的图像都会经过坐标原点,这时候w与e的函数便不再适用。完全的一次函数y=wx+b需要我们绘制w,e,b的三维代价函数,来求这个三维图像的最低点,这时候,我们便需要使用三维梯度下降。

首先我们先来绘制三维的代价函数图像,可以使用matplotlib中的Axed3D来实现。

import dataset
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

m=100
xs, ys = dataset.get_beans(m)

plt.title("STF", fontsize=12)
plt.xlabel("B")
plt.ylabel("T")
plt.xlim(0,1)
plt.ylim(0,1.5)

plt.scatter(xs, ys)

w = 0.1
b = 0.1
y_pre = w*xs + b

plt.plot(xs,y_pre)


plt.show()

fig = plt.figure()
ax = Axes3D(fig)
ax.set_zlim(0,2)

ws = np.arange(-1,2,0.1)
bs = np.arange(-2,2,0.04)

for b in bs:
	es = []
	for w in ws:
		y_pre = w*xs + b
		e = np.sum((ys-y_pre)**2)*(1/m)
		es.append(e)
	ax.plot(ws,es,b,zdir='y')
plt.show()

屏幕截图 2022-01-03 135457.png

我们可以清晰地看到它有一个最低点,接下来我们就使用梯度下降算法得到它。

		dw = 2*(x**2)*w + 2*x*b - 2*x*y
		db = 2*b + 2*x*w -2*y
		alpha = 0.05
		w = w - alpha*dw
		b = b - alpha*db

分别求得w和b方向上的斜率,把它们合到一起便实现了一次三维梯度下降

接下来就再用动态图像来观察三位梯度下降的过程,完整代码如下:

import dataset
import matplotlib.pyplot as plt
import numpy as np
xs, ys = dataset.get_beans(100)


print(xs)
print(ys)


plt.title("STF", fontsize=12)
plt.xlabel("B")
plt.ylabel("T")

plt.scatter(xs, ys)

w = 0.1
b = 0.1
y_pre = w*xs + b
plt.plot(xs, y_pre)
plt.show()

for _ in range(500):
	for i in range(100):
		x = xs[i]
		y = ys[i]
		dw = 2*(x**2)*w + 2*x*b - 2*x*y
		db = 2*b + 2*x*w -2*y
		alpha = 0.01
		w = w - alpha*dw
		b = b - alpha*db


	plt.clf()
	plt.scatter(xs, ys)
	y_pre = w*xs + b
	plt.xlim(0,1)
	plt.ylim(0,1.2)
	plt.plot(xs, y_pre)
	plt.pause(0.01)#暂停0.01秒

最后能得到这样的效果

GIF.gif

本文作者:Ch1nfo

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!