u.lekhoni.com

Data Loading and Processing

PyTorch provides utilities for loading and processing data, which are especially useful when dealing with large datasets or when using GPUs for training.

Import Necessary Libraries

First, we need to import PyTorch, and also the torchvision package, which includes data loaders for standard datasets such as ImageNet, CIFAR10, MNIST, etc., and data transformers for images.


import torch
import torchvision
import torchvision.transforms as transforms

Define the Transformations

Before loading the data, we define the transformations we want to apply. For the sake of simplicity, we’ll just convert the data to tensors.

transform = transforms.Compose(
    [transforms.ToTensor()])

Load the Dataset

We’ll use the CIFAR10 dataset, which contains 60,000 32x32 color images in 10 different classes. The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

# Load CIFAR10 training data
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

# Load CIFAR10 testing data
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

Create a Data Loader

A data loader provides an iterator over the dataset. Here, we’ll create a data loader for the training data.

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

Similarly, for the testing data:

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

Iterate Over the Data

We can now use the data loader in a loop. Here’s a simple example that prints the size of the images and labels.

for i, data in enumerate(trainloader, 0):
    # Get the inputs
    inputs, labels = data

    print(f'Input size: {inputs.size()}')
    print(f'Labels size: {labels.size()}')
    
    if i == 3:  # print sizes for first 4 batches and break
        break

In the loop, inputs are the input images, and labels are the labels for these images. We are printing the size of the inputs and labels for the first 4 batches only, just to avoid flooding your screen with output.

And that’s it! We have loaded the data, applied transformations to it, and iterated over it. This data is now ready to be used for training your models. In a real-world scenario, we would typically also include data augmentation transforms to improve the model’s performance.

Remember to experiment with these steps, try different transforms, datasets and data loaders to get a good understanding of these concepts.