Skip to content

Commit

Permalink
Merge pull request tensorflow#529 from tensorflow/bf_channel
Browse files Browse the repository at this point in the history
Added BF channel support.
  • Loading branch information
jaeyoo authored Mar 31, 2021
2 parents 6e67c0e + 91c4d8a commit ee12f8d
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
33 changes: 33 additions & 0 deletions tensorflow_quantum/core/serialize/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,37 @@ def _phase_flip_channel_deserializer():
args=args)


def _bit_flip_channel_serializer():
"""Make standard serializer for BitFlip channel."""
args = [
# cirq channels can't contain symbols.
cirq.google.SerializingArg(serialized_name="p",
serialized_type=float,
op_getter=lambda x: x.gate.p),
cirq.google.SerializingArg(serialized_name="control_qubits",
serialized_type=str,
op_getter=lambda x: ''),
cirq.google.SerializingArg(serialized_name="control_values",
serialized_type=str,
op_getter=lambda x: '')
]
return cirq.google.GateOpSerializer(gate_type=cirq.BitFlipChannel,
serialized_gate_id="BF",
args=args,
can_serialize_predicate=_CONSTANT_TRUE)


def _bit_flip_channel_deserializer():
"""Make standard deserializer for BitFlip channel."""
args = [
cirq.google.DeserializingArg(serialized_name="p",
constructor_arg_name="p")
]
return cirq.google.GateOpDeserializer(serialized_gate_id="BF",
gate_constructor=cirq.BitFlipChannel,
args=args)


# Gates.
def _eigen_gate_serializer(gate_type, serialized_id):
"""Make standard serializer for eigen gates."""
Expand Down Expand Up @@ -699,6 +730,7 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
] + [
_amplitude_damp_channel_serializer(),
_asymmetric_depolarize_serializer(),
_bit_flip_channel_serializer(),
_depolarize_channel_serializer(),
_fsim_gate_serializer(),
_gad_channel_serializer(),
Expand All @@ -717,6 +749,7 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
] + [
_amplitude_damp_channel_deserializer(),
_asymmetric_depolarize_deserializer(),
_bit_flip_channel_deserializer(),
_depolarize_channel_deserializer(),
_fsim_gate_deserializer(),
_gad_channel_deserializer(),
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_quantum/core/serialize/serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ def _get_noise_proto_pairs():

# Phase flip.
(cirq.Circuit(cirq.phase_flip(p=0.1)(q0)),
_build_op_proto("PF", ['p'], [0.1], ['0_0']))
_build_op_proto("PF", ['p'], [0.1], ['0_0'])),

# Bit flip.
(cirq.Circuit(cirq.bit_flip(p=0.1)(q0)),
_build_op_proto("BF", ['p'], [0.1], ['0_0']))
]
return pairs

Expand Down
22 changes: 21 additions & 1 deletion tensorflow_quantum/core/src/circuit_parser_qsim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,26 @@ inline Status PhaseFlipChannel(const Operation& op,
return Status::OK();
}

inline Status BitFlipChannel(const Operation& op, const unsigned int num_qubits,
const unsigned int time,
NoisyQsimCircuit* ncircuit) {
int q;
bool unused;
float p;
Status u;
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);

u = ParseProtoArg(op, "p", {}, &p);
if (!u.ok()) {
return u;
}

auto chan =
qsim::Cirq::BitFlipChannel<float>::Create(time, num_qubits - q - 1, p);
ncircuit->channels.push_back(chan);
return Status::OK();
}

tensorflow::Status ParseAppendChannel(const Operation& op,
const unsigned int num_qubits,
const unsigned int time,
Expand All @@ -750,7 +770,7 @@ tensorflow::Status ParseAppendChannel(const Operation& op,
{"DP", &DepolarizingChannel}, {"ADP", &AsymmetricDepolarizingChannel},
{"GAD", &GADChannel}, {"AD", &AmplitudeDampingChannel},
{"RST", &ResetChannel}, {"PD", &PhaseDampingChannel},
{"PF", &PhaseFlipChannel}};
{"PF", &PhaseFlipChannel}, {"BF", &BitFlipChannel}};

auto build_f = chan_func_map.find(op.gate().id());
if (build_f == chan_func_map.end()) {
Expand Down
36 changes: 36 additions & 0 deletions tensorflow_quantum/core/src/circuit_parser_qsim_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,42 @@ TEST(QsimCircuitParserTest, PhaseFlip) {
ASSERT_EQ(test_circuit.num_qubits, 1);
}

TEST(QsimCircuitParserTest, BitFlip) {
float p = 0.1234;
auto reference = qsim::Cirq::BitFlipChannel<float>::Create(0, 0, p);
Program program_proto;
Circuit* circuit_proto = program_proto.mutable_circuit();
circuit_proto->set_scheduling_strategy(circuit_proto->MOMENT_BY_MOMENT);
Moment* moments_proto = circuit_proto->add_moments();

// Add channel.
Operation* operations_proto = moments_proto->add_operations();
Gate* gate_proto = operations_proto->mutable_gate();
gate_proto->set_id("BF");

// Set the args.
google::protobuf::Map<std::string, Arg>* args_proto =
operations_proto->mutable_args();
(*args_proto)["p"] = MakeArg(p);

// Set the control args.
(*args_proto)["control_qubits"] = MakeControlArg("");
(*args_proto)["control_values"] = MakeControlArg("");

// Set the qubits.
Qubit* qubits_proto = operations_proto->add_qubits();
qubits_proto->set_id("0");

NoisyQsimCircuit test_circuit;

ASSERT_EQ(
NoisyQsimCircuitFromProgram(program_proto, {}, 1, false, &test_circuit),
tensorflow::Status::OK());
AssertChannelEqual(test_circuit.channels[0], reference);
ASSERT_EQ(test_circuit.channels.size(), 1);
ASSERT_EQ(test_circuit.num_qubits, 1);
}

TEST(QsimCircuitParserTest, NoisyEmpty) {
Program program_proto;
Circuit* circuit_proto = program_proto.mutable_circuit();
Expand Down
6 changes: 6 additions & 0 deletions tensorflow_quantum/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
cirq.ResetChannel,
cirq.PhaseDampingChannel,
cirq.PhaseFlipChannel,
cirq.BitFlipChannel,
]


Expand Down Expand Up @@ -84,6 +85,7 @@ def get_supported_channels():
channel_mapping[cirq.ResetChannel()] = 1
channel_mapping[cirq.PhaseDampingChannel(0.01)] = 1
channel_mapping[cirq.PhaseFlipChannel(0.01)] = 1
channel_mapping[cirq.BitFlipChannel(0.01)] = 1

return channel_mapping

Expand Down Expand Up @@ -534,6 +536,10 @@ def _channel_approx_eq(op_true, op_deser, atol=1e-5):
if isinstance(op_deser, cirq.PhaseFlipChannel):
return abs(op_true.p - op_deser.p) < atol

if isinstance(op_true, cirq.BitFlipChannel):
if isinstance(op_deser, cirq.BitFlipChannel):
return abs(op_true.p - op_deser.p) < atol

return False


Expand Down

0 comments on commit ee12f8d

Please sign in to comment.