Author avatar

Gaurav Singhal

Transfer Learning with ResNet in PyTorch

Gaurav Singhal

  • May 5, 2020
  • 12 Min read
  • 50,640 Views
  • May 5, 2020
  • 12 Min read
  • 50,640 Views

Introduction

To solve complex image analysis problems using deep learning, network depth (stacking hundreds of layers) is important to extract critical features from training data and learn meaningful patterns. However, adding neural layers can be computationally expensive and problematic because of the gradients. In this guide, you will learn about problems with deep neural networks, how ResNet can help, and how to use ResNet in transfer learning.

Important: I highly recommend that you understand the basics of CNN before reading further about ResNet and transfer learning. Read this Image Classification Using PyTorch guide for a detailed description of CNN.

The Problem

As the authors of this paper discovered, a multi-layer deep neural network can produce unexpected results. In this case, the training accuracy dropped as the layers increased, technically known as vanishing gradients.

While training, the vanishing gradient effect on network output with regard to parameters in the initial layer becomes extremely small. The gradient becomes further smaller as it reaches the minima. As a result, weights in initial layers update very slowly or remain unchanged, resulting in an increase in error.

Let's see how Residual Network (ResNet) flattens the curve.

ResNet

A residual network, or ResNet for short, is an artificial neural network that helps to build deeper neural network by utilizing skip connections or shortcuts to jump over some layers. You'll see how skipping helps build deeper network layers without falling into the problem of vanishing gradients.

There are different versions of ResNet, including ResNet-18, ResNet-34, ResNet-50, and so on. The numbers denote layers, although the architecture is the same.

To create a residual block, add a shortcut to the main path in the plain neural network, as shown in the figure below.

Imgur

Why does skipping a connection work?

Follow the math below:

Imgur

From the math above, we can conclude:

  • It's easier for identity function to learn for Residual Network
  • It's better to skip 1, 2, and 3 layers. Identity function will map well with an output function without hurting NN performance. It will ensure that higher layers perform as well as lower layers.

ResNet Blocks

There are two main types of blocks used in ResNet, depending mainly on whether the input and output dimensions are the same or different.

  • Identity Block: When the input and output activation dimensions are the same.
  • Convolution Block: When the input and output activation dimensions are different from each other.

For example, to reduce the activation dimensions (HxW) by a factor of 2, you can use a 1x1 convolution with a stride of 2.

The figure below shows how residual block look and what is inside these blocks.

Imgur

Data Preparation

In this guide, you'll use the Fruits 360 dataset from Kaggle. You can download the dataset here. It's big—approximately 730 MB—and contains a multi-class classification problem with nearly 82,000 images of 120 fruits and vegetables.

Let's see the code in action. The first step is always to prepare your data.

Import the torch library and transform or normalize the image data before feeding it into the network. Learn more about pre-processing data in this guide.

1import torch
2import torch.nn as nn
3import torch.optim as optim
4import torch.nn.functional as F
5import numpy as np
6import torchvision
7from torchvision import *
8from torch.utils.data import Dataset, DataLoader
9
10import matplotlib.pyplot as plt
11import time
12import copy
13import os
14
15batch_size = 128
16learning_rate = 1e-3
17
18transforms = transforms.Compose(
19[
20    transforms.ToTensor()
21])
22
23train_dataset = datasets.ImageFolder(root='/input/fruits-360-dataset/fruits-360/Training', transform=transforms)
24test_dataset = datasets.ImageFolder(root='/input/fruits-360-dataset/fruits-360/Test', transform=transforms)
25
26train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
27test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
28device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
29
30def imshow(inp, title=None):
31    
32    inp = inp.cpu() if device else inp
33    inp = inp.numpy().transpose((1, 2, 0))
34    
35    mean = np.array([0.485, 0.456, 0.406])
36    std = np.array([0.229, 0.224, 0.225])
37    inp = std * inp + mean
38    inp = np.clip(inp, 0, 1)
39    
40    plt.imshow(inp)
41    if title is not None:
42        plt.title(title)
43    plt.pause(0.001)
44    
45images, labels = next(iter(train_dataloader)) 
46print("images-size:", images.shape)
47
48out = torchvision.utils.make_grid(images)
49print("out-size:", out.shape)
50
51imshow(out, title=[train_dataset.classes[x] for x in labels])
python

img

img

Transfer Learning with Pytorch

The main aim of transfer learning (TL) is to implement a model quickly. To solve the current problem, instead of creating a DNN (dense neural network) from scratch, the model will transfer the features it has learned from the different dataset that has performed the same task. This transaction is also known as knowledge transfer.

Imgur Source: James Le

The Pytorch API calls a pre-trained model of ResNet18 by using models.resnet18(pretrained=True), the function from TorchVision's model library. ResNet-18 architecture is described below.

Imgur

1net = models.resnet18(pretrained=True)
2net = net.cuda() if device else net
3net
python

img

img

1criterion = nn.CrossEntropyLoss()
2optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
3
4def accuracy(out, labels):
5    _,pred = torch.max(out, dim=1)
6    return torch.sum(pred==labels).item()
7
8num_ftrs = net.fc.in_features
9net.fc = nn.Linear(num_ftrs, 128)
10net.fc = net.fc.cuda() if use_cuda else net.fc
python

Finally, add a fully-connected layer for classification, specifying the classes and number of features (FC 128).

1n_epochs = 5
2print_every = 10
3valid_loss_min = np.Inf
4val_loss = []
5val_acc = []
6train_loss = []
7train_acc = []
8total_step = len(train_dataloader)
9for epoch in range(1, n_epochs+1):
10    running_loss = 0.0
11    correct = 0
12    total=0
13    print(f'Epoch {epoch}\n')
14    for batch_idx, (data_, target_) in enumerate(train_dataloader):
15        data_, target_ = data_.to(device), target_.to(device)
16        optimizer.zero_grad()
17        
18        outputs = net(data_)
19        loss = criterion(outputs, target_)
20        loss.backward()
21        optimizer.step()
22
23        running_loss += loss.item()
24        _,pred = torch.max(outputs, dim=1)
25        correct += torch.sum(pred==target_).item()
26        total += target_.size(0)
27        if (batch_idx) % 20 == 0:
28            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
29                   .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
30    train_acc.append(100 * correct / total)
31    train_loss.append(running_loss/total_step)
32    print(f'\ntrain-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}')
33    batch_loss = 0
34    total_t=0
35    correct_t=0
36    with torch.no_grad():
37        net.eval()
38        for data_t, target_t in (test_dataloader):
39            data_t, target_t = data_t.to(device), target_t.to(device)
40            outputs_t = net(data_t)
41            loss_t = criterion(outputs_t, target_t)
42            batch_loss += loss_t.item()
43            _,pred_t = torch.max(outputs_t, dim=1)
44            correct_t += torch.sum(pred_t==target_t).item()
45            total_t += target_t.size(0)
46        val_acc.append(100 * correct_t/total_t)
47        val_loss.append(batch_loss/len(test_dataloader))
48        network_learned = batch_loss < valid_loss_min
49        print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}\n')
50
51        
52        if network_learned:
53            valid_loss_min = batch_loss
54            torch.save(net.state_dict(), 'resnet.pt')
55            print('Improvement-Detected, save-model')
56    net.train()
python

img

img

The accuracy will improve further if you increase the epochs.

1fig = plt.figure(figsize=(20,10))
2plt.title("Train-Validation Accuracy")
3plt.plot(train_acc, label='train')
4plt.plot(val_acc, label='validation')
5plt.xlabel('num_epochs', fontsize=12)
6plt.ylabel('accuracy', fontsize=12)
7plt.legend(loc='best')
python

img

1def visualize_model(net, num_images=4):
2    images_so_far = 0
3    fig = plt.figure(figsize=(15, 10))
4    
5    for i, data in enumerate(test_dataloader):
6        inputs, labels = data
7        if use_cuda:
8            inputs, labels = inputs.cuda(), labels.cuda()
9        outputs = net(inputs)
10        _, preds = torch.max(outputs.data, 1)
11        preds = preds.cpu().numpy() if use_cuda else preds.numpy()
12        for j in range(inputs.size()[0]):
13            images_so_far += 1
14            ax = plt.subplot(2, num_images//2, images_so_far)
15            ax.axis('off')
16            ax.set_title('predictes: {}'.format(test_dataset.classes[preds[j]]))
17            imshow(inputs[j])
18            
19            if images_so_far == num_images:
20                return 
21
22plt.ion()
23visualize_model(net)
24plt.ioff()
python

img

Conclusion

The model has an accuracy of 97%, which is great, and it predicts the fruits correctly.

This guide gives a brief overview of problems faced by deep neural networks, how ResNet helps to overcome this problem, and how ResNet can be used in transfer learning to speed up the development of CNN. I highly recommend you learn more by going through the resources mentioned above, performing EDA, and getting to know your data better. Try customizing the model by freezing and unfreezing layers, increasing the number of ResNet layers, and adjusting the learning rate. Read this post for further mathematical background. If you still have any questions, feel free to contact me at CodeAlphabet.

Transfer learning adapts to a new domain by transferring knowledge to new tasks. The concepts of ResNet are creating new research angles, making it more efficient to solve real-world problems day by day.