[データセット] MNIST

MNIST

MNIST は手書き数字のデータセットです。
手軽に扱うことができるので、機械学習の Hello World 的存在になっています。

データの特徴

28×28ピクセル、グレースケールの手書き数字画像
学習データ6万枚、テストデータ1万枚

ダウンロード方法

Yann LeCun’s のページで配布されていたのですが、現在ダウンロードできなくなっているので web.archive.org から入手します。

  • train-images-idx3-ubyte.gz : 学習用画像データ
  • train-labels-idx1-ubyte.gz : 学習用ラベルデータ
  • t10k-images-idx3-ubyte.gz : テスト用画像データ
  • t10k-labels-idx1-ubyte.gz : テスト用ラベルデータ

データの構造

.gz を解凍すると、以下のようなバイナリファイルが入っています。

  • train-images-idx3-ubyte : 学習用画像データ
  • train-labels-idx1-ubyte : 学習用ラベルデータ
  • t10k-images-idx3-ubyte : テスト用画像データ
  • t10k-labels-idx1-ubyte : テスト用ラベルデータ

MNISTの画像データは28×28ピクセルの画像が枚数分含まれています。
先頭16バイトがヘッダでそのあとに画像が28×28=728バイト区切りで続いています。
ヘッダは4バイト区切りでマジックナンバー、画像枚数、画像は高さ、画像の幅の情報が入っています。
画像部分は1バイトがグレースケールの1ピクセルを表し、1行目、2行目、…、28行目というように入っています。

MNISTのラベルデータは画像に対応したラベルが含まれています。
先頭8バイトがヘッダでそのあとにラベルが1バイト区切りで続いています。
ヘッダは4バイト区切りでマジックナンバー、画像枚数の情報が入っています。
マジックナンバーはなんの意味があるのか不明です。ファイルの整合性を確認したりするのでしょうか。

Pythonコード

Pythonコードで書くと以下のようになります。
gzip を解凍するコードは含めてないので、手動で解凍し、MNISTデータと同じ階層で実行してください。

# -*- coding: utf-8 -*-
import struct

import cv2
import numpy as np

def get_data(img_file, label_file):
    with open(img_file, 'rb') as file:
        # ヘッダ16バイトを読みこみ、4つの uint8 として解釈する。
        magic, num, rows, cols = struct.unpack(">4I", file.read(16))
        print("magic={}, num={}, rows={}, cols={}".format(magic, num, rows, cols))

        # 残り全部を読み込み、1次元配列を作成した後、num x rows x cols に変形する。
        imgs = np.fromfile(file, dtype=np.uint8).reshape(num, rows, cols)

    with open(label_file, 'rb') as file:
        # ヘッダ8バイトを読みこみ、2つの uint8 として解釈する。
        magic, num = struct.unpack(">2I", file.read(8))
        print("magic={}, num={}".format(magic, num))

        # 残り全部を読み込み、1次元配列を作成する。
        labels = np.fromfile(file, dtype=np.uint8)
    return [{"img": img, "label": label} for img, label in zip(imgs, labels)]


if __name__ == '__main__':
    train_data = get_data("train-images-idx3-ubyte", "train-labels-idx1-ubyte")
    test_data = get_data("t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte")

    # 表示する
    for data in train_data:
        window_name = "label {}".format(data["label"])
        cv2.imshow(window_name, data["img"])
        cv2.waitKey(0)
        cv2.destroyWindow(window_name)

3コメント

コメントを残す

メールアドレスが公開されることはありません。