Backward Propagation Through Time

MM 0428/2018


圖一 RNN 網路架構,x¯(t)\bar{x}^{(t)}​​​​\bar{y}^{(t)}​y​¯​​​(t)​​tt時刻的輸入與輸出。

考慮網路架構如圖1,\bar{x}^{(t)}​x​¯​​​(t)​​與\bar{y}^{(t)}​y​¯​​​(t)​​是tt時刻的輸入與輸出。

hidden layer的計算式為

\bar{c}^{(t)}=f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c \bar{c}^{(t-1)}+ \bar{b})​c​¯​​​(t)​​=f(​​W​¯​​​¯​​​x​​​x​¯​​​(t)​​+​​W​¯​​​¯​​​c​​​c​¯​​​(t−1)​​+​b​¯​​)

output\bar{y}​y​¯​​的計算式為

\bar{y}^{(t)}=\textrm{softmax}(\bar{u}^{(t)})=\textrm{softmax}(\bar{\bar{W}}_o \bar{c}^{(t)})​y​¯​​​(t)​​=softmax(​u​¯​​​(t)​​)=softmax(​​W​¯​​​¯​​​o​​​c​¯​​​(t)​​),y_j = \displaystyle \frac{e^{u_j}}{\displaystyle \sum_{j=1}^{\#. \bar{u} \textrm{ element}} e^{u_j}}y​j​​=​​j=1​∑​#.​u​¯​​ element​​e​u​j​​​​​​e​u​j​​​​​​.

Note that the inputs previous timett,\bar{x}^{(t)},\bar{x}^{(t-1)},\bar{x}^{(t-2)},\cdots​x​¯​​​(t)​​,​x​¯​​​(t−1)​​,​x​¯​​​(t−2)​​,⋯, has impact on\bar{y}^{(t)}​y​¯​​​(t)​​as

\bar{y}^{(t)}=\textrm{softmax}(\bar{u}^{(t)}) =\textrm{softmax}(\bar{\bar{W}}_o \bar{c}^{(t)}) = \textrm{softmax}(\bar{\bar{W}}_o f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c \bar{c}^{(t-1)}+ \bar{b}))​y​¯​​​(t)​​=softmax(​u​¯​​​(t)​​)=softmax(​​W​¯​​​¯​​​o​​​c​¯​​​(t)​​)=softmax(​​W​¯​​​¯​​​o​​f(​​W​¯​​​¯​​​x​​​x​¯​​​(t)​​+​​W​¯​​​¯​​​c​​​c​¯​​​(t−1)​​+​b​¯​​))

=\textrm{softmax}(\bar{\bar{W}}_o f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c f(\bar{\bar{W}}_x \bar{x}^{(t-1)} + \bar{\bar{W}}_c \bar{c}^{(t-2)}+ \bar{b})+ \bar{b}))=softmax(​​W​¯​​​¯​​​o​​f(​​W​¯​​​¯​​​x​​​x​¯​​​(t)​​+​​W​¯​​​¯​​​c​​f(​​W​¯​​​¯​​​x​​​x​¯​​​(t−1)​​+​​W​¯​​​¯​​​c​​​c​¯​​​(t−2)​​+​b​¯​​)+​b​¯​​))

=\textrm{softmax}(\bar{\bar{W}}_o f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c f(\bar{\bar{W}}_x \bar{x}^{(t-1)} + \bar{\bar{W}}_c f(\bar{\bar{W}}_x \bar{x}^{(t-2)} + \bar{\bar{W}}_c \bar{c}^{(t-3)}+ \bar{b})+ \bar{b})+ \bar{b}))=softmax(​​W​¯​​​¯​​​o​​f(​​W​¯​​​¯​​​x​​​x​¯​​​(t)​​+​​W​¯​​​¯​​​c​​f(​​W​¯​​​¯​​​x​​​x​¯​​​(t−1)​​+​​W​¯​​​¯​​​c​​f(​​W​¯​​​¯​​​x​​​x​¯​​​(t−2)​​+​​W​¯​​​¯​​​c​​​c​¯​​​(t−3)​​+​b​¯​​)+​b​¯​​)+​b​¯​​))

since

\bar{c}^{(t-1)}=f(\bar{\bar{W}}_x \bar{x}^{(t-1)} + \bar{\bar{W}}_c \bar{c}^{(t-2)}+ \bar{b})​c​¯​​​(t−1)​​=f(​​W​¯​​​¯​​​x​​​x​¯​​​(t−1)​​+​​W​¯​​​¯​​​c​​​c​¯​​​(t−2)​​+​b​¯​​)

\bar{c}^{(t-2)}=f(\bar{\bar{W}}_x \bar{x}^{(t-2)} + \bar{\bar{W}}_c \bar{c}^{(t-3)}+ \bar{b})​c​¯​​​(t−2)​​=f(​​W​¯​​​¯​​​x​​​x​¯​​​(t−2)​​+​​W​¯​​​¯​​​c​​​c​¯​​​(t−3)​​+​b​¯​​)

圖二 inputs previous timett,\bar{x}^{(t)},\bar{x}^{(t-1)},\bar{x}^{(t-2)},\cdots​x​¯​​​(t)​​,​x​¯​​​(t−1)​​,​x​¯​​​(t−2)​​,⋯, has impact on\bar{y}^{(t)}​y​¯​​​(t)​​

圖二呈現\bar{x}^{(t)},\bar{x}^{(t-1)},\bar{x}^{(t-2)},\cdots​x​¯​​​(t)​​,​x​¯​​​(t−1)​​,​x​¯​​​(t−2)​​,⋯, 將影響\bar{y}^{(t)}​y​¯​​​(t)​​


loss function is defined as

L=\displaystyle \sum_{t=1}^T L_t=-\sum_{t=1}^T \log y^{(t)}_{\textrm{target}}L=​t=1​∑​T​​L​t​​=−​t=1​∑​T​​logy​target​(t)​​


Considering\partial L/ \partial \bar{\bar{W}}_o∂L/∂​​W​¯​​​¯​​​o​​,

\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}o}= \displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial \bar{\bar{W}}_o} = \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_\textrm{target}} \frac{\partial y^{(t)}_\textrm{target}}{\partial u^{(t)}_\textrm{target}} \frac{\partial u^{(t)}_\textrm{target}}{\partial \bar{\bar{W}}_o}​∂​​W​¯​​​¯​​o​​∂L​​=​t=1​∑​T​​​∂​​W​¯​​​¯​​​o​​​​∂L​t​​​​=​t=1​∑​T​​​∂y​target​(t)​​​​∂L​t​​​​​∂u​target​(t)​​​​∂y​target​(t)​​​​​∂​​W​¯​​​¯​​​o​​​​∂u​target​(t)​​​​

= \displaystyle \frac{\partial L_1}{\partial y^{(1)}_\textrm{target}} \frac{\partial y^{(1)}_\textrm{target}}{\partial u^{(1)}_\textrm{target}} \frac{\partial u^{(1)}_\textrm{target}}{\partial \bar{\bar{W}}_o} + \frac{\partial L_2}{\partial y^{(2)}_\textrm{target}} \frac{\partial y^{(2)}_\textrm{target}}{\partial u^{(2)}_\textrm{target}} \frac{\partial u^{(2)}_\textrm{target}}{\partial \bar{\bar{W}}_o} + \sum_{t=3}^T \frac{\partial L_t}{\partial y^{(t)}_\textrm{target}} \frac{\partial y^{(t)}_\textrm{target}}{\partial u^{(t)}_\textrm{target}} \frac{\partial u^{(t)}_\textrm{target}}{\partial \bar{\bar{W}}_o}=​∂y​target​(1)​​​​∂L​1​​​​​∂u​target​(1)​​​​∂y​target​(1)​​​​​∂​​W​¯​​​¯​​​o​​​​∂u​target​(1)​​​​+​∂y​target​(2)​​​​∂L​2​​​​​∂u​target​(2)​​​​∂y​target​(2)​​​​​∂​​W​¯​​​¯​​​o​​​​∂u​target​(2)​​​​+​t=3​∑​T​​​∂y​target​(t)​​​​∂L​t​​​​​∂u​target​(t)​​​​∂y​target​(t)​​​​​∂​​W​¯​​​¯​​​o​​​​∂u​target​(t)​​​​

其中,\displaystyle \frac{\partial L_1}{\partial y^{(1)}_\textrm{target}} \frac{\partial y^{(1)}_\textrm{target}}{\partial u^{(1)}_\textrm{target}} \frac{\partial u^{(1)}_\textrm{target}}{\partial \bar{\bar{W}}_o}​∂y​target​(1)​​​​∂L​1​​​​​∂u​target​(1)​​​​∂y​target​(1)​​​​​∂​​W​¯​​​¯​​​o​​​​∂u​target​(1)​​​​是來自第一個字loss (L_1L​1​​) 貢獻的修正量,\displaystyle \frac{\partial L_2}{\partial y^{(2)}_\textrm{target}} \frac{\partial y^{(2)}_\textrm{target}}{\partial u^{(2)}_\textrm{target}} \frac{\partial u^{(2)}_\textrm{target}}{\partial \bar{\bar{W}}_o}​∂y​target​(2)​​​​∂L​2​​​​​∂u​target​(2)​​​​∂y​target​(2)​​​​​∂​​W​¯​​​¯​​​o​​​​∂u​target​(2)​​​​是來自第二個字loss (L_2L​2​​) 貢獻的修正量。

Note thaty_{\textrm{target}} = \displaystyle \frac{e^{u_{\textrm{target} }} }{ \sum_{j=1} e^{u_j}}y​target​​=​∑​j=1​​e​u​j​​​​​​e​u​target​​​​​​,\frac{\partial y_{\textrm{target}}}{\partial u_{\textrm{target}}}​∂u​target​​​​∂y​target​​​​can be obtained as

\frac{\partial y_{\textrm{target}}}{\partial u_{\textrm{target}}} = \displaystyle \frac{e^{u_{\textrm{target}}}}{\sum_{j=1} e^{u_j}} - \frac{e^{u_{\textrm{target}}} \times e^{u_{\textrm{target}}} }{(\sum_{j=1} e^{u_j})^2}=y_{\textrm{target}} - y_{\textrm{target}}^2​∂u​target​​​​∂y​target​​​​=​∑​j=1​​e​u​j​​​​​​e​u​target​​​​​​−​(∑​j=1​​e​u​j​​​​)​2​​​​e​u​target​​​​×e​u​target​​​​​​=y​target​​−y​target​2​​.


Considering\partial L/ \partial \bar{\bar{W}}_c∂L/∂​​W​¯​​​¯​​​c​​,

圖三 同圖一,輔助講解chain rule用。

\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}= \frac{\partial }{\partial \bar{\bar{W}}_c} \displaystyle \sum_{t=1}^T L_t = \displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial c^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{\bar{W}}_c}​∂​​W​¯​​​¯​​​c​​​​∂L​​=​∂​​W​¯​​​¯​​​c​​​​∂​​​t=1​∑​T​​L​t​​=​t=1​∑​T​​​∂c​(t)​​​​∂L​t​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(t)​​​​

From chain rule,

\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}=\sum_{t=1}^T \sum_{k=1}^t \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{y^{(t)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(t)}} \frac{\partial u_{\textrm{target}}^{(t)}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}​∂​​W​¯​​​¯​​​c​​​​∂L​​=​t=1​∑​T​​​k=1​∑​t​​​∂y​target​(t)​​​​∂L​t​​​​​∂u​target​(t)​​​​y​target​(t)​​​​​∂​c​¯​​​(t)​​​​∂u​target​(t)​​​​​∂​c​¯​​​(k)​​​​∂​c​¯​​​(t)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(k)​​​​

\displaystyle = \frac{\partial L_1}{\partial y^{(1)}_{\textrm{target}}} \frac{y^{(1)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(1)}} \frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{c}^{(1)}} \frac{\partial \bar{c}^{(1)}}{\partial \bar{\bar{W}}_c} + \sum_{k=1}^2 \frac{\partial L_2}{\partial y^{(2)}_{\textrm{target}}} \frac{y^{(2)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(2)}} \frac{\partial u_{\textrm{target}}^{(2)}}{\partial \bar{c}^{(2)}} \frac{\partial \bar{c}^{(2)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}=​∂y​target​(1)​​​​∂L​1​​​​​∂u​target​(1)​​​​y​target​(1)​​​​​∂​c​¯​​​(1)​​​​∂u​target​(1)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(1)​​​​+​k=1​∑​2​​​∂y​target​(2)​​​​∂L​2​​​​​∂u​target​(2)​​​​y​target​(2)​​​​​∂​c​¯​​​(2)​​​​∂u​target​(2)​​​​​∂​c​¯​​​(k)​​​​∂​c​¯​​​(2)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(k)​​​​

\displaystyle + \sum_{t=3}^T \sum_{k=1}^t \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{y^{(t)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(t)}} \frac{\partial u_{\textrm{target}}^{(t)}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}+​t=3​∑​T​​​k=1​∑​t​​​∂y​target​(t)​​​​∂L​t​​​​​∂u​target​(t)​​​​y​target​(t)​​​​​∂​c​¯​​​(t)​​​​∂u​target​(t)​​​​​∂​c​¯​​​(k)​​​​∂​c​¯​​​(t)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(k)​​​​

其中,\displaystyle \frac{\partial L_1}{\partial y^{(1)}_{\textrm{target}}} \frac{y^{(1)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(1)}} \frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{c}^{(1)}} \frac{\partial \bar{c}^{(1)}}{\partial \bar{\bar{W}}_c}​∂y​target​(1)​​​​∂L​1​​​​​∂u​target​(1)​​​​y​target​(1)​​​​​∂​c​¯​​​(1)​​​​∂u​target​(1)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(1)​​​​是來自第一個字loss (L_1L​1​​) 貢獻的修正量。而且

\frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{\bar{W}}_c} =\displaystyle \sum_{n_c=1}^{N_c} \frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{c}_{n_c}^{(1)}} \frac{\partial \bar{c}_{n_c}^{(1)}}{\partial \bar{\bar{W}}_c}​∂​​W​¯​​​¯​​​c​​​​∂u​target​(1)​​​​=​n​c​​=1​∑​N​c​​​​​∂​c​¯​​​n​c​​​(1)​​​​∂u​target​(1)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​n​c​​​(1)​​​​。

又,\displaystyle \sum_{k=1}^2 \frac{\partial L_2}{\partial y^{(2)}_{\textrm{target}}} \frac{y^{(2)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(2)}} \frac{\partial u_{\textrm{target}}^{(2)}}{\partial \bar{c}^{(2)}} \frac{\partial \bar{c}^{(2)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}​k=1​∑​2​​​∂y​target​(2)​​​​∂L​2​​​​​∂u​target​(2)​​​​y​target​(2)​​​​​∂​c​¯​​​(2)​​​​∂u​target​(2)​​​​​∂​c​¯​​​(k)​​​​∂​c​¯​​​(2)​​​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(k)​​​​是來自第二個字loss (L_2L​2​​) 貢獻的修正量。

從\bar{c}^{(t)}=f(\bar{\bar{W}}_c \bar{c}^{(t-1)} + \bar{\bar{W}}_x \bar{x}^{(t)} + \bar{b})​c​¯​​​(t)​​=f(​​W​¯​​​¯​​​c​​​c​¯​​​(t−1)​​+​​W​¯​​​¯​​​x​​​x​¯​​​(t)​​+​b​¯​​), 我們可以得出

\displaystyle \frac{\partial \bar{c}^{(t)}}{\partial \bar{c}^{(k)}} = \prod_{i=k+1}^t \frac{\partial \bar{c}^{(i)}}{\partial \bar{c}^{(i-1)}} = \prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{ diag}[f'(\bar{c}^{(i-1)})]​∂​c​¯​​​(k)​​​​∂​c​¯​​​(t)​​​​=​i=k+1​∏​t​​​∂​c​¯​​​(i−1)​​​​∂​c​¯​​​(i)​​​​=​i=k+1​∏​t​​​​W​¯​​​¯​​​c​t​​ diag[f​′​​(​c​¯​​​(i−1)​​)]

therefore,\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}​∂​​W​¯​​​¯​​​c​​​​∂L​​的計算式為

\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}=\sum_{t=1}^T \frac{\partial y^{(t)}_{\textrm{target}} }{\partial \bar{c}^{(t)}} \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \sum_{k=1}^t \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c} ( \prod_{i=k+1}^t \bar{\bar{W}}_c \textrm{ diag}[f'(\bar{c}^{(i-1)})])​∂​​W​¯​​​¯​​​c​​​​∂L​​=​t=1​∑​T​​​∂​c​¯​​​(t)​​​​∂y​target​(t)​​​​​∂y​target​(t)​​​​∂L​t​​​​​k=1​∑​t​​​∂​​W​¯​​​¯​​​c​​​​∂​c​¯​​​(k)​​​​(​i=k+1​∏​t​​​​W​¯​​​¯​​​c​​ diag[f​′​​(​c​¯​​​(i−1)​​)])


Considering\partial L/ \partial \bar{\bar{W}}_x∂L/∂​​W​¯​​​¯​​​x​​,

\frac{\partial L}{\partial \bar{\bar{W}}_x}=\displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{\bar{W}}_x}​∂​​W​¯​​​¯​​​x​​​​∂L​​=​t=1​∑​T​​​∂y​target​(t)​​​​∂L​t​​​​​∂​c​¯​​​(t)​​​​∂y​target​(t)​​​​​∂​​W​¯​​​¯​​​x​​​​∂​c​¯​​​(t)​​​​

=\displaystyle \sum_{t=1}^T \sum_{k=1}^t \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_x}=​t=1​∑​T​​​k=1​∑​t​​​∂y​target​(t)​​​​∂L​t​​​​​∂​c​¯​​​(t)​​​​∂y​target​(t)​​​​​​c​¯​​​(k)​​​​∂​c​¯​​​(t)​​​​​∂​​W​¯​​​¯​​​x​​​​∂​c​¯​​​(k)​​​​

=\displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \sum_{k=1}^t \frac{\partial \bar{c}^{(t)}}{\bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_x}=​t=1​∑​T​​​∂y​target​(t)​​​​∂L​t​​​​​∂​c​¯​​​(t)​​​​∂y​target​(t)​​​​​k=1​∑​t​​​​c​¯​​​(k)​​​​∂​c​¯​​​(t)​​​​​∂​​W​¯​​​¯​​​x​​​​∂​c​¯​​​(k)​​​​

=\displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \sum_{k=1}^t (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})]) \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_x}=​t=1​∑​T​​​∂y​target​(t)​​​​∂L​t​​​​​∂​c​¯​​​(t)​​​​∂y​target​(t)​​​​​k=1​∑​t​​(​i=k+1​∏​t​​​​W​¯​​​¯​​​c​t​​diag[f​′​​(​c​¯​​​(i−1)​​)])​∂​​W​¯​​​¯​​​x​​​​∂​c​¯​​​(k)​​​​

其中\displaystyle (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})])(​i=k+1​∏​t​​​​W​¯​​​¯​​​c​t​​diag[f​′​​(​c​¯​​​(i−1)​​)])此項目是連續相乘,若是\bar{\bar{W}}_c​​W​¯​​​¯​​​c​​裡頭的元素element的絕對值大於1,則會使\displaystyle (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})])(​i=k+1​∏​t​​​​W​¯​​​¯​​​c​t​​diag[f​′​​(​c​¯​​​(i−1)​​)])的值過大(爆炸),若是元素element的絕對值小於1,則會使\displaystyle (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})])(​i=k+1​∏​t​​​​W​¯​​​¯​​​c​t​​diag[f​′​​(​c​¯​​​(i−1)​​)]) 的值過小(gradient vanishing)。

results matching ""

    No results matching ""