ใน ep ที่แล้วเราได้เรียนรู้ถึงปัญหา Vanishing Gradient Problem และวิธีแก้ไขกันไปแล้ว ใน ep นี้เราจะเจาะลึกลงไปถึงสาเหตุ ดูตัวอย่างของ Neural Network ว่าเมื่อเกิดปัญหา Vanishing Gradient Problem และ Exploding Gradient Problem จะมีอาการอย่างไร และเราจะแก้ไขอย่างไรให้โมเดลสามารถเทรนได้ต่อ

Vanishing Gradient from Tensorboard
Vanishing Gradient from Tensorboard

เรามาเริ่มกันเลยดีกว่า

Open In Colab

ในการเทรน Neural Network แล้ว Operation ส่วนใหญ่ก็คือ การคุณเมตริกซ์ Activation ของ Layer ก่อนหน้า กับ Weight ของ Layer นั้น

เราได้ Normalize ข้อมูล Input ให้มีขนาด mean = 0, std = 1 เรียบร้อยแล้ว แล้ว Weight เราควร Initialize อย่างไร ให้ไม่เกิดปัญหา Vanishing Gradient และ Exploding Gradient

0. Import

In [1]:
import torch, math

1. Vanishing Gradient

สมมติ Weight เรา Inital ไว้น้อยเกินไป จะทำให้เกิด Vanishing Gradient คือ Gradient น้อยลงจนโมเดลเทรนไม่ไปไหน

In [2]:
x = torch.randn(100, 100)
a = torch.randn(100, 100) * 0.01
In [3]:
x.mean(), x.std()
Out[3]:
(tensor(0.0084), tensor(1.0040))
In [4]:
a.mean(), a.std()
Out[4]:
(tensor(3.0860e-05), tensor(0.0099))
In [5]:
for i in range(50):
    x = x @ a 
    print(f'{i}, {x.mean()}, {x.std()}')
0, -0.00012308883015066385, 0.10081212222576141
1, 6.062877218937501e-05, 0.009958631359040737
2, -5.713999598810915e-06, 0.0009827547473832965
3, -2.6385391720396e-06, 9.923477045958862e-05
4, 1.8872927398660977e-07, 9.949823834176641e-06
5, 6.353400916481178e-09, 1.009312995847722e-06
6, -2.987314395852536e-09, 1.0363645230881957e-07
7, 2.2307683478217655e-10, 1.0457650745365754e-08
8, 4.485252447228305e-12, 1.082509304417556e-09
9, -2.21616366750943e-12, 1.1044364034429321e-10
10, -9.716613093551513e-14, 1.132850757645798e-11
11, -1.1563880143164104e-14, 1.1714003248300409e-12
12, 2.821596555103241e-15, 1.2018770039531196e-13
13, 4.366935600467934e-17, 1.2278859806908512e-14
14, -4.1292419427283714e-17, 1.2419247894619387e-15
15, 5.672071035082906e-19, 1.2641872770052847e-16
16, 3.116762522866051e-19, 1.2977833499330239e-17
17, -1.6383184448665015e-20, 1.334556346488074e-18
18, -1.7409013616233646e-21, 1.3774881059541302e-19
19, -6.945157284203479e-23, 1.4424008783420303e-20
20, 2.6377699023674058e-23, 1.4907215345141426e-21
21, 4.8852504182816955e-25, 1.5494043545172308e-22
22, -3.251581606365266e-25, 1.626404783485928e-23
23, 3.722433159465772e-27, 1.724377166394625e-24
24, 4.58402950533849e-27, 1.806224985881572e-25
25, -1.7937884001548034e-28, 1.8902267469293942e-26
26, -3.12939541838621e-29, 1.995416533815755e-27
27, -9.38149722210471e-31, 2.1155981734641644e-28
28, 5.380486848515984e-31, 2.2131926270954216e-29
29, 1.5259121555382885e-32, 2.3212095741167758e-30
30, -4.713272801974598e-33, 2.4525643439934797e-31
31, -2.0899534210664372e-34, 2.6006014005586353e-32
32, 6.862043552252367e-35, 2.726543534697949e-33
33, -2.001314063843063e-37, 2.8853367502822124e-34
34, -3.5140834668234994e-37, 3.07595704938684e-35
35, -1.7158627843755306e-38, 3.296727860235098e-36
36, 6.278009098064793e-39, 3.487663273473242e-37
37, 8.119543691837288e-41, 3.720382122273952e-38
38, -4.1370534562261574e-41, 3.993968271332415e-39
39, -2.948331968939415e-42, 4.298609156178166e-40
40, 9.262582849187041e-43, 4.5652902669238215e-41
41, 2.802596928649634e-45, 4.8681108650644145e-42
42, -5.605193857299268e-45, 5.226843271931568e-43
43, -0.0, 5.605193857299268e-44
44, 0.0, 5.605193857299268e-45
45, -0.0, 0.0
46, 0.0, 0.0
47, 0.0, 0.0
48, 0.0, 0.0
49, 0.0, 0.0

ผลลัพธ์น้อยลง ๆ ๆ จนหายไปหมด กลายเป็น 0 เรียกว่า Vanishing Gradient

2. Exploding Gradient

สมมติ Weight เรา Inital ไว้มากเกินไป จะทำให้เกิด Exploding Gradient คือ Gradient มากจนโมเดล Error

In [6]:
x = torch.randn(100, 100)
a = torch.randn(100, 100) 
In [7]:
x.mean(), x.std()
Out[7]:
(tensor(0.0052), tensor(1.0035))
In [8]:
a.mean(), a.std()
Out[8]:
(tensor(0.0036), tensor(1.0041))
In [9]:
for i in range(50):
    x = x @ a 
    print(f'{i}, {x.mean()}, {x.std()}')
0, 0.06944113224744797, 10.123376846313477
1, 0.046876031905412674, 102.4043960571289
2, 0.909347653388977, 1025.0296630859375
3, 135.59620666503906, 10326.9931640625
4, -453.7614440917969, 103001.0625
5, -21130.669921875, 1014882.625
6, -71946.9296875, 9843137.0
7, -24773.82421875, 98388896.0
8, 4196430.0, 977391168.0
9, -90994256.0, 9854563328.0
10, -93759464.0, 99021996032.0
11, 10142417920.0, 979634028544.0
12, -3620067328.0, 9672285224960.0
13, 297260023808.0, 94938733740032.0
14, -6461432791040.0, 953410731900928.0
15, -61434331398144.0, 9274541641564160.0
16, -407568740515840.0, 9.12780451339305e+16
17, -1256197269225472.0, 9.051564548921754e+17
18, -6.097315962028032e+16, 8.988005080254906e+18
19, -5.3255419498961306e+17, 9.200909454103963e+19
20, -5.971085498809057e+17, 9.268659953232637e+20
21, 5.46994170532515e+18, 9.409074589789678e+21
22, -2.1198788252879395e+20, 9.562251695855747e+22
23, -4.3926077108665083e+20, 9.528083425914842e+23
24, -3.3611557472967265e+22, 9.645982331897443e+24
25, -6.202185189164896e+23, 9.829271948025432e+25
26, -7.887051273284852e+24, 9.933428536958582e+26
27, -3.222442814323646e+25, 1.0294414199902577e+28
28, -4.646402052904346e+26, 1.0380755115255476e+29
29, 1.9213426175161542e+27, 1.0611449140362075e+30
30, 6.055577904763679e+28, 1.095920390592914e+31
31, 5.816875407233441e+29, 1.107372085490312e+32
32, 1.353650296268542e+28, 1.1467829318100571e+33
33, -4.350869864916332e+31, 1.1785441713764834e+34
34, -6.343091799269704e+32, 1.2036507837956287e+35
35, -5.184847244863418e+33, 1.2621429752828878e+36
36, nan, 1.2930884180579493e+37
37, nan, nan
38, nan, nan
39, nan, nan
40, nan, nan
41, nan, nan
42, nan, nan
43, nan, nan
44, nan, nan
45, nan, nan
46, nan, nan
47, nan, nan
48, nan, nan
49, nan, nan

ผลลัพธ์มากขึ้น ๆ ๆ ๆ จนเกินค่ามากที่สุด ที่ระบบรับไหว เหมือนกับระเบิดออก กลายเป็น Infinity (inf) หรือ Not a number (nan) เรียกว่า Exploding Gradient

3. Kaiming Initialization

การ Initialize Weight ที่เหมาะสม ด้วย Kaiming Initialization จะช่วยให้เราเทรนโมเดล เร็วขึ้น และเทรนได้นานขึ้นตามที่เราต้องการโดยไม่ Error ไปเสียก่อน

In [10]:
x = torch.randn(100, 100)
a = torch.randn(100, 100) * math.sqrt(1./100.)

ในเคสนี้ เป็น Kaiming Initialization เวอร์ชัน สำหรับไม่มี Activation Function

In [11]:
x.mean(), x.std()
Out[11]:
(tensor(0.0068), tensor(0.9893))
In [12]:
a.mean(), a.std()
Out[12]:
(tensor(-0.0003), tensor(0.0997))
In [13]:
for i in range(50):
    x = x @ a 
    print(f'{i}, {x.mean()}, {x.std()}')
0, -0.0002966029569506645, 0.9761020541191101
1, -0.0003545041545294225, 0.9636297225952148
2, 0.008720957674086094, 0.9397486448287964
3, 0.0005875895731151104, 0.9026142954826355
4, 0.009837822057306767, 0.8824164271354675
5, -0.004870842210948467, 0.8637325167655945
6, 0.009631485678255558, 0.8482977151870728
7, -0.00699358806014061, 0.8444557189941406
8, 0.007382232695817947, 0.8386370539665222
9, 0.004917025100439787, 0.825736403465271
10, 0.009946335107088089, 0.8209970593452454
11, -0.014773333445191383, 0.8316162824630737
12, -0.010095912963151932, 0.830187976360321
13, -0.00988482404500246, 0.828865647315979
14, 0.014411764219403267, 0.8282093405723572
15, -0.004298543091863394, 0.8409703969955444
16, -0.008712172508239746, 0.8465444445610046
17, -0.015099233947694302, 0.8531993627548218
18, 0.013269342482089996, 0.871336042881012
19, 0.004237722605466843, 0.8973748087882996
20, 0.013615096919238567, 0.9357888698577881
21, -0.006991087459027767, 0.9703547954559326
22, 0.004890932235866785, 1.0143762826919556
23, -0.004372886847704649, 1.0621992349624634
24, 0.013723906129598618, 1.1149945259094238
25, -0.0068347700871527195, 1.1737383604049683
26, 0.005121869500726461, 1.2397925853729248
27, -0.012573646381497383, 1.3081283569335938
28, 0.006252205464988947, 1.3783568143844604
29, -0.015094763599336147, 1.4558509588241577
30, 0.007662947755306959, 1.544228434562683
31, -0.007653167005628347, 1.6396034955978394
32, 0.01473329309374094, 1.7396129369735718
33, -0.01577235572040081, 1.8459067344665527
34, 0.011018366552889347, 1.9617066383361816
35, -0.010309331119060516, 2.088635206222534
36, 0.023677419871091843, 2.222790002822876
37, -0.010005255229771137, 2.364104986190796
38, 0.016244009137153625, 2.5151500701904297
39, -0.02267889305949211, 2.6774721145629883
40, 0.019860919564962387, 2.852034091949463
41, -0.01604018174111843, 3.036756992340088
42, 0.023031268268823624, 3.2309606075286865
43, -0.028827959671616554, 3.4372637271881104
44, 0.019172750413417816, 3.6594107151031494
45, -0.028283527120947838, 3.8974359035491943
46, 0.02871464192867279, 4.149108409881592
47, -0.03014484792947769, 4.416177749633789
48, 0.03198430314660072, 4.701683044433594
49, -0.03420396149158478, 5.006956100463867

คูณเมตริกซ์ยังไง ก็ยังใกล้เคียง mean = 0, std = 1 สามารถเทรนไปได้อีกยาว ๆ ไม่มี Vanishing Gradient และ Exploding Gradient

4. Kaiming Initialization and ReLU Activation Function

การ Initialize Weight ที่เหมาะสม ด้วย Kaiming Initialization จะช่วยให้เราเทรนโมเดล เร็วขึ้น และเทรนได้นานขึ้นตามที่เราต้องการโดยไม่ Error ไปเสียก่อน

In [14]:
gain = math.sqrt(2.)
In [15]:
x = torch.randn(100, 100)
a = torch.randn(100, 100) * math.sqrt(1./100.) * gain

ในเคสนี้ เป็น Kaiming Initialization เวอร์ชัน รองรับ ReLU Activation Function

In [16]:
x.mean(), x.std()
Out[16]:
(tensor(-0.0034), tensor(1.0043))
In [17]:
a.mean(), a.std()
Out[17]:
(tensor(0.0003), tensor(0.1395))

เพิ่ม ReLU Activation Function หลังจากที่ คุณเมตริกซ์

In [18]:
def relu(x):
    return x.clamp_(0.).sub_(0.5) # -0.5 for move mean
In [19]:
for i in range(50):
    x = x @ a 
    relu(x)
    print(f'{i}, {x.mean()}, {x.std()}')
0, 0.07064563035964966, 0.827333927154541
1, -0.046542372554540634, 0.6736540198326111
2, -0.13649706542491913, 0.5447353720664978
3, -0.19656209647655487, 0.4531916677951813
4, -0.21645238995552063, 0.4117547869682312
5, -0.21927060186862946, 0.3951576054096222
6, -0.2151007354259491, 0.3940693140029907
7, -0.2132544368505478, 0.39335212111473083
8, -0.20003201067447662, 0.4055188000202179
9, -0.19418643414974213, 0.4183146059513092
10, -0.18435138463974, 0.4282020032405853
11, -0.181589737534523, 0.433869868516922
12, -0.17619332671165466, 0.43589404225349426
13, -0.17180456221103668, 0.4391709268093109
14, -0.169425830245018, 0.4414549171924591
15, -0.1683892160654068, 0.4369339048862457
16, -0.16912926733493805, 0.4300256371498108
17, -0.17198164761066437, 0.424317330121994
18, -0.17612454295158386, 0.4219188094139099
19, -0.17541316151618958, 0.41755664348602295
20, -0.18087436258792877, 0.4102037250995636
21, -0.1807420402765274, 0.4083867073059082
22, -0.1825791299343109, 0.40876492857933044
23, -0.18176256120204926, 0.4084380865097046
24, -0.18169115483760834, 0.40591299533843994
25, -0.18313050270080566, 0.403063029050827
26, -0.1827109456062317, 0.40572530031204224
27, -0.18196377158164978, 0.4076358377933502
28, -0.18078096210956573, 0.40915733575820923
29, -0.17984361946582794, 0.4081428050994873
30, -0.18115629255771637, 0.410800039768219
31, -0.17896030843257904, 0.41564321517944336
32, -0.17737989127635956, 0.4136631190776825
33, -0.1789243221282959, 0.41373905539512634
34, -0.17796248197555542, 0.41324225068092346
35, -0.1783604472875595, 0.41560474038124084
36, -0.17756642401218414, 0.41592589020729065
37, -0.1775074452161789, 0.4147123098373413
38, -0.17679576575756073, 0.41498681902885437
39, -0.1766795516014099, 0.41608095169067383
40, -0.17610087990760803, 0.4162077307701111
41, -0.17626897990703583, 0.4145180284976959
42, -0.17690247297286987, 0.41068997979164124
43, -0.1787371039390564, 0.4113275110721588
44, -0.1764233112335205, 0.41124045848846436
45, -0.1784019023180008, 0.4086558520793915
46, -0.1782366931438446, 0.40653035044670105
47, -0.17908145487308502, 0.40583592653274536
48, -0.17923395335674286, 0.4067550301551819
49, -0.17963126301765442, 0.4055445194244385

คูณเมตริกซ์ยังไง ก็ยังใกล้เคียง mean = 0, std = 1 สามารถเทรนไปได้อีกยาว ๆ ไม่มี Vanishing Gradient และ Exploding Gradient

4. สรุป

  1. Initialization เป็นวิธีง่าย ๆ ที่คนมองข้ามไป ที่จะมาช่วยแก้ปัญหา Vanishing Gradient และ Exploding Gradient
  2. อันนี้เป็นตัวอย่างง่าย ๆ ให้พอเห็นภาพ แต่ Neural Network จริง ๆ จะซับซ้อนกว่านี้ และมี Activation Function มาคั่น ทำให้พฤติกรรมของ Gradient เปลี่ยนไปอีก
  3. ยังมีอีกหลายเทคนิค ที่มาช่วยคุมให้ไม่เกิดการ Vanishing Gradient และ Exploding Gradient เช่น เปลี่ยนจาก Sigmoid Activation Function เป็น ReLU Activation Function, Batch Normalization, LSTM, Residual Neural Network, etc.

Credit

In [ ]:
 

แชร์ให้เพื่อน:

Keng Surapong on FacebookKeng Surapong on GithubKeng Surapong on Linkedin
Keng Surapong
Project Manager at Bua Labs
The ultimate test of your knowledge is your capacity to convey it to another.

Published by Keng Surapong

The ultimate test of your knowledge is your capacity to convey it to another.

Enable Notifications    OK No thanks