Skip to content

Commit

Permalink
update resnet-small.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shuokay committed Jan 14, 2016
1 parent 09260a2 commit 424bd24
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions resnet-small.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def get_dataiter(batch_size=128):
train_dataiter, test_dataiter = get_dataiter(batch_size=batch_size)
finetune=False
if finetune==False:
model = mx.model.FeedForward(ctx=mx.gpu(0), symbol=softmax, num_epoch=10000, learning_rate=0.1, momentum=0.9, wd=0.0001, \
model = mx.model.FeedForward(ctx=mx.gpu(0), symbol=softmax, num_epoch=70, learning_rate=0.1, momentum=0.9, wd=0.0001, \
initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
# initializer=mx.init.Xavier(),
# initializer=mx.init.Normal(),
lr_scheduler=mx.lr_scheduler.FactorScheduler(step =100000000000, factor = 0.95)
)
model.fit(X=train_dataiter, eval_data=test_dataiter, batch_end_callback=mx.callback.Speedometer(batch_size),epoch_end_callback=mx.callback.do_checkpoint("./models/resnet"))
else:
loaded = mx.model.FeedForward.load('models/resnet', 17)
loaded = mx.model.FeedForward.load('models/resnet', 70)
continue_model = mx.model.FeedForward(ctx=mx.gpu(0), symbol = loaded.symbol, arg_params = loaded.arg_params, aux_params = loaded.aux_params, num_epoch=10000, learning_rate=0.01, momentum=0.9, wd=0.0001)
continue_model.fit(X=train_dataiter, eval_data=test_dataiter, batch_end_callback=mx.callback.Speedometer(batch_size),epoch_end_callback=mx.callback.do_checkpoint("./models/resnet"))

0 comments on commit 424bd24

Please sign in to comment.