出たとこデータサイエンス

アラサーでデータサイエンティストになったエンジニアが、覚えたことを書きなぐるためのブログ

TensorBoardでMNIST分類器の見える化

TensorBoardを初めて触ったので備忘録として。
TensorBoardは、TensorFlowのモデルの構造や精度等を可視化してくれる活かしたツールです。

目標

TensorFlowではじめるDeepLearning実装入門の第3章に従い、TensorboardでMNIST分類器の精度や重み等を表示してみる。

方針

  • モデルは前回の記事と同じ、隠れ層1つのシンプルなもの
  • TensorBoardに下記を表示する
    • 入力画像
    • 入力層から隠れ層への重みのヒストグラム
    • 学習途中のaccuracyとlossの推移
    • ネットワークのグラフ構造

ポイント

  • TensorBoardはTensorFlowの吐き出したログを元に可視化を行うため、まずログの設定をする必要がある
  • tf.summaryモジュールに、ログとして吐き出したいものを指定するメソッドがあるので、それらを各所に埋め込む
  • tf.summary.mergeメソッドで、埋め込んだ各種ログを1つのテキストとして取得
  • tf.summary.FileWriterオブジェクトで、ファイルの書き出し
  • TensorFlowの実行後、TensorBoardをコマンドラインから実行

コード

import tensorflow as tf
import tensorflow.examples.tutorials.mnist as mnist

# MNISTを保存
mnist_data = mnist.input_data.read_data_sets("mnist/", one_hot=True)

# 入力
with tf.name_scope("input"):
    
    # 入力の定義
    x = tf.placeholder(tf.float32, 
                       [None, 784], 
                       name="x")
    
    # ログに10枚ずつ入力画像を出力
    img = tf.reshape(x,
                     [-1, 28, 28, 1])
    summary_input_data = tf.summary.image("input_data",
                                          img,
                                          10)

# 隠れ層
with tf.name_scope("hidden"):
    
    # 隠れ層の計算
    w_1 = tf.Variable(tf.truncated_normal([784, 64],
                                          stddev=0.1), 
                      name="w1")
    b_1 = tf.Variable(tf.zeros([64]),
                      name="b1")
    h_1 = tf.nn.relu(tf.matmul(x, w_1) + b_1)
    
    # ログにw_1のヒストグラムを出力
    summary_w_1 = tf.summary.histogram("w1", w_1)

# 出力層
with tf.name_scope("output"):
    
    # 出力層の計算
    w_2 = tf.Variable(tf.truncated_normal([64, 10],
                                          stddev=0.1),
                      name="w2")
    b_2 = tf.Variable(tf.zeros([10]),
                      name="b2")
    out = tf.nn.softmax(tf.matmul(h_1, w_2) + b_2)

# 正解ラベル
with tf.name_scope("label"):

    # 正解ラベルの定義
    y = tf.placeholder(tf.float32,
                       [None, 10],
                       name="y")

# 誤差
with tf.name_scope("loss"):
    
    # 誤差の計算
    loss = tf.reduce_mean(tf.square(y - out))
    
    # 誤差をログに出力
    summary_loss = tf.summary.scalar("loss", loss)
    
# 訓練
with tf.name_scope("train"):

    # 訓練の設定
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

# 正解率
with tf.name_scope("accuracy"):
    
    # 正解率の計算
    correct = tf.equal(tf.argmax(out, 1),
                       tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct,
                                      tf.float32))
    
    # 正解率をログに出力
    summary_accuracy = tf.summary.scalar("accuracy", accuracy)
    
# 初期化
init = tf.global_variables_initializer()

# 実行
with tf.Session() as sess:
    
    # ログをひとまとめにする設定
    summary_op = tf.summary.merge([summary_input_data,
                                   summary_w_1,
                                   summary_accuracy,
                                   summary_loss])
    
    # ログの保管場所
    summary_writer = tf.summary.FileWriter("logs/mnist", sess.graph)
    
    # 初期化
    sess.run(init)
    
    # テストデータ
    test_images, test_labels = mnist_data.test.images, mnist_data.test.labels
    
    # 学習
    for i in range(1000):
        
        # 訓練データ
        train_images, train_labels = mnist_data.train.next_batch(50)
        
        # ミニバッチ学習
        sess.run(train_step,
                 feed_dict={x: train_images,
                            y: train_labels})
        
        # 定期的にログ書き出し
        step = i + 1
        if step % 10 == 0:
            
            # ログのテキスト取得
            summary_text = sess.run(summary_op,
                                    feed_dict={x: test_images,
                                               y: test_labels})
            
            # ログに書き出し
            summary_writer.add_summary(summary_text,
                                       step)

出力

tensorboard --logdir logs/mnistを実行し、ブラウザからlocalhost:6006にアクセスすると、可視化結果が見られる。

入力画像

f:id:mizuwan:20180624160559p:plain

入力層から隠れ層への重みのヒストグラム

f:id:mizuwan:20180624160612p:plain
だいたい0の周りに分布していることがわかる。

学習途中のaccuracyとlossの推移

f:id:mizuwan:20180624160247p:plain
こういうグラフを見ると、深層学習やってる感が凄く出る。

ネットワークのグラフ構造

f:id:mizuwan:20180624160636p:plain
上記は画像のほんの一部で、実際にはこの何十倍もの巨大な画像に。。。
おそらく何かを盛大に間違っている気がするのだが、詳しい方教えてください。