shwld.io

doc2vecとtensorflowで、livedoor newsコーパスを自動分類してみた


ゼロから作るDeep Learning」 を読み終わったので、復習しつつ自動分類をやってみた。

まだまだ理解できていないことが多いと感じるが、それなりに動くところまでいけたので(色々間違いはあるかもしれませんが)公開してみる。

ソースは shwld/ana-livedoor-news-classification にあります。jupyter notebookでコメントもたくさん書いているのでそちらもよければ見てみてください。

[[toc]]

全体の流れ

  1. mecabやdoc2vecを利用できる環境を用意する

  2. livedoor newsコーパスをダウンロードする

  3. mecabを使いlivedoor newsコーパスを分かち書きしてDoc2Vecのトレーニングデータを作る

  4. Doc2Vecで記事毎にベクトル化する

  5. ベクトル化したデータをtensorflow専用の形式に変換して保存する

  6. tfrファイルを読み込む

  7. 確率的勾配降下法を用いて学習させる

  8. 学習結果を確認する

1. mecabやdoc2vecを利用できる環境を用意する

dockerでとりあえず動くimageを作成してみたのでそれを使う

これです shwld/mecab-python

2. livedoor newsコーパス をダウンロードする

livedoor newsコーパス

こちらもdockerでダウンロードするようにした。 こうしておけばgit管理から外せる上に、環境が変わっても復元できるので結構使いやすい。

FROM shwld/mecab-python

...

RUN wget http://www.rondhuit.com/download/ldcc-20140209.tar.gz \
    && tar xvfz ldcc-20140209.tar.gz

...

3. mecabを使いlivedoor newsコーパスを分かち書きしてDoc2Vecのトレーニングデータを作る

ドキュメントを順番に読み込んで、分かち書きした後、Doc2Vecのトレーニングデータを作るためLabeledSentenceを実行する。

LabeledSentenceについてはあまり詳しく調べられていないが、tagsは一意なものにしておかないと、うまくmost_simularメソッドが動かなかった。 tagsは複数指定できるので、今回はやっていないが、それを使えば自動分類自体はDoc2Vecだけでもできる?かもしれない。

training_docs = []
for idx, (doc) in enumerate(docs):
    text = ''
    for line in open(doc[0], 'r'):
        if (line is ''):
            continue

        text += mecab.parse(line)

    # doc2vecのタグは一意なもの(ファイル名)にしておく
    training_docs.append(LabeledSentence(words=text, tags=[dir_doc[0]]))

4. Doc2Vecで記事毎にベクトル化する

少し時間がかかる

model = Doc2Vec(documents=training_docs, min_count=1, dm=0, docvecs_mapfile="../input/mapfile.txt")

5. ベクトル化したデータをtensorflow専用の形式に変換して保存する

今回一番やりたかったこととして、自分で作ったデータをtensorflowに食わせてみたい。というのがあったので、tensorflowに読める形式に変換して保存します。

import tensorflow as tf

step = 1
with tf.python_io.TFRecordWriter('../input/train.tfr') as x_writer, tf.python_io.TFRecordWriter('../input/test.tfr') as t_writer:
    for doc in dir_docs:
        if (doc[0] not in model.docvecs):
            continue
        example = tf.train.Example(features=tf.train.Features(feature={
            'id': tf.train.Feature(int64_list=tf.train.Int64List(value=[doc[1]])),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[doc[3]])),
            'feature': tf.train.Feature(float_list=tf.train.FloatList(value=model.docvecs[doc[0]]))
        }))
        
        # 9割のデータを学習用に使う
        if (random.randint(1,100) < 90):
            x_writer.write(example.SerializeToString())
        else:
            t_writer.write(example.SerializeToString())
        
        step += 1

6. tfrファイルを読み込む

TFRecordReaderを使うと、レコードをキューで読み込み、バッチで処理することができるので、それを使った。

テストデータも同様に読み込める。

reader = tf.TFRecordReader()
min_after_dequeue = 5000  # 5000個以上キューが貯まるまで待ってそこからランダムに取得をするような感じだと思われる
capacity = min_after_dequeue + 3 * BATCH_SIZE


# トレーニングデータの準備
x_filename_queue = tf.train.string_input_producer(['../input/train.tfr'])
_, x_serialized_example = reader.read(x_filename_queue)
x_inputs = tf.parse_single_example(x_serialized_example, features={
    'id': tf.FixedLenFeature([1], tf.int64),
    'label': tf.FixedLenFeature([1], tf.int64),
    'feature': tf.FixedLenFeature([BATCH_SIZE], tf.float32),
})
x_batch = tf.train.shuffle_batch(x_inputs, batch_size=BATCH_SIZE, capacity=capacity, min_after_dequeue=min_after_dequeue)

7. 確率的勾配降下法を用いて学習させる

このあたりは、MNIST For ML Beginnersを参考に組んでいます。 tf.placeholderの使い方がわかれば結構捗ります。

x = tf.placeholder(tf.float32, [None, TRAIN_DATA_SIZE])
W = tf.Variable(tf.zeros([TRAIN_DATA_SIZE, 9]))
b = tf.Variable(tf.zeros([9]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
one_hot = tf.placeholder(tf.int32, [None])
y_ = tf.one_hot(one_hot, depth=9, on_value = 5.0, off_value = 0.0, dtype=tf.float32)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess = tf.Session()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
    for i in range(1000):
        x_train = sess.run(x_batch)
        sess.run(train_step, feed_dict={x: x_train['feature'], one_hot: x_train['label'].reshape((BATCH_SIZE))})
        
        # 10件ごとに予測と精度の計算をする
        if i%10 == 0:
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
            train_accuracy = sess.run(accuracy, feed_dict={x: x_train['feature'], one_hot: x_train['label'].reshape((BATCH_SIZE))})
            print("step %d, training accuracy %g"%(i, train_accuracy))
        fp.value = i
except tf.errors.OutOfRangeError:
    print('Done training for %d steps.' % (step))
finally:
    coord.request_stop()
    
coord.join(threads)

printしているので、こんな感じでトレーニング状態が表示される。

... step 830, training accuracy 0.7 step 840, training accuracy 0.73 step 850, training accuracy 0.69 step 860, training accuracy 0.88 step 870, training accuracy 0.9 step 880, training accuracy 0.87 step 890, training accuracy 0.59 step 900, training accuracy 0.74 ...

8. 学習結果を確認する

分けて作っておいた方のデータを使ってテストしてみた。

t_train = sess.run(t_batch)

print("結果")

result = sess.run(y, feed_dict={x: t_train['feature']})

t_correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
t_accuracy = tf.reduce_mean(tf.cast(t_correct_prediction, "float"))
test_accuracy = sess.run(t_accuracy, feed_dict={x: t_train['feature'], one_hot: t_train['label'].reshape((BATCH_SIZE))})
print("test accuracy %g"%(test_accuracy))

print("\n1件だけ見てみる")

print("id      : {}".format(t_train['id'][0]))
print("label : {}".format(t_train['label'][0]))

print(result[0])
print("label : {}".format(np.argmax(result[0])))

こんな感じで結果が出る

結果 test accuracy 0.85

1件だけ見てみる id : [290] label : [0] [ 9.82768357e-01 3.50690279e-05 1.37721945e-05 1.44424368e-04 2.19040285e-05 2.19040285e-05 2.12749350e-04 5.95606631e-04 1.61861740e-02] label : 0

85%くらいでした。 高いような気がしてしまいましたが、5回に1回くらい失敗すると考えるとそんなに実用的ではないですね。

チューニングできるところやアルゴリズム変えるとかいろいろ試してひとまず90%超を目指してみたいところです。

Line Notifyサービス終了。移行先どうしよう
shwld
2024/10/08
リモートワーク下のビデオ通話における音声デバイスの変遷ログ
shwld
2024/05/10
NestJSのGraphQLサーバにApplication Insightsのトレースを仕込む
shwld
2024/02/19
GitHub Copilotを効率よく使うために
shwld
2024/01/29