使用Keras构造简单的CNN网络实例
作者:tina_ttl 发布时间:2023-08-23 04:38:21
1. 导入各种模块
基本形式为:
import 模块名
from 某个文件 import 某个模块
2. 导入数据(以两类分类问题为例,即numClass = 2)
训练集数据data
可以看到,data是一个四维的ndarray
训练集的标签
3. 将导入的数据转化我keras可以接受的数据格式
keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数
label = np_utils.to_categorical(label, numClass
此时的label变为了如下形式
(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)
4. 建立CNN模型
以下图所示的CNN网络为例
#生成一个model
model = Sequential()
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
# layer5-fully connect
model.add(Dense(numClass, init='normal'))
model.add(Activation('softmax'))
#
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")
5. 开始训练model
利用model.train_on_batch或者model.fit
补充知识:keras 多分类一些函数参数设置
用Lenet-5 识别Mnist数据集为例子:
采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。
首先import所需要的模块
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K
定义图像数据信息及训练参数
img_width, img_height = 28, 28
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000
nb_validation_samples = 10000
epochs = 50
batch_size = 32
判断使用的后台
if K.image_dim_ordering() == 'th':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
网络模型定义
主要注意最后的输出层定义
比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:
x = Dense(10,activation='softmax')(x)
此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid
img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)
x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)
x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)
x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model.summary()
利用ImageDataGenerator传入图像数据集
注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical', #多分类问题设为'categorical'
classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
)
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical'
)
训练和保存模型及权值
model.fit_generator(
train_generator,
samples_per_epoch=nb_train_samples,
nb_epoch=epochs,
validation_data=validation_generator,
nb_val_samples=nb_validation_samples
)
model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')
至此训练结束
图片预测
注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。
此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
img_addr=input('Please input your image address:')
if img_addr=="exit":
break
else:
img = load_img(img_addr, False, target_size=(28, 28))
x = img_to_array(img) / 255.0
x = np.expand_dims(x, axis=0)
result = model.predict(x)
ind=np.argmax(result,1)
print('this is a ', classes[ind])
来源:https://blog.csdn.net/tina_ttl/article/details/51034821
猜你喜欢
- 1、将python程序打包成单文件(使用 -F 参数)后,尝试运行外部文件却提示找不到的问题当你将python程序打包成单文件(使用 -F
- 在这之前我们先回顾以前用php导出excel,我直接写成方法在这里:public static function phpExcelList(
- 动态展示这是一个动态图哦导读兄弟们可以收藏一下哦!情人节可以送出去,肥学找了几朵python写的花给封装好送给大家。不是多炫酷但是有足够的用
- 从ResNet到DenseNet上图中,左边是ResNet,右边是DenseNet,它们在跨层上的主要区别是:使用相加和使用连结。最后,将这
- 作为六大python可视化库,基本上学会都是可以通吃任何领域的存在,本章要给大家介绍的Altair就是其中之一的可视化库,能够将数据转化为非
- Python是一种非常富有表现力的语言。它为我们提供了一个庞大的标准库和许多内置模块,帮助我们快速完成工作。然而,许多人可能会迷失在它提供的
- ASCII(str) 返回字符串str的第一个字符的ASCII值(str是空串时返回0)mysql> select ASCII(
- 一开始都是先去《英雄联盟》官网找到英雄及皮肤图片的网址:URL = r'https://lol.qq.com/data/info-h
- 终于开始做用户部分了,先做注册一用户 1.1用户注册 首先在Models里添加用户注册模型类UserRegister 继
- 下面是最终代码 (windows下实现的) # -*- coding: cp936 -*- import os path = 'D:
- 按照惯例,年底的淘宝的确是到了“需要改版的时候”。这次新版的淘宝首页上线,乍看并没有多少夺人眼球的地方,但仔细揣摩其中的细节,还是发现了不少
- 这篇博客将介绍如何使用OpenCV制作Mask图像掩码。使用位运算和图像掩码允许我们只关注图像中感兴趣的部分,截取出任意区域的ROIs。应用
- 本文实例讲述了Python回文字符串及回文数字判定功能。分享给大家供大家参考,具体如下:所谓回文字符串,就是一个字符串,从左到右读和从右到左
- Thrift 是一种接口描述语言和二进制通信协议。以前也没接触过,最近有个项目需要建立自动化测试,这个项目之间的微服务都是通过 Thrift
- 基于OpenCV2.4.8和 python 2.7实现简单的手势识别。以下为基本步骤 1.去除背景,提取手的轮廓2. RGB->YUV
- 在网页设计发展到一定阶段的时候就必然会和其他学科或领域只是产生交汇和共鸣,在阅读《超越CSS:web设计艺术精髓》这本书的时候,发现原来we
- Expression定义 IE5及其以后版本支持在CSS中使用expression,用来把CSS属性和Javascript表达式关联起来,这
- 页面跳转页面跳转的url中必须在最后会自动添加【\】,所以在urls.py的路由表中需要对应添加【\】from django.shortcu
- 自定义查询对象 - objects①声明一个类EntryManager,继承自models.Manager,并添加自定义函数②使用创建的自定
- System.Data.OleDb.OleDbDataAdapter与System.Data.OleDb.OleDbDataReader的区