1 准备工作
利用TensorFlow做一个线性回归的案例,我们需要知道我们要干一些什么?我们要的干的活:
- 1 准备数据
- 2 构造模型
- weights
- bias
- 3 构造损失函数
- 4 优化损失
- 5 运行代码(session)
2 开始写代码
- 1 准备数据
X = tf.random_normal(shape=[100, 1],mean=0.0,stddev=1.0,name=”feature”) # 用高斯分布生成随机值,
y_true = tf.matmul( X, [[0.8]]) + 0.7 # 用高斯分布生成随机值, 默认均值是0 方差是1 - 2 构造模型
- weights
weights = tf.Variable(initial_value=tf.random_normal(shape=[1,1]),name=”weights”) - bias
bias = tf.Variable(initial_value=tf.random_normal(shape=[1, 1]),name=”bias”) - 构造模型
y_predict = tf.matmul(X,weights) + bias
- weights
- 3 构造损失函数
loss = tf.reduce_mean(tf.square(y_true - y_predict)) - 4 优化损失
optimaizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss) - 5 运行代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21# 显式的初始化变量
init = tf.global_variables_initializer()
# 开启会话
with tf.Session() as sess:
sess.run(init)
# 训练前
print("训练前参数:")
print("weights:{0}|bias:{1}|loss:{2}".format(weights.eval(),bias.eval(),loss.eval()))
print("-"*10 + "我是分割线" + "-"*10)
# 开始训练
count = 0 # 训练次数
# 这里我们认为误差小于e^-9时,为合格
while loss.eval() > np.exp(-9):
count += 1
sess.run(optimaizer)
# 训练后
# print()
print("第{3}次训练后参数:weights:{0}|bias:{1}|loss:{2}".format(weights.eval(),bias.eval(),loss.eval(),count))
3 搜集变量并显示tensorboard
步骤:
- 1 创建事件
file_writer = tf.summary.FileWriter(graph=sess.graph,logdir=”./temp”) - 2 收集变量
1
2
3tf.summary.scalar("loss",loss) # 标量
tf.summary.histogram("weights",weights)
tf.summary.histogram("bias",bias) # 收集高维变量 - 3 合并变量
merged = tf.summary.merge_all() - 4 运行合并后的变量
summary = sess.run(merged)
file_writer.add_summary(summary,count)
4 运行效果
4.1 迭代回归
4.2 显示TensorBoard
进入到TensorBoard的目录,输入下面的命令tensorboard --logdir="./"
,如下:
然后在浏览器输入http://127.0.0.1:6006访问,如下所示:
如损失函数的图像吐如下:
5 增加命名空间
5.1 怎么增加命名空间?
给我们需要添加命名空间的地方加上如下语句:
1 | # 这里以准备数据为例 |
增加完命名空间之后的TensorBoard显示:
7 模型的保存与加载
7.1 为什么要是用这个功能?
这里使用的线性回归只有几个参数,假如我们使用深度学习算法对模型进行训练时,可能会用到很多参数,假如中间有意外发生,比如说断电等,那么我们的训练就相当于白训练了。所以这里引入模型的保存与加载功能。
7.2 如何使用
首先定义一个保存器saver
:
1 | # 定义保存模型的保存器 |
然后在训练时保存模型,当断电发生时,我们在加载模型。
- 保存模型
1
2if count % 10 == 0:
saver.save(sess,"./temp/model.ckpt") - 加载模型
1
2
3if os.path.exists("./temp/checkpoint"):
saver.restore(sess,"./temp/model.ckpt")
print("第{3}次训练后参数:weights:{0}|bias:{1}|loss:{2}".format(weights.eval(), bias.eval(), loss.eval(), count))
8 添加命令行参数
8.1 定义命令参数
1 | ## 定义命令行参数 |
8.2 使用命令行参数
8.3 tf.app.run
的使用
首先定义个main
函数,该main
函数必须带参数argv
,否则会报错
如下:
1 | def main(argv): |
然后在if __name__ == '__main__':
中使用即可:
1 | if __name__ == '__main__': |
9 完整代码
1 | import tensorflow as tf |
写在最后
欢迎大家关注鄙人的公众号【麦田里的守望者zhg】,让我们一起成长,谢谢。
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Comment