Deep Learning - Training

We are using Keras as the deep learning framework with TensorFlow as a backend. Additionally, we make an assumption that the data has shape (320, 456) that corresponds to the 25 micron atlas.

Training data

To avoid wasting time on generating geometric augmentations on the fly we assume the user already precomputed them and stored in an .h5 file. To see more details read Deep Learning - Generating a dataset.


The goal is to train the network via the fit_generator method of a keras.Model. In atlalign one can use the utility class atlalign.ml_utils.SupervisedGenerator. The main parameter it expects in the constructor is the path to the .h5 file. It needs to contain the following datasets






Grayscale moving image (0-255 intensities)

(n, 320, 456)



Deltas for forward transformation (mov2reg)

(n, 320, 456, 2)



Deltas for inverse transformation (reg2mov)

(n, 320, 456, 2)



Coronal section in microns




Allen image section identifier




Allen dataset identifier



As additional parameters, one can specify

  • batch_size - batch size to train the network on

  • shuffle - if True, than dataset shuffled in the sample dimension at the end of each epoch

  • augmenter_ref - instance of imgaug.augmenters.Augmenter representing intensity augmentations applied to the reference

  • augmenter_mov - instance of imgaug.augmenters.Augmenter representing intensity augmentations applied to the moving

  • return_inverse- if True then yielding also the inverser transformation

What does the SupervisedGenerator yield? It is always 2 objects If return_inverse=False then:

  1. X - array of shape (batch_size, 320, 456, 2) representing the reference and moving images

  2. [reg_images, deltas_xy]

    • reg_images - array of shape (batch_size, 320, 456, 1) representing registered images

    • deltas_xy - array of shape (batch_size, 320, 456, 2) representing x and y displacements

If return_inverse=True then:

  1. [X_rm, X_mr]

    • X_rm - array of shape (batch_size, 320, 456, 2) representing the reference and moving images

    • X_mr - array of shape (batch_size, 320, 456, 2) representing the moving and reference images (swapped X_mr)

  2. [reg_images, deltas_xy, inv_deltas_xy]

    • reg_images - array of shape (batch_size, 320, 456, 1) representing registered images

    • deltas_xy - array of shape (batch_size, 320, 456, 2) representing x and y displacements

    • inv_deltas_xy - array of shape (batch_size, 320, 456, 2) representing x and y displacements of the inverse mapping

from atlalign.ml_utils import SupervisedGenerator, augmenter_1

path = '/path/to/file.h5'

gen = SupervisedGenerator(path,

gen_inv = SupervisedGenerator(path,

len(gen[0]) == 2
len(gen_inv[0]) == 2

Custom Layers

Before delving into the architectures let us first describe 3 custom layers (implemented in atlalign.ml_utils.layers):

  • Affine2DVF - Turns (2,3) affine matrix to a delta_x and delta_y (displacement field) with a fixed shape

  • BillinearInterpolation - Differentiable version of warp method of the DisplacementField

  • DVFComposition - Differentiable version of __call__ method of the DisplacementField


Ideally, we want to have two networks that take care of global resp. local transformations. In other words the first network makes sure that the moving and reference images are of the same scale, rotation and position. The second network allows for a more fine-grained alignment of specific parts of the image.

Based on experiments we highly recommend training these two networks separately. When both of them are good enough we can merge them into a single network.


The global network is conceptually identical to the Spatial Transformer Network - STN. The goal is to have a regressor network that predicts a set of parameters that fully define a transformation. The most common example (and also the one we implemented) is to find parameters of an 2D affine transformation (6 parameters). Instead of using the actual ground truth matrix we train the network on a different task - image registration.

We provide a utility function atlalign.nn.supervised_global_model_factory that outputs keras.Model that corresponds to chosen hyperparamters. See below an example how to create a network that had the best performance during our experiments.

from atlalign.nn import supervised_global_model_factory

filters = (16, 16, 32, 32, 32)
dense_layers = (40,)
losses = ('perceptual_loss_net-lin_vgg', 'vector_distance')

model_g = supervised_global_model_factory(filters=filters,

This is how the model looks like inside:

Layer (type)                    Output Shape         Param #     Connected to
input_1 (InputLayer)            (None, 320, 456, 2)  0
conv2d_1 (Conv2D)               (None, 320, 456, 16) 304         input_1[0][0]
conv2d_2 (Conv2D)               (None, 320, 456, 16) 2320        conv2d_1[0][0]
max_pooling2d_1 (MaxPooling2D)  (None, 160, 228, 16) 0           conv2d_2[0][0]
conv2d_3 (Conv2D)               (None, 160, 228, 16) 2320        max_pooling2d_1[0][0]
conv2d_4 (Conv2D)               (None, 160, 228, 16) 2320        conv2d_3[0][0]
max_pooling2d_2 (MaxPooling2D)  (None, 80, 114, 16)  0           conv2d_4[0][0]
conv2d_5 (Conv2D)               (None, 80, 114, 32)  4640        max_pooling2d_2[0][0]
conv2d_6 (Conv2D)               (None, 80, 114, 32)  9248        conv2d_5[0][0]
max_pooling2d_3 (MaxPooling2D)  (None, 40, 57, 32)   0           conv2d_6[0][0]
conv2d_7 (Conv2D)               (None, 40, 57, 32)   9248        max_pooling2d_3[0][0]
conv2d_8 (Conv2D)               (None, 40, 57, 32)   9248        conv2d_7[0][0]
max_pooling2d_4 (MaxPooling2D)  (None, 20, 28, 32)   0           conv2d_8[0][0]
conv2d_9 (Conv2D)               (None, 20, 28, 32)   9248        max_pooling2d_4[0][0]
conv2d_10 (Conv2D)              (None, 20, 28, 32)   9248        conv2d_9[0][0]
max_pooling2d_5 (MaxPooling2D)  (None, 10, 14, 32)   0           conv2d_10[0][0]
flatten_1 (Flatten)             (None, 4480)         0           max_pooling2d_5[0][0]
dense_1 (Dense)                 (None, 40)           179240      flatten_1[0][0]
dense_2 (Dense)                 (None, 6)            246         dense_1[0][0]
reshape_1 (Reshape)             (None, 2, 3)         0           dense_2[0][0]
extract_moving (Lambda)         (None, 320, 456, 1)  0           input_1[0][0]
affine2dvf_1 (Affine2DVF)       (None, 320, 456, 2)  0           reshape_1[0][0]
img_registered (BilinearInterpo (None, 320, 456, 1)  0           extract_moving[0][0]
Total params: 237,630
Trainable params: 237,630
Non-trainable params: 0

Note that one can create a custom network as long as the inputs and the outputs are compatible with the SupervisedGenerator.


The most popular type of network to use for the local displacements is a UNet. Similarly to the global case, we provide a convenience factory function atlalign.nn.supervised_model_factory that outputs an instance of keras.Model given the selected hyperparameters.

See below an example with the best hyperparameters find during experiments.

import tensorflow as tf
from atlalign.nn import supervised_model_factory

start_filters = (16,)
downsample_filters = (16, 32, 32, 32)
middle_filters = (32,)
upsample_filters = (32, 32, 32, 32)
end_filters = (64, 64)

compute_inv = True

losses = ('perceptual_loss_net-lin_vgg', 'perceptual_loss_net-lin_vgg&vdclip2', 'perceptual_loss_net-lin_vgg')
losses_weights = (1, 1, 1)

model_l = supervised_model_factory(start_filters=start_filters,
Layer (type)                    Output Shape         Param #     Connected to
reg_mov (InputLayer)            (None, 320, 456, 2)  0
cropping2d_1 (Cropping2D)       (None, 320, 448, 2)  0           reg_mov[0][0]
conv2d_1 (Conv2D)               (None, 320, 448, 16) 304         cropping2d_1[0][0]
leaky_re_lu_1 (LeakyReLU)       (None, 320, 448, 16) 0           conv2d_1[0][0]
max_pooling2d_1 (MaxPooling2D)  (None, 160, 224, 16) 0           leaky_re_lu_1[0][0]
conv2d_2 (Conv2D)               (None, 160, 224, 16) 2320        max_pooling2d_1[0][0]
leaky_re_lu_2 (LeakyReLU)       (None, 160, 224, 16) 0           conv2d_2[0][0]
max_pooling2d_2 (MaxPooling2D)  (None, 80, 112, 16)  0           leaky_re_lu_2[0][0]
conv2d_3 (Conv2D)               (None, 80, 112, 32)  4640        max_pooling2d_2[0][0]
leaky_re_lu_3 (LeakyReLU)       (None, 80, 112, 32)  0           conv2d_3[0][0]
max_pooling2d_3 (MaxPooling2D)  (None, 40, 56, 32)   0           leaky_re_lu_3[0][0]
conv2d_4 (Conv2D)               (None, 40, 56, 32)   9248        max_pooling2d_3[0][0]
leaky_re_lu_4 (LeakyReLU)       (None, 40, 56, 32)   0           conv2d_4[0][0]
max_pooling2d_4 (MaxPooling2D)  (None, 20, 28, 32)   0           leaky_re_lu_4[0][0]
conv2d_5 (Conv2D)               (None, 20, 28, 32)   9248        max_pooling2d_4[0][0]
leaky_re_lu_5 (LeakyReLU)       (None, 20, 28, 32)   0           conv2d_5[0][0]
conv2d_6 (Conv2D)               (None, 20, 28, 32)   9248        leaky_re_lu_5[0][0]
leaky_re_lu_6 (LeakyReLU)       (None, 20, 28, 32)   0           conv2d_6[0][0]
up_sampling2d_1 (UpSampling2D)  (None, 40, 56, 32)   0           leaky_re_lu_6[0][0]
conv2d_7 (Conv2D)               (None, 40, 56, 32)   9248        up_sampling2d_1[0][0]
leaky_re_lu_7 (LeakyReLU)       (None, 40, 56, 32)   0           conv2d_7[0][0]
concatenate_1 (Concatenate)     (None, 40, 56, 64)   0           leaky_re_lu_7[0][0]
conv2d_8 (Conv2D)               (None, 40, 56, 32)   18464       concatenate_1[0][0]
leaky_re_lu_8 (LeakyReLU)       (None, 40, 56, 32)   0           conv2d_8[0][0]
up_sampling2d_2 (UpSampling2D)  (None, 80, 112, 32)  0           leaky_re_lu_8[0][0]
conv2d_9 (Conv2D)               (None, 80, 112, 32)  9248        up_sampling2d_2[0][0]
leaky_re_lu_9 (LeakyReLU)       (None, 80, 112, 32)  0           conv2d_9[0][0]
concatenate_2 (Concatenate)     (None, 80, 112, 64)  0           leaky_re_lu_9[0][0]
conv2d_10 (Conv2D)              (None, 80, 112, 32)  18464       concatenate_2[0][0]
leaky_re_lu_10 (LeakyReLU)      (None, 80, 112, 32)  0           conv2d_10[0][0]
up_sampling2d_3 (UpSampling2D)  (None, 160, 224, 32) 0           leaky_re_lu_10[0][0]
conv2d_11 (Conv2D)              (None, 160, 224, 32) 9248        up_sampling2d_3[0][0]
leaky_re_lu_11 (LeakyReLU)      (None, 160, 224, 32) 0           conv2d_11[0][0]
concatenate_3 (Concatenate)     (None, 160, 224, 48) 0           leaky_re_lu_11[0][0]
conv2d_12 (Conv2D)              (None, 160, 224, 32) 13856       concatenate_3[0][0]
leaky_re_lu_12 (LeakyReLU)      (None, 160, 224, 32) 0           conv2d_12[0][0]
up_sampling2d_4 (UpSampling2D)  (None, 320, 448, 32) 0           leaky_re_lu_12[0][0]
conv2d_13 (Conv2D)              (None, 320, 448, 32) 9248        up_sampling2d_4[0][0]
leaky_re_lu_13 (LeakyReLU)      (None, 320, 448, 32) 0           conv2d_13[0][0]
concatenate_4 (Concatenate)     (None, 320, 448, 48) 0           leaky_re_lu_13[0][0]
conv2d_14 (Conv2D)              (None, 320, 448, 32) 13856       concatenate_4[0][0]
leaky_re_lu_14 (LeakyReLU)      (None, 320, 448, 32) 0           conv2d_14[0][0]
conv2d_15 (Conv2D)              (None, 320, 448, 64) 18496       leaky_re_lu_14[0][0]
leaky_re_lu_15 (LeakyReLU)      (None, 320, 448, 64) 0           conv2d_15[0][0]
conv2d_16 (Conv2D)              (None, 320, 448, 64) 36928       leaky_re_lu_15[0][0]
leaky_re_lu_16 (LeakyReLU)      (None, 320, 448, 64) 0           conv2d_16[0][0]
conv2d_17 (Conv2D)              (None, 320, 448, 2)  514         leaky_re_lu_16[0][0]
mov_reg (InputLayer)            (None, 320, 456, 2)  0
extract_moving (Lambda)         (None, 320, 456, 1)  0           reg_mov[0][0]
dvf (ZeroPadding2D)             (None, 320, 456, 2)  0           conv2d_17[0][0]
model_1 (Model)                 (None, 320, 456, 2)  192578      mov_reg[0][0]
img_registered (BilinearInterpo (None, 320, 456, 1)  0           extract_moving[0][0]
inv_dvf (Lambda)                (None, 320, 456, 2)  0           model_1[1][0]
Total params: 192,578
Trainable params: 192,578
Non-trainable params: 0

One important thing to note is the boolean compute_inv. When equal to True then the network not only learns to warp the moving image such that it is as similar to the reference as possible but also vice versa. In other words, it also learns to warp the reference image such that it is as similar to the moving image as possible. Since our generator SupervisedGenerator can yield also inverse displacement fields this is trivially done just via sharing weights and swapping the order of the inputs.

Loss function

Loss function together with the architecture is the most important component. We experimented with many different losses and ideas and found that losses that are based on the Perceputal Loss are superior in vast majority of cases. See The Unreasonable Effectiveness of Deep Features as a Perceptual Metric for more details.

The user can access the losses via two dictionaries:

  • atlalign.ml_utils.ALL_IMAGE_LOSSES - Losses on images (grayscale)

  • atlalign.ml_utils.ALL_DVF_LOSSES - Losses on displacement fields

One important insight is that we can also apply image losses on displacement fields since displacements fields are nothing else than 2 images - delta_x and delta_y. This is implemented in atlalign.ml_utils.losses.DVF2IMG. Note that we also scale down the displacement by a constant.

from atlalign.ml_utils import ALL_DVF_LOSSES, ALL_IMAGE_LOSSES


In what follows we describe some interesting and useful losses. They are all implemented in atlalign.ml_utils.losses.

Perceptual loss

Image loss that has four versions

  • perceptual_loss_net-lin_alex

  • perceptual_loss_net-lin_vgg

  • perceptual_loss_net_alex

  • perceptual_loss_net_vgg

If the string lin is contained then it refers to a model where linear layer inserted after the feature extractor. The second string refers to the CNN used to extract features.

Vector distance and its clipped version

Displacement field loss that represents the average euclidean distance between the prediction and the ground truth. The average is taken over all pixels and all samples in the batch.

Note that instead of using the vector distance as the main loss one might just use it as a simple way how to prevent the network from resorting to some exit strategies. We call this a clipper vector distance.

In the below figure on can see the idea behind clipping. The user specifies a threshold (20) and a power and the actual loss is then computed as loss = (vd / threshold) ** power.

Clipped vector Distance

See below the official keys of atlalign.ml_utils.ALL_DVF_LOSSES but not that one can easily add other versions via atlalign.ml_utils.losses.VDClipper.

  • vector_distance

  • vdclip2 - threshold=20, power=2

  • vdclip3 - threshold=20, power=3

The idea behind having power > 1 is to punish the model for making big mistakes but be more forgiving on smaller ones.


Mixer is a meta loss that simply takes two losses and computes their convex combination (by default just a mean). The corresponding keys in atlalign.ml_utils.ALL_DVF_LOSSES have the form of first&second.

Saving model

After training one can easily save the model with a utility function atlalign.ml_utils.save_model. The first argument represents the actual keras.Model instance and the second is the path (without extensions). The keyword argument separate allows the user to select whether to save weights and architecture separately or not. If done separately one loses the information on the loss function and the optimizer (and its state).

