最近看了看机器学习的内容,就是自己啃,看看B站的视频课还有CSDN的大佬博客,然后对着数学公式实现算法,找一些数据集慢慢理解,代码为python3编写,数据集存储于csv文件,需要的函数库为numpy,matplotlib,csv三个库,希望大佬看到能够指出错误,联系方式QQ 2214685813
第一个接触的算法是一元线性回归,其实就是找出最贴近数据的一次函数,h(x)=m*x+b这样的函数,x为自变量或者叫特征,y是因变量或者叫结果,我的初始函数并没有将m和b都设置为0,假设我们没有垃圾数据,或者说垃圾数据不在首尾这两端,所以我的想法是将第一条数据和最后一条数据连接,求出m和b作为初始函数的参数
m=(yn-y0)/(xn-x0),b=y0-m*x
接着设定损失函数,我选择常用的误差平方求均值的方法,因为平方保证我们得到误差肯定是正数,可以省去很多计算的麻烦
loss = 1/(n*2) * ∑(yi-m*xi-b)²
然后是最重要的梯度下降法降低损失,尽可能的拟合数据,然后降低误差,当然,误差只能尽可能减小而不可能消失,如果在训练集中消灭了误差,那只能说明过拟合了,反而带来了更大的麻烦,同时说一下我踩的坑,我一直以为梯度下降的求偏导是对我们的原函数求偏导,但实际是对我们的损失函数求偏导,分别对m和b求偏导
m的偏导 -1/n * ∑ (yi-m*xi-b) * xi
b的偏导 -1/n * ∑ (yi-m*xi-b)
然后我们设定一下学习率,别设的太大,否则会直接跳过误差最小点,如果设的太小,需要的计算次数就会很多,所以学习率的设定一定要合适,我设定的的学习率0.001,学习次数10000次接下来贴上我们的代码
# 一元线性回归
import numpy as np
import matplotlib.pyplot as plt
import csv
# 求预测函数,首数据与尾数据求m与b
# m为斜率 b为截距
def getStartFunction(x,y):
# 求预测函数的 m 和 b
m = (y[len(y) - 1] - y[0]) / (x[len(x) - 1] - x[0])
b = y[0] - m * x[0]
return m,b
# 读取数据集提取x和y
def readData():
# 读取数据
csvfile = open("data.csv", "r")
reader = csv.reader(csvfile)
# 转化为numpy数组对象 并将x和y分离出
reader = np.array(list(reader))
reader = reader.astype(np.float)
# 垂直分割,分成两列
x = np.hsplit(reader, 2)[0]
y = np.hsplit(reader, 2)[1]
return x,y
# 均方误差函数
def squareLoss(x,y,m,b):
loss = sum(np.square(y - (m * x + b))) / (len(x) * 2)
return loss
# 进行梯度下降
def downGradient(x,y,m,b,learnRate):
nowM = m
nowB = b
for i in range(10000):
nowM,nowB = getGradient(x,y,nowM,nowB,learnRate)
if (i+1) % 1000 == 0:
print("迭代次数{} 均方误差{}".format(i+1,squareLoss(x,y,nowM,nowB)))
return nowM,nowB
# 获取梯度值
def getGradient(x,y,m,b,learnRate):
# 分别对m和b求偏导得到梯度值
mGradient = sum(-(1 / len(x)) * x * (y - m * x - b))
bGradient = sum(-(1 / len(x)) * (y - m * x - b))
newM = m - (learnRate * mGradient)
newB = b - (learnRate * bGradient)
return newM,newB
# 展示回归图形
def showFunction(x,y,m,b):
plt.scatter(x,y)
line = np.arange(0, 13)
lineFunction = line * m + b
plt.plot(line,lineFunction)
plt.show()
if __name__ == "__main__":
x,y = readData()
m,b = getStartFunction(x,y)
newm,newb = downGradient(x,y,m,b,0.001)
showFunction(x,y,newm,newb)
print(newm, " ",newb)