Source code for torchquantum.operator.standard_gates.trainable_unitary

from ..op_types import *
from abc import ABCMeta
from torchquantum.macro import C_DTYPE
import torchquantum as tq
import torch
from torchquantum.functional import mat_dict
import torchquantum.functional as tqf


[docs]class TrainableUnitary(Operation, metaclass=ABCMeta): """Class for TrainableUnitary Gate.""" num_params = AnyNParams num_wires = AnyWires op_name = "trainableunitary" func = staticmethod(tqf.qubitunitaryfast)
[docs] def build_params(self, trainable): """Build the parameters for the gate. Args: trainable (bool): Whether the parameters are trainble. Returns: torch.Tensor: Parameters. """ parameters = nn.Parameter( torch.empty(1, 2**self.n_wires, 2**self.n_wires, dtype=C_DTYPE) ) parameters.requires_grad = True if trainable else False # self.register_parameter(f"{self.name}_params", parameters) return parameters
[docs] def reset_params(self, init_params=None): """Reset the parameters. Args: init_params (torch.Tensor, optional): Initial parameters. Returns: None. """ mat = torch.randn((1, 2**self.n_wires, 2**self.n_wires), dtype=C_DTYPE) U, Sigma, V = torch.svd(mat) self.params.data.copy_(U.matmul(V.permute(0, 2, 1)))
@staticmethod def _matrix(self, params): return tqf.qubitunitaryfast(params)
[docs]class TrainableUnitaryStrict(TrainableUnitary, metaclass=ABCMeta): """Class for Strict Unitary matrix gate.""" num_params = AnyNParams num_wires = AnyWires op_name = "trainableunitarystrict" func = staticmethod(tqf.qubitunitarystrict)