Optimization


Machine Learning Model

Fig.1 Schematic of Machine learning model, label data and loss function LL.

Fig.1 shows the schematic of machine learning model, label data y^\hat{y} and loss function LL.

The machine learning model are characterized by the weight w¯¯\bar{\bar{w}} and bias b¯\bar{b}.

x¯\bar{x} is the input data, and y¯\bar{y} is the output data. y¯\bar{y} can be also seen as the model prediction.

y^\hat{y} is the label data. Loss function measure the distance between model prediction y¯\bar{y} and label data y^\hat{y}.

The target of machine learning is to find the model (usually characterized with w¯¯\bar{\bar{w}} and b¯\bar{b}) with the model prediction y¯\bar{y} close to y^\hat{y}.

In other words to minimize the loss function value LL.

Fig.2 schematic of Machine learning frame work as an optimization problem.

Fig.2 shows the schematic of machine learning frame work as an optimization problem.

x¯(1),x¯(2),,x¯(k),,x¯(K)\bar{x}^{(1)},\bar{x}^{(2)},\cdots, \bar{x}^{(k)}, \cdots, \bar{x}^{(K)}are K dataset, and y¯(1),y¯(2),,y¯(k),,y¯(K)\bar{y}^{(1)},\bar{y}^{(2)},\cdots, \bar{y}^{(k)}, \cdots, \bar{y}^{(K)}are its corresponding model predictions.

y^(1),y^(2),,y^(k),,y^(K)\hat{y}^{(1)},\hat{y}^{(2)},\cdots, \hat{y}^{(k)}, \cdots, \hat{y}^{(K)}are the corresponding label datum.

LkL_k refers to the distance between the kkth model prediction y¯(k)\bar{y}^{(k)} and label data y^(k)\hat{y}^{(k)}.

The loss function LL is defined as the over all KK data, L=k=1KLkL=\displaystyle \sum_{k=1}^K L_k .

The machine learning can be seen as the optimization problem:

Find w¯¯\bar{\bar{w}} s and b¯\bar{b} to minimize the loss function value LL.

w¯¯,b¯=minw¯¯,b¯L\bar{\bar{w}},\bar{b} = \displaystyle \min_{\bar{\bar{w}},\bar{b} } L

The gradient descent is applied to solve this optimization problem.

Gradient descent

The update rule based on gradient descent can be expressed as

w(t)=w(t1)ηLww^{(t)}=w^{(t-1)} - \eta \frac{\partial L}{\partial w} .

where η\eta is the learning rate.

One dimension example

Assuming the loss function is obtained as L(w)=w2L(w)=w^2, we can derive Lw=2w\frac{\partial L}{\partial w}=2 w.

Fig.3 Gradient decent example 1 with η=0.1\eta=0.1.

If ww is initialized as 5, w(t=0)=5w^{(t=0)}=5, we derive Lw=2w(0)=2×5=10\frac{\partial L}{\partial w}= 2 w^{(0)}= 2 \times 5 =10.

Assuming the learning rate η\eta=0.1, the w at time 1,t=1t=1, is updated as

w(1)=w(0)0.1×2×5=51=4w^{(1)}=w^{(0)}-0.1 \times 2 \times 5 =5-1=4

the w at time 2,t=2t=2, is updated as

w(2)=w(1)0.1×2×4=40.8=3.2w^{(2)}=w^{(1)}-0.1 \times 2 \times 4=4-0.8=3.2

the w at time 3,t=3t=3, is updated as

w(3)=w(2)0.1×2×3.2=3.20.64=2.56w^{(3)}=w^{(2)}-0.1 \times 2 \times 3.2=3.2-0.64=2.56

As time goes infinity, the w(t)=0w^{(t \to \infty)}=0.

considering larger learning rate

Fig.4 Gradient decent example 1 with η=0.6\eta=0.6.

Considering a larger learning rate η=0.6\eta=0.6,

the w at time 1,t=1t=1, is updated as

w(1)=w(0)0.6×2×5=56=1w^{(1)}=w^{(0)}-0.6 \times 2 \times 5 =5-6=-1

the w at time 2,t=2t=2, is updated as

w(2)=w(1)0.6×2×(1)=(1)(1.2)=0.2w^{(2)}=w^{(1)}-0.6 \times 2 \times (-1) =(-1)-(-1.2) =0.2

the w at time 3,t=3t=3, is updated as

w(3)=w(2)0.6×2×(0.2)=0.20.24=0.04w^{(3)}=w^{(2)}-0.6 \times 2 \times (0.2)=0.2-0.24=-0.04

As time goes infinity, the w(t)=0w^{(t \to \infty)}=0.

The convergence rate with η=0.6\eta=0.6 is larger than that with η=0.1\eta=0.1 in this illustrative example.

considering a even larger learning rate

Fig.5 Gradient decent example 1 with η=1.2\eta=1.2.

Considering a even larger learning rate η=1.2\eta=1.2,

the w at time 1,t=1t=1, is updated as

w(1)=w(0)1.2×2×5=512=7w^{(1)}=w^{(0)}-1.2 \times 2 \times 5 =5-12=-7

the w at time 2,t=2t=2, is updated as

w(2)=w(1)1.2×2×(7)=(7)(16.8)=9.8w^{(2)}=w^{(1)}-1.2 \times 2 \times (-7) =(-7)-(-16.8) =9.8

the w at time 3,t=3t=3, is updated as

w(3)=w(2)1.2×2×(9.8)=9.823.52=13.72w^{(3)}=w^{(2)}-1.2 \times 2 \times (9.8)=9.8-23.52=-13.72

As time goes infinity, the w(t)w^{(t \to \infty)}.will diverge.

results matching ""

    No results matching ""