python深度学习之多标签分类器及pytorch实现源码
作者:鬼道2022 发布时间:2022-09-26 01:09:12
标签:多标签,分类器,pytorch,源码
多标签分类器
多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:
类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云
如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。
多标签分类器损失函数
代码实现
针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别1的多标签可以为[1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.Sq1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # (16, 14, 14)
)
self.Sq2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14)
nn.ReLU(),
nn.MaxPool2d(2), # (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 100)
def forward(self, x):
x = self.Sq1(x)
x = self.Sq2(x)
x = x.view(x.size(0), -1)
x = self.out(x)
## Sigmoid activation
output = F.sigmoid(x) # 1/(1+e**(-x))
return output
def loss_fn(pred, target):
return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum()
def multilabel_generate(label):
Y1 = F.one_hot(label, num_classes = 100)
Y2 = F.one_hot(label+10, num_classes = 100)
Y3 = F.one_hot(label+50, num_classes = 100)
multilabel = Y1+Y2+Y3
return multilabel
# def multilabel_generate(label):
# multilabel_dict = {}
# multi_list = []
# for i in range(label.shape[0]):
# multi_list.append(multilabel_dict[label[i].item()])
# multilabel_tensor = torch.tensor(multi_list)
# return multilabel
def train():
epoches = 10
mnist_net = CNN()
mnist_net.train()
opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002)
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True)
for epoch in range(epoches):
loss = 0
for batch_X, batch_Y in train_loader:
opitimizer.zero_grad()
outputs = mnist_net(batch_X)
loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0]
loss.backward()
opitimizer.step()
print(loss)
if __name__ == '__main__':
train()
来源:https://guidao.blog.csdn.net/article/details/122085474
0
投稿
猜你喜欢
- 以下实例用于判断一个数字是否为奇数或偶数:# -*- coding: UTF-8 -*-# Filename : test.py# Pyth
- 很多年前,我们就可以轻易的从很多国营商场、火车车厢、饭馆旅馆中看到墙上挂的那个小本本-意见薄,作为经营方与顾客沟通的
- MySQL 客户端连接成功后,通过 show [session|global]status 命令 可以提供服务器状态信息,也可以在操作系统上
- 本文实例讲述了Python实现查询某个目录下修改时间最新的文件。分享给大家供大家参考,具体如下:通过Python脚本,查询出某个目录下修改时
- 引言今天给大家推荐的是web应用安全防护方面的一个包:csrf。该包为Go web应用中常见的跨站请求伪造(CSRF)攻击提供预防功能。cs
- PHP registerXPathNamespace() 函数实例为下一个 XPath 查询创建命名空间上下文:<?php $xml=
- 本文实例讲述了Flask框架通过Flask_login实现用户登录功能。分享给大家供大家参考,具体如下:通过Flask_Login实现用户验
- 出现invalid syntax报错的几种原因这篇文章旨为刚接触python不久的朋友,提供一点帮助,请在检查代码没有主要问题时再看是否存在
- 作者:Robert Lair and Jason Lefebvr Intensity Software, Inc.
- 引言从他人的错误中学习,通过本指南避免常见陷阱和坏习惯,提高你的 Go 编程技巧在 Go 语言中,就像在任何编程语言中一样,了解常见陷阱和坏
- 测试语法如下:powered by jb51.netexec GetRecordFromPage news,newsid,10,100000
- 本文实例讲述了PHP排序二叉树基本功能实现方法。分享给大家供大家参考,具体如下:这里演示了排序二叉树节点的插入,中序遍历,极值的查找和特定值
- import osimport sysimport ftplibimport socket#########################
- Python Scrapy爬虫,听说妹子图挺火,我整站爬取了,上周一共搞了大概8000多张图片。和大家分享一下。核心爬虫代码# -*- co
- Python 如何转换string到float?简单几步,让你轻松解决。打开软件,新建python项目,如图所示右键菜单中创建.py文件,如
- 1.数据是什么?在 Python 以及其他所有面向对象编程语言中,类都是对数据的构成(状态)以及数据 能做什么(行为)的描述。由于类的使用者
- 用程序来求积分的方法有很多,这篇文章主要是有关牛顿-科特斯公式。学过插值算法的同学最容易想到的就是用插值函数代替被积分函数来求积分,但实际上
- 计算表达式:1 - 2 * ( (60-30 +(-40/5) * (9-2*5/3 + 7 /3*99/4*2998 +10 * 568/
- 线性回归是一种常见的机器学习算法,也是人工智能中常用的算法。它是一种用于预测数值型输出变量与一个或多个自变量之间线性关系的方法。例如,你可以
- 一个封装好的JavaScript拖动类,使用方便:<div id="idDrag" style="bor