forward

MaxPoolingLayer.forward()

順伝播

ソース

forward() {
    var lap = new Lap(this.forwardTime);

    var prev_layer = this.prevLayer;

    var prev_y_dt = prev_layer.y_.dt;

    // 出力先
    var output_idx = 0;

    // バッチ内のデータに対し
    for (var batch_idx = 0; batch_idx < miniBatchSize; batch_idx++) {

        // すべての特徴マップに対し
        for (var channel_idx = 0; channel_idx < this.numChannels; channel_idx++) {

            // 出力の行に対し
            for (var r1 = 0; r1 < this.numRows; r1++) {
                var r0 = r1 * this.filterSize;

                // 出力の列に対し
                for (var c1 = 0; c1 < this.numCols; c1++) {
                    var c0 = c1 * this.filterSize;

                    if(inGradientCheck){

                        var r2 = this.maxRow[output_idx];
                        var c2 = this.maxCol[output_idx];
                        var prev_y_idx = batch_idx * prev_layer.unitSize + (channel_idx * prev_layer.numRows + (r0 + r2)) * prev_layer.numCols + (c0 + c2);
                        this.y_.dt[output_idx] = prev_y_dt[prev_y_idx];
                    }
                    else{

                        var max_val = -10000;
                        var max_row, max_col;

                        // フィルターの行に対し
                        for (var r2 = 0; r2 < this.filterSize; r2++) {

                            // フィルターの列に対し
                            for (var c2 = 0; c2 < this.filterSize; c2++) {

                                var prev_y_idx = batch_idx * prev_layer.unitSize + (channel_idx * prev_layer.numRows + (r0 + r2)) * prev_layer.numCols + (c0 + c2);
                                var val = prev_y_dt[prev_y_idx];
                                if (max_val < val) {

                                    max_val = val;
                                    max_row = r2;
                                    max_col = c2;
                                }
                            }
                        }

                        this.y_.dt[output_idx] = max_val;
                        this.maxRow[output_idx] = max_row;
                        this.maxCol[output_idx] = max_col;
                    }

                    output_idx++;
                }
            }
        }
    }

    Assert(output_idx == this.y_.dt.length);
    lap.Time();
}