Skip to content

Commit

Permalink
zodb对象存储提高磁盘读写效率
Browse files Browse the repository at this point in the history
  • Loading branch information
SoCool1345 committed Jun 21, 2022
1 parent e326d9b commit dae03e2
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 53 deletions.
22 changes: 13 additions & 9 deletions UIplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from config import CONFIG



if CONFIG['use_frame'] == 'paddle':
from paddle_net import PolicyValueNet
elif CONFIG['use_frame'] == 'pytorch':
Expand All @@ -25,8 +26,10 @@ def get_action(self, move):
# move从鼠标点击事件触发
# print('当前是player2在操作')
# print(board.current_player_color)
move = move_action2move_id[move]

if move_action2move_id.__contains__(move):
move = move_action2move_id[move]
else:
move = -1
# move = random.choice(board.availables)
return move

Expand All @@ -37,7 +40,7 @@ def set_player_ind(self, p):
if CONFIG['use_frame'] == 'paddle':
policy_value_net = PolicyValueNet(model_file='current_policy.model')
elif CONFIG['use_frame'] == 'pytorch':
policy_value_net = PolicyValueNet(model_file='current_policy.pkl')
policy_value_net = PolicyValueNet(model_file='models/current_policy_batch1900.model')
else:
print('暂不支持您选择的框架')

Expand All @@ -57,7 +60,7 @@ def set_player_ind(self, p):
# 创建指定大小的窗口
screen = pygame.display.set_mode(size)
# 设置窗口标题
pygame.display.set_caption("一心炼银")
pygame.display.set_caption("中国象棋")

# 加载一个列表进行图像的绘制
# 列表表示的棋盘初始化,红子在上,黑子在下,禁止对该列表进行编辑,使用时必须使用深拷贝
Expand Down Expand Up @@ -199,7 +202,7 @@ def board2image(board):

player1 = MCTSPlayer(policy_value_net.policy_value_fn,
c_puct=5,
n_playout=800,
n_playout=1000,
is_selfplay=0)


Expand Down Expand Up @@ -271,10 +274,11 @@ def board2image(board):
swicth_player = False
if len(move_action) == 4:
move = player_in_turn.get_action(move_action) # 当前玩家代理拿到动作
board.do_move(move) # 棋盘做出改变
swicth_player = True
move_action = ''
draw_fire = False
if move != -1:
board.do_move(move) # 棋盘做出改变
swicth_player = True
move_action = ''
draw_fire = False

end, winner = board.game_end()
if end:
Expand Down
53 changes: 34 additions & 19 deletions collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from game import Board, Game, move_action2move_id, move_id2move_action, flip_map
from mcts import MCTSPlayer
from config import CONFIG
from my_zodb import MyZODB


if CONFIG['use_frame'] == 'paddle':
Expand All @@ -34,6 +35,7 @@ def __init__(self, init_model=None):
self.data_buffer = deque(maxlen=self.buffer_size)
self.iters = 0


# 从主体加载模型
def load_model(self):
if CONFIG['use_frame'] == 'paddle':
Expand Down Expand Up @@ -82,27 +84,40 @@ def collect_selfplay_data(self, n_games=1):
self.episode_len = len(play_data)
# 增加数据
play_data = self.get_equi_data(play_data)

if os.path.exists(CONFIG['train_data_buffer_path']):
while True:
while True:
try:
db = MyZODB()
self.iters = db.dump(db.getMaxIters()+1,play_data)
print("存储完成")
break
except:
print("存储失败")
time.sleep(1)
finally:
try:
with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
data_file = pickle.load(data_dict)
self.data_buffer = data_file['data_buffer']
self.iters = data_file['iters']
del data_file
self.iters += 1
self.data_buffer.extend(play_data)
print('成功载入数据')
break
db.close()
except:
time.sleep(30)
else:
self.data_buffer.extend(play_data)
self.iters += 1
data_dict = {'data_buffer': self.data_buffer, 'iters': self.iters}
with open(CONFIG['train_data_buffer_path'], 'wb') as data_file:
pickle.dump(data_dict, data_file)
pass
# if os.path.exists(CONFIG['train_data_buffer_path']):
# while True:
# try:
# with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
# data_file = pickle.load(data_dict)
# self.data_buffer = data_file['data_buffer']
# self.iters = data_file['iters']
# del data_file
# self.iters += 1
# self.data_buffer.extend(play_data)
# print('成功载入数据')
# break
# except:
# time.sleep(30)
# else:
# self.data_buffer.extend(play_data)
# self.iters += 1
# data_dict = {'data_buffer': self.data_buffer, 'iters': self.iters}
# with open(CONFIG['train_data_buffer_path'], 'wb') as data_file:
# pickle.dump(data_dict, data_file)
return self.iters

def run(self):
Expand Down
10 changes: 5 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
CONFIG = {
'kill_action': 60,
'dirichlet': 0.15, # 国际象棋,0.3;日本将棋,0.15;围棋,0.03
'play_out': 20, # 每次移动的模拟次数
'kill_action': 30, #和棋回合数
'dirichlet': 0.2, # 国际象棋,0.3;日本将棋,0.15;围棋,0.03
'play_out': 500, # 每次移动的模拟次数
'c_puct': 5, # u的权重
'buffer_size': 100000, # 经验池大小
'buffer_size': 10000, # 经验池大小
'paddle_model_path': 'current_policy.model', # paddle模型路径
'pytorch_model_path': 'current_policy.pkl', # pytorch模型路径
'train_data_buffer_path': 'train_data_buffer.pkl', # 数据容器的路径
'batch_size': 512, # 每次更新的train_step数量
'batch_size': 128, # 每次更新的train_step数量
'kl_targ': 0.02, # kl散度控制
'epochs' : 5, # 每次更新的train_step数量
'game_batch_num': 3000, # 训练更新的次数
Expand Down
90 changes: 90 additions & 0 deletions my_zodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import time

import ZODB, ZODB.FileStorage,ZODB.FileStorage.interfaces
import persistent,persistent.list
import transaction
from BTrees.IOBTree import IOBTree
from BTrees.OOBTree import OOBTree

from config import CONFIG


class MyZODB():
def __init__(self,data_path='data/train_data_buffer.fs'):
self.storage = ZODB.FileStorage.FileStorage(data_path)
self.db = ZODB.DB(self.storage,cache_size=1,large_record_size=1<<64)
self.connection = self.db.open()
self.root = self.connection.root()
self.gc_cnt = 0 #20次gc一次
while True:
try:
if not isinstance(self.root.data,IOBTree):
self.root.data = IOBTree()
transaction.commit()
else:
break
except:
time.sleep(5)
self.data = self.root.data
def close(self):
self.connection.close()
self.db.close()
self.storage.close()

def getMaxIters(self):
return self.data.maxKey() if len(self.data)>0 else 0
def getMinIters(self):
return self.data.minKey() if len(self.data)>0 else 0

def __delitem__(self, key):
self.data.__delitem__(key)

def gc(self,buffer_size = CONFIG['buffer_size']):
len = self.data.__sizeof__()
if len > buffer_size:
self.delAny(len / 10) #删除10%数据
if self.gc_cnt > 20:
self.pack()
self.gc_cnt = 0

def delAny(self,num):
for i in range(num):
self.data.__delitem__(self.db.getMinIters())
transaction.commit()

def pack(self):
self.db.pack(time.time())
transaction.commit()

#加载所有数据
def load(self):
return self.data.maxKey(),list([d for ds in self.data.values() for d in ds])

def dump(self,iters, data_buffer):
self.gc_cnt += 1
i = iters
while True:
if not self.data.has_key(i):
self.data.insert(i,data_buffer)
transaction.commit()
break
else:
i += 1
return i

class Book(persistent.Persistent):

def __init__(self, title):
self.title = title
self.authors = []

def add_author(self, author):
self.authors.append(author)
self._p_changed = True

if __name__ == '__main__':
t = MyZODB('data/train_data_buffer.fs')
t.pack()
print(t.storage.getSize())
print(t.getMaxIters())
t.close()
52 changes: 32 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pickle
import time
from config import CONFIG

from my_zodb import MyZODB

if CONFIG['use_frame'] == 'paddle':
from paddle_net import PolicyValueNet
Expand All @@ -29,7 +29,6 @@ def __init__(self, init_model=None):
self.kl_targ = CONFIG['kl_targ'] # kl散度控制
self.check_freq = 100 # 保存模型的频率
self.game_batch_num = CONFIG['game_batch_num'] # 训练更新的次数

if init_model:
try:
self.policy_value_net = PolicyValueNet(model_file=init_model)
Expand All @@ -45,7 +44,7 @@ def __init__(self, init_model=None):
def policy_updata(self):
"""更新策略价值网络"""
mini_batch = random.sample(self.data_buffer, self.batch_size)

# print(mini_batch[0][1],mini_batch[1][1])
state_batch = [data[0] for data in mini_batch]
state_batch = np.array(state_batch).astype('float32')

Expand Down Expand Up @@ -79,7 +78,7 @@ def policy_updata(self):
self.lr_multiplier /= 1.5
elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
self.lr_multiplier *= 1.5

# print(old_v.flatten(),new_v.flatten())
explained_var_old = (1 -
np.var(np.array(winner_batch) - old_v.flatten()) /
np.var(np.array(winner_batch)))
Expand All @@ -91,8 +90,8 @@ def policy_updata(self):
"lr_multiplier:{:.3f},"
"loss:{},"
"entropy:{},"
"explained_var_old:{:.3f},"
"explained_var_new:{:.3f}"
"explained_var_old:{:.9f},"
"explained_var_new:{:.9f}"
).format(kl,
self.lr_multiplier,
loss,
Expand All @@ -106,27 +105,40 @@ def run(self):
try:
for i in range(self.game_batch_num):
time.sleep(30) # 每10分钟更新一次模型
# while True:
# try:
# with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
# data_file = pickle.load(data_dict)
# self.data_buffer = data_file['data_buffer']
# self.iters = data_file['iters']
# del data_file
# print('已载入数据')
# break
# except:
# time.sleep(30)
while True:
try:
with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
data_file = pickle.load(data_dict)
self.data_buffer = data_file['data_buffer']
self.iters = data_file['iters']
del data_file
print('已载入数据')
mydb = MyZODB()
mydb.gc()
self.iters,self.data_buffer = mydb.load()
break
except:
time.sleep(30)
time.sleep(5)
finally:
try:
mydb.close()
except:
pass
print('step i {}: '.format(self.iters))
if len(self.data_buffer) > self.batch_size:
loss, entropy = self.policy_updata()
# 保存模型
if CONFIG['use_frame'] == 'paddle':
self.policy_value_net.save_model(CONFIG['paddle_model_path'])
elif CONFIG['use_frame'] == 'pytorch':
self.policy_value_net.save_model(CONFIG['pytorch_model_path'])
else:
print('不支持所选框架')
# 保存模型
if CONFIG['use_frame'] == 'paddle':
self.policy_value_net.save_model(CONFIG['paddle_model_path'])
elif CONFIG['use_frame'] == 'pytorch':
self.policy_value_net.save_model(CONFIG['pytorch_model_path'])
else:
print('不支持所选框架')
if (i + 1) % self.check_freq == 0:
print('current selfplay batch: {}'.format(i + 1))
self.policy_value_net.save_model('models/current_policy_batch{}.model'.format(i + 1))
Expand Down

0 comments on commit dae03e2

Please sign in to comment.