Skip to content

Commit

Permalink
Some small util ops cleanup. (tensorflow#165)
Browse files Browse the repository at this point in the history
* Some small util ops cleanup.

* missed some tstrings.

Co-authored-by: Michael Broughton <[email protected]>
  • Loading branch information
MichaelBroughton and MichaelBroughton authored Mar 18, 2020
1 parent fe81a99 commit 260ad7e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 15 deletions.
6 changes: 3 additions & 3 deletions tensorflow_quantum/core/ops/parse_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Status ParsePrograms(OpKernelContext* context, const std::string& input_name,
absl::StrCat("programs must be rank 1. Got rank ", input->dims(), "."));
}

const auto program_strings = input->vec<std::string>();
const auto program_strings = input->vec<tensorflow::tstring>();
const int num_programs = program_strings.dimension(0);
programs->assign(num_programs, Program());

Expand Down Expand Up @@ -161,7 +161,7 @@ Status GetPauliSums(OpKernelContext* context,
input->dims(), "."));
}

const auto sum_specs = input->matrix<std::string>();
const auto sum_specs = input->matrix<tensorflow::tstring>();
p_sums->reserve(sum_specs.dimension(0));
for (int i = 0; i < sum_specs.dimension(0); i++) {
std::vector<PauliSum> sub_ops;
Expand Down Expand Up @@ -208,7 +208,7 @@ Status GetSymbolMaps(OpKernelContext* context, std::vector<SymbolMap>* maps) {
input_values->dims(), "."));
}

const auto symbol_names = input_names->vec<std::string>();
const auto symbol_names = input_names->vec<tensorflow::tstring>();
const auto symbol_values = input_values->matrix<float>();

if (symbol_names.dimension(0) != symbol_values.dimension(1)) {
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_quantum/core/ops/tfq_circuit_append_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ class TfqCircuitAppendOp : public tensorflow::OpKernel {
tensorflow::Tensor *output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, context->input(0).shape(), &output));
auto output_tensor = output->flat<std::string>();
auto output_tensor = output->flat<tensorflow::tstring>();

auto DoWork = [&](int start, int end) {
std::string temp;
for (int i = start; i < end; i++) {
for (int j = 0; j < programs_to_append.at(i).circuit().moments().size();
j++) {
Moment *new_moment = programs.at(i).mutable_circuit()->add_moments();
*new_moment = programs_to_append.at(i).circuit().moments(j);
}
programs.at(i).SerializeToString(&output_tensor(i));
programs.at(i).SerializeToString(&temp);
output_tensor(i) = temp;
}
};

Expand Down
6 changes: 4 additions & 2 deletions tensorflow_quantum/core/ops/tfq_ps_decompose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
tensorflow::Tensor *output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, context->input(0).shape(), &output));
auto output_tensor = output->flat<std::string>();
auto output_tensor = output->flat<tensorflow::tstring>();

const int max_buffer_moments = 3;

auto DoWork = [&](int start, int end) {
for (int i = start; i < end; i++) {
Program cur_program = programs.at(i);
Program new_program;
std::string temp;
new_program.mutable_language()->set_gate_set("tfq_gate_set");
new_program.mutable_circuit()->set_scheduling_strategy(
Circuit::MOMENT_BY_MOMENT);
Expand Down Expand Up @@ -145,7 +146,8 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
}
}
}
new_program.SerializeToString(&output_tensor(i));
new_program.SerializeToString(&temp);
output_tensor(i) = temp;
}
};

Expand Down
24 changes: 18 additions & 6 deletions tensorflow_quantum/core/ops/tfq_ps_symbol_replace_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,26 @@ class TfqPsSymbolReplaceOp : public tensorflow::OpKernel {

// Parse the input string here.
const Tensor *symbols_tensor;
context->input("symbols", &symbols_tensor);
OP_REQUIRES_OK(context, context->input("symbols", &symbols_tensor));
OP_REQUIRES(
context, symbols_tensor->dims() == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"symbols must be rank 1. Got rank ", symbols_tensor->dims(), ".")));

const auto symbols = symbols_tensor->vec<std::string>();
const auto symbols = symbols_tensor->vec<tensorflow::tstring>();
const size_t n_symbols = symbols.size();

// Parse the replacement string here.
const Tensor *replacement_symbols_tensor;
context->input("replacement_symbols", &replacement_symbols_tensor);
OP_REQUIRES_OK(context, context->input("replacement_symbols",
&replacement_symbols_tensor));
OP_REQUIRES(context, replacement_symbols_tensor->dims() == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"replacement_symbols must be rank 1. Got rank ",
replacement_symbols_tensor->dims(), ".")));

const auto replacement_symbols =
replacement_symbols_tensor->vec<std::string>();
replacement_symbols_tensor->vec<tensorflow::tstring>();

OP_REQUIRES(context, symbols.size() == replacement_symbols.size(),
tensorflow::errors::InvalidArgument(absl::StrCat(
Expand All @@ -86,6 +87,7 @@ class TfqPsSymbolReplaceOp : public tensorflow::OpKernel {
int sidx = i % n_symbols;
int pidx = i / n_symbols;
std::string symbol_to_replace = symbols(sidx);
std::string temp_symbol_holder;
Program cur_program = programs.at(pidx);
for (int j = 0; j < cur_program.circuit().moments().size(); j++) {
Moment cur_moment = cur_program.circuit().moments().at(j);
Expand All @@ -98,14 +100,18 @@ class TfqPsSymbolReplaceOp : public tensorflow::OpKernel {
if (arg.symbol() == symbol_to_replace) {
// Copy the proto, modify the symbol and append to output.
Program temp(cur_program);

// temp_symbol_holder is needed to avoid call ambiguity for
// set_symbol below.
temp_symbol_holder = replacement_symbols(sidx);
temp.mutable_circuit()
->mutable_moments()
->at(j)
.mutable_operations()
->at(k)
.mutable_args()
->at(key)
.set_symbol(replacement_symbols(sidx));
.set_symbol(temp_symbol_holder);

std::string res;
temp.SerializeToString(&res);
Expand Down Expand Up @@ -147,7 +153,7 @@ class TfqPsSymbolReplaceOp : public tensorflow::OpKernel {
output_shape.AddDim(biggest_pad);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

auto output_tensor = output->tensor<std::string, 3>();
auto output_tensor = output->tensor<tensorflow::tstring, 3>();

// TODO: investigate whether or not it is worth this parallelization at the
// end.
Expand Down Expand Up @@ -194,6 +200,12 @@ REGISTER_OP("TfqPsSymbolReplace")
TF_RETURN_IF_ERROR(
c->WithRank(c->input(2), 1, &replacement_symbols_shape));

c->set_output(
0, c->MakeShape(
{c->Dim(programs_shape, 0),
tensorflow::shape_inference::InferenceContext::kUnknownDim,
tensorflow::shape_inference::InferenceContext::kUnknownDim}));

return tensorflow::Status::OK();
});

Expand Down
10 changes: 8 additions & 2 deletions tensorflow_quantum/core/ops/tfq_ps_weights_from_symbols_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ class TfqPsWeightsFromSymbolOp : public tensorflow::OpKernel {

// Parse the input string here.
const Tensor *symbols_tensor;
context->input("symbols", &symbols_tensor);
OP_REQUIRES_OK(context, context->input("symbols", &symbols_tensor));
OP_REQUIRES(
context, symbols_tensor->dims() == 1,
tensorflow::errors::InvalidArgument(absl::StrCat(
"symbols must be rank 1. Got rank ", symbols_tensor->dims(), ".")));

const auto symbols = symbols_tensor->vec<std::string>();
const auto symbols = symbols_tensor->vec<tensorflow::tstring>();
const int n_symbols = symbols.size();

// (i,j,k) = the kth scalar value found for symbols(j) in programs(i).
Expand Down Expand Up @@ -176,6 +176,12 @@ REGISTER_OP("TfqPsWeightsFromSymbols")
tensorflow::shape_inference::ShapeHandle symbols_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &symbols_shape));

c->set_output(
0, c->MakeShape(
{c->Dim(programs_shape, 0),
tensorflow::shape_inference::InferenceContext::kUnknownDim,
tensorflow::shape_inference::InferenceContext::kUnknownDim}));

return tensorflow::Status::OK();
});

Expand Down

0 comments on commit 260ad7e

Please sign in to comment.