จากใน ep ที่แล้วเราได้เรียนรู้การใช้งาน PyTorch Hook ใน ep นี้เราจะมา Refactor โค้ดสร้าง Class ขึ้นมาจัดการ Hook และใช้ Hook สถิติ ที่ลึกมากขึ้น

เราจะวิเคราะห์กราฟ Mean, Std และ Histogram จะเห็นว่าค่อนข้าง Converge เร็ว ไม่เกิด Vanishing Gradient เนื่องจาก PyTorch ได้แก้ปัญหาไปแล้วด้วย Kaiming Initialization แต่ก็ยังมีปัญหาอื่น ๆ อยู่ในช่วงแรก ๆ แล้วเราจะแก้ปัญหานี้อย่างไร

Mean, Standard Deviation of Activation Map of Convolutional Neural Network
Mean, Standard Deviation of Activation Map of Convolutional Neural Network

GeneralReLU คืออะไร

เราจะแก้ปัญหานี้ ด้วย ReLU แบบใหม่ ที่ลดความไม่ Balance เนื่องจากค่าลบหายไปหมด, มีการ Cap Activation ไม่ให้เกินค่าที่กำหนด และมีการใช้ Leaky ReLU ในส่วนที่ติดลบด้วย

รวมเรียกว่า GeneralReLU เราจะเปลี่ยน ReLU Activation Function ใน Convolutaional Neural Network ของเราเป็น GeneralReLU ทั้งหมด เพื่อให้เทรนโมเดลให้ Converge ได้เร็วขึ้น เราจะเริ่มต้นที่หัวข้อ 6.2 Hook

เรามาเริ่มกันเลยดีกว่า

Open In Colab

จากใน ep ที่แล้วเราได้เรียนรู้การใช้งาน Hook ใน ep นี้เราจะมา Refactor โค้ดสร้าง Class ขึ้นมาจัดการ Hook และใช้ Hook สถิติ ที่ลึกมากขึ้น วิเคราะห์ แล้วแก้ปัญหาเพื่อให้เทรนให้ Converge ได้เร็วขึ้น

เราจะเริ่มต้นที่หัวข้อ 6.2 Hook

0. Magic

In [0]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

1. Import

In [0]:
import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import *
import fastai
from fastai import datasets
from fastai.metrics import accuracy
from fastai.basic_data import *
from fastai.basic_train import *
import pickle, gzip, math, torch, re
from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt
from functools import partial

2. Data

In [0]:
class Dataset(Dataset):
    def __init__(self, x, y, c):
        self.x, self.y, self.c = x, y, c
    def __len__(self):
        return len(self.x)
    def __getitem__(self, i):
        return self.x[i], self.y[i]
In [0]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
In [0]:
def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train, y_train, x_valid, y_valid))
In [0]:
x_train, y_train, x_valid, y_valid = get_data()

3. Data Preprocessing

In [0]:
def normalize(x, m, s): 
    return (x-m)/s
In [0]:
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]
In [0]:
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()   
In [0]:
train_mean, train_std = x_train.mean(), x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)
In [0]:
bs = 256
n, m = x_train.shape
c = (y_train.max()+1).numpy()
loss_func = F.cross_entropy
In [0]:
train_ds, valid_ds = Dataset(x_train, y_train, c), Dataset(x_valid, y_valid, c)
In [0]:
data = DataBunch.create(train_ds, valid_ds, bs=bs, num_workers=8)

4. Model

In [0]:
def get_cnn_model(data):
    return nn.Sequential(
        # Lambda(mnist_resize),  # use BatchTransformXCallback instead.
        nn.Conv2d(  1,  8, 5, padding=2, stride=2), nn.ReLU(), #14
        nn.Conv2d(  8, 16, 3, padding=1, stride=2), nn.ReLU(), # 7
        nn.Conv2d( 16, 32, 3, padding=1, stride=2), nn.ReLU(), # 4
        nn.Conv2d( 32, 32, 3, padding=1, stride=2), nn.ReLU(), # 2
        nn.AdaptiveAvgPool2d(1), 
        Lambda(flatten), 
        nn.Linear(32, data.c)
    )
In [0]:
def init_cnn(m, uniform=False):
    f = init.kaiming_uniform_ if uniform else init.kaiming_normal_
    for l in m:
        if isinstance(l, nn.Sequential):
            f(l[0].weight, a=0.1)
            l[0].bias.data.zero_()

วิธีการเทรน Callback, Training Loop เหมือนปกติ เริ่มเทรนใน หัวข้อ 7. Train

In [0]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    def forward(self, x): return self.func(x)
In [0]:
def flatten(x): return x.view(x.shape[0], -1)
# def mnist_resize(x): return x.view(-1, 1, 28, 28) # use BatchTransformXCallback instead.

5. Training Loop

แล้วเราจะเทรนตามปกติ

In [0]:
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop, self.cbs = False, [TrainEvalCallback()]+cbs

    @property
    def opt_func(self):     return self.learn.opt_func
    @property
    def model(self):        return self.learn.model
    @property
    def loss_func(self):    return self.learn.loss_func
    @property
    def data(self):         return self.learn.data

    def one_batch(self, xb, yb):
        try: 
            self.xb, self.yb = xb, yb
            self('begin_batch')
            self.pred = self.model(self.xb)
            self('after_pred')
            self.loss = self.loss_func(self.pred, self.yb)
            self('after_loss')
            if not self.in_train: return
            self.loss.backward()
            self('after_backward')
            self.opt_func.step()
            self('after_step')
            self.opt_func.zero_grad()
        except CancelBatchException: self('after_cancel_batch')
        finally: self('after_batch')
    
    def all_batches(self, dl):
        self.iters = len(dl)
        try:
            for xb, yb in dl:
                self.one_batch(xb, yb)
        except CancelEpochException: self('after_cancel_epoch')
    
    def fit(self, epochs, learn):
        self.epochs, self.learn, self.loss = epochs, learn, tensor(0.)

        try:
            for cb in self.cbs: cb.set_runner(self)
            self('begin_fit')
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad():
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                self('after_epoch')
        except CancelTrainException: self('after_cancel_train')
        finally: 
            self('after_fit')
            self.train = None

    def __call__(self, cb_name):
        # return True = Cancel, return False = Continue (Default)
        res = False
        # check if at least one True return True
        for cb in sorted(self.cbs, key=lambda x: x._order): res = res or cb(cb_name)
        return res        

6. Callback and Hook

6.1 Callback

In [0]:
class Callback():
    _order = 0
    def set_runner(self, run): self.run = run
    def __getattr__(self, k): return getattr(self.run, k)

    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')
    
    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False

class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0.
        self.run.n_iter = 0
    
    def begin_epoch(self):
        self.run.n_epochs = self.epoch  
        self.model.train()
        self.run.in_train=True

    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter += 1

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False    
           
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass
In [0]:
class Recorder(Callback):
    def begin_fit(self): 
        self.lrs = [[] for _ in self.opt_func.param_groups]
        self.losses = []

    def after_batch(self):
        if not self.in_train: return
        for pg, lr in zip(self.opt_func.param_groups, self.lrs): lr.append(pg['lr'])
        self.losses.append(self.loss.detach().cpu())
    
    def plot_lr(self, pgid=-1): plt.plot(self.lrs[pgid])
    def plot_loss(self, skip_last=0): plt.plot(self.losses[:len(self.losses)-skip_last])
    def plot(self, skip_last=0, pgid=-1):
        losses = [o.item() for o in self.losses]
        lrs = self.lrs[pgid]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(lrs[:n], losses[:n])
In [0]:
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats, self.valid_stats = AvgStats(metrics, True), AvgStats(metrics, False)
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
    
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)

    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

class AvgStats():
    def __init__(self, metrics, in_train): 
        self.metrics, self.in_train = listify(metrics), in_train
    
    def reset(self):
        self.tot_loss, self.count = 0., 0
        self.tot_mets = [0.] * len(self.metrics)

    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]

    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn
In [0]:
class ParamScheduler(Callback):
    _order = 1
    def __init__(self, pname, sched_funcs): self.pname, self.sched_funcs = pname, sched_funcs

    def begin_fit(self): 
        if not isinstance(self.sched_funcs, (list, tuple)): 
            self.sched_funcs = [self.sched_funcs] * len(self.opt_func.param_groups)

    def set_param(self):
        assert len(self.opt_func.param_groups) == len(self.sched_funcs)
        for pg, f in zip(self.opt_func.param_groups, self.sched_funcs):
            pg[self.pname] = f(self.n_epochs/self.epochs)

    def begin_batch(self):
        if self.in_train: self.set_param()
In [0]:
def annealer(f):
    def _inner(start, end): return partial(f, start, end)
    return _inner 
In [0]:
@annealer
def sched_lin(start, end, pos): return start + pos * (end - start)
@annealer
def sched_cos(start, end, pos): return start + (1 + math.cos(math.pi*(1-pos))) * (end-start) / 2
In [0]:
torch.Tensor.ndim = property(lambda x: len(x.shape))
In [0]:
def combine_scheds(pcts, scheds):
    assert sum(pcts) == 1.
    pcts = tensor([0] + listify(pcts))
    assert torch.all(pcts >= 0)
    pcts = torch.cumsum(pcts, 0)

    def _inner(pos):
        idx = (pos >= pcts).nonzero().max()
        actual_pos = (pos-pcts[idx]) / (pcts[idx+1]-pcts[idx])
        return scheds[idx](actual_pos)
    return _inner    
In [0]:
max_lr = 3e-1
sched = combine_scheds([0.3, 0.7], [sched_cos(3e-3, max_lr), sched_cos(max_lr, 3e-4)])
In [0]:
class CudaCallback(Callback):
    def begin_fit(self): self.model.cuda()
    def begin_batch(self): self.run.xb, self.run.yb = self.xb.cuda(), self.yb.cuda()

เพื่อความยืดหยุ่น เราจะใช้ Callback ในการแปลง MNIST ที่มาในรูปแบบ Matrix 2 มิติ ให้เป็น 4 มิติตามที่โมเดลต้องการ แทนที่จะ Fix เป็น Layer ภายในโมเดล

In [0]:
class BatchTransformXCallback(Callback):
    _order = 2
    def __init__(self, tfm): self.tfm = tfm
    def begin_batch(self): 
        # set_trace()
        self.run.xb = self.tfm(self.xb)

def view_tfm(*size):
    # set_trace()
    def _inner(x): return x.view(*(-1, )+size)
    return _inner
In [0]:
mnist_view = view_tfm(1, 28, 28)

ในเคสนี้ เราจะไม่ใช้ ParamScheduler จะได้เห็นชัด ๆ

In [0]:
# cbfs = [Recorder, CudaCallback, partial(ParamScheduler, 'lr', sched), partial(BatchTransformXCallback, mnist_view), partial(AvgStatsCallback, accuracy)]
cbfs = [Recorder, CudaCallback, partial(BatchTransformXCallback, mnist_view), partial(AvgStatsCallback, accuracy)]

6.2 Hook

เราจะสร้าง Class มาห่อหุ้ม Hook เอาไว้ แทนที่ฟังก์ชันแบบเดิม จะได้ Maintain ง่ายขึ้น มีการลบทิ้งเมื่อใช้เสร็จ และเพิ่มการเก็บสถิติ Histogram ของ Activation Map อีก 1 ตัว

In [0]:
def children(m): list(m.children())

class Hook():
    # m = module, f = function
    def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
    def remove(self): self.hook.remove()
    def __del__(self): self.remove()

def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([], [], [])
    means, stds, hists = hook.stats 
    means.append(outp.data.mean().cpu())
    stds.append(outp.data.std().cpu())
    # Tensor.histc(bins, min, max)
    hists.append(outp.data.cpu().histc(40, 0, 10))

เราจะสร้าง Class ใหม่ขึ้นมาใช้แทน Numpy Array แต่ใส่คุณสมบัติเฉพาะที่เราต้องการเข้าไปด้วย

In [0]:
class ListContainer():
    def __init__(self, items): self.items = listify(items)
    def __getitem__(self, idx): 
        if isinstance(idx, (int, slice)): return self.items[idx]
        if isinstance(idx[0], bool):
            assert len(idx) == len(self) # boolean mask
            return [o for m, o in zip(idx, self.items) if m]
        return [self.items[i] for i in idx]
    def __len__(self): return len(self.items)
    def __iter__(self): return iter(self.items)
    def __setitem__(self, i, o): self.items[i] = o
    def __delitem__(self, i): del(self.items[i])
    def __repr__(self):
        res = f'{self.__class__.__name__} ({len(self)} items)\n{self.items[:10]}'
        if len(self) > 10: res = res[:-1] + '...]'
        return res

ลองเทสการใช้งาน ListContainer

In [34]:
ListContainer(range(10))
Out[34]:
ListContainer (10 items)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
In [35]:
ListContainer(range(1000))
Out[35]:
ListContainer (1000 items)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9...]
In [36]:
t = ListContainer(range(10))
t[[2, 3]], t[[False] * 8 + [True, False]]
Out[36]:
([2, 3], [8])

ประกาศ Class สำหรับเก็บ Hook ทั้งหมดของโมเดล สังเกต enter, exit dunder สำหรับใช้ใน with statement เพื่อที่จะ Ensure ว่า Hook จะถูก Remove หลังจากออกจาก Statement ที่กำหนด

In [0]:
class Hooks(ListContainer):
    def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()
    def __del__(self): self.remove()

    def __delitem__(self, i):
        self[i].remove()
        super().__delitem__(i)
    
    def remove(self):
        for h in self: h.remove()

7. Train

In [0]:
model = get_cnn_model(data)
init_cnn(model)
opt = torch.optim.SGD(model.parameters(), lr=max_lr)
learn = Learner(data, model, opt, loss_func=loss_func)

run = Runner(cb_funcs=cbfs)

8. Interpret

8.1 Mean และ Standard Deviation

เทรนไป 1 Epoch ก็พอ เนื่องจากเราต้องการจะดู Mean, Standard Deviation ของ Activation Map ของแต่ละ Layer เท่านั้น

In [39]:
with Hooks(model, append_stats) as hooks: 

    run.fit(1, learn)

    # Plot first 10 iterations.
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4))
    for h in hooks:
        # ms = means, ss = stds, _ = histogram
        ms, ss, _ = h.stats
        ax0.plot(ms[:10])
        ax1.plot(ss[:10])
    plt.legend(range(6))

    # Plot all
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4))
    for h in hooks:
        ms, ss, _ = h.stats
        ax0.plot(ms)
        ax1.plot(ss)
    plt.legend(range(6))
        
train: [1.4220355694110578, tensor(0.4962, device='cuda:0')]
valid: [0.265903759765625, tensor(0.9224, device='cuda:0')]

กราฟด้านบน แสดง Mean, Std Zoom 10 Iteration แรก กราฟล่างแสดงทั้งหมด

8.2 Histogram

เราจะมาวิเคราะห์ Histogram กันต่อว่า Activaiton Map 4 Layer แรก มีค่าในช่วงไหนกันบ้าง

In [0]:
def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p()
In [41]:
fig, axes = plt.subplots(2, 2, figsize=(15, 6))
for ax, h in zip(axes.flatten(), hooks[:4]):
    ax.imshow(get_hist(h), origin='lower')
    ax.axis('off')
plt.tight_layout()

จะเห็นว่าค่อนข้าง Converge เร็ว ไม่เกิด Vanishing Gradient เนื่องจาก PyTorch ได้แก้ปัญหาไปแล้วด้วย Kaiming Initialization แต่ก็ยังมีปัญหาอื่น ๆ อยู่ในช่วงแรก ๆ แล้วเราจะแก้ปัญหานี้อย่างไร

8.3 ค่า Min ของ Activation

เราจะมาดู 2 Bin แรก ที่มีค่าน้อยที่สุดของ Histogram ว่ามีค่าเฉลี่ยเป็นสัดส่วนเท่าไรเทียบกับ Activation ทั้งหมด และโมเดลใช้เวลาในการเทรน กี่ Iteration ถึงจะเข้าที่ มีค่า Activation Map มากขึ้น (กราฟ 2 Bin แรกนี้ลดลง)

In [0]:
def get_min(h):
    # stats = mean, std, histogram, stats[2] = hists
    h1 = torch.stack(h.stats[2]).t().float()
    return h1[:2].sum(0)/h1.sum(0)

เราจะดูเฉพาะ Hook ของ 4 Layer แรก

In [43]:
fig, axes = plt.subplots(2, 2, figsize=(15, 6))
for ax, h in zip(axes.flatten(), hooks[:4]):
    ax.plot(get_min(h))
    ax.set_ylim(0, 1)
plt.tight_layout()

มี Activation Map ที่อยู่ 2 Bin แรก มีสัดส่วนมากเกินไป และเทรนไปเท่าไรก็ยังมากเหมือนเดิม

4/2. Model

จาก ep ก่อน ๆ ที่เราพบว่า ReLU มีปัญหา ทำให้ Activation Map ไม่ Balance เพราะทำให้ค่าติดลบหายไปหมด เราจะสร้าง Class ใหม่มาแก้ปัญหานี้ ชื่อว่า GeneralReLU ที่รับ Parameter leak = ค่าในช่วงติดลบ, sub = จะลบ Output ด้วยค่าเท่าไรเพื่อเลื่อน Mean, maxv = จะ Cap Output มากที่สุดไม่เกินเท่าไร

In [0]:
class GeneralRelu(nn.Module): 
    def __init__(self, leak=None, sub=None, maxv=None):
        super().__init__()
        self.leak, self.sub, self.maxv = leak, sub, maxv
    def forward(self, x):
        x = F.leaky_relu(x, self.leak) if self.leak is not None else F.relu(x)
        if self.sub is not None: x.sub_(self.sub)
        if self.maxv is not None: x.clamp_max_(self.maxv)
        return x

เนื่องจาก GeneralReLU ทำให้ Activation Map ของเรามีค่าเป็นลบได้ เราจึงต้องเปลี่ยน Histogram จาก 0-10 ให้เป็นช่วง -10-10

In [0]:
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([], [], [])
    means, stds, hists = hook.stats 
    means.append(outp.data.mean().cpu())
    stds.append(outp.data.std().cpu())
    # Tensor.histc(bins, min, max)
    hists.append(outp.data.cpu().histc(80, -10, 10))
In [0]:
def get_cnn_model2(data, leak, sub, maxv):
    return nn.Sequential(
        # Lambda(mnist_resize),  # use BatchTransformXCallback instead.
        nn.Conv2d(  1,  8, 5, padding=2, stride=2), GeneralRelu(leak, sub, maxv), #14
        nn.Conv2d(  8, 16, 3, padding=1, stride=2), GeneralRelu(leak, sub, maxv), # 7
        nn.Conv2d( 16, 32, 3, padding=1, stride=2), GeneralRelu(leak, sub, maxv), # 4
        nn.Conv2d( 32, 32, 3, padding=1, stride=2), GeneralRelu(leak, sub, maxv), # 2
        nn.AdaptiveAvgPool2d(1), 
        Lambda(flatten), 
        nn.Linear(32, data.c)
    )

7/2 Train

สร้างโมเดล Convolutional Neural Network ที่ใช้ ReLU พิเศษ ที่เราสร้างขึ้น โดย leak = 0.1, sub=0.4, maxv ไม่เกิน 6.0

In [0]:
model = get_cnn_model2(data, leak=0.1, sub=0.4, maxv=6.)
init_cnn(model)
opt = torch.optim.SGD(model.parameters(), lr=max_lr)
learn = Learner(data, model, opt, loss_func=loss_func)

run = Runner(cb_funcs=cbfs)

8/2 Intepret new General ReLU

เทรนโมเดล และพล็อตกราฟ นำมาวิเคราะห์ ตีความ

In [48]:
with Hooks(model, append_stats) as hooks: 

    run.fit(1, learn)

    # Plot first 10 iterations.
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4))
    for h in hooks:
        # ms = means, ss = stds, _ = histogram
        ms, ss, _ = h.stats
        ax0.plot(ms[:10])
        ax1.plot(ss[:10])
    plt.legend(range(6))

    # Plot all
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4))
    for h in hooks:
        ms, ss, _ = h.stats
        ax0.plot(ms)
        ax1.plot(ss)
    plt.legend(range(6))
        
train: [1.0515426244491186, tensor(0.6456, device='cuda:0')]
valid: [0.2699013671875, tensor(0.9193, device='cuda:0')]

Histogram สามารถอยู่ในช่วงลบได้แล้ว เนื่องจาก GeneralReLU เราให้ sub=0.4

In [49]:
fig, axes = plt.subplots(2, 2, figsize=(15, 6))
for ax, h in zip(axes.flatten(), hooks[0:4]):
    ax.imshow(get_hist(h), origin='lower')
    ax.axis('off')
plt.tight_layout()

เนื่องจาก Histogram กว้างขึ้น เราต้องปรับ min เป็น Bin ที่ 38-42 ถึงจะเป็นค่าที่เข้าใกล้ 0 หมายถึงค่าที่น้อยที่สุด

In [0]:
def get_min(h):
    # stats = mean, std, histogram, stats[2] = hists
    h1 = torch.stack(h.stats[2]).t().float()
    return h1[38:42].sum(0)/h1.sum(0)
In [51]:
fig, axes = plt.subplots(2, 2, figsize=(15, 6))
for ax, h in zip(axes.flatten(), hooks[0:4]):
    ax.plot(get_min(h))
    ax.set_ylim(0, 1)
plt.tight_layout()

จะเห็นว่า GeneralReLU ลดสัดส่วน Activation ที่มีค่าน้อย ลดลง ช่วยให้ โมเดล Converge เร็วขึ้น ลด Activation Map แปลก ๆ ช่วยแรกให้น้อยลง

9. สรุป

  • เราได้สร้าง Class ขึ้นมาจัดการสร้าง Hook เมื่อต้องการใช้ และทำลาย เมื่อเราใช้เสร็จ
  • เราใช้ Hook เก็บสถิติ Mean, Std, Histogram ของ Activation Map ภายในโมเดล ขณะ Feedforward
  • การวิเคราะห์ Activation Map ในแง่มุมต่าง ๆ ช่วยให้เราเข้าใจการทำงานภายในโมเดล และมองเห็นปัญหาที่เกิดขึ้นระหว่างการเทรนได้ดีขึ้น

Credit

แชร์ให้เพื่อน:

Keng Surapong on FacebookKeng Surapong on GithubKeng Surapong on Linkedin
Keng Surapong
Project Manager at Bua Labs
The ultimate test of your knowledge is your capacity to convey it to another.

Published by Keng Surapong

The ultimate test of your knowledge is your capacity to convey it to another.

Enable Notifications    OK No thanks