DataLoader
is a pytorch class that wraps around a dataset to create a iterator that makes it easy to shuffle and create batches for training.
from torch.utils.data import DataLoader
coll = range(15)
dl = DataLoader(coll, batch_size=5, shuffle=True)
list(dl), next(iter(dl))
([tensor([ 8, 12, 4, 1, 7]),
tensor([ 0, 3, 5, 13, 9]),
tensor([11, 6, 10, 14, 2])],
tensor([ 1, 5, 9, 2, 12]))
Two kinds of datasets can be supplied to a pytorch DataLoader
, a iterable-style dataset or a map-style dataset.
A iterable dataset can iterable over a data collection. The easiest way to test it is to use iter()
over the dataset. For example, a range object is a iterable dataset:
type(coll), iter(coll)
(range, <range_iterator at 0x7fc7f8111950>)
A map-style dataset represents a map from (possibly non-integral) indices/keys to data samples (should have key-value pairs). It also implements the __getitem__()
and __len__()
methods.
import string
class AlphabetDataset:
def __init__(self, letters):
self.letters = list(letters)
def __len__(self):
return len(self.letters)
def __getitem__(self, idx):
return self.letters[idx]
dataset = AlphabetDataset(string.ascii_lowercase)
dataloader = DataLoader(dataset, batch_size=6, shuffle=True)
next(iter(dataloader))
['o', 'v', 'k', 's', 't', 'p']
or more simply, use a dictionary
dt = dict(enumerate(string.ascii_lowercase))
dt.__getitem__(2), dt.__len__()
('c', 26)
dicLoader = DataLoader(dt, batch_size=6, shuffle=True)
next(iter(dicLoader))
['w', 's', 't', 'm', 'c', 'd']