import tensorflow as tf
from tensorflow.keras import layers, Model
[docs]
def build_classification_tail(input_features: tf.Tensor, num_classes: int) -> tf.Tensor:
"""
Builds the classification tail of the multitask reconstructor model.
Parameters
----------
input_features : tf.Tensor
Input feature vector.
num_classes : int
Number of classes to classify into.
Returns
-------
tf.Tensor
Output tensor of the classification tail.
"""
x = layers.Dense(64)(input_features)
x = layers.LeakyReLU(negative_slope=0.1)(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(128)(x)
x = layers.LeakyReLU(negative_slope=0.1)(x)
output = layers.Dense(num_classes, activation='softmax', name="classification_output")(x)
return output
# Regression tail Future work could implement two networks, one to predict each of the real and complex parts of the parameter
[docs]
def build_regression_tail(input_features: tf.Tensor, num_outputs: int) -> tf.Tensor:
"""
Builds the regression tail of the multitask reconstructor model.
Parameters
----------
input_features : tf.Tensor
Input feature vector.
num_outputs : int
Number of regression outputs.
Returns
-------
tf.Tensor
Output tensor of the regression tail.
"""
x = layers.Dense(64)(input_features)
x = layers.LeakyReLU(negative_slope=0.1)(x)
x = layers.Dense(128)(x)
x = layers.LeakyReLU(negative_slope=0.1)(x)
x = layers.Dense(256)(x)
x = layers.LeakyReLU(negative_slope=0.1)(x)
x = layers.Dense(256)(x)
x = layers.LeakyReLU(negative_slope=0.1)(x)
x = layers.Dense(256)(x)
x = layers.LeakyReLU(negative_slope=0.1)(x)
x = layers.Dense(128)(x)
x = layers.LeakyReLU(negative_slope=0.1)(x)
output = layers.Dense(num_outputs, name="regression_output")(x)
return output
[docs]
def build_multitask_reconstructor(input_shape: tuple, num_classes: int, num_regression_outputs: int) -> tf.keras.Model:
"""
Builds the multitask reconstructor model.
Parameters
----------
input_shape : tuple
Shape of the input data.
num_classes : int
Number of classes to classify.
num_regression_outputs : int
Number of regression outputs.
Returns
-------
tf.keras.Model
Multitask reconstructor model.
"""
inputs = tf.keras.Input(shape=input_shape)
feature_extractor = build_feature_extractor(input_shape)
feature_vector = feature_extractor(inputs)
classification_tail = build_classification_tail(feature_vector, num_classes)
regression_input = layers.Concatenate()([feature_vector, classification_tail])
regression_tail = build_regression_tail(regression_input, num_regression_outputs)
return Model(inputs=inputs, outputs=[classification_tail, regression_tail])