-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_LSGAN.py
79 lines (65 loc) · 2.85 KB
/
train_LSGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import math
import argparse
import tqdm
import tensorflow as tf
from utils import str_to_bool, get_config, find_config, check_dataset_config
from utils import allow_memory_growth, ImageLoader
from models import LSGAN
def main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-mg', '--memory_growth', type=str_to_bool,
default=True)
arg_parser.add_argument('-c', '--config', type=str,
default='configs/LSGAN/lsun.yaml')
arg_parser.add_argument('-ckpt', '--checkpoint', type=str,
default=None)
args = vars(arg_parser.parse_args())
tf.keras.backend.set_image_data_format('channels_first')
if args['memory_growth']:
allow_memory_growth()
if args['checkpoint'] is not None:
args['config'] = find_config(args['checkpoint'])
conf = get_config(args['config'])
check_dataset_config(conf)
"""Load Dataset"""
dataset_conf = conf['dataset']
loader = ImageLoader(data_txt_file=dataset_conf['train_data_txt'])
train_dataset = loader.get_dataset(batch_size=conf['batch_size'],
new_size=(conf['input_size'],)*2,
cache=dataset_conf['cache'])
test_data = tf.random.normal(shape=(conf['test_batch_size'], conf['latent_dim']),
seed=conf['random_seed'])
"""Model Initiate"""
model = LSGAN(conf, args['checkpoint'])
if args['checkpoint'] is None:
model.copy_conf(args['config'])
"""Start Train"""
start_epoch = model.ckpt.step // len(train_dataset) + 1
epoch_by_step = math.ceil(conf['steps'] / len(train_dataset))
if conf['epochs'] < epoch_by_step:
conf['epochs'] = epoch_by_step
else:
conf['steps'] = conf['epochs'] * len(train_dataset)
test_step = conf['test_step']
save_step = conf['save_step']
end_step = conf['steps']
pbar = tqdm.trange(start_epoch, conf['epochs']+1,
position=0, leave=True)
for epoch in pbar:
pbar.set_postfix({'Current Epoch': epoch})
sub_pbar = tqdm.tqdm(train_dataset, leave=False)
for image_batch in sub_pbar:
log_dict = model.train(image_batch)
current_step = model.ckpt.step.numpy()
if current_step % 10 == 0:
sub_pbar.set_postfix({'Step': current_step,
'G': '{:.4f}'.format(log_dict['loss/gen']),
'D': '{:.4f}'.format(log_dict['loss/dis'])})
if test_step and current_step % test_step == 0:
model.test(test_data, current_step, save=True)
if save_step and current_step % save_step == 0:
model.save()
if current_step == end_step:
return
if __name__ == '__main__':
main()