ปัญหาหลักอย่างนึงในการเทรน Deep Learning คือ Dataset ของเรามีข้อมูลตัวอย่างไม่เพียงพอ สมมติว่าเราปิ๊งสุดยอดไอเดีย ที่จะสร้าง App ใหม่ ที่ใช้ Machine Learning ขึ้นมา เราเปิดเว็บเพื่อ Search Google หาข้อมูลตัวอย่าง มาไว้เทรนโมเดล เรานั่ง Search Google Images หารูปภาพอยู่หลายชั่วโมง นั่งจัด นั่ง Clean ข้อมูลที่ไม่เกี่ยวข้องออกไป สุดท้ายเราได้ รูปมา 500 รูป ถ้าหาแบบนี้ 10 วัน ก็ 5,000 รูป แต่เรารู้มาว่าโมเดลที่ดัง ๆ ใช้ข้อมูลในการเทรน เกิน 1 ล้านรูปขึ้นไปทั้งนั้น แล้วเราจะทำอย่างไรดี

Transfer Learning คืออะไร

Transfer Learning คือ การนำโมเดลของผู้อื่น ที่เทรนกับข้อมูลอื่นเรียบร้อยแล้ว ตัดเอาบางส่วนเฉพาะส่วนที่เราสนใจ มา Reuse ใช้ต่อ มาประกอบในการสร้างโมเดลใหม่ของเรา เช่น เรานำโมเดล Convolution Neural Network (CNN) ที่มี 100 Layer มาตัด Layer สุดท้ายทิ้งไป เปลี่ยนเป็น Layer ที่เหมาะกับงานของเราแทน แล้วเวลาเทรน ก็เทรนแค่ Layer ใหม่ที่อยู่ท้ายสุดอย่างเดียว เพื่อลดเวลา และข้อมูลที่ใช้ในการเทรน เพราะมี 99 Layer ที่ทำงานได้ถูกต้องอยู่แล้ว ไม่ต้องเทรนตั้งแต่ต้น เรื่อง Transfer Learning จะอธิบายเพิ่มเติมต่อไป

Data Augmentation คืออะไร

แต่ถึงเราจะใช้ Transfer Learning ประกอบโมเดลของเราขึ้นมาจากโมเดล Convolution Neural Network (CNN) ชื่อดัง เช่น ResNet, VGGNet, Inception รูปแค่ 500 รูปของเราก็ไม่พอเทรนอยู่ดี แล้วเราจะแก้อย่างไร

คำตอบ คือ เราก็เอารูป 500 รูปมา Recycle ให้เป็น 1,000 รูปสิ เช่น สมมติว่าเรามีรูปดอกไม้ 500 รูป เราเอารูปทั้งหมดมา Flip กลับซ้ายขวา ก็จะได้รูปเพิ่มขึ้นมาอีก 500 รูป รวมเป็น 1,000 รูป แบบง่าย ๆ เรียกว่า Data Augmentation

Data Augmentation Normal Compare to Flip Horizontal
Data Augmentation Normal Compare to Flip Horizontal

Data Augmentation ที่เป็นที่นิยมสำหรับรูปภาพมีอีกหลายอย่าง ได้แก่ ย่อ/ขยาย, หมุน ซ้าย/ขวา, Flip ซ้าย/ขวา/บน/ล่าง, Crop มุม, ปรับสีเข้ม/จืด, ปรับแสง สว่าง/มืด, ปรับ Contrast, ปรับ Perspective, เพิ่ม/ลด Noise, เบลอภาพ, Etc. ทั้งนี้ขึ้นอยู่กับงานด้วย เช่น ดอกไม้จะ Flip กลับหัวได้ไหม

ถ้าทำทั้งหมดผสม ๆ กันมากน้อย แบบ Random ไปเรื่อย ๆ จากรูปดอกไม้ ตอนแรกเพียง 500 รูป ก็จะกลายเป็นหมื่นเป็นแสนรูปได้ไม่จำกัด แต่คุณภาพของโมเดลก็จะขึ้นอยู่กับ ข้อมูลเริ่มต้น และ Hyperparameter ที่เรากำหนดว่าจะผสมอย่างไร

เรามาเริ่มกันเลยดีกว่า

Open In Colab

เราจะมาลองเทรน 2 โมเดล เปรียบเทียบโมเดล ที่ใช้ Data Augmentation และไม่ใช้ ว่า Validation Loss จะต่างกันอย่างไร

0. Magic Commands

In [0]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

1. Import Library

In [0]:
from fastai import *
from fastai.vision import *
from fastai.metrics import accuracy

2. ข้อมูล

เราจะใช้ Dataset Oxford-IIIT Pet Dataset จำแนกพันธุ์หมาแมวเหมือนดิม

In [0]:
path = untar_data(URLs.PETS)
path_images = path/'images'
filenames = get_image_files(path_images)

ประกาศฟัง์ชัน สร้าง databunch และฟังก์ชันแสดงภาพ เราจะ Sample ข้อมูลมาแค่ 500 ตัวอย่าง

In [0]:
def get_databunch(transform):
    batchsize = 32
    sample = 5000
    np.random.seed(555)
    regex_pattern = r'/([^/]+)_\d+.jpg$'

    return ImageDataBunch.from_name_re(path_images, 
                                       random.sample(filenames, sample), 
                                       regex_pattern, 
                                       ds_tfms=transform, 
                                       size=224, bs=batchsize).normalize(imagenet_stats)

def get_ex(): return open_image(f'{path_images}/pug_147.jpg')

def plots_f(rows, cols, width, height, **kwargs):
    [get_ex().apply_tfms(transform[0], **kwargs).show(ax=ax) for i,ax in enumerate(plt.subplots(
        rows,cols,figsize=(width,height))[1].flatten())]

3. เตรียมข้อมูล

เราจะไปสร้าง DataBunch พร้อมสร้างโมเดลจะได้สะดวกในการเปรียบเทียบ

4. สร้างโมเดล

ในเคสนี้ เราจะใช้โมเดลที่ไม่ใหม่มาก ไม่มี Skip Connection อย่าง VGG และไม่ใช้ Dropout (ps=0.0), Weight Decay (wd=0.0) จะได้เปรียบเทียบได้ชัด ๆ

ไม่ใช้ Data Augmentation

ปิด Data Augmentaion ทุกอย่าง ด้วย Empty List 2 อัน คือ transform สำหรับ Training Set และ Validation Set

In [5]:
transform = ([], [])
databunch = get_databunch(transform)
learner = cnn_learner(databunch, models.vgg16_bn, ps=0.0, wd=0.0, 
                      metrics=accuracy, callback_fns=ShowGraph)#.to_fp16()
plots_f(3, 3, 9, 9, size=224)
Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/checkpoints/vgg16_bn-6c64b313.pth
100%|██████████| 553507836/553507836 [00:06<00:00, 84022164.53it/s]
In [6]:
learner.fit_one_cycle(1, max_lr=1e-2)
epochtrain_lossvalid_lossaccuracytime
00.6550790.4053140.87600000:59
In [7]:
learner.unfreeze()
learner.fit_one_cycle(8, max_lr=slice(3e-6, 3e-3))
epochtrain_lossvalid_lossaccuracytime
00.2364930.3146590.89600001:18
10.2057470.4268210.86800001:16
20.1735750.5292730.85400001:16
30.0916640.4219310.88200001:16
40.0349630.3505060.89800001:16
50.0234170.3387530.90900001:16
60.0072620.3306750.91300001:16
70.0030210.3309440.91200001:16

เคลียร์ Memory

In [8]:
learner = None
gc.collect()
Out[8]:
23212

ใช้ Data Augmentation

เปิด Data Augmentaion ทุกอย่าง

In [9]:
# transform = get_transform()
transform = get_transforms(do_flip=True, flip_vert=False, max_rotate=10.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75)
databunch = get_databunch(transform)
learner = cnn_learner(databunch, models.vgg16_bn, ps=0.0, wd=0.0, 
                      metrics=accuracy, callback_fns=ShowGraph)#.to_fp16()
plots_f(3, 3, 9, 9, size=224)
In [10]:
learner.fit_one_cycle(1, 1e-2)
epochtrain_lossvalid_lossaccuracytime
00.7540360.3981100.86900001:11
In [11]:
learner.unfreeze()
learner.fit_one_cycle(8, max_lr=slice(3e-6, 3e-3))
epochtrain_lossvalid_lossaccuracytime
00.3374700.3079220.89900001:19
10.3480620.4971710.84600001:18
20.3054660.3918890.88600001:17
30.1971340.3496130.89300001:18
40.1390960.3360360.90300001:18
50.0800440.2749230.92200001:17
60.0456400.2676920.92100001:18
70.0325430.2685800.92500001:17

5. สรุป

  1. โมเดลที่ไม่ได้ใช้ Data Augmentation เทรนไปหลาย Epoch แล้ว Training Loss ลดลงเรื่อย ๆ แต่ Validation Loss กลับไม่ลดลง และ Accuracy ก็ไม่ได้ดีขึ้น เป็นสัญญาณของ Overfit
  2. โมเดลที่ใช้ Data Augmentation เทรนไปด้วยจำนวน Epoch เท่ากัน Training Loss ลดลงเรื่อย ๆ พร้อมกับ Validation Loss และ Accuracy ก็ดีขึ้นเรื่อย ๆ ไม่ Overfit
  3. โมเดลสมัยใหม่ ออกแบบมาค่อนข้างดี ทำให้ Overfit ค่อนข้างยาก
In [0]:
 

Data Augmentation ใช้กันอย่างแพร่หลายในข้อมูลรูปภาพ แต่ก็ปัจจุบันมีการศึกษาเกี่ยวกับ Data Augmentation ในข้อมูลแบบอื่น ๆ เช่น ตาราง เสียงพูด และ ข้อความ NLP เช่น เปลี่ยนชื่อตัวละครจากชื่อนึงเป็นอีกชื่อนึง, เปลี่ยนสลับคำศัพท์ที่มีความหมายเหมือนกัน, Etc.

Regularization

Data Augmentation ถือว่าเป็นการ Regularization แบบหนึ่ง เพราะช่วยลด Overfit ช่วยให้โมเดลทำงานได้ Generalization ขึ้น ลดการจำข้อสอบ แต่ถือว่าเป็น Regularization ทางอ้อม เพราะไม่ได้ทำกับโมเดลโดยตรง แต่เป็นการทำกับข้อมูลแทน

แชร์ให้เพื่อน:

Keng Surapong on FacebookKeng Surapong on GithubKeng Surapong on Linkedin
Keng Surapong
Data Science Manager at Bua Labs
The ultimate test of your knowledge is your capacity to convey it to another.

Published by Keng Surapong

The ultimate test of your knowledge is your capacity to convey it to another.