「ゼロから作るDeep Learning 2 ―自然言語処理編」のRNNコードと数式の関係
前回、RNNのバックプロパゲーション(誤差逆伝搬)の式(活性化関数が「tanh()」関数の場合)
を導きました。
この関係をみれば、「1つ上の(l+1)層」と「1つ先の時間、(t+1)層」から
をもらってくれば「(l)層のdL/du」を求めることができるということが分かりますね。
(いままで同様バイアスに関しては省略させてもらってます)
つまり、この式に(l)層の活性化関数の微分を掛けて(今の場合「tanh()」関数の微分、「(1-h)^2」を掛けて)
となって
が得られます。(これをそのまんま、以下、コードで確認します)
DNNの時と、同じで
「数式とコードの関係をきちんと理解したい!」
となりまして、、、、、、
まず、メインルーチンを確認しますと、、、、
「class SimpleRnnlm」を「model」として、インスタンス化して、「model.backward()」でバックプロパゲーション(誤差逆伝播)の計算をしているとなっています。
それではということで、「class SimpleRnnlm」を確認すると、、、、、、
レイヤーが、
「TimeEmbedding(embed_W)」
「TimeRNN(rnn_Wx, rnn_Wh, rnn_b, stateful=True)」
「TimeAffine(affine_W, affine_b)」
「TimeSoftmaxWithLoss()」
で構成されていて、バックプロパゲーション(誤差逆伝播)の計算を、「この順番をひっくり返して行う(reversed)」というコードになっています。
ここからさらに、個々のレイヤーのバックプロパゲーション(誤差逆伝播)つまり「backward」を確認しますと、、、、、
さらに対応する、それぞれのレイヤの「backward」部分だけを書き出すと、
この4つのレイヤーを逆に並べて、「backward」からバックプロパゲーション(誤差逆伝播)を行っているというわけですね。
ただし、ちょっとだけ、注意がいるのが、
「class TimeRNN」
は、内部で
「class RNN」
を呼び出しているということですね。
ここまでで、コードの構造を確認したので、RNNコードと数式の関係を確認できそうです。
確認したいのは、バックプロパゲーション(誤差逆伝播)の式
が、「コードでどのように対応しているか」です。
まず、出力層の「TimeSoftmaxWithLoss()」は、Softmax関数を交差エントロピー誤差で計算しています。その場合、交差エントロピー誤差の勾配(微分)は(y - t)になりますので。
(y - t) は 「dL/du(l+1)層」
に対応するということですね。というわけで
これを「dx」としてreturnします。
(ここ!私、分かったように書いていますが、「TimeSoftmaxWithLoss().backward」のコードはDNNの時と一緒で、よく分かってないです。読みにくいコードは、ホントに困ります、はあ~難し!)
で、「dx」は「dout」として「TimeAffine(affine_W, affine_b).backward」レイヤーに入りますので、
また、同じようにこれを「dx」としてreturnします。ここまでは(l + 1)層の分です。
(というか、正確には出力層の分)
で、次は(l)層に移って「dx」を「dhs」として「TimeRNN.backward」レイヤーに渡します。下の図で言うと、一番上の赤四角の中の、「dhs」です。
その後は、RNN時間ループ分くるくるバックプロパゲーション計算します。
ここで、大事なことは、RNNのバックプロパゲーションの場合は、(l+1)層(dx(dhs))と(t-1)層(dh)の両方から誤差勾配をもらうということですね。下の図で言うと、「
赤枠の中」、
上の「TimeRNN.backward」レイヤーのコードではここ。
数式で言うとここ。
「RNN.backward」コードを見てみると、以下のように計算しています。
つまり、
「(l+1、t)層」「(l、t+1)層」
の誤差勾配をもらって
「(l、t)層」
の誤差勾配を求めて
「(lー1、t)層」「(l、tー1)層」
のそれぞれに、誤差勾配を渡す。
と、なるのですね。
ここで、RNNのバックプロパゲーションの動きを再度確認するのに下の図を見てください。
データは「層」方向(下図で縦のライン、緑色、黄色、青色)と「時間」方向(下図で横のライン、黄色)の2つがあります
「層」と「時間」の両方で誤差勾配をやり取りしますが、時間でのやり取りは「Forループ」を使うので、黄色の「時間」方向は、リストに保存しませんが、「層」方向(緑色、黄色、青色)の誤差勾配のデータは一旦リストに保存する必要があります。(このコードだとRNNの中間層が1層しかないので分かり難いですが、、、、、、)
もう一度コードで確認すると
ここですね。コードでは「dx」の項は、計算が(l)層から(l-1)層に移らないと使わないので、一旦「dxs」にリストとして保存しているわけですね。
ということで整理するため全体の流れを図だけでみてみると
となるかと思います。①~⑪を順番に追っていくと、コードの流れが分かるでしょうか、、、、、
以上です。
コメント
コメントを投稿