"""Module containing utilities for manipulation of keras models."""
"""
The package atlalign is a tool for registration of 2D images.
Copyright (C) 2021 EPFL/Blue Brain Project
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import json
import pathlib
from copy import deepcopy
from tensorflow.keras.layers import Lambda, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model as load_model_keras
from tensorflow.keras.models import model_from_json
from atlalign.ml_utils import (
Affine2DVF,
BilinearInterpolation,
DVFComposition,
ExtractMoving,
NoOp,
)
[docs]def merge_global_local(model_g, model_l, expose_global=False):
"""Merge a global and local aligner models into a new one.
Seems to be also changing the input models.
Parameters
----------
model_g : keras.Model
Model performing global alignment.
model_l : keras.Model
Model performing local alignment.
expose_global : bool, optional
If True, then model has 4 outputs where the last two represent image after global alignment
and the corresponding dvfs. If False, then only 2 outputs : image after both global and
local and the overall dvfs.
Returns
-------
model_gl : keras.Model
Model performing both the local and the global alignment.
"""
# define slicing layers
extract_0 = Lambda(lambda x: x[..., :-1])
is_inverse = isinstance(model_l.inputs, list) and len(model_l.inputs) == 2
# prepare some tensors
overall_input = model_g.input
img_ref = extract_0(overall_input)
img_reg_g, dvfs_g = model_g.outputs[
:2
] # in case it has 3 outputs i.e inverse model
middle_input = concatenate([img_ref, img_reg_g])
if is_inverse:
middle_input = [middle_input, middle_input] # quick hack
img_reg_gl, dvfs_l = model_l(middle_input)[:2]
dvfs_gl = DVFComposition()([dvfs_g, dvfs_l])
new_output = [img_reg_gl, dvfs_gl]
if expose_global:
new_output += [img_reg_g, dvfs_g]
model_gl = Model(inputs=overall_input, outputs=new_output)
return model_gl
[docs]def save_model(model, path, separate=True, overwrite=True):
"""Save model.
Parameters
----------
model : keras.Model
Keras model to be saved.
path : str or pathlib.Path
Path to where to save the serialized model. If `separate=True` then it needs to represent a folder name.
Inside of the folder 2 files are created - weights (.h5) and architecture (.json). If `separate=False` then
an extension `.h5` is added and architecture + weights are dumped into one file.
separate : bool
If True, then architecture and weights are stored separately. Note that if False then one might encounter
issues when loading in in a different Python environment (see references).
overwrite : bool
If True, then possible existing files/folders are overwritten.
References
----------
[1] https://github.com/keras-team/keras/issues/9595
"""
path = pathlib.Path(path)
if path.suffix:
raise ValueError("Please specify a path without extension (folder).")
if not separate:
model.save(str(path) + ".h5", overwrite=overwrite, save_format="h5")
else:
path_architecture = path / (path.stem + ".json")
path_weights = path / (path.stem + ".h5")
if path_architecture.exists() and not overwrite:
raise FileExistsError(
"The file {} already exists and overwriting is disabled.".format(
path_architecture
)
)
if not path_architecture.exists():
path_architecture.parent.mkdir(
parents=True, exist_ok=True
) # maybe the folder was already created before
path_architecture.touch()
with path_architecture.open("w") as f_a:
json.dump(model.to_json(), f_a)
model.save_weights(str(path_weights), overwrite=overwrite)
[docs]def load_model(path, compile=False):
"""Load a model that uses custom `atlalign` layers.
The benefit of using this function as opposed to the keras equivalent is that the user does not have to care about
how the model was saved (whether architecture and weights were separated). Additionally, all custom possible
custom layers are provided.
Parameters
----------
path : str or pathlib.Path
If `path` is a folder then the folder is expected to have one `.h5` file (weights) and one `.json`
(architecture). If `path` is a file then it needs to be an `.h5` and it needs to encapsulate both
the weights and the architecture.
compile : bool
Only possible if `path` refers to a `.h5` file.
Returns
-------
keras.Model
Model ready for inference. If `compile=True` then also ready for continuing the training.
"""
path = pathlib.Path(str(path))
if path.is_file():
model = load_model_keras(
str(path),
compile=compile,
custom_objects={
"Affine2DVF": Affine2DVF,
"DVFComposition": DVFComposition,
"BilinearInterpolation": BilinearInterpolation,
"ExtractMoving": ExtractMoving,
"NoOp": NoOp,
},
)
elif path.is_dir():
if compile:
raise ValueError(
"Cannot compile the model because weights and architecture stored separately."
)
h5_files = [p for p in path.iterdir() if p.suffix == ".h5"]
json_files = [p for p in path.iterdir() if p.suffix == ".json"]
if not (len(h5_files) == 1 and len(json_files) == 1):
raise ValueError(
"The folder {} needs to contain exactly one .h5 file and one .json file".format(
path
)
)
path_architecture = json_files[0]
path_weights = h5_files[0]
with path_architecture.open("r") as f:
json_str = json.load(f)
model = model_from_json(
json_str,
custom_objects={
"Affine2DVF": Affine2DVF,
"DVFComposition": DVFComposition,
"BilinearInterpolation": BilinearInterpolation,
"ExtractMoving": ExtractMoving,
"NoOp": NoOp,
},
)
model.load_weights(str(path_weights))
else:
raise OSError("The path {} does not exist.".format(str(path)))
return model
[docs]def replace_lambda_in_config(input_config, output_format="dict", verbose=False):
"""Replace Lambda layers with full blown keras layers.
This function only exists because we pretrained a lot of models with 2 different
Lambda layers and only after that realized that they cause issues during
serialization. One can then use this function to just fix it.
Notes
-----
To make this clear let us define the top dictionary as the one that has keys
- 'backend'
- 'config'
- 'keras_version'
- 'class_name'
The bottom dictionary is top['config'].
Parameters
----------
input_config : dict or str or pathlib.Path or keras.Model
Config containing an architecture generated by one of the functions in `atlalign.nn` possibly
containing Lambda layers.
output_format : str, {'dict', 'keras', 'json'}
What output type to use. See how below how to instantiate a model out of each of the formats
- 'dict' - `keras.Model.from_config`
- 'json' - `kersa.models.model_from_json`
- 'keras' - already a model instance
verbose : bool
If True, printing to standard output.
Returns
-------
output_config : dict
Config containing an architecture of the input network but all Lambda layers are replaced by full
blown operations.
"""
translation_map = {"extract_moving": "ExtractMoving", "inv_dvf": "NoOp"}
def lambda_replacer(layer_dict, verbose=False):
"""Replace a layer specific dict (only if Lambda layer).
Parameters
----------
layer_dict : dict
Corresponds to an element of `json.loads(model.to_json())['config']['layers']`.
verbose : bool
If True, printing to standard output.
Returns
-------
If not a Lambda layer then returns an unmodified `layer_dict`. However if it is a Lambda layer
it is replaced by a predefined custom (non-Lambda) Layer.
"""
copy_dict = deepcopy(layer_dict)
if "class_name" not in copy_dict:
raise KeyError("Does not contain class_name")
if copy_dict["class_name"] != "Lambda":
if verbose:
print("Not a Lambda")
return copy_dict
if "config" not in copy_dict:
raise KeyError("Does not contain config")
name = copy_dict["config"]["name"]
if name in translation_map:
copy_dict["class_name"] = translation_map[name]
new_config = {
"name": name,
"trainable": True,
}
copy_dict["config"] = new_config
else:
raise KeyError(
"Stumbled upon a lambda layer with an unrecognized name: {}".format(
name
)
)
return copy_dict
if isinstance(input_config, str):
# assuming it came from model.to_json()
config_dict = json.loads(input_config)["config"]
elif isinstance(input_config, Model):
config_dict = input_config.get_config()
elif isinstance(input_config, dict):
config_dict = input_config
elif isinstance(input_config, pathlib.Path):
if input_config.suffix != ".json":
raise ValueError(
"The only allowed extension is .json, {} is unsupported".format(
input_config.suffix
)
)
with input_config.open("r") as f:
config_dict = json.loads(json.load(f))["config"]
else:
raise TypeError(
"Unsupported type of input_config: {}".format(type(input_config))
)
hacked_config = deepcopy(config_dict)
hacked_config["layers"] = []
for i, x in enumerate(config_dict["layers"]):
hacked_config["layers"].append(lambda_replacer(x, verbose=verbose))
if verbose:
print("Before == after: {}".format(x == hacked_config["layers"][-1]))
print("\n\n")
if output_format == "json":
return Model.from_config(
hacked_config,
custom_objects={
"Affine2DVF": Affine2DVF,
"DVFComposition": DVFComposition,
"BilinearInterpolation": BilinearInterpolation,
"ExtractMoving": ExtractMoving,
"NoOp": NoOp,
},
).to_json()
elif output_format == "dict":
return hacked_config
elif output_format == "keras":
final_model = Model.from_config(
hacked_config,
custom_objects={
"Affine2DVF": Affine2DVF,
"DVFComposition": DVFComposition,
"BilinearInterpolation": BilinearInterpolation,
"ExtractMoving": ExtractMoving,
"NoOp": NoOp,
},
)
if isinstance(input_config, Model):
final_model.set_weights(input_config.get_weights())
return final_model
else:
raise TypeError("Unrecognized output format: {}".format(output_format))