ใน ep ก่อน ๆ เราได้เรียนรู้ Tensor การจัดการมิติ การเลือกข้อมูลด้วย indexing, slicing กันไปแล้ว ใน ep นี้ เราจะมาเรียนรู้การเลือกข้อมูล Tensor ที่ซับซ้อนยิ่งขึ้น ด้วย gather อ่านเอกสารแล้วอาจจะยังงง เรามาดูตัวอย่างกันเลยดีกว่า

ในเคสตัวอย่างเช่น Sequence Model ข้อมูลตัวอย่าง มี 3 มิติ เช่น BATCH_SIZE x MAX_SEQ_LEN x HIDDEN_STATE สมมติเราต้องการเลือก Hidden State ทั้งหมดของทั้ง Batch ที่ ณ Sequence ตัวสุดท้าย (Sequence แต่ละตัวไม่เท่ากัน แต่มี Padding ให้เท่ากับ MAX_SEQ_LEN) แต่ในเคสนี้เราไม่สามารถใช้การ indexing แบบปกติได้ เพราะตำแหน่งของ Sequence ตัวสุดท้ายไม่เท่ากัน โดยเราจะมีลิสต์ความยาวของแต่ละ Sequence ใน Batch (ลิสต์ของตำแหน่งของ Sequence ตัวสุดท้าย)

แล้วเราจะใช้วิธีไหนในการดึงข้อมูลใน Tensor ที่มีหลายมิติ ในเงื่อนไขที่ซับซ้อนแบบนี้ คำตอบ คือ gather

gather คืออะไร

tf Gather Gather slices from params axis axis according to indices. Credit https://www.tensorflow.org/api_docs/python/tf/gather
tf Gather Gather slices from params axis axis according to indices. Credit https://www.tensorflow.org/api_docs/python/tf/gather

gather คือ ฟังก์ชันที่ใช้ในการดึงค่าตามแกน (dim มิติ) ด้วย index ที่กำหนด

Gathers values along an axis specified by dim.

torch.gather(inputdimindexout=Nonesparse_grad=False) → Tensor

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • out (Tensor, optional) – the destination tensor
  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

อ่าน Help แล้วก็ยังงง และตัวอย่างก็น้อยไปนิดนึง เรามาดูตัวอย่างเพิ่มเติมทีละ Step จะเข้าใจง่ายกว่า

เรามาเริ่มกันเลยดีกว่า

แชร์ให้เพื่อน:

Surapong Kanoktipsatharporn on FacebookSurapong Kanoktipsatharporn on LinkedinSurapong Kanoktipsatharporn on Rss
Surapong Kanoktipsatharporn
Solutions Architect at Bua Labs
The ultimate test of your knowledge is your capacity to convey it to another.

Published by Surapong Kanoktipsatharporn

The ultimate test of your knowledge is your capacity to convey it to another.