使用pytorch完成kaggle猫狗图像识别方式
作者:charsenhz 发布时间:2023-04-10 08:39:01
kaggle是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,在这上面有非常多的好项目、好资源可供机器学习、深度学习爱好者学习之用。
碰巧最近入门了一门非常的深度学习框架:pytorch,所以今天我和大家一起用pytorch实现一个图像识别领域的入门项目:猫狗图像识别。
深度学习的基础就是数据,咱们先从数据谈起。此次使用的猫狗分类图像一共25000张,猫狗分别有12500张,我们先来简单的瞅瞅都是一些什么图片。
我们从下载文件里可以看到有两个文件夹:train和test,分别用于训练和测试。以train为例,打开文件夹可以看到非常多的小猫图片,图片名字从0.jpg一直编码到9999.jpg,一共有10000张图片用于训练。
而test中的小猫只有2500张。仔细看小猫,可以发现它们姿态不一,有的站着,有的眯着眼睛,有的甚至和其他可识别物体比如桶、人混在一起。
同时,小猫们的图片尺寸也不一致,有的是竖放的长方形,有的是横放的长方形,但我们最终需要是合理尺寸的正方形。小狗的图片也类似,在这里就不重复了。
紧接着我们了解一下特别适用于图像识别领域的神经网络:卷积神经网络。学习过神经网络的同学可能或多或少地听说过卷积神经网络。这是一种典型的多层神经网络,擅长处理图像特别是大图像的相关机器学习问题。
卷积神经网络通过一系列的方法,成功地将大数据量的图像识别问题不断降维,最终使其能够被训练。CNN最早由Yann LeCun提出并应用在手写体识别上。
一个典型的CNN网络架构如下:
这是一个典型的CNN架构,由卷基层、池化层、全连接层组合而成。其中卷基层与池化层配合,组成多个卷积组,逐层提取特征,最终完成分类。
听到上述一连串的术语如果你有点蒙了,也别怕,因为这些复杂、抽象的技术都已经在pytorch中一一实现,我们要做的不过是正确的调用相关函数,
我在粘贴代码后都会做更详细、易懂的解释。
import os
import shutil
import torch
import collections
from torchvision import transforms,datasets
from __future__ import print_function, division
import os
import torch
import pylab
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
一个正常的CNN项目所需要的库还是蛮多的。
import math
from PIL import Image
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
# assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
w,h = img.size
min_edge = min(img.size)
rate = min_edge / self.size
new_w = math.ceil(w / rate)
new_h = math.ceil(h / rate)
return img.resize((new_w,new_h))
这个称为Resize的库用于给图像进行缩放操作,本来是不需要亲自定义的,因为transforms.Resize已经实现这个功能了,但是由于目前还未知的原因,我的库里没有提供这个函数,所以我需要亲自实现用来代替transforms.Resize。
如果你的torch里面已经有了这个Resize函数就不用像我这样了。
data_transform = transforms.Compose([
Resize(84),
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean = [0.5,0.5,0.5],std = [0.5,0.5,0.5])
])
train_dataset = datasets.ImageFolder(root = 'train/',transform = data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 4,shuffle = True,num_workers = 4)
test_dataset = datasets.ImageFolder(root = 'test/',transform = data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = 4,shuffle = True,num_workers = 4)
transforms是一个提供针对数据(这里指的是图像)进行转化的操作库,Resize就是上上段代码提供的那个类,主要用于把一张图片缩放到某个尺寸,在这里我们把需求暂定为要把图像缩放到84 x 84这个级别,这个就是可供调整的参数,大家为部署好项目以后可以试着修改这个参数,比如改成200 x 200,你就发现你可以去玩一盘游戏了~_~。
CenterCrop用于从中心裁剪图片,目标是一个长宽都为84的正方形,方便后续的计算。
ToTenser()就比较重要了,这个函数的目的就是读取图片像素并且转化为0-1的数字。
Normalize作为垫底的一步也很关键,主要用于把图片数据集的数值转化为标准差和均值都为0.5的数据集,这样数据值就从原来的0到1转变为-1到1。
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16 * 18 * 18,800)
self.fc2 = nn.Linear(800,120)
self.fc3 = nn.Linear(120,2)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16 * 18 * 18)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
来源:https://blog.csdn.net/xcszjjh1991/article/details/79256576
猜你喜欢
- 一、无组件上传的原理我还是一点一点用一个实例来说明的吧,客户端HTML如下。要浏览上传附件,我们通过<input type="
- PDO::getAttributePDO::getAttribute — 取回一个数据库连接的属性(PHP 5 >= 5.1.0, P
- 功能描述目标完成多账号微信小程序每天自动签到输出签到成功则向微信群发送签到成功的信息否则提示用户签到失败,需手动签到包管理requestsi
- 显然,效果很实用。对于这个效果,我们并不解释如何去使用效果库,而是讲解如何创建类似的效果,并保持他的可用性,分离式(unobtrusive)
- Python使用QRCode模块生成二维码QRCode官网https://pypi.python.org/pypi/qrcode/5.1简介
- 这次我们讨论的是,区分有单选框的选项和普通的选项~~乍听起来,可能不太理解我说了什么,下面举个例子先~~1、标签的单选~~例如QQ秀的支付流
- 前言:在开发中经常会与时间打交道,如:获取事件戳,时间戳的格式化等,这里简要记录一下python操作时间的方法。python中常见的处理时间
- 本文实例讲述了python双向链表原理与实现方法。分享给大家供大家参考,具体如下:双向链表一种更复杂的链表是“双向链表”或“双面链表”。每个
- 关于mysql数据库在Linux下的应用一直以来都是我认为比较棘手的,这次通过搭建Linux学习环境顺便研究和学习Mysql数据库在Linu
- 简单低级的爬虫速度快,伪装度低,如果没有反爬机制,它们可以很快的抓取大量数据,甚至因为请求过多,造成服务器不能正常工作。而伪装度高的爬虫爬取
- 删除列表中元素的方法有三种:1. del命令使用del命令能够删除列表中指定位置上的元素,也可以删除整个列表。2. pop( )方法使用列表
- 上一篇我们写了Django基于类如何增删改数据的方法,方法虽然简单,但新手可能对其原理不是很清楚,那么我们这次就用Django提供的Mode
- 原理1.使用python中的mtplotlib库。2.立体爱心面公式点画法(实心)代码import matplotlib.pyplot as
- 要随机生成字符串代码如下:在MySQL中定义一个随机串的方法,然后再SQL语句中调用此方法。随机串函数定义方法:CREATE DEFINER
- 代码如下#encoding:utf-8import requestsfrom lxml import etreeimport xlwtimp
- 1.在浏览器上搜索PyCharmhttps://www.jetbrains.com/pycharm/download/#section=wi
- python注释方法方式1单行注释:shift + #(在代码的最前面输入,非选中代码进行注释)多行注释:同单行一样在每一行的前面输入shi
- 新闻系统、blog系统等都可能用到将动态页面生成静态页面的技巧来提高页面的访问速度,从而减轻服务器的压力,本文为大家搜集整理了ASP编程中常
- 这篇文章主要介绍了Python使用configparser库读取配置文件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考
- 实时画图import matplotlib.pyplot as pltax = [] # 定义一个 x 轴的空列表用来接收动态