Source code for atlalign.utils

"""Collection of helper classes and function that do not deserve to be in base.py.

Notes
-----
This module cannot import from anywhere else within this project to prevent circular dependencies.

"""

"""
    The package atlalign is a tool for registration of 2D images.

    Copyright (C) 2021 EPFL/Blue Brain Project

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

import numpy as np
import scipy.spatial.qhull as qhull


def _triangulate(xyz, uvw):
    """Perform Delaunay triangulation.

    Parameters
    ----------
    xyz : np.ndarray
        An array of shape (N, 2) where each row represents one point in 2D (stable points) for which we know the
        function value.

    uvw : np.ndarray
        An array of shape (K, 2) where each row represents one point in 2D (query point) for which we want to
        interpolate the function value.

    Returns
    -------
    vertices : np.ndarray
        An array of shape (K, 3) representing the triangle vertices of each query point. Note that these
        vertices are always stable points (from range [O, N))

    wts : np.ndarray
        An array of shape (K, 3) representing the weights of respective vertices at each query point.

    """
    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, 2]
    bary = np.einsum("njk,nk->nj", temp[:, :2, :], delta)
    wts = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

    return vertices, wts


def _interpolate(values, vertices, wts):
    """Interpolate inside a triangle.

    Parameters
    ----------
    values : np.ndarray
        An array of shape (N,) that represents function value on the know points which are the vertices after
        Delaunay triangulation.

    vertices : np.ndarray
        An array of shape (K, 3) representing the triangle vertices of each query point. Note that these
        vertices are always stable points (from range [O, N)).

    wts : np.ndarray
        An array of shape (K, 3) representing the weights of respective vertices at each query point.

    Returns
    -------
    interpolations : np.ndarray
        An array of shape (K,) representing the interpolated function values on the query points.

    """
    return np.einsum("nj,nj->n", np.take(values, vertices), wts)


[docs]def griddata_custom(points, values_f_1, values_f_2, xi): """Run griddata extensions that performs only one triangulation. Notes ----- The scipy implementation does not allow to separate triangulation from interpolation. Since we need to evaluate 2 different functions on the !same! non-regular grid if points the triangulation can be simply just done once and stored. Parameters ---------- points : np.ndarray An array of shape (N, 2) where each row represents one point in 2D for which we know the function value. values_f_1 : np.ndarray An array of shape (N,) where each row represents a value of function f_1 on the corresponding point in `points`. values_f_2 : np.ndarray An array of shape (N,) where each row represents a value of function f_2 on the corresponding point in `points`. xi : tuple Tuple of 2 np.ndarray of shapes (h, w) representing the x and y coordinates of the points where we want to interpolate data. Note that this is simply the result of `np.meshgrid` if our points of interest lie on a regular grid. Returns ------- f_1_interpolation_on_xi : np.ndarray An array of shape (h, w) representing the interpolation of f_1 on the `xi` points. f_2_interpolation_on_xi : np.ndarray An array of shape (h, w) representing the interpolation of f_2 on the `xi` points. References ---------- https://stackoverflow.com/questions/20915502/speedup-scipy-griddata-for-multiple-interpolations-between-two-irregular-grids # noqa """ if isinstance(xi, tuple): shape = xi[0].shape xi = np.hstack((xi[0].reshape(-1, 1), xi[1].reshape(-1, 1))) # possible speedup else: raise TypeError("The xi needs to be a tuple of equally shaped np.ndarrays.") vertices, wts = _triangulate(points, xi) f_1_interpolation_on_xi = _interpolate(values_f_1, vertices, wts).reshape(shape) f_2_interpolation_on_xi = _interpolate(values_f_2, vertices, wts).reshape(shape) return f_1_interpolation_on_xi, f_2_interpolation_on_xi
def _find_all_children(d, children_list=None): """Construct a list of all the ids of the children of a node and the node itself. Parameters ---------- d : dict Dictionary node from whom we want the list of all the children and children's children. children_list : list, default None List of children which has to be empty for the first iteration of the function. Returns ------- children_list : list List of children's ids. """ if children_list is None: children_list = [] for key, value in d.items(): if key == "id": children_list.append(value) if isinstance(value, list): for child in value: _find_all_children(child, children_list) return children_list def _find_concatenate_labels(d, chosen_depth, dict_of_labels=None, current_depth=0): """Construct a dictionary which has for each key, the value of the new label after concatenation. Parameters ---------- d : dict Dictionary node for which we want to concatenate some ids depending on the depth branch. chosen_depth : int Depth at which it is wanted to concatenate the labels. dict_of_labels : dict, default {} Dictionary of corresponding labels (empty at the first call). current_depth : int, default 0 Depth of the dictionary node. Returns ------- dict_of_labels: dict Dictionary of corresponding labels after concatenation of labels tree. """ if dict_of_labels is None: dict_of_labels = {} if current_depth < chosen_depth: for key, value in d.items(): if key == "id": dict_of_labels[value] = value if isinstance(value, list): current_depth = current_depth + 1 for child in value: _find_concatenate_labels( child, chosen_depth, dict_of_labels=dict_of_labels, current_depth=current_depth, ) else: children_list = [] _find_all_children(d, children_list) for key, value in d.items(): if key == "id": for child in children_list: dict_of_labels[child] = value return dict_of_labels
[docs]def find_labels_dic(segmentation_array, dic, chosen_depth): """Collapse existing labels into parent labels corresponding to the tree provided in a dictionary. Parameters ---------- segmentation_array : np.array Annotation array before the concatenation of the labels. dic : dict Dictionary of tree of labels. chosen_depth : int Depth at which it is wanted to concatenate the labels. Returns ------- new_segmentation_array : np.array New Annotation array with the concatenation of the labels at the desired depth. If a specific label does not exist in the tree it is assigned -1. """ labels_dic = _find_concatenate_labels(dic, chosen_depth) new_segmentation_array = segmentation_array.copy() all_labels = np.unique(segmentation_array) for label in all_labels: if label != 0: new_label = labels_dic.get(label, -1) new_segmentation_array[new_segmentation_array == label] = new_label return new_segmentation_array