Skip to content

Commit

Permalink
Always return 1D vector from final_state_vector (quantumlib#5793)
Browse files Browse the repository at this point in the history
Currently there is a special case for 0-qubit tensors, which causes the state vector to be returned as a scalar instead of a 1D vector of length 2**num_qubits == 1.

Also change the `final_state_vector` property to use `_compat.cached_property`.
  • Loading branch information
maffoo authored Jul 18, 2022
1 parent e362cb6 commit 9f5e234
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
3 changes: 1 addition & 2 deletions cirq-core/cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

if TYPE_CHECKING:
import cirq
from numpy.typing import DTypeLike

CIRCUIT_LIKE = Union[circuits.Circuit, ops.Gate, ops.OP_TREE]
document(
Expand Down Expand Up @@ -112,7 +111,7 @@ def final_state_vector(
ignore_terminal_measurements: bool = False,
dtype: Type[np.complexfloating] = np.complex64,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
) -> 'np.ndarray':
) -> np.ndarray:
"""Returns the state vector resulting from acting operations on a state.
By default the input state is the computational basis zero state, in which
Expand Down
17 changes: 5 additions & 12 deletions cirq-core/cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
"""Abstract classes for simulations which keep track of state vector."""

import abc
from typing import Any, Dict, Iterator, Sequence, Type, TYPE_CHECKING, Generic, TypeVar, Optional
from typing import Any, Dict, Iterator, Sequence, Type, TYPE_CHECKING, Generic, TypeVar

import numpy as np

from cirq import ops, value, qis
from cirq._compat import proper_repr
from cirq import _compat, ops, value, qis
from cirq.sim import simulator, state_vector, simulator_base

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,16 +118,10 @@ def __init__(
final_simulator_state=final_simulator_state,
qubit_map=final_simulator_state.qubit_map,
)
self._final_state_vector: Optional[np.ndarray] = None

@property
@_compat.cached_property
def final_state_vector(self) -> np.ndarray:
if self._final_state_vector is None:
tensor = self._get_merged_sim_state().target_tensor
if tensor.ndim > 1:
tensor = tensor.reshape(-1)
self._final_state_vector = tensor
return self._final_state_vector
return self._get_merged_sim_state().target_tensor.reshape(-1)

def state_vector(self, copy: bool = False) -> np.ndarray:
"""Return the state vector at the end of the computation.
Expand Down Expand Up @@ -196,6 +189,6 @@ def _repr_pretty_(self, p: Any, cycle: bool):
def __repr__(self) -> str:
return (
'cirq.StateVectorTrialResult('
f'params={self.params!r}, measurements={proper_repr(self.measurements)}, '
f'params={self.params!r}, measurements={_compat.proper_repr(self.measurements)}, '
f'final_simulator_state={self._final_simulator_state!r})'
)
12 changes: 12 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def test_state_vector_trial_state_vector_is_copy():
assert trial_result.state_vector(copy=True) is not final_simulator_state.target_tensor


def test_state_vector_trial_result_no_qubits():
initial_state_vector = np.array([1], dtype=np.complex64)
initial_state = initial_state_vector.reshape((2,) * 0) # reshape as tensor for 0 qubits
final_simulator_state = cirq.StateVectorSimulationState(qubits=[], initial_state=initial_state)
trial_result = cirq.StateVectorTrialResult(
params=cirq.ParamResolver({}), measurements={}, final_simulator_state=final_simulator_state
)
state_vector = trial_result.state_vector()
assert state_vector.shape == (1,)
assert np.array_equal(state_vector, initial_state_vector)


def test_str_big():
qs = cirq.LineQubit.range(10)
final_simulator_state = cirq.StateVectorSimulationState(
Expand Down

0 comments on commit 9f5e234

Please sign in to comment.