GAN Neural Networks for Unpaired Image-Image Translation on MRI Images

Generate Synthetic T1/T2 weighted MRI Images using CycleGAN

Image Courtesy :

Misdiagnosis in medical field is a very serious issue having adverse implications on the patient. With skilled radiologists overwhelmed by volume of work, there is critical need for Artifical Intelligence to assist them in taking the right decisions.

Magnetic Resonance Imaging (MRI) is a key imaging technology which offers superb soft tissue contrast with different contrast mechanisms such as T1 weighted and T2 weighted. A radiologist often needs multi contrast images to arrive at the right decision but is often cost prohibitive.

Deep Learning is a great promise to generate synthetic MRI images of different contrast levels from existing MRI scans. Style transfer using CycleGAN can be used for this application.

Introduction to CycleGAN

Generative Adversarial Network or in short GAN, is an unsupervised machine learning task that involves automatically discovering and learning the regularities or patterns in input data to generate new data synthetically.

GANs are designed with two sub modules, Generator and Discriminator which work in adversarial manner. Generator model generates new images and Discriminator tries to classify the images from Generator as either real or fake. The two models are trained together in a zero-sum game, adversarial, until the discriminator model is fooled, meaning the generator model is able to generate realistic outputs

Conditional GANs such as Pix2Pix can generate image to image translation for paired datasets such as generating colored pictures from corresponding Black and White images

CycleGAN Architecture Image Courtesy :


  1. Input pipeline and Normalization
  2. Model Building
  • Generator Model Design
  • Discriminator Model Design
  • Loss Functions

3. Training using T1 and T2 available images with mulitple epochs

4. Predictions using Final Model

1. Input Pipeline

The reference T1 and T2 images are loaded with a specified image size. Image size of (256,256) with grayscale mode was imported into the tensor Batchdataset.

The original CycleGAN paper recommends to apply image augmentation techniques such as random flipping and random jittering to avoid overfitting. Tensorflow has in-built functions to process the data through these augmentation methods.

Normalization of input data — Large integer values of the image can slow the training process. Hence it is very important to normalize the data [0,1] or [-1, 1] range.

2. Model Building

The CycleGAN model consists of three important modules

  1. Generator
  2. Discriminator
  3. Loss functions

Generator Design

The Generator is the module which actually generates the synthetic image starting with a random noise signal.

Below is the architecture of U-Net from its original paper

Breifly, the U-Net architecture has two paths — Encoder and Decoder

Encoder — The left side of the network where regular convolutions and max pooling are applied. Here, the image size gradually reduces while the depth gradually increases (256x256x1 → 1x1x512). The encoder network learns the “WHAT” in the image but loses the “WHERE” information

Decoder — The right side of the network where transposed convolutions are applied. Here again, the image size increases (1x1x512 → 256x256x1). The decoder network extracts the “WHERE” information by gradually applying up-sampling

Skip Connections — To get precise locations, feature maps from corresponding encoder level is concatenated with output of transposed convolutional layers

Two Generators are defined in this implementation

  • Generator_g learns to transform image from T1 to T2
  • Generator_f learns to transform image from T2 to T1

Below are the first level of images generated after passing the reference image through generators model

Discriminator Design

The major task of discriminator module is to classify the synthetic image from the generator as real or fake.

In this example, the shape of the output of last layer in the discriminator is 30x30. Each patch of 30x30 shall classify a portion of the image generated from the generator (input image) as fake or real.

Two discriminators are used in this example

  • Discriminator_X learns to differentiate between image X and generated image X.
  • Discriminator_Y learns to differentiate between image Y and generated image Y

Below are the first level of image generated after passing the synthetic image from generator to the discriminator

Loss Functions

Several Loss functions are defined

  1. Generator loss : It is a binary cross entropy loss of the generated images and an array of ones
  2. Discrimantor loss Consists of 2 inputs
    — real_loss is a binary cross entropy loss of the real images and an array of ones(since these are the real images)
    — generated_loss is a sigmoid cross entropy loss of the generated images and an array of zeros(since these are the fake images)
    — total_loss is the sum of real_loss and the generated_loss
  3. Cycle Consistency Loss — Cycle consistency means the result should be close to the original input. If T1 image is translated to T2 image, and then translates it back from T2 to T1, then the resulting image should be the same as the original image
  4. Identity Loss — If the image is fed to the generator, it should yield the real image or something close to image.

Optimizer — Adam optimizer is used for both generators and discriminators

3. Training

The training consists of four broad steps

  1. Get the predictions — The generator and discriminator modules are used to get the predictions
  2. Calculate the loss — Various loss functions defined above are used to calculate the losses
  3. Calculate the gradients using backpropagation
  4. Apply the gradients to the optimizer — Adam optimizer as defined in the original paper is used for optimization

Below is the animated picture of the training process. The Predicted image is plotted after every epoch. (Plotted only few images for brevity)

4. Final Predictions

Below are the predictions based on the trained model

References and Links

  1. Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks by Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros
  2. Generative Adversarial Networks for Image-to-Image Translation on Street View and MR Images by Simon Karlsson & Per Welander



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store