Skip to content

Commit 7369328

Browse files
author
李闯
committed
add tf_classify_demo
1 parent 8e6d6aa commit 7369328

File tree

4 files changed

+432
-0
lines changed

4 files changed

+432
-0
lines changed

tf_classify_demo/classify.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
利用tensorflow做图书分类模型训练
3+
"""
4+
#!/usr/bin/env python
5+
# coding=utf8
6+
7+
import sys
8+
import tensorflow as tf
9+
from http.server import BaseHTTPRequestHandler, HTTPServer
10+
import urllib.parse as parse
11+
from sample_data import InputData
12+
13+
samples = InputData.read_data_sets('./data/sample/samples')
14+
config = tf.ConfigProto(device_count={'CPU':4})
15+
sess = tf.InteractiveSession(config=config)
16+
feature_len = samples.dim_info.x_dim
17+
x = tf.placeholder(tf.float32, [None, feature_len])
18+
W = tf.Variable(tf.zeros([feature_len, samples.maps.group_id_size()]))
19+
b = tf.Variable(tf.zeros([samples.maps.group_id_size()]))
20+
y = tf.nn.softmax(tf.matmul(x, W) + b)
21+
y_ = tf.placeholder(tf.float32, [None, samples.maps.group_id_size()])
22+
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)),\
23+
reduction_indices=[1]))
24+
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cross_entropy)
25+
tf.global_variables_initializer().run()
26+
saver = tf.train.Saver()
27+
28+
def train(samples, sess, x, y, y_, train_step):
29+
"""
30+
利用无隐藏层的softmax实现简单的分类模型
31+
"""
32+
33+
samples.clear_word_vector()
34+
test_xs, test_ys = samples.test_sets()
35+
36+
for i in range(10000):
37+
batch_xs, batch_ys = samples.next_batch(1)
38+
train_step.run({x: batch_xs, y_: batch_ys})
39+
40+
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
41+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
42+
print(accuracy.eval({x: test_xs, y_: test_ys}))
43+
saver.save(sess, 'data/model/model')
44+
45+
def predict(samples, sess, x, y, y_, train_step):
46+
x_s = samples.generate_xs('数据科学入门')
47+
print(sess.run(tf.argmax(y, 1), feed_dict={x:x_s}))
48+
49+
class MyServer(BaseHTTPRequestHandler):
50+
def do_GET(self):
51+
self.send_response(200)
52+
self.send_header("Content-type", "application/json")
53+
self.end_headers()
54+
arg_dict = parse.urlparse(self.path)
55+
if len(arg_dict.query) > 0 and 'q' in parse.parse_qs(arg_dict.query):
56+
q = parse.parse_qs(arg_dict.query)['q'][0]
57+
x_s = samples.generate_xs(q)
58+
local_group_id = sess.run(tf.argmax(y, 1), feed_dict={x:x_s})[0]
59+
group_id = samples.maps.real_group_id_map[str(local_group_id)]
60+
print("q=", q, "group_id=", group_id)
61+
self.wfile.write(bytes(str(group_id), "utf-8"))
62+
63+
def main(is_predict):
64+
if is_predict:
65+
saver.restore(sess, 'data/model/model')
66+
#predict(samples, sess, x, y, y_, train_step)
67+
myServer = HTTPServer(("0.0.0.0", 5001), MyServer)
68+
print("begin listen")
69+
myServer.serve_forever()
70+
else:
71+
train(samples, sess, x, y, y_, train_step)
72+
73+
74+
if __name__ == '__main__':
75+
is_predict = True
76+
if len(sys.argv) > 1 and sys.argv[1] == "train":
77+
is_predict = False
78+
main(is_predict)

tf_classify_demo/data/sample/samples

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
5 0-3岁孩子的正面管教
2+
2 11处特工皇妃
3+
5 50个教育法:我把三个儿子送入了斯坦福
4+
1 Excel这么用就对了
5+
1 JavaScript高级程序设计
6+
1 PPT,要你好看
7+
1 Python编程 从入门到实践
8+
1 TensorFlow实战亿级流量网站架构核心技术
9+
2 三生三世枕上书
10+
4 你从未真正拼过
11+
4 做人要稳,做事要狠
12+
3 公司理财
13+
3 古老东方投资术的现代指南
14+
2 和你在一起才是全世界2:么么哒
15+
3 国富论
16+
5 好妈妈胜过好老师
17+
5 如何培养出优秀的孩子
18+
5 如何说孩子才会听
19+
5 孩子:挑战
20+
2 守夜者:罪案终结者的觉醒
21+
2 官路十八弯4
22+
4 小强升职记:时间管理故事书
23+
5 当我遇见一个人:母婴关系决定孩子的一切关系
24+
3 彼得•林奇的成功投资
25+
2 意外事故 (少年绘明星系列丛书)
26+
5 捕捉儿童敏感期
27+
1 数学之美
28+
2 朱元璋传
29+
1 机器学习
30+
1 机器学习实战
31+
5 正面管教
32+
2 步履不停
33+
1 深入浅出数据分析
34+
3 澄明之境:青泽谈投资之道
35+
5 真正的蒙氏教育在家庭:蒙台梭利家庭教育解决方案
36+
3 穷爸爸富爸爸
37+
4 职场加分项:成为卓有成效的职业人
38+
5 聪明的妈妈教方法
39+
3 聪明的投资者
40+
3 股票大作手回忆录
41+
3 证券分析
42+
2 这世界偷偷爱着你
43+
4 这些道理没有人告诉过你
44+
4 这就是我背叛自己的方式
45+
5 这样跟孩子定规矩,孩子最不会抵触
46+
5 遇见孩子,遇见更好的自己
47+
1 马云:未来已来
48+
1 鸟哥的Linux私房菜

tf_classify_demo/sample_data.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
样本加载
3+
"""
4+
# coding=utf8
5+
6+
import sys
7+
import random
8+
import jieba
9+
import numpy as np
10+
from word_vectors_loader import get_words_sizes, load_vectors
11+
12+
VECTORS_BIN = 'data/wordvec/vectors.bin'
13+
TEST_COUNT = 5
14+
15+
16+
class DimInfo(object):
17+
"""
18+
维度信息
19+
"""
20+
21+
def __init__(self):
22+
# 词向量有多少维
23+
self.vec_dim = 0
24+
# 样本输入的x有多少维
25+
self.x_dim = 0
26+
# 当前最大的词编号是多大
27+
self.max_word_id = -1
28+
29+
def get_vec_dim(self):
30+
"""
31+
get_vec_dim
32+
"""
33+
return self.vec_dim
34+
35+
def get_x_dim(self):
36+
"""
37+
get_x_dim
38+
"""
39+
return self.x_dim
40+
41+
42+
class Maps(object):
43+
"""
44+
各种映射表
45+
"""
46+
47+
def __init__(self):
48+
self.local_word_id_map = {}
49+
self.local_group_id_map = {"1":0, "2":1, "3":2, "4":3, "5":4}
50+
self.real_group_id_map = {}
51+
for key in self.local_group_id_map:
52+
value = str(self.local_group_id_map[key])
53+
self.real_group_id_map[value] = int(key)
54+
55+
def get_local_word_id_map(self):
56+
"""
57+
get_local_word_id_map
58+
"""
59+
return self.local_word_id_map
60+
61+
def get_local_group_id_map(self):
62+
"""
63+
get_local_group_id_map
64+
"""
65+
return self.local_group_id_map
66+
67+
def group_id_size(self):
68+
"""
69+
获取local_group的数量
70+
"""
71+
return len(self.local_group_id_map)
72+
73+
74+
class InputData(object):
75+
"""
76+
样本加载类
77+
"""
78+
79+
def __init__(self):
80+
self.data = []
81+
self.test_data = []
82+
self.pos = 0
83+
self.word_vector_dict, self.word_id_dict = load_vectors(VECTORS_BIN)
84+
self.dim_info = DimInfo()
85+
self.maps = Maps()
86+
_, self.dim_info.vec_dim = get_words_sizes(VECTORS_BIN)
87+
self.dim_info.x_dim = len(self.word_vector_dict) * self.dim_info.vec_dim
88+
self.maps.local_word_id_map = {}
89+
90+
def clear_word_vector(self):
91+
"""
92+
清理点内存
93+
"""
94+
self.word_vector_dict.clear()
95+
self.word_id_dict.clear()
96+
97+
@staticmethod
98+
def read_data_sets(file_name):
99+
"""
100+
读取文件,加载数据
101+
"""
102+
instance = InputData()
103+
file_object = open(file_name, 'r')
104+
while True:
105+
line = file_object.readline(1024)
106+
if line:
107+
line = line.strip()
108+
if len(line) == 0:
109+
continue
110+
split = line.split(' ')
111+
group_id = 0
112+
try:
113+
group_id = int(split[0])
114+
except ValueError:
115+
continue
116+
txt = ' '.join(split[1:])
117+
txt = txt.replace('None', '').strip()
118+
if len(txt) == 0:
119+
continue
120+
121+
vectors = {}
122+
seg_list = jieba.cut(txt)
123+
for seg in seg_list:
124+
seg_unicode = seg.encode('utf-8')
125+
if seg_unicode in instance.word_vector_dict:
126+
word_id = instance.word_id_dict[seg_unicode]
127+
if word_id in instance.maps.local_word_id_map:
128+
local_word_id = instance.maps.local_word_id_map[word_id]
129+
vectors[local_word_id] = instance.word_vector_dict[seg_unicode]
130+
else:
131+
local_word_id = instance.dim_info.max_word_id
132+
instance.maps.local_word_id_map[word_id] = local_word_id
133+
vectors[local_word_id] = instance.word_vector_dict[seg_unicode]
134+
instance.dim_info.max_word_id = instance.dim_info.max_word_id + 1
135+
136+
# 稀疏向量
137+
item = {'vectors':vectors,
138+
'local_group_id':instance.maps.local_group_id_map[str(group_id)]}
139+
instance.data.append(item)
140+
else:
141+
break
142+
file_object.close()
143+
144+
random.shuffle(instance.data)
145+
for _ in range(TEST_COUNT):
146+
instance.test_data.append(instance.data.pop())
147+
instance.dim_info.x_dim = instance.dim_info.max_word_id * instance.dim_info.vec_dim
148+
print("max_word_id=", instance.dim_info.max_word_id)
149+
print("x_dim=", instance.dim_info.x_dim)
150+
return instance
151+
152+
def generate_xs(self, txt):
153+
"""
154+
根据文本生成输入向量
155+
"""
156+
x_s = []
157+
vectors = {}
158+
seg_list = jieba.cut(txt)
159+
for seg in seg_list:
160+
seg_unicode = seg.encode('utf-8')
161+
if seg_unicode in self.word_vector_dict:
162+
word_id = self.word_id_dict[seg_unicode]
163+
if word_id in self.maps.local_word_id_map:
164+
local_word_id = self.maps.local_word_id_map[word_id]
165+
vectors[local_word_id] = self.word_vector_dict[seg_unicode]
166+
167+
x_array = np.zeros([self.dim_info.x_dim], dtype=np.float)
168+
for word_id in vectors:
169+
vector = vectors[word_id]
170+
for index, weight in enumerate(vector):
171+
x_array[word_id*self.dim_info.vec_dim+index] = weight
172+
x_s.append(x_array)
173+
return x_s
174+
175+
176+
def next_batch(self, count):
177+
"""
178+
获取一批样本数据
179+
"""
180+
x_s = []
181+
y_s = []
182+
if self.pos >= len(self.data):
183+
print("error")
184+
sys.exit(1)
185+
while count > 0:
186+
item = self.data[self.pos]
187+
vectors = item['vectors']
188+
local_group_id = item['local_group_id']
189+
x_array = np.zeros([self.dim_info.x_dim], dtype=np.float)
190+
y_array = np.zeros(self.maps.group_id_size(), dtype=np.float)
191+
y_array[local_group_id] = 1
192+
for word_id in vectors:
193+
vector = vectors[word_id]
194+
for index, weight in enumerate(vector):
195+
x_array[word_id*self.dim_info.vec_dim+index] = weight
196+
x_s.append(x_array)
197+
y_s.append(y_array)
198+
self.pos = (self.pos + 1) % len(self.data)
199+
count = count - 1
200+
return x_s, y_s
201+
202+
def test_sets(self):
203+
"""
204+
获取测试样本集
205+
"""
206+
x_s = []
207+
y_s = []
208+
for item in self.test_data:
209+
vectors = item['vectors']
210+
local_group_id = item['local_group_id']
211+
x_array = np.zeros([self.dim_info.x_dim], dtype=np.float)
212+
y_array = np.zeros(self.maps.group_id_size(), dtype=np.float)
213+
y_array[local_group_id] = 1
214+
for word_id in vectors:
215+
vector = vectors[word_id]
216+
for index, weight in enumerate(vector):
217+
x_array[word_id*self.dim_info.vec_dim+index] = weight
218+
x_s.append(x_array)
219+
y_s.append(y_array)
220+
return x_s, y_s
221+
222+
223+
if __name__ == '__main__':
224+
CLUES = InputData.read_data_sets('./data/sample/samples')
225+
XS, YS = CLUES.next_batch(2)
226+
print(XS)
227+
print(XS[0].shape)
228+
print(YS)

0 commit comments

Comments
 (0)