-
Notifications
You must be signed in to change notification settings - Fork 53
/
saveload.py
38 lines (25 loc) · 917 Bytes
/
saveload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pickle
from tensorflow import Session
import os
import tensorflow as tf
def main(save_path, sess):
if not os.path.exists(save_path):
with open(save_path, "wb") as file:
variables = tf.trainable_variables()
values = sess.run(variables)
pickle.dump({var.name: val for var, val in zip(variables, values)}, file)
else:
v_dic = {v.name: v for v in tf.trainable_variables()}
for key, value in pickle.load(open(save_path, "rb")).items():
sess.run(tf.assign(v_dic[key], value))
def load_np(save_path):
if not os.path.exists(save_path):
raise Exception("No saved weights at that location")
else:
v_dict = pickle.load(open(save_path, "wb"))
for key in v_dict.keys():
print("Key name: " + key)
return v_dict
if __name__ == '__main__':
from sys import argv
exit(main(argv))