Source code for atlalign.label.new_GUI

"""Graphical User Interface for manual registration."""
# The package atlalign is a tool for registration of 2D images.
#
# Copyright (C) 2021 EPFL/Blue Brain Project
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from collections import deque
from copy import deepcopy

import matplotlib as mpl
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button, RadioButtons, Slider
from skimage.color import gray2rgb

from atlalign.base import DisplacementField
from atlalign.visualization import create_grid

plt.style.use("default")


[docs]class HelperGlobal: """Just a way how to avoid using global variables. Parameters ---------- img_ref : np.ndaray Reference image. Needs to be dtype == np.uint8. img_mov : np.ndarray Input image. Needs to be dtype == np.uint8 and the same shape as `img_ref`. mode : str, {'ref2mov', 'mov2ref'} If 'ref2mov' then the first point should be in the reference image and the other point in the moving one. For 'mov2ref' its vice versa. title : str, Additional title of the figure. Attributes ---------- img_ref_ : np.ndarray Copy of the reference image. img_mov_ : np.ndarray Copy of the moving image. img_reg : np.ndarray Continuously updated registered image. ax, ax_reg : matplotlib.Axes Axes objects: ax — overlay of the `img_ref_` and `img_mov_` together with scatter of current keypoints. ax_reg - overlay of the `img_ref_ and img_reg`. keypoints : dict The keys are (x_ref, y_ref) pairs in the reference image whereas the values are (x_mov, y_mov) tuples. Note that we make heavy use of the None sentinel whenever a new pair is being inputted. If `mode` = 'ref2mov' then it holds that n_ref = n_mov or n_ref = n_mov + 1. In the second case, the sentinel is used in dictionary values. If `mode` = 'mov2ref' then it holds that n_ref = n_mov or n_ref = n_mov - 1. In the second case, the sentinel is used in dictionary keys. all_colors : deque All possible colors for the scatter plot. Note that we infinitely iteratre through this for new input points. epsilon : int A parameter that determines the rectangle around a clicked point during deletions (spacebar) The higher this parameter the less precise you need to be when trying to delete a reference keypoint. """ def __init__(self, img_ref, img_mov, mode, title): self.mode = mode self.title = title # Save internally original images (immutable) self.img_ref_ = img_ref.copy() self.img_mov_ = img_mov.copy() # self.img_reg = img_mov.copy() # continuously_updated self.df = DisplacementField.generate( self.img_ref_.shape, approach="identity" ) # latest DVF self.grid_ = create_grid( shape=self.img_ref_.shape, grid_spacing=15, grid_thickness=2 ) # the unwarped grid # dummy self.img_dummy = np.zeros(self.img_ref_.shape) # Hyperparameters self.colormaps = ["hot", "gray", "cool", "viridis", "spring"] self.cmap_ref = "gray" # needs index in the colormaps list self.cmap_movreg = "hot" # needs index in the colormaps list self.th_ref = 10 / 255 self.th_movreg = 10 / 255 self.ref_first = True self.show_grid = False self.show_arrows = True self.alpha_ref = 0.8 self.alpha_movreg = 0.5 self.alpha_movreg_prev = 0.0 # interpolation related self.interpolation_methods = ["griddata", "rbf"] self.interpolation_method = "rbf" self.interpolation_method_prev = self.interpolation_method self.kernels = [ "multiquadric", "inverse", "gaussian", "linear", "cubic", "quintic", "thin_plate", ] self.kernel = "thin_plate" self.kernel_prev = self.kernel # Visual self.marker_ref = "." self.marker_mov = "+" self.marker_size_ref = 7**2 self.marker_size_mov = 7**2 self.all_colors = deque( ["red", "green", "blue", "yellow", "orange", "pink", "brown", "cyan"] ) # left start # self.all_colors = deque([cm.tab20(i / 100) for i in range(0, 100, 5)]) # Attributes self.keypoints = {} # # (x_ref, y_ref) -> (x_inp, y_inp) self.keypoints_prev = {} self.colors = {} # (x_ref, y_ref) -> color () self.epsilon = 3 self.symmetric_registration = False # Modify keyboard shortcuts self.key_pan = "a" self.key_zoom_rect = "s" self.key_delete_ref_point = "d" self.key_reset_zoom = "f" self.key_swap_alpha = " " # Remove all default key bindings to avoid clashes for key, value in mpl.rcParams.items(): if key.startswith("keymap."): value.clear() mpl.rcParams["keymap.pan"] = [self.key_pan] mpl.rcParams["keymap.zoom"] = [self.key_zoom_rect] mpl.rcParams["keymap.home"] = [self.key_reset_zoom] self.key_descriptions = { self.key_pan: "pan", self.key_zoom_rect: "zoom rectangle", self.key_reset_zoom: "reset zoom", self.key_delete_ref_point: "delete ref point", self.key_swap_alpha: "toggle alpha", } # Axis self.fig, (self.ax, self.ax_reg) = plt.subplots(1, 2, figsize=(20, 20)) # self.fig.canvas.set_window_title(self.title) self.fig.tight_layout() self.fig.suptitle(self.title, fontsize=17) self.ax.set_axis_off() self.ax_reg.set_axis_off() self._define_widgets() # Initialize plots self._draw() def _make_buttons(self, y_pos): width, height = 0.15, 0.03 # toggle reset self.reset_button = Button( plt.axes([0.01, y_pos, width, height]), "Reset", # color=[0.0, 1.0, 0.0] if self.ref_first else [1, 0, 0], ) def on_clicked(*args, **kwargs): self.keypoints = {} self._update_plots() self.reset_button.on_clicked(on_clicked) # Toggle symmetric registration self.toggle_symmetric_reg = Button( plt.axes([0.20, y_pos, width, height]), "", # color=[0.0, 1.0, 0.0] if self.ref_first else [1, 0, 0], ) def set_symmetric_reg_label(): status_str = "[On]" if self.symmetric_registration else "[Off]" self.toggle_symmetric_reg.label.set_text( f"Symmetric Registration {status_str}" ) def on_clicked(_event): self.symmetric_registration = not self.symmetric_registration set_symmetric_reg_label() # No no key points, so we have to force the redraw self._update_plots(force=True) set_symmetric_reg_label() self.toggle_symmetric_reg.on_clicked(on_clicked) # toggle ref_first self.ref_first_button = Button( plt.axes([0.4, y_pos, width, height]), "Change order", # color=[0.0, 1.0, 0.0] if self.ref_first else [1, 0, 0], ) def on_clicked(*args, **kwargs): self.ref_first = not self.ref_first self._update_plots() self.ref_first_button.on_clicked(on_clicked) # toggle show_arrows self.show_arrows_button = Button( plt.axes([0.6, y_pos, width, height]), "Show arrows", # color=[0.0, 1.0, 0.0] if self.ref_first else [1, 0, 0], ) def on_clicked(*args, **kwargs): self.show_arrows = not self.show_arrows self._update_plots() self.show_arrows_button.on_clicked(on_clicked) # toggle show_grid self.show_grid_button = Button( plt.axes([0.8, y_pos, width, height]), "Show grid", # color=[0.0, 1.0, 0.0] if self.ref_first else [1, 0, 0], ) def on_clicked(*args, **kwargs): self.show_grid = not self.show_grid self._update_plots() self.show_grid_button.on_clicked(on_clicked) def _define_widgets(self): """Define all widgets.""" self._make_buttons(y_pos=0.01) # threshold ref axcolor = "lightgoldenrodyellow" self.th_ref_slider = Slider( plt.axes([0.25, 0.05, 0.65, 0.03], facecolor=axcolor), "Reference Threshold", 0.0, 1, valinit=self.th_ref, color=[0.0, 1.0, 0.0], ) def on_changed(val): self.th_ref = val self._update_plots() self.th_ref_slider.on_changed(on_changed) # threshold movreg axcolor = "lightgoldenrodyellow" self.th_movreg_slider = Slider( plt.axes([0.25, 0.075, 0.65, 0.03], facecolor=axcolor), "Moving/Registered Threshold", 0.0, 1, valinit=self.th_movreg, color=[0.0, 1.0, 0.0], ) def on_changed(val): self.th_movreg = val self._update_plots() self.th_movreg_slider.on_changed(on_changed) # alpha ref axcolor = "lightgoldenrodyellow" self.alpha_ref_slider = Slider( plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor), "Alpha Ref", 0.0, 1, valinit=self.alpha_ref, color=[0.0, 0.0, 1.0], ) def on_changed(val): self.alpha_ref = val self._update_plots() self.alpha_ref_slider.on_changed(on_changed) # alpha movreg axcolor = "lightgoldenrodyellow" self.alpha_movreg_slider = Slider( plt.axes([0.25, 0.125, 0.65, 0.03], facecolor=axcolor), "Alpha Movinig/Registered", 0.0, 1, valinit=self.alpha_movreg, color=[0.0, 0.0, 1.0], ) def on_changed(val): self.alpha_movreg = val self._update_plots() self.alpha_movreg_slider.on_changed(on_changed) # cmap ref axcolor = "lightgoldenrodyellow" rax = plt.axes([0.1, 0.85, 0.05, 0.10], facecolor=axcolor) rax.set_title("Ref") self.cmap_ref_radio = RadioButtons( rax, labels=self.colormaps, active=self.colormaps.index(self.cmap_ref) ) def on_clicked(*args, **kwargs): self.cmap_ref = self.cmap_ref_radio.value_selected self._update_plots() self.cmap_ref_radio.on_clicked(on_clicked) # cmap_movreg axcolor = "lightgoldenrodyellow" rax = plt.axes([0.05, 0.85, 0.05, 0.10], facecolor=axcolor) rax.set_title("Mov/Reg") self.cmap_movreg_radio = RadioButtons( rax, labels=self.colormaps, active=self.colormaps.index(self.cmap_movreg), ) def on_clicked(*args, **kwargs): self.cmap_movreg = self.cmap_movreg_radio.value_selected self._update_plots() self.cmap_movreg_radio.on_clicked(on_clicked) # interpolation method axcolor = "lightgoldenrodyellow" rax = plt.axes([0.85, 0.85, 0.05, 0.10], facecolor=axcolor) rax.set_title("Interpolator") self.interpolation_method_radio = RadioButtons( rax, labels=self.interpolation_methods, active=self.interpolation_methods.index(self.interpolation_method), ) def on_clicked(*args, **kwargs): self.interpolation_method = self.interpolation_method_radio.value_selected self._update_plots() self.interpolation_method_radio.on_clicked(on_clicked) # kernel axcolor = "lightgoldenrodyellow" rax = plt.axes([0.9, 0.85, 0.09, 0.10], facecolor=axcolor) rax.set_title("Kernel") self.kernel_radio = RadioButtons( rax, labels=self.kernels, active=self.kernels.index(self.kernel) ) def on_clicked(*args, **kwargs): self.kernel = self.kernel_radio.value_selected self._update_plots() self.kernel_radio.on_clicked(on_clicked) # Transform quality axcolor = "lightgoldenrodyellow" self.rax_quality = plt.axes([0.45, 0.8, 0.10, 0.05], facecolor=axcolor) self.rax_quality.set_axis_off() def _update_plots(self, force=False): """Render the entire figure with the most recent parameters.""" new_kps = self.keypoints != self.keypoints_prev new_ip = ( self.interpolation_method != self.interpolation_method_prev or self.kernel != self.kernel_prev ) is_complete = not np.any( [k is None or v is None for k, v in self.keypoints.items()] ) if (new_kps and is_complete) or new_ip or force: self.keypoints_prev = deepcopy(self.keypoints) self.interpolation_method_prev = self.interpolation_method self.kernel_prev = self.kernel # Check if any keypoints if not self.keypoints: self.df = DisplacementField.generate( self.img_ref_.shape, approach="identity" ) self.img_reg = self.img_mov_.copy() else: # Interpolation preparation all_kps = [ (k, v) for k, v in self.keypoints.items() if k is not None and v is not None ] # If symmetric registration, then add mirrored points if self.symmetric_registration: _, width = self.img_ref_.shape[:2] mirrored_kps = [ ((width - x1, y1), (width - x2, y2)) for (x1, y1), (x2, y2) in all_kps ] all_kps.extend(mirrored_kps) coords_ref = [x[0] for x in all_kps] coords_inp = [x[1] for x in all_kps] points = np.flip( np.array(coords_ref), axis=1 ) # control_points uses (row, col) = ( y, x) values_delta_y = np.array( [ xy_inp[1] - xy_ref[1] for xy_ref, xy_inp in zip(coords_ref, coords_inp) ] ) values_delta_x = np.array( [ xy_inp[0] - xy_ref[0] for xy_ref, xy_inp in zip(coords_ref, coords_inp) ] ) interpolator_kwargs = ( {} if self.interpolation_method == "griddata" else {"function": "{}".format(self.kernel)} ) # Actual interpolation self.df = DisplacementField.generate( self.img_ref_.shape, approach="control_points", points=points, values_delta_x=values_delta_x, values_delta_y=values_delta_y, anchor_corners=True, interpolation_method=self.interpolation_method, interpolator_kwargs=interpolator_kwargs, ) # Plot # self.df.plot_dvf(ax=self.ax_df) # Update warped image self.img_reg = self.df.warp(self.img_mov_.copy()) # Redraw figure while keeping the zoom/limits ax_xlim = self.ax.get_xlim() ax_ylim = self.ax.get_ylim() ax_reg_xlim = self.ax_reg.get_xlim() ax_reg_ylim = self.ax_reg.get_ylim() self._draw() self.ax.set(xlim=ax_xlim, ylim=ax_ylim) self.ax_reg.set(xlim=ax_reg_xlim, ylim=ax_reg_ylim) def _draw(self): self.ax.cla() self.ax_reg.cla() n_ref = len([x for x in self.keypoints.keys() if x is not None]) n_mov = len([x for x in self.keypoints.values() if x is not None]) perc_good = np.sum(self.df.jacobian > 0) / np.prod(self.df.shape) average_disp = self.df.average_displacement key_shortcuts = ", ".join( f"{description}: {key!r}" for key, description in self.key_descriptions.items() ) self.rax_quality.set_title( f"Transform quality: {perc_good:.2%}\n" f"Average displacement: {average_disp:.2f}\n\n" f"(Key shortcuts: {key_shortcuts})" ) self.ax.set_title( f"Reference vs Moving (Interactive), ref: {n_ref}, mov: {n_mov}" ) self.ax_reg.set_title( "Reference vs Registered" if not self.show_grid else "Warping Grid" ) self.ax.set_axis_off() self.ax_reg.set_axis_off() # Prepare images img_ref = self.img_ref_.copy() img_ref[img_ref < self.th_ref] = 0 img_mov = self.img_mov_.copy() img_mov[img_mov < self.th_movreg] = 0 img_reg = self.img_reg.copy() img_reg[img_reg < self.th_movreg] = 0 colored_grid = gray2rgb(self.grid_ / 255) colored_grid[self.df.jacobian <= 0, :] *= [0.9, 0, 0] warped_grid = self.df.warp(colored_grid) if self.ref_first: self.ax.imshow(img_ref, cmap=self.cmap_ref, alpha=self.alpha_ref) self.ax.imshow(img_mov, cmap=self.cmap_movreg, alpha=self.alpha_movreg) if self.show_grid: self.ax_reg.imshow(warped_grid) else: self.ax_reg.imshow(img_ref, cmap=self.cmap_ref, alpha=self.alpha_ref) self.ax_reg.imshow( img_reg, cmap=self.cmap_movreg, alpha=self.alpha_movreg ) else: self.ax.imshow(img_mov, cmap=self.cmap_movreg, alpha=self.alpha_movreg) self.ax.imshow(img_ref, cmap=self.cmap_ref, alpha=self.alpha_ref) if self.show_grid: self.ax_reg.imshow(warped_grid) else: self.ax_reg.imshow( img_reg, cmap=self.cmap_movreg, alpha=self.alpha_movreg ) self.ax_reg.imshow(img_ref, cmap=self.cmap_ref, alpha=self.alpha_ref) # Scatter plots refs_movs = [(k, v) for k, v in self.keypoints.items()] # THIS SETS THE ORDER refs_with_none = [x[0] for x in refs_movs] movs_with_none = [x[1] for x in refs_movs] colors_ref = [ self.colors[x_ref] for x_ref, x_mov in refs_movs if x_ref is not None ] colors_mov = [ self.colors[x_ref] for x_ref, x_mov in refs_movs if x_mov is not None ] # Reference points self.ax.scatter( [x[0] for x in refs_with_none if x is not None], [x[1] for x in refs_with_none if x is not None], marker=self.marker_ref, s=self.marker_size_ref, # label='ref', # c=[c for c in colors_ref if c is not None], # edgecolors='red', c=colors_ref, ) # Moving points self.ax.scatter( [x[0] for x in movs_with_none if x is not None], [x[1] for x in movs_with_none if x is not None], marker=self.marker_mov, s=self.marker_size_mov, # label='mov', c=colors_mov, ) ref_label = mlines.Line2D( [], [], color="black", marker=self.marker_ref, linestyle="None", markersize=self.marker_size_ref ** (1 / 2), label="ref", ) mov_label = mlines.Line2D( [], [], color="black", marker=self.marker_mov, linestyle="None", markersize=self.marker_size_mov ** (1 / 2), label="mov", ) if self.show_arrows: deltas = [ (m[0] - r[0], m[1] - r[1]) for r, m in refs_movs if (r is not None and m is not None) ] x_del = [x[0] for x in deltas] y_del = [x[1] for x in deltas] x_pos = [r[0] for r, m in refs_movs if (r is not None and m is not None)] y_pos = [r[1] for r, m in refs_movs if (r is not None and m is not None)] for i in range(len(deltas)): self.ax.arrow( x_pos[i], y_pos[i], x_del[i], y_del[i], color=self.colors[(x_pos[i], y_pos[i])], ) self.ax.legend(handles=[ref_label, mov_label]) self.fig.canvas.draw_idle()
[docs] def on_click(self, event): """Take action on a click. Parameters ---------- event : matplotlib.backend_bases.LocationEvent The location event. Notes ----- We can use this to extract x and y coordinate of the click. """ # Can be [None, "ZOOM", "PAN"], don't handle clicks if # zooming or panning if self.ax.get_navigate_mode() is not None: return # Only self.ax are interactive axes if event.inaxes != self.ax: return # Get coordinates of the click x, y = int(event.xdata), int(event.ydata) # Clean axis (The logic is to draw everything from scratch) new_pair_mode = np.all( [ x is not None for x in ( self.keypoints.values() if self.mode == "ref2mov" else self.keypoints.keys() ) ] ) if new_pair_mode: c = self.all_colors[0] self.all_colors.rotate() if self.mode == "ref2mov": self.keypoints[(x, y)] = None self.colors[(x, y)] = c else: self.keypoints[None] = (x, y) self.colors[None] = c else: if self.mode == "ref2mov": self.keypoints[ [k for k, v in self.keypoints.items() if v is None][0] ] = (x, y) else: self.keypoints[(x, y)] = self.keypoints[None] del self.keypoints[None] self.colors[(x, y)] = self.colors[None] del self.colors[None] self._update_plots()
[docs] def on_press(self, event): """Take action on a key press. Parameters ---------- event : matplotlib.backend_bases.KeyEvent The key event fired. """ if event.key is None: return key_pressed = event.key.lower() if key_pressed == self.key_delete_ref_point: print("Handling delete") if event.inaxes != self.ax: return x, y = int(event.xdata), int(event.ydata) # You can only remove reference points for diff_x in range(-self.epsilon, self.epsilon): for diff_y in range(-self.epsilon, self.epsilon): if (x + diff_x, y + diff_y) in self.keypoints: del self.keypoints[(x + diff_x, y + diff_y)] print("Deleted {}".format((x + diff_x, y + diff_y))) self._update_plots() return elif key_pressed == self.key_swap_alpha: self.alpha_movreg, self.alpha_movreg_prev = ( self.alpha_movreg_prev, self.alpha_movreg, ) self.alpha_movreg_slider.set_val(self.alpha_movreg)
[docs] def run(self): """Run the GUI.""" # Register mouse and keyboard event callbacks self.fig.canvas.mpl_connect("button_press_event", self.on_click) self.fig.canvas.mpl_connect("key_press_event", self.on_press) # Show the plot window plt.show()
[docs]def run_gui(img_ref, img_mov, mode="ref2mov", title=""): """Graphical user interface for manual labeling. Notes ----- If `mode` == 'ref2mov' then one first specifies the point in the reference image (circle marker) and then the corresponding pixel in the moving image (star marker). Note that these pairs have the same color. To delete a specific pair hover above on the undesirable reference point and press space bar and this will automatically delete it. Note that deletion delete both the reference point and the moving point but you can only point at the reference one for deletions. Parameters ---------- img_ref : np.ndaray Reference image. Needs to be dtype == np.uint8. img_mov : np.ndarray Input image. Needs to be dtype == np.uint8 and the same shape as `img_ref`. mode : str, {'ref2mov', 'mov2ref'} If 'ref2mov' then the first point should be in the reference image and the other point in the moving one. For 'mov2ref' its vice versa. title : str, Additional title of the figure. Returns ------- df : DisplacementField Displacement field corresponding to the last change before closing the window of the GUI. keypoints : dict Dictionary of keypoints. symmetric_registration : bool Whether or not the registration was symmetrized. If true then all the returned keypoints should be mirrored across a vertical line through the image. This can be done by setting x => (width - x) for all keypoints. img_reg : np.ndarray Registered image. interpolation_method : str Interpolation method kernel : str Kernel. """ if not img_ref.shape[:2] == img_mov.shape[:2]: raise ValueError( "The fixed and moving image need to have the same shape. " f"{img_ref.shape} vs. {img_mov.shape}" ) if not (img_ref.dtype == np.float32 and img_mov.dtype == np.float32): raise TypeError("Only works with float32 dtype") if mode not in {"ref2mov", "mov2ref"}: raise ValueError("The mode can only be ref2mov or mov2ref.") helper = HelperGlobal(img_ref, img_mov, mode, title) helper.run() return ( helper.df, helper.keypoints, helper.symmetric_registration, helper.img_reg, helper.interpolation_method, helper.kernel, ) # noqa