読者です 読者をやめる 読者になる 読者になる

人工知能エンジニア修行日記

人工知能エンジニアを目指して修行します

手書き数字認識とバッチ処理『ゼロから作るDeep Learning』

『ゼロから作るDeep Learning』3章最後

いよいよ手書き数字認識に入る。ここでは、「学習」フェーズは完了している前提で、「推論」フェーズのみ順方向伝播方式で実施。

# MNISTという手書き数字画像セットを準備
# git clone https://github.com/oreilly-japan/deep-learning-from-scratch.git
# cd deep-learning-from-scratch/ch03
# ls ../dataset/mnist.py
# ../dataset/mnist.py
# cat mnist_check.py

import sys, os
sys.path.append(os.pardir) #親ディレクトリのファイルをインポート
from dataset.mnist import load_mnist

# load_minst((訓練画像、訓練ラベル), (テスト画像、テストラベル)
# 初回呼び出しはネットDL、2回目以降はローカルpickleファイル読込
# 引数
# normalize: 0.0..1.0に正規化
# flatten: 1次元配列化
# one_hot_label: 正解ラベルのみ1でそれ以外は0にするone-hot表現として格納するか
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

print(x_train.shape) # (60000, 784)
print(t_train.shape) # (60000,)
print(x_test.shape) # (10000, 784)
print(t_test.shape) # (10000,)

続いて、MNIST画像の表示スクリプトmnist_show.pyを参考に、保存処理mnist_save.pyを作って確認

# cat mnist_save.py
# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image

def img_save(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.save('mnist_save_sample.png')

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

img = x_train[0]
label = t_train[0]
print(label)  # 5

print(img.shape)  # (784,)
img = img.reshape(28, 28)  # 形状を元の画像サイズに変形
print(img.shape)  # (28, 28)

img_save(img)

# py mnist_save.py
5
(784,)
(28, 28)

出力された画像

f:id:kaeken:20161107013732p:plain

続いて、ニューラルネットワークの推論処理を確認

# cat neuralnet_mnist.py
# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
  # 前処理pre-processingとして、正規化normalizationを実施
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

# pickleファイルに保存された学習済重みパラメータの読込
def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p= np.argmax(y) # 最も確率の高い要素のインデックスを取得
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
#=>Accuracy:0.9352で、
# 93.52%正しく分類できた

続いてバッチ処理

バッチbatch:ひとまとまりの入力データ束

バッチ処理によって1枚あたりの処理時間を短縮できる

# cat neuralnet_mnist_batch.py
# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 親ディレクトリのファイルをインポートするための設定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()

batch_size = 100 # バッチの数
accuracy_cnt = 0

for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1)
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

実際に比較してみると、確かに高速化していることが確認できた

# バッチなし
# time py neuralnet_mnist.py
Accuracy:0.9352

real    0m2.220s
user    0m2.859s
sys 0m0.521s

# バッチあり
# time py neuralnet_mnist_batch.py
Accuracy:0.9352

real    0m0.991s
user    0m0.857s
sys 0m0.312s