出たとこデータサイエンス

アラサーでデータサイエンティストになったエンジニアが、覚えたことを書きなぐるためのブログ

TensorFlowでMNIST

TensorFlowでMNISTの分類器を実装した記事はいくらでもあるのだが、備忘録として。

目標

TensorFlowではじめるDeepLearning実装入門の第2章に従い、MNISTデータの分類器を構築する。

方針

  • 隠れ層1つの単純なモデル
  • 隠れ層の活性化関数はRelu、出力層の活性化関数はSoftmax
  • 損失関数は平均二乗誤差
  • 100回ミニバッチを学習するたびに、正解率を表示

コード

import tensorflow.examples.tutorials.mnist as mnist

# MNISTを保存
mnist_data = mnist.input_data.read_data_sets("mnist/", one_hot=True)

# 入力
x = tf.placeholder(tf.float32, [None, 784])

# 隠れ層
w_1 = tf.Variable(tf.truncated_normal([784, 64], stddev=0.1), name="w1")
b_1 = tf.Variable(tf.zeros([64]), name="b1")
h_1 = tf.nn.relu(tf.matmul(x, w_1) + b_1)

# 出力層
w_2 = tf.Variable(tf.truncated_normal([64, 10], stddev=0.1), name="w2")
b_2 = tf.Variable(tf.zeros([10]), name="b2")
out = tf.nn.softmax(tf.matmul(h_1, w_2) + b_2)

# 正解
y = tf.placeholder(tf.float32, [None, 10])

# 最適化
loss = tf.reduce_mean(tf.square(y - out))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

# 評価
correct = tf.equal(tf.argmax(out, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

# 初期化
init = tf.global_variables_initializer()

# 実行
with tf.Session() as sess:
    
    # 初期化
    sess.run(init)
    
    # テストデータ
    test_images, test_labels = mnist_data.test.images, mnist_data.test.labels
    
    # 学習
    for i in range(1000):
        
        # 訓練データ
        train_images, train_labels = mnist_data.train.next_batch(50)
        
        # ミニバッチ学習
        sess.run(train_step, feed_dict={x: train_images, y: train_labels})
        
        # 定期的に評価
        if (i + 1) % 100 == 0:
            
            # 正解率表示
            acc_val = sess.run(accuracy, feed_dict={x: test_images, y: test_labels})
            print(f"acc_val: {acc_val: .3}")

出力

Extracting mnist/train-images-idx3-ubyte.gz
Extracting mnist/train-labels-idx1-ubyte.gz
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz
acc_val:  0.419
acc_val:  0.653
acc_val:  0.797
acc_val:  0.834
acc_val:  0.85
acc_val:  0.865
acc_val:  0.875
acc_val:  0.882
acc_val:  0.889
acc_val:  0.887

深層学習の教科書で読んだ各部品(重みや活性化関数等)を、レゴブロックのような感覚で組み合わせられるのは面白い。