Source code for atlalign.augmentations

"""Module creating one-to-many augmentations."""

    The package atlalign is a tool for registration of 2D images.

    Copyright (C) 2021 EPFL/Blue Brain Project

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Lesser General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    GNU Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public License
    along with this program.  If not, see <>.

import h5py
import numpy as np
from skimage.feature import canny
from skimage.util import img_as_float32

from atlalign.base import DisplacementField

[docs]def load_dataset_in_memory(h5_path, dataset_name): """Load a dataset of a h5 file in memory.""" with h5py.File(h5_path, "r") as f: return f[dataset_name][:]
[docs]class DatasetAugmenter: """Class that does the augmentation. Attributes ---------- original_path : str Path to where the original dataset is located. """ def __init__(self, original_path): self.original_path = original_path self.n_orig = len(load_dataset_in_memory(self.original_path, "image_id"))
[docs] def augment( self, output_path, n_iter=10, anchor=True, p_reg=0.5, random_state=None, max_corrupted_pixels=500, ds_f=8, max_trials=5, ): """Augment the original dataset and create a new one. Note that this not modify the original dataset. Parameters ---------- output_path : str Path to where the new h5 file stored. n_iter : int Number of augmented samples per each sample in the original dataset. anchor : bool If True, then dvf anchored before inverted. p_reg : bool Probability that we start from a registered image (rather than the moving). random_state : bool Random state max_corrupted_pixels : int Maximum numbr of corrupted pixels allowed for a dvf - the actual number is computed as np.sum(df.jacobian() < 0) ds_f : int Downsampling factor for inverses. 1 creates the least artifacts. max_trials : int Max number of attemps to augment before an identity displacement used as augmentation. """ np.random.seed(random_state) n_new = n_iter * self.n_orig print(n_new) with h5py.File(self.original_path, "r") as f_orig: # extract dset_img_orig = f_orig["img"] dset_image_id_orig = f_orig["image_id"] dset_dataset_id_orig = f_orig["dataset_id"] dset_deltas_xy_orig = f_orig["deltas_xy"] dset_inv_deltas_xy_orig = f_orig["inv_deltas_xy"] dset_p_orig = f_orig["p"] with h5py.File(output_path, "w") as f_aug: dset_img_aug = f_aug.create_dataset( "img", (n_new, 320, 456), dtype="uint8" ) dset_image_id_aug = f_aug.create_dataset( "image_id", (n_new,), dtype="int" ) dset_dataset_id_aug = f_aug.create_dataset( "dataset_id", (n_new,), dtype="int" ) dset_p_aug = f_aug.create_dataset("p", (n_new,), dtype="int") dset_deltas_xy_aug = f_aug.create_dataset( "deltas_xy", (n_new, 320, 456, 2), dtype=np.float16 ) dset_inv_deltas_xy_aug = f_aug.create_dataset( "inv_deltas_xy", (n_new, 320, 456, 2), dtype=np.float16 ) for i in range(n_new): print(i) i_orig = i % self.n_orig mov2reg = DisplacementField( dset_deltas_xy_orig[i_orig, ..., 0], dset_deltas_xy_orig[i_orig, ..., 1], ) # copy dset_image_id_aug[i] = dset_image_id_orig[i_orig] dset_dataset_id_aug[i] = dset_dataset_id_orig[i_orig] dset_p_aug[i] = dset_p_orig[i_orig] use_reg = np.random.random() > p_reg print("Using registered: {}".format(use_reg)) if not use_reg: # mov != reg img_mov = dset_img_orig[i_orig] else: # mov=reg img_mov = mov2reg.warp(dset_img_orig[i_orig]) mov2reg = DisplacementField.generate( (320, 456), approach="identity" ) is_nice = False n_trials = 0 while not is_nice: n_trials += 1 if n_trials == max_trials: print("Replicating original: out of trials") dset_img_aug[i] = dset_img_orig[i_orig] dset_deltas_xy_aug[i] = dset_deltas_xy_orig[i_orig] dset_inv_deltas_xy_aug[i] = dset_inv_deltas_xy_orig[i_orig] break else: mov2art = self.generate_mov2art(img_mov) reg2mov = mov2reg.pseudo_inverse(ds_f=ds_f) reg2art = reg2mov(mov2art) # anchor if anchor: print("ANCHORING") reg2art = reg2art.anchor( ds_f=50, smooth=0, h_kept=0.9, w_kept=0.9 ) art2reg = reg2art.pseudo_inverse(ds_f=ds_f) validity_check = np.all( np.isfinite(reg2art.delta_x) ) and np.all(np.isfinite(reg2art.delta_y)) validity_check &= np.all( np.isfinite(art2reg.delta_x) ) and np.all(np.isfinite(art2reg.delta_y)) jacobian_check = ( np.sum(reg2art.jacobian < 0) < max_corrupted_pixels ) jacobian_check &= ( np.sum(art2reg.jacobian < 0) < max_corrupted_pixels ) if validity_check and jacobian_check: is_nice = True print("Check passed") else: print("Check failed") if n_trials != max_trials: dset_img_aug[i] = mov2art.warp(img_mov) dset_deltas_xy_aug[i] = np.stack( [art2reg.delta_x, art2reg.delta_y], axis=-1 ) dset_inv_deltas_xy_aug[i] = np.stack( [reg2art.delta_x, reg2art.delta_y], axis=-1 )
[docs] @staticmethod def generate_mov2art(img_mov, verbose=True, radius_max=60, use_normal=True): """Generate geometric augmentation and its inverse.""" shape = img_mov.shape img_mov_float = img_as_float32(img_mov) edge_mask = canny(img_mov_float) if use_normal: c = np.random.normal(0.7, 0.3) else: c = np.random.random() if verbose: print("Scalar: {}".format(c)) mov2art = c * DisplacementField.generate( shape, approach="edge_stretching", edge_mask=edge_mask, interpolation_method="rbf", interpolator_kwargs={"function": "linear"}, n_perturbation_points=6, radius_max=radius_max, ) return mov2art