Source code for atlalign.base

"""Fundamental building blocks of the project.

Notes
-----
This module does not import any other module except for zoo. Be careful to keep this logic in order
to prevent cyclical imports.
"""

"""
    The package atlalign is a tool for registration of 2D images.

    Copyright (C) 2021 EPFL/Blue Brain Project

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

import pathlib

import cv2
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import (
    LSQBivariateSpline,
    NearestNDInterpolator,
    Rbf,
    SmoothBivariateSpline,
    griddata,
)
from skimage.transform import resize

from atlalign.utils import griddata_custom
from atlalign.zoo import (
    affine,
    affine_simple,
    control_points,
    edge_stretching,
    paper,
    paper_microsoft,
    patch_shift,
    projective,
    single_frequency,
)

# Set/create the default caching folder
GLOBAL_CACHE_FOLDER = pathlib.Path.home() / ".atlalign"
if not GLOBAL_CACHE_FOLDER.exists():
    GLOBAL_CACHE_FOLDER.mkdir(parents=True)


[docs]class DisplacementField: """ A class representing a 2D displacement vector field. Notes ----- The dtype is enforced to be single-precision (float32) since opencv's remap function (used for warping) does not accept double-precision (float64). Attributes ---------- delta_x : np.ndarray A 2D array of dtype float32 that represents the displacement field in the x coordinate (columns). Positive values move the pixel to the right, negative move it to the left. delta_y : np.ndarray A 2D array of dtype float32 that represents the displacement field in the y coordinate (rows). Positive values move the pixel down, negative pixels move the pixels up. """
[docs] @classmethod def generate(cls, shape, approach="identity", **kwargs): """Construct different displacement vector fields (DVF) via factory method. Parameters ---------- shape : tuple A tuple representing the (height, width) of the displacement field. Note that if multiple channels passed then only the height and width is extracted. approach : str, {'affine', 'affine_simple', 'control_points', 'identity', 'microsoft', 'paper', 'patch_shift'} What approach to use for generating the DVF. kwargs Additional parameters that are passed into the the given approach function. Returns ------- DisplacementField An instance of a Displacement field. """ # Check - extremely important since no checks in the zoo if len(shape) == 3: shape_ = shape[ :2 ] # to make it easier for the user who passes img.shape of an RGB image elif len(shape) == 2: shape_ = shape else: raise ValueError( "The length of shape needs to be either 2 or 3, {} given".format( len(shape) ) ) if approach == "affine": kw = ["matrix"] kwargs_affine = { k: v for k, v in kwargs.items() if k in kw } # Check if passed any delta_x, delta_y = affine(shape_, **kwargs_affine) elif approach == "affine_simple": kw = [ "scale_x", "scale_y", "rotation", "translation_x", "translation_y", "shear", "apply_centering", ] kwargs_affine_simple = { k: v for k, v in kwargs.items() if k in kw } # Check if passed any delta_x, delta_y = affine_simple(shape_, **kwargs_affine_simple) elif approach == "control_points": kw = [ "points", "values_delta_x", "values_delta_y", "anchor_corners", "interpolation_method", "interpolator_kwargs", ] kwargs_control_points = { k: v for k, v in kwargs.items() if k in kw } # Check if passed any delta_x, delta_y = control_points(shape_, **kwargs_control_points) elif approach == "edge_stretching": kw = [ "edge_mask", "n_perturbation_points", "radius_max", "interpolation_method", "interpolator_kwargs", ] kwargs_edge_stretching = {k: v for k, v in kwargs.items() if k in kw} delta_x, delta_y = edge_stretching(shape_, **kwargs_edge_stretching) elif approach == "identity": kw = [] delta_x, delta_y = np.zeros(shape_), np.zeros(shape_) elif approach == "microsoft": kw = ["alpha", "sigma", "random_state"] kwargs_microsoft = { k: v for k, v in kwargs.items() if k in kw } # Check if passed any delta_x, delta_y = paper_microsoft(shape_, **kwargs_microsoft) elif approach == "paper": kw = [ "n_pixels", "v_min", "v_max", "kernel_sigma", "p", "random_state", ] # all possible keywords kwargs_paper = { k: v for k, v in kwargs.items() if k in kw } # Check if passed any delta_x, delta_y = paper(shape_, **kwargs_paper) elif approach == "patch_shift": kw = ["ul", "height", "width", "shift_size", "shift_direction"] kwargs_patch_shift = {k: v for k, v in kwargs.items() if k in kw} delta_x, delta_y = patch_shift(shape_, **kwargs_patch_shift) elif approach == "projective": kw = ["matrix"] kwargs_projective = { k: v for k, v in kwargs.items() if k in kw } # Check if passed any delta_x, delta_y = projective(shape_, **kwargs_projective) elif approach == "single_frequency": kw = [ "p", "grid_spacing", "n_perturbation_points", "radius_mean", "interpolation_method", "interpolator_kwargs", ] kwargs_single_frequency = {k: v for k, v in kwargs.items() if k in kw} delta_x, delta_y = single_frequency(shape_, **kwargs_single_frequency) else: raise ValueError("The approach {} is not valid".format(approach)) # Check if no illegal arguments (now its too late but beter than never:D) allowed_kw = set(kw) passed_kw = set(kwargs) if not passed_kw.issubset(allowed_kw): diff = passed_kw - allowed_kw raise ValueError( "Illegal arguments passed for approach {}: {}".format(approach, diff) ) return cls(delta_x=delta_x, delta_y=delta_y)
[docs] @classmethod def from_file(cls, file_path): """Load displacement field from a file. Parameters ---------- file_path : str or pathlib.Path Path to where the file is located. Returns ------- DisplacementField Instance of the Displacement field. """ if isinstance(file_path, str): file_path = pathlib.Path(file_path) elif isinstance(file_path, pathlib.Path): pass else: raise TypeError( "The file path needs to be either a string or a pathlib.Path." ) suffix = file_path.suffix if suffix == ".npy": deltas_xy = np.load(str(file_path)) if deltas_xy.ndim != 3: raise ValueError("Only supporting 3 dimensional arrays.") if deltas_xy.shape[2] != 2: raise ValueError( "The last dimensions needs to have 2 elements (delta_x, delta_y)" ) return cls(deltas_xy[..., 0], deltas_xy[..., 1]) else: raise ValueError("Unsupported suffix {}".format(suffix))
[docs] @classmethod def from_transform(cls, f_x, f_y): """Instantiate displacement field from actual transformations. Parameters ---------- f_x : np.array 2D array of shape (h, w) representing the x coordinate of the transformation. f_y : np.array 2D array of shape (h, w) representing the y coordinate of the transformation. Returns ------- DisplacementField Instance of the Displacement field. """ # checks shape_x, shape_y = f_x.shape, f_y.shape if not len(shape_x) == len(shape_y) == 2: raise ValueError("The transforms need to be 2D arrays") if not shape_x == shape_y: raise ValueError( "The width and height of x and y transforms do not match, {} vs {}".format( shape_x, shape_y ) ) shape = shape_x x, y = np.meshgrid(list(range(shape[1])), list(range(shape[0]))) delta_x = f_x - x delta_y = f_y - y return DisplacementField(delta_x, delta_y)
def __init__(self, delta_x, delta_y): # Checks shape_x, shape_y = delta_x.shape, delta_y.shape if not len(shape_x) == len(shape_y) == 2: raise ValueError("The displacement fields need to be 2D arrays") if not shape_x == shape_y: raise ValueError( "The width and height of x and y displacement field do not match, {} vs {}".format( shape_x, shape_y ) ) self.delta_x = delta_x.astype(np.float32) self.delta_y = delta_y.astype(np.float32) # Define more attributes self.shape = shape_x def __add__(self, other): """Addition of two displacement vector fields. Notes ----- Not useful at all. This is addition of the transformation implied by these displacement fields. u_sum(x) = F_sum(x) - x = F_1(x) + F_2(x) - x = u_1(x) + x + u_2(x) + x - x = u_1(x) + u_2(x) + x """ raise NotImplementedError("Still needs to be implemented") def __eq__(self, other): """Equality.""" if not isinstance(other, DisplacementField): raise TypeError( "The right hand side object is not DisplacementField but {}".format( type(other) ) ) return np.allclose(self.delta_x, other.delta_x) and np.allclose( self.delta_y, other.delta_y ) def __call__(self, other, interpolation="linear", border_mode="replicate", c=0): """Composition of transformations. Notes ----- This composition is only approximate since we need to approximate `self` on off-grid elements. Negative side effect is that composing with inverse will not necessarily lead to identity. Parameters ---------- other : DisplacementField An inner DVF. interpolation : str, {'nearest', 'linear', 'cubic', 'area', 'lanczos'} Regular grid interpolation method to be used. border_mode : str, {'constant', 'replicate', 'reflect', 'wrap', 'reflect101', 'transparent'} How to fill outside of the range values. See references for detailed explanation. c : float Only used if `border_mode='constant'` and represents the fill value. Returns ------- composition : DisplacementField Let F: x -> x + self and G: x -> x + other. Then the composition represents x: F(G(x)) - x. """ if not isinstance(other, DisplacementField): raise TypeError( "The inner object is not DisplacementField but {}".format(type(other)) ) if self.shape != other.shape: raise ValueError("Cannot compose DVF of different shapes!") # Think about self as 2 images delta_x and delta_y, and the final transformation also as 2 images with # intentieties being equal to the output vector. x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0]))) delta_x = ( other.warp( x + self.delta_x, interpolation=interpolation, border_mode=border_mode, c=c, ) - x ) delta_y = ( other.warp( y + self.delta_y, interpolation=interpolation, border_mode=border_mode, c=c, ) - y ) return DisplacementField(delta_x, delta_y) def __mul__(self, c): """Multiplication by a constant from the right. Parameters ---------- c : int or float A number. Returns ------- result : DisplacementField An instance of DisplacementField where both the `delta_x' and `delta_y` were elementwise multiplied by `c`. Raises ------ TypeError If `c` is not int or float. """ if not isinstance(c, (int, float)): raise TypeError("The constant c needs to be a number.") return DisplacementField(delta_x=c * self.delta_x, delta_y=c * self.delta_y) def __rmul__(self, c): """Multiplication by a constant from the left. Notes ----- Since we want this to be commutative we simply delegate all the logic to `__mul__` method. Parameters ---------- c : int or float A number. Returns ------- result : DisplacementField An instance of DisplacementField where both the `delta_x' and `delta_y` were elementwise multiplied by `c`. Raises ------ TypeError If `c` is not int or float. """ return self * c @property def average_displacement(self): """Average displacement per pixel.""" return self.norm.mean() @property def delta_x_scaled(self): """Scaled version of delta_x.""" return self.delta_x / self.shape[1] @property def delta_y_scaled(self): """Scaled version of delta_y.""" return self.delta_y / self.shape[0] @property def is_valid(self): """Check whether both delta_x and delta_y finite.""" return np.all(np.isfinite(self.delta_x)) and np.all(np.isfinite(self.delta_y)) @property def jacobian(self): """Compute determinant of a Jacobian per each pixel.""" delta_x = self.delta_x delta_y = self.delta_y a_11 = np.zeros(self.shape) a_12 = np.zeros(self.shape) a_21 = np.zeros(self.shape) a_22 = np.zeros(self.shape) # inside (symmetric) a_11[:, 1:-1] = 1 + (-delta_x[:, :-2] + delta_x[:, 2:]) / 2 a_12[1:-1, :] = (-delta_x[:-2, :] + delta_x[2:, :]) / 2 a_21[:, 1:-1] = (-delta_y[:, :-2] + delta_y[:, 2:]) / 2 a_22[1:-1, :] = 1 + (-delta_y[:-2, :] + delta_y[2:, :]) / 2 # edges (one-sided) a_11[:, 0] = 1 + (delta_x[:, 1] - delta_x[:, 0]) a_11[:, -1] = 1 + (delta_x[:, -1] - delta_x[:, -2]) a_12[0, :] = delta_x[1, :] - delta_x[0] a_12[-1, :] = delta_x[-1, :] - delta_x[-2, :] a_21[:, 0] = delta_y[:, 1] - delta_y[:, 0] a_21[:, -1] = delta_y[:, -1] - delta_y[:, -2] a_22[0, :] = 1 + delta_y[1, :] - delta_y[0] a_22[-1, :] = 1 + delta_y[-1, :] - delta_y[-2, :] res = np.multiply(a_11, a_22) - np.multiply(a_12, a_21) return res @property def n_pixels(self): """Count the number of pixels in the displacement field. Notes ----- Number of channels is ignored. """ return np.prod(self.shape[:2]) @property def norm(self): """Norm for each pixel.""" return np.sqrt(np.square(self.delta_x) + np.square(self.delta_y)) @property def outsiders(self): """For each pixels determines whether it is mapped outside of the image. Notes ----- An important thing to look out for since for each outsider the interpolator cannot use grid interpolation. """ x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0]))) f_x, f_y = x + self.delta_x, y + self.delta_y return np.logical_or( np.logical_or(0 > f_x, f_x >= self.shape[1]), np.logical_or(0 > f_y, f_y >= self.shape[0]), ) @property def summary(self): """Generate a summary of the displacement field. Returns ------- summary : pd.Series Summary series containing the most interesting values. """ import pandas as pd # Parameters eps = 1e-2 # Preparation df_dict = {} abs_x = abs(self.delta_x) abs_y = abs(self.delta_y) # PER COORDINATE # x df_dict["x_amax"] = abs_x.max() df_dict["x_amean"] = abs_x.mean() df_dict["x_mean"] = self.delta_x.mean() # y df_dict["y_amax"] = abs_y.max() df_dict["y_amean"] = abs_y.mean() df_dict["y_mean"] = self.delta_y.mean() # COMBINED df_dict["n_unchanged"] = np.logical_and(abs_x < eps, abs_y < eps).sum() df_dict["percent_unchanged"] = 100 * df_dict["n_unchanged"] / self.n_pixels df_dict["n_outside"] = self.outsiders.sum() df_dict["percent_outside"] = 100 * df_dict["n_outside"] / self.n_pixels return pd.Series(df_dict) @property def transformation(self): """Output the actual transformation rather than the displacement field. Returns ------- f_x : np.ndarray A 2D array of dtype float32. For each pixel in the fixed image what is the corresponding x coordinate in the moving image. f_y : np.ndarray A 2D array of dtype float32. For each pixel in the fixed image what is the corresponding y coordinate in the moving image. """ x, y = np.meshgrid( np.arange(self.shape[1], dtype=np.float32), np.arange(self.shape[0], dtype=np.float32), copy=False, ) # will guarantee the output is float32 f_x = x + self.delta_x f_y = y + self.delta_y return f_x, f_y
[docs] def anchor(self, h_kept=0.75, w_kept=0.75, ds_f=5, smooth=0): """Anchor and smoothen the displacement field. Embeds a rectangle inside of the domain and uses it as a regular subgrid to smoothen out the original displacement field via radial basis function interpolation. Additionally makes sure that the 4 corners have zero displacements. Parameters ---------- h_kept : int or float If ``int`` then represents the actual height of the rectangle to be embedded. If ``float`` then percentage of the df height. w_kept : int or float If ``int`` then represents the actual width of the rectangle to be embedded. If ``float`` then percentage of the df width. ds_f : ds_f Downsampling factor. The higher the quicker the interpolation but the more different the new df compared to the original. smooth : float If 0 then performs exact interpolation - transformation values on node points are equal to the original. If >0 then starts favoring smoothness over exact interpolation. Needs to be meddled with manually. Returns ------- DisplacementField Smoothened and anchored version of the original displacement field. """ h, w = self.shape center_r, center_c = int(h // 2), int(w // 2) h_kept_ = h_kept if isinstance(h_kept, int) else h_kept * h w_kept_ = w_kept if isinstance(w_kept, int) else w_kept * w h_half, w_half = int(h_kept_ // 2), int(w_kept_ // 2) h_range = list(range(center_r - h_half, center_r + h_half, ds_f)) w_range = list(range(center_c - w_half, center_c + w_half, ds_f)) y, x = np.meshgrid(h_range, w_range) points = np.stack([y.ravel(), x.ravel()], axis=-1) values_delta_x = np.array([self.delta_x[y, x] for (y, x) in points]) values_delta_y = np.array([self.delta_y[y, x] for (y, x) in points]) return DisplacementField.generate( shape=self.shape, approach="control_points", points=points, values_delta_x=values_delta_x, values_delta_y=values_delta_y, interpolation_method="rbf", interpolator_kwargs={"smooth": smooth, "function": "linear"}, )
[docs] def adjust(self, delta_x_max=None, delta_y_max=None, force_inside_border=True): """Adjust the displacement field. Notes ----- Not in place, returns a modified instance. Parameters ---------- delta_x_max : float Maximum absolute size of delta_x. If None, no limit is imposed. delta_y_max : float Maximum absolute size of delta_y. If None, no limit is imposed. force_inside_border : bool If True, then all displacement vector that would result in leaving the image are clipped. Returns ------- DisplacementField Adjusted DisplacementField. """ eps_final = 1e-8 # just to make sure everything is within the image region sign_x = np.sign(self.delta_x) sign_y = np.sign(self.delta_y) new_delta_x = ( np.minimum(delta_x_max, abs(self.delta_x)) * sign_x if delta_x_max is not None else self.delta_x ) new_delta_y = ( np.minimum(delta_y_max, abs(self.delta_y)) * sign_y if delta_y_max is not None else self.delta_y ) if force_inside_border: # will result in a lot of pixels mapping to the border c_matrix_x = np.ones( self.shape ) # minimum scaling that results in x not being outside c_matrix_y = np.ones( self.shape ) # minimum scaling that results in y not being outside # Preparation x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0]))) max_u = y max_d = self.shape[0] - y - 1 max_r = self.shape[1] - x - 1 max_l = x # Max delta_x - depends whether positive or negative max_delta_x = ( np.ones(self.shape) * np.inf ) # by default it can be as large as you want ix_delta_x_pos = self.delta_x > 0 ix_delta_x_zero = self.delta_x == 0 ix_delta_x_neg = self.delta_x < 0 max_delta_x[ix_delta_x_pos] = max_r[ix_delta_x_pos] max_delta_x[ix_delta_x_neg] = max_l[ix_delta_x_neg] # Max delta_y - depends whether positive or negative max_delta_y = np.ones(self.shape) * np.inf ix_delta_y_pos = self.delta_y > 0 ix_delta_y_zero = self.delta_y == 0 ix_delta_y_neg = self.delta_y < 0 # Experimental max_delta_y[ix_delta_y_pos] = max_d[ix_delta_y_pos] max_delta_y[ix_delta_y_neg] = max_u[ix_delta_y_neg] c_matrix_x[~ix_delta_x_zero] = np.minimum( 1, max_delta_x[~ix_delta_x_zero] / abs(new_delta_x[~ix_delta_x_zero]) ) c_matrix_y[~ix_delta_y_zero] = np.minimum( 1, max_delta_y[~ix_delta_y_zero] / abs(new_delta_y[~ix_delta_y_zero]) ) c_matrix = np.minimum(c_matrix_x, c_matrix_y) * (1 - eps_final) new_delta_x = c_matrix * new_delta_x new_delta_y = c_matrix * new_delta_y return DisplacementField(new_delta_x, new_delta_y)
[docs] def mask(self, mask_matrix, fill_value=0): """Mask a displacement field. Notes ----- Not in place, returns a modified instance. Parameters ---------- mask_matrix : np.array An array of dtype=bool where True represents a pixel that is supposed to be unchanged. False pixels are filled with `fill_value`. fill_value : float or tuple Value to fill the False pixels with. If tuple then fill_value_x, fill_value_y Returns ------- DisplacementField A new DisplacementField instance accordingly masked. """ if not mask_matrix.shape == self.shape: raise ValueError( "The mask array has an incorrect shape of {}.".format(mask_matrix.shape) ) if not mask_matrix.dtype == bool: raise TypeError( "The dtype of the array needs to be a bool, current dtype {}".format( mask_matrix.dtype ) ) delta_x_masked = self.delta_x.copy() delta_y_masked = self.delta_y.copy() if isinstance(fill_value, tuple): fill_value_x, fill_value_y = fill_value elif isinstance(fill_value, (float, int)): fill_value_x, fill_value_y = fill_value, fill_value else: raise TypeError("Incorrect type {} of fill_value".format(type(fill_value))) delta_x_masked[~mask_matrix] = fill_value_x delta_y_masked[~mask_matrix] = fill_value_y return DisplacementField(delta_x_masked, delta_y_masked)
[docs] def plot_dvf(self, ds_f=8, figsize=(15, 15), ax=None): """Plot displacement vector field. Notes ----- Still works in a weird way. Parameters ---------- ds_f : int Downsampling factor, i.e if `ds_f=8` every 8-th row and every 8th column printed. figsize : tuple Size of the figure. ax : matplotlib.Axes Axes upon which to plot. If None, create a new one Returns ------- ax : matplotlib.Axes Axes with the visualization. """ x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0]))) if ax is None: _, ax_quiver = plt.subplots(figsize=figsize) else: ax_quiver = ax ax_quiver.invert_yaxis() ax_quiver.quiver( x[::ds_f, ::ds_f], y[::ds_f, ::ds_f], self.delta_x[::ds_f, ::ds_f], -self.delta_y[::ds_f, ::ds_f], ) # matplotlib has positive delta y as up, in our case its down return ax_quiver
[docs] def plot_outside(self, figsize=(15, 15), ax=None): """Plot all pixels that are mapped outside of the image. Parameters ---------- figsize : tuple Size of the figure. ax : matplotlib.Axes Axes upon which to plot. If None, create a new one Returns ------- ax : matplotlib.Axes Axes with the visualization. """ res = np.zeros(self.shape, dtype=float) if ax is None: _, ax_outside = plt.subplots(figsize=figsize) else: ax_outside = ax res[self.outsiders] = 1 ax_outside.imshow(res, cmap="gray") return ax_outside
[docs] def plot_ranges( self, freq=10, figsize=(15, 10), kwargs_domain=None, kwargs_range=None, ax=None ): """Plot domain and the range of the mapping. Parameters ---------- freq : int Take every `freq` th pixel. The higher the more sparse. figsize : tuple Size of the figure. kwargs_domain : dict or None If ``dict`` then matplotlib kwargs to be passed into the domain scatter. kwargs_range : dict or None If ``dict`` then matplotlib kwargs to be passed into the range scatter. ax : matplotlib.Axes Axes upon which to plot. If None, create a new one. Returns ------- ax : matplotlib.Axes Axes with the visualization. """ # original range h, w = self.shape tx, ty = self.transformation x, y = [], [] x_r, y_r = [], [] i = 0 for r in range(h): for c in range(w): i += 1 if i % freq == 0: x.append(c) y.append(h - r) x_r.append(tx[r, c]) y_r.append(h - ty[r, c]) kwargs_domain = kwargs_domain or {"s": 0.1, "color": "blue"} kwargs_range = kwargs_range or {"s": 0.1, "color": "green"} if ax is None: _, ax_ranges = plt.subplots(figsize=figsize) else: ax_ranges = ax ax_ranges.scatter(x, y, label="Domain", **kwargs_domain) ax_ranges.scatter(x_r, y_r, label="Range", **kwargs_range) ax_ranges.legend() return ax_ranges
[docs] def pseudo_inverse( self, ds_f=1, interpolation_method="griddata_custom", interpolator_kwargs=None ): """Find the displacement field of the inverse mapping. Notes ----- Dangerously approximate and imprecise. Uses irregular grid interpolation. Parameters ---------- ds_f : int, optional Downsampling factor for all the interpolations. Note that ds_1 = 1 means no downsampling. Applied both to the x and y coordinates. interpolation_method : {'griddata', 'bspline', 'rbf'}, optional Interpolation method to use. interpolator_kwargs : dict, optional Additional parameters passed to the interpolator. Returns ------- DisplacementField An instance of the DisplacementField class representing the inverse mapping. """ interpolator_kwargs = interpolator_kwargs or {} x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0]))) xi = (y, x) x_r, y_r = x.ravel(), y.ravel() points = np.hstack( ( (y_r + self.delta_y.ravel()).reshape(-1, 1), (x_r + self.delta_x.ravel()).reshape(-1, 1), ) ) # Downsampling points = points[::ds_f] x_r_ds = x_r[::ds_f] y_r_ds = y_r[::ds_f] x_, y_ = points[:, 1], points[:, 0] if interpolation_method == "griddata": values_grid_x = griddata(points=points, values=x_r_ds, xi=xi) values_grid_y = griddata(points=points, values=y_r_ds, xi=xi) delta_x = values_grid_x.reshape(self.shape) - x delta_y = values_grid_y.reshape(self.shape) - y elif interpolation_method == "griddata_custom": # triangulation performed only once values_grid_x, values_grid_y = griddata_custom(points, x_r_ds, y_r_ds, xi) delta_x = values_grid_x.reshape(self.shape) - x delta_y = values_grid_y.reshape(self.shape) - y elif interpolation_method == "noop": # for benchmarking purposes delta_x, delta_y = np.zeros(self.shape), np.zeros(self.shape) elif interpolation_method == "smooth_bspline": ip_delta_x = SmoothBivariateSpline(x_, y_, x_r_ds, **interpolator_kwargs) ip_delta_y = SmoothBivariateSpline(x_, y_, y_r_ds, **interpolator_kwargs) delta_x = ip_delta_x(x_r, y_r, grid=False).reshape(self.shape) - x delta_y = ip_delta_y(x_r, y_r, grid=False).reshape(self.shape) - y elif interpolation_method == "LSQ_bspline": tx_ds_f = interpolator_kwargs.pop( "tx_ds_f", 1 ) # downsampling factor on x knots, not part of scipy kwargs ty_ds_f = interpolator_kwargs.pop( "ty_ds_f", 1 ) # downsampling factor on y knots, not part of scipy kwargs auto = interpolator_kwargs.pop("auto", False) if auto: # SEEMS TO CRASH THE KERNEL # Create a grid center only where deformation takes place eps = 0.1 range_x = np.unique(self.transformation[0][self.delta_x > eps]) range_y = np.unique(self.transformation[1][self.delta_y > eps]) tx_start, tx_end = max(0, np.floor(range_x.min())), min( self.shape[1] - 1, np.ceil(range_x.max()) ) ty_start, ty_end = max(0, np.floor(range_y.min())), min( self.shape[0] - 1, np.ceil(range_y.max()) ) tx = list(np.arange(tx_start, tx_end, dtype=int))[::tx_ds_f] ty = list(np.arange(ty_start, ty_end, dtype=int))[::ty_ds_f] else: tx, ty = ( list(range(self.shape[1]))[::tx_ds_f], list(range(self.shape[0]))[::ty_ds_f], ) ip_delta_x = LSQBivariateSpline( x_, y_, x_r_ds, tx, ty, **interpolator_kwargs ) ip_delta_y = LSQBivariateSpline( x_, y_, y_r_ds, tx, ty, **interpolator_kwargs ) delta_x = ip_delta_x(x_r, y_r, grid=False).reshape(self.shape) - x delta_y = ip_delta_y(x_r, y_r, grid=False).reshape(self.shape) - y elif interpolation_method == "rbf": ip_delta_x = Rbf(x_, y_, x_r_ds, **interpolator_kwargs) ip_delta_y = Rbf(x_, y_, y_r_ds, **interpolator_kwargs) delta_x = ip_delta_x(x_r, y_r).reshape(self.shape) - x delta_y = ip_delta_y(x_r, y_r).reshape(self.shape) - y else: raise ValueError( "Unrecognized interpolation_method: {}".format(interpolation_method) ) return DisplacementField(delta_x, delta_y)
[docs] def resize(self, new_shape): """Calculate a resized displacement vector field. Goal: df_resized.warp(img) ~ resized(df.warp(img)) Parameters ---------- new_shape : tuple Represents (new_height, new_width) of the resized displacement field. Returns ------- DisplacementField New DisplacementField with a shape of new_shape. """ if not isinstance(new_shape, tuple): raise TypeError("Incorrect type of new_shape: {}".format(type(new_shape))) if not len(new_shape) == 2: raise ValueError("The length of new shape must be 2") f_x, f_y = self.transformation new_f_x = resize(f_x, output_shape=new_shape) new_f_y = resize(f_y, output_shape=new_shape) return DisplacementField.from_transform(new_f_x, new_f_y)
[docs] def resize_constant(self, new_shape): """Calculate the resized displacement vector field that will have the same effect on original image. Goal: upsampled(df.warp(img_downsampled)) ~ df_resized.warp(img). Parameters ---------- new_shape : tuple Represents (new_height, new_width) of the resized displacement field. Returns ------- DisplacementField New DisplacementField with a shape of new_shape. Notes ----- Very useful when we perform registration on a smaller resolution image and then we want to resize it back to the original higher resolution shape. """ fx, fy = self.transformation x_ratio, y_ratio = new_shape[1] / self.shape[1], new_shape[0] / self.shape[0] fx_, fy_ = fx * x_ratio, fy * y_ratio new_f_x = resize(fx_, output_shape=new_shape) new_f_y = resize(fy_, output_shape=new_shape) return DisplacementField.from_transform(new_f_x, new_f_y)
[docs] def save(self, path): """Save displacement field as a .npy file. Notes ----- Can be loaded via `DisplacementField.from_file` class method. Parameters ---------- path : str or pathlib.Path Path to the file. """ path = pathlib.Path(path) if path.suffix == "": path = path.with_suffix(".npy") elif path.suffix == ".npy": pass else: raise ValueError("Invalid suffix {}".format(path.suffix)) np.save(path, np.stack([self.delta_x, self.delta_y], axis=2))
[docs] def warp(self, img, interpolation="linear", border_mode="constant", c=0): """Warp an input image based on the inner displacement field. Parameters ---------- img : np.ndarray Input image to which we will apply the transformation. Currently the only 3 supported dtypes are uint8, float32 and float64. The logic is for the `warped_img` to have the dtype (input dtype - output dtype). * uint8 - uint8 * float32 - float32 * float64 - float32 interpolation : str, {'nearest', 'linear', 'cubic', 'area', 'lanczos'} Regular grid interpolation method to be used. border_mode : str, {'constant', 'replicate', 'reflect', 'wrap', 'reflect101', 'transparent'} How to fill outside of the range values. See references for detailed explanation. c : float Only used if `border_mode='constant'` and represents the fill value. Returns ------- warped_img : np.ndarray Warped image. Note that the dtype will be the same as the input `img`. """ interpolation_mapper = { "nearest": cv2.INTER_NEAREST, "linear": cv2.INTER_LINEAR, "cubic": cv2.INTER_CUBIC, "area": cv2.INTER_AREA, "lanczos": cv2.INTER_LANCZOS4, } border_mode_mapper = { "constant": cv2.BORDER_CONSTANT, "replicate": cv2.BORDER_REPLICATE, "reflect": cv2.BORDER_REFLECT, "wrap": cv2.BORDER_WRAP, "reflect_101": cv2.BORDER_REFLECT101, "transparent": cv2.BORDER_TRANSPARENT, } if interpolation not in interpolation_mapper: raise KeyError( "Unsupported interpolation, available options: {}".format( interpolation_mapper.keys() ) ) if border_mode not in border_mode_mapper: raise KeyError( "Unsupported border_mode, available options: {}".format( border_mode_mapper.keys() ) ) dtype = img.dtype if dtype == np.float32 or dtype == np.uint8: img_ = img elif dtype == np.float64: img_ = img.astype(np.float32) dtype = np.float32 else: raise ValueError("Unsupported dtype {}.".format(dtype)) fx, fy = self.transformation return cv2.remap( img_, fx, fy, interpolation=interpolation_mapper[interpolation], borderMode=border_mode_mapper[border_mode], borderValue=c, )
[docs] def warp_annotation(self, img, approach="opencv"): """Warp an input annotation image based on the displacement field. If displacement falls outside of the image the logic is to replicate the border. This approach guarantees that no new labels are created. Notes ----- If approach is 'scipy' then calls ``scipy.spatial.cKDTree`` in the background with default Euclidian distance and exactly 1 nearest neighbor. Parameters ---------- img : np.ndarray Input annotation image. The allowed dtypes are currently int8, int16, int32 approach : str, {'scipy', 'opencv'} Approach to be used. Currently 'opencv' way faster. Returns ------- warped_img : np.ndarray Warped image. """ allowed_dtypes = ["int8", "int16", "int32"] input_dtype = img.dtype # CHECKS if not any([input_dtype == x for x in allowed_dtypes]): raise ValueError("The only allowed dtypes are {}".format(allowed_dtypes)) if approach == "scipy": x, y = np.meshgrid(list(range(self.shape[1])), list(range(self.shape[0]))) temp_all = np.hstack( (y.reshape(-1, 1), x.reshape(-1, 1), img[y, x].reshape(-1, 1)) ) inst = NearestNDInterpolator(temp_all[:, :2], temp_all[:, 2]) x_r, y_r = x.ravel(), y.ravel() coords = np.hstack( ( (y_r + self.delta_y.ravel()).reshape(-1, 1), (x_r + self.delta_x.ravel()).reshape(-1, 1), ) ) return inst(coords).reshape(self.shape).astype(input_dtype) elif approach == "opencv": # opencv keeps the same dtype apparently fx, fy = self.transformation return cv2.remap( img, fx, fy, cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE ) else: raise ValueError("Unrecognized approach {}".format(approach))