Backward Propagation Through Time
MM 0428/2018
圖一 RNN 網路架構,與\bar{y}^{(t)}y¯(t)是tt時刻的輸入與輸出。
考慮網路架構如圖1,\bar{x}^{(t)}x¯(t)與\bar{y}^{(t)}y¯(t)是tt時刻的輸入與輸出。
hidden layer的計算式為
\bar{c}^{(t)}=f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c \bar{c}^{(t-1)}+ \bar{b})c¯(t)=f(W¯¯xx¯(t)+W¯¯cc¯(t−1)+b¯)
output\bar{y}y¯的計算式為
\bar{y}^{(t)}=\textrm{softmax}(\bar{u}^{(t)})=\textrm{softmax}(\bar{\bar{W}}_o \bar{c}^{(t)})y¯(t)=softmax(u¯(t))=softmax(W¯¯oc¯(t)),y_j = \displaystyle \frac{e^{u_j}}{\displaystyle \sum_{j=1}^{\#. \bar{u} \textrm{ element}} e^{u_j}}yj=j=1∑#.u¯ elementeujeuj.
Note that the inputs previous timett,\bar{x}^{(t)},\bar{x}^{(t-1)},\bar{x}^{(t-2)},\cdotsx¯(t),x¯(t−1),x¯(t−2),⋯, has impact on\bar{y}^{(t)}y¯(t)as
\bar{y}^{(t)}=\textrm{softmax}(\bar{u}^{(t)}) =\textrm{softmax}(\bar{\bar{W}}_o \bar{c}^{(t)}) = \textrm{softmax}(\bar{\bar{W}}_o f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c \bar{c}^{(t-1)}+ \bar{b}))y¯(t)=softmax(u¯(t))=softmax(W¯¯oc¯(t))=softmax(W¯¯of(W¯¯xx¯(t)+W¯¯cc¯(t−1)+b¯))
=\textrm{softmax}(\bar{\bar{W}}_o f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c f(\bar{\bar{W}}_x \bar{x}^{(t-1)} + \bar{\bar{W}}_c \bar{c}^{(t-2)}+ \bar{b})+ \bar{b}))=softmax(W¯¯of(W¯¯xx¯(t)+W¯¯cf(W¯¯xx¯(t−1)+W¯¯cc¯(t−2)+b¯)+b¯))
=\textrm{softmax}(\bar{\bar{W}}_o f(\bar{\bar{W}}_x \bar{x}^{(t)} + \bar{\bar{W}}_c f(\bar{\bar{W}}_x \bar{x}^{(t-1)} + \bar{\bar{W}}_c f(\bar{\bar{W}}_x \bar{x}^{(t-2)} + \bar{\bar{W}}_c \bar{c}^{(t-3)}+ \bar{b})+ \bar{b})+ \bar{b}))=softmax(W¯¯of(W¯¯xx¯(t)+W¯¯cf(W¯¯xx¯(t−1)+W¯¯cf(W¯¯xx¯(t−2)+W¯¯cc¯(t−3)+b¯)+b¯)+b¯))
since
\bar{c}^{(t-1)}=f(\bar{\bar{W}}_x \bar{x}^{(t-1)} + \bar{\bar{W}}_c \bar{c}^{(t-2)}+ \bar{b})c¯(t−1)=f(W¯¯xx¯(t−1)+W¯¯cc¯(t−2)+b¯)
\bar{c}^{(t-2)}=f(\bar{\bar{W}}_x \bar{x}^{(t-2)} + \bar{\bar{W}}_c \bar{c}^{(t-3)}+ \bar{b})c¯(t−2)=f(W¯¯xx¯(t−2)+W¯¯cc¯(t−3)+b¯)
圖二 inputs previous timett,\bar{x}^{(t)},\bar{x}^{(t-1)},\bar{x}^{(t-2)},\cdotsx¯(t),x¯(t−1),x¯(t−2),⋯, has impact on\bar{y}^{(t)}y¯(t)
圖二呈現\bar{x}^{(t)},\bar{x}^{(t-1)},\bar{x}^{(t-2)},\cdotsx¯(t),x¯(t−1),x¯(t−2),⋯, 將影響\bar{y}^{(t)}y¯(t)
loss function is defined as
L=\displaystyle \sum_{t=1}^T L_t=-\sum_{t=1}^T \log y^{(t)}_{\textrm{target}}L=t=1∑TLt=−t=1∑Tlogytarget(t)
Considering\partial L/ \partial \bar{\bar{W}}_o∂L/∂W¯¯o,
\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}o}= \displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial \bar{\bar{W}}_o} = \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_\textrm{target}} \frac{\partial y^{(t)}_\textrm{target}}{\partial u^{(t)}_\textrm{target}} \frac{\partial u^{(t)}_\textrm{target}}{\partial \bar{\bar{W}}_o}∂W¯¯o∂L=t=1∑T∂W¯¯o∂Lt=t=1∑T∂ytarget(t)∂Lt∂utarget(t)∂ytarget(t)∂W¯¯o∂utarget(t)
= \displaystyle \frac{\partial L_1}{\partial y^{(1)}_\textrm{target}} \frac{\partial y^{(1)}_\textrm{target}}{\partial u^{(1)}_\textrm{target}} \frac{\partial u^{(1)}_\textrm{target}}{\partial \bar{\bar{W}}_o} + \frac{\partial L_2}{\partial y^{(2)}_\textrm{target}} \frac{\partial y^{(2)}_\textrm{target}}{\partial u^{(2)}_\textrm{target}} \frac{\partial u^{(2)}_\textrm{target}}{\partial \bar{\bar{W}}_o} + \sum_{t=3}^T \frac{\partial L_t}{\partial y^{(t)}_\textrm{target}} \frac{\partial y^{(t)}_\textrm{target}}{\partial u^{(t)}_\textrm{target}} \frac{\partial u^{(t)}_\textrm{target}}{\partial \bar{\bar{W}}_o}=∂ytarget(1)∂L1∂utarget(1)∂ytarget(1)∂W¯¯o∂utarget(1)+∂ytarget(2)∂L2∂utarget(2)∂ytarget(2)∂W¯¯o∂utarget(2)+t=3∑T∂ytarget(t)∂Lt∂utarget(t)∂ytarget(t)∂W¯¯o∂utarget(t)
其中,\displaystyle \frac{\partial L_1}{\partial y^{(1)}_\textrm{target}} \frac{\partial y^{(1)}_\textrm{target}}{\partial u^{(1)}_\textrm{target}} \frac{\partial u^{(1)}_\textrm{target}}{\partial \bar{\bar{W}}_o}∂ytarget(1)∂L1∂utarget(1)∂ytarget(1)∂W¯¯o∂utarget(1)是來自第一個字loss (L_1L1) 貢獻的修正量,\displaystyle \frac{\partial L_2}{\partial y^{(2)}_\textrm{target}} \frac{\partial y^{(2)}_\textrm{target}}{\partial u^{(2)}_\textrm{target}} \frac{\partial u^{(2)}_\textrm{target}}{\partial \bar{\bar{W}}_o}∂ytarget(2)∂L2∂utarget(2)∂ytarget(2)∂W¯¯o∂utarget(2)是來自第二個字loss (L_2L2) 貢獻的修正量。
Note thaty_{\textrm{target}} = \displaystyle \frac{e^{u_{\textrm{target} }} }{ \sum_{j=1} e^{u_j}}ytarget=∑j=1eujeutarget,\frac{\partial y_{\textrm{target}}}{\partial u_{\textrm{target}}}∂utarget∂ytargetcan be obtained as
\frac{\partial y_{\textrm{target}}}{\partial u_{\textrm{target}}} = \displaystyle \frac{e^{u_{\textrm{target}}}}{\sum_{j=1} e^{u_j}} - \frac{e^{u_{\textrm{target}}} \times e^{u_{\textrm{target}}} }{(\sum_{j=1} e^{u_j})^2}=y_{\textrm{target}} - y_{\textrm{target}}^2∂utarget∂ytarget=∑j=1eujeutarget−(∑j=1euj)2eutarget×eutarget=ytarget−ytarget2.
Considering\partial L/ \partial \bar{\bar{W}}_c∂L/∂W¯¯c,
圖三 同圖一,輔助講解chain rule用。
\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}= \frac{\partial }{\partial \bar{\bar{W}}_c} \displaystyle \sum_{t=1}^T L_t = \displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial c^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{\bar{W}}_c}∂W¯¯c∂L=∂W¯¯c∂t=1∑TLt=t=1∑T∂c(t)∂Lt∂W¯¯c∂c¯(t)
From chain rule,
\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}=\sum_{t=1}^T \sum_{k=1}^t \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{y^{(t)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(t)}} \frac{\partial u_{\textrm{target}}^{(t)}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}∂W¯¯c∂L=t=1∑Tk=1∑t∂ytarget(t)∂Lt∂utarget(t)ytarget(t)∂c¯(t)∂utarget(t)∂c¯(k)∂c¯(t)∂W¯¯c∂c¯(k)
\displaystyle = \frac{\partial L_1}{\partial y^{(1)}_{\textrm{target}}} \frac{y^{(1)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(1)}} \frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{c}^{(1)}} \frac{\partial \bar{c}^{(1)}}{\partial \bar{\bar{W}}_c} + \sum_{k=1}^2 \frac{\partial L_2}{\partial y^{(2)}_{\textrm{target}}} \frac{y^{(2)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(2)}} \frac{\partial u_{\textrm{target}}^{(2)}}{\partial \bar{c}^{(2)}} \frac{\partial \bar{c}^{(2)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}=∂ytarget(1)∂L1∂utarget(1)ytarget(1)∂c¯(1)∂utarget(1)∂W¯¯c∂c¯(1)+k=1∑2∂ytarget(2)∂L2∂utarget(2)ytarget(2)∂c¯(2)∂utarget(2)∂c¯(k)∂c¯(2)∂W¯¯c∂c¯(k)
\displaystyle + \sum_{t=3}^T \sum_{k=1}^t \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{y^{(t)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(t)}} \frac{\partial u_{\textrm{target}}^{(t)}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}+t=3∑Tk=1∑t∂ytarget(t)∂Lt∂utarget(t)ytarget(t)∂c¯(t)∂utarget(t)∂c¯(k)∂c¯(t)∂W¯¯c∂c¯(k)
其中,\displaystyle \frac{\partial L_1}{\partial y^{(1)}_{\textrm{target}}} \frac{y^{(1)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(1)}} \frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{c}^{(1)}} \frac{\partial \bar{c}^{(1)}}{\partial \bar{\bar{W}}_c}∂ytarget(1)∂L1∂utarget(1)ytarget(1)∂c¯(1)∂utarget(1)∂W¯¯c∂c¯(1)是來自第一個字loss (L_1L1) 貢獻的修正量。而且
\frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{\bar{W}}_c} =\displaystyle \sum_{n_c=1}^{N_c} \frac{\partial u_{\textrm{target}}^{(1)}}{\partial \bar{c}_{n_c}^{(1)}} \frac{\partial \bar{c}_{n_c}^{(1)}}{\partial \bar{\bar{W}}_c}∂W¯¯c∂utarget(1)=nc=1∑Nc∂c¯nc(1)∂utarget(1)∂W¯¯c∂c¯nc(1)。
又,\displaystyle \sum_{k=1}^2 \frac{\partial L_2}{\partial y^{(2)}_{\textrm{target}}} \frac{y^{(2)}_{\textrm{target}}}{\partial u_{\textrm{target}}^{(2)}} \frac{\partial u_{\textrm{target}}^{(2)}}{\partial \bar{c}^{(2)}} \frac{\partial \bar{c}^{(2)}}{\partial \bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c}k=1∑2∂ytarget(2)∂L2∂utarget(2)ytarget(2)∂c¯(2)∂utarget(2)∂c¯(k)∂c¯(2)∂W¯¯c∂c¯(k)是來自第二個字loss (L_2L2) 貢獻的修正量。
從\bar{c}^{(t)}=f(\bar{\bar{W}}_c \bar{c}^{(t-1)} + \bar{\bar{W}}_x \bar{x}^{(t)} + \bar{b})c¯(t)=f(W¯¯cc¯(t−1)+W¯¯xx¯(t)+b¯), 我們可以得出
\displaystyle \frac{\partial \bar{c}^{(t)}}{\partial \bar{c}^{(k)}} = \prod_{i=k+1}^t \frac{\partial \bar{c}^{(i)}}{\partial \bar{c}^{(i-1)}} = \prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{ diag}[f'(\bar{c}^{(i-1)})]∂c¯(k)∂c¯(t)=i=k+1∏t∂c¯(i−1)∂c¯(i)=i=k+1∏tW¯¯ct diag[f′(c¯(i−1))]
therefore,\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}∂W¯¯c∂L的計算式為
\displaystyle \frac{\partial L}{\partial \bar{\bar{W}}_c}=\sum_{t=1}^T \frac{\partial y^{(t)}_{\textrm{target}} }{\partial \bar{c}^{(t)}} \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \sum_{k=1}^t \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_c} ( \prod_{i=k+1}^t \bar{\bar{W}}_c \textrm{ diag}[f'(\bar{c}^{(i-1)})])∂W¯¯c∂L=t=1∑T∂c¯(t)∂ytarget(t)∂ytarget(t)∂Ltk=1∑t∂W¯¯c∂c¯(k)(i=k+1∏tW¯¯c diag[f′(c¯(i−1))])
Considering\partial L/ \partial \bar{\bar{W}}_x∂L/∂W¯¯x,
\frac{\partial L}{\partial \bar{\bar{W}}_x}=\displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\partial \bar{\bar{W}}_x}∂W¯¯x∂L=t=1∑T∂ytarget(t)∂Lt∂c¯(t)∂ytarget(t)∂W¯¯x∂c¯(t)
=\displaystyle \sum_{t=1}^T \sum_{k=1}^t \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \frac{\partial \bar{c}^{(t)}}{\bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_x}=t=1∑Tk=1∑t∂ytarget(t)∂Lt∂c¯(t)∂ytarget(t)c¯(k)∂c¯(t)∂W¯¯x∂c¯(k)
=\displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \sum_{k=1}^t \frac{\partial \bar{c}^{(t)}}{\bar{c}^{(k)}} \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_x}=t=1∑T∂ytarget(t)∂Lt∂c¯(t)∂ytarget(t)k=1∑tc¯(k)∂c¯(t)∂W¯¯x∂c¯(k)
=\displaystyle \sum_{t=1}^T \frac{\partial L_t}{\partial y^{(t)}_{\textrm{target}}} \frac{\partial y^{(t)}_{\textrm{target}}}{\partial \bar{c}^{(t)}} \sum_{k=1}^t (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})]) \frac{\partial \bar{c}^{(k)}}{\partial \bar{\bar{W}}_x}=t=1∑T∂ytarget(t)∂Lt∂c¯(t)∂ytarget(t)k=1∑t(i=k+1∏tW¯¯ctdiag[f′(c¯(i−1))])∂W¯¯x∂c¯(k)
其中\displaystyle (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})])(i=k+1∏tW¯¯ctdiag[f′(c¯(i−1))])此項目是連續相乘,若是\bar{\bar{W}}_cW¯¯c裡頭的元素element的絕對值大於1,則會使\displaystyle (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})])(i=k+1∏tW¯¯ctdiag[f′(c¯(i−1))])的值過大(爆炸),若是元素element的絕對值小於1,則會使\displaystyle (\prod_{i=k+1}^t \bar{\bar{W}}_c^t \textrm{diag}[f'(\bar{c}^{(i-1)})])(i=k+1∏tW¯¯ctdiag[f′(c¯(i−1))]) 的值過小(gradient vanishing)。