[TensorFlow] マルチスレッド処理

Coordinator

Coordinator は日本語で調整者を意味します。
tf.train.Coordinator クラスはマルチスレッドで処理を行う場合に同期などスレッド間の調整を行う機能を提供します。

同期待ちを行う

複数のスレッドを実行し、すべてのスレッドの処理が完了するのを待つには join() を使います。
tf.train.Coordinator.join(threads=None)
threads に実行する threading.Thread インスタンスのリストを渡します。

# -*- coding: utf-8 -*-
import threading
import tensorflow as tf

# スレッドで処理したい処理を関数で記述する。
def MyLoop(coord):
    処理を記述する。

# Coordinator クラスを作成する。
coord = tf.train.Coordinator()

# threading.Thread クラスのインスタンスを10個作成する。
threads = [threading.Thread(target=MyLoop, args=(coord)) for i in range(10)]

# スレッドを実行する。
for t in threads:
    t.start()

# すべてのスレッドが終了するまで待つ。
coord.join(threads)
print("All threads are terminated.")

全スレッドの停止をリクエストする

実行中の全スレッドを停止するには request_stop() を使用します。
このとき、should_stop() が True を返すように変更されます。
ですので、各スレッドでは定期的に coord.should_stop() が True でないか確認し、True ならスレッドを終了するようにします。
request_stop() はスレッド内または本体側のどちらからも呼び出せます。

# -*- coding: utf-8 -*-
import threading
import tensorflow as tf

# スレッドで処理したい処理を関数で記述する。
def MyLoop(coord):
    # coord.should_stop() が False である限り、処理を継続する。
    while not coord.should_stop():
        処理を記述する。

        if 停止をリクエストする条件:
            coord.request_stop()

# Coordinator クラスを作成する。
coord = tf.train.Coordinator()

# threading.Thread クラスのインスタンスを10個作成する。
threads = [threading.Thread(target=MyLoop, args=(coord)) for i in range(10)]

# スレッドを実行する。
for t in threads:
    t.start()

# すべてのスレッドが終了するまで待つ。
coord.join(threads)
print("All threads are terminated.")

例外が発生したら、スレッドの停止をリクエストする

stop_on_exception() を使用すると、あるスレッド内の処理中に例外が発生したら、Coordinator に停止をリクエストするという処理が簡潔に書ける。

def MyLoop(coord):
    with coord.stop_on_exception():
        処理を記述する。

これは以下のコードと同じである。

def MyLoop(coord):
    try:
        while not coord.should_stop():
            処理を記述する。
    except Exception as e:
        coord.request_stop(e)

1コメント

コメントを残す

メールアドレスが公開されることはありません。