pytorch如何定义新的自动求导函数
作者:l8947943 发布时间:2021-02-10 20:14:49
标签:pytorch,自动求导,函数
pytorch定义新的自动求导函数
在pytorch中想自定义求导函数,通过实现torch.autograd.Function并重写forward和backward函数,来定义自己的自动求导运算。参考官网上的demo:传送门
直接上代码,定义一个ReLu来实现自动求导
import torch
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 我们使用ctx上下文对象来缓存,以便在反向传播中使用,ctx存储时候只能存tensor
# 在正向传播中,我们接收一个上下文对象ctx和一个包含输入的张量input;
# 我们必须返回一个包含输出的张量,
# input.clamp(min = 0)表示讲输入中所有值范围规定到0到正无穷,如input=[-1,-2,3]则被转换成input=[0,0,3]
ctx.save_for_backward(input)
# 返回几个值,backward接受参数则包含ctx和这几个值
return input.clamp(min = 0)
@staticmethod
def backward(ctx, grad_output):
# 把ctx中存储的input张量读取出来
input, = ctx.saved_tensors
# grad_output存放反向传播过程中的梯度
grad_input = grad_output.clone()
# 这儿就是ReLu的规则,表示原始数据小于0,则relu为0,因此对应索引的梯度都置为0
grad_input[input < 0] = 0
return grad_input
进行输入数据并测试
dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 使用torch的generator定义随机数,注意产生的是cpu随机数还是gpu随机数
generator=torch.Generator(device).manual_seed(42)
# N是Batch, H is hidden dimension,
# D_in is input dimension;D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype,generator=generator)
y = torch.randn(N, D_out, device=device, dtype=dtype, generator=generator)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True, generator=generator)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True, generator=generator)
learning_rate = 1e-6
for t in range(500):
relu = MyRelu.apply
# 使用函数传入参数运算
y_pred = relu(x.mm(w1)).mm(w2)
# 计算损失
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# 传播
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
pytorch自动求导与逻辑回归
自动求导
retain_graph设为True,可以进行两次反向传播
逻辑回归
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)
#========生成数据=============
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
x0 = torch.normal(mean_value*n_data,1)+bias#类别0数据
y0 = torch.zeros(sample_nums)#类别0标签
x1 = torch.normal(-mean_value*n_data,1)+bias#类别1数据
y1 = torch.ones(sample_nums)#类别1标签
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)
#==========选择模型===========
class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.features = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features(x)
x = self.sigmoid(x)
return x
lr_net = LR()#实例化逻辑回归模型
#==============选择损失函数===============
loss_fn = nn.BCELoss()
#==============选择优化器=================
lr = 0.01
optimizer = torch.optim.SGD(lr_net.parameters(),lr = lr,momentum=0.9)
#===============模型训练==================
for iteration in range(1000):
#前向传播
y_pred = lr_net(train_x)#模型的输出
#计算loss
loss = loss_fn(y_pred.squeeze(),train_y)
#反向传播
loss.backward()
#更新参数
optimizer.step()
#绘图
if iteration % 20 == 0:
mask = y_pred.ge(0.5).float().squeeze() #以0.5分类
correct = (mask==train_y).sum()#正确预测样本数
acc = correct.item()/train_y.size(0)#分类准确率
plt.scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c='r',label='class0')
plt.scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c='b',label='class1')
w0,w1 = lr_net.features.weight[0]
w0,w1 = float(w0.item()),float(w1.item())
plot_b = float(lr_net.features.bias[0].item())
plot_x = np.arange(-6,6,0.1)
plot_y = (-w0*plot_x-plot_b)/w1
plt.xlim(-5,7)
plt.ylim(-7,7)
plt.plot(plot_x,plot_y)
plt.text(-5,5,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.title('Iteration:{}\nw0:{:.2f} w1:{:.2f} b{:.2f} accuracy:{:2%}'.format(iteration,w0,w1,plot_b,acc))
plt.legend()
plt.show()
plt.pause(0.5)
if acc > 0.99:
break
来源:https://blog.csdn.net/l8947943/article/details/105633826
0
投稿
猜你喜欢
- 我就废话不多说了,大家还是直接看代码吧~import numpy as np kernel = np.array([1, 1, 1, 2])
- 前三篇文章中,明确了栅格系统的设计细节和适用范围。这一篇将集中讨论960栅格系统的技术实现。Blueprint的实现Blueprint是一个
- 前言本文主要介绍了关于Python中TCP socket的写法,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧。一、 服务器
- 上个星期,大佬分享了一个验证身份证号合法性的库:id_validator,没空去试着用一下看看,今天有点时间,来试着用下这个库。1、首先,要
- 前几天写了一个ajax的,总感觉代码比较多,今天晚上又得写了一下,感觉代码还是比较多,但还好的是,比较通用。谁有办法优化一下当然好。&nbs
- 看了很多介绍javascript面向对象技术的文章,很晕.为什么?不是因为写得不好,而是因为太深奥.javascript中的对象还没解释清楚
- 本文介绍Python实现端口复用实例如下所示:#coding=utf-8import socketimport sysimport sele
- 举个例子吧Django最佳实践与部署:Nginx + Gunicorn + Supervisor(Ubuntu和CentOS)http://
- max(iterable, *[, key, default])max(arg1, arg2, *args[, key])函数功能为取传入的
- python 里面与时间有关的模块主要是 time 和 datetime如果想获取系统当前时间戳:time.time(),是一个float型
- 朋友去面试。对方问他:说说你之前做的那个站,有什么地方好的?朋友就说:用户体验比别的站好。对方又问:你怎么知道用户体验比别的好?朋友于是又磕
- SQL Server通常都运行在多处理器的服务器上,这一点在现在尤为普遍。原因是多内核的处理器越来越普及。那么,在多处理器环境下,Windo
- 假设在搜索框search中输入:“asp 编程” 先得到输入框中的内容:search=request("search")
- 前言哈希 又称作 “散列”,它接收任何一组任意长度的输入信息,通过 哈希 算法变换成固定长度的数据指
- web2.0的标志是Ajax的异步通信的发掘,给我们带来像google map,google suggest 这样令人惊叹的东西。而Ajax
- 1. CBV加装饰器CBV加装饰器有三种方法,案例:要求登录(不管get请求还是post请求)后才可以访问HTML代码index.html&
- 在腾讯的微信公众平台开发者文档,网页授权获取用户基本信息这一节中写道”在微信公众号请求用户网页授权之前,开发者需要先到公众平台网站的我的服务
- IE的有条件注释是一种专有的(因此是非标准的)、对常规(X)HTML注释的Miscrosoft扩展。顾名思义,有条件注释使你能够根据条件(比
- 原型扩展:>> String.prototype :String对象原型扩展 --------------
- 从而达到方便快捷的目的,但是它在存储信息的时候往往会有一些敏感的东西,这些东西可能成为被攻击的目标,如银行的账号、信用卡事务或档案记录等。这