网络编程
位置:首页>> 网络编程>> Python编程>> PyTorch加载预训练模型实例(pretrained)

PyTorch加载预训练模型实例(pretrained)

作者:Xie_learning  发布时间:2021-02-04 15:26:11 

标签:PyTorch,预训练,模型,pretrained

使用预训练模型的代码如下:


# 加载预训练模型
resNet50 = models.resnet50(pretrained=True)
ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

# 读取参数
pretrained_dict = resNet50.state_dict()
model_dict = ResNet50.state_dict()

# 将pretained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# 更新现有的model_dict
model_dict.update(pretrained_dict)

# 加载真正需要的state_dict
ResNet50.load_state_dict(model_dict)

来源:https://blog.csdn.net/Xie_learning/article/details/89176636

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com