Pytorch中如何调用forward()函数
作者:good 发布时间:2023-06-14 21:00:24
Pytorch调用forward()函数
Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型。
下面继承Module类构造本节开头提到的多层感知机。
这里定义的MLP类重载了Module类的__init__函数和forward函数。
它们分别用于创建模型参数和定义前向计算。
前向计算也即正向传播。
import torch
from torch import nn
class MLP(nn.Module):
# 声明带有模型参数的层,这里声明了两个全连接层
def __init__(self, **kwargs):
# 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
# 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Linear(784, 256) # 隐藏层
self.act = nn.ReLU()
self.output = nn.Linear(256, 10) # 输出层
# 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)
输出:
MLP( (hidden): Linear(in_features=784, out_features=256, bias=True) (act): ReLU() (output): Linear(in_features=256, out_features=10, bias=True) ) tensor([[-0.1798, -0.2253, 0.0206, -0.1067, -0.0889, 0.1818, -0.1474, 0.1845, -0.1870, 0.1970], [-0.1843, -0.1562, -0.0090, 0.0351, -0.1538, 0.0992, -0.0883, 0.0911, -0.2293, 0.2360]], grad_fn=<ThAddmmBackward>)
为什么会调用forward()呢,是因为Module中定义了__call__()函数,该函数调用了forward()函数,当执行net(x)的时候,会自动调用__call__()函数
Pytorch函数调用的问题和源码解读
最近用到 softmax 函数,但是发现 softmax 的写法五花八门,记录如下
# torch._C._VariableFunctions
torch.softmax(x, dim=-1)
# class
softmax = torch.nn.Softmax(dim=-1)
x=softmax(x)
# function
x = torch.nn.functional.softmax(x, dim=-1)
简单测试了一下,用 torch.nn.Softmax 类是最慢的,另外两个差不多
torch.nn.Softmax 源码如下,可以看到这是个类,而他这里的 return F.softmax(input, self.dim, _stacklevel=5) 调用的是 torch.nn.functional.softmax
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
Softmax is defined as:
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
When the input Tensor is a sparse tensor then the unspecifed
values are treated as ``-inf``.
Shape:
- Input: :math:`(*)` where `*` means, any number of additional
dimensions
- Output: :math:`(*)`, same shape as the input
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1]
Args:
dim (int): A dimension along which Softmax will be computed (so every slice
along dim will sum to 1).
.. note::
This module doesn't work directly with NLLLoss,
which expects the Log to be computed between the Softmax and itself.
Use `LogSoftmax` instead (it's faster and has better numerical properties).
Examples::
>>> m = nn.Softmax(dim=1)
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
__constants__ = ['dim']
dim: Optional[int]
def __init__(self, dim: Optional[int] = None) -> None:
super(Softmax, self).__init__()
self.dim = dim
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, 'dim'):
self.dim = None
def forward(self, input: Tensor) -> Tensor:
return F.softmax(input, self.dim, _stacklevel=5)
def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
torch.nn.functional.softmax 函数源码如下,可以看到 ret = input.softmax(dim) 实际上调用了 torch._C._VariableFunctions 中的 softmax 函数
def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor:
r"""Applies a softmax function.
Softmax is defined as:
:math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
It is applied to all slices along dim, and will re-scale them so that the elements
lie in the range `[0, 1]` and sum to 1.
See :class:`~torch.nn.Softmax` for more details.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
If specified, the input tensor is casted to :attr:`dtype` before the operation
is performed. This is useful for preventing data type overflows. Default: None.
.. note::
This function doesn't work directly with NLLLoss,
which expects the Log to be computed between the Softmax and itself.
Use log_softmax instead (it's faster and has better numerical properties).
"""
if has_torch_function_unary(input):
return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
if dtype is None:
ret = input.softmax(dim)
else:
ret = input.softmax(dim, dtype=dtype)
return ret
那么不如直接调用 built-in C 的函数?
但是有个博客 A selective excursion into the internals of PyTorch 里说
Note: That bilinear is exported as torch.bilinear is somewhat accidental. Do use the documented interfaces, here torch.nn.functional.bilinear whenever you can!
意思是说 built-in C 能被 torch.xxx 直接调用是意外的,强烈建议使用 torch.nn.functional.xxx 这样的接口
看到最新的 transformer 官方代码里也用的是 torch.nn.functional.softmax,还是和他们一致更好(虽然他们之前用的是类。。。)
来源:https://blog.csdn.net/weixin_39454351/article/details/106419293


猜你喜欢
- 1005:创建表失败1006:创建数据库失败1007:数据库已存在,创建数据库失败1008:数据库不存在,删除数据库失败1009:不能删除数
- Python 文件处理注意事项总结文件处理在编程中是常见的操作,文件的打开,关闭,重命名,删除,追加,复制,随机读写非常容易理解和使用。需要
- Notes怀疑模型梯度 * ,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.Ten
- 前言因为项目需要,需要批处理很多Matlab的.m文件,从每个文件中提取结果合并到一个文件中。 很明显,如果手工统计,几百个文件会累死的。
- 一些命令行工具的使用能够大大简化代码脚本的维护成本,提升复用性,今天主要是借助于python提供的几种主流的参数解析工具来实现简单的功能,主
- defer用于资源的释放,会在函数返回之前进行调用。如果有多个defer表达式,调用顺序类似于栈,越后面的defer表达式越先被调用。def
- 简称oop复习面向对象编程,简称oop [object oriented programming] 是一种python的编程思想面向过程:按
- 1、Dreamweaver中的复制我在网页中复制的文字,粘贴到Dreamweaver中时,它总是带有原来网页的格式,请问如何只复制其中的文本
- 为了熟悉Python基础语法,学习了一个经典的案例:飞机大战,最后实现效果如下:实现步骤:①下载64位对应python版本的pygame:p
- 一、官方解释:1.v-if 是“真正”的条件渲染,因为它会确保在切换过程中条件块内的事件 * 和子组件适当地被销毁和重建。2.v-if 也是
- SPI是一种JDK提供的加载插件的灵活机制,分离了接口与实现,就拿常用的数据库驱动来说,我们只需要在spring系统中引入对应的数据库依赖包
- 本文主要内容python MySQLdb数据库批量插入insert,更新update的:1.python MySQLdb的使用,写了一个基类
- 1.概述python中的日志库logging使用起来有点像log4j,但配置通常比较复杂,构建日志服务器时也不是方便。标准库logging的
- python 里面与时间有关的模块主要是 time 和 datetime如果想获取系统当前时间戳:time.time(),是一个float型
- MySQL 与 Elasticsearch 数据不对称问题解决办法jdbc-input-plugin 只能实现数据库的追加,对于 elast
- 一、效果展示在介绍代码之前,先来看下本文的实现效果。可以参考下面步骤把Python文件转化成exe,发给未安装Python的他/她。Pins
- 本文实例为大家分享了python多线程分块读取文件的具体代码,供大家参考,具体内容如下# _*_coding:utf-8_*_import
- 本文介绍Python实现端口复用实例如下所示:#coding=utf-8import socketimport sysimport sele
- 直接搭建网络必须与torchvision自带的网络的权重也就是pth文件的结构、尺寸和变量命名完全一致,否则无法加载权重文件。此时可比较2个
- SQL Server有几个版本都在使用中——4.2, 6.0, 6.5, 7.0, 2000,以及2