DataLoader#

This example shows how to re-implement Pytorch DataLoader using Seqtools. The DataLoader is an iterable objects that wraps an indexable dataset adding shuffling, batching, prefetching and a few other operations relevant for a Neural Network training pipeline.

For the sake of clarity, this version does not support iterable datasets or infinite samplers because it would require to generate shuffling indices on the go which is not easily done with prefetch. This limitation may be circumvented by writing a different pre-fetching logic that handles a custom iterator-based index order. This would actually be fairly easy since the whole multi-process computation lies in a separate class and can be reused.

"""An implementation of PyTorch's DataLoader to showcase seqtools."""

import copyreg
import numbers
from functools import partial

import torch

import seqtools


# overload torch.Tensor pickling to benefit from zero copy on buffer
def pickle_tensor(t: torch.Tensor):
    return torch.from_numpy, (t.contiguous().numpy(),)


def worker_init_fn_wrapper(user_fn, *kargs, **kwargs):
    copyreg.pickle(torch.Tensor, pickle_tensor)
    if user_fn is not None:
        user_fn(*kargs, **kwargs)


def pin_tensors_memory(value):
    """Pin memory of tensors inside an object."""
    if isinstance(value, (tuple, list)):
        return value.__class__(pin_tensors_memory(v) for v in value)
    elif isinstance(value, dict):
        return value.__class__((k, pin_tensors_memory(v)) for k, v in value.items())
    elif isinstance(value, torch.Tensor):
        return value.pin_memory()


def default_collate_fn(values):
    """Stack samples together into a minibatch."""
    if not isinstance(values, list):  # force evaluation if not done already
        values = list(values)

    sample = values[0]

    if isinstance(sample, torch.Tensor):
        return torch.stack(values)
    elif isinstance(sample, numbers.Integral):
        return torch.tensor(values)
    elif isinstance(sample, (tuple, list)):
        return sample.__class__(default_collate_fn(row) for row in zip(*values))
    elif isinstance(sample, dict):
        return sample.__class__(
            (k, default_collate_fn([v[k] for v in values])) for k in sample.keys()
        )


def gather_items(a, items):
    return [a[i] for i in items]


class DataLoader:
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=False,
        sampler=None,
        batch_sampler=None,
        num_workers=0,
        collate_fn=None,
        pin_memory=False,
        drop_last=False,
        worker_init_fn=None,
        prefetch_factor=2,
        shm_size=0,
    ):
        """Re-implementation of pytorch DataLoader using seqtools.

        Notable differences:

        - only datasets and samplers with a len() are supported, and shuffling
          indices will be pre-computed before iterating.
        - shm_size specifies how much shared memory to allocate for zero-copy
          transfers between workers and the main process. That shared
          memory is divided into num_worker * prefetch_factor slots.
        - timeout is not implemented
        - a pool of shared memory with a fixed size (shm_size) is used for
          zero-copy buffer transfers from workers.
        """
        # sampling/batching
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.num_workers = num_workers
        self.collate_fn = collate_fn or default_collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.worker_init_fn = partial(worker_init_fn_wrapper, worker_init_fn)
        self.prefetch_factor = prefetch_factor
        self.shm_size = shm_size

    def __len__(self):
        if self.batch_sampler:
            return len(self.batch_sampler)

        dataset_size = len(self.sampler) if self.sampler else len(self.dataset)
        if self.batch_size and self.drop_last and dataset_size % self.batch_size > 0:
            return dataset_size // self.batch_size
        else:
            return dataset_size // self.batch_size + 1

    def make_sequence(self):
        """Build a sequence that looks like a DataLoader when iterated over."""
        # shuffling
        if self.batch_sampler:
            batch_indices = list(self.batch_sampler)
            out = seqtools.smap(partial(gather_items, self.dataset), batch_indices)
        elif self.sampler:
            shuffle_indices = list(self.sampler)
            out = seqtools.gather(self.dataset, shuffle_indices)
        elif self.shuffle:
            shuffle_indices = torch.randperm(len(self.dataset))
            out = seqtools.gather(self.dataset, shuffle_indices)
        else:
            out = self.dataset

        # batch
        if not self.batch_sampler and self.batch_size is not None:
            out = seqtools.batch(
                out,
                k=self.batch_size,
                drop_last=self.drop_last,
                collate_fn=self.collate_fn,
            )
        elif self.batch_sampler:
            out = seqtools.smap(self.collate_fn, out)

        # prefetch
        if self.num_workers > 0:
            out = seqtools.prefetch(
                out,
                max_buffered=max(4, self.num_workers * self.prefetch_factor),
                nworkers=self.num_workers,
                method="process",
                start_hook=self.worker_init_fn,
                shm_size=self.shm_size,
            )

        # pin memory
        if self.pin_memory:
            out = seqtools.smap(pin_tensors_memory, out)
            out = seqtools.prefetch(  # execute in background thread
                out, nworkers=1, method="thread", max_buffered=1
            )

        return out

    def __iter__(self):
        """Instantiate a new data pipeline and return an iterator over it."""
        return iter(self.make_sequence())

Sample usage:

import torch
from torchvision.datasets import FakeData
from torchvision import transforms as T

transform = T.Compose([
    T.Resize((256, 256)),
    T.ColorJitter(),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.ConvertImageDtype(torch.float),
])
dataset = FakeData(100, (320, 320), 10, transform=transform)

DataLoader = DataLoader(
    dataset,
    num_workers=2,
    batch_size=8,
    shm_size=16777216,  # 16MB
)
for images, labels in DataLoader:
    pass