Skip to content

Commit

Permalink
fix issue hanbt#2
Browse files Browse the repository at this point in the history
  • Loading branch information
hanbingtao committed Aug 30, 2017
1 parent eff712e commit f8c5253
Showing 1 changed file with 178 additions and 0 deletions.
178 changes: 178 additions & 0 deletions mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-

import struct
from fc import *
from datetime import datetime


# 数据加载器基类
class Loader(object):
def __init__(self, path, count):
'''
初始化加载器
path: 数据文件路径
count: 文件中的样本个数
'''
self.path = path
self.count = count

def get_file_content(self):
'''
读取文件内容
'''
f = open(self.path, 'rb')
content = f.read()
f.close()
return content

def to_int(self, byte):
'''
将unsigned byte字符转换为整数
'''
return struct.unpack('B', byte)[0]


# 图像数据加载器
class ImageLoader(Loader):
def get_picture(self, content, index):
'''
内部函数,从文件中获取图像
'''
start = index * 28 * 28 + 16
picture = []
for i in range(28):
picture.append([])
for j in range(28):
picture[i].append(
self.to_int(content[start + i * 28 + j]))
return picture

def get_one_sample(self, picture):
'''
内部函数,将图像转化为样本的输入向量
'''
sample = []
for i in range(28):
for j in range(28):
sample.append(picture[i][j])
return sample

def load(self):
'''
加载数据文件,获得全部样本的输入向量
'''
content = self.get_file_content()
data_set = []
for index in range(self.count):
data_set.append(
self.get_one_sample(
self.get_picture(content, index)))
return data_set


# 标签数据加载器
class LabelLoader(Loader):
def load(self):
'''
加载数据文件,获得全部样本的标签向量
'''
content = self.get_file_content()
labels = []
for index in range(self.count):
labels.append(self.norm(content[index + 8]))
return labels

def norm(self, label):
'''
内部函数,将一个值转换为10维标签向量
'''
label_vec = []
label_value = self.to_int(label)
for i in range(10):
if i == label_value:
label_vec.append(0.9)
else:
label_vec.append(0.1)
return label_vec


def get_training_data_set():
'''
获得训练数据集
'''
image_loader = ImageLoader('train-images-idx3-ubyte', 60000)
label_loader = LabelLoader('train-labels-idx1-ubyte', 60000)
return image_loader.load(), label_loader.load()


def get_test_data_set():
'''
获得测试数据集
'''
image_loader = ImageLoader('t10k-images-idx3-ubyte', 10000)
label_loader = LabelLoader('t10k-labels-idx1-ubyte', 10000)
return image_loader.load(), label_loader.load()


def show(sample):
str = ''
for i in range(28):
for j in range(28):
if sample[i*28+j] != 0:
str += '*'
else:
str += ' '
str += '\n'
print str


def get_result(vec):
max_value_index = 0
max_value = 0
for i in range(len(vec)):
if vec[i] > max_value:
max_value = vec[i]
max_value_index = i
return max_value_index


def evaluate(network, test_data_set, test_labels):
error = 0
total = len(test_data_set)

for i in range(total):
label = get_result(test_labels[i])
predict = get_result(network.predict(test_data_set[i]))
if label != predict:
error += 1
return float(error) / float(total)


def now():
return datetime.now().strftime('%c')


def train_and_evaluate():
last_error_ratio = 1.0
epoch = 0
train_data_set, train_labels = transpose(get_training_data_set())
test_data_set, test_labels = transpose(get_test_data_set())
network = Network([784, 100, 10])
while True:
epoch += 1
network.train(train_labels, train_data_set, 0.01, 1)
print '%s epoch %d finished, loss %f' % (now(), epoch,
network.loss(train_labels[-1], network.predict(train_data_set[-1])))
if epoch % 2 == 0:
error_ratio = evaluate(network, test_data_set, test_labels)
print '%s after epoch %d, error ratio is %f' % (now(), epoch, error_ratio)
if error_ratio > last_error_ratio:
break
else:
last_error_ratio = error_ratio

if __name__ == '__main__':
train_and_evaluate()


0 comments on commit f8c5253

Please sign in to comment.