From ImageNet to X-Ray

Case Study of Transfer Learning

Transfer Learning is a powerful tool allowing us to use huge pre-trained models and tune them for our task. In this post we review a case of transferring weights of a CNN architecture trained ImageNet to detect chest X-Ray infected with pneumonia.

Table of Contents

Prerequisites

  • Artificial Neural Networks
  • Deep Learning
  • Convolutional Neural Network
  • Keras/Tensorflow
  • Python

Transfer Learning in a Nutshell.

Transfer learning allows us to leverage huge general data collections to train a model, which in the next step is fine-tuned for our specific task. A typical case in Computer Vision is to train a deep learning model on the ImageNet collection of more than 14M images organized in more than 20k categories. Following, such a model is entirely or just partially tuned with our data collection, let’s say containing 5k images in 4 classes. Another case in Natural Language Processing boils down to word embeddings and deriving language models from huge corpora of text like Wikipedia.

We do not have to train a base model for transfer learning ourselves. Pre-trained weights for many architectures to download are available to download. Good places to start research are:

  • Keras Applications containing architectures with pre-trained weights,
  • TensorFlow Hub which is a library for the publication, discovery, and consumption of reusable parts of machine learning models,
  • TensorFlow Model Garden is a repository with a number of different implementations of state-of-the-art models and modeling solutions for TensorFlow.

Why Would I Care About Transfer Learning?

There are a few cases when one cares about Transfer Learning.

  • Small dataset. The most straightforward case, where we do not have enough data to train a deep neural network. Leveraging a pre-trained model allows us to start from a point where some shapes, context, sounds can already be recognized by a model. All we need to do is to fine-tune it for our particular case.
  • Limited time. Training an entire neural network requires not only a lot of data but also a lot of time. In some cases days or weeks on a modern computation unit. Transfer learning allows us to achieve similar results overnight. This is crucial for quick prototyping and preparing proofs-of-concept in a commercial environment.
  • Limited computation power. As mentioned above, training a deep neural network requires a modern computation unit, preferably equipped with a GPU with a significant amount of graphic card memory. When you do not have such hardware, you still can develop a decent DL model leveraging Transfer Learning.

Transfer Learning Strategies.

Typically, when we use Transfer Learning in Computer Vision we have to replace the last layer which is specific for our task (e.g., softmax, 4 outputs for 4 class classification). There are two main strategies of what to do with the remaining layers:

  • Tune them all. In this strategy, we use weights from the pre-trained model as our initial weights and we train the entire model with our dataset. This can give us a better result, supposing we have enough data. However, the downside of this strategy is that it takes longer to train and, what might be even more important, it consumes much more resources. The increased resource consumption means in some cases we won’t be able to train such a model.
  • Tune only last layers. This strategy boils down to freezing most of the model and training only a limited number of final layers. In some cases, one can train only the fully connected layers or just the very last block of the CNN layers.

In some cases, it is also beneficial to simplify the fully connected layers. We can limit the number of the FC layers, as well as, limit the number of units in each layer. This decreases the capacity of our model enabling faster training. Such limitation obviously can work only if our task is much simpler than the original task used to train the model. For example, we deal only with two classes, whilst the original model was trained on 1000 classes.

The decision should be made with time and resources limitations in mind. Moreover, it should be empirically validated with an appropriate metric.

Problem Statement

We aim at detecting if an X-Ray image is NORMAL or infected with PNEUMONIA. It boils down to a binary classification of an image.

Dataset

For this experiment we use a publicly available dataset.

“Large Dataset of Labeled Optical Coherence Tomography (OCT) and Chest X-Ray Images”
Kermany, Daniel; Zhang, Kang; Goldbaum, Michael (2018)
Mendeley Data, v3

The dataset consists of 5856 images in two categories NORMAL and PNEUMONIA. We have 1583 images labeled as NORMAL and 4273 images labeled as PNEUMONIA. We spit the dataset into three subsets train, validation, and test.

NORMAL PNEUMONIA
train 1283 3473
val 150 400
test 150 400

Data Preprocessing and Augmentation

Training a deep learning model on a large volume of images permits us to achieve much better performance of the model. Moreover, it also enables the model to generalize better to new images. To increase volume and variety of images, we employ the data augmentation technique. The data augmentation technique allows us to manipulate images in multiple ways. We employ the following:

  • rotation
  • width shift
  • height shift
  • shear
  • zoom_range
  • flip

We also preprocess images as defined in the source model. We use a preprocessing function preprocess_input defined in tf.keras.applications.vgg16.

With tf.keras.preprocessing.image.ImageDataGenerator we create two data generators:

  • train_datagen - for training process where we employ all the aforementioned data augmentation techniques,
  • test_datagen - for validation and test sets, here we only preprocess images as we want our validation and test images to be as close to real images as possible hence, we avoid any unnecessary manipulations.
train_datagen = ImageDataGenerator(
                                  preprocessing_function = preprocess_input,
                                  rotation_range=40,
                                  width_shift_range=0.3,
                                  height_shift_range=0.4,
                                  shear_range=0.2,
                                  zoom_range=0.4,
                                  horizontal_flip=True)

test_datagen = ImageDataGenerator(preprocessing_function = preprocess_input)

Having configured the data generator we create an iterator that generates batches of augmented and preprocessed images. We create three such iterators:

  • train_generator with the train dataset and data generator augmenting and preprocessing values of the images (train_datagen).
  • val_generator with the validation dataset and data generator only preprocessing of the images (train_datagen).
  • test_generator with the test dataset and data generator only preprocessing of the images (train_datagen).

The CNN architecture we use takes as an input RGB images of the size 224x224, so we have to take care that all images are resized accordingly with target_size. We also define batch_size for the model. Batch size is also a parameter we can adjust for the resources we have at hand. In this case, we use 62 images per batch as we are limited to the GPU card memory. For validation and test, we take one image at a time hence, batch_size is set to 1. We have only two classes in the dataset, the target of the model can be defined as a 1D NumPy array of binary labels. For this, we set class_mode to binary.

batch_size = 62
target_size = (224,224)

train_dir = '/tf/work/data/train'
train_generator = train_datagen.flow_from_directory(train_dir,
                                                   batch_size=batch_size,
                                                   class_mode='binary',
                                                   target_size=target_size)

val_dir = '/tf/work/data/val'
val_generator = test_datagen.flow_from_directory(val_dir,
                                                   batch_size=1,
                                                   class_mode='binary',
                                                   target_size=target_size)

test_dir = '/tf/work/data/test'
test_generator = test_datagen.flow_from_directory(test_dir,
                                                   batch_size=1,
                                                   class_mode='binary',
                                                   target_size=target_size, 
                                                   shuffle=False)

The Model

Having prepared the data, we can move to our CNN architecture. We use a simple yet powerful architecture named VGG16.

For a detailed description of the architecture please refer to the original paper:

Very Deep Convolutional Networks for Large-Scale Image Recognition
Karen Simonyan and Andrew Zisserman (2014)
arXiv 1409.1556

The advantage of using high-level frameworks and libraries is that at this step all we need to do is to import VGG16 form tensorflow.keras.applications with correct parameters. We want to use the model pretrained on ImageNet, so we set weights to “imagenet”. We want to simplify the architecture and change the output size, so we do not want include 3 fully-connected layers at the top of the network. We set include_top to false. Finally, with input_shape we define what is the size of the input images.

vgg_pretrained_model = VGG16(weights="imagenet", 
                             include_top= False,
                             input_shape=(224,224,3) )

Let’s see the summary of the model.

vgg_pretrained_model.summary()

The VGG16 architecture contains 5 blocks of Conv2D and MaxPooling2D layers with all parameters set as trainable.

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
_________________________________________________________________
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________

The goal of this experiment is to leverage transfer learning. The strategy we employ is to freeze blocks 1–4 of the VGG16 architecture and train only the block number 5. This means for blocks 1–4 we use weights trained on ImageNet. For block 5, these weights are used as initialization, and we tune this part of the architecture.

for layer in vgg_pretrained_model.layers[:15]:
    layer.trainable = False

Let’s now define a new model using 5 blocks of the pre-trained CNN layers. Then we add new fully connected layers to the architecture. Our classification problem is much simpler than the one posed by ImageNew (two vs thousand classes) hence, it seems reasonable to use much simpler FC layers. We use layers from tf.keras.layers. We add two Dense layers with respectively 512 and 64 units. We also include Dropout layers in-between. Finally, we define an output layer as Dense with 1 unit and sigmoid activation function.

model = vgg_pretrained_model.output
model = Flatten(name="flatten")(new_model)
model = Dense(512, activation="relu")(new_model)
model = Dropout(0.4)(new_model)
model = Dense(64, activation="relu")(new_model)
model = Dropout(0.4)(new_model)
model = Dense(1, activation="sigmoid")(new_model)
model = Model(inputs=vgg_pretrained_model.input, outputs=model)

The final architecture we train has the same CNN layers as VGG16, but only block 5 is set to be trainable. Then, there are 3 Dense trainable layers.

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
_________________________________________________________________
----------------------------- VGG16 ----------------------------- 
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 512)               12845568  
_________________________________________________________________
dense_1 (Dense)              (None, 64)                32832     
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 65        
_________________________________________________________________
Total params: 27,593,153
Trainable params: 19,957,889
**Non-trainable params: 7,635,264**
_________________________________________________________________

The main metric we use to evaluate our model is Area Under the Curve (AUC). However, we also monitor additional metrics. We use Built-in metrics from tf.keras.metrics.

metrics = [
      AUC(name='auc'),
      TruePositives(name='tp'),
      FalsePositives(name='fp'),
      TrueNegatives(name='tn'),
      FalseNegatives(name='fn'), 
      BinaryAccuracy(name='binary_accuracy'),
      Precision(name='precision'),
      Recall(name='recall'),
      
]

Finally we can compile the new model setting Adam optimizer, a binary cross-entropy loss function, and the metrics defined above.

model.compile(optimizer = Adam(learning_rate=0.0001), 
            loss='binary_crossentropy', 
            metrics=metrics)

Training

Having defined the model based on the VGG16 architectures, we can now train it. There are two more things we add at this point.

  • EarlyStopping to avoid training the model longer than necessary, we can set a high number of epochs, and the training stops when the model performance does not improve anymore,
  • class_weight to tell the model to “pay more attention” to samples from an under-represented class (NORMAL in this case).

Training of this model takes 19 epochs and is stopped by EarlyStopping. Each epoch took around 80 seconds on the used hardware.

es = EarlyStopping(monitor='val_loss', patience=10, 
                    verbose = 1, 
                    restore_best_weights=True)

positive = sum(train_generator.classes)
total = len(train_generator.classes)
negative = total - positive
weight_for_negative = (1 / negative)*(total)/2.0 
weight_for_positive = (1 / positive)*(total)/2.0
class_weight = {0: weight_for_negative, 1: weight_for_positive}

history = model.fit_generator(
            train_generator,   
            validation_data = val_generator,
            steps_per_epoch = int(train_generator.samples/batch_size),
            epochs = 10000,
            validation_steps = int(val_generator.samples/1),
            verbose = 1,
            callbacks=[es],
            class_weight=class_weight)

Evaluation

Finally, let’s evaluate the model on the test dataset.

evaluate_results = model.evaluate_generator(
                        generator=test_generator, 
                        steps = int(test_generator.samples/1), 
                        verbose=1
                    )
metric/dataset train validation test
AUC 0.9856 0.9948 0.9895
Precision 0.9836 0.9973 0.9891
Recall 0.9430 0.9275 0.9050
True Positive 3240 371 362
True Negative 1203 150 146
False Positive 54 1 4
False Negative 196 29 38

How about no Transfer Learning?

To load VGG16 without the pre-trained weights we have to set weights to None and keep all layers as trainable.

vgg_pretrained_model = VGG16(weights=None, 
                             include_top= False,
                             input_shape=(224,224,3) )

To be able to run model training on the hardware at hand, we have to decrease batch_size to 16. In such a setup, t he model trains roughly three times longer to achieve the same AUC (~0.99).

References

Transfer Learning from Chest X-Ray Pre-trained Convolutional Neural Network for Learning Mammogram Data
Pardamean, Bens, Tjeng Wawan Cenggoro, Reza Rahutomo, Arif Budiarto, and Ettikan Kandasamy Karuppiah
Procedia Computer Science
https://doi.org/10.1016/j.procs.2018.08.190

CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays with Deep Learning
Pranav Rajpurkar, Jeremy Irvin, Kaylie Zhu, Brandon Yang, Hershel Mehta, Tony Duan, Daisy Ding, Aarti Bagul, Curtis Langlotz, Katie Shpanskaya, Matthew P. Lungren, Andrew Y. Ng
arXiv:1711.05225

Comparison of Deep Learning Approaches for Multi-Label Chest X-Ray Classification
Ivo M. Baltruschat, Hannes Nickisch, Michael Grass, Tobias Knopp & Axel Saalbach
arXiv:1803.02315

Classification of COVID-19 in chest X-ray images using DeTraC deep convolutional neural network
Asmaa Abbas, Mohammed M. Abdelsamea, Mohamed Medhat Gaber
arXiv:2003.13815

Data Scientist | Machine Learning Engineer | AI Advisor

20 years of experience in data processing from BigData to AI