Skip to content

Commit

Permalink
chore: result plots moved in classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Henri-ColibrITD committed May 16, 2024
1 parent d8f859f commit 23c4262
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 112 deletions.
20 changes: 3 additions & 17 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,9 @@ Tools

Some additional tools are provided with our library. Even though most of them
are geared at internal usage, they are all presented here. Amongst them, the
ones most probable of being of use for you are in:

- the section :ref:`viz` presents visualization tools for several data types.
They might be integrated in these types if they are popular enough.
- the section :ref:`math` presents mathematical tools for linear algebra,
functions generalized to more data types, etc...

.. _viz:

Visualization
-------------

.. code-block::python
from mpqp.tools.visualization import *
.. automodule:: mpqp.tools.visualization
ones most probable of being of use for you are in the :ref:`math` section,
presenting mathematical tools for linear algebra, functions generalized to more
data types, etc...

.. _math:

Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/bell_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from mpqp.execution.devices import AWSDevice, IBMDevice
from mpqp.gates import CNOT, H
from mpqp.measures import BasisMeasure
from mpqp.tools.visualization import plot_results_sample_mode

# Declaration of the circuit with the right size
circuit = QCircuit(2, label="Bell pair")
Expand All @@ -18,6 +17,6 @@
results = run(circuit, [IBMDevice.AER_SIMULATOR, AWSDevice.BRAKET_LOCAL_SIMULATOR])
print(results)

plot_results_sample_mode(results)
results.plot(show=False)
circuit.display()
plt.show()
9 changes: 2 additions & 7 deletions examples/scripts/cirq_experiments.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# %%
import matplotlib.pyplot as plt

from mpqp import QCircuit
from mpqp.core.languages import Language
from mpqp.execution import run
from mpqp.execution.connection.google_connection import config_ionq_account
from mpqp.execution.devices import GOOGLEDevice, IBMDevice
from mpqp.gates import H, Rx, Ry, Rz
from mpqp.measures import BasisMeasure
from mpqp.tools.visualization import plot_results_sample_mode

# %%
circuit = QCircuit(3)
Expand All @@ -34,8 +31,7 @@
)
print(results)

plot_results_sample_mode(results)
plt.show()
results.plot()

# %%
cirq_circuit = circuit.to_other_language(Language.CIRQ)
Expand All @@ -54,5 +50,4 @@
results = run(circuit, [GOOGLEDevice.IONQ_SIMULATOR])
print(results)

plot_results_sample_mode(results)
plt.show()
results.plot()
5 changes: 1 addition & 4 deletions examples/scripts/demonstration.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Demonstration MPQP"""

import matplotlib.pyplot as plt
import numpy as np

from mpqp import QCircuit
from mpqp.execution import run
from mpqp.execution.devices import ATOSDevice, AWSDevice, GOOGLEDevice, IBMDevice
from mpqp.gates import *
from mpqp.measures import BasisMeasure
from mpqp.tools.visualization import plot_results_sample_mode

# Constructing the circuit
meas = BasisMeasure(list(range(3)), shots=2000)
Expand All @@ -31,8 +29,7 @@
print(results)

# Display the circuit
plot_results_sample_mode(results)
plt.show()
results.plot()

c = QCircuit([T(0), CNOT(0, 1), Ry(np.pi / 2, 2), S(1), CZ(2, 1), SWAP(2, 0)])
res = run(
Expand Down
72 changes: 72 additions & 0 deletions mpqp/execution/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from __future__ import annotations

import math
import random
from numbers import Complex
from typing import Any, Optional

import numpy as np
import numpy.typing as npt
from matplotlib import pyplot as plt
from typeguard import typechecked

from mpqp.execution.devices import AvailableDevice
Expand Down Expand Up @@ -385,6 +387,50 @@ def __str__(self):
f"Job type {self.job.job_type} not implemented for __str__ method"
)

def plot(self, show: bool = True):
"""Extract sampling info from the result and construct the bar diagram
plot.
Args:
show: ``plt.show()`` is only executed if ``show``, useful to batch
plots.
"""
if show:
plt.figure()

x_array, y_array = self._to_display_lists()
x_axis = range(len(x_array))

plt.bar(x_axis, y_array, color=(*[random.random() for _ in range(3)], 0.9))
plt.xticks(x_axis, x_array, rotation=-60)
plt.xlabel("State")
plt.ylabel("Counts")
device = self.job.device
plt.title(type(device).__name__ + "\n" + device.name)

if show:
plt.show()

def _to_display_lists(self) -> tuple[list[str], list[int]]:
"""Transform a result into an x and y array containing the string of
basis state with the associated counts.
Returns:
The list of each basis state and the corresponding counts.
"""
if self.job.job_type != JobType.SAMPLE:
raise NotImplementedError(
f"{self.job.job_type} not handled, only {JobType.SAMPLE} is handled for now."
)
if self.job.measure is None:
raise ValueError(
f"{self.job=} has no measure, making the counting impossible"
)
n = self.job.measure.nb_qubits
x_array = [f"|{bin(i)[2:].zfill(n)}⟩" for i in range(2**n)]
y_array = self.counts
return x_array, y_array


@typechecked
class BatchResult:
Expand Down Expand Up @@ -454,6 +500,32 @@ def __repr__(self):
def __getitem__(self, index: int):
return self.results[index]

def plot(self, show: bool = True):
"""Display the result(s) using ``matplotlib.pyplot``.
The result(s) must be from a job who's ``job_type`` is ``SAMPLE``. They will
be displayed as histograms.
If a ``BatchResult`` is given, the contained results will be displayed in a
grid using subplots.
Args:
show: ``plt.show()`` is only executed if ``show``, useful to batch
plots.
"""
n_cols = math.ceil((len(self.results) + 1) // 2)
n_rows = math.ceil(len(self.results) / n_cols)

for index, result in enumerate(self.results):
plt.subplot(n_rows, n_cols, index + 1)

result.plot(show=False)

plt.tight_layout()

if show:
plt.show()


def clean_array(array: npt.NDArray[Any]) -> str:
"""TODO: doc"""
Expand Down
82 changes: 0 additions & 82 deletions mpqp/tools/visualization.py

This file was deleted.

0 comments on commit 23c4262

Please sign in to comment.