จาก ep ที่แล้ว ที่เราประยุกต์ใช้ Callback กับ Training Loop ในการเทรน Machine Learning ด้วยอัลกอริทึม Gradient Descent สร้างเป็น Runner Class ที่มี Callback ในทุก ๆ Event ที่เป็นไปได้ ในการเทรน Deep Neural Network แล้วเราจะใช้ประโยชน์จาก ระบบ Callback อันแสนยืดหยุ่นนี้ อย่างไรได้บ้าง
ใน ep นี้เราจะมาดูตัวอย่างการสร้าง Callback แบบง่าย ๆ แต่มีประโยชน์ คือการ คำนวน Metrics และ บันทึกค่า Loss, Learning Rate
ในการเทรน Machine Learning ด้วยอัลกอริทึม Gradient Descent เราสามารถประยุกต์ใช้ Callback ทำงานได้หลากหลาย เนื่องจากเราจะรู้ถึง Event ต่าง ๆ ที่เกิดขึ้นภายใน Training Loop ของ Neural Network ทุกขั้นตอน และ สามารถเข้าถึง/ปรับเปลี่ยนข้อมูล Model Deep Neural Network, Runner, DataBunch, Loss Function, etc. ได้ทั้งหมด

Callback Limitation
Callback มีข้อดีที่ทำให้การเทรนมีความยืดหยุ่นมาก แต่ก็มีข้อจำกัดข้อใหญ่ คือ เราไม่สามารถเข้าถึง Layer, Weight, Gradient ต่าง ๆ หรือ Event ภายในโมเดล Deep Neural Network ได้
เนื่องจากโมเดลที่เราใช้เป็นการเรียกใช้ API ของ PyTorch อีกต่อหนึ่ง ในกรณีที่เราต้องการเข้าถึงภายในโมเดลจริง ๆ PyTorch ก็มี Callback ให้เราใช้เหมือนกัน เรียกว่า Hook จะอธิบายต่อไป
เรามาเริ่มกันเลยดีกว่า
เราสามารถประยุกต์ใช้ Callback ทำงานได้หลากหลาย เนื่องจากเราจะรู้ถึง Event ต่าง ๆ ที่เกิดขึ้นภายใน Training Loop ทุกขั้นตอน และ สามารถเข้าถึงข้อมูล Model, Runner, DataBunch, Loss Function, etc. ได้ทั้งหมด
แต่ Callback ก็มีข้อจำกัดอยู่ คือ เราไม่สามารถเข้าถึง Layer ต่าง ๆ หรือ Event ภายในโมเดลได้ เนื่องจากโมเดลที่เราใช้เป็นการเรียกใช้ API ของ PyTorch อีกต่อหนึ่ง ในกรณีที่เราต้องการเข้าถึงภายในโมเดล PyTorch ก็มี Callback ให้เราใช้เหมือนกัน เรียกว่า Hook จะอธิบายต่อไป
0. Magic¶
%load_ext autoreload
%autoreload 2
%matplotlib inline
1. Import¶
import torch
from torch import tensor
from torch.nn import *
import torch.nn.functional as F
from torch.utils.data import *
from fastai import datasets
from fastai.metrics import accuracy
import pickle, gzip, math, torch, re
from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt
2. Data¶
class Dataset(Dataset):
def __init__(self, x, y):
self.x, self.y = x, y
def __len__(self):
return len(self.x)
def __getitem__(self, i):
return self.x[i], self.y[i]
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
def get_data():
path = datasets.download_data(MNIST_URL, ext='.gz')
with gzip.open(path, 'rb') as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
return map(tensor, (x_train, y_train, x_valid, y_valid))
x_train, y_train, x_valid, y_valid = get_data()
def normalize(x, m, s):
return (x-m)/s
from typing import *
def listify(o):
if o is None: return []
if isinstance(o, list): return o
if isinstance(o, str): return [o]
if isinstance(o, Iterable): return list(o)
return [o]
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
s1 = re.sub(_camel_re1, r'\1_\2', name)
return re.sub(_camel_re2, r'\1_\2', s1).lower()
train_mean, train_std = x_train.mean(), x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)
nh, bs = 100, 32
n, m = x_train.shape
c = (y_train.max()+1).numpy()
loss_func = F.cross_entropy
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
train_dl, valid_dl = DataLoader(train_ds, bs), DataLoader(valid_ds, bs)
3. DataBunch¶
class DataBunch():
def __init__(self, train_dl, valid_dl, c=None):
self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c
@property
def train_ds(self): return self.train_dl.dataset
@property
def valid_ds(self): return self.valid_dl.dataset
data = DataBunch(train_dl, valid_dl, c)
4. Model¶
lr = 0.03
epoch = 10
nh = 50
def get_model():
# loss function
loss_func = F.cross_entropy
model = Sequential(Linear(m, nh), ReLU(), Linear(nh,c))
return model, loss_func
class Learner():
def __init__(self, model, opt, loss_func, data):
self.model, self.opt, self.loss_func, self.data = model, opt, loss_func, data
5. Training Loop¶
Training Loop ที่รองรับ Callback
class Runner():
def __init__(self, cbs=None, cb_funcs=None):
cbs = listify(cbs)
for cbf in listify(cb_funcs):
cb = cbf()
setattr(self, cb.name, cb)
cbs.append(cb)
self.stop, self.cbs = False, [TrainEvalCallback()]+cbs
@property
def opt(self): return self.learn.opt
@property
def model(self): return self.learn.model
@property
def loss_func(self): return self.learn.loss_func
@property
def data(self): return self.learn.data
def one_batch(self, xb, yb):
try:
self.xb, self.yb = xb, yb
self('begin_batch')
self.pred = self.model(xb)
self('after_pred')
self.loss = self.loss_func(self.pred, yb)
self('after_loss')
if not self.in_train: return
self.loss.backward()
self('after_backward')
self.opt.step()
self('after_step')
self.opt.zero_grad()
except CancelBatchException: self('after_cancel_batch')
finally: self('after_batch')
def all_batches(self, dl):
self.iters = len(dl)
try:
for xb, yb in dl:
self.one_batch(xb, yb)
except CancelEpochException: self('after_cancel_epoch')
def fit(self, epochs, learn):
self.epochs, self.learn, self.loss = epochs, learn, tensor(0.)
try:
for cb in self.cbs: cb.set_runner(self)
self('begin_fit')
for epoch in range(epochs):
self.epoch = epoch
if not self('begin_epoch'): self.all_batches(self.data.train_dl)
with torch.no_grad():
if not self('begin_validate'): self.all_batches(self.data.valid_dl)
self('after_epoch')
except CancelTrainException: self('after_cancel_train')
finally:
self('after_fit')
self.train = None
def __call__(self, cb_name):
# return True = Cancel, return False = Continue (Default)
res = False
# check if at least one True return True
for cb in sorted(self.cbs, key=lambda x: x._order): res = res or cb(cb_name)
return res
class Callback():
_order = 0
def set_runner(self, run): self.run = run
def __getattr__(self, k): return getattr(self.run, k)
@property
def name(self):
name = re.sub(r'Callback$', '', self.__class__.__name__)
return camel2snake(name or 'callback')
def __call__(self, cb_name):
f = getattr(self, cb_name, None)
if f and f(): return True
return False
class TrainEvalCallback(Callback):
def begin_fit(self):
self.run.n_epochs = 0.
self.run.n_iter = 0
def begin_epoch(self):
self.run.n_epochs = self.epoch
self.model.train()
self.run.in_train=True
def after_batch(self):
if not self.in_train: return
self.run.n_epochs += 1./self.iters
self.run.n_iter += 1
def begin_validate(self):
self.model.eval()
self.run.in_train=False
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass
6. Callbacks¶
6.1 Recorder Callback¶
เริ่มกันที่ Recorder Callback เป็น Callback ที่ทำหน้าที่เก็บ Loss และ Learning Rate ของแต่ละ Batch (after_batch) เอาไว้พล็อตกราฟ ใน ep หน้า เราจะเห็นการเปลี่ยนแปลงของ Learning Rate ในช่วงต่าง ๆ ของการเทรน เป็นการทำ Learning Rate Annealing จะอธิบายต่อไป
class Recorder(Callback):
def begin_fit(self): self.lrs, self.losses = [], []
def after_batch(self):
if not self.in_train: return
self.lrs.append(self.opt.param_groups[-1]['lr'])
self.losses.append(self.loss.detach().cpu())
def plot_lr(self): plt.plot(self.lrs)
def plot_loss(self): plt.plot(self.losses)
ลองเทรน แล้วพล็อตกราฟ
model, loss_func = get_model()
opt = torch.optim.SGD(model.parameters(), lr=lr)
learn = Learner(model, opt, loss_func, data)
run = Runner(cb_funcs=Recorder)
run.fit(2, learn)
เราสามารถสั่งพล็อตกราฟ จาก Recorder Callback ได้ดังด้านล่าง
run.recorder.plot_lr()
สังเกตว่า Learning Rate คงที่ตลอด, Loss ลดลงอย่างรวดเร็ว แต่กราฟเหวี่ยงมาก เนื่องจาก SGD Algorithm
run.recorder.plot_loss()
6.2 AvgStatsCallback¶
ต่อกันที่ AvgStatsCallback เราจะเขียน Callback ที่จะทำหน้าที่คำนวน Metrics หลังจากที่คำนวน Loss ทุกครั้ง (after_loss) สะสมไว้ แล้วหาค่าเฉลี่ย ตอนจบ Epoch (after_epoch) แล้ว Print ออกมา
class AvgStatsCallback(Callback):
def __init__(self, metrics):
self.train_stats, self.valid_stats = AvgStats(metrics, True), AvgStats(metrics, False)
def begin_epoch(self):
self.train_stats.reset()
self.valid_stats.reset()
def after_loss(self):
stats = self.train_stats if self.in_train else self.valid_stats
with torch.no_grad(): stats.accumulate(self.run)
def after_epoch(self):
print(self.train_stats)
print(self.valid_stats)
สร้าง Class ที่เป็น Logic คำนวน
class AvgStats():
def __init__(self, metrics, in_train):
self.metrics, self.in_train = listify(metrics), in_train
def reset(self):
self.tot_loss, self.count = 0., 0
self.tot_mets = [0.] * len(self.metrics)
@property
def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
@property
def avg_stats(self): return [o/self.count for o in self.all_stats]
def __repr__(self):
if not self.count: return ""
return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
def accumulate(self, run):
bn = run.xb.shape[0]
self.tot_loss += run.loss * bn
self.count += bn
for i, m in enumerate(self.metrics):
self.tot_mets[i] += m(run.pred, run.yb) * bn
ลองเทส AvgStatsCallback โดยใช้ Metric ตัวเดียว ในที่นี้คือ accuracy
model, loss_func = get_model()
opt = torch.optim.SGD(model.parameters(), lr=lr)
learn = Learner(model, opt, loss_func, data)
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)
run.fit(5, learn)
loss, acc = stats.valid_stats.avg_stats
loss, acc
6.3 ใช้ Partial ช่วยสร้าง Callback¶
from functools import partial
ถ้าเราใช้ ลองเทส AvgStatsCallback โดยใช้แค่ accuracy บ่อย ๆ เราสามารถใช้ Partial สร้าง Function/Class ใหม่จาก AvgStatsCallback + Accuracy เตรียมไว้ได้เลย
acc_cbf = partial(AvgStatsCallback, accuracy)
ลองเทรนโมเดล โดยพาส Callback ใหม่ที่สร้างจาก Partial และ Recorder ทั้ง 2 Callback ใส่ List ไปให้ cb_funcs=[acc_cbf, Recorder]
model, loss_func = get_model()
opt = torch.optim.SGD(model.parameters(), lr=lr)
learn = Learner(model, opt, loss_func, data)
run = Runner(cb_funcs=[acc_cbf, Recorder])
run.fit(5, learn)
acc_cbf Callback ก็จะ Print Training Loss+Accuracy และ Validation Loss+Accuracy และเราสามารถสั่งพล็อตกราฟ จาก Recorder Callback ได้ดังด้านล่าง
run.avg_stats.valid_stats.avg_stats
run.recorder.plot_lr()
run.recorder.plot_loss()
Credit¶