forked from apple/coremltools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CoreMLPythonArray.mm
56 lines (45 loc) · 1.78 KB
/
CoreMLPythonArray.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#import "CoreMLPythonArray.h"
@implementation PybindCompatibleArray
+ (MLMultiArrayDataType)dataTypeOf:(py::array)array {
const auto& dt = array.dtype();
char kind = dt.kind();
size_t itemsize = dt.itemsize();
if(kind == 'i' && itemsize == 4) {
return MLMultiArrayDataTypeInt32;
} else if(kind == 'f' && itemsize == 4) {
return MLMultiArrayDataTypeFloat32;
} else if( (kind == 'f' || kind == 'd') && itemsize == 8) {
return MLMultiArrayDataTypeDouble;
}
throw std::runtime_error("Unsupported array type: " + std::to_string(kind) + " with itemsize = " + std::to_string(itemsize));
}
+ (NSArray<NSNumber *> *)shapeOf:(py::array)array {
NSMutableArray<NSNumber *> *ret = [[NSMutableArray alloc] init];
for (size_t i=0; i<array.ndim(); i++) {
[ret addObject:[NSNumber numberWithUnsignedLongLong:array.shape(i)]];
}
return ret;
}
+ (NSArray<NSNumber *> *)stridesOf:(py::array)array {
// numpy strides is in bytes.
// this type must return number of ELEMENTS! (as per mlkit)
NSMutableArray<NSNumber *> *ret = [[NSMutableArray alloc] init];
for (size_t i=0; i<array.ndim(); i++) {
size_t stride = array.strides(i) / array.itemsize();
[ret addObject:[NSNumber numberWithUnsignedLongLong:stride]];
}
return ret;
}
- (PybindCompatibleArray *)initWithArray:(py::array)array {
self = [super initWithDataPointer:array.mutable_data()
shape:[self.class shapeOf:array]
dataType:[self.class dataTypeOf:array]
strides:[self.class stridesOf:array]
deallocator:nil
error:nil];
if (self) {
m_array = array;
}
return self;
}
@end