博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow初入门
阅读量:4078 次
发布时间:2019-05-25

本文共 2421 字,大约阅读时间需要 8 分钟。

终究还是回到了NN上面来,两年左右的时间没做了,尽管基础知识还是没忘记,但是很多这两年出来的很新的开发平台倒是不会用了,去美国之前还是主打theano,现在这么多的优秀平台,考虑到最近CPU资源的机器比较多,所以考虑了使用Tensorflow。今天算是真正第一天上手,就记录一下我的感受吧。

其实写法和theano差不多,但是已经很简洁了,尤其theano中需要自己定义模型,模型里面如何shuffle以及把误差加起来平均等都需要手工来写,但是在TF中很简洁了已经。这里来弄个做softmax的在mnist下的例子代码,然后着重讲解几句吧:

1   from tensorflow.examples.tutorials.mnist import input_data2   import tensorflow as tf3   mnist = input_data.read_data_sets('MNIST_data', one_hot=True)4   x = tf.placeholder(tf.float32, shape=[None, 784])5   y = tf.placeholder(tf.float32, shape=[None, 10])6   W = tf.Variable(tf.zeros([784, 10]))7   b = tf.Variable(tf.zeros([10]))8   sess = tf.InteractiveSession()9   sess.run(tf.global_variables_initializer())10  pred_y = tf.matmul(x, W) + b11  cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=pred_y))12  train_step =        tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)13  for _ in range(1000):14      batch = mnist.train.next_batch(100)15      train_step.run(feed_dict={x:batch[0], y:batch[1]})16  correct_pred = tf.equal(tf.argmax(pred_y, 1), tf.argmax(y, 1))17  accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))18  print(sess.run(accuracy, {x:mnist.test.images, y:mnist.test.labels}))

1-2行就不多说了吧,导入必要的包,以及当前例子中的数据包,这个数据包是首先看看当前文件夹下的MNIST_data下有没有数据,没有的话下载,我是自己下载放进来的,这个MNIST_data是运行一次它自己就创建的了

3行是加载进来mnist数据,其中one_hot指的是是否把标签换成矩阵的形式,如果做过matlab下的mnist的分类,大家肯定都知道这回事,为true则转换,否则不转换。我觉得转换不转都可以,如果不转换的话在计算correct_pred的时候就不需要argmax了。

4-7定义输入数据的placeholder,以及权重和偏置的variable,之所以定义成variable是因为TF需要在每一步迭代之后要自动更新它们的值,如果定义成constant则无法更新了,这个更新是优化过程中自动更新的。placeholder就是个占位符号,theano中也有的,写法不一样,意思就是占位置,具体使用的时候在传递进来值。

8-9就不多说了,TF的内部要求意味着创建session对象为了在构建grahic之后可以编译进去,也就是前段和后端的连接(这句话不理解也没事,可以忽略)

第10行计算输入数据的输出,也就是提供给softmax的输入,这里看到没有任何非线性的变换哈

第11行计算训练数据的交叉熵,也就是训练数据在softmax的输出和真实训练数据类标的误差

第12行构建训练对象,0.5指的是梯度下降的学习率吧

第13-15行,是训练1000次的mini_batch

第16行计算正确分类的个数,第17行把这个布尔类型的分类个数转换成浮点数的百分比

第18行传递进入测试数据信息,进行测试数据上的分类率的计算。

初次之外,我想说几点:

1:第15行,如果新手才开始,就会看到tutorial上写的是应该用sess.run(train_step, {x:…})这样的方式,但是此处用了直接train_step.run… 这是因为我们生命session的对象的方式是Interactive
Session,此处也可以直接用sess.run(..)的方式来运行。

2:其实次数有2个session的运行,除了刚才说的train_step,另外一个就是最后一句,计算accuracy,注意此处,其实重新运行了一次,但是用的是更新后的W和b,因为这2个地方的值在上面已经训练过了,其实tutorial的地方写的是accuracy.eval(feed_dict={x:…})的方式来打印的,但是也可以用sess的方式来运行的。如果从头开始看这个tutorial的话,直接看eval就不太理解,这个地方隐式的用了session对象,但是这种使用session对象的方式必须是InteractiveSession的方式来生命session才可以。

总的来说,TF算是简单。

转载地址:http://doini.baihongyu.com/

你可能感兴趣的文章
我发现无人机把Pixhawk2.4.8换成pixhawk4 不需要换很多东西,也不需要再焊什么东西
查看>>
**当你做东西卡住的时候,也是在提醒你需要花时间沉下来深入学东西了,不然做不出来的
查看>>
1-13 笔记
查看>>
pixhawk的offboard模式其实是和定高(AltHold) 定点(loiter) 这些模式平级的一个模式
查看>>
我发觉大家都在讲VINS
查看>>
OpenCV编程中的C++语法知识点回顾
查看>>
C++的模板化等等的确实比C用起来方便多了
查看>>
ROS是不是可以理解成一个虚拟机,就是操作系统之上的操作系统
查看>>
用STL algorithm轻松解决几道算法面试题
查看>>
ACfly之所以不怕炸机因为它觉得某个传感器数据不安全就立马不用了
查看>>
我发觉,不管是弄ROS OPENCV T265二次开发 SDK开发 caffe PX4 都是用的C++
查看>>
ROS的安装(包含文字和视频教程,我的ROS安装教程以这篇为准)
查看>>
国内有个码云,gitee
查看>>
我居然在GAAS的硬件清单上看到了权盛光流,又想起ZN无人机课程他们购买无人机配件也是在权盛
查看>>
原来我之前一直用的APM固件....现在很多东西明白了。
查看>>
GAAS提供的TX2镜像就给你装好了小觅SDK
查看>>
七月在线GAAS-2 ROS与OFFBOARD MODE 笔记
查看>>
我看了下GAAS里ROS里发布的pose 的 topic包含position和orientation,我觉得position是实际位置,orientation是期望位置。错了,是标准的里程计消息。
查看>>
realsense-ros里里程计相关代码
查看>>
transfer.py就是把vins的坐标系转为PX4的坐标系,其实也是个ROS功能包,包含代码分析。(最后发现是改成GAAS里pose的消息形式)
查看>>