DataLoader 和 Iterator 的差異

在 trace AI 課程中的訓練代碼時, 看到 iterator 和 DataLoader 同場出現時有點暈, 我想說這兩個做的事情不是一樣嗎? 所以花點時間將這個疑惑釐清.

先講結論: Iterator 是 design pattern, DataLoader 是 PyTorch 的 class.

Iterator 做為 design pattern 的好處是: 不用管 dataset 真正長什麼樣子, 每次都都抓出一包 data, 還用不到的先不抓, 等下次抓, 以節省記憶體.

那一包要抓多大呢? 透過 DataLoader 這個 class 取得 dataloader 這個 object.

然後 data_iterator 是個 iterator, iter() 把 dataloader 轉換成可迭代的 object.

first_batch 的 type 取決 Dataset, 主要就是一包 data. 而第一次呼叫 next() 拿到的是第一包, 而不是第二包.

from torch.utils.data import DataLoader, Dataset
...
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
....
data_iterator = iter(dataloader)
first_batch = next(data_iterator)

在 Tensorflow 中, 沒有 DataLoader, 但有 tf.data.Dataset. 效果也是一次抓一包.

import tensorflow as tf

# Create a basic dataset
dataset = tf.data.Dataset.range(10)

# Apply transformations (e.g., batch, shuffle)
batched_dataset = dataset.batch(2)

# Iterate through the dataset
for batch in batched_dataset:
    print(batch)

發表留言