Author avatar

Gaurav Singhal

Introduction to DenseNet with TensorFlow

Gaurav Singhal

  • May 6, 2020
  • 10 Min read
  • 37,737 Views
  • May 6, 2020
  • 10 Min read
  • 37,737 Views

Introduction

DenseNet is one of the new discoveries in neural networks for visual object recognition. DenseNet is quite similar to ResNet with some fundamental differences. ResNet uses an additive method (+) that merges the previous layer (identity) with the future layer, whereas DenseNet concatenates (.) the output of the previous layer with the future layer. Get in-depth knowledge of ResNet in this guide.

Why Do We DenseNet?

DenseNet was developed specifically to improve the declined accuracy caused by the vanishing gradient in high-level neural networks. In simpler terms, due to the longer path between the input layer and the output layer, the information vanishes before reaching its destination.

The primary purpose of this guide is to give insights on DenseNet and implement DenseNet121 using TensorFlow 2.0 (TF 2.0) and Keras.

In this guide, you will work with a data set called Natural Images that can be downloaded from Kaggle.

DenseNet Architecture

DenseNet Structure

DenseNet falls in the category of classic networks.

This image shows a 5-layer dense block with a growth rate of k = 4 and the standard ResNet structure.

Imgur Sources: DenseNet Structure - G. Huang, Z. Liu and L. van der Maaten, “Densely Connected Convolutional Networks,” 2018; Resnet Structure - Missinglink.ai

An output of the previous layer acts as an input of the second layer by using composite function operation. This composite operation consists of the convolution layer, pooling layer, batch normalization, and non-linear activation layer.

These connections mean that the network has L(L+1)/2 direct connections. L is the number of layers in the architecture.

The DenseNet has different versions, like DenseNet-121, DenseNet-160, DenseNet-201, etc. The numbers denote the number of layers in the neural network. The number 121 is computed as follows:

Imgur

DenseBlocks and Layers

Be it adding or concatenating, the grouping of layers by the above equation is only possible if feature map dimensions are the same. What if dimensions are different? The DenseNet is divided into DenseBlocks where a number of filters are different, but dimensions within the block are the same. Transition Layer applies batch normalization using downsampling; it's an essential step in CNN.

Let's see what's inside the DenseBlock and transition layer.:

Imgur Source: G. Huang, Z. Liu and L. van der Maaten, “Densely Connected Convolutional Networks,” 2018.

This is the full architecture in abstract form.:

Imgur

Source: Pablo R

The number of filters changes between the DenseBlocks, increasing the dimensions of the channel. The growth rate (k) helps in generalizing the l-th layer. It controls the amount of information to be added to each layer.

Imgur

Implementing the Code

Before starting, it is essential to import all the relevant libraries. The main drivers here are tensorflow.keras.applications to import DenseNet121 and tensorflow.keras.layers to import layers involved in building the network.

1import tensorflow 
2
3import pandas as pd
4import numpy as np
5import os
6import keras
7import random
8import cv2
9import math
10import seaborn as sns
11
12from sklearn.metrics import confusion_matrix
13from sklearn.preprocessing import LabelBinarizer
14from sklearn.model_selection import train_test_split
15
16import matplotlib.pyplot as plt
17
18from tensorflow.keras.layers import Dense,GlobalAveragePooling2D,Convolution2D,BatchNormalization
19from tensorflow.keras.layers import Flatten,MaxPooling2D,Dropout
20
21from tensorflow.keras.applications import DenseNet121
22from tensorflow.keras.applications.densenet import preprocess_input
23
24from tensorflow.keras.preprocessing import image
25from tensorflow.keras.preprocessing.image import ImageDataGenerator,img_to_array
26
27from tensorflow.keras.models import Model
28
29from tensorflow.keras.optimizers import Adam
30
31from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
32
33import warnings
34warnings.filterwarnings("ignore")
python
1print("Tensorflow-version:", tensorflow.__version__)
python

Output: Tensorflow-version: 2.0.0

1model_d=DenseNet121(weights='imagenet',include_top=False, input_shape=(128, 128, 3)) 
2
3x=model_d.output
4
5x= GlobalAveragePooling2D()(x)
6x= BatchNormalization()(x)
7x= Dropout(0.5)(x)
8x= Dense(1024,activation='relu')(x) 
9x= Dense(512,activation='relu')(x) 
10x= BatchNormalization()(x)
11x= Dropout(0.5)(x)
12
13preds=Dense(8,activation='softmax')(x) #FC-layer
python
1model=Model(inputs=base_model.input,outputs=preds)
2model.summary()
python

Imgur

To avoid the problem of overfitting, avoid training the entire network. layer.trainable=False will freeze all the layers, keeping only the last eight layers (FC) to detect edges and blobs in the image. Once the model is fitted well, it can be fine-tuned by using layer.trainable=True.

1for layer in model.layers[:-8]:
2    layer.trainable=False
3    
4for layer in model.layers[-8:]:
5    layer.trainable=True
python
1model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])
2model.summary()
python

Notice the drop in the parameters.

Imgur

1data=[]
2labels=[]
3random.seed(42)
4imagePaths = sorted(list(os.listdir("../input/natural-images/")))
5random.shuffle(imagePaths)
6print(imagePaths)
7
8for img in imagePaths:
9    path=sorted(list(os.listdir("../input/natural-images/"+img)))
10    for i in path:
11        image = cv2.imread("../input/natural-images/"+img+'/'+i)
12        image = cv2.resize(image, (128,128))
13        image = img_to_array(image)
14        data.append(image)
15        l = label = img
16        labels.append(l)
python

Imgur

1data = np.array(data, dtype="float32") / 255.0
2labels = np.array(labels)
3mlb = LabelBinarizer()
4labels = mlb.fit_transform(labels)
5print(labels[0])
python

Imgur

1(xtrain,xtest,ytrain,ytest)=train_test_split(data,labels,test_size=0.4,random_state=42)
2print(xtrain.shape, xtest.shape)
python

Imgur

If the model sees no change in validation loss the ReduceLROnPlateau function will reduce the learning rate, which often benefits the model. The ImageDataGenerator function performs real-time data augmentation over generated tensor image data batches in a loop.

1anne = ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=5, verbose=1, min_lr=1e-3)
2checkpoint = ModelCheckpoint('model.h5', verbose=1, save_best_only=True)
3
4datagen = ImageDataGenerator(zoom_range = 0.2, horizontal_flip=True, shear_range=0.2)
5
6
7datagen.fit(xtrain)
8# Fits-the-model
9history = model.fit_generator(datagen.flow(xtrain, ytrain, batch_size=128),
10               steps_per_epoch=xtrain.shape[0] //128,
11               epochs=50,
12               verbose=2,
13               callbacks=[anne, checkpoint],
14               validation_data=(xtrain, ytrain))
python

Imgur

1ypred = model.predict(xtest)
2
3total = 0
4accurate = 0
5accurateindex = []
6wrongindex = []
7
8for i in range(len(ypred)):
9    if np.argmax(ypred[i]) == np.argmax(ytest[i]):
10        accurate += 1
11        accurateindex.append(i)
12    else:
13        wrongindex.append(i)
14        
15    total += 1
16    
17print('Total-test-data;', total, '\taccurately-predicted-data:', accurate, '\t wrongly-predicted-data: ', total - accurate)
18print('Accuracy:', round(accurate/total*100, 3), '%')
python

Imgur

1label=['dog', 'flower', 'motorbike', 'person', 'cat', 'fruit', 'airplane', 'car']
2imidx = random.sample(accurateindex, k=9)# replace with 'wrongindex'
3
4nrows = 3
5ncols = 3
6fig, ax = plt.subplots(nrows,ncols,sharex=True,sharey=True,figsize=(15, 12))
7
8n = 0
9for row in range(nrows):
10    for col in range(ncols):
11            ax[row,col].imshow(xtest[imidx[n]])
12            ax[row,col].set_title("Predicted label :{}\nTrue label :{}".format(label[np.argmax(ypred[imidx[n]])], label[np.argmax(ytest[imidx[n]])]))
13            n += 1
14
15plt.show()
python

Imgur

1Ypred = model.predict(xtest)
2
3Ypred = np.argmax(Ypred, axis=1)
4Ytrue = np.argmax(ytest, axis=1)
5
6cm = confusion_matrix(Ytrue, Ypred)
7plt.figure(figsize=(12, 12))
8ax = sns.heatmap(cm, cmap="rocket_r", fmt=".01f",annot_kws={'size':16}, annot=True, square=True, xticklabels=label, yticklabels=label)
9ax.set_ylabel('Actual', fontsize=20)
10ax.set_xlabel('Predicted', fontsize=20)
python

Imgur

Conclusion

You have built a DenseNet model with ~98% accuracy. DenseNet diminishes the vanishing gradient problem, and it requires fewer parameters to train the model. Dynamic feature propagation takes care of the seamless flow of information.

This guide gives the basic knowledge on building the DenseNet-121, its architecture, its advantages, and how it is different from ResNet. From the heat map, we can see that 44 dogs are misclassified as cats, possibly because the misclassified dog pictures have traits similar to the cats. Results can be improved by fine-tuning the model. Try adding or removing more dense blocks and layers, finding the frequency of data in each class, and augmenting the images.

Deep Neural Network is a vast field. Progressive research is carried on to make it simpler to learn and solve complex real-world problems. If you need any help with your projects in Deep Learning, contact me at CodeAlphabet.

References

G. Huang, Z. Liu and L. van der Maaten, “Densely Connected Convolutional Networks,” 2018.