关于tf.reverse_sequence()简述
作者:Cerisier 发布时间:2022-05-23 03:05:08
tf.reverse_sequence()简述
在看bidirectional_dynamic_rnn()的源码的时候,看到了代码中有调用 reverse_sequence()这一方法,于是又回去看了下这个函数的用法,发现还是有点意思的。根据名字就可以能看得出,这个方法主要是用来翻转序列的,就像双线LSTM中在反向传播那里需要从下文往上文处理一样,需要对序列做一个镜像的翻转处理。
先来看一下这个方法的定义:
reverse_sequence(
input,
seq_lengths,
seq_axis=None,
batch_axis=None,
name=None,
seq_dim=None,
batch_dim=None)
其中input是输入的需要翻转的目标张量,seq_lengths是一个张量;
其元素是input中每一处需要翻转时翻转的长度,在双向LSTM中这个值统一被设为输入语句的长度,代表着整句话都需要被翻转,而实际上张量中的元素值可以是不同的,下面的例子中就可以看出;
seq_axis和seq_dim的关系,在源码中做了如下操作:
seq_axis = deprecation.deprecated_argument_lookup("seq_axis", seq_axis,
"seq_dim", seq_dim)
返回中return gen_array_ops.reverse_sequence(..., seq_dim=seq_axis,...),同理,对于batch_axis和batch_dim也是相同的处理。意义上来说,按照官方给出的解释,“此操作首先沿着维度batch_axis对input进行分割,并且对于每个切片 i,将前 seq_lengths 元素沿维度 seq_axis 反转”。实际上通俗来理解,就是对于张量input中的第batch_axis维中的每一个子张量,在这个子张量的第seq_axis维上进行翻转,翻转的长度为 seq_lengths 张量中对应的数值。
举个例子,如果 batch_axis=0,seq_axis=1,则代表我希望每一行为单位分开处理,对于每一行中的每一列进行翻转。相反的,如果 batch_axis=1,seq_axis=0,则是以列为单位,对于每一列的张量,进行相应行的翻转。回头去看双向RNN的源码,就可以理解当time_major这一属性不同时,time_dim 和 batch_dim 这一对组合的取值为什么恰好是相反的了。
写一个简单的测试代码:
a = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
l = tf.constant([1,2,3],tf.int64) # 每一次翻转长度分别为1,2,3.由于a是(3,3)维的,所以l中数值最大只能是3
x = tf.reverse_sequence(a,seq_lengths=l,seq_axis = 0,batch_axis= 1) # 以列为单位进行翻转,翻转的是每一行的元素
y = tf.reverse_sequence(a,seq_lengths=l,seq_axis = 1,batch_axis= 0) # 以行为单位进行翻转,翻转的是每一列的元素
with tf.Session() as sess:
print(sess.run(x))
print(sess.run(y))
结果如下:
# 每一列上的元素种类没有发生变化,但是从每一行来看,行的顺序分别翻转了前1,前2,前3个元素
[[1 5 9]
[4 2 6]
[7 8 3]]
# 每一行上的元素种类没有发生变化,但是从每一列来看,列的顺序分别翻转了前1,前2,前3个元素
[[1 2 3]
[5 4 6]
[9 8 7]]
来源:https://blog.csdn.net/cerisier/article/details/80118611
猜你喜欢
- 本文介绍如何利用带进度条的ASP无组件实现断点续传下载大文件。<%@LANGUAGE="VBSCRIPT"&nbs
- 一、前言前几天在Python钻石交流群分享了一个Python基础的问题,这里拿出来给大家分享下,一起学习下。编写程序,输入若干整数(整数之间
- # -*- encoding: utf8 -*-import osimport sysimport ftplibclass FTPSync(
- 对于部署在百度应用引擎BAE上的项目,使用百度云存储BCS(Baidu Cloud Storage)是不错的存储方案。百度云存储已有Pyth
- arange()类似于内置函数range(),通过指定开始值、终值和步长创建表示等差数列的一维数组,注意得到的结果数组不包含终值。linsp
- 前言在讲解如何解决migrate报错原因前,我们先要了解migrate做了什么事情,migrate:将新生成的迁移脚本。映射到数据库中。创建
- 导语:哈喽,哈喽~大家有没有遇到过这种情况,手机用着用着没有内存了,一到设置里面一看。微信和 QQ 10G!啊这。。。。。就离谱!好说,好说
- [PHP] ; PHP还是一个不断发展的工具,其功能还在不断地删减 ; 而php.ini的设置更改可以反映出相当的变化,
- 一.图像加法运算1.Numpy库加法其运算方法是:目标图像 = 图像1 + 图像2,运算结果进行取模运算。当像素值<=255时,结果为
- 如下所示:import cv2#循环灰度图片并保存def grayImg(): for x in range(1,38): #读
- 服务器的CentOS 7中自带的python版本是python-2.7.5,需要再安装一个 python-3.8.1一、查看版本安
- sorted函数sorted(iterable,key,reverse)iterable 待排序的可迭代对象key 对应的是个函数, 该函数
- 最近需要训练一个生成对抗网络模型,然后开发接口,不得不在一台有显卡的远程linux服务器上进行,所以,趁着这个机会研究了下怎么使用vscod
- 需求最近公司干活,收到一个需求,说是让手动将数据库查出来的信息复制粘贴到excel中,在用excel中写好的公式将指定的两列数据用updat
- 前段时间做视频时需要演示电脑端的操作,因此要用到屏幕录制,下载了个迅捷屏幕录制,但是没有vip录制的视频有水印且只能录制二分钟,于是鄙人想了
- pillowPillow是PIL的一个派生分支,但如今已经发展成为比PIL本身更具活力的图像处理库。pillow可以说已经取代了PIL,将其
- 很长时间以来,一直想将自己的一些零碎的想法总结下,给自己一个完整的思维,也算是做个存档。一家之言,绝不敢说对别人会有什么帮助,对外人的层面上
- 特征选择时困难耗时的,也需要对需求的理解和专业知识的掌握。在机器学习的应用开发中,最基础的是特征工程。——吴恩达1.数据预处理数据预处理需要
- 使用Python获取电脑的磁盘信息需要借助于第三方的模块psutil,这个模块需要自己安装,纯粹的CPython下面不具备这个功能。在iPy
- 1、filter_uniquefrom collections import Counterdef filter_unique(lst):