[numpy] np.where

np.where

以下のコードを実行するには,まず numpy を import する.

import numpy as np

numpy.where(condition[, x, y])

numpy.where を使うと,Bool 型の ndarray を渡すと,True の位置を返してくれる.
True である要素を軸ごとにインデックス一覧として返す.

>>> z = np.array([[True, False, False],
                  [False, True, False],
                  [False, True, True]])
>>> np.where(z)
(array([0, 1, 2, 2], dtype=int64), array([0, 1, 1, 2], dtype=int64))
>>> z = np.array([[False, False, False],
                 [False, False, False],
                 [False, False, False]])
>>> np.where(z)
(array([], dtype=int64), array([], dtype=int64))

さらに,引数 x, y も与えた場合は,各要素を condition が True の部分は x の値,False の場合は y の値にした ndarray を作成する.

>>> conditions = np.array([[False, False, False],
                           [False, False, False],
                           [False, False, False]])
>>> x = np.array([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])
>>> y = np.array([[-1, -2, -3],
                  [-4, -5, -6],
                  [-7, -8, -9]])
>>> np.where(conditions, x, y)
array([[ 1, -2, -3],
       [-4,  5, -6],
       [-7,  8,  9]])

コメントを残す

メールアドレスが公開されることはありません。