对Pytorch中nn.ModuleList 和 nn.Sequential详解
作者:ustc_lijia 发布时间:2023-07-04 06:54:46
简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。
需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:
nn.ModuleList(
[nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
range(nstack)])
其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!
摘录自
nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:
class LinearNet(nn.Module):
def __init__(self, input_size, num_layers, layers_size, output_size):
super(LinearNet, self).__init__()
self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
self.linears.append(nn.Linear(layers_size, output_size)
nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:
class Flatten(nn.Module):
def forward(self, x):
N, C, H, W = x.size() # read in N, C, H, W
return x.view(N, -1)
simple_cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
Flatten(),
nn.Linear(5408, 10),
)
In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module
On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.
来源:https://blog.csdn.net/xiaojiajia007/article/details/82118559
猜你喜欢
- 可以使用numpy中的linspace函数np.linspace(start, stop, num, endpoint, retstep,
- 启动sql server Net Start MSSqlServer 暂停sql server Net Pause MSSqlServer
- 介绍Zmail 使得在python3中发送和接受邮件变得更简单。你不需要手动添加服务器地址、端口以及适合的协议,zmail会帮你完成。此外,
- 本文实例讲述了Python实现迭代时使用索引的方法。分享给大家供大家参考,具体如下:索引迭代Python中,迭代永远是取出元素本身,而非元素
- 本文实例讲述了PHP解析xml格式数据工具类。分享给大家供大家参考,具体如下:class ome_xml { /**  
- 时间格式化函数,代码简单但较实用代码很简单,谁都能看懂Function fmstr(str, str1, Lens) Dim str2For
- 1、字典中的键存在时,可以通过字典名+下标的方式访问字典中改键对应的值,若键不存在则会抛出异常。如果想直接向字典中添加元素可以直接用字典名+
- python 字符串和日期之间转换 StringAndDate &nb
- 这片文章只对本地存储方法做介绍,若要查看本地存储组件使用方法的介绍请稍等。本地数据持久化(或者也叫做浏览器本地存储)是一种在浏览器中长久保存
- PDOStatement::setAttributePDOStatement::setAttribute — 设置一个语句属性(PHP 5
- eclipse安装Python插件之后,主要是为了方便Python代码就可以再Eclipse进行代码脚本,使用Eclipse开发Python
- 代码实现如下:import win32com.client,os,timedef word_encryption(path, passwor
- 处理办法,删除该文件,或清空该文件内容;我的处理是清空后,再设置该文件权限为Everyone拒绝访问。
- 例如:我们在百度中搜索 词典网,则网址后面的参数就是http://www.baidu.com/s?cl=3&wd=%B4%CA%B5
- 前言Python可以操作Excel的模块不止一种,我习惯使用的写入模块是xlwt(一般都是读写模块分开的)python中使用xlwt操作ex
- Web,全称为 World Wide Web,是 Internet 上最重要和最为人们所熟知的应用之一。Web 是指 Internet 上所
- 在本章中,您将详细了解Python中各种加密模块.加密模块它包含所有配方和基元,并在Python中提供高级编码接口.您可以使用以下命令安装加
- 一、python线程的模块1.1 thread和threading模块thread模块提供了基本的线程和锁的支持threading提供了更高
- 本文实例讲述了Python实现读取文件最后n行的方法。分享给大家供大家参考,具体如下:# -*- coding:utf8-*-import
- PHP addAttribute() 函数实例给根元素和 body 元素添加一个属性:<?php $note=<<<