tensorflow建立一个简单的神经网络的方法
作者:Mr丶Caleb 发布时间:2022-09-27 17:01:51
标签:tensorflow,神经网络
本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。
如何添加一层网络
代码如下:
def add_layer(inputs, in_size, out_size, activation_function=None):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。
创建一些数据
# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。
numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source]
Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].
noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。
numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。
添加占位符,用作输入
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
添加隐藏层和输出层
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
计算误差,并用梯度下降使得误差最小
# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
完整代码如下:
from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def add_layer(inputs, in_size, out_size, activation_function=None):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()
for i in range(1000):
# training
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
# to visualize the result and improvement
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction, feed_dict={xs: x_data})
# plot the prediction
lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
plt.pause(0.1)
运行结果:
来源:http://blog.csdn.net/qq_30159351/article/details/52639291


猜你喜欢
- Tkinter是python的GUI模块,内含各种窗口控件,利用其中messagbox可以制作各种信息弹出窗口。以下是制作信息提示框的代码:
- 在使用Python编写的应用的过程中,有时会遇到多个文件之间传递同一个全局变量的情况,此时通过配置文件定义全局变量是一个比较好的选择。首先配
- 在说到什么是回表查询的时候,有两个概念需要先解释清楚:分别是聚集索引(聚簇索引)和非聚集索引(非聚簇索引)聚集索引和非聚集索引MySQL规定
- 本文实例讲述了python实现自动更换ip的方法。分享给大家供大家参考。具体实现方法如下:#!/usr/bin/env python#-*-
- 本文实例讲述了golang简单获取上传文件大小的方法。分享给大家供大家参考,具体如下:package mainimport ( &
- Flask 环境配置你的应用程序可能需要大量的软件包才能正常的工作。如果都不需要 Flask 包的话,你有可能读错了教程。当应用程序运行的时
- 这篇文章主要介绍了Python多线程获取返回值代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 在windows中罗技K380可以安装Logitech Options来实现这个Fn锁定功能。在linux中如何实现Logitech Opt
- 前言在这篇文章 Go Mutex:保护并发访问共享资源的利器 中,主要介绍了 Go 语言中互斥锁 Mutex&
- 学了python后,之前一些我们常用的方法,也可以换一种思路用python中的知识来解决。相信操作出来后,能收获一大批小粉丝们。就像我们没学
- tablewidgetpyqt5的tablewidget组件比较特殊,每个方格可以装载其他组件来搭配实现不同的效果,所以在qtdesigne
- 本文实例为大家分享了java模拟ATM功能的具体代码,供大家参考,具体内容如下有三个类:Test.java、Customer.java、Cu
- 项目中使用mp3格式进行音效播放,遇到一个mp3文件在程序中死活播不出声音,最后发现它是wav格式的文件,却以mp3结尾。要对资源进行mp3
- python配置matlab库1、确认配置版本matlab与python有相互对应的版本,需要两者版本兼容。如不兼容,需要调整matlab版
- 这个是今年年初写的一篇,拿出来温习下。指针让程序结构变得混乱,也让程序执行效率提高,因此在oo的语言中不提倡指针的使用,使得程序结构清晰易读
- [1]定义:正则又叫规则或模式,是一个强大的字符串匹配工具,在javascript中是一个对象[2]特性:[2.1]贪婪性,匹配最长的[2.
- 将数据库中的数据保存在excel文件中有很多种方法,这里主要介绍pyExcelerator的使用。一、前期准备(不详细介绍MySQL)pyt
- 1.删除序列相同元素并保持顺序如果仅仅就是想消除重复元素,通常可以简单的构造一个集合,利用集合之间元素互不相同的特性就可以消除重复,但是这种
- python提供了4种方式来满足进程间的数据通信1. 使用multiprocessing.Queue可以在进程间通信,但不能在Pool池创建
- mysql 5.7.19 winx64安装教程记录如下,分享给大家step1官方下载地址:https://dev.mysql.com/dow