对Pytorch 中的contiguous理解说明
作者:gdymind 发布时间:2022-04-14 08:34:47
最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解。
在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。
这些操作是:
narrow(),view(),expand()和transpose()
举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。
转置的tensor和原tensor的内存是共享的!
为了证明这一点,我们来看下面的代码:
x = torch.randn(3, 2)
y = x.transpose(x, 0, 1)
x[0, 0] = 233
print(y[0, 0])
# print 233
可以看到,改变了y的元素的值的同时,x的元素的值也发生了变化。
也就是说,经过上述操作后得到的tensor,它内部数据的布局方式和从头开始创建一个这样的常规的tensor的布局方式是不一样的!于是…这就有contiguous()的用武之地了。
在上面的例子中,x是contiguous的,但y不是(因为内部数据不是通常的布局方式)。
注意不要被contiguous的字面意思“连续的”误解,tensor中数据还是在内存中一块区域里,只是布局的问题!
当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。
一般来说这一点不用太担心,如果你没在需要调用contiguous()的地方调用contiguous(),运行时会提示你:
RuntimeError: input is not contiguous
只要看到这个错误提示,加上contiguous()就好啦~
补充:pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax
gather
torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.
input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。
import torch
a=torch.tensor([[.1,.2,.3],
[1.1,1.2,1.3],
[2.1,2.2,2.3],
[3.1,3.2,3.3]])
b=torch.LongTensor([[1,2,1],
[2,2,2],
[2,2,2],
[1,1,0]])
b=b.view(4,3)
print(a.gather(1,b))
print(a.gather(0,b))
c=torch.LongTensor([1,2,0,1])
c=c.view(4,1)
print(a.gather(1,c))
输出:
tensor([[ 0.2000, 0.3000, 0.2000],
[ 1.3000, 1.3000, 1.3000],
[ 2.3000, 2.3000, 2.3000],
[ 3.2000, 3.2000, 3.1000]])
tensor([[ 1.1000, 2.2000, 1.3000],
[ 2.1000, 2.2000, 2.3000],
[ 2.1000, 2.2000, 2.3000],
[ 1.1000, 1.2000, 0.3000]])
tensor([[ 0.2000],
[ 1.3000],
[ 2.1000],
[ 3.2000]])
squeeze
将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)
import torch
a=torch.randn(2,1,1,3)
print(a)
print(a.squeeze())
输出:
tensor([[[[-0.2320, 0.9513, 1.1613]]],
[[[ 0.0901, 0.9613, -0.9344]]]])
tensor([[-0.2320, 0.9513, 1.1613],
[ 0.0901, 0.9613, -0.9344]])
expand
扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)
import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)
输出:
tensor([[[ 0.0608],
[ 2.2106]],
[[-1.9287],
[ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
[ 2.2106, 2.2106, 2.2106]],
[[-1.9287, -1.9287, -1.9287],
[ 0.8748, 0.8748, 0.8748]]])
sum
size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量
import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.sum())
print(a.sum(dim=1))
输出:
tensor(60)
tensor([[ 5, 10, 15],
[ 5, 10, 15]])
contiguous
返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。
可以通过is_contiguous查看张量内存是否连续。
import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous)
print(a.contiguous().view(4,3))
输出:
<built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
tensor([[ 1, 2, 3],
[ 4, 8, 12],
[ 1, 2, 3],
[ 4, 8, 12]])
softmax
假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。
import torch
import torch.nn.functional as F
a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
b=F.softmax(a,dim=1)
print(b)
输出:
tensor([[ 0.5000, 0.5000],
[ 0.7311, 0.2689],
[ 0.8808, 0.1192],
[ 0.2689, 0.7311],
[ 0.1192, 0.8808]])
max
返回最大值,或指定维度的最大值以及index
import torch
a=torch.tensor([[.1,.2,.3],
[1.1,1.2,1.3],
[2.1,2.2,2.3],
[3.1,3.2,3.3]])
print(a.max(dim=1))
print(a.max())
输出:
(tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
tensor(3.3000)
argmax
返回最大值的index
import torch
a=torch.tensor([[.1,.2,.3],
[1.1,1.2,1.3],
[2.1,2.2,2.3],
[3.1,3.2,3.3]])
print(a.argmax(dim=1))
print(a.argmax())
输出:
tensor([ 2, 2, 2, 2])
tensor(11)
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/gdymind/article/details/82662502
猜你喜欢
- 本文实例为大家分享了tensorflow如何批量读取图片的具体代码,供大家参考,具体内容如下代码:import tensorflow as
- 在数据库开发方面,通过单表所表现的实现,有时候需要组合查询来找到我们需要的记录集,这时候我们就会用到连接查询。连接查询主要包括以下几个方面:
- 可以自动轮换的页签 tabs with auto play fucntion<html><head><meta
- 这是由十几位视觉设计师设计的挂历,每个月份都是不同的风格,就像每个月都有不同温度和心情一样,思维跳跃性很大,可以作为挂历设计参考。当然,如果
- 首先,让我们介绍一下什么是pytorch,它是一个基于Python的开源深度学习框架,它提供了两个核心功能:张量计算和自动求导。张量计算张量
- 半透明效果有时候会给页面增加不少色彩,特别是Vista盛行之后,半透明效果更加受推崇。在诸多可用于Web浏览的图片格式中,只有PNG格式和G
- 本文代码来之《数据分析与挖掘实战》,在此基础上补充完善了一下~代码是基于SVM的分类器Python实现,原文章节题目和code关系不大,或者
- 直接上代码:<span style="font-family: arial,helvetica,sans-serif; fo
- 整个重装步骤大致分四个步骤进行,第一步,备份原mysql中的所有数据库。第二步,完全卸载mysql第三步,下载安装新版mysql第四步,导入
- Python语言简洁明了,可以用较少的代码实现同样的功能。这其中Python的四个内置数据类型功不可没,他们即是list, tuple, d
- 本文介绍了django反向解析URL和URL命名空间,分享给大家,具体如下:首先明确几个概念:1.在html页面上的内容特别是向用户展示的u
- 为了使一个MySQL系统安全,强烈要求你考虑下列建议……当你连接一个MySQL服务器时,你通常应
- python 消除序列的重复值,并保持原来顺序1、如果仅仅消除重复元素,可以简单的构造一个集合$ pythonPython 3.5.2 (d
- 基于微信开放的个人号接口python库itchat,实现对微信好友的获取,并对省份、性别、微信签名做数据分析。效果:直接上代码,建三个空文本
- 除了使用pycharm外,还可使用vscode来操作pyqt,方法如下:1. 在vscode中配置相关的pyqt的相关根据自己实际情况修改第
- 如何用组件实现自动发送电子邮件?我想做一个能够自动发送电子邮件的程序,该如何做? 这就要用到w3 upl
- pandas 中 inplace 参数在很多函数中都会有,它的作用是:是否在原对象基础上进行修改inplace = True:不创建新的对象
- InnoDB给MySQL提供了具有提交,回滚和崩溃恢复能力的事务安全(ACID兼容)存储引擎。InnoDB锁定在行级并且也在SELECT语句
- 如下所示:import cv2import mathimport numpy as npdef move(img): height, wid
- 这学期在学习编译原理,最近的上机作业就是做一个简单的词法分析器,在做的过程中,突然有个需求就是判断一个字符串是否为合法的标示符,因为我是用p