PyTorch | Data Augmentation

Table of Contents:

Python libraries for data augmentation

Data augmentation is a technique where you increase the number of data examples somehow. The additional data examples should ideally have the same or “close” data distribution as the initial data.

Imagine your initial data is 100 images. You can create 50 more images similar to these original 100 to augment.

With more data you have better chance to train your model better.

There are several Python image libraries we may use for image augmentation:

  • CV2
  • PIL
  • matplotlib
  • scikit-images
  • pgmagic
  • numpy
  • SciPy

Reading images

Example: Reading the image using CV2 and displaying using matplotlib

%matplotlib inline
from matplotlib import pyplot as plt
import cv2
_ = plt.imshow(img)

data aug

Example: Reading the image using Pillow and displaying using matplotlib

%matplotlib inline
from PIL import Image
from matplotlib import pyplot as plt
img ="C:\Users\dj\Desktop\img.jpg")
_ = plt.imshow(img)

data aug

Example: Reading and displaying images using matplotlib

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread(r"C:\Users\dj\Desktop\img.jpg")
_ = plt.imshow(img)

data aug

Images as Tensors

Example: To tensor from tensor

%matplotlib inline
import PIL    
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO

url = ""

response = requests.get(url)
img =
display(img) # Display a Python object in all frontends.

ToTensor = torchvision.transforms.ToTensor()
FromTensor = torchvision.transforms.ToPILImage()
t = FromTensor(ToTensor(img))
plt.imshow(t) # matplotlib show

data aug

Cropping images

We were using in here just the regular numpy indexing. im=im[:,:,::-1] means invert the non-spatial dimension.

Example: Cropping the image using CV2

import cv2
from matplotlib import pyplot as plt

im = cv2.imread(r"C:\Users\dj\Desktop\img.jpg") # numpy.ndarray
h,w, c = img.shape 
crop_im = im[y:y+h//2, x:x+w//2]
_ = plt.imshow(crop_im)

data aug

Affecting channels

Example: Split and merge image channels using CV2

%matplotlib inline
import cv2
from matplotlib import pyplot as plt

p = r"C:\Users\dj\Desktop\img.jpg"
img = cv2.imread(p)
img = img[:,:,::-1]
b,g,r = cv2.split(img)
img = cv2.merge((r,g,b)) # default is b,g,r
_ = plt.imshow(img)

data aug


Example: Two images into one using CV2

%matplotlib inline
import cv2
from matplotlib import pyplot as plt

img1 = cv2.imread(r'C:\Users\dj\Desktop\img.jpg')
img1 = img1[:,:,::-1]

img2 = cv2.imread(r'C:\Users\dj\Desktop\img2.jpg')
img2 = img2[:,:,::-1]

cimg1 = cv2.resize(img1,(450,450))
cimg2 = cv2.resize(img2,(450,450))

dst = cv2.addWeighted(cimg1,0.5,cimg2,0.5,0)

data aug

Using permute()

The notebook also shows the trick to print the image using permute method since matplotlib pyplot imshow method needs channel dimension at the very end.

print("img shape: ", imgs[0].shape)
img = imgs[0]
img = img.permute(1, 2, 0)
_ = plt.imshow(img)

Saving images

Example: Saving PIL image

The notebook also shows PIL imge save() method to save the image.

# create the imgs folder
Path('/content/imgs').mkdir(parents=True, exist_ok=True)

# create folders for 10 classes
for c in range(0,10): 
    Path(f'/content/imgs/{c}').mkdir(parents=True, exist_ok=True)

# save all the images to specific folder
for idx, (img, cls) in enumerate(ds):'/content/imgs/{cls}/{idx}.png')     

Image augmentation via transforms

In PyTorch there is torchvision.transforms module. It can help transforming original image known as image augmentation.

You can achieve this when creating the Dataset with the transform parameter.

We already showcased this example:

import torch
from import DataLoader, Dataset
import torchvision
from torchvision import transforms
import PIL.Image as Image

t = transforms.Compose([transforms.RandomHorizontalFlip(),
                       transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]

And then once you have the composed transformation t

ds = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t)

Here is the list of common transforms:

Resizing Images

Often, when you train image datasets images differ in size. For instance, in case of ImageNet dataset not all images are 224×224.

There are two options:

torchvision.transforms.Resize(size, interpolation=2)

Where size is a pair of integers (H, W). Unfortunately, this way aspect ratio is gone.

Resize takes the single integer (in case of ImageNet 224). The smaller from the two edges matches that size, aspect ratio is kept and the bugger edge is suppressed using CenterCrop.

The final output should be 224×224 in both cases.


A folder of classes

Often for the classification problems there may be multiple folders, each for a single class of data. Let’s create one example by downloading the cifar-10 dataset from URL.

!tar xzf cifar.tgz
import glob
print(len(f), f)

Or you may create the DatasetFolder based on different classes of images you have.

def pil_loader(path):    
    with open(path, 'rb') as f:
        img =
        return img.convert('RGB')

ds = torchvision.datasets.DatasetFolder('/content/classes', 

assuming you have different images inside your classes.

Load disk images

Next we will show how to create a dataset from the torchvision.datasets.DatasetFolder(). You may use DatasetFolder whenever you have multiple subdirectories each for a different class.


t = transforms.Compose(
    transforms.ToTensor() ])
#  transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ])
# this function will be used to load images from disk
def pil_loader(path):    
    with open(path, 'rb') as f:
        img =
        return img.convert('RGB')

# with DatasetFolder we can set the extensions
ds = torchvision.datasets.DatasetFolder(



Dataset DatasetFolder
    Number of datapoints: 60000
    Root location: /content/imgs
Transform: Compose(

tags: data augmentation - data augmentation in pytorch - data augmentation pytorch - pytorch data augmentation - visual machine learning - visual images - image examples - split and merge image - data distribution - torchvision - CV2 - PIL - matplotlib - scikit-images - pgmagic - numpy - SciPy & category: pytorch