import numpy as np
import matplotlib.pyplot as plt
import time
from qutip import Qobj
import tensorflow as tf
import warnings
from qsttoolkit.quantum import fidelity
from qsttoolkit.plots import plot_hinton, plot_husimi_Q, plot_wigner
from qsttoolkit.utils import _subplot_number, _subplot_figsize, _deprecation_warning, _no_longer_required_warning
from qsttoolkit.data.measurement import Husimi_Q_measurement_operators, photon_number_measurement_operators, measurement_operators
from qsttoolkit.tomography.loss import log_likelihood
##### Cholesky parametrization functions - once more are introduced, move to their own file #####
[docs]
def parametrize_density_matrix(rho: tf.Tensor) -> tf.Tensor:
"""
parametrizes the density matrix using the Cholesky decomposition.
Parameters
----------
rho : tf.Tensor
Density matrix to be parametrized.
Returns
-------
tf.Tensor
Cholesky decomposition of the density matrix.
"""
if type(rho) == Qobj: rho = rho.full()
T = tf.linalg.cholesky(rho) # Lower triangular (batch_size, dim, dim)
return T
[docs]
def parameterise_density_matrix(rho: tf.Tensor) -> tf.Tensor:
"""Deprecated alias for parametrize_density_matrix."""
_deprecation_warning('parameterise_density_matrix', 'parametrize_density_matrix')
return parametrize_density_matrix(rho)
[docs]
def reconstruct_density_matrix(params: tf.Tensor, reg: float=1.0e-10, dim=None) -> tf.Tensor:
"""
Reconstructs the density matrix from the Cholesky decomposition.
Parameters
----------
params : tf.Tensor
Cholesky decomposition of the density matrix.
Returns
-------
tf.Tensor
Reconstructed density matrix.
"""
if dim is not None: _no_longer_required_warning('dim')
# Compute density matrix
rho = tf.matmul(tf.linalg.adjoint(params), params)
# Regularisation to prevent singular matrices (adding a small identity term)
dim = tf.shape(rho)[1]
rho += reg * tf.eye(dim, dtype=tf.complex128)
# Normalize to ensure trace = 1
rho /= tf.linalg.trace(rho)
return rho
##### Define constraints - no longer used by MLE #####
[docs]
def trace_constraint(params: np.ndarray) -> float:
"""
Constraint function to ensure the trace of the density matrix is 1.
Parameters
----------
params : np.ndarray
Flattened vector of real parameters.
Returns
-------
float
Difference between the trace of the reconstructed density matrix and 1.
"""
warnings.warn("The trace_constraint function is deprecated and will be removed in a future version. The trace of the density matrix is now enforced by reconstruct_density_matrix function.", DeprecationWarning, stacklevel=2)
rho = reconstruct_density_matrix(params)
return np.trace(rho).real - 1 # Should be zero
[docs]
def positivity_constraint(params: np.ndarray) -> float:
"""
Constraint to ensure the density matrix is positive semi-definite.
Parameters
----------
params : np.ndarray
Flattened vector of real parameters.
Returns
-------
float
Smallest eigenvalue of the reconstructed density matrix.
"""
warnings.warn("The positivity_constraint function is deprecated and will be removed in a future version. The trace of the density matrix is now enforced by reconstruct_density_matrix function.", DeprecationWarning, stacklevel=2)
rho = reconstruct_density_matrix(params)
eigenvalues = np.linalg.eigvalsh(rho) # Eigenvalues of rho
return np.min(eigenvalues) # Should be >= 0
##### Parent class for all QST methods #####
[docs]
class QuantumStateTomography:
"""A parent class for all quantum state tomography methods."""
def __init__(self):
self.reconstructed_dm = None
self.progress_saves = None
self.fidelities = None
self.times = None
[docs]
def fidelity(self, true_dm: np.ndarray) -> float:
"""
Computes the fidelity between the true and reconstructed density matrices.
Parameters
----------
true_dm : np.ndarray
True density matrix.
Returns
-------
float
Fidelity between the true and reconstructed density matrices.
"""
if len(self.reconstructed_dm.shape) != 2: raise ValueError("Invalid shape of reconstructed density matrix.")
return fidelity(true_dm, self.reconstructed_dm)
[docs]
def plot_losses(self):
"""Plots the losses over epochs."""
plt.figure(figsize=(5, 4))
plt.plot(self.losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Losses over epochs')
plt.show()
[docs]
def plot_fidelities(self, true_dm=None):
"""Plots the fidelity between the true and reconstructed density matrices over epochs."""
if true_dm is not None: _no_longer_required_warning('true_dm')
plt.figure(figsize=(5, 4))
plt.plot(self.fidelities)
plt.ylim(0,1)
plt.xlabel('Epoch')
plt.ylabel('Fidelity')
plt.title('Fidelity over epochs')
plt.show()
[docs]
def plot_times(self):
"""Plots the cumulative time taken for each epoch."""
plt.figure(figsize=(5, 4))
plt.plot(self.times)
plt.xlabel('Epoch')
plt.ylabel('Time (s)')
plt.title('Time taken after epochs')
plt.show()
[docs]
def plot_comparison_hintons(self, true_dm: np.ndarray):
"""
Plots Hinton diagrams of the true and reconstructed density matrices.
Parameters
----------
true_dm : np.ndarray
True density matrix.
"""
if type(true_dm) == Qobj:
true_dm = true_dm.full()
elif type(true_dm) == tf.Tensor:
true_dm = true_dm.numpy()
elif type(true_dm) != np.ndarray:
raise ValueError("unrecognized data type for true_dm.")
_, axs = plt.subplots(1, 2, figsize=(10, 5))
plot_hinton(true_dm, ax=axs[0], label='true density matrix')
if len(self.reconstructed_dm.shape) == 2:
reconstruction = self.reconstructed_dm
else:
raise ValueError("Invalid shape of reconstructed density matrix.")
plot_hinton(reconstruction, ax=axs[1], label='optimized density matrix')
plt.show()
[docs]
def plot_comparison_Hintons(self, true_dm: np.ndarray):
"""Deprecated alias for plot_comparison_hintons. Plots Hinton diagrams of the true and reconstructed density matrices."""
_deprecation_warning('plot_comparison_Hintons', 'plot_comparison_hintons')
return self.plot_comparison_hintons(true_dm)
[docs]
def plot_comparison_husimi_Qs(self, true_dm: np.ndarray, xgrid: np.ndarray, pgrid: np.ndarray):
"""
Plots the Husimi Q functions of the true and reconstructed density matrices.
Parameters
----------
true_dm : np.ndarray
True density matrix.
xgrid : np.ndarray
Phase space X quadrature grid.
pgrid : np.ndarray
Phase space P quadrature grid.
"""
if not isinstance(xgrid, np.ndarray) or not isinstance(pgrid, np.ndarray): raise ValueError("xgrid and pgrid must be numpy arrays.")
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
plot_husimi_Q(true_dm, xgrid, pgrid, fig=fig, ax=axs[0], label='true density matrix')
if len(self.reconstructed_dm.shape) == 2:
reconstruction = self.reconstructed_dm
else:
raise ValueError("Invalid shape of reconstructed density matrix.")
plot_husimi_Q(reconstruction, xgrid, pgrid, fig=fig, ax=axs[1], label='reconstructed density matrix')
plt.show()
[docs]
def plot_comparison_Husimi_Qs(self, true_dm: np.ndarray, xgrid: np.ndarray, pgrid: np.ndarray):
"""Deprecated alias for plot_comparison_husimi_Qs. Plots the Husimi Q functions of the true and reconstructed density matrices."""
_deprecation_warning('plot_comparison_Husimi_Qs', 'plot_comparison_husimi_Qs')
return self.plot_comparison_husimi_Qs(true_dm, xgrid, pgrid)
[docs]
def plot_comparison_wigners(self, true_dm: np.ndarray, xgrid: np.ndarray, pgrid: np.ndarray):
"""
Plots the Wigner functions of the true and reconstructed density matrices.
Parameters
----------
true_dm : np.ndarray
True density matrix.
xgrid : np.ndarray
Phase space X quadrature grid.
pgrid : np.ndarray
Phase space P quadrature grid.
"""
if not isinstance(xgrid, np.ndarray) or not isinstance(pgrid, np.ndarray): raise ValueError("xgrid and pgrid must be numpy arrays.")
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
plot_wigner(true_dm, xgrid, pgrid, fig=fig, ax=axs[0], label='true density matrix')
if len(self.reconstructed_dm.shape) == 2:
reconstruction = self.reconstructed_dm
else:
raise ValueError("Invalid shape of reconstructed density matrix.")
plot_wigner(reconstruction, xgrid, pgrid, fig=fig, ax=axs[1], label='reconstructed density matrix')
plt.show()
[docs]
class CustomQuantumStateTomography(QuantumStateTomography):
"""
A class for designing custom quantum state tomography methods.
Attributes
----------
model : tf.keras.Model
Model used for the reconstruction.
training_step_fn : callable
Function that defines the training step for the model. Arguments must include [model, measurement_data, measurement_operators]. Must return the generated density matrix and the loss.
"""
def __init__(self, model, training_step_fn):
super().__init__()
self.model = model
self.training_step_fn = training_step_fn
[docs]
def reconstruct(self, measurement_data, measurement_operators, epochs, optimizer, verbose_interval: int=None, num_progress_saves: int=None, true_dm: tf.Tensor=None, time_log_interval: int=None, **kwargs):
"""
Reconstructs the density matrix using a custom method.
Parameters
----------
initial_dm : np.ndarray
Initial density matrix.
measurement_data : np.ndarray
Frequency of each measurement outcome.
measurement_operators : np.ndarray
Projective operators corresponding to the measurement outcomes.
epochs : int
Number of training epochs.
optimizer : tf.keras.optimizers.Optimizer
Optimizer for the training step.
verbose_interval : int
Interval at which to print progress updates. Defaults to None.
num_progress_saves : int
Number of intermediate progress saves to make. Defaults to None.
true_dm : tf.Tensor
True density matrix used for calculating fidelities. Defaults to None.
time_log_interval : int
Interval at which to log the time taken after each epoch. Defaults to None.
"""
# Input error handling
if len(measurement_data[0]) != len(measurement_operators): raise ValueError("measurement_data[0] and measurement_operators must have the same length.")
if not all([isinstance(data, np.ndarray) for data in measurement_data]): raise ValueError("All elements of measurement_data must be numpy arrays.")
# if not all([isinstance(op, np.ndarray) for op in measurement_operators]): raise ValueError("All elements of measurement_operators must be numpy arrays.")
if not isinstance(epochs, int): raise ValueError("epochs must be an integer.")
if not isinstance(verbose_interval, int) and verbose_interval is not None: raise ValueError("verbose_interval must be an integer.")
if not isinstance(num_progress_saves, int) and num_progress_saves is not None: raise ValueError("num_progress_saves must be an integer.")
if not isinstance(time_log_interval, int) and time_log_interval is not None: raise ValueError("time_log_interval must be an integer.")
self.optimizer = optimizer
self.losses = []
if num_progress_saves:
progress_save_interval = epochs // num_progress_saves
self.progress_saves = []
else:
self.progress_saves = None
self.fidelities = [] if true_dm is not None else None
if time_log_interval:
start_time = time.time()
self.times = []
else:
self.times = None
for epoch in range(epochs):
# Forward pass through generator
with tf.GradientTape() as tape:
generated_dm, epoch_loss = self.training_step_fn(self.model, measurement_data, measurement_operators, **kwargs)
# Backpropagation
gradients = tape.gradient(epoch_loss, self.model.trainable_variables)
optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
self.losses.append(epoch_loss.numpy())
# Log fidelity
if true_dm is not None:
epoch_fidelity = fidelity(true_dm, generated_dm[0].numpy())
self.fidelities.append(epoch_fidelity)
# Save progress
if num_progress_saves and epoch % progress_save_interval == 0:
self.progress_saves.append(generated_dm[0].numpy())
# Log progress
if verbose_interval and epoch % verbose_interval == 0:
print(f"Epoch {epoch}/{epochs}, Loss: {epoch_loss.numpy()}, Fidelity: {epoch_fidelity if true_dm is not None else None}")
# Log time
if time_log_interval and epoch % time_log_interval == 0:
self.times.append(time.time() - start_time)
self.reconstructed_dm = generated_dm[0].numpy()
if verbose_interval: print('Reconstruction complete.')