TensorFlow进一步优化神经网络 | QIMING.INFO

TensorFlow进一步优化神经网络

在本站的这篇文章《TensorFlow实现简单神经网络》中,我们用TensorFlow实现了对MINST手写数字集的分类,分类的准确率达到了91%,本文中将优化此神经网络,将准确率提升至98%以上。

1 优化思路

对神经网络进行优化时,可以采取的思路主要有以下几种:

  • 合适的损失函数
  • 合适的激活函数
  • 合适的优化器
  • 神经网络的层数
  • 学习率的设置
  • 处理过拟合问题
  • 增大训练样本量、训练轮次

本例中,交叉熵函数比二次代价函数更适合作为损失函数,激活函数采用了tanh()函数,优化器选用了Adam函数。

神经网络的层数并不是越多越好(太复杂的神经网络解决数据量较小的问题极易出现过拟合现象),本例中设置了两层中间层。

设置学习率时,学习率太大会导致参数的值不停摇摆,而不会收敛到一个极小值,太小又会大大降低优化速度,所以我们可以先使用一个较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。

为防止过拟合问题,本例中使用了dropout机制。

在深度学习中,增大训练样本量可以使很多问题迎刃而解,但在本例中并不适用,因为本例已经使用了MNIST的全部训练数据。但是可以增加训练轮次,本例中将上文的21次提升到了51次。

好了,来敲敲代码看疗效吧~

2 代码及说明

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

# 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
# 定义dropout
keep_prob = tf.placeholder(tf.float32)
# 定义一个可变的学习率变量
lr = tf.Variable(0.001,dtype=tf.float32)

# 创建神经网络
# 设置第一层中间层的节点数为1000个
W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
b1 = tf.Variable(tf.zeros([1000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob)

# 设置第二层中间层的节点数为500个
W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
b2 = tf.Variable(tf.zeros([500])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob)

# 输出层
W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.matmul(L2_drop,W3)+b3

# 交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# 使用Adam作为优化器进行训练
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()

# 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
sess.run(init)
for epoch in range(51):
# 每训练一轮 学习率降低
sess.run(tf.assign(lr,0.001 * (0.95 ** epoch)))
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})

# 计算测试数据的准确率
test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0 })
# 计算训练数据的准确率
train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
# 输出训练轮次、测试数据准确率、训练数据准确率
print("Iter "+str(epoch)+",Testing Accuracy "+str(test_acc)+",Training Accuracy " + str(train_acc) )

3 结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
Iter 0,Testing Accuracy 0.956,Training Accuracy 0.95829093
Iter 1,Testing Accuracy 0.9665,Training Accuracy 0.97316366
Iter 2,Testing Accuracy 0.9704,Training Accuracy 0.9800182
Iter 3,Testing Accuracy 0.9741,Training Accuracy 0.98243636
Iter 4,Testing Accuracy 0.975,Training Accuracy 0.98547274
Iter 5,Testing Accuracy 0.9754,Training Accuracy 0.9882
Iter 6,Testing Accuracy 0.9773,Training Accuracy 0.9896909
Iter 7,Testing Accuracy 0.977,Training Accuracy 0.9915091
Iter 8,Testing Accuracy 0.9786,Training Accuracy 0.9924727
Iter 9,Testing Accuracy 0.9775,Training Accuracy 0.9947636
Iter 10,Testing Accuracy 0.979,Training Accuracy 0.9940909
Iter 11,Testing Accuracy 0.9797,Training Accuracy 0.99587274
Iter 12,Testing Accuracy 0.9803,Training Accuracy 0.99643636
Iter 13,Testing Accuracy 0.9818,Training Accuracy 0.9971273
Iter 14,Testing Accuracy 0.9801,Training Accuracy 0.99756366
Iter 15,Testing Accuracy 0.9821,Training Accuracy 0.9982727
Iter 16,Testing Accuracy 0.9816,Training Accuracy 0.9984546
Iter 17,Testing Accuracy 0.9818,Training Accuracy 0.99874544
Iter 18,Testing Accuracy 0.9814,Training Accuracy 0.99883634
Iter 19,Testing Accuracy 0.9828,Training Accuracy 0.9993273
Iter 20,Testing Accuracy 0.9816,Training Accuracy 0.9992545
Iter 21,Testing Accuracy 0.9838,Training Accuracy 0.9993273
Iter 22,Testing Accuracy 0.9824,Training Accuracy 0.99965453
Iter 23,Testing Accuracy 0.9829,Training Accuracy 0.9997454
Iter 24,Testing Accuracy 0.9836,Training Accuracy 0.99965453
Iter 25,Testing Accuracy 0.9828,Training Accuracy 0.9996727
Iter 26,Testing Accuracy 0.9841,Training Accuracy 0.99987274
Iter 27,Testing Accuracy 0.9823,Training Accuracy 0.9999091
Iter 28,Testing Accuracy 0.9837,Training Accuracy 0.9998909
Iter 29,Testing Accuracy 0.9846,Training Accuracy 0.99987274
Iter 30,Testing Accuracy 0.9835,Training Accuracy 0.9999273
Iter 31,Testing Accuracy 0.9829,Training Accuracy 0.9998909
Iter 32,Testing Accuracy 0.9835,Training Accuracy 0.9999818
Iter 33,Testing Accuracy 0.9843,Training Accuracy 0.9999818
Iter 34,Testing Accuracy 0.9836,Training Accuracy 1.0
Iter 35,Testing Accuracy 0.9831,Training Accuracy 1.0
Iter 36,Testing Accuracy 0.9841,Training Accuracy 0.9999818
Iter 37,Testing Accuracy 0.9835,Training Accuracy 1.0
Iter 38,Testing Accuracy 0.9847,Training Accuracy 1.0
Iter 39,Testing Accuracy 0.9836,Training Accuracy 1.0
Iter 40,Testing Accuracy 0.9844,Training Accuracy 1.0
Iter 41,Testing Accuracy 0.9844,Training Accuracy 1.0
Iter 42,Testing Accuracy 0.9843,Training Accuracy 1.0
Iter 43,Testing Accuracy 0.9847,Training Accuracy 1.0
Iter 44,Testing Accuracy 0.984,Training Accuracy 1.0
Iter 45,Testing Accuracy 0.9836,Training Accuracy 1.0
Iter 46,Testing Accuracy 0.9839,Training Accuracy 1.0
Iter 47,Testing Accuracy 0.9839,Training Accuracy 1.0
Iter 48,Testing Accuracy 0.9834,Training Accuracy 1.0
Iter 49,Testing Accuracy 0.9835,Training Accuracy 1.0
Iter 50,Testing Accuracy 0.9843,Training Accuracy 1.0

可以看出,在训练了51轮后,测试数据的准确率已经达到了98.4%,训练数据的准确率达到了100% 。

4 参考资料

[1]@Bilibili.深度学习框架Tensorflow学习与应用.2018-03
[2]郑泽宇,梁博文,顾思宇.TensorFlow:实战Goole深度学习框架(第2版)[M].北京:电子工业出版社.2018-02

-----本文结束感谢您的阅读-----
如我有幸帮到了您,那么,不妨~~~谢谢!
0%