AMP Tensor Cores节省内存PyTorch模型详解
作者:ronghuaiyang 发布时间:2021-07-08 01:30:32
导读
只需要添加几行代码,就可以得到更快速,更省显存的PyTorch模型。
你知道吗,在1986年Geoffrey Hinton就在Nature论文中给出了反向传播算法?
此外,卷积网络最早是由Yann le cun在1998年提出的,用于数字分类,他使用了一个卷积层。但是直到2012年晚些时候,Alexnet才通过使用多个卷积层来实现最先进的imagenet。
那么,是什么让他们现在如此出名,而不是之前呢?
只有在我们拥有大量计算资源的情况下,我们才能够在最近的过去试验和充分利用深度学习的潜力。
但是,我们是否已经足够好地使用了我们的计算资源呢?我们能做得更好吗?
这篇文章的主要内容是关于如何利用Tensor Cores和自动混合精度更快地训练深度学习网络。
什么是Tensor Cores?
根据NVIDIA的网站:
NVIDIA Turing和Volta GPUs都是由Tensor Cores驱动的,这是一项突破性的技术,提供了突破性的AI性能。Tensor Cores可以加速AI核心的大矩阵运算,在一次运算中就可以完成混合精度的矩阵乘法和累加运算。在一个NVIDIA GPU上有数百个Tensor Cores并行运行,这大大提高了吞吐量和效率。
简单地说,它们是专门的cores,非常适合特定类型的矩阵操作。
我们可以将两个FP16矩阵相乘,并将其添加到一个FP16/FP32矩阵中,从而得到一个FP16/FP32矩阵。Tensor cores支持混合精度数学,即以半精度(FP16)进行输入,以全精度(FP32)进行输出。上述类型的操作对许多深度学习任务具有内在价值,而Tensor cores为这种操作提供了专门的硬件。
现在,使用FP16和FP32主要有两个好处。
FP16需要更少的内存,因此更容易训练和部署大型神经网络。它还只需要较少的数据移动。
数学运算在降低精度的Tensor cores运行得更快。NVIDIA给出的Volta GPU的确切数字是:FP16的125 TFlops vs FP32的15.7 TFlops(8倍加速)。
但也有缺点。当我们从FP32转到FP16时,我们需要降低精度。
FP32 vs FP16: FP32 有8个指数位和23个分数位,而FP16有5个指数位和10个分数位。
但是FP32真的有必要吗?
实际上,FP16可以很好地表示大多数权重和梯度。所以存储和使用FP32是很浪费的。
那么,我们如何使用Tensor Cores?
我检查了一下我的Titan RTX GPU有576个tensor cores和4608个NVIDIA CUDA核心。但是我如何使用这些tensor cores呢?
坦白地说,NVIDIA用几行代码就能提供自动混合精度,因此使用tensor cores很简单。我们需要在代码中做两件事:
需要用到FP32的运算比如Softmax之类的就分配用FP32,而Conv之类的操作可以用FP16的则被自动分配用FP16。
使用损失缩放 为了保留小的梯度值。梯度值可能落在FP16的范围之外。在这种情况下,梯度值被缩放,使它们落在FP16范围内。
如果你还不了解背景细节也没关系,代码实现相对简单。
使用PyTorch进行混合精度训练:
让我们从PyTorch中的一个基本网络开始。
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device="cuda")
y = torch.randn(N, D_out, device="cuda")
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for to in range(500):
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
为了充分利用自动混合精度训练的优势,我们首先需要安装apex库。只需在终端中运行以下命令。
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
然后,我们只需向神经网络代码中添加几行代码,就可以利用自动混合精度(AMP)。
from apex import amp
N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device="cuda")
y = torch.randn(N, D_out, device="cuda")
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
for to in range(500):
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
在这里你可以看到我们用amp.initialize
初始化了我们的模型。我们还使用amp.scale_loss
来指定损失缩放。
基准测试
git clone https://github.com/MLWhiz/data_science_blogs
cd data_science_blogs/amp/pytorch-apex-experiment/
python run_benchmark.py
python make_plot.py --GPU 'RTX' --method 'FP32' 'FP16' 'amp' --batch 128 256 512 1024 2048
这会在home目录中生成下面的图:
在这里,我使用不同的精度和批大小设置训练了同一个模型的多个实例。我们可以看到,从FP32到amp,内存需求减少,而精度保持大致相同。时间也会减少,但不会减少那么多。这可能是由于数据集或模型太简单。
根据NVIDIA给出的基准测试,AMP比标准的FP32快3倍左右,如下图所示。
在单精度和自动混合精度两种精度下,加速比为固定周期训练的时间比。
来源:https://juejin.cn/post/7157663977437134884
猜你喜欢
- 知识点爬虫基本流程re正则表达式简单使用requestsjson数据解析方法视频数据保存开发环境Python 3.8Pycharm爬虫基本思
- Python的命名空间是Python程序猿必须了解的内容,对Python命名空间的学习,将使我们在本质上掌握一些Python中的琐碎的规则。
- 新建label与button,并设置位置(grid)import tkinter as tkroot = tk.Tk()label = tk
- 本文实例为大家分享了php微信公众号开发之快递查询的具体代码,供大家参考,具体内容如下快递查询数组用法foreach查询接口是:爱快递:ht
- 1、XML 是什么?XML仅仅是一种数据存放格式,这种格式是一种文本(虽然XML规范中也提供了存放二进制数据的解决方案)。事实上有很多文本格
- 使用python进行websocket的客户端压力测试,这个代码是从github上 找到。然后简单修改了下。大神运用了进程池,以及线程池的内
- 编写一个名为 collatz()的函数,它有一个名为 number 的参数。如果参数是偶数,那么 collatz()就打印出 number
- 1、Introduction之前写过2篇文章,分别是:Mysql主从同步的原理 Myql主从同步实战 基于此,我们再实
- Turtle库是Python语言中一个很流行的绘制图像的函数库,想象一个小乌龟,在一个横轴为x、纵轴为y的坐标系原点,(0,0)位置开始,它
- 前言:在Python里面,只要类型对象实现了__iter__,那么它的实例对象就被称为可迭代对象(Iterable),比如字符串、元组、列表
- 准备工作:python:https://www.python.org/downloads/Dev-C++:https://sourcefor
- subplot(arg1, arg2, arg3)arg1: 在垂直方向同时画几张图arg2: 在水平方向同时画几张图arg3: 当前命令修
- If...Then...Else 语句的一种变形,即添加任意多个 ElseIf 子句以扩充 If...Then...Else 语句的功能,允
- 目录1、请求模块:urllib.requestdata参数:post请求urlopen()中的参数timeout:设置请求超时时间:响应类型
- Beautiful Soup使用时,一般可以通过指定对应的name和attrs去搜索,特定的名字和属性,以找到所需要的部分的html代码。但
- JScript 具有全范围的运算符,包括算术、逻辑、位、赋值以及其他某些运算符。算术运算符描述 符号 负值 - 递增 ++ 递减 ? 乘法
- 前言本文主要给大家介绍了Go语言中函数new与make的使用和区别,关于Go语言中new和make是内建的两个函数,主要用来创建分配类型内存
- 本文实例讲述了php使用pthreads v3多线程实现抓取新浪新闻信息。分享给大家供大家参考,具体如下:我们使用pthreads,来写一个
- 执行环境会负责管理代码执行过程中使用的内存,编写JavaScript程序时,所需内存的分配以及无用内存的回收完全实现自动管理。原理:找出那些
- 1.在OpenCV中我们经常会遇到一个名字:Mask(掩膜)。很多函数都使用到它,那么这个Mask到底什么呢?2.如果我们想要裁剪图像中任意