Skip to content

Commit

Permalink
update lr in the begining of each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Jul 19, 2020
1 parent 80ba217 commit fd29199
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
3 changes: 1 addition & 2 deletions data/image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def __init__(self, root, transform=None, return_paths=False,
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
Expand Down
3 changes: 2 additions & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ def get_image_paths(self):

def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
old_lr = self.optimizers[0].param_groups[0]['lr']
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()

lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
print('learning rate %.7f -> %.7f' % (old_lr, lr))

def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
Expand Down
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch

model.update_learning_rate() # update learning rates in the beginning of every epoch.
for i, data in enumerate(dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % opt.print_freq == 0:
Expand Down Expand Up @@ -75,4 +75,3 @@
model.save_networks(epoch)

print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
model.update_learning_rate() # update learning rates at the end of every epoch.

0 comments on commit fd29199

Please sign in to comment.