利用Pytorch实现简单的线性回归算法
作者:carmanzzz 发布时间:2022-09-08 00:00:09
标签:Pytorch,线性,回归
最近听了张江老师的深度学习课程,用Pytorch实现神经网络预测,之前做Titanic生存率预测的时候稍微了解过Tensorflow,听说Tensorflow能做的Pyorch都可以做,而且更方便快捷,自己尝试了一下代码的逻辑确实比较简单。
Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量),对于这些概念我也是一知半解,tensor和向量,矩阵等概念都有交叉的部分,下次有时间好好补一下数学的基础知识,不过现阶段的任务主要是应用,学习掌握思维和方法即可,就不再深究了。tensor和ndarray可以相互转换,python的numpy库中的命令也都基本适用。
一些基本的代码:
import torch #导入torch包
x = torch.rand(5, 3) #产生一个5*3的tensor,在 [0,1) 之间随机取值
y = torch.ones(5, 3) #产生一个5*3的Tensor,元素都是1
#和numpy的命令一致
#tensor的运算
z = x + y #两个tensor可以直接相加
q = x.mm(y.transpose(0, 1)) #x乘以y的转置 mm为矩阵的乘法,矩阵相乘必须某一个矩阵的行与另一个矩阵的列相等
##Tensor与numpy.ndarray之间的转换
import numpy as np #导入numpy包
a = np.ones([5, 3]) #建立一个5*3全是1的二维数组(矩阵)
b = torch.from_numpy(a) #利用from_numpy将其转换为tensor
c = torch.FloatTensor(a) #另外一种转换为tensor的方法,类型为FloatTensor,还可以使LongTensor,整型数据类型
b.numpy() #从一个tensor转化为numpy的多维数组
from torch.autograd import Variable #导入自动梯度的运算包,主要用Variable这个类
x = Variable(torch.ones(2, 2), requires_grad=True) #创建一个Variable,包裹了一个2*2张量,将需要计算梯度属性置为True
下面用pytorch做一个简单的线性关系预测
线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解。生活中的场景往往会比较复杂,需要考虑多元线性关系和非线性关系,用其他的回归分析方法求解。
这里po上代码:
#人为生成一些样本点作为原始数据
x = Variable(torch.linspace(0, 100).type(torch.FloatTensor))
rand = Variable(torch.randn(100)) * 10 #随机生成100个满足标准正态分布的随机数,均值为0,方差为1.将这个数字乘以10,标准方差变为10
y = x + rand #将x和rand相加,得到伪造的标签数据y。所以(x,y)应能近似地落在y=x这条直线上
import matplotlib.pyplot as plt #导入画图的程序包
plt.figure(figsize=(10,8)) #设定绘制窗口大小为10*8 inch
plt.plot(x.data.numpy(), y.data.numpy(), 'o') #绘制数据,考虑到x和y都是Variable,需要用data获取它们包裹的Tensor,并专成numpy
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
图示:
训练模型:
#a,b就是要构建的线性函数的系数
a = Variable(torch.rand(1), requires_grad = True) #创建a变量,并随机赋值初始化
b = Variable(torch.rand(1), requires_grad = True) #创建b变量,并随机赋值初始化
print('Initial parameters:', [a, b])
learning_rate = 0.0001 #设置学习率
for i in range(1000):
### 增加了这部分代码,清空存储在变量a,b中的梯度信息,以免在backward的过程中会反复不停地累加
if (a.grad is not None) and (b.grad is not None):
a.grad.data.zero_()
b.grad.data.zero_()
predictions = a.expand_as(x) * x+ b.expand_as(x) #计算在当前a、b条件下的模型预测数值
loss = torch.mean((predictions - y) ** 2) #通过与标签数据y比较,计算误差
print('loss:', loss)
loss.backward() #对损失函数进行梯度反传,backward的方向传播算法
a.data.add_(- learning_rate * a.grad.data) #利用上一步计算中得到的a的梯度信息更新a中的data数值
b.data.add_(- learning_rate * b.grad.data) #利用上一步计算中得到的b的梯度信息更新b中的data数值
##拟合
x_data = x.data.numpy()
plt.figure(figsize = (10, 7))
xplot = plt.plot(x_data, y.data.numpy(), 'o') # 绘制原始数据
yplot = plt.plot(x_data, a.data.numpy() * x_data + b.data.numpy()) #绘制拟合数据
plt.xlabel('X')
plt.ylabel('Y')
str1 = str(a.data.numpy()[0]) + 'x +' + str(b.data.numpy()[0]) #图例信息
plt.legend([xplot, yplot],['Data', str1]) #绘制图例
plt.show()
图示:
测试:
x_test = Variable(torch.FloatTensor([1, 2, 10, 100, 1000])) #随便选择一些点1,2,……,1000
predictions = a.expand_as(x_test) * x_test + b.expand_as(x_test) #计算模型的预测结果
predictions #输出
ok,大功告成,可以看到用pytorch做机器学习确实无论是准确度还是方便性都有优势,继续探索学习。
来源:https://www.jianshu.com/p/42eaf6e4c59a
0
投稿
猜你喜欢
- 本文实例讲述了Python SVM(支持向量机)实现方法。分享给大家供大家参考,具体如下:运行环境Pyhton3numpy(科学计算包)ma
- 本文实例为大家分享了python实现自动打卡小程序的具体代码,供大家参考,具体内容如下"""湖南大学疫情防控每
- Douglas Crockford是JavaScript开发社区最知名的权威,是JSON、JSLint、JSMin和ADSafe之父,是《J
- 万维网联盟(W3C)发布了HTML 5规格说明书的草稿 ,这是自HTML 4在十多年前发布以来的第一个主要的修订版.在这期间,随着开发者逐渐
- 本文实例讲述了PHP实现将MySQL重复ID二维数组重组为三维数组的方法。分享给大家供大家参考,具体如下:应用场景MYSQL在使用关联查询时
- 1、说明tqdm是一个方便且易于扩展的Python进度条,可以在python执行长循环时在命令行界面实时地显示一个进度提示信息,包括执行进度
- 1、Python函数函数是Python为了代码最大程度的重用和最小化代码冗余而提供的基本程序结构,用于将相关功能打包并参数化Python中可
- 1.包: package PaintBrush; /** * * @author lucifer */ public class Paint
- 一 导言设计一个好的用户系统往往不是那么容易,Django提供的用户系统可以快速实现基本的功能,并可以在此基础上继续扩展以满足我们的需求。先
- Python中的缩进(Indentation)决定了代码的作用域范围。这一点和传统的c/c++有很大的不同(传统的c/c++使用花括号{}符
- 最近有网友在留言板里问到jRaiser和jQuery的冲突问题,特此写一篇文章进行解释。冲突的根源众所周知,jQuery是通过一个全局变量$
- 阅读上一篇:javascript 45种缓动效果(一)这部分对原先的缓动函数进行抽象化,并结合缓动公式进行强化。成品的效果非常惊人逆天。走过
- 大家好,今天在写代码的时候,遇到了这样一种情况。我有如下所示的几个类用来存放程序配置(其实当做命名空间来用,同时感觉能够继承方便一点),im
- 一、IE透明度问题在IE的高度超过某一阀值时,会产生透明度不时失效的问题,这现象比较奇怪,(会有的时候全黑,有的时候全白)你有可能无法复现。
- var a= new Array(new Array(1,2),new Array('b','c')); d
- 大家平时见到google的广告太多了,但有没有兴趣知道一下它的运行过程呢?下面我们一起来看看这个广告代码的执行过程,以及其中的一些精彩内容。
- 直接看代码: 代码如下:Class GoogleTranslator sub Class_Initialize
- 今天彬Go要向大家推荐9款很棒的可在网页中绘制图表的JavaScript脚本,这些有趣的JS脚本可以帮助你快速方便的绘制图表(线、面、饼、条
- 这篇论坛文章主要介绍了Oracle数据库到SQL Server数据库主键的迁移过程,具体内容请参考下文。由于项目需要要将以前Oracle的数
- 当系统出现故障时,只要存在数据日志那么就可以利用它来恢复数据解决数据库故障。作为SQL Server数据库管理员,了解数据日志文件的作用,以