Skip to content

Commit

Permalink
Execute TODO to have the prob mass auto added (quantumlib#3363)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonybruguier authored Sep 29, 2020
1 parent fa99585 commit bba0153
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
36 changes: 22 additions & 14 deletions cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self,
p_x: Optional[float] = None,
p_y: Optional[float] = None,
p_z: Optional[float] = None,
error_probabilities: Optional[Dict[str, float]] = None
) -> None:
error_probabilities: Optional[Dict[str, float]] = None,
tol: float = 1e-8) -> None:
r"""The asymmetric depolarizing channel.
This channel applies one of 4**n disjoint possibilities: nothing (the
Expand All @@ -58,7 +58,10 @@ def __init__(self,
p_y: The probability that a Pauli Y and no other gate occurs.
p_z: The probability that a Pauli Z and no other gate occurs.
error_probabilities: Dictionary of string (Pauli operator) to its
probability
probability. If the identity is missing from the list, it will
be added so that the total probability mass is 1.
tol: The tolerance used making sure the total probability mass is
equal to 1.
Examples of calls:
* Single qubit: AsymmetricDepolarizingChannel(0.2, 0.1, 0.3)
Expand All @@ -79,10 +82,10 @@ def __init__(self,
for k, v in error_probabilities.items():
value.validate_probability(v, f"p({k})")
sum_probs = sum(error_probabilities.values())
# TODO(tonybruguier): Instead of forcing the probabilities to add up
# to 1, check whether the identity is missing, and if that is the
# case, automatically add it with the missing probability mass.
if abs(sum_probs - 1.0) > 1e-6:
identity = 'I' * num_qubits
if sum_probs < 1.0 - tol and identity not in error_probabilities:
error_probabilities[identity] = 1.0 - sum_probs
elif abs(sum_probs - 1.0) > tol:
raise ValueError(
f"Probabilities do not add up to 1 but to {sum_probs}")
self._num_qubits = num_qubits
Expand Down Expand Up @@ -190,11 +193,12 @@ def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['error_probabilities'])


def asymmetric_depolarize(p_x: Optional[float] = None,
p_y: Optional[float] = None,
p_z: Optional[float] = None,
error_probabilities: Optional[Dict[str, float]] = None
) -> AsymmetricDepolarizingChannel:
def asymmetric_depolarize(
p_x: Optional[float] = None,
p_y: Optional[float] = None,
p_z: Optional[float] = None,
error_probabilities: Optional[Dict[str, float]] = None,
tol: float = 1e-8) -> AsymmetricDepolarizingChannel:
r"""Returns a AsymmetricDepolarizingChannel with given parameter.
This channel applies one of 4**n disjoint possibilities: nothing (the
Expand All @@ -215,7 +219,10 @@ def asymmetric_depolarize(p_x: Optional[float] = None,
p_y: The probability that a Pauli Y and no other gate occurs.
p_z: The probability that a Pauli Z and no other gate occurs.
error_probabilities: Dictionary of string (Pauli operator) to its
probability
probability. If the identity is missing from the list, it will
be added so that the total probability mass is 1.
tol: The tolerance used making sure the total probability mass is
equal to 1.
Examples of calls:
* Single qubit: AsymmetricDepolarizingChannel(0.2, 0.1, 0.3)
Expand All @@ -226,7 +233,8 @@ def asymmetric_depolarize(p_x: Optional[float] = None,
Raises:
ValueError: if the args or the sum of the args are not probabilities.
"""
return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities)
return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities,
tol)


@value.value_equality
Expand Down
7 changes: 7 additions & 0 deletions cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,13 @@ def test_bad_probs():
cirq.asymmetric_depolarize(error_probabilities={'X': 0.7, 'Y': 0.6})


def test_missing_prob_mass():
with pytest.raises(ValueError, match='Probabilities do not add up to 1'):
cirq.asymmetric_depolarize(error_probabilities={'X': 0.1, 'I': 0.2})
d = cirq.asymmetric_depolarize(error_probabilities={'X': 0.1})
np.testing.assert_almost_equal(d.error_probabilities['I'], 0.9)


def test_multi_asymmetric_depolarizing_channel():
d = cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})
np.testing.assert_almost_equal(
Expand Down

0 comments on commit bba0153

Please sign in to comment.