GAN Neural Networks for Unpaired Image-Image Translation on MRI Images
Generate Synthetic T1/T2 weighted MRI Images using CycleGAN
Misdiagnosis in medical field is a very serious issue having adverse implications on the patient. With skilled radiologists overwhelmed by volume of work, there is critical need for Artifical Intelligence to assist them in taking the right decisions.
Magnetic Resonance Imaging (MRI) is a key imaging technology which offers superb soft tissue contrast with different contrast mechanisms such as T1 weighted and T2 weighted. A radiologist often needs multi contrast images to arrive at the right decision but is often cost prohibitive.
Deep Learning is a great promise to generate synthetic MRI images of different contrast levels from existing MRI scans. Style transfer using CycleGAN can be used for this application.
This article briefly describes the implementation of unpaired image to image translation using CycleGAN to translate MRI T1-weighted images to T2-weighted images and viceversa
Introduction to CycleGAN
Generative Adversarial Network or in short GAN, is an unsupervised machine learning task that involves automatically discovering and learning the regularities or patterns in input data to generate new data synthetically.
GANs are designed with two sub modules, Generator and Discriminator which work in adversarial manner. Generator model generates new images and Discriminator tries to classify the images from Generator as either real or fake. The two models are trained together in a zero-sum game, adversarial, until the discriminator model is fooled, meaning the generator model is able to generate realistic outputs
Conditional GANs such as Pix2Pix can generate image to image translation for paired datasets such as generating colored pictures from corresponding Black and White images
CycleGAN, on the other hand uses unpaired images to translate images from one domain to another. One example which is discussed in detail in this article is generating T2-weighted MRI images from T1-weighted images and viceversa.
- Input pipeline and Normalization
- Model Building
- Generator Model Design
- Discriminator Model Design
- Loss Functions
3. Training using T1 and T2 available images with mulitple epochs
4. Predictions using Final Model
1. Input Pipeline
The reference T1 and T2 images are loaded with a specified image size. Image size of (256,256) with grayscale mode was imported into the tensor Batchdataset.
The original CycleGAN paper recommends to apply image augmentation techniques such as random flipping and random jittering to avoid overfitting. Tensorflow has in-built functions to process the data through these augmentation methods.
Normalization of input data — Large integer values of the image can slow the training process. Hence it is very important to normalize the data [0,1] or [-1, 1] range.
2. Model Building
The CycleGAN model consists of three important modules
- Loss functions
The Generator is the module which actually generates the synthetic image starting with a random noise signal.
A modified U-Net architecture is used in the generator module. U-Net is predominantly used in Bio-medical image segmentation.
In a generic convolational neural network, the task is classify the image by taking in an image and classifying it with a label. In case of bio-medical imaging, it is important not just to classify whether there is a disease but also to localise the area of abnormality.
U-Net solves this problem. U-Net is able to localize and distinguish borders by classifying on every pixel, with both input and output images being of same size
Below is the architecture of U-Net from its original paper
Breifly, the U-Net architecture has two paths — Encoder and Decoder
Encoder — The left side of the network where regular convolutions and max pooling are applied. Here, the image size gradually reduces while the depth gradually increases (256x256x1 → 1x1x512). The encoder network learns the “WHAT” in the image but loses the “WHERE” information
Decoder — The right side of the network where transposed convolutions are applied. Here again, the image size increases (1x1x512 → 256x256x1). The decoder network extracts the “WHERE” information by gradually applying up-sampling
Skip Connections — To get precise locations, feature maps from corresponding encoder level is concatenated with output of transposed convolutional layers
Two Generators are defined in this implementation
- Generator_g learns to transform image from T1 to T2
- Generator_f learns to transform image from T2 to T1
Below are the first level of images generated after passing the reference image through generators model
The major task of discriminator module is to classify the synthetic image from the generator as real or fake.
PatchGAN is used as a discriminator in this example. PatchGAN is a type of discriminator which penalizes structure at the scale of local image patches. The PatchGAN discriminator tries to classify if each N×N patch in an image is real or fake. This discriminator is run convolutionally across the image, averaging all responses to provide the ultimate output
In this example, the shape of the output of last layer in the discriminator is 30x30. Each patch of 30x30 shall classify a portion of the image generated from the generator (input image) as fake or real.
Two discriminators are used in this example
- Discriminator_X learns to differentiate between image X and generated image X.
- Discriminator_Y learns to differentiate between image Y and generated image Y
Below are the first level of image generated after passing the synthetic image from generator to the discriminator
Binary Cross Entropy loss function is used in this example.
Learning Rate (Lambda) of 10 is used as recommended in the original paper
Several Loss functions are defined
- Generator loss : It is a binary cross entropy loss of the generated images and an array of ones
- Discrimantor loss Consists of 2 inputs
— real_loss is a binary cross entropy loss of the real images and an array of ones(since these are the real images)
— generated_loss is a sigmoid cross entropy loss of the generated images and an array of zeros(since these are the fake images)
— total_loss is the sum of real_loss and the generated_loss
- Cycle Consistency Loss — Cycle consistency means the result should be close to the original input. If T1 image is translated to T2 image, and then translates it back from T2 to T1, then the resulting image should be the same as the original image
- Identity Loss — If the image is fed to the generator, it should yield the real image or something close to image.
Optimizer — Adam optimizer is used for both generators and discriminators
The training consists of four broad steps
- Get the predictions — The generator and discriminator modules are used to get the predictions
- Calculate the loss — Various loss functions defined above are used to calculate the losses
- Calculate the gradients using backpropagation
- Apply the gradients to the optimizer — Adam optimizer as defined in the original paper is used for optimization
Below is the animated picture of the training process. The Predicted image is plotted after every epoch. (Plotted only few images for brevity)
4. Final Predictions
Below are the predictions based on the trained model
References and Links
- Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
- Generative Adversarial Networks for Image-to-Image Translation on Street View and MR Images by Simon Karlsson & Per Welander