tensorflow实现softma识别MNIST

yipeiwu_com5年前Python基础

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist) 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python执行子进程实现进程间通信的方法

本文实例讲述了python执行子进程实现进程间通信的方法。分享给大家供大家参考。具体实现方法如下: a.py: import subprocess, time subproc = s...

pyqt5的QWebEngineView 使用模板的方法

说明1:关于QWebEngineView pyqt5 已经抛弃 QtWebKit和QtWebKitWidgets,而使用最新的QtWebEngineWidgets。 QtWebEng...

python实现最长公共子序列

python实现最长公共子序列

最长公共子序列python实现,最长公共子序列是动态规划基本题目,下面按照动态规划基本步骤解出来。 1.找出最优解的性质,并刻划其结构特征 序列a共有m个元素,序列b共有n个元素,如果a...

使用Python通过win32 COM打开Excel并添加Sheet的方法

使用Python通过win32 COM打开Excel并添加Sheet的方法

对win32 COM不是很熟悉,不知道一个程序究竟有多少属性或者方法可以操作。仅仅是一个Sheet页的添加就费了我好长时间,因为这种成功来自于试探。 编辑代码如下: #!/usr/b...

深入了解如何基于Python读写Kafka

这篇文章主要介绍了深入了解如何基于Python读写Kafka,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 本篇会给出如何使用pyth...