@@ -195,16 +195,23 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
195
195
// 2. permutation.size() == input.size().
196
196
template <template <typename ...> class C , typename T>
197
197
std::vector<T> Permute (tensorflow::gtl::ArraySlice<int64> permutation,
198
- C<T> input_ ) {
199
- tensorflow::gtl::ArraySlice<T> input (input_ );
200
- CHECK (IsPermutation (permutation, input .size ()));
201
- std::vector<T> output (input .size ());
198
+ C<T> input ) {
199
+ tensorflow::gtl::ArraySlice<T> data (input );
200
+ CHECK (IsPermutation (permutation, data .size ()));
201
+ std::vector<T> output (data .size ());
202
202
for (size_t i = 0 ; i < permutation.size (); ++i) {
203
- output[permutation[i]] = input [i];
203
+ output[permutation[i]] = data [i];
204
204
}
205
205
return output;
206
206
}
207
207
208
+ // Override of the above that works around compile failures with vectors.
209
+ template <typename T>
210
+ std::vector<T> Permute (tensorflow::gtl::ArraySlice<int64> permutation,
211
+ const std::vector<T>& input) {
212
+ return Permute<std::vector, T>(permutation, input);
213
+ }
214
+
208
215
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
209
216
std::vector<int64> InversePermutation (
210
217
tensorflow::gtl::ArraySlice<int64> input_permutation);
0 commit comments