Skip to content

Commit

Permalink
Parsing functionality for program-program. (tensorflow#382)
Browse files Browse the repository at this point in the history
* Parsing functionality for program-program.

* added empty case.
  • Loading branch information
MichaelBroughton authored Sep 17, 2020
1 parent 0712199 commit 18c57af
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 0 deletions.
84 changes: 84 additions & 0 deletions tensorflow_quantum/core/src/program_resolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,90 @@ Status ResolveQubitIds(Program* program, unsigned int* num_qubits,
return Status::OK();
}

Status ResolveQubitIds(Program* program, unsigned int* num_qubits,
std::vector<Program>* other_programs) {
if (program->circuit().moments().empty()) {
// (#679) Just ignore empty program.
// Number of qubits in empty programs is zero.
*num_qubits = 0;
return Status::OK();
}

absl::flat_hash_set<std::pair<std::pair<int, int>, std::string>> id_set;
for (const Moment& moment : program->circuit().moments()) {
for (const Operation& operation : moment.operations()) {
for (const Qubit& qubit : operation.qubits()) {
int r, c;
const std::vector<std::string> splits = absl::StrSplit(qubit.id(), "_");
if (splits.size() != 2) {
return Status(tensorflow::error::INVALID_ARGUMENT,
"Unable to parse qubit: " + qubit.id());
}
if (!absl::SimpleAtoi(splits[0], &r)) {
return Status(tensorflow::error::INVALID_ARGUMENT,
"Unable to parse qubit: " + qubit.id());
}
if (!absl::SimpleAtoi(splits[1], &c)) {
return Status(tensorflow::error::INVALID_ARGUMENT,
"Unable to parse qubit: " + qubit.id());
}
auto locs = std::pair<std::pair<int, int>, std::string>(
std::pair<int, int>(r, c), qubit.id());
id_set.insert(locs);
}
}
}
*num_qubits = id_set.size();

// call to std::sort will do (r1 < r2) || ((r1 == r2) && c1 < c2)
std::vector<std::pair<std::pair<int, int>, std::string>> ids(id_set.begin(),
id_set.end());
std::sort(ids.begin(), ids.end());

absl::flat_hash_map<std::string, std::string> id_to_index;
absl::flat_hash_set<std::string> id_ref;
for (size_t i = 0; i < ids.size(); i++) {
id_to_index[ids[i].second] = absl::StrCat(i);
id_ref.insert(ids[i].second);
}

// Replace the Program Qubit ids with the indices.
for (Moment& moment : *program->mutable_circuit()->mutable_moments()) {
for (Operation& operation : *moment.mutable_operations()) {
for (Qubit& qubit : *operation.mutable_qubits()) {
qubit.set_id(id_to_index.at(qubit.id()));
}
}
}

for (size_t i = 0; i < other_programs->size(); i++) {
// Replace the other_program Qubit ids with the indices.
absl::flat_hash_set<std::string> visited_qubits(id_ref);
for (Moment& moment :
*(other_programs->at(i)).mutable_circuit()->mutable_moments()) {
for (Operation& operation : *moment.mutable_operations()) {
for (Qubit& qubit : *operation.mutable_qubits()) {
visited_qubits.erase(qubit.id());
const auto result = id_to_index.find(qubit.id());
if (result == id_to_index.end()) {
return Status(tensorflow::error::INVALID_ARGUMENT,
"A paired circuit contains qubits not found in "
"reference circuit.");
}
qubit.set_id(result->second);
}
}
}
if (!visited_qubits.empty()) {
return Status(
tensorflow::error::INVALID_ARGUMENT,
"A reference circuit contains qubits not found in paired circuit.");
}
}

return Status::OK();
}

Status ResolveSymbols(
const absl::flat_hash_map<std::string, std::pair<int, float>>& param_map,
Program* program, bool resolve_all /*=true*/) {
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_quantum/core/src/program_resolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ tensorflow::Status ResolveQubitIds(
cirq::google::api::v2::Program* program, unsigned int* num_qubits,
std::vector<tfq::proto::PauliSum>* p_sums = nullptr);

// Overload which allows for strict resolution of multiple programs.
// Will resolve GridQubits in `program` and then double check that
// all qubits in `other_programs` match and resolve them.
// Note: no nullptr default is done here to avoid signature resolutions issues.
tensorflow::Status ResolveQubitIds(
cirq::google::api::v2::Program* program, unsigned int* num_qubits,
std::vector<cirq::google::api::v2::Program>* other_programs);

// Resolves all of the symbols present in the Program. Iterates through all
// operations in all moments, and if any Args have a symbol, replaces the one-of
// with an ArgValue representing the value in the parameter map keyed by the
Expand Down
113 changes: 113 additions & 0 deletions tensorflow_quantum/core/src/program_resolution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,119 @@ TEST(ProgramResolutionTest, ResolveQubitIds) {
EXPECT_EQ(num_qubits_alphabet, 4);
}

TEST(ProgramResolutionTest, ResolveQubitIdsPrograms) {
const std::string text = R"(
circuit {
moments {
operations {
qubits {
id: "0_0"
}
qubits {
id: "1_0"
}
}
}
moments {
operations {
qubits {
id: "0_0"
}
qubits {
id: "0_1"
}
}
}
}
)";

const std::string text_alphabet = R"(
circuit {
moments {
operations {
qubits {
id: "0_0"
}
qubits {
id: "1_0"
}
}
}
moments {
operations {
qubits {
id: "0_1"
}
qubits {
id: "0_3"
}
}
}
}
)";

const std::string text_empty = R"(
circuit {
}
)";

Program program, program_copy, empty_program;
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(text, &program));
ASSERT_TRUE(
google::protobuf::TextFormat::ParseFromString(text, &program_copy));

unsigned int num_qubits, num_qubits_empty;
std::vector<Program> vec({program_copy});
EXPECT_TRUE(ResolveQubitIds(&program, &num_qubits, &vec).ok());

// Test case where circuits are aligned.
EXPECT_EQ(program.circuit().moments(0).operations(0).qubits(0).id(), "0");
EXPECT_EQ(program.circuit().moments(0).operations(0).qubits(1).id(), "2");
EXPECT_EQ(program.circuit().moments(1).operations(0).qubits(0).id(), "0");
EXPECT_EQ(program.circuit().moments(1).operations(0).qubits(1).id(), "1");

EXPECT_EQ(vec[0].circuit().moments(0).operations(0).qubits(0).id(), "0");
EXPECT_EQ(vec[0].circuit().moments(0).operations(0).qubits(1).id(), "2");
EXPECT_EQ(vec[0].circuit().moments(1).operations(0).qubits(0).id(), "0");
EXPECT_EQ(vec[0].circuit().moments(1).operations(0).qubits(1).id(), "1");

// Test case where source circuit is smaller than paired circuit:
program.Clear();
program_copy.Clear();
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(text, &program));
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(text_alphabet,
&program_copy));

std::vector<Program> vec2({program_copy});

EXPECT_EQ(
ResolveQubitIds(&program, &num_qubits, &vec2),
tensorflow::Status(
tensorflow::error::INVALID_ARGUMENT,
"A paired circuit contains qubits not found in reference circuit."));

// Test case where paired circuit is smaller than source circuit:
program.Clear();
program_copy.Clear();
ASSERT_TRUE(
google::protobuf::TextFormat::ParseFromString(text_alphabet, &program));
ASSERT_TRUE(
google::protobuf::TextFormat::ParseFromString(text, &program_copy));

std::vector<Program> vec3({program_copy});

EXPECT_EQ(
ResolveQubitIds(&program, &num_qubits, &vec3),
tensorflow::Status(
tensorflow::error::INVALID_ARGUMENT,
"A reference circuit contains qubits not found in paired circuit."));

// Ensure empty case is consistent.
std::vector<Program> vec4;
EXPECT_TRUE(ResolveQubitIds(&empty_program, &num_qubits_empty, &vec4).ok());
EXPECT_EQ(num_qubits_empty, 0);
}

TEST(ProgramResolutionTest, ResolveSymbolsInvalidArg) {
const std::string text = R"(
circuit {
Expand Down

0 comments on commit 18c57af

Please sign in to comment.