diff --git a/data/image_folder.py b/data/image_folder.py index d0b4b308a82..80c0b9a2d93 100644 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -44,8 +44,7 @@ def __init__(self, root, transform=None, return_paths=False, imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + - ",".join(IMG_EXTENSIONS))) + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs diff --git a/models/base_model.py b/models/base_model.py index 9cfb761897e..6de961b51a2 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -115,6 +115,7 @@ def get_image_paths(self): def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] for scheduler in self.schedulers: if self.opt.lr_policy == 'plateau': scheduler.step(self.metric) @@ -122,7 +123,7 @@ def update_learning_rate(self): scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] - print('learning rate = %.7f' % lr) + print('learning rate %.7f -> %.7f' % (old_lr, lr)) def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" diff --git a/train.py b/train.py index 9ff8b3d3f15..2852652df82 100644 --- a/train.py +++ b/train.py @@ -40,7 +40,7 @@ iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch - + model.update_learning_rate() # update learning rates in the beginning of every epoch. for i, data in enumerate(dataset): # inner loop within one epoch iter_start_time = time.time() # timer for computation per iteration if total_iters % opt.print_freq == 0: @@ -75,4 +75,3 @@ model.save_networks(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) - model.update_learning_rate() # update learning rates at the end of every epoch.