0 引言
昨天我做(学者网上的教程)了一个线性回归的模型,可以参考这篇博客,用的TensorFlow框架,今天我继续学习,用TensorFlow框架对mnist数据集进行手写体识别。
1 准备数据
这里用到的是TensorFlow里面的placeholder
占位符,类似constant
,只不过先定义但是不赋值,用起来的时候再赋值。
- mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets(“./mnist_data”, one_hot=True) - y_true
y_true = tf.placeholder(dtype=tf.float32, shape=[None, 10], name=”y_true”) - label
X = tf.placeholder(dtype=tf.float32, shape=[None, 784], name=”X”)
2 构造模型
- 参数
weights = tf.Variable(initial_value=tf.random_normal(shape=[784,10]),name=”weight”)
bias = tf.Variable(initial_value=tf.random_normal([10]),name=”bias”) - 模型
y_predict = tf.matmul(X,weights) + bias
3 构造损失函数
- loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
4 优化损失
- 梯度下降法
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss)
5 计算准确率
- 预测值和真实值进行比较
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1)) - 求平均
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
6 初始化变量
- 初始化
init = tf.global_variables_initializer() - 在会话中运行
sess.run(init)
7 开启会话
- 拉取mnist训练集
image, label = mnist_data.train.next_batch(batch_size) - 开始训练
_optimizer, loss_value, accuracy_value = sess.run([optimizer, loss, accuracy], feed_dict={X: image, y_true: label})
8 运行效果
9 源代码
1 | import time |
写在最后
欢迎大家关注鄙人的公众号【麦田里的守望者zhg】,让我们一起成长,谢谢。
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Comment