Skip to content

Commit

Permalink
[minor] tq2qiskit_op_history
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanrui-Wang committed Apr 12, 2023
1 parent bbd229b commit 17a2a85
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 15 deletions.
33 changes: 24 additions & 9 deletions examples/simple_vqe/new_simple_vqe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import torchquantum as tq
import torch
from torchquantum.vqe_utils import parse_hamiltonian_file
import random
import numpy as np
import argparse
import torch.optim as optim

from torch.optim.lr_scheduler import CosineAnnealingLR
from torchquantum.measurement import expval_joint_analytical

from torchquantum.algorithms import VQE, Hamiltonian
from qiskit import QuantumCircuit

from torchquantum.plugins import qiskit2tq_op_history

if __name__ == "__main__":
hamil = Hamiltonian.from_file("./h2.txt")
Expand All @@ -24,6 +18,27 @@
{'name': 'cu3', 'wires': [0, 1], 'trainable': True},
{'name': 'cu3', 'wires': [1, 0], 'trainable': True},
]

# or alternatively, you can use the following code to generate the ops
circ = QuantumCircuit(2)
circ.h(0)
circ.rx(0.1, 1)
circ.cx(0, 1)
circ.u(0.1, 0.2, 0.3, 0)
circ.u(0.1, 0.2, 0.3, 0)
circ.cx(1, 0)
circ.u(0.1, 0.2, 0.3, 0)
circ.u(0.1, 0.2, 0.3, 0)
circ.cx(0, 1)
circ.u(0.1, 0.2, 0.3, 0)
circ.u(0.1, 0.2, 0.3, 0)
circ.cx(1, 0)
circ.u(0.1, 0.2, 0.3, 0)
circ.u(0.1, 0.2, 0.3, 0)

ops = qiskit2tq_op_history(circ)
print(ops)

ansatz = tq.QuantumModule.from_op_history(ops)
configs = {
"n_epochs": 10,
Expand Down
24 changes: 24 additions & 0 deletions test/plugins/test_qiskit2tq_op_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torchquantum.plugins import qiskit2tq_op_history
import torchquantum as tq
from qiskit.circuit.random import random_circuit
from qiskit import QuantumCircuit


def test_qiskit2tp_op_history():
circ = QuantumCircuit(3, 3)
circ.h(0)
circ.rx(0.1, 1)
circ.cx(0, 1)
circ.cx(1, 2)
circ.u(0.1, 0.2, 0.3, 0)
print(circ)
ops = qiskit2tq_op_history(circ)
qmodule = tq.QuantumModule.from_op_history(ops)
print(qmodule.Operator_list)



if __name__ == '__main__':
import pdb
pdb.set_trace()
test_qiskit2tp_op_history()
10 changes: 5 additions & 5 deletions torchquantum/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def static_forward(self, q_device: tq.QuantumDevice):
# params=self.unitary
# )

def __repr__(self):
if self.Operator_list is not None:
return f"QuantumModule with Operator_list {self.Operator_list}"
else:
return "QuantumModule"
# def __repr__(self):
# if self.Operator_list is not None:
# return f"QuantumModule with Operator_list {self.Operator_list}"
# else:
# return "QuantumModule"

def get_unitary(self, q_device: tq.QuantumDevice, x=None):
original_wires_per_block = self.wires_per_block
Expand Down
98 changes: 97 additions & 1 deletion torchquantum/plugins/qiskit_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"tq2qiskit",
"tq2qiskit_parameterized",
"qiskit2tq",
"qiskit2tq_op_history",
"qiskit2tq_Operator",
"tq2qiskit_measurement",
"tq2qiskit_expand_params",
"qiskit_assemble_circs",
Expand All @@ -33,6 +35,100 @@
]


def qiskit2tq_op_history(circ):
if getattr(circ, "_layout", None) is not None:
try:
p2v_orig = circ._layout.final_layout.get_physical_bits().copy()
except:
p2v_orig = circ._layout.get_physical_bits().copy()
p2v = {}
for p, v in p2v_orig.items():
if v.register.name == "q":
p2v[p] = v.index
else:
p2v[p] = f"{v.register.name}.{v.index}"
else:
p2v = {}
for p in range(circ.num_qubits):
p2v[p] = p

ops = []
for gate in circ.data:
op_name = gate[0].name
wires = list(map(lambda x: x.index, gate[1]))
wires = [p2v[wire] for wire in wires]
# sometimes the gate.params is ParameterExpression class
init_params = (
list(map(float, gate[0].params)) if len(gate[0].params) > 0 else None
)
print(op_name,)

if op_name in [
"h",
"x",
"y",
"z",
"s",
"t",
"sx",
"cx",
"cz",
"cy",
"swap",
"cswap",
"ccx",
]:
ops.append(
{
"name": op_name, # type: ignore
"wires": np.array(wires),
"params": None,
"inverse": False,
"trainable": False,
}
)
elif op_name in [
"rx",
"ry",
"rz",
"rxx",
"xx",
"ryy",
"yy",
"rzz",
"zz",
"rzx",
"zx",
"p",
"cp",
"crx",
"cry",
"crz",
"u1",
"cu1",
"u2",
"u3",
"cu3",
"u",
"cu",
]:
ops.append(
{
"name": op_name, # type: ignore
"wires": np.array(wires),
"params": init_params,
"inverse": False,
"trainable": True
})
elif op_name in ["barrier", "measure"]:
continue
else:
raise NotImplementedError(
f"{op_name} conversion to tq is currently not supported."
)
return ops


def qiskit_assemble_circs(encoders, fixed_layer, measurement):
circs_all = []
n_qubits = len(fixed_layer.qubits)
Expand Down Expand Up @@ -552,7 +648,7 @@ def op_history2qiskit_expand_params(n_wires, op_history, bsz):

# construct a tq QuantumModule object according to the qiskit QuantumCircuit
# object
def qiskit2tq_ops(circ: QuantumCircuit):
def qiskit2tq_Operator(circ: QuantumCircuit):
if getattr(circ, "_layout", None) is not None:
try:
p2v_orig = circ._layout.final_layout.get_physical_bits().copy()
Expand Down

0 comments on commit 17a2a85

Please sign in to comment.