Skip to content

Commit

Permalink
Update cnn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wepe committed Jun 6, 2016
1 parent 1ead15f commit 970e34e
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions DeepLearning Tutorials/keras_usage/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cnn.py
CPU run command:
python cnn.py
2016.06.06更新:
这份代码是keras开发初期写的,当时keras还没有现在这么流行,文档也还没那么丰富,所以我当时写了一些简单的教程。
现在keras的API也发生了一些的变化,建议及推荐直接上keras.io看更加详细的教程。
'''
#导入各种用到的模块组件
from __future__ import absolute_import
Expand All @@ -19,6 +24,9 @@
from six.moves import range
from data import load_data
import random
import numpy as np

np.random.seed(1024) # for reproducibility



Expand Down Expand Up @@ -46,7 +54,7 @@
#border_mode可以是valid或者full,具体看这里说明:http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv.conv2d
#激活函数用tanh
#你还可以在model.add(Activation('tanh'))后加上dropout的技巧: model.add(Dropout(0.5))
model.add(Convolution2D(4, 5, 5, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Convolution2D(4, 5, 5, border_mode='valid',input_shape=(1,28,28)))
model.add(Activation('tanh'))


Expand Down Expand Up @@ -82,14 +90,14 @@
##############
#使用SGD + momentum
#model.compile里的参数loss就是损失函数(目标函数)
sgd = SGD(l2=0.0,lr=0.05, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")
sgd = SGD(lr=0.05, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd)


#调用fit方法,就是一个训练过程. 训练的epoch数设为10,batch_size为100.
#数据经过随机打乱shuffle=True。verbose=1,训练过程中输出的信息,0、1、2三种方式都可以,无关紧要。show_accuracy=True,训练时每一个epoch都输出accuracy。
#validation_split=0.2,将20%的数据作为验证集。
model.fit(data, label, batch_size=100, nb_epoch=10,shuffle=True,verbose=1,show_accuracy=True,validation_split=0.2)
model.fit(data, label, batch_size=100, nb_epoch=10,shuffle=True,verbose=1,validation_split=0.2)


"""
Expand Down Expand Up @@ -123,4 +131,3 @@
progbar.add(X_batch.shape[0], values=[("train loss", loss),("accuracy:", accuracy)] )
"""

0 comments on commit 970e34e

Please sign in to comment.