SGD

NeuralNetwork.* SGD(training_data, test_data, epochs, mini_batch_size, learning_rate)

SGD(Stochastic Gradient Descent) 確率的勾配降下法

引数:
  • epochs (int) – エポック数
  • mini_batch_size (int) – ミニバッチのサイズ
  • learning_rate (double) – 学習率

ソース

* SGD(training_data, test_data, epochs, mini_batch_size, learning_rate) {
    this.learningRate = learning_rate;
    var last_layer = this.layers[this.layers.length - 1];
    last_layer.deltaY = new ArrayView(mini_batch_size, last_layer.unitSize);
    var exp_work = new Float32Array(last_layer.unitSize);
    var change_mini_batch_size = false;

    // 前回yieldでブラウザに制御を戻した時間
    var last_yield_time = undefined;

    for (this.epochIdx = 0; this.epochIdx < epochs; this.epochIdx++) {

        var start_epoch_time = new Date();
        var total_data_cnt = training_data.X.shape[0] + test_data.X.shape[0];
        var total_processed_data_cnt = 0

        // トレーニングとテストの処理
        for(var mode = 0; mode < 2; mode++){
            this.isTraining = (mode == 0);

            // 入力(X)と出力(Y)のペア
            var XY_data;

            // 正解数
            var ok_cnt = 0;

            var cost_sum = 0;

            if(this.isTraining){

                XY_data = training_data;
                miniBatchSize = mini_batch_size;
            }
            else{

                XY_data = test_data;
                miniBatchSize = (change_mini_batch_size ? 10 : 1) * mini_batch_size;
            }
            this.miniBatchSize = miniBatchSize;

            // ミニバッチ内のコスト
            var costs = new Float32Array(miniBatchSize);

            if(change_mini_batch_size || this.epochIdx == 0 && this.isTraining){

                this.layers.forEach(x => x.miniBatchSizeChanged());
            }

            var data_cnt = XY_data.X.shape[0];

            // 0からdata_cnt-1までの数をシャッフルした配列
            var idx_list = random.RandomSampling(data_cnt);

            DataCnt = data_cnt;
            this.miniBatchCnt = Math.floor(data_cnt / miniBatchSize);

            for (miniBatchIdx = 0; miniBatchIdx < this.miniBatchCnt; miniBatchIdx++) {

                var start_pos = miniBatchIdx * miniBatchSize;

                // idx_listのstart_posからminiBatchSize個のインデックスを使って、XY_dataから入力(X)と出力(Y)のデータを抜き出す。
                var X = this.ExtractArrayView(XY_data.X, idx_list, start_pos, miniBatchSize);
                var Y = this.ExtractArrayView(XY_data.Y, idx_list, start_pos, miniBatchSize);

                // 最初のレイヤー(入力層)の出力にXをセットする。
                this.layers[0].y_ = X;

                // 順伝播
                this.layers.forEach(x => x.forward());

                if(useSoftMax){

                    for (var batch_idx = 0; batch_idx < miniBatchSize; batch_idx++){

                        costs[batch_idx] = this.SoftMax(last_layer.deltaY.dt, last_layer.y_.dt, Y.dt, exp_work, last_layer.unitSize, batch_idx);
                    }

                    cost_sum += costs.reduce((x,y) => x + y) / miniBatchSize;
                }
                else{

                    this.LeastSquaresDelta(last_layer.deltaY.dt, last_layer.y_.dt, Y.dt);
                }

                if(this.isTraining){

                    if(useGradientCheck){

                        // 勾配の計算のチェック
                        this.netGradientCheck(Y.dt, exp_work, costs);
                    }
                    else{

                        // 誤差逆伝播
                        for (var i = this.layers.length - 1; 1 <= i; i--) {
                            this.layers[i].backpropagation();
                        }
                    }

                    // パラメータの更新
                    this.layers.forEach(x => x.updateParameter());
                }

                // 処理データ数
                this.processedDataCnt = (miniBatchIdx + 1) * miniBatchSize;

                // 正解数
                ok_cnt += this.CorrectCount(Y);

                // 正解率
                var accuracy = ok_cnt   / this.processedDataCnt;

                // コストの平均
                var avg_cost = cost_sum / this.processedDataCnt;

                if(this.isTraining){

                    this.trainingCost[this.epochIdx] = avg_cost;
                    this.trainingAccuracy[this.epochIdx] = accuracy;
                }
                else{

                    this.testCost[this.epochIdx] = avg_cost;
                    this.testAccuracy[this.epochIdx] = accuracy;
                }

                total_processed_data_cnt += miniBatchSize;

                if (last_yield_time == undefined || 10 * 1000 < new Date() - last_yield_time) {
                    // 最初か、10秒経過した場合

                    if(last_yield_time != undefined){

                        this.epochTime = Math.round( (new Date() - start_epoch_time) * total_data_cnt / (60 * 1000 * total_processed_data_cnt) );
                    }

                    // ミニバッチごとの処理時間 (レイヤー別)
                    this.processedTimeLayer = this.layers.slice(1).map(layer => layer.processedTime(miniBatchIdx + 1)).join("\n");

                    last_yield_time = new Date();

                    yield 1;
                }
            }

            var sum_last_layer_y = last_layer.y_.dt.reduce((x,y) => x + y);
            console.log("乱数の数:%d 出力の和:%f", MersenneTwisterIdx, sum_last_layer_y);

            if(change_mini_batch_size){
                this.layers.forEach(x => x.clear());
            }
        }

        console.log("Epoch %d  %.02f% %dmin", this.epochIdx, 100 * this.testAccuracy[this.epochIdx], this.epochTime);

        yield 2;
    }

    if(! change_mini_batch_size){
        this.layers.forEach(x => x.clear());
    }

    yield 0;
}