pytorch中使用LSTM详解
作者:qyhyzard 发布时间:2021-01-08 04:27:10
LSMT层
可以在troch.nn
模块中找到LSTM类
lstm = torch.nn.LSTM(*paramsters)
1、__init__方法
首先对nn.LSTM
类进行实例化,需要传入的参数如下图所示:
一般我们关注这4个:
input_size
表示输入的每个token的维度,也可以理解为一个word的embedding的维度。hidden_size
表示隐藏层也就是记忆单元C的维度,也可以理解为要将一个word的embedding维度转变成另一个大小的维度。除了C,在LSTM中输出的H的维度与C的维度是一致的。num_layers
表示有多少层LSTM,加深网络的深度,这个参数对LSTM的输出的维度是有影响的(后文会提到)。bidirectional
表示是否需要双向LSTM,这个参数也会对后面的输出有影响。
2、forward方法的输入
将数据input传入forward方法进行前向传播时有3个参数可以输入,见下图:
这里要注意的是
input
参数各个维度的意义,一般来说如果不在实例化时制定batch_first=True
,那么input
的第一个维度是输入句子的长度seq_len,第二个维度是批量的大小,第三个维度是输入句子的embedding维度也就是input_size,这个参数要与__init__
方法中的第一个参数对应。另外记忆细胞中的两个参数
h_0
和c_0
可以选择自己初始化传入也可以不传,系统默认是都初始化为0。传入的话注意维度[bidirectional * num_layers, batch_size, hidden_size]。
3、forward方法的输出
forward方法的输出如下图所示:
一般采用如下形式:
out,(h_n, c_n) = lstm(x)
out
表示在最后一层上,每一个时间步的输出,也就是句子有多长,这个out的输出就有多长;其维度为[seq_len, batch_size, hidden_size * bidirectional]。因为如果的双向LSTM,最后一层的输出会把正向的和反向的进行拼接,故需要hidden_size * bidirectional。h_n
表示的是每一层(双向算两层)在最后一个时间步上的输出;其维度为[bidirectional * num_layers, batch_size, hidden_size]
假设是双向的LSTM,且是3层LSTM,双向每个方向算一层,两个方向的组合起来叫一层LSTM,故共会有6层(3个正向,3个反向)。所以h_n是每层的输出,bidirectional * num_layers = 6。c_n
表示的是每一层(双向算两层)在最后一个时间步上的记忆单元,意义不同,但是其余均与 h_n
一样。
LSTMCell
可以在troch.nn
模块中找到LSTMCell类
lstm = torch.nn.LSTMCell(*paramsters)
它的__init__
方法的参数设置与LSTM类似,但是没有num_layers
参数,因为这就是一个细胞单元,谈不上多少层和是否双向。forward
的输入和输出与LSTM均有所不同:
其相比LSTM,输入没有了时间步的概念,因为只有一个Cell单元;输出 也没有out
参数,因为就一个Cell,out
就是h_1
,h_1
和c_1
也因为只有一个Cell单元,其没有层数上的意义,故只是一个Cell的输出的维度[batch_size, hidden_size].
代码演示如下:
rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
# 从输入的第一个维度也就是seq_len上遍历,每循环一次,输入一个单词
for i in range(input.size()[0]):
# 更新细胞记忆单元
hx, cx = rnn(input[i], (hx, cx))
# 将每个word作为输入的输出存起来,相当于LSTM中的out
output.append(hx)
output = torch.stack(output, dim=0)
来源:https://blog.csdn.net/qq_42961603/article/details/119638341


猜你喜欢
- Python 代码库之Tuple如何append元素tuple不像array给我们提供了append函数,我们可以通过下面的方式添加t=[1
- 本文实例讲述了python使用 request 发送表单数据操作。分享给大家供大家参考,具体如下:# !/usr/bin/env pytho
- dim dr dr="2123123" dr1=Cint(dr) dr2=Clng(dr) 可参考如下函数说明: CIn
- 关系型数据库:关系型数据库的优势:保持数据的一致性(事务处理)由于以标准化为前提,数据更新的开销很小(相同的字段基本上都只有一处)可以进行J
- javascript 跨域问题以及解决办法什么是跨域问题?跨域这个问题是由于浏览器的同源策略引起的,请求的URL地址,必须与浏览器的URL是
- 前言python内置了一些非常巧妙而且强大的内置函数,对初学者来说,一般不怎么用到,我也是用了一段时间python之后才发现,哇还有这么好的
- 组合字面量组合字面量是最直接方式初始化Go对象,假设定义了Book类型,使用字面量初始化代码如下:type Book struct { &n
- MySQL 是完全网络化的跨平台关系型数据库系统,同时是具有客户机/服务器体系结构的分布式数据库管理系统。MySQL 是完全网络化
- 如下所示:# 选取等于某些值的行记录 用 == df.loc[df['column_name'] == some_value
- 1、如果之前已经安装我们先卸载一下yum -y remove php*2、由于linux的yum源不存在php7.x,所以我们要更改yum源
- 实现思路很多网站都有拼图验证码1.首先要了解拼图验证码的生成原理2.制定破解计划,考虑其可能性和成功率。3.编写脚本很多网站的拼图验证码都是
- 前言突然想起来之前讲SQL注入时忘记讲一下这个宽字节注入了,因为这个知识点还是挺重要的,所以本文就带大家了解一下宽字节注入的原理以及应用方法
- 点击获取后,返回2s后的鼠标位置,显示在文本框(需要用pip命令安装所需的的库)(pip install 模块名比如 安装pyautogui
- requests 是一个非常小巧全面的库,应用它可以很容易写出与服务器进行交互的程序,今天遇到了一个问题,与服务器交互时,url都是http
- 传入参数一个,为元素的id值或元素本身,返回为元素的真实背景色值(字符串)。 值得一提的是IE里面返回的是16进制的值,而Mozi
- <script language="javascript"> functio
- 原文链接:Histogram of Oriented Gradients(文中的图片均来自翻译原文)什么是特征描述子特征描述子一张图片或者一
- 前言最近遇到一个mysql在RR级别下的死锁问题,感觉有点意思,研究了一下,做个记录。涉及知识点:共享锁、排他锁、意向锁、间隙锁、插入意向锁
- 下面基础的解释一下这错误: 1:本质上的错误: object a;//a是Null对象 protected void Page_Load(o
- 前言在一个分布式环境中,每台机器上可能需要启动和停止多个进程,使用命令行方式一个一个手动启动和停止非常麻烦,而且查看每个进程的状态也很不方便