-
Notifications
You must be signed in to change notification settings - Fork 66
/
uwnet.py
182 lines (135 loc) · 4.7 KB
/
uwnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import sys, os
from ctypes import *
import math
import random
lib = CDLL(os.path.join(os.path.dirname(__file__), "libuwnet.so"), RTLD_GLOBAL)
def c_array(ctype, values):
arr = (ctype*len(values))()
arr[:] = values
return arr
class IMAGE(Structure):
_fields_ = [("w", c_int),
("h", c_int),
("c", c_int),
("data", POINTER(c_float))]
def __add__(self, other):
return add_image(self, other)
def __sub__(self, other):
return sub_image(self, other)
class MATRIX(Structure):
_fields_ = [("rows", c_int),
("cols", c_int),
("data", POINTER(c_float)),
("shallow", c_int)]
class DATA(Structure):
_fields_ = [("X", MATRIX),
("y", MATRIX)]
class LAYER(Structure):
pass
LAYER._fields_ = [("in", POINTER(MATRIX)),
("out", POINTER(MATRIX)),
("delta", POINTER(MATRIX)),
("w", MATRIX),
("dw", MATRIX),
("b", MATRIX),
("db", MATRIX),
("activation", c_int),
("type", c_int),
("forward", CFUNCTYPE(MATRIX, POINTER(LAYER), MATRIX)),
("backward", CFUNCTYPE(None, POINTER(LAYER), MATRIX)),
("update", CFUNCTYPE(None, POINTER(LAYER), c_float, c_float, c_float))]
class NET(Structure):
_fields_ = [("layers", POINTER(LAYER)),
("n", c_int)]
(LINEAR, LOGISTIC, RELU, LRELU, SOFTMAX) = range(5)
add_image = lib.add_image
add_image.argtypes = [IMAGE, IMAGE]
add_image.restype = IMAGE
sub_image = lib.sub_image
sub_image.argtypes = [IMAGE, IMAGE]
sub_image.restype = IMAGE
make_image = lib.make_image
make_image.argtypes = [c_int, c_int, c_int]
make_image.restype = IMAGE
free_image = lib.free_image
free_image.argtypes = [IMAGE]
get_pixel = lib.get_pixel
get_pixel.argtypes = [IMAGE, c_int, c_int, c_int]
get_pixel.restype = c_float
set_pixel = lib.set_pixel
set_pixel.argtypes = [IMAGE, c_int, c_int, c_int, c_float]
set_pixel.restype = None
copy_image = lib.copy_image
copy_image.argtypes = [IMAGE]
copy_image.restype = IMAGE
clamp_image = lib.clamp_image
clamp_image.argtypes = [IMAGE]
clamp_image.restype = None
shift_image = lib.shift_image
shift_image.argtypes = [IMAGE, c_int, c_float]
shift_image.restype = None
load_image_lib = lib.load_image
load_image_lib.argtypes = [c_char_p]
load_image_lib.restype = IMAGE
def load_image(f):
return load_image_lib(f.encode('utf-8'))
# Filetypes
(PNG, BMP, TGA, JPG) = range(4)
save_image_options_lib = lib.save_image_options
save_image_options_lib.argtypes = [IMAGE, c_char_p, c_int, c_int]
save_image_options_lib.restype = None
def save_image(im, f):
return save_image_options_lib(im, f.encode('utf-8'), JPG, 80)
def save_png(im, f):
return save_image_options_lib(im, f.encode('utf-8'), PNG, 0)
nn_resize = lib.nn_resize
nn_resize.argtypes = [IMAGE, c_int, c_int]
nn_resize.restype = IMAGE
bilinear_resize = lib.bilinear_resize
bilinear_resize.argtypes = [IMAGE, c_int, c_int]
bilinear_resize.restype = IMAGE
train_image_classifier = lib.train_image_classifier
train_image_classifier.argtypes = [NET, DATA, c_int, c_int, c_float, c_float, c_float]
train_image_classifier.restype = None
accuracy_net = lib.accuracy_net
accuracy_net.argtypes = [NET, DATA]
accuracy_net.restype = c_float
forward_net = lib.forward_net
forward_net.argtypes = [NET, MATRIX]
forward_net.restype = MATRIX
load_image_classification_data_lib = lib.load_image_classification_data
load_image_classification_data_lib.argtypes = [c_char_p, c_char_p]
load_image_classification_data_lib.restype = DATA
def load_image_classification_data(images, labels):
return load_image_classification_data_lib(images.encode('utf-8'), labels.encode('utf-8'))
make_connected_layer = lib.make_connected_layer
make_connected_layer.argtypes = [c_int, c_int, c_int]
make_connected_layer.restype = LAYER
save_weights_lib = lib.save_weights
save_weights_lib.argtypes = [NET, c_char_p]
save_weights_lib.restype = None
load_weights_lib = lib.load_weights
load_weights_lib.argtypes = [NET, c_char_p]
load_weights_lib.restype = None
def save_weights(net, f):
save_weights_lib(net, f.encode('utf-8'))
def load_weights(net, f):
load_weights_lib(net, f.encode('utf-8'))
print_matrix = lib.print_matrix
print_matrix.argtypes = [MATRIX]
print_matrix.restype = None
def run_net_image(net, im):
m = MATRIX()
m.rows = 1
m.cols = im.h*im.w*im.c
m.data = im.data
m.shallow = 1
return forward_net(net, m)
def make_net(layers):
m = NET()
m.n = len(layers)
m.layers = (LAYER*m.n) (*layers)
return m
if __name__ == "__main__":
im = load_image("data/dog.jpg")
save_image(im, "hey")