Coordinator

Coordinator は日本語で調整者を意味します。
tf.train.Coordinator クラスはマルチスレッドで処理を行う場合に同期などスレッド間の調整を行う機能を提供します。

同期待ちを行う

複数のスレッドを実行し、すべてのスレッドの処理が完了するのを待つには join() を使います。
tf.train.Coordinator.join(threads=None)
threads に実行する threading.Thread インスタンスのリストを渡します。

# -*- coding: utf-8 -*-
import threading
import tensorflow as tf

# スレッドで処理したい処理を関数で記述する。
def MyLoop(coord):
    処理を記述する。

# Coordinator クラスを作成する。
coord = tf.train.Coordinator()

# threading.Thread クラスのインスタンスを10個作成する。
threads = [threading.Thread(target=MyLoop, args=(coord)) for i in range(10)]

# スレッドを実行する。
for t in threads:
    t.start()

# すべてのスレッドが終了するまで待つ。
coord.join(threads)
print("All threads are terminated.")

全スレッドの停止をリクエストする

実行中の全スレッドを停止するには request_stop() を使用します。
このとき、should_stop() が True を返すように変更されます。
ですので、各スレッドでは定期的に coord.should_stop() が True でないか確認し、True ならスレッドを終了するようにします。
request_stop() はスレッド内または本体側のどちらからも呼び出せます。

# -*- coding: utf-8 -*-
import threading
import tensorflow as tf

# スレッドで処理したい処理を関数で記述する。
def MyLoop(coord):
    # coord.should_stop() が False である限り、処理を継続する。
    while not coord.should_stop():
        処理を記述する。

        if 停止をリクエストする条件:
            coord.request_stop()

# Coordinator クラスを作成する。
coord = tf.train.Coordinator()

# threading.Thread クラスのインスタンスを10個作成する。
threads = [threading.Thread(target=MyLoop, args=(coord)) for i in range(10)]

# スレッドを実行する。
for t in threads:
    t.start()

# すべてのスレッドが終了するまで待つ。
coord.join(threads)
print("All threads are terminated.")

例外が発生したら、スレッドの停止をリクエストする

stop_on_exception() を使用すると、あるスレッド内の処理中に例外が発生したら、Coordinator に停止をリクエストするという処理が簡潔に書ける。

def MyLoop(coord):
    with coord.stop_on_exception():
        処理を記述する。

これは以下のコードと同じである。

def MyLoop(coord):
    try:
        while not coord.should_stop():
            処理を記述する。
    except Exception as e:
        coord.request_stop(e)

tf.flags

公式ドキュメント: tf.flags

argparse をラップしたコマンドライン引数を解析するためのモジュール。
公式のサンプルなどで利用されているのを見かける。

引数の型によって、以下の4種類がある。
tf.flags.DEFINE_string(flag_name, default_value, docstring)
tf.flags.DEFINE_integer(flag_name, default_value, docstring)
tf.flags.DEFINE_boolean(flag_name, default_value, docstring)
DEFINE_float(flag_name, default_value, docstring)

例:

import tensorflow as tf
tf.flags.DEFINE_string("msg", "hello", "Message to print.")
tf.flags.DEFINE_string("num", "0", "Number to print.")
FLAGS = tf.flags.FLAGS

print(FLAGS.msg)
print(FLAGS.num)
$ python test.py --msg hoge
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcudnn.so.5 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcufft.so.8.0 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcurand.so.8.0 locally
hoge
0

x軸、y軸のラベル及びタイトルを設定する

import matplotlib.pyplot as plt
import numpy as np

values = 33.73 * np.random.randn(10000) + 170.69
fig, ax = plt.subplots(1, 1)
ax.hist(values, bins=50)

plt.xlabel("height (cm)")
plt.ylabel("frequency")
plt.title("stature distribution")

plt.show()

テキスト

px などの尺度ではなく、グラフ上の座標で位置を指定します。

import matplotlib.pyplot as plt
import numpy as np

values = 33.73 * np.random.randn(10000) + 170.69
fig, ax = plt.subplots(1, 1)
ax.hist(values, bins=50)

plt.text(50, 300, "Text!")

plt.show()

アノテーション

テキストでアノテーションする

annotate() はグラフのある場所をテキストでアノテーションするのに使われる。
s に文字列、xy にアノテーションをする座標を指定します。

import matplotlib.pyplot as plt
import numpy as np


xs = np.arange(-5.0, 5.0, 0.1)
ys = np.sin(xs)

fig, ax = plt.subplots(1, 1)
ax.plot(xs, ys)
ax.set_ylim(-2,2)

ax.annotate('local max', xy=(2, 1))

plt.show()

テキストと矢印でアノテーションする

まず xytext でテキストの位置、xy でアノテーション対象の位置を指定します。
そして arrowprops 引数を与えることで、始点が xytext で終点が xy である矢印が描写されます。

shrink

xytext ~ xy のうち、矢印を描写する割合

import matplotlib.pyplot as plt
import numpy as np

xs = np.arange(-5.0, 5.0, 0.1)
ys = np.sin(xs)

fig, [ax1, ax2, ax3] = plt.subplots(3, 1, sharey=True)
ax1.set_ylim(-2,2)

ax1.plot(xs, ys)
ax1.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(shrink=0.1))

ax2.plot(xs, ys)
ax2.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(shrink=0.5))

ax3.plot(xs, ys)
ax3.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(shrink=0.9))

plt.show()

width

矢印の太さを pt で指定する。

import matplotlib.pyplot as plt
import numpy as np

xs = np.arange(-5.0, 5.0, 0.1)
ys = np.sin(xs)

fig, [ax1, ax2] = plt.subplots(2, 1, sharey=True)
ax1.set_ylim(-2,2)

ax1.plot(xs, ys)
ax1.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(width=10))

ax2.plot(xs, ys)
ax2.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(width=30))

plt.show()

headwidth

矢印の先の太さを pt で指定する。

import matplotlib.pyplot as plt
import numpy as np

xs = np.arange(-5.0, 5.0, 0.1)
ys = np.sin(xs)

fig, [ax1, ax2] = plt.subplots(2, 1, sharey=True)
ax1.set_ylim(-2,2)

ax1.plot(xs, ys)
ax1.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(headwidth=10))

ax2.plot(xs, ys)
ax2.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(headwidth=30))

plt.show()

headlength

矢印の先の長さを pt で指定する。

import matplotlib.pyplot as plt
import numpy as np

xs = np.arange(-5.0, 5.0, 0.1)
ys = np.sin(xs)

fig, [ax1, ax2] = plt.subplots(2, 1, sharey=True)
ax1.set_ylim(-2,2)

ax1.plot(xs, ys)
ax1.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(headlength=10))

ax2.plot(xs, ys)
ax2.annotate('local max', xy=(2, 1), xytext=(3, 1.5),
            arrowprops=dict(headlength=30))

plt.show()

matplotlib の線のプロパティ

sin をプロット対象に使います。

>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> xs = np.arange(-2.0, 2.0, 0.1)
>>> ys = np.sin(xs)

[alpha] 透過度を設定する

fig, [ax1, ax2, ax3] = plt.subplots(3, 1)
ax1.plot(xs, ys, alpha=0.3)
ax2.plot(xs, ys, alpha=0.6)
ax3.plot(xs, ys, alpha=0.9)
plt.show()

[color] 色を設定する

<

色をANSIカラーで指定する

ANSI で規定されている8色(青、緑、赤、シアン、マゼンダ、黄、黒、白)は名前で指定できます。
ANSI_escape_code Colors (Wikipedia)

colors = ["blue", "green", "red", "cyan",
          "magenta", "yellow", "black", "white"]
fig, axes = plt.subplots(len(colors), 1)

for ax, color in zip(axes, colors):
    ax.plot(xs, ys, color=color)
plt.show()
# 一番下は背景も白なので同化して見えない

名前は頭文字で指定することもできます

colors = ["b", "g", "r", "c", "m", "y", "b", "w"]
fig, axes = plt.subplots(len(colors), 1)

for ax, color in zip(axes, colors):
    ax.plot(xs, ys, color=color)
plt.show()

輝度値で指定する

color=<輝度値> (輝度値は str で “0.0” ~ “1.0”) と指定した場合は色はグレースケールになります。

fig, [ax1, ax2, ax3] = plt.subplots(3, 1)
ax1.plot(xs, ys, color="0.3")
ax2.plot(xs, ys, color="0.6")
ax3.plot(xs, ys, color="0.9")
plt.show()

HTMLカラーコードで指定する

HTMLカラーコードはこちらを参照。

fig, [ax1, ax2, ax3] = plt.subplots(3, 1)
ax1.plot(xs, ys, color="#40E0D0")
ax2.plot(xs, ys, color="#FFD700")
ax3.plot(xs, ys, color="#F6546A")
plt.show()

HTMLカラーネームで指定する

HTMLカラーネームはこちらを参照。

fig, [ax1, ax2, ax3] = plt.subplots(3, 1)
ax1.plot(xs, ys, color="deeppink")
ax2.plot(xs, ys, color="orange")
ax3.plot(xs, ys, color="darkgreen")
plt.show()

RGB値で指定する

color=[R, G, B], (R, G, B は float で 0.0 ~ 1.0) で指定することもできます。

fig, [ax1, ax2, ax3] = plt.subplots(3, 1)
ax1.plot(xs, ys, color=[0.2, 0.8, 0.8])
ax2.plot(xs, ys, color=[0.5, 0.7, 0.2])
ax3.plot(xs, ys, color=[0.8, 0.8, 0.5])
plt.show()

[linestyle] 線の種類を設定する

fig, [ax1, ax2, ax3, ax4, ax5] = plt.subplots(5, 1)
ax1.plot(xs, ys, linestyle="solid")
ax2.plot(xs, ys, linestyle="dashed")
ax3.plot(xs, ys, linestyle="dashdot")
ax4.plot(xs, ys, linestyle="dotted")
ax5.plot(xs, ys, linestyle="None")
plt.show()

[linewidth] 線の太さを設定する

fig, [ax1, ax2, ax3] = plt.subplots(3, 1)
ax1.plot(xs, ys, linewidth=0.5)
ax2.plot(xs, ys, linewidth=1.0)
ax3.plot(xs, ys, linewidth=2.0)
plt.show()

[fillstyle] fillstyle を設定する

fig, [ax1, ax2, ax3, ax4, ax5, ax6] = plt.subplots(6, 1)
ax1.plot(xs, ys, fillstyle="full")
ax2.plot(xs, ys, fillstyle="left")
ax3.plot(xs, ys, fillstyle="right")
ax4.plot(xs, ys, fillstyle="bottom")
ax5.plot(xs, ys, fillstyle="top")
ax6.plot(xs, ys, fillstyle="none")
plt.show()

[drawstyle] 線の補完方式を設定する

与えられた点と点の間をどのように補完するかの設定

fig, [ax1, ax2, ax3, ax4, ax5] = plt.subplots(5, 1)
ax1.plot(xs, ys, drawstyle="default")
ax2.plot(xs, ys, drawstyle="steps")
ax3.plot(xs, ys, drawstyle="steps-pre")
ax4.plot(xs, ys, drawstyle="steps-mid")
ax5.plot(xs, ys, drawstyle="steps-post")
plt.show()