Convolutional neural network classification of brain MRI

Convolutional neural network classification of brain MRI in patients with schizophrenia

Trying to classify the patients diagnosed with Schizophrenia from their brain MRI using Convolutional Neural Network on Python will be the scope of this post.

There is a possibility that the person with this mental disorder has developed certain areas of the brain, different from the others. Some studies state that in patients with schizophrenia, MR imaging shows a smaller total brain volume and enlarged ventricles.


Import necessary libraries


import os
import zipfile
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

The code was executed on Google Colab. We will upload the files in 2 different folders on Google Drive and call them from there.


from google.colab import drive
drive.mount('/content/drive')

The following code reads, normalizes, and resizes a 3D medical image. The code first defines three functions:

read_nifti_file(): This function reads a 3D medical image from a file and returns the image data.
normalize(): This function normalizes the image data by subtracting the minimum value and dividing by the maximum value.
resize_volume(): This function resizes the image data by a specified factor along the z-axis.
The code then defines a fourth function called process_scan(). This function reads, normalizes, and resizes a 3D medical image.

import nibabel as nib

from scipy import ndimage

def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan

def normalize(volume):
    """Normalize the volume"""
    min = -1000
    max = 400
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    """Resize across z-axis"""
    # Set the desired depth
    desired_depth = 95
    desired_width = 79
    desired_height = 79
    # Get current depth
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # Compute depth factor
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # Rotate
    img = ndimage.rotate(img, 90, reshape=False)
    # Resize across z-axis
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
    # Normalize
    volume = normalize(volume)
    # Resize width, height and depth
    volume = resize_volume(volume)
    return volume

Here are some additional things you can do to process 3D medical images:

Apply a contrast enhancement technique to improve the visibility of the image features.
Segment the image into different regions.
Calculate quantitative measures of the image, such as the volume of a tumor.
Register the image to another image, such as a reference image.
Apply a machine learning algorithm to the image to classify or diagnose a medical condition.

Folder "OCF_" consist of MRI From normal patients, while "1CF_" will have all the MRI images of the patients diagnosed with schizophrenia


normal_scan_paths = [
    os.path.join(os.getcwd(), "/content/drive/MyDrive/Colab Notebooks/0CF_", x)
    for x in os.listdir("/content/drive/MyDrive/Colab Notebooks/0CF_")
]
abnormal_scan_paths = [
    os.path.join(os.getcwd(), "/content/drive/MyDrive/Colab Notebooks/1CF_", x)
    for x in os.listdir("/content/drive/MyDrive/Colab Notebooks/1CF_")
]
print("MRI for healthy patient: " + str(len(normal_scan_paths)))
print("MRI for non-healthy patient: " + str(len(abnormal_scan_paths)))

Output: MRI for healthy patient: 360 MRI for non-healthy patient: 314


Read and process the scans.For the MRI having presence of schizophrenia assign 1, for the normal assign 0.


abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])

Popular picture of Van Gogh! 20% OFF with discount code KERAS
High quality Matte paper with free delivery. 20% OFF with discount code KERAS

Split data in the ratio 70-30 for training and validation.


x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print(
    "Number of samples in train and validation are %d and %d."
    % (x_train.shape[0], x_val.shape[0])
)

Output: Number of samples in train and validation are 140 and 534.


Rotate the images to gain a more robust classification. The code below defines three functions: rotate(), train_preprocessing(), and validation_preprocessing().

The rotate() function rotates a volume by a few degrees. It takes a volume as input and returns a rotated volume. The rotation angle is randomly chosen from a list of angles.

The train_preprocessing() function processes the training data. It first rotates the volume using the rotate() function. Then, it adds a channel to the volume. This is done to make the volume compatible with the convolutional neural network that will be used to train the model.

The validation_preprocessing() function processes the validation data. It only adds a channel to the volume. This is because the validation data should not be augmented, as this could introduce bias into the model.

The code uses the @tf.function decorator to make the rotate() function a TensorFlow function. This means that the function will be executed on the GPU, which can improve the performance of the code.


import random

from scipy import ndimage

@tf.function
def rotate(volume):
    """Rotate the volume by a few degrees"""

    def scipy_rotate(volume):
        # define some rotation angles
        angles = [-20, -10, -5, 5, 10, 20]
        # pick angles at random
        angle = random.choice(angles)
        # rotate volume
        volume = ndimage.rotate(volume, angle, reshape=False)
        volume[volume < 0] = 0
        volume[volume > 1] = 1
        return volume

    augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
    return augmented_volume

def train_preprocessing(volume, label):
    """Process training data by rotating and adding a channel."""
    # Rotate volume
    volume = rotate(volume)
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

def validation_preprocessing(volume, label):
    """Process validation data by only adding a channel."""
    volume = tf.expand_dims(volume, axis=3)
    return volume, label

The code below defines two data loaders, train_loader and validation_loader. These data loaders will be used to load the training and validation data for a machine learning model.

The first line of code creates a tf.data.Dataset object from the training data. The from_tensor_slices() method takes two tensors as input, the features and labels. In this case, the features are the images and the labels are the class labels.

The second line of code creates a tf.data.Dataset object from the validation data. This dataset is created in the same way as the training dataset.

The third line of code sets the batch size for the training dataset. The batch size is the number of examples that will be processed at a time. In this case, the batch size is 2.

The fourth line of code applies a data augmentation transform to the training dataset. Data augmentation is a technique that is used to artificially increase the size of the training dataset. This is done by applying random transformations to the images, such as flipping, rotating, and cropping.

The fifth line of code applies a rescaling transform to the validation dataset. Rescaling is a technique that is used to normalize the data. This is done by bringing all of the values in the data set to a common range.

The sixth line of code sets the prefetching buffer size for the training and validation datasets. The prefetching buffer size is the number of batches that will be prefetched in the background. This helps to improve the performance of the data loaders.


# Define data loaders.
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))

batch_size = 2
# Augment the on the fly during training.
train_dataset = (
    train_loader.shuffle(len(x_train))
    .map(train_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)
# Only rescale.
validation_dataset = (
    validation_loader.shuffle(len(x_val))
    .map(validation_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)

The following code defines a function called get_model() that builds a 3D convolutional neural network (CNN) model. The model has the following layers:

A convolutional layer with 64 filters and kernel size 3.
A max pooling layer with pool size 2.
A batch normalization layer.
Another convolutional layer with 64 filters and kernel size 3.
Another max pooling layer with pool size 2.
A batch normalization layer.
Another convolutional layer with 128 filters and kernel size 3.
Another max pooling layer with pool size 2.
A batch normalization layer.
Another convolutional layer with 256 filters and kernel size 3.
Another max pooling layer with pool size 2.
A batch normalization layer.
A global average pooling layer.
A dense layer with 512 units and activation function relu.
A dropout layer with rate 0.3.
A dense layer with 1 unit and activation function sigmoid.
The model is called 3dcnn and has two inputs: the width, height, and depth of the CT scan, and the number of channels.

The code also builds the model and prints a summary of the model. The summary shows the number of layers in the model, the number of parameters in each layer, and the overall number of parameters in the model.

The model has a total of 19,362,432 parameters. This is a relatively large number of parameters, which means that the model is likely to be complex and capable of learning complex patterns in the data.

However, it is important to note that the number of parameters is not the only factor that determines the performance of a model. The architecture of the model, the choice of optimizer, and the amount of training data also play important roles.

Overall, the code defines a well-structured 3D CNN model for medical image classification. The model has a large number of parameters, which suggests that it is capable of learning complex patterns in the data. However, it is important to train the model on a large dataset of MRI scans to ensure that it generalizes well to unseen data.


def get_model(width=79, height=79, depth=95):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((width, height, depth, 1))

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(units=1, activation="sigmoid")(x)

    # Define the model.
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model


# Build model.
model = get_model(width=79, height=79, depth=95)
model.summary()

The below summary is the output of the above code.


Sensational The Scream painting printed on high quality matte paper!
20% OFF with discount code KERAS with free delivery

The following code first compiles the model using the following arguments:

loss: The loss function to use. In this case, the binary crossentropy loss function is used, which is a good choice for binary classification problems.
optimizer: The optimizer to use. In this case, the Adam optimizer is used, which is a popular choice for deep learning models.
metrics: The metrics to track during training and validation. In this case, the accuracy metric is used.

The code then defines two callbacks:

checkpoint_cb: This callback saves the model weights to a file at the end of each epoch. This is useful for resuming training or for loading the best model weights for inference.
early_stopping_cb: This callback stops training early if the validation accuracy does not improve for a certain number of epochs. This helps to prevent overfitting.
Finally, the code trains the model for a specified number of epochs. The training data is passed to the model in batches. The validation data is used to evaluate the model's performance after each epoch.

The code is a good starting point for training a 3D CNN model for medical image classification. However, it is important to experiment with different hyperparameters and training settings to get the best results.


# Compile model.
initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
model.compile(
    loss="binary_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    metrics=["acc"],
)

# Define callbacks.
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

# Train the model, doing validation at the end of each epoch
epochs = 100
model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    shuffle=True,
    verbose=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

Here are some additional things you can do to improve the performance of the model:

Use a larger dataset of MRI scans.
Use a different optimizer, such as SGD with momentum or RMSprop.
Try different learning rate schedules.
Use data augmentation techniques, such as random cropping and flipping.
Use dropout to prevent overfitting.

While the machine is taking its time to compile the model, you might need the help of a GPU to speed up the process. After the model has been compiled we will produce graphs of the training and validation curves for the model's accuracy and loss metrics.

The code first creates a figure and two subplots. The subplots are arranged in a single row.

The code then loops over the two metrics, acc and loss. For each metric, the code plots the training and validation curves for the metric. The training curves are plotted in blue and the validation curves are plotted in green.

The code also sets the titles and labels for the two subplots. The first subplot shows the training and validation curves for the accuracy metric. The second subplot shows the training and validation curves for the loss metric.

The curves show that the model's accuracy and loss both improve over the course of training. The accuracy curve reaches a plateau after about 50 epochs, while the loss curve continues to decrease. This suggests that the model is overfitting after about 50 epochs.


fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()

for i, metric in enumerate(["acc", "loss"]):
    ax[i].plot(model.history.history[metric])
    ax[i].plot(model.history.history["val_" + metric])
    ax[i].set_title("Model {}".format(metric))
    ax[i].set_xlabel("epochs")
    ax[i].set_ylabel(metric)
    ax[i].legend(["train", "val"])



The new experiment from Google, Bard has helped with the code explanation.

Louis Wain - A cat standing on its hind legs

Comments

Popular posts from this blog

Classification using Support Vector Machines and K-Nearest Neighbours algorithms in RStudio.

Three ways of splitting Train and Test in RStudio