Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangchen1ren committed Jan 28, 2017
1 parent 847f399 commit 58cd0a9
Show file tree
Hide file tree
Showing 6 changed files with 608 additions and 15 deletions.
6 changes: 3 additions & 3 deletions 01_tensorflow_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
y = tf.placeholder(tf.float32, shape=[None, 2])

# weights and bias are the variables to be trained
weights = tf.Variable(tf.random_normal([6, 2]))
bias = tf.Variable(tf.zeros([2]))
weights = tf.Variable(tf.random_normal([6, 2]), name='weights')
bias = tf.Variable(tf.zeros([2]), name='bias')
y_pred = tf.nn.softmax(tf.matmul(X, weights) + bias)

# Minimise cost using cross entropy
Expand All @@ -61,7 +61,7 @@
# use session to run the calculation
with tf.Session() as sess:
# variables have to be initialized at the first place
tf.initialize_all_variables().run()
tf.global_variables_initializer().run()

# training loop
for epoch in range(10):
Expand Down
15 changes: 7 additions & 8 deletions 02_tensorflow_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
y_pred = tf.nn.softmax(tf.matmul(X, weights) + bias)

# add histogram summaries for weights, view on tensorboard
tf.histogram_summary('weights', weights)
tf.histogram_summary('bias', bias)
tf.summary.histogram('weights', weights)
tf.summary.histogram('bias', bias)

# Minimise cost using cross entropy
# NOTE: add a epsilon(1e-10) when calculate log(y_pred),
Expand All @@ -60,7 +60,7 @@
cross_entropy = - tf.reduce_sum(y_true * tf.log(y_pred + 1e-10),
reduction_indices=1)
cost = tf.reduce_mean(cross_entropy)
tf.scalar_summary('loss', cost)
tf.summary.scalar('loss', cost)

# use gradient descent optimizer to minimize cost
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
Expand All @@ -69,7 +69,7 @@
correct_pred = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_pred, 1))
acc_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Add scalar summary for accuracy
tf.scalar_summary('accuracy', acc_op)
tf.summary.scalar('accuracy', acc_op)

global_step = tf.Variable(0, name='global_step', trainable=False)
# use saver to save and restore model
Expand All @@ -89,11 +89,11 @@
# use session to run the calculation
with tf.Session() as sess:
# create a log writer. run 'tensorboard --logdir=./logs'
writer = tf.train.SummaryWriter('./logs', sess.graph)
merged = tf.merge_all_summaries()
writer = tf.summary.FileWriter('./logs', sess.graph)
merged = tf.summary.merge_all()

# variables have to be initialized at the first place
tf.initialize_all_variables().run()
tf.global_variables_initializer().run()

# restore variables from checkpoint if exists
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
Expand Down Expand Up @@ -125,7 +125,6 @@
global_step.assign(epoch).eval()
saver.save(sess, ckpt_dir + '/logistic.ckpt',
global_step=global_step)

print('Training complete!')

################################
Expand Down
37 changes: 37 additions & 0 deletions csv_to_tfrecords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#! -*- coding:utf-8 -*-

import pandas as pd
import tensorflow as tf


# convert train.csv to train.tfrecords
def transform_to_tfrecord():
data = pd.read_csv('data/train.csv')
tfrecord_file = 'train.tfrecords'

def int_feature(value):
return tf.train.Feature(
int64_list=tf.train.Int64List(value=[value]))

def float_feature(value):
return tf.train.Feature(
float_list=tf.train.FloatList(value=[value]))

writer = tf.python_io.TFRecordWriter(tfrecord_file)
for i in range(len(data)):
features = tf.train.Features(feature={
'Age': float_feature(data['Age'][i]),
'Survived': int_feature(data['Survived'][i]),
'Pclass': int_feature(data['Pclass'][i]),
'Parch': int_feature(data['Parch'][i]),
'SibSp': int_feature(data['SibSp'][i]),
'Sex': int_feature(1 if data['Sex'][i] == 'male' else 0),
'Fare': float_feature(data['Fare'][i])
})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()


if __name__ == '__main__':
transform_to_tfrecord()
Loading

0 comments on commit 58cd0a9

Please sign in to comment.