ปัญหาหลักอย่างนึงในการเทรน Deep Learning คือ Dataset ของเรามีข้อมูลตัวอย่างไม่เพียงพอ สมมติว่าเราปิ๊งสุดยอดไอเดีย ที่จะสร้าง App ใหม่ ที่ใช้ Machine Learning ขึ้นมา เราเปิดเว็บเพื่อ Search Google หาข้อมูลตัวอย่าง มาไว้เทรนโมเดล เรานั่ง Search Google Images หารูปภาพอยู่หลายชั่วโมง นั่งจัด นั่ง Clean ข้อมูลที่ไม่เกี่ยวข้องออกไป สุดท้ายเราได้ รูปมา 500 รูป ถ้าหาแบบนี้ 10 วัน ก็ 5,000 รูป แต่เรารู้มาว่าโมเดลที่ดัง ๆ ใช้ข้อมูลในการเทรน เกิน 1 ล้านรูปขึ้นไปทั้งนั้น แล้วเราจะทำอย่างไรดี
Transfer Learning คืออะไร
Transfer Learning คือ การนำโมเดลของผู้อื่น ที่เทรนกับข้อมูลอื่นเรียบร้อยแล้ว ตัดเอาบางส่วนเฉพาะส่วนที่เราสนใจ มา Reuse ใช้ต่อ มาประกอบในการสร้างโมเดลใหม่ของเรา เช่น เรานำโมเดล Convolution Neural Network (CNN) ที่มี 100 Layer มาตัด Layer สุดท้ายทิ้งไป เปลี่ยนเป็น Layer ที่เหมาะกับงานของเราแทน แล้วเวลาเทรน ก็เทรนแค่ Layer ใหม่ที่อยู่ท้ายสุดอย่างเดียว เพื่อลดเวลา และข้อมูลที่ใช้ในการเทรน เพราะมี 99 Layer ที่ทำงานได้ถูกต้องอยู่แล้ว ไม่ต้องเทรนตั้งแต่ต้น เรื่อง Transfer Learning จะอธิบายเพิ่มเติมต่อไป
Data Augmentation คืออะไร
แต่ถึงเราจะใช้ Transfer Learning ประกอบโมเดลของเราขึ้นมาจากโมเดล Convolution Neural Network (CNN) ชื่อดัง เช่น ResNet, VGGNet, Inception รูปแค่ 500 รูปของเราก็ไม่พอเทรนอยู่ดี แล้วเราจะแก้อย่างไร
คำตอบ คือ เราก็เอารูป 500 รูปมา Recycle ให้เป็น 1,000 รูปสิ เช่น สมมติว่าเรามีรูปดอกไม้ 500 รูป เราเอารูปทั้งหมดมา Flip กลับซ้ายขวา ก็จะได้รูปเพิ่มขึ้นมาอีก 500 รูป รวมเป็น 1,000 รูป แบบง่าย ๆ เรียกว่า Data Augmentation

Data Augmentation ที่เป็นที่นิยมสำหรับรูปภาพมีอีกหลายอย่าง ได้แก่ ย่อ/ขยาย, หมุน ซ้าย/ขวา, Flip ซ้าย/ขวา/บน/ล่าง, Crop มุม, ปรับสีเข้ม/จืด, ปรับแสง สว่าง/มืด, ปรับ Contrast, ปรับ Perspective, เพิ่ม/ลด Noise, เบลอภาพ, Etc. ทั้งนี้ขึ้นอยู่กับงานด้วย เช่น ดอกไม้จะ Flip กลับหัวได้ไหม
Data Augmentation Zoom In Data Augmentation Zoom Out
Data Augmentation Rotate Right 15 Degree Data Augmentation Rotate Left 15 Degree
Data Augmentation Flip Horizontal Data Augmentation Flip Vertical
Data Augmentation Crop Bottom Right Data Augmentation Crop Top Left
Data Augmentation Decrease Saturation Data Augmentation Increase Saturation
Data Augmentation Brighten Data Augmentation Darken
Data Augmentation Skew Data Augmentation Noise
ถ้าทำทั้งหมดผสม ๆ กันมากน้อย แบบ Random ไปเรื่อย ๆ จากรูปดอกไม้ ตอนแรกเพียง 500 รูป ก็จะกลายเป็นหมื่นเป็นแสนรูปได้ไม่จำกัด แต่คุณภาพของโมเดลก็จะขึ้นอยู่กับ ข้อมูลเริ่มต้น และ Hyperparameter ที่เรากำหนดว่าจะผสมอย่างไร
เรามาเริ่มกันเลยดีกว่า
เราจะมาลองเทรน 2 โมเดล เปรียบเทียบโมเดล ที่ใช้ Data Augmentation และไม่ใช้ ว่า Validation Loss จะต่างกันอย่างไร
0. Magic Commands¶
%reload_ext autoreload
%autoreload 2
%matplotlib inline
1. Import Library¶
from fastai import *
from fastai.vision import *
from fastai.metrics import accuracy
2. ข้อมูล¶
เราจะใช้ Dataset Oxford-IIIT Pet Dataset จำแนกพันธุ์หมาแมวเหมือนดิม
path = untar_data(URLs.PETS)
path_images = path/'images'
filenames = get_image_files(path_images)
ประกาศฟัง์ชัน สร้าง databunch และฟังก์ชันแสดงภาพ เราจะ Sample ข้อมูลมาแค่ 500 ตัวอย่าง
def get_databunch(transform):
batchsize = 32
sample = 5000
np.random.seed(555)
regex_pattern = r'/([^/]+)_\d+.jpg$'
return ImageDataBunch.from_name_re(path_images,
random.sample(filenames, sample),
regex_pattern,
ds_tfms=transform,
size=224, bs=batchsize).normalize(imagenet_stats)
def get_ex(): return open_image(f'{path_images}/pug_147.jpg')
def plots_f(rows, cols, width, height, **kwargs):
[get_ex().apply_tfms(transform[0], **kwargs).show(ax=ax) for i,ax in enumerate(plt.subplots(
rows,cols,figsize=(width,height))[1].flatten())]
3. เตรียมข้อมูล¶
เราจะไปสร้าง DataBunch พร้อมสร้างโมเดลจะได้สะดวกในการเปรียบเทียบ
4. สร้างโมเดล¶
ในเคสนี้ เราจะใช้โมเดลที่ไม่ใหม่มาก ไม่มี Skip Connection อย่าง VGG และไม่ใช้ Dropout (ps=0.0), Weight Decay (wd=0.0) จะได้เปรียบเทียบได้ชัด ๆ
ไม่ใช้ Data Augmentation¶
ปิด Data Augmentaion ทุกอย่าง ด้วย Empty List 2 อัน คือ transform สำหรับ Training Set และ Validation Set
transform = ([], [])
databunch = get_databunch(transform)
learner = cnn_learner(databunch, models.vgg16_bn, ps=0.0, wd=0.0,
metrics=accuracy, callback_fns=ShowGraph)#.to_fp16()
plots_f(3, 3, 9, 9, size=224)
learner.fit_one_cycle(1, max_lr=1e-2)
learner.unfreeze()
learner.fit_one_cycle(8, max_lr=slice(3e-6, 3e-3))
เคลียร์ Memory
learner = None
gc.collect()
ใช้ Data Augmentation¶
เปิด Data Augmentaion ทุกอย่าง
# transform = get_transform()
transform = get_transforms(do_flip=True, flip_vert=False, max_rotate=10.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75)
databunch = get_databunch(transform)
learner = cnn_learner(databunch, models.vgg16_bn, ps=0.0, wd=0.0,
metrics=accuracy, callback_fns=ShowGraph)#.to_fp16()
plots_f(3, 3, 9, 9, size=224)
learner.fit_one_cycle(1, 1e-2)
learner.unfreeze()
learner.fit_one_cycle(8, max_lr=slice(3e-6, 3e-3))
5. สรุป¶
- โมเดลที่ไม่ได้ใช้ Data Augmentation เทรนไปหลาย Epoch แล้ว Training Loss ลดลงเรื่อย ๆ แต่ Validation Loss กลับไม่ลดลง และ Accuracy ก็ไม่ได้ดีขึ้น เป็นสัญญาณของ Overfit
- โมเดลที่ใช้ Data Augmentation เทรนไปด้วยจำนวน Epoch เท่ากัน Training Loss ลดลงเรื่อย ๆ พร้อมกับ Validation Loss และ Accuracy ก็ดีขึ้นเรื่อย ๆ ไม่ Overfit
- โมเดลสมัยใหม่ ออกแบบมาค่อนข้างดี ทำให้ Overfit ค่อนข้างยาก
Data Augmentation ใช้กันอย่างแพร่หลายในข้อมูลรูปภาพ แต่ก็ปัจจุบันมีการศึกษาเกี่ยวกับ Data Augmentation ในข้อมูลแบบอื่น ๆ เช่น ตาราง เสียงพูด และ ข้อความ NLP เช่น เปลี่ยนชื่อตัวละครจากชื่อนึงเป็นอีกชื่อนึง, เปลี่ยนสลับคำศัพท์ที่มีความหมายเหมือนกัน, Etc.
Regularization
Data Augmentation ถือว่าเป็นการ Regularization แบบหนึ่ง เพราะช่วยลด Overfit ช่วยให้โมเดลทำงานได้ Generalization ขึ้น ลดการจำข้อสอบ แต่ถือว่าเป็น Regularization ทางอ้อม เพราะไม่ได้ทำกับโมเดลโดยตรง แต่เป็นการทำกับข้อมูลแทน