LSTMの例

LSTMは下図のように内部にメモリセルと、入力・忘却・出力の3個のゲートを持った構造をしています。

\( x_{t}^{i} \) : 入力 、 \( u_{t}^{j} \) : 入力の重み付き加算 、 \( s_{t}^{j} \) : メモリセル、 \( y_{t}^{j} \) : 出力
\( uI_{t}^{j} \) : 入力ゲート、 \( uF_{t}^{j} \) : 忘却ゲート、 \( uO_{t}^{j} \) : 出力ゲート


以下はLSTMの順伝播の式です。 $$ y_{t}^{j} = σ(uO_{t}^{j}) \cdot σ(s_{t}^{j}) \\ s_{t}^{j} = σ(uF_{t}^{j}) \cdot s_{t - 1}^{j} + σ(uI_{t}^{j}) \cdot σ(u_{t}^{j}) \\ uO_{t}^{j} = ((\displaystyle \sum_{i }^{ X } wOin_{j}^{i} \cdot x_{t}^{i} + \displaystyle \sum_{i }^{ Y } wOr_{j}^{i} \cdot y_{t - 1}^{i}) + wO_{j} \cdot s_{t}^{j}) + bO_{j} \\ uF_{t}^{j} = ((\displaystyle \sum_{i }^{ X } wFin_{j}^{i} \cdot x_{t}^{i} + \displaystyle \sum_{i }^{ Y } wFr_{j}^{i} \cdot y_{t - 1}^{i}) + wF_{j} \cdot s_{t - 1}^{j}) + bF_{j} \\ uI_{t}^{j} = ((\displaystyle \sum_{i }^{ X } wIin_{j}^{i} \cdot x_{t}^{i} + \displaystyle \sum_{i }^{ Y } wIr_{j}^{i} \cdot y_{t - 1}^{i}) + wI_{j} \cdot s_{t - 1}^{j}) + bI_{j} \\ u_{t}^{j} = (\displaystyle \sum_{i }^{ X } win_{j}^{i} \cdot x_{t}^{i} + \displaystyle \sum_{i }^{ Y } wr_{j}^{i} \cdot y_{t - 1}^{i}) + b_{j} $$ \( wOin_{j}^{i} \)など名前がwで始まるの変数は重みのパラメータで、 \( bO_{j} \)など名前がbで始まるの変数はバイアスのパラメータです。

これらの条件をMkFnでは、以下のようなC#のコードで書きます。

public class LSTMLayer : Layer {
    public int T;
    public int X;
    public int Y;

    public double[,] x;
    public double[,] y;

    public double[,] wIin;
    public double[,] wFin;
    public double[,] wOin;
    public double[,] win;

    public double[,] wIr;
    public double[,] wFr;
    public double[,] wOr;
    public double[,] wr;

    public double[] wI;
    public double[] wF;
    public double[] wO;

    public double[] bO;
    public double[] bF;
    public double[] bI;
    public double[] b;

    public double[,] u;
    public double[,] s;

    public double[,] uI;
    public double[,] uF;
    public double[,] uO;

    public LSTMLayer(int t_size, int x_size, int y_size) {
        T = t_size;
        X = x_size;
        Y = y_size;
        x = new double[T, X];
        y = new double[T, Y];

        wIin = new double[Y, X];
        wFin = new double[Y, X];
        wOin = new double[Y, X];
        win = new double[Y, X];

        wIr = new double[Y, Y];
        wFr = new double[Y, Y];
        wOr = new double[Y, Y];
        wr = new double[Y, Y];

        wI = new double[Y];
        wF = new double[Y];
        wO = new double[Y];
        bO = new double[Y];
        bF = new double[Y];
        bI = new double[Y];
        b = new double[Y];

        u = new double[T, Y];
        s = new double[T, Y];
        uI = new double[T, Y];
        uF = new double[T, Y];
        uO = new double[T, Y];
    }

    public override void Forward() {
        foreach (int t in Range(T)) {
            foreach (int j in Range(Y)) {
                y[t, j] = σ(uO[t, j]) * σ(s[t, j]);
                s[t, j] = σ(uF[t, j]) * s[t - 1, j] + σ(uI[t, j]) * σ(u[t, j]);
                uO[t, j] = (from i in Range(X) select wOin[j, i] * x[t, i]).Sum() + (from i in Range(Y) select wOr[j, i] * y[t - 1, i]).Sum() + wO[j] * s[t, j] + bO[j];
                uF[t, j] = (from i in Range(X) select wFin[j, i] * x[t, i]).Sum() + (from i in Range(Y) select wFr[j, i] * y[t - 1, i]).Sum() + wF[j] * s[t - 1, j] + bF[j];
                uI[t, j] = (from i in Range(X) select wIin[j, i] * x[t, i]).Sum() + (from i in Range(Y) select wIr[j, i] * y[t - 1, i]).Sum() + wI[j] * s[t - 1, j] + bI[j];
                u[t, j] = (from i in Range(X) select win[j, i] * x[t, i]).Sum() + (from i in Range(Y) select wr[j, i] * y[t - 1, i]).Sum() + b[j];
            }
        }
    }
}
MkFnはC#のソースコードを解析してから微分や数式の簡約化をして以下のような誤差逆伝播の式を生成します。 $$ \frac{ \partial E }{ \partial s_{t}^{j} } = \delta y_{t}^{j} \cdot σ(uO_{t}^{j}) \cdot σ'(s_{t}^{j}) + \delta s_{t + 1}^{j} \cdot σ(uF_{t + 1}^{j}) + \delta uO_{t}^{j} \cdot wO_{j} + \delta uF_{t + 1}^{j} \cdot wF_{j} + \delta uI_{t + 1}^{j} \cdot wI_{j} $$