ใน ep ก่อน ๆ เราได้เรียนรู้ Tensor การจัดการมิติ การเลือกข้อมูลด้วย indexing, slicing กันไปแล้ว ใน ep นี้ เราจะมาเรียนรู้การเลือกข้อมูล Tensor ที่ซับซ้อนยิ่งขึ้น ด้วย gather อ่านเอกสารแล้วอาจจะยังงง เรามาดูตัวอย่างกันเลยดีกว่า
ในเคสตัวอย่างเช่น Sequence Model ข้อมูลตัวอย่าง มี 3 มิติ เช่น BATCH_SIZE x MAX_SEQ_LEN x HIDDEN_STATE สมมติเราต้องการเลือก Hidden State ทั้งหมดของทั้ง Batch ที่ ณ Sequence ตัวสุดท้าย (Sequence แต่ละตัวไม่เท่ากัน แต่มี Padding ให้เท่ากับ MAX_SEQ_LEN) แต่ในเคสนี้เราไม่สามารถใช้การ indexing แบบปกติได้ เพราะตำแหน่งของ Sequence ตัวสุดท้ายไม่เท่ากัน โดยเราจะมีลิสต์ความยาวของแต่ละ Sequence ใน Batch (ลิสต์ของตำแหน่งของ Sequence ตัวสุดท้าย)
แล้วเราจะใช้วิธีไหนในการดึงข้อมูลใน Tensor ที่มีหลายมิติ ในเงื่อนไขที่ซับซ้อนแบบนี้ คำตอบ คือ gather
gather คืออะไร

gather คือ ฟังก์ชันที่ใช้ในการดึงค่าตามแกน (dim มิติ) ด้วย index ที่กำหนด
Gathers values along an axis specified by dim.
torch.
gather
(input, dim, index, out=None, sparse_grad=False) → Tensor
- input (Tensor) – the source tensor
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to gather
- out (Tensor, optional) – the destination tensor
- sparse_grad (bool,optional) – If
True
, gradient w.r.t.input
will be a sparse tensor.
อ่าน Help แล้วก็ยังงง และตัวอย่างก็น้อยไปนิดนึง เรามาดูตัวอย่างเพิ่มเติมทีละ Step จะเข้าใจง่ายกว่า