pytorch 网络参数 weight bias 初始化详解
作者:Ibelievesunshine 发布时间:2023-08-12 07:43:57
权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。
在pytorch的使用过程中有几种权重初始化的方法供大家参考。
注意:第一种方法不推荐。尽量使用后两种方法。
# not recommend
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# recommend
def initialize_weights(m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
# recommend
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
nn.init.xavier_normal_(m.bias.data)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias, 0)
编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。
net = Residual() # generate an instance network from the Net class
net.apply(weights_init) # apply weight init
补充知识:Pytorch权值初始化及参数分组
1. 模型参数初始化
# ————————————————— 利用model.apply(weights_init)实现初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data = torch.ones(m.bias.data.size())
# ————————————————— 直接放在__init__构造函数中实现初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)
2. 模型参数分组weight_decay
def separate_bn_prelu_params(model, ignored_params=[]):
bn_prelu_params = []
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
ignored_params += list(map(id, m.parameters()))
bn_prelu_params += m.parameters()
if isinstance(m, nn.BatchNorm1d):
ignored_params += list(map(id, m.parameters()))
bn_prelu_params += m.parameters()
elif isinstance(m, nn.PReLU):
ignored_params += list(map(id, m.parameters()))
bn_prelu_params += m.parameters()
base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))
return base_params, bn_prelu_params, ignored_params
OPTIMIZER = optim.SGD([
{'params': base_params, 'weight_decay': WEIGHT_DECAY},
{'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
{'params': bn_prelu_params, 'weight_decay': 0.0}
], lr=LR, momentum=MOMENTUM ) # , nesterov=True
Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.
Note 2: weight decay should not be used when learning a for good performance.
Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.
3. 参数分组weight_decay–其他
第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master
自定义schedule
def schedule_lr(optimizer):
for params in optimizer.param_groups:
params['lr'] /= 10.
print(optimizer)
方法一:利用model.modules()和obj.__class__ (更普适)
# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层
# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块
def separate_irse_bn_paras(model):
paras_only_bn = []
paras_no_bn = []
for layer in model.modules():
if 'model' in str(layer.__class__): # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
continue
if 'container' in str(layer.__class__): # 去掉Sequential型的模块
continue
else:
if 'batchnorm' in str(layer.__class__):
paras_only_bn.extend([*layer.parameters()])
else:
paras_no_bn.extend([*layer.parameters()]) # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)
return paras_only_bn, paras_no_bn
方法二:调用modules.parameters和named_parameters()
但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。
def separate_resnet_bn_paras(model):
all_parameters = model.parameters()
paras_only_bn = []
for pname, p in model.named_parameters():
if pname.find('bn') >= 0:
paras_only_bn.append(p)
paras_only_bn_id = list(map(id, paras_only_bn))
paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))
return paras_only_bn, paras_no_bn
两种方法的区别
参数分组的区别,其实对应了模型构造时的区别。举例:
1、构造ResNet的basic block,在__init__()函数中定义了
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = ReLU(inplace = True)
…
2、在forward()中定义
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
…
3、对ResNet取model.name_parameters()返回的pname形如:
‘layer1.0.conv1.weight'
‘layer1.0.bn1.weight'
‘layer1.0.bn1.bias'
# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1
4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:
‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。
self.res_layer = Sequential(
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
ReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth)
)
5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一
downsample = Sequential( OrderedDict([
(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),
(‘bn_ds', BatchNorm2d(planes * block.expansion)),
]))
# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样
来源:https://blog.csdn.net/Ibelievesunshine/article/details/99478182
猜你喜欢
- mysql数据库数据表和数据表关联--问题??用户数据表user 字词作品数据表article 短信 message外键ID 主键,之间的关
- win2000注册表程序 regedt32.exe下面是解决IIS出现Active Server Pages错误&
- 完成了UI,我们就需要对数据进行处理了。在开始“数据”的本地存储之前,我们先来了解一下client-side database storag
- 一些很实用且必用的js小脚本代码:脚本1:进入页面后自动播放音乐或其它声音文件<embed src="音乐地址&q
- RSS是 Really Simple Syndication的缩写(对rss2.0而言,是这三个词的缩写,对rss1.0而言则是RDF Si
- 在一次ASP程序中不能正常连接MSSQL出现出错信息如下:以下为引用的内容:HTTP/1.1 200 OK S
- oracle10g数据备份 1.用sql/plus developer,选中要备份的数据表,右击选择"Export data&qu
- 之前看到很多人一直都问CSS 中DIV垂直居中的问题,看来对此的需求还不少。现在就把我经验拿出来分享一下,希望大家鼓鼓掌。因为在 CSS 中
- 近段时间由于修改一个ASP程序(有SQL注入漏洞),在网上找了很多相关的一些防范办法,都不近人意,所以我将现在网上的一些方法综合改良了一下,
- 我相信站长们做网站的最终目的还是想要获得收入的,我想象站长们大部分的都做Google的联盟的,我相信站长中大部分的人都有考虑过做英文站的,但
- MySQL是一个小型关系型数据库管理系统,开发者为瑞典MySQLAB公司,在2008年1月16号被Sun公司收购。MySQL被广泛地应用在I
- 在实现TextStraem的时候,找到判断文件编码的代码是VBS的,但是在JScript中是没有ASC等函数的,也不能对二进制数据进行处理,
- 随着网络技术的不断发展,网络应用已经渗透到人类社会的各个角落。作为网络世界的支撑点的网站,更是人们关注的热点:政府利用网站宣传自己的施政纲领
- 在JavaScript前端开发工作中,由于浏览器兼容性等问题,我们会经常用到“停止事件冒泡”和“阻止浏览器默认行为”。1..停止事件冒泡//
- 当点了链接后,跳出的网页地址是https://www.aspxhome.com/ 或https://www.cidianwang.
- 一、偏好资源的积累利用DreamWeaver 4制作网页会应用到许多各种类型的要素,比如色彩、图片、模板、脚本等。利用站点资源面板将这些东东
- 昨天Steve的 讲座涉及了一个我从没考虑的领域,在没法优化后台服务器的时候,如何合理的放置网页的元件让她们在浏览器里显示得更加快。这里,我
- asp.net压缩文件夹调用示例:rar("e:/www.aspxhome/", "e:/www.aspxho
- 楔子随着自媒体时代,现在对视频的处理变得越来越常见。我们可以使用Adobe的一些专业工具,但是效率不高;如果只是对视频进行一些简单的处理的话
- 我看见朋友可以把数据库的记录输出到页面表格上去,觉得很有用。这是怎么做的啊?见下:dbtable.asp<html><he