Skip to content

Sensing circuits

queso.sensors.tc.sensor.Sensor

The Sensor class represents a quantum sensor. It is initialized with the number of qubits (n) and the number of layers (k) in the quantum circuit.

The class provides methods for creating the quantum circuit, calculating the quantum state, probabilities, and quantum Fisher information (QFI), sampling measurements, and more.

Attributes:

Name Type Description
n int

The number of qubits in the quantum circuit.

k int

The number of layers in the quantum circuit.

preparation function

The function to prepare the quantum state.

interaction function

The function to apply the interaction Hamiltonian.

detection function

The function to apply the detection Hamiltonian.

theta ndarray

The parameters for the preparation function.

phi float

The parameter for the interaction function.

mu ndarray

The parameters for the detection function.

layers dict

A dictionary containing the names of the preparation, interaction, and detection layers.

Methods:

Name Description
circuit

Returns the quantum circuit.

state

Returns the quantum state.

probs

Returns the probabilities of the quantum state.

sample

Returns samples of measurements.

qfi

Returns the quantum Fisher information.

cfi

Returns the classical Fisher information.

entanglement

Returns the entanglement entropy.

sample_over_phases

Returns samples of measurements over different phases.

Source code in queso/sensors/tc/sensor.py
class Sensor:
    """
    The Sensor class represents a quantum sensor. It is initialized with the number of qubits (n) and the number of layers (k) in the quantum circuit.

    The class provides methods for creating the quantum circuit, calculating the quantum state, probabilities, and quantum Fisher information (QFI), sampling measurements, and more.

    Attributes:
        n (int): The number of qubits in the quantum circuit.
        k (int): The number of layers in the quantum circuit.
        preparation (function): The function to prepare the quantum state.
        interaction (function): The function to apply the interaction Hamiltonian.
        detection (function): The function to apply the detection Hamiltonian.
        theta (jax.numpy.ndarray): The parameters for the preparation function.
        phi (float): The parameter for the interaction function.
        mu (jax.numpy.ndarray): The parameters for the detection function.
        layers (dict): A dictionary containing the names of the preparation, interaction, and detection layers.

    Methods:
        circuit(theta, phi, mu): Returns the quantum circuit.
        state(theta, phi): Returns the quantum state.
        probs(theta, phi, mu): Returns the probabilities of the quantum state.
        sample(theta, phi, mu, key=None, n_shots=100, verbose=False): Returns samples of measurements.
        qfi(theta, phi): Returns the quantum Fisher information.
        cfi(theta, phi, mu): Returns the classical Fisher information.
        entanglement(theta, phi): Returns the entanglement entropy.
        sample_over_phases(theta, phis, mu, n_shots, key=None, verbose=False): Returns samples of measurements over different phases.
    """
    def __init__(
        self,
        n,
        k,
        **kwargs,
    ):
        self.n = n
        self.k = k
        self.kwargs = kwargs
        backend = kwargs.get("backend", "ket")
        if backend == "ket":
            self._circ = tc.Circuit
        elif backend == "dm":
            self._circ = tc.DMCircuit
        else:
            raise ValueError

        # tc.set_contractor(contractor)  # “auto”, “greedy”, “branch”, “plain”, “tng”, “custom”

        # default circuits
        preparation = kwargs.get("preparation", "hardware_efficient_ansatz")
        interaction = kwargs.get("interaction", "local_rz")
        detection = kwargs.get("detection", "local_r")

        self.preparation, self.theta = set_preparation(preparation, n, k, kwargs)
        self.interaction, self.phi = set_interaction(interaction)
        self.detection, self.mu = set_detection(detection, n, k)
        self.layers = dict(
            preparation=preparation, interaction=interaction, detection=detection
        )

        return

    def init_params(self, key=None):
        if key is None:
            key = jax.random.PRNGKey(time.time_ns())
        keys = jax.random.split(key, 3)
        return (
            jax.random.uniform(keys[0], self.theta.shape),
            jax.random.uniform(keys[0], self.phi.shape),
            jax.random.uniform(keys[0], self.mu.shape),
        )

    def circuit(self, theta, phi, mu):
        c = self._circ(self.n)
        c = self.preparation(c, theta, self.n, self.k)
        c = self.interaction(c, phi, self.n)
        c = self.detection(c, mu, self.n, self.k)
        return c

    @partial(jax.jit, static_argnums=(0,))
    def state(self, theta, phi):
        c = self._circ(self.n)
        c = self.preparation(c, theta, self.n, self.k)
        c = self.interaction(c, phi, self.n)
        return c.state()

    @partial(jax.jit, static_argnums=(0,))
    def probs(self, theta, phi, mu):
        c = self._circ(self.n)
        c = self.preparation(c, theta, self.n, self.k)
        c = self.interaction(c, phi, self.n)
        c = self.detection(c, mu, self.n, self.k)
        return c.probability()

    @partial(jax.jit, static_argnums=(0,), backend="cpu")
    def _sample(self, theta, phi, mu, key):
        c = self._circ(self.n)
        c = self.preparation(c, theta, self.n, self.k)
        c = self.interaction(c, phi, self.n)
        c = self.detection(c, mu, self.n, self.k)

        backend.set_random_state(key)
        return c.measure(*list(range(self.n)))[0]

    # @partial(jax.jit, static_argnums=(0,))
    def sample(self, theta, phi, mu, key=None, n_shots=100, verbose=False):
        if key is None:
            key = jax.random.PRNGKey(time.time_ns())
        keys = jax.random.split(key, n_shots)
        shots = jnp.array([self._sample(theta, phi, mu, key) for key in keys]).astype(
            "bool"
        )
        return shots

    @partial(jax.jit, static_argnums=(0,))
    def qfi(self, theta, phi):
        psi = self.state(theta, phi)
        dpsi = jax.jacrev(self.state, argnums=1, holomorphic=True)(
            theta.astype("complex64"), phi.astype("complex64")
        )
        fi = (
            4
            * jnp.real(
                (
                    jnp.conj(dpsi[None, :]) @ dpsi[:, None]
                    - jnp.abs(jnp.conj(dpsi[None, :]) @ psi[:, None])**2
                )
            ).squeeze()
        )
        return fi

    @partial(jax.jit, static_argnums=(0,))
    def cfi(self, theta, phi, mu):
        pr = self.probs(theta, phi, mu)
        dpr = jax.jacrev(self.probs, argnums=1, holomorphic=False)(theta, phi, mu)
        # fi = jnp.sum((jnp.power(dpr, 2) / pr))
        fi = jnp.nansum((jnp.power(dpr, 2) / pr))  # todo: check if removing nans helps/hurts numerical stability
        return fi

    @partial(jax.jit, static_argnums=(0,))
    def entanglement(self, theta, phi):
        # state = self.state(theta, phi)
        c = self._circ(self.n)
        c = self.preparation(c, theta, self.n, self.k)
        state = c.state()
        rho_A = tc.quantum.reduced_density_matrix(
            state, [i for i in range(self.n // 2)]
        )
        entropy = tc.quantum.entropy(rho_A)
        return entropy

    def sample_over_phases(self, theta, phis, mu, n_shots, key=None, verbose=False):
        check = self.sample(theta, 0.0, mu, key=key, n_shots=1, verbose=verbose)
        print(f"Sampling at φ = {phis}")
        print(check.device())
        if key is None:
            key = jax.random.PRNGKey(time.time_ns())
        keys = jax.random.split(key, phis.shape[0])
        data = [
            self.sample(theta, phi, mu, key=key, n_shots=n_shots, verbose=verbose)
            for (phi, key) in tqdm(zip(phis, keys), total=phis.size)
        ]
        data = jnp.stack(data, axis=0)
        probs = jnp.stack([self.probs(theta, phi, mu) for phi in phis], axis=0)
        return data, probs