Skip to content

Commit 2df6cd3

Browse files
authored
Another try
1 parent 49292d1 commit 2df6cd3

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tensorflow/compiler/xla/util.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,23 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
195195
// 2. permutation.size() == input.size().
196196
template <template <typename...> class C, typename T>
197197
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
198-
C<T> input_) {
199-
tensorflow::gtl::ArraySlice<T> input(input_);
200-
CHECK(IsPermutation(permutation, input.size()));
201-
std::vector<T> output(input.size());
198+
C<T> input) {
199+
tensorflow::gtl::ArraySlice<T> data(input);
200+
CHECK(IsPermutation(permutation, data.size()));
201+
std::vector<T> output(data.size());
202202
for (size_t i = 0; i < permutation.size(); ++i) {
203-
output[permutation[i]] = input[i];
203+
output[permutation[i]] = data[i];
204204
}
205205
return output;
206206
}
207207

208+
// Override of the above that works around compile failures with vectors.
209+
template <typename T>
210+
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
211+
const std::vector<T>& input) {
212+
return Permute<std::vector, T>(permutation, input);
213+
}
214+
208215
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
209216
std::vector<int64> InversePermutation(
210217
tensorflow::gtl::ArraySlice<int64> input_permutation);

0 commit comments

Comments
 (0)