-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtypes.cpp
122 lines (104 loc) · 2.72 KB
/
types.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "types.h"
#include <cstddef>
#include <cstdint>
namespace SimpleInfer {
template<>
bool IsSameDataType<float>(const DataType data_type) {
return (DataType::kFloat32 == data_type);
}
template<>
bool IsSameDataType<double>(const DataType data_type) {
return (DataType::kFloat64 == data_type);
}
template<>
bool IsSameDataType<int32_t>(const DataType data_type) {
return (DataType::kInt32 == data_type);
}
template<>
bool IsSameDataType<int64_t>(const DataType data_type) {
return (DataType::kInt64 == data_type);
}
template<>
bool IsSameDataType<int16_t>(const DataType data_type) {
return (DataType::kInt16 == data_type);
}
template<>
bool IsSameDataType<int8_t>(const DataType data_type) {
return (DataType::kInt8 == data_type);
}
template<>
bool IsSameDataType<uint8_t>(const DataType data_type) {
return (DataType::kUint8 == data_type);
}
template<>
bool IsSameDataType<bool>(const DataType data_type) {
return (DataType::kBool == data_type);
}
DataType PnnxToDataType(int type) {
switch (type) {
case 1:
return DataType::kFloat32;
case 2:
return DataType::kFloat64;
case 3:
return DataType::kFloat16;
case 4:
return DataType::kInt32;
case 5:
return DataType::kInt64;
case 6:
return DataType::kInt16;
case 7:
return DataType::kInt8;
case 8:
return DataType::kUint8;
case 9:
return DataType::kBool;
case 10:
return DataType::kComplex64;
case 11:
return DataType::kComplex128;
case 12:
return DataType::kComplex32;
default:
return DataType::kNone;
}
return DataType::kNone;
}
int ElementSize(const DataType data_type) {
switch (data_type) {
case DataType::kInt8:
case DataType::kUint8:
case DataType::kBool:
return 1;
case DataType::kFloat16:
case DataType::kInt16:
return 2;
case DataType::kFloat32:
case DataType::kInt32:
case DataType::kComplex32:
return 4;
case DataType::kFloat64:
case DataType::kInt64:
case DataType::kComplex64:
return 8;
case DataType::kComplex128:
return 16;
default:
return 0;
}
return 0;
}
bool IsSameShape(const std::vector<int>& shape0,
const std::vector<int>& shape1) {
if (shape0.size() != shape1.size()) {
return false;
}
for (size_t i = 0; i < shape0.size(); ++i) {
if (shape0[i] != shape1[i]) {
return false;
}
}
return true;
}
} // namespace SimpleInfer