出たとこデータサイエンス

アラサーでデータサイエンティストになったエンジニアが、覚えたことを書きなぐるためのブログ

TensorFlowでRNN(LSTM)実装

初めてRNN(LSTM)を実装したので備忘録として。

目標

※ LSTMの理論的説明はこちらを御覧ください。

方針

  • MNISTの各画像を、上から1行ずつスキャンし、時系列データとしてLSTMに入力
  • LSTMの最後の中間層の次の全結合層を出力層とする

コード

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

#mnistデータを格納したオブジェクトを呼び出す
mnist = input_data.read_data_sets("../data/mnist",
                                  one_hot=True)

def seq_mnist():
    
    # ログのリスト
    summaries = []
    
    # データのサイズ
    num_image = 784 # 画像のピクセル数
    num_class = 10  # 正解のクラス数
    num_input = 28  # RNNの入力長
    num_seq = 28    # RNNの時間長

    # データの定義
    with tf.name_scope("data"):
        
        # 入力画像
        x = tf.placeholder(tf.float32,
                           [None, num_image])
        
        # 正解ラベル
        y = tf.placeholder(tf.float32,
                           [None, num_class])

    # LSTM
    with tf.name_scope("lstm"):
        
        # 画像を1行ずつ読み込みよう整形
        input = tf.reshape(x,
                           [-1, num_seq, num_input])
        
        # ユニット数128個のLSTMセル
        stacked_cells = [tf.nn.rnn_cell.LSTMCell(num_units=128) for _ in range(2)]
        cell = tf.nn.rnn_cell.MultiRNNCell(cells=stacked_cells)
        out_lstm, states = tf.nn.dynamic_rnn(cell=cell,
                                             inputs=input,
                                             dtype=tf.float32)
        
        # 最後の層を取得
        out_lstm_list = tf.unstack(out_lstm,
                                   axis=1)
        lstm_last = out_lstm_list[-1]
        
    # 出力層
    with tf.name_scope("out"):
        
        # 出力層定義
        w = tf.Variable(tf.truncated_normal([128, 10],
                                            stddev=0.1))
        b = tf.Variable(tf.zeros([10]))
        out = tf.nn.softmax(tf.matmul(lstm_last, w) + b)
        
    # 訓練
    with tf.name_scope("train"):
        
        # 誤差
        loss = tf.reduce_mean(-tf.reduce_sum(y * tf.log(out),
                                             axis=[1]))
        
        # ログ出力
        summaries.append(tf.summary.scalar("loss",
                                           loss))        
        
        # 訓練
        train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
        
        # 評価
    with tf.name_scope("evaluation"):

        # 正解率
        correct = tf.equal(tf.argmax(out, 1),
                           tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct,
                                          tf.float32))
        
        # ログ出力
        summaries.append(tf.summary.scalar("accuracy",
                                           accuracy))        
        
    # 初期化
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        
        # 初期化
        sess.run(init)
        
        # ログをひとまとめにする設定
        summary_op = tf.summary.merge(summaries)

        # ログの保管場所
        summary_writer = tf.summary.FileWriter("../logs",
                                               sess.graph)        
        
        # テストデータ
        test_images, test_labels = mnist.test.images, mnist.test.labels
        
        # 学習
        for i in range(1000):
            step = i + 1
            train_images, train_labels = mnist.train.next_batch(50)
            sess.run(train_step,
                     feed_dict={x: train_images,
                                y: train_labels})
            
            # 定期的に正解率確認
            if step % 100 == 0:
                
                # ログのテキスト取得
                summary_text, acc_val = sess.run([summary_op, accuracy],
                                                 feed_dict={x: test_images,
                                                           y: test_labels})

                # ログに書き出し
                summary_writer.add_summary(summary_text,
                                           step)                
                
                # 正解率表示
                acc_val = sess.run(accuracy,
                                   feed_dict={x: train_images,
                                              y: train_labels})
                print("step: {} acc: {:.3}".format(step,
                                                   acc_val))
                
seq_mnist()

TensorBoardの出力

Accuracy

f:id:mizuwan:20180702221759p:plain

Loss

f:id:mizuwan:20180702221808p:plain

グラフ

f:id:mizuwan:20180702221825p:plain

TensorFlowの関数を使うと非常に短いコードで実装できるが、中で何が起きているかはよく分からない。