Skip to content

Sampling sensor data

queso.sample.circuit.sample_circuit(io, config, key, plot=False, progress=True)

Samples a quantum circuit based on the provided configuration and random key.

This function initializes a sensor with the given configuration, samples the circuit, and optionally plots the results. The sampled data is saved in an HDF5 file.

Parameters:

Name Type Description Default
io IO

An IO object for handling file operations.

required
config Configuration

A Configuration object containing the parameters for the circuit.

required
key PRNGKey

A random key for JAX operations.

required
plot bool

If True, plots the true probabilities and relative frequencies. Defaults to False.

False
progress bool

If True, displays progress information. Defaults to True.

True

Returns:

Type Description

None

Source code in queso/sample/circuit.py
def sample_circuit(
    io: IO,
    config: Configuration,
    key: jax.random.PRNGKey,
    plot: bool = False,
    progress: bool = True,
):
    """
    Samples a quantum circuit based on the provided configuration and random key.

    This function initializes a sensor with the given configuration, samples the circuit,
    and optionally plots the results. The sampled data is saved in an HDF5 file.

    Args:
        io (IO): An IO object for handling file operations.
        config (Configuration): A Configuration object containing the parameters for the circuit.
        key (jax.random.PRNGKey): A random key for JAX operations.
        plot (bool, optional): If True, plots the true probabilities and relative frequencies. Defaults to False.
        progress (bool, optional): If True, displays progress information. Defaults to True.

    Returns:
        None
    """
    jax.config.update("jax_default_device", jax.devices(os.getenv("DEFAULT_DEVICE_SAMPLE_CIRC", "cpu"))[0])


    n = config.n
    k = config.k
    phi_range = config.phi_range
    n_phis = config.n_phis
    n_shots = config.n_shots
    kwargs = dict(
        preparation=config.preparation,
        interaction=config.interaction,
        detection=config.detection,
        backend=config.backend,
        n_ancilla=config.n_ancilla,
        gamma_dephasing=config.gamma_dephasing,
    )

    # %%
    print(f"Initializing sensor n={n}, k={k}")
    sensor = Sensor(n, k, **kwargs)

    # %%
    hf = h5py.File(io.path.joinpath("circ.h5"), "r")
    # print(hf.keys())
    theta = jnp.array(hf.get("theta"))
    mu = jnp.array(hf.get("mu"))
    hf.close()

    # %% training data set
    print(
        f"Sampling {n_shots} shots for {n_phis} phase value between {phi_range[0]} and {phi_range[1]}."
    )
    phis = (phi_range[1] - phi_range[0]) * jnp.arange(n_phis) / (
        n_phis - 1
    ) + phi_range[0]
    t0 = time.time()
    shots, probs = sensor.sample_over_phases(
        theta, phis, mu, n_shots=n_shots, verbose=True, key=key
    )
    t1 = time.time()
    print(f"Sampling took {t1 - t0} seconds.")

    # %%
    outcomes = sample_bin2int(shots, n)
    counts = jnp.stack(
        [
            jnp.count_nonzero(outcomes == x, axis=(1,), keepdims=True).squeeze()
            for x in range(2**n)
        ],
        axis=1,
    )
    freqs = counts / counts.sum(axis=-1, keepdims=True)
    bit_strings = sample_int2bin(jnp.arange(2**n), n)

    # %%
    if plot:
        # %%
        fig, axs = plt.subplots(nrows=2)
        sns.heatmap(probs, ax=axs[0], cbar_kws={"label": "True Probs."})
        sns.heatmap(freqs, ax=axs[1], cbar_kws={"label": "Rel. Freqs."})
        plt.show()
        io.save_figure(fig, filename="probs_freqs.png")

        colors = sns.color_palette("deep", n_colors=bit_strings.shape[0])
        fig, ax = plt.subplots()
        for i in range(bit_strings.shape[0]):
            xdata = jnp.linspace(
                phi_range[0], phi_range[1], probs.shape[0], endpoint=False
            )
            ax.plot(xdata, freqs[:, i], color=colors[i], ls="--", alpha=0.3)
        io.save_figure(fig, filename="liklihoods.png")

    # %%
    hf = h5py.File(io.path.joinpath("train_samples.h5"), "w")
    hf.create_dataset("probs", data=probs)
    hf.create_dataset("shots", data=shots)
    hf.create_dataset("counts", data=counts)
    hf.create_dataset("phis", data=phis)
    hf.close()

    print(f"Finished sampling the circuits.")
    return

queso.sample.circuit_test.sample_circuit_testing(io, config, key, plot=False, progress=True)

Samples a quantum circuit for testing based on the provided configuration and random key.

This function initializes a sensor with the given configuration, samples the circuit, and saves the sampled data in an HDF5 file. The data is for testing estimators.

Parameters:

Name Type Description Default
io IO

An IO object for handling file operations.

required
config Configuration

A Configuration object containing the parameters for the circuit.

required
key PRNGKey

A random key for JAX operations.

required
plot bool

If True, plots the true probabilities and relative frequencies. Defaults to False.

False
progress bool

If True, displays progress information. Defaults to True.

True

Returns:

Type Description

None

Source code in queso/sample/circuit_test.py
def sample_circuit_testing(
    io: IO,
    config: Configuration,
    key: jax.random.PRNGKey,
    plot: bool = False,
    progress: bool = True,
):
    """
    Samples a quantum circuit for testing based on the provided configuration and random key.

    This function initializes a sensor with the given configuration, samples the circuit,
    and saves the sampled data in an HDF5 file. The data is for testing estimators.

    Args:
        io (IO): An IO object for handling file operations.
        config (Configuration): A Configuration object containing the parameters for the circuit.
        key (jax.random.PRNGKey): A random key for JAX operations.
        plot (bool, optional): If True, plots the true probabilities and relative frequencies. Defaults to False.
        progress (bool, optional): If True, displays progress information. Defaults to True.

    Returns:
        None
    """

    n = config.n
    k = config.k
    phis_test = jnp.array(config.phis_test)
    n_shots_test = config.n_shots_test
    kwargs = dict(
        preparation=config.preparation,
        interaction=config.interaction,
        detection=config.detection,
        backend=config.backend,
        n_ancilla=config.n_ancilla,
        gamma_dephasing=config.gamma_dephasing,
    )
    jax.config.update("jax_default_device", jax.devices(os.getenv("DEFAULT_DEVICE_SAMPLE_CIRC", "cpu"))[0])

    # %%
    print(f"Initializing sensor n={n}, k={k}")
    sensor = Sensor(n, k, **kwargs)

    # %%
    hf = h5py.File(io.path.joinpath("circ.h5"), "r")
    theta = jnp.array(hf.get("theta"))
    mu = jnp.array(hf.get("mu"))
    hf.close()

    # %% testing samples
    print(f"Sampling {n_shots_test} shots for {phis_test}.")
    t0 = time.time()
    shots_test, probs_test = sensor.sample_over_phases(
        theta, phis_test, mu, n_shots=n_shots_test, verbose=True, key=key
    )
    t1 = time.time()
    print(f"Sampling took {t1 - t0} seconds.")

    # %%
    hf = h5py.File(io.path.joinpath("test_samples.h5"), "w")
    hf.create_dataset("probs_test", data=probs_test)
    hf.create_dataset("shots_test", data=shots_test)
    hf.create_dataset("phis_test", data=phis_test)
    hf.close()

    # %%
    print(f"Finished sampling the circuits for test data.")

    return