Skip to content

Commit

Permalink
add images to tfevents
Browse files Browse the repository at this point in the history
  • Loading branch information
tehutahu committed Feb 20, 2022
1 parent da161cd commit b3648fc
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def setup_snapshot_image_grid(training_set, random_seed=0):

#----------------------------------------------------------------------------

def image_covert(img, drange, num=8):
lo, hi = drange
img = np.asarray(img, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = np.rint(img).clip(0, 255).astype(np.uint8)
return img[:num]

#----------------------------------------------------------------------------

def save_image_grid(img, fname, drange, grid_size):
lo, hi = drange
img = np.asarray(img, dtype=np.float32)
Expand Down Expand Up @@ -349,6 +358,7 @@ def training_loop(
print('Aborting...')

# Save image snapshot.
images = None
if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
Expand Down Expand Up @@ -407,6 +417,9 @@ def training_loop(
stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
for name, value in stats_metrics.items():
stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
if images is not None:
images = image_covert(images, [-1, 1], num=8)
stats_tfevents.add_images('fakes', images, global_step=global_step, walltime=walltime)
stats_tfevents.flush()
if progress_fn is not None:
progress_fn(cur_nimg // 1000, total_kimg)
Expand Down

0 comments on commit b3648fc

Please sign in to comment.