Source code for atlalign.visualization

"""A collection of utils for all visualization scripts."""

"""
    The package atlalign is a tool for registration of 2D images.

    Copyright (C) 2021 EPFL/Blue Brain Project

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation

from atlalign.base import DisplacementField


[docs]def create_animation( df, img, frames_per_second=30, n_seconds=3, repeat=False, blit=False, cmap="gray", img_ref=None, n_ref=3, duration_ref=1, ): """Create a slow motion animation of a warping. Parameters ---------- df : DisplacementField or list If an instance of the DisplacementField class representing then representing a single transformation. If a list of DisplacementField instances then represents a pipeline of different transformations to be applied in the respective order. img : np.ndarray Image to be warped. Needs to have the same shape as teh `df` and dtype either uint8 or float32. frames_per_second : int, default 30 Number of frames per second. n_seconds : int, default 3 Number of seconds one df will last. Total number of seconds is `len(df) * n_seconds`. repeat : bool If True, animation is automatically restarted. blit : bool Controls whether blitting is used to optimize drawing. cmap : str, default 'gray' Only applicable if image grayscale. img_ref : np.array or None If supplied then at the end of the animation switch between moving and registered `n_ref` of times where each blit lasts `duration_ref` seconds. n_ref : int Number of times to switch between `img_ref` and registered image. Only active when `img_reg` is not None. duration_ref : int Number of seconds `img_ref` is visible per blit. Returns ------- ani : matplotlib.animation.ArtistAnimation Animation object that can be viewed in a jupter notebook for example. Notes ----- To make it viewable in a jupyter notebook one needs to do the following >>> from matplotlib import rc >>> rc('animation', html='jshtml') If you get errors using these settings consider replacing `html='jshtml'` by `html='html5'` above. Additionally, it is necessary to install ffpmeg package. On Ubuntu this can be done: ```bash sudo apt install ffmpeg ``` """ # Prepare variables total_frames = n_seconds * frames_per_second # total frames per one df interval = int(1000 / frames_per_second) all_frames = [] df_list = df if isinstance(df, list) else [df] # Prepare plot fig = plt.figure() plt.axis("off") # Collect frames for df_ in df_list: for i in range(total_frames + 1): df_temp = df_ * (i / total_frames) warped_img_ = df_temp.warp(img) warped_img = plt.imshow(warped_img_, cmap=cmap) all_frames.append([warped_img]) # Update starting image with the last warped image img = warped_img_ if img_ref is not None: img_mov_axes = all_frames[-1][0] img_ref_axes = plt.imshow(img_ref, cmap=cmap) for i in range(n_ref): # reference all_frames.extend(int(frames_per_second * duration_ref) * [[img_ref_axes]]) # moving all_frames.extend(int(frames_per_second * duration_ref) * [[img_mov_axes]]) ani = animation.ArtistAnimation( fig, all_frames, interval=interval, blit=blit, repeat=repeat, repeat_delay=None ) return ani
[docs]def create_grid(shape, grid_spacing=20, grid_thickness=3): """Create a grid to see warpings clearly. Parameters ---------- shape : tuple Tuple of (height, width) which represent the shape of the output image. grid_spacing : int Both horizontal and vertical spacing of consecutive lines. grid_thickness : int Thickness of all lines. Returns ------- img_grid : np.ndarray An image of the grid. """ grid_shape = (grid_h, grid_w) = shape[:2] grid = np.zeros(grid_shape) # Populate horizontal for c in range(0, grid_w, grid_spacing): grid[:, c : c + grid_thickness] = 255 # Populate vertical for r in range(0, grid_h, grid_spacing): grid[r : r + grid_thickness, :] = 255 grid = 255 - grid return grid
[docs]def create_segmentation_image(segmentation_array, colors_dict=None): """Turn segmentation array into a colorful image. Parameters ---------- segmentation_array : np.array An array of shape (h, w) and dtype ``int`` where each number represents a unique class. colors_dict : None or dict If None, then all classes are assigned a random color (except for 0 which by default gets a black color). If dict, keys are integers representing classes and values are tuples of size 3 representing (R, G, B). If a class is not contained in the dict then color randomly generated. Returns ------- segmentation_img : np.array An image of shape (h, w) and dtype `uint8`` and 3 channels (RGB). colors_dict : dict Color (values) per class (keys) dictionary. If no `colors_dict` passed then a new instance. If passed, then it is an updated version. """ if not np.issubdtype(segmentation_array.dtype, np.integer): raise TypeError("Only integer valued classes are allowed.") if colors_dict is None: colors_dict = {0: np.array([0, 0, 0])} # background all_labels = np.unique(segmentation_array) segmentation_img = np.zeros((*segmentation_array.shape, 3), dtype=np.uint8) for lb in all_labels: if lb not in colors_dict: color = np.random.randint(255, size=3) colors_dict[lb] = color else: color = colors_dict[lb] segmentation_img[segmentation_array == lb] = color return segmentation_img, colors_dict
[docs]def generate_df_plots(df_true, df_pred, filepath=None, figsize=(15, 15)): """Generate displacement vector plots. df_true : DisplacementField Truth. Assumes that shape is (320, 456). df_pred : DisplacementField Prediction. Assumes that shape is (320, 456) filepath : None or pathlib.Path If specified, then the path to where the figure saved as a PNG image. If not specified, then shown. """ # The import is placed here in order to avoid the tensorflow import coming # from atlalign.metrics in the module scope (it's very slow) import seaborn as sns from atlalign.metrics import angular_error_of plt.ioff() fig, ( (ax_norm, ax_norm_p), (ax_angle, ax_angle_p), (ax_jacob, ax_jacob_p), (ax_grid, ax_grid_p), ) = plt.subplots(4, 2, figsize=figsize) df_base = DisplacementField.generate( (320, 456), approach="affine_simple", translation_x=1 ) # make the angle work bar_norm = fig.add_axes([0.95, 0.772, 0.03, 0.2]) bar_angle = fig.add_axes([0.95, 0.525, 0.03, 0.2]) bar_jacob = fig.add_axes([0.95, 0.275, 0.03, 0.2]) # Jacobian jacob_true = df_true.jacobian jacob_pred = df_pred.jacobian jacob_vmin, jacob_vmax = min(jacob_true.min(), jacob_pred.min()), max( jacob_true.max(), jacob_pred.max() ) ax_jacob.set_axis_off() sns.heatmap( jacob_true, ax=ax_jacob, cbar_ax=bar_jacob, center=0, cmap="seismic_r", vmin=jacob_vmin, vmax=jacob_vmax, ) ax_jacob.set_title("Jacobian - Ground Truth") ax_jacob_p.set_axis_off() sns.heatmap( jacob_pred, ax=ax_jacob_p, cbar_ax=bar_jacob, center=0, cmap="seismic_r", vmin=jacob_vmin, vmax=jacob_vmax, ) ax_jacob_p.set_title("Jacobian - Predicted") # GRID img_grid = create_grid((320, 456)) img_grid_warped = df_true.warp(img_grid) ax_grid.set_axis_off() ax_grid.imshow(img_grid_warped, cmap="gray") ax_grid.set_title("Warped Grid - Ground Truth") img_grid_warped_p = df_pred.warp(img_grid) ax_grid_p.set_axis_off() ax_grid_p.imshow(img_grid_warped_p, cmap="gray") ax_grid_p.set_title("Warped Grid - Predicted") # NORM norm_vmin, norm_vmax = 0, max(df_true.norm.max(), df_pred.norm.max()) ax_norm.set_axis_off() sns.heatmap( df_true.norm, ax=ax_norm, cbar_ax=bar_norm, vmin=norm_vmin, vmax=norm_vmax ) ax_norm.set_title("Norm - Ground Truth") ax_norm_p.set_axis_off() sns.heatmap( df_pred.norm, ax=ax_norm_p, cbar_ax=bar_norm, vmin=norm_vmin, vmax=norm_vmax ) ax_norm_p.set_title("Norm - Predicted") # Angle angle_vmin, angle_vmax = 0, 180 ax_angle.set_title("Angle wrt positive x-axis - Ground Truth") ax_angle.set_axis_off() angles = angular_error_of([df_true], [df_base])[1] sns.heatmap( angles, ax=ax_angle, cbar_ax=bar_angle, mask=~np.isfinite(angles), cmap="hot_r", vmin=angle_vmin, vmax=angle_vmax, ) ax_angle_p.set_title("Angle wrt positive x-axis - Predicted") ax_angle_p.set_axis_off() angles = angular_error_of([df_pred], [df_base])[1] sns.heatmap( angles, ax=ax_angle_p, cbar_ax=bar_angle, mask=~np.isfinite(angles), cmap="hot_r", vmin=angle_vmin, vmax=angle_vmax, ) # fig.tight_layout(rect=[0, 0, .95, 1]) if filepath is not None: fig.savefig(str(filepath)) else: plt.show()
[docs]def chain_predict(model, inp, n_iterations=1): """Run alignment recursively. Parameters ---------- model : keras.models.Model A trained model that whose inputs have shape (batch_size, h, w, 2) - last dimension represents stacking of atlas and input image. The outputs are of the same shape where the last dimension represents stacking of delta_x and delta_y of the displacement field. inp : np.ndarray An array of shape (h, w, 2) or (1, h, w, 2) representing the atlas and input image. Returns ------- unwarped_img_list : list List of np.ndarrays of shape (h, w) representign the unwarped image at each iteration. """ # Checks if inp.ndim == 3: inp_ = np.array([inp]) elif inp.ndim == 4 and inp.shape[0] == 1: inp_ = inp else: raise ValueError("Input has incorrect shape of {}".format(inp.shape)) shape = inp.shape[1:3] df = DisplacementField.generate(shape, approach="identity") img_atlas = inp_[0, :, :, 0] img_warped = inp_[0, :, :, 1] unwarped_img_list = [img_warped] for i in range(n_iterations): new_inputs = np.concatenate( ( img_atlas[np.newaxis, :, :, np.newaxis], unwarped_img_list[-1][np.newaxis, :, :, np.newaxis], ), axis=3, ) pred = model.predict(new_inputs) delta_x_pred = pred[0, ..., 0] delta_y_pred = pred[0, ..., 1] df_pred = DisplacementField(delta_x_pred, delta_y_pred) df_pred_inv = df_pred.pseudo_inverse(ds_f=8) df = df_pred_inv(df).adjust() img_unwarped_pred = df.warp(img_warped) unwarped_img_list.append(img_unwarped_pred) return unwarped_img_list