人工知能エンジニア修行日記

主に機械学習、深層学習、Python、数学について覚え書きを記します

4章ニューラルネットワークの学習 データ駆動アプローチ、損失関数、ミニバッチ『ゼロから作るDeep Learning』

3章ではニューラルネットワークの「推論」を実装したが、4章からニューラルネットワークの「学習」を実装する。

「学習」とは:訓練データから最適な重みパラメータ値を自動で獲得すること

パラメータの数は、実際数千〜数億にも及ぶため、手動で調整することは不可能

「データ駆動アプローチ」:いままでの「人」を中心としたアプローチではなく、「データ」を中心としたアプローチ

ニューラルネットワークディープラーニングでは、従来の機械学習以上に、属人性を排している

・機械学習(ML)以前 入力データ → 人力処理 → 出力データ
↓
・ML 入力データ → 人力特徴量 → ML自動処理 → 出力データ
↓
・NNやDL 入力データ → NN/DL自動処理 → 出力データ

機械学習におけるデータの取扱について

2種類のデータ:「訓練(教師)データ」と「テストデータ」

まず「訓練(教師)データ」で学習し、最適なパラメータを探索

つぎに「テストデータ」で汎化能力を評価し、一部のデータセットだけ過度に対応した「過学習overfitting」を避ける

「損失関数loss function」:ニューラルネットワークの学習で用いられる指標で、主に「二乗和誤差」「交差エントロピー誤差」が用いられる

「二乗和誤差」について。

# 二乗和誤差
def mean_squared_error(y, t): #y: ニューラルネットワークの出力、t:訓練データ
  return np.sum( (y - t)**2 ) / 2

# 実行例
# cat mean_squared_error.py
#!/usr/bin/env python

import numpy as np
import my_module as my

# 訓練データ(「2」を正解とするone-hot表現
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]

# 例1: '2'の確率が最も高い場合
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(np.sum(y)) #=> 1.0

y = my.mean_squared_error(np.array(y), np.array(t))
print(y) #=> 0.0975

# 例2: '7'の確率が最も高い場合
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(np.sum(y)) #=> 1.0

y = my.mean_squared_error(np.array(y), np.array(t))
print(y) #=> 0.5975

「交差エントロピー誤差」について。

#交差エントロピー誤差
def cross_entropy_error(y, t):
  delta = 1e-7 # マイナス無限大を回避するための微小値
  return - np.sum( t * np.log( y + delta ) )


# 実行例
# cat cross_entropy_error.py
#!/usr/bin/env python

import numpy as np
import my_module as my

# 訓練データ(「2」を正解とするone-hot表現
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]

# 例1: '2'の確率が最も高い場合
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(np.sum(y)) #=> 1.0

y = my.cross_entropy_error(np.array(y), np.array(t))
print(y) #=> 0.510825457099

# 例2: '7'の確率が最も高い場合
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(np.sum(y)) #=> 1.0

y = my.cross_entropy_error(np.array(y), np.array(t))
print(y) #=> 2.30258409299

続いて、大量にある全データのうち一部を選出して近似とする「ミニバッチ学習」について。

# 0から指定された数字(60000)未満までの数字をランダムに指定個数(10)選択する処理
>>> import numpy as np
>>> np.random.choice(60000, 10)
array([54904, 15528, 35786, 44249, 25077, 37764, 46607,   552, 33463, 12885])
>>> np.random.choice(60000, 10)
array([38745,  8181,  8602, 37811, 24747, 18214, 50371, 13052, 13100, 36289])
>>> np.random.choice(60000, 10)

MNISTデータセットで動作確認

# cat mini_batch.py
# coding: utf-8
import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist

(x_train, t_train), (x_test, t_test) = \
  load_mnist(normalize=True, one_hot_label=True)

print(x_train.shape) #=>(60000, 784)
print(t_train.shape) #=>(60000, 10)

# mini_batch
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]

print(x_batch) #=>ランダム結果
'''
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ...,
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
'''

print(t_batch) #=>ランダム結果
'''
[[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]]
'''

交差エントロピー誤差関数の改変

#交差エントロピー誤差(バッチ対応版)
def cross_entropy_error_batch(y, t):
  if y.ndim == 1:
    t = t.reshape(1, t.size)
    y = y.reshape(1, y.size)

  batch_size = y.shape[0]
  return - ( np.sum( t * np.log( y ) ) / batch_size )


#交差エントロピー誤差(バッチ対応、教師データラベル版)
#one-hot表現ではなく'2'などのラベル
def cross_entropy_error_batch_label(y, t):
  if y.ndim == 1:
    t = t.reshape(1, t.size)
    y = y.reshape(1, y.size)

  batch_size = y.shape[0]
  return - ( np.sum( t * np.log( y[np.arange(batch_size), t] ) ) / batch_size )

ニューラルネットワークの学習では、

認識精度を「指標」にしてはいけない。

なぜなら、認識精度を指標にすると、

パラメータの微分がほとんどの場所で0

になってしまうから。

だから、損失関数が必要なのである。