Source code for atlalign.utils

"""Collection of helper classes and function that do not deserve to be in

This module cannot import from anywhere else within this project to prevent circular dependencies.


    The package atlalign is a tool for registration of 2D images.

    Copyright (C) 2021 EPFL/Blue Brain Project

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <>.

import numpy as np
import scipy.spatial.qhull as qhull

def _triangulate(xyz, uvw):
    """Perform Delaunay triangulation.

    xyz : np.ndarray
        An array of shape (N, 2) where each row represents one point in 2D (stable points) for which we know the
        function value.

    uvw : np.ndarray
        An array of shape (K, 2) where each row represents one point in 2D (query point) for which we want to
        interpolate the function value.

    vertices : np.ndarray
        An array of shape (K, 3) representing the triangle vertices of each query point. Note that these
        vertices are always stable points (from range [O, N))

    wts : np.ndarray
        An array of shape (K, 3) representing the weights of respective vertices at each query point.

    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, 2]
    bary = np.einsum("njk,nk->nj", temp[:, :2, :], delta)
    wts = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

    return vertices, wts

def _interpolate(values, vertices, wts):
    """Interpolate inside a triangle.

    values : np.ndarray
        An array of shape (N,) that represents function value on the know points which are the vertices after
        Delaunay triangulation.

    vertices : np.ndarray
        An array of shape (K, 3) representing the triangle vertices of each query point. Note that these
        vertices are always stable points (from range [O, N)).

    wts : np.ndarray
        An array of shape (K, 3) representing the weights of respective vertices at each query point.

    interpolations : np.ndarray
        An array of shape (K,) representing the interpolated function values on the query points.

    return np.einsum("nj,nj->n", np.take(values, vertices), wts)

[docs]def griddata_custom(points, values_f_1, values_f_2, xi): """Run griddata extensions that performs only one triangulation. Notes ----- The scipy implementation does not allow to separate triangulation from interpolation. Since we need to evaluate 2 different functions on the !same! non-regular grid if points the triangulation can be simply just done once and stored. Parameters ---------- points : np.ndarray An array of shape (N, 2) where each row represents one point in 2D for which we know the function value. values_f_1 : np.ndarray An array of shape (N,) where each row represents a value of function f_1 on the corresponding point in `points`. values_f_2 : np.ndarray An array of shape (N,) where each row represents a value of function f_2 on the corresponding point in `points`. xi : tuple Tuple of 2 np.ndarray of shapes (h, w) representing the x and y coordinates of the points where we want to interpolate data. Note that this is simply the result of `np.meshgrid` if our points of interest lie on a regular grid. Returns ------- f_1_interpolation_on_xi : np.ndarray An array of shape (h, w) representing the interpolation of f_1 on the `xi` points. f_2_interpolation_on_xi : np.ndarray An array of shape (h, w) representing the interpolation of f_2 on the `xi` points. References ---------- # noqa """ if isinstance(xi, tuple): shape = xi[0].shape xi = np.hstack((xi[0].reshape(-1, 1), xi[1].reshape(-1, 1))) # possible speedup else: raise TypeError("The xi needs to be a tuple of equally shaped np.ndarrays.") vertices, wts = _triangulate(points, xi) f_1_interpolation_on_xi = _interpolate(values_f_1, vertices, wts).reshape(shape) f_2_interpolation_on_xi = _interpolate(values_f_2, vertices, wts).reshape(shape) return f_1_interpolation_on_xi, f_2_interpolation_on_xi
def _find_all_children(d, children_list=None): """Construct a list of all the ids of the children of a node and the node itself. Parameters ---------- d : dict Dictionary node from whom we want the list of all the children and children's children. children_list : list, default None List of children which has to be empty for the first iteration of the function. Returns ------- children_list : list List of children's ids. """ if children_list is None: children_list = [] for key, value in d.items(): if key == "id": children_list.append(value) if isinstance(value, list): for child in value: _find_all_children(child, children_list) return children_list def _find_concatenate_labels(d, chosen_depth, dict_of_labels=None, current_depth=0): """Construct a dictionary which has for each key, the value of the new label after concatenation. Parameters ---------- d : dict Dictionary node for which we want to concatenate some ids depending on the depth branch. chosen_depth : int Depth at which it is wanted to concatenate the labels. dict_of_labels : dict, default {} Dictionary of corresponding labels (empty at the first call). current_depth : int, default 0 Depth of the dictionary node. Returns ------- dict_of_labels: dict Dictionary of corresponding labels after concatenation of labels tree. """ if dict_of_labels is None: dict_of_labels = {} if current_depth < chosen_depth: for key, value in d.items(): if key == "id": dict_of_labels[value] = value if isinstance(value, list): current_depth = current_depth + 1 for child in value: _find_concatenate_labels( child, chosen_depth, dict_of_labels=dict_of_labels, current_depth=current_depth, ) else: children_list = [] _find_all_children(d, children_list) for key, value in d.items(): if key == "id": for child in children_list: dict_of_labels[child] = value return dict_of_labels
[docs]def find_labels_dic(segmentation_array, dic, chosen_depth): """Collapse existing labels into parent labels corresponding to the tree provided in a dictionary. Parameters ---------- segmentation_array : np.array Annotation array before the concatenation of the labels. dic : dict Dictionary of tree of labels. chosen_depth : int Depth at which it is wanted to concatenate the labels. Returns ------- new_segmentation_array : np.array New Annotation array with the concatenation of the labels at the desired depth. If a specific label does not exist in the tree it is assigned -1. """ labels_dic = _find_concatenate_labels(dic, chosen_depth) new_segmentation_array = segmentation_array.copy() all_labels = np.unique(segmentation_array) for label in all_labels: if label != 0: new_label = labels_dic.get(label, -1) new_segmentation_array[new_segmentation_array == label] = new_label return new_segmentation_array