tensorflow实现KNN识别MNIST

yipeiwu_com5年前Python基础

KNN算法算是最简单的机器学习算法之一了,这个算法最大的特点是没有训练过程,是一种懒惰学习,这种结构也可以在tensorflow实现。

KNN的最核心就是距离度量方式,官方例程给出的是L1范数的例子,我这里改成了L2范数,也就是我们常说的欧几里得距离度量,另外,虽然是叫KNN,意思是选取k个最接近的元素来投票产生分类,但是这里只是用了最近的那个数据的标签作为预测值了。

__author__ = 'freedom' 
import tensorflow as tf 
import numpy as np 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
def KNN(mnist): 
 train_x,train_y = mnist.train.next_batch(5000) 
 test_x,test_y = mnist.train.next_batch(200) 
 
 xtr = tf.placeholder(tf.float32,[None,784]) 
 xte = tf.placeholder(tf.float32,[784]) 
 distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(xtr,tf.neg(xte)),2),reduction_indices=1)) 
 
 pred = tf.argmin(distance,0) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 
 right = 0 
 for i in range(200): 
  ansIndex = sess.run(pred,{xtr:train_x,xte:test_x[i,:]}) 
  print 'prediction is ',np.argmax(train_y[ansIndex]) 
  print 'true value is ',np.argmax(test_y[i]) 
  if np.argmax(test_y[i]) == np.argmax(train_y[ansIndex]): 
   right += 1.0 
 accracy = right/200.0 
 print accracy 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 KNN(mnist) 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python命名空间的本质和加载顺序

Python的命名空间是Python程序猿必须了解的内容,对Python命名空间的学习,将使我们在本质上掌握一些Python中的琐碎的规则。 接下来我将分四部分揭示Python命名空间的...

Python常见排序操作示例【字典、列表、指定元素等】

本文实例讲述了Python常见排序操作。分享给大家供大家参考,具体如下: 字典排序 按value排序 d1 = {"name":"python","bank":"icbc",...

Python线程指南分享

Python线程指南分享

本文介绍了Python对于线程的支持,包括“学会”多线程编程需要掌握的基础以及Python两个线程标准库的完整介绍及使用示例。 注意:本文基于Python2.4完成,;如果看到不明白的词...

Matplotlib scatter绘制散点图的方法实现

Matplotlib scatter绘制散点图的方法实现

前言 考虑到很多同学可能还没有安装matplotlib包,这里给大家提供我常用的安装方法。首先Win键 + R,输入命令cmd打开命令行工具,再次在命令行工具中输入pip install...

python 利用for循环 保存多个图像或者文件的实例

在实际应用中,会遇到保存多个文件或者图像的操作,利用for循环可以实现基本要求: for i in range(50): plt.savefig("%d.jpg"%(i+1))...