tensorflow实现KNN识别MNIST

yipeiwu_com6年前Python基础

KNN算法算是最简单的机器学习算法之一了,这个算法最大的特点是没有训练过程,是一种懒惰学习,这种结构也可以在tensorflow实现。

KNN的最核心就是距离度量方式,官方例程给出的是L1范数的例子,我这里改成了L2范数,也就是我们常说的欧几里得距离度量,另外,虽然是叫KNN,意思是选取k个最接近的元素来投票产生分类,但是这里只是用了最近的那个数据的标签作为预测值了。

__author__ = 'freedom' 
import tensorflow as tf 
import numpy as np 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
def KNN(mnist): 
 train_x,train_y = mnist.train.next_batch(5000) 
 test_x,test_y = mnist.train.next_batch(200) 
 
 xtr = tf.placeholder(tf.float32,[None,784]) 
 xte = tf.placeholder(tf.float32,[784]) 
 distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(xtr,tf.neg(xte)),2),reduction_indices=1)) 
 
 pred = tf.argmin(distance,0) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 
 right = 0 
 for i in range(200): 
  ansIndex = sess.run(pred,{xtr:train_x,xte:test_x[i,:]}) 
  print 'prediction is ',np.argmax(train_y[ansIndex]) 
  print 'true value is ',np.argmax(test_y[i]) 
  if np.argmax(test_y[i]) == np.argmax(train_y[ansIndex]): 
   right += 1.0 
 accracy = right/200.0 
 print accracy 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 KNN(mnist) 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Flask的图形化管理界面搭建框架Flask-Admin的使用教程

Flask-Admin是Flask框架的一个扩展,用它能够快速创建Web管理界面,它实现了比如用户、文件的增删改查等常用的管理功能;如果对它的默认界面不喜欢,可以通过修改模板文件来定制;...

pycharm设置鼠标悬停查看方法设置

pycharm设置鼠标悬停查看方法设置

我们使用pycharm的时候,有时遇到了不认识的方法习惯于将鼠标悬停在方法上查看方法介绍。那么如何设置呢?下面小编给大家分享一下。 首先假如我们要查看下图所示的方法,鼠标放上去并没有显示...

Flask框架学习笔记(一)安装篇(windows安装与centos安装)

Flask 依赖于两个外部库: Werkzeug  和  Jinja2  。 Werkzeug 是一个 WSGI (在 web 应用和多种服务器之间开发和部...

Python如何使用函数做字典的值

这篇文章主要介绍了Python如何使用函数做字典的值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 当需要用到3个及以上的if...e...

从训练好的tensorflow模型中打印训练变量实例

从训练好的tensorflow模型中打印训练变量实例

从tensorflow 训练后保存的模型中打印训变量:使用tf.train.NewCheckpointReader() import tensorflow as tf reader...