Skip to content

Commit

Permalink
Modularize PEC functionality (#2604)
Browse files Browse the repository at this point in the history
* remove old slow comment

* add circuit generation/combination functions

* add missing docstring args

* remove overload typehinting; better variable naming

* remove num_samples from output
  • Loading branch information
natestemen authored Dec 20, 2024
1 parent c4d6a44 commit 2dc8c4f
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 66 deletions.
8 changes: 7 additions & 1 deletion mitiq/pec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from mitiq.pec.types import NoisyOperation, OperationRepresentation, NoisyBasis
from mitiq.pec.sampling import sample_sequence, sample_circuit
from mitiq.pec.pec import execute_with_pec, mitigate_executor, pec_decorator
from mitiq.pec.pec import (
execute_with_pec,
mitigate_executor,
pec_decorator,
combine_results,
generate_sampled_circuits,
)

from mitiq.pec.representations import (
represent_operation_with_global_depolarizing_noise,
Expand Down
129 changes: 99 additions & 30 deletions mitiq/pec/pec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Expand All @@ -37,6 +38,96 @@ class LargeSampleWarning(Warning):
)


def generate_sampled_circuits(
circuit: QPROGRAM,
representations: Sequence[OperationRepresentation],
precision: float = 0.03,
num_samples: int | None = None,
random_state: int | np.random.RandomState | None = None,
full_output: bool = False,
) -> list[QPROGRAM] | tuple[list[QPROGRAM], list[int], float]:
"""Generates a list of sampled circuits based on the given
quasi-probability representations.
Args:
circuit: The quantum circuit to be sampled.
representations: The quasi-probability representations of the circuit
operations.
precision: The desired precision for the sampling process.
Default is 0.03.
num_samples: The number of samples to generate. If None, the number of
samples is deduced based on the precision. Default is None.
random_state: The random state or seed for reproducibility.
full_output: If True, returns the signs and the norm along with the
sampled circuits. Default is False.
Returns:
A list of sampled circuits. If ``full_output`` is True, also returns a
list of signs, the norm.
Raises:
ValueError: If the precision is not within the interval (0, 1].
"""
if isinstance(random_state, int):
random_state = np.random.RandomState(random_state)

if not (0 < precision <= 1):
raise ValueError(
"The value of 'precision' should be within the interval (0, 1],"
f" but precision is {precision}."
)

# Get the 1-norm of the circuit quasi-probability representation
_, _, norm = sample_circuit(
circuit,
representations,
num_samples=1,
)

# Deduce the number of samples (if not given by the user)
if num_samples is None:
num_samples = int((norm / precision) ** 2)

if num_samples > 10**5:
warnings.warn(_LARGE_SAMPLE_WARN, LargeSampleWarning)

sampled_circuits, signs, _ = sample_circuit(
circuit,
representations,
random_state=random_state,
num_samples=num_samples,
)

if full_output:
return sampled_circuits, signs, norm
return sampled_circuits


def combine_results(
results: Iterable[float], norm: float, signs: Iterable[int]
) -> float:
"""Combine expectation values coming from probabilistically sampled
circuits.
Warning:
The ``results`` must be in the same order as the circuits were
generated.
Args:
results: Results as obtained from running circuits.
norm: The one-norm of the circuit representation.
signs: The signs corresponding to the positivity of the sampled
circuits.
Returns:
The PEC estimate of the expectation value.
"""
unbiased_estimators = [norm * s * val for s, val in zip(signs, results)]

pec_value = cast(float, np.average(unbiased_estimators))
return pec_value


def execute_with_pec(
circuit: QPROGRAM,
executor: Union[Executor, Callable[[QPROGRAM], QuantumResult]],
Expand Down Expand Up @@ -95,39 +186,16 @@ def execute_with_pec(
The error is estimated as ``pec_std / sqrt(num_samples)``, where
``pec_std`` is the standard deviation of the PEC samples, i.e., the
square root of the mean squared deviation of the sampled values from
``pec_value``. If ``full_output`` is ``True``, only ``pec_value`` is
``pec_value``. If ``full_output`` is ``False``, only ``pec_value`` is
returned.
"""
if isinstance(random_state, int):
random_state = np.random.RandomState(random_state)

if not (0 < precision <= 1):
raise ValueError(
"The value of 'precision' should be within the interval (0, 1],"
f" but precision is {precision}."
)

# Get the 1-norm of the circuit quasi-probability representation
_, _, norm = sample_circuit(
circuit,
representations,
num_samples=1,
)

# Deduce the number of samples (if not given by the user)
if not isinstance(num_samples, int):
num_samples = int((norm / precision) ** 2)

# Issue warning for very large sample size
if num_samples > 10**5:
warnings.warn(_LARGE_SAMPLE_WARN, LargeSampleWarning)

# Sample all the circuits
sampled_circuits, signs, _ = sample_circuit(
sampled_circuits, signs, norm = generate_sampled_circuits(
circuit,
representations,
precision,
num_samples,
random_state=random_state,
num_samples=num_samples,
full_output=True,
)

# Execute all sampled circuits
Expand All @@ -144,12 +212,13 @@ def execute_with_pec(
if not full_output:
return pec_value

num_circuits = len(sampled_circuits)
# Build dictionary with additional results and data
pec_data: Dict[str, Any] = {
"num_samples": num_samples,
"num_samples": num_circuits,
"precision": precision,
"pec_value": pec_value,
"pec_error": np.std(unbiased_estimators) / np.sqrt(num_samples),
"pec_error": np.std(unbiased_estimators) / np.sqrt(num_circuits),
"unbiased_estimators": unbiased_estimators,
"measured_expectation_values": results,
"sampled_circuits": sampled_circuits,
Expand Down
77 changes: 42 additions & 35 deletions mitiq/pec/tests/test_pec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

"""Unit tests for PEC."""

import warnings
from functools import partial
from typing import List, Optional
from unittest.mock import patch

import cirq
import numpy as np
Expand All @@ -21,10 +21,13 @@
from mitiq.pec import (
NoisyOperation,
OperationRepresentation,
combine_results,
execute_with_pec,
generate_sampled_circuits,
mitigate_executor,
pec_decorator,
)
from mitiq.pec.pec import LargeSampleWarning
from mitiq.pec.representations import (
represent_operations_in_circuit_with_local_depolarizing_noise,
)
Expand Down Expand Up @@ -258,8 +261,6 @@ def test_execute_with_pec_mitigates_noise(circuit, executor, circuit_type):
base_noise=BASE_NOISE,
qubits=[cirq.NamedQubit(name) for name in ("q_0", "q_1")],
)
# TODO: PEC with Qiskit is slow.
# See https://github.com/unitaryfund/mitiq/issues/507.
circuit, _ = convert_to_mitiq(circuit)
else:
reps = pauli_representations
Expand Down Expand Up @@ -366,64 +367,61 @@ def test_execute_with_pec_error_scaling(num_samples: int):


@pytest.mark.parametrize("precision", [0.2, 0.1])
def test_precision_option_in_execute_with_pec(precision: float):
def test_precision_option_used_in_num_samples(precision):
"""Tests that the 'precision' argument is used to deduce num_samples."""
# For a noiseless circuit we expect num_samples = 1/precision^2:
_, pec_data = execute_with_pec(
circuits, _, _ = generate_sampled_circuits(
oneq_circ,
partial(fake_executor, random_state=np.random.RandomState(0)),
representations=pauli_representations,
precision=precision,
force_run_all=True,
full_output=True,
random_state=1,
)
# The error should scale as precision
assert np.isclose(pec_data["pec_error"] / precision, 1.0, atol=0.15)
num_circuits = len(circuits)
# we expect num_samples = 1/precision^2:
assert np.isclose(precision**2 * num_circuits, 1, atol=0.2)

# Check precision is ignored when num_samples is given.
num_samples = 1
_, pec_data = execute_with_pec(

def test_precision_ignored_when_num_samples_present():
"""Check precision is ignored when num_samples is given."""
num_expected_circuits = 123
circuits, _, _ = generate_sampled_circuits(
oneq_circ,
partial(fake_executor, random_state=np.random.RandomState(0)),
representations=pauli_representations,
precision=precision,
num_samples=num_samples,
force_run_all=False,
precision=0.1,
num_samples=num_expected_circuits,
full_output=True,
random_state=1,
)
assert pec_data["num_samples"] == num_samples
num_circuits = len(circuits)
assert num_circuits == num_expected_circuits


@pytest.mark.parametrize("bad_value", (0, -1, 2))
def test_bad_precision_argument(bad_value: float):
def test_bad_precision_argument(bad_value):
"""Tests that if 'precision' is not within (0, 1] an error is raised."""
with pytest.raises(ValueError, match="The value of 'precision' should"):
execute_with_pec(
generate_sampled_circuits(
oneq_circ,
serial_executor,
representations=pauli_representations,
precision=bad_value,
)


def test_large_sample_size_warning():
"""Tests whether a warning is raised when PEC sample size
is greater than 10 ** 5.
"""
warnings.simplefilter("error")
with pytest.raises(
Warning,
match="The number of PEC samples is very large.",
):
execute_with_pec(
@patch("mitiq.pec.pec.sample_circuit")
def test_large_sample_size_warning(mock_sample_circuit):
"""Ensure a warning is raised when sample size is greater than 100k."""

mock_sample_circuit.return_value = ([], [], 0.911)

with pytest.warns(LargeSampleWarning):
generate_sampled_circuits(
oneq_circ,
partial(fake_executor, random_state=np.random.RandomState(0)),
representations=pauli_representations,
num_samples=100001,
force_run_all=False,
representations=[],
num_samples=100_001,
)

assert mock_sample_circuit.call_count == 2


def test_pec_data_with_full_output():
"""Tests that execute_with_pec mitigates the error of a noisy
Expand Down Expand Up @@ -649,3 +647,12 @@ def type_detecting_executor(circuit: QPROGRAM):
num_samples=1,
)
assert np.isclose(mitigated, 0.0)


def test_combining_results():
"""simple arithmetic test"""
results = [0.1, 0.2, 0.3]
norm = 23
signs = [1, -1, 1]
pec_estimate = combine_results(results, norm, signs)
assert np.isclose(pec_estimate, 1.53, atol=0.01)

0 comments on commit 2dc8c4f

Please sign in to comment.