网络编程
位置:首页>> 网络编程>> Python编程>> pytorch DataLoader的num_workers参数与设置大小详解

pytorch DataLoader的num_workers参数与设置大小详解

作者:龙南希  发布时间:2022-12-22 12:15:58 

标签:pytorch,DataLoader,num,workers

Q:在给Dataloader设置worker数量(num_worker)时,到底设置多少合适?这个worker到底怎么工作的?


   train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=4)

参数详解:

1、每次dataloader加载数据时:dataloader一次性创建num_worker个worker,(也可以说dataloader一次性创建num_worker个工作进程,worker也是普通的工作进程),并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。

然后,dataloader从RAM中找本轮迭代要用的batch,如果找到了,就使用。如果没找到,就要num_worker个worker继续加载batch到内存,直到dataloader在RAM中找到目标batch。一般情况下都是能找到的,因为batch_sampler指定batch时当然优先指定本轮要用的batch。

2、num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮...迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。

3、如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度更慢。

设置大小建议:

1、Dataloader的num_worker设置多少才合适,这个问题是很难有一个推荐的值。有以下几个建议:

2、num_workers=0表示只有主进程去加载batch数据,这个可能会是一个瓶颈。

3、num_workers = 1表示只有一个worker进程用来加载batch数据,而主进程是不参与数据加载的。这样速度也会很慢。

num_workers>0 表示只有指定数量的worker进程去加载数据,主进程不参与。增加num_works也同时会增加cpu内存的消耗。所以num_workers的值依赖于 batch size和机器性能。

4、一般开始是将num_workers设置为等于计算机上的CPU数量

5、最好的办法是缓慢增加num_workers,直到训练速度不再提高,就停止增加num_workers的值。

补充:pytorch中Dataloader()中的num_workers设置问题

如果num_workers的值大于0,要在运行的部分放进__main__()函数里,才不会有错:


import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional
import matplotlib.pyplot as plt
import torch.utils.data as Data

BATCH_SIZE=5

x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
torch_dataset=Data.TensorDataset(x,y)
loader=Data.DataLoader(
   dataset=torch_dataset,
   batch_size=BATCH_SIZE,
   shuffle=True,
   num_workers=2,
)

def main():
   for epoch in range(3):
       for step,(batch_x,batch_y) in enumerate(loader):
           # training....
           print('Epoch:',epoch,'| step:',step,'| batch x:',batch_x.numpy(),
                 '| batch y:',batch_y.numpy())

if __name__=="__main__":
   main()

'''
# 下面这样直接运行会报错:
for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        # training....
         print('Epoch:',epoch,'| step:',step,'| batch x:',batch_x.numpy(),
                 '| batch y:',batch_y.numpy()
'''

来源:https://blog.csdn.net/qq_28057379/article/details/115427052

0
投稿

猜你喜欢

  • 在实际使用python时,我们会将一些公共的东西写到一些基础模块中,供其他模块去调用,这时会去import自定义的一些基础模块,然后来导入。
  • 有时候我们不希望浏览器使用缓存加快网页的显示,尤其是那些论坛之类的频繁更新内容的网页,在网上有说可以使用下面的方法来屏蔽缓存,但是我试了效果
  • 如下所示:a = [1,1,1,2,3,45,1,2,1]a.remove(1) result: [1,1,2,3,45,1,2,1]whi
  • OK,今天我们来学习一下 python 中的日志模块,日志模块也是我们日后的开发工作中使用率很高的模块之一,接下来们就看一看今天具体要学习日
  • 引言使用python进行接口测试时常常需要接口用例测试数据、断言接口功能、验证接口响应状态等,如果大量的接口测试用例脚本都将接口测试用例数据
  • 这篇文章主要介绍了原生Java操作mysql数据库过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
  • 经过了上个星期的努力学习,对处理html又有了新的发现感觉真的很不错可以说js的威力在处理html代码方面我又有所领悟了1、截取特定长度字符
  • 何为共线性:共线性问题指的是输入的自变量之间存在较高的线性相关度。共线性问题会导致回归模型的稳定性和准确性大大降低,另外,过多无关的维度计算
  • 一年一度的元宵节刚刚过去,由于时间关系,在元宵节当天晚上11点多才完成本文灯笼的绘制。这两天又在忙着别的事情,所以现在才跟大家分享。一、效果
  • import requestsimport reimport jsonimport ossession = requests.session
  • 问:如何给导入文件加上时间戳标记?答:请参考下文中介绍的两种方法:1.在DOS下从系统获得时间戳利用Dos命令取得时间戳:C:\>ec
  • 以下列出了两种数据库的方法:ASP+Access20001.要获取的ID值字段属性必须设为:自动编号(我们假设字段名为recordID)2.
  • 一、制作思路由于注册的时候常常会用到注册码来防止机器恶意注册,这里我发表一个产生png图片验证码的基本图像,简单的思路分析:1、产生一张pn
  • 请问如何处理Oracle中较大的文本数据?我们可在ASP中予以解决,如在Oracle8i中文版中,建立数据表:CREATE TABLE SY
  • 引入:通常,钓鱼网站本质是本质搭建一个跟正常网站一模一样的页面,用户在该页面上完成转账功能转账的请求确实是朝着正常网站的服务端提交,唯一不同
  • PDOStatement::bindParamPDOStatement::bindParam — 绑定一个参数到指定的变量名(PHP 5 &
  • 我就废话不多说了,直接上代码吧:package mainimport ("flag""fmt"&qu
  • 经常需要读取某个文件夹下所有的图像文件。我使用python写了个简单的代码,读取某个文件夹下某个后缀的文件,将文件名生成为文本(csv格式)
  • 先看效果,实现一个图片左右摇动,在一般的H5宣传页,商家活动页面我们会看到这样的动画,小程序的动画效果不同于css3动画效果,是通过js来完
  • 前言相关性分析算是很多算法以及建模的基础知识之一了,十分经典。关于许多特征关联关系以及相关趋势都可以利用相关性分析计算表达。其中常见的相关性
手机版 网络编程 asp之家 www.aspxhome.com