Skip to content

Commit

Permalink
clean up training loop params
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhughes27 committed Dec 28, 2016
1 parent 9b7dc7d commit 5d80b20
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
26 changes: 16 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@

sess.run(tf.global_variables_initializer())

# Train the Model
for i in range(20000):
batch = data.next_batch(100)
# Training loop variables
epochs = 30
batch_size = 100
num_samples = data.num_examples
step_size = int(num_samples / batch_size)

if i%100 == 0:
train_accuracy = loss.eval(feed_dict={model.x:batch[0], model.y_: batch[1], model.keep_prob: 1.0})
print("step: %d loss: %g"%(i, train_accuracy))
for epoch in range(epochs):
for i in range(step_size):
batch = data.next_batch(100)

train_step.run(feed_dict={model.x: batch[0], model.y_: batch[1], model.keep_prob: 0.5})
train_step.run(feed_dict={model.x: batch[0], model.y_: batch[1], model.keep_prob: 0.8})

# Save the Model
saver = tf.train.Saver()
saver.save(sess, "model.ckpt")
if i%10 == 0:
loss_value = loss.eval(feed_dict={model.x:batch[0], model.y_: batch[1], model.keep_prob: 1.0})
print("epoch: %d step: %d loss: %g"%(epoch, epoch * batch_size + i, loss_value))

saver = tf.train.Saver()
saver.save(sess, "model.ckpt")
print("model saved")
4 changes: 4 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def __init__(self):
self._index_in_epoch = 0
self._num_examples = self._X.shape[0]

@property
def num_examples(self):
return self._num_examples

def next_batch(self, batch_size):
start = self._index_in_epoch
self._index_in_epoch += batch_size
Expand Down

0 comments on commit 5d80b20

Please sign in to comment.