6章 学習に関するテクニック『ゼロから作るDeep Learning』
続いて、Batch Normalization
Batch Normalization(Batch Norm)とは:ミニバッチごとに正規化する手法。
メリット:
学習を速く進行させることができる 初期値にそれほど依存しない(初期値にロバスト) 過学習を抑制する
以下さまざまな初期値で比較した、サンプルスクリプト
# cat batch_norm_test_save.py # coding: utf-8 import sys, os sys.path.append(os.pardir) # 親ディレクトリのファイルをインポートするための設定 import numpy as np import matplotlib.pyplot as plt plt.switch_backend('agg') from dataset.mnist import load_mnist from common.multi_layer_net_extend import MultiLayerNetExtend from common.optimizer import SGD, Adam (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True) # 学習データを削減 x_train = x_train[:1000] t_train = t_train[:1000] max_epochs = 20 train_size = x_train.shape[0] batch_size = 100 learning_rate = 0.01 def __train(weight_init_std): bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10, weight_init_std=weight_init_std, use_batchnorm=True) network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10, weight_init_std=weight_init_std) optimizer = SGD(lr=learning_rate) train_acc_list = [] bn_train_acc_list = [] iter_per_epoch = max(train_size / batch_size, 1) epoch_cnt = 0 for i in range(1000000000): batch_mask = np.random.choice(train_size, batch_size) x_batch = x_train[batch_mask] t_batch = t_train[batch_mask] for _network in (bn_network, network): grads = _network.gradient(x_batch, t_batch) optimizer.update(_network.params, grads) if i % iter_per_epoch == 0: train_acc = network.accuracy(x_train, t_train) bn_train_acc = bn_network.accuracy(x_train, t_train) train_acc_list.append(train_acc) bn_train_acc_list.append(bn_train_acc) print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc)) epoch_cnt += 1 if epoch_cnt >= max_epochs: break return train_acc_list, bn_train_acc_list # 3.グラフの描画========== weight_scale_list = np.logspace(0, -4, num=16) x = np.arange(max_epochs) for i, w in enumerate(weight_scale_list): print( "============== " + str(i+1) + "/16" + " ==============") train_acc_list, bn_train_acc_list = __train(w) plt.subplot(4,4,i+1) plt.title("W:" + str(w)) if i == 15: plt.plot(x, bn_train_acc_list, label='Batch Normalization', markevery=2) plt.plot(x, train_acc_list, linestyle = "--", label='Normal(without BatchNorm)', markevery=2) else: plt.plot(x, bn_train_acc_list, markevery=2) plt.plot(x, train_acc_list, linestyle="--", markevery=2) plt.ylim(0, 1.0) if i % 4: plt.yticks([]) else: plt.ylabel("accuracy") if i < 12: plt.xticks([]) else: plt.xlabel("epochs") plt.legend(loc='lower right') #plt.show() plt.savefig('batch_norm_test_save.png')
若干表示が崩れたが、確かに多くの場合Batch Normの方が速く学習が進む。