ใน ep นี้เราจะมาเรียนรู้กันว่า Learning Rate คืออะไร Learning Rate สำคัญอย่างไรกับการเทรน Machine Learning โมเดล Neural Network / Deep Learning เราจะปรับ Learning Rate อย่างไรให้เหมาะสม เราสามารถเทรนไปปรับไปได้ไหม หรือต้องใช้ค่าคงที่ตลอด และโมเดลที่ Transfer Learning กับโมเดลที่เทรนใหม่เลย ต้องการ Learning Rate, จำนวน Epoch ต่างกันอย่างไร

ในการเทรนโมเดล Deep Learning เราต้องเข้าใจก่อนว่าอัลกอริทึมที่ใช้ในการเทรน ที่ชื่อว่า Gradient Descent ทำงานอย่างไร และ Hyperparameter หลักในการควบคุมการทำงานของ Gradient Descent ก็คือ Learning Rate

Learning Rate คืออะไร

Learning Rate คือ Hyperparameter ตัวหนึ่งที่ควบคุมว่าในหนึ่ง Step ของการเทรน เราจะปรับ Weight ของ Neural Network มากน้อยแค่ไหน

  • ถ้า Learning Rate มีค่าน้อย Weight ของโมเดลก็จะเปลี่ยนแปลงน้อย การทำงานของโมเดลก็จะเปลี่ยนไปน้อย Loss ก็ไม่ค่อยเปลี่ยนเท่าไร
  • ถ้า Learning Rate มีค่ามาก Weight ของโมเดลก็จะเปลี่ยนแปลงมาก การทำงานของโมเดลก็จะเปลี่ยนไปมาก Loss ก็จะเปลี่ยนแปลงมาก
Gradient descent with small (top) and large (bottom) learning rates. Source: Andrew Ng’s Machine Learning course on Coursera
Gradient descent ด้วย Learning Rate น้อยเกิน (รูปบน) และมากเกิน (รูปล่าง) เครดิต: Andrew Ng Machine Learning course on Coursera

เรามาเริ่มกันเลยดีกว่า

Open In Colab

0. Magic Commands

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

1. Import Library

In [2]:
from fastai import *
from fastai.vision import *
from fastai.metrics import accuracy

2. ข้อมูล

ใช้ชุดข้อมูล Dataset MNIST ตัวเลขอารบิคเขียนด้วยลายมือ

In [3]:
path = untar_data(URLs.MNIST_SAMPLE)

3. เตรียมข้อมูล

In [4]:
batchsize = 64
np.random.seed(55)
transform = get_transforms(do_flip=False)

databunch = ImageDataBunch.from_folder(path, train='training', 
                                       valid_pct=0.2, size=28, 
                                       ds_tfms=transform, bs=batchsize, 
                                       num_workers=8).normalize()

4. สร้างโมเดล

เพื่อความสะดวก เราจะไปสร้างก่อน fit แต่ละแบบ

  • เคสนี้เราจะใช้โมเดล models.vgg16_bn คือ โมเดลชื่อว่า VGGNet มี 16 Layers จากปี 2014 ซึ่งเป็นโมเดลที่เก่ากว่า เก่งน้อยกว่า Resnet จะได้เห็นกราฟ Loss ได้ชัดขึ้น
  • pretrained=False หมายถึง เราจะเริ่มเทรนตั้งแต่ต้น ไม่เอา Weight ที่เคยเทรนกับ ImageNet มาแล้วมาใช้
In [5]:
# learner = cnn_learner(databunch, models.vgg16_bn, 
#                       pretrained=False, 
#                       metrics=accuracy, callback_fns=ShowGraph)

5. เริ่มต้นเทรนโมเดล

ปกติ Learning Rate จะมีค่า Default ประมาณ 3e-3 หรือ 0.003 เราจะลองเทรนด้วย Learning Rate สูงมาก และต่ำมาก เปรียบเทียบกับ Image Classification ep.3 ดูว่าจะเป็นอย่างไร

ลองเทรนด้วย Learning Rate สูง ๆ เช่น 0.1

In [6]:
learner = cnn_learner(databunch, models.vgg16_bn, 
                      pretrained=False, 
                      metrics=accuracy, callback_fns=ShowGraph)
learner.fit(10, lr=0.1)
Total time: 00:59

epochtrain_lossvalid_lossaccuracytime
00.1721080.0351460.98752600:06
10.1082550.5916710.84442100:05
20.4489670.1037650.98475400:05
30.0417990.0070840.99757500:05
40.3340000.0821080.99688100:05
50.3181170.1644950.99168400:05
60.2086530.1352040.99341600:05
70.3356870.1724720.99237700:05
80.3271840.1913360.99411000:05
90.4125420.2681390.99618900:05

Weight ของ Model ถูกลบด้วยค่าที่ใหญ่เกินไป ซ้ำ ๆ เหมือนถูกระเบิดออก ทำให้ Loss กระเด้งไปมา ไม่ลงสู่จุดต่ำสุด

บางที Weight ระเบิดมากเกินกว่าที่ตัวเลขในคอมพิวเตอร์ จะรับได้ จะกลายเป็นค่า NaN หรือ Not a Number ต้อง Restart และเทรนใหม่ตั้งแต่ต้น

ลองเทรนด้วย Learning Rate ต่ำ ๆ เช่น 0.000001

In [7]:
learner = cnn_learner(databunch, models.vgg16_bn, 
                      pretrained=False, 
                      metrics=accuracy, callback_fns=ShowGraph)
learner.fit(8, lr=0.000001)
Total time: 00:46

epochtrain_lossvalid_lossaccuracytime
00.7978040.6917740.62127500:05
10.6258220.5716630.71483000:05
20.5196030.4950000.74809400:05
30.4290420.3969080.81150400:05
40.3701130.3153370.86070700:05
50.3051420.2766490.88322900:05
60.2694670.2446620.90055400:05
70.2321830.2107870.91684000:05

Weight ของ Model ถูกลบด้วยค่าที่เล็กเกินไป ทำให้ Loss ไม่ขยับไปไหน

ลองเทรนด้วย Epoch น้อยเกินไป หรือ เทรนสั้นเกินไป

In [8]:
learner = cnn_learner(databunch, models.vgg16_bn, 
                      pretrained=False, 
                      metrics=accuracy, callback_fns=ShowGraph)
learner.fit(1, lr=0.0003)
Total time: 00:06

epochtrain_lossvalid_lossaccuracytime
00.0359750.0399680.98406100:05

เทรนสั้นไป ยังไม่ได้อัพเดท Weight สักเท่าไร ทำให้ Loss ไม่ขยับไปไหน

ลองเทรนด้วย Epoch เยอะเกินไป หรือ เทรนนานเกินไป

In [9]:
learner = cnn_learner(databunch, models.vgg16_bn, 
                      pretrained=False, 
                      metrics=accuracy, callback_fns=ShowGraph)
learner.fit(30, lr=0.0003)
Total time: 02:56

epochtrain_lossvalid_lossaccuracytime
00.0494240.0197330.99376300:05
10.0258350.0601890.98024900:05
20.0254990.0087830.99826700:05
30.0211860.0029400.99896000:05
40.0220880.0132140.99549500:06
50.0137800.0035070.99930700:05
60.0173770.0092160.99826700:05
70.0174070.0058060.99826700:05
80.0118180.0037910.99896000:05
90.0078150.0030560.99965400:05
100.0147300.0043810.99792100:05
110.0072520.0072810.99792100:05
120.0072060.0057670.99826700:05
130.0044530.0001851.00000000:05
140.0074190.0024670.99930700:05
150.0038480.0059950.99861400:05
160.0033350.0052000.99861400:05
170.0042490.0056450.99896000:05
180.0065000.0093490.99618900:05
190.0036240.0013290.99930700:05
200.0020340.0045740.99861400:05
210.0053650.0044450.99826700:05
220.0043020.0086510.99896000:05
230.0014780.0039470.99896000:05
240.0027430.0035960.99896000:05
250.0085570.0136260.99411000:05
260.0017660.0026700.99896000:05
270.0051950.0048490.99896000:05
280.0026100.0022950.99930700:05
290.0027280.0016050.99930700:05

เทรนนานเกินไป ผลลัพธ์ก็ไม่ได้ดีขึ้น บางทีอาจจะทำให้ Overfit ซึ่งไว้เราจะอธิบายต่อไป

In [ ]:
 
In [ ]:
 

หมายเหตุ

  • ในกรณีนี้เราถือว่า Bias คือ Weight ตัวที่ 0 ที่คุณกับ x0 (ที่ค่าเป็น 1 เสมอ) จึงถือว่า Bias เป็น Weight เหมือนกัน เราจะไม่ได้เขียนแยก ว่า Weight และ Bias

Learning Rate แบบไม่คงที่

Learning Rate ไม่จำเป็นต้องคงทีตลอดการเทรน ปัจจุบันมี Paper มากมาย นำเสนอไอเดียการเพิ่มลด Learning Rate ระหว่างเทรน เช่น

  • ค่อย ๆ ลด Learning Rate ไปเรื่อย ๆ ทุก Epoch
  • แบ่งการเทรนเป็น N Cycle แล้วเพิ่ม Learning Rate จนสุด Max ลดจนถึง Min ทุก Cycle เป็นฟันปลา
  • รวม 2 แบบบนเข้าด้วยกัน เทรน N Cycle โดยลดค่า Max Learning Rate ทุก Epoch
  • Stochastic Gradient Descent with Warm Restarts ค่อย ๆ เพิ่ม/ลด Learning Rate ด้วย Cosine Function โค้ง ๆ แทนที่จะเป็นฟันปลา
  • etc.

เชื่อว่า การเพิ่ม/ลด Learning Rate นี้จะช่วยให้ เราเทรนโมเดลได้รวดเร็วขึ้น โมเดลสามารถกระโดดข้ามภูเขา และกระโดดออกมาจาก หลุม Local Minima ได้ ซึ่งเราจะอธิบายต่อไป

รูปเปรียบเทียบ Learning Rate แบบคงที่ กับ Learning Rate แบบ Cycle เครดิต https://arxiv.org/abs/1704.00109
รูปเปรียบเทียบ Learning Rate แบบคงที่ กับ Learning Rate แบบ Cycle เครดิต https://arxiv.org/abs/1704.00109

Learning Rate สำหรับ Pretrained Model

จากบทความเรื่อง Image Classification เรามีการใช้ Transfer Learning จากโมเดล Resnet34 ตัด Layer สุดท้ายทิ้ง เทรนเฉพาะ Layer สุดท้าย ด้วย Learning Rate ที่สูง unfreeze แล้วเทรนทั้งโมเดล ด้วย Learning Rate ที่ต่ำลง และแบ่งเป็นช่วง ๆ สาเหตุเพราะว่า Layer ต่างกัน ต้องการความเปลี่ยนแปลง ของ Weight หรือ Learning Rate ที่ต่างกัน

เช่น ใน Convolution Neural Network (CNN) Layer แรก จะเป็นเรื่องพื้นฐาน เช่น เส้นตั้ง เส้นนอน เส้นทแยง Layer ต่อมาจะเป็นมุม เป็นเส้นโค้ง เป็นเรื่องที่ไม่ว่ารูปแบบไหนก็ต้องประกอบขึ้นมาจากส่วนประกอบพื้นฐานเหล่านี้ Layer เหล่านี้ต้องการ Learning Rate ที่ต่ำมาก ตรงข้ามกับ Layer หลัง ๆ ซึ่งเกี่ยวกับ Object, ดวงตา, ใบหน้า, พื้นผิว Texture, ตัวหนังสือ ที่เราต้องการ Tune ให้เข้ากับงานที่เรา ซึ่งต้องการ Learning Rate ที่สูงกว่า

ใน learner.fit_one_cycle เราจึงมีการกำหนด Maximum Learning Rate (max_lr) ด้วย split(3e-6, 3e-3) เพื่อให้ Layer แรก ๆ ได้ค่า Learning Rate น้อย ๆ คือ 3e-6 ไล่ไปจนถึง Layer สุดท้าย ได้ค่า Learning Rate มากที่สุด คือ 3-e3

Credit

แชร์ให้เพื่อน:

Keng Surapong on FacebookKeng Surapong on GithubKeng Surapong on Linkedin
Keng Surapong
Project Manager at Bua Labs
The ultimate test of your knowledge is your capacity to convey it to another.

Published by Keng Surapong

The ultimate test of your knowledge is your capacity to convey it to another.

Enable Notifications    OK No thanks