Skip to content

Commit

Permalink
try ints
Browse files Browse the repository at this point in the history
  • Loading branch information
zaqqwerty committed May 2, 2020
1 parent 5c383e6 commit 4dd711d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
11 changes: 6 additions & 5 deletions tensorflow_quantum/core/ops/parse_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ Status GetGradients(OpKernelContext* context,

tensorflow::Status GetNumSamples(
tensorflow::OpKernelContext* context,
std::vector<std::vector<unsigned int>>* parsed_num_samples) {
std::vector<std::vector<int>>* parsed_num_samples) {
const Tensor* input_num_samples;
Status status = context->input("num_samples", &input_num_samples);
if (!status.ok()) {
Expand All @@ -270,13 +270,14 @@ tensorflow::Status GetNumSamples(
input_num_samples->dims(), "."));
}

const auto matrix_num_samples = input_num_samples->matrix<unsigned int>();
const auto matrix_num_samples = input_num_samples->matrix<int>();
parsed_num_samples->reserve(matrix_num_samples.dimension(0));
for (unsigned int i = 0; i < matrix_num_samples.dimension(0); i++) {
std::vector<unsigned int> sub_parsed_num_samples;
std::vector<int> sub_parsed_num_samples;
sub_parsed_num_samples.reserve(matrix_num_samples.dimension(1));
for (int j = 0; j < matrix_num_samples.dimension(1); j++) {
sub_parsed_num_samples.push_back(matrix_num_samples(i, j));
for (unsigned int j = 0; j < matrix_num_samples.dimension(1); j++) {
const int num_samples = matrix_num_samples(i, j);
sub_parsed_num_samples.push_back(num_samples);
}
parsed_num_samples->push_back(sub_parsed_num_samples);
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_quantum/core/ops/parse_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ tensorflow::Status GetGradients(tensorflow::OpKernelContext* context,
// Parses the number of samples from the 'num_samples' input tensor.
tensorflow::Status GetNumSamples(
tensorflow::OpKernelContext* context,
std::vector<std::vector<unsigned int>>* parsed_num_samples);
std::vector<std::vector<int>>* parsed_num_samples);

} // namespace tfq

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TfqSimulateSampledExpectationOp : public tensorflow::OpKernel {
programs.size(), " circuits and ", pauli_sums.size(),
" paulisums.")));

std::vector<std::vector<unsigned int>> num_samples;
std::vector<std::vector<int>> num_samples;
OP_REQUIRES_OK(context, GetNumSamples(context, &num_samples));

OP_REQUIRES(context, num_samples.size() == pauli_sums.size(),
Expand Down Expand Up @@ -172,7 +172,7 @@ REGISTER_OP("TfqSimulateSampledExpectation")
.Input("symbol_names: string")
.Input("symbol_values: float")
.Input("pauli_sums: string")
.Input("num_samples: uint32")
.Input("num_samples: int")
.Output("expectations: float")
.SetShapeFn([](tensorflow::shape_inference::InferenceContext *c) {
tensorflow::shape_inference::ShapeHandle programs_shape;
Expand Down

0 comments on commit 4dd711d

Please sign in to comment.