Skip to content

Commit

Permalink
Update helloworld_SAC_TD3_single_file.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonv1943 committed Apr 10, 2024
1 parent 87cf325 commit 31d791e
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions helloworld/helloworld_SAC_TD3_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def train_sac_td3_for_pendulum():


def train_sac_td3_for_lunar_lander():
agent_class = [AgentSAC, AgentTD3][1] # DRL algorithm name
agent_class = [AgentSAC, AgentTD3][0] # DRL algorithm name
env_class = gym.make
env_args = {
'env_name': 'LunarLanderContinuous-v2', # A lander learns to land on a landing pad
Expand All @@ -618,39 +618,34 @@ def train_sac_td3_for_lunar_lander():
'if_discrete': False # continuous action space, symbols → direction, value → force
}
get_gym_env_args(env=gym.make('LunarLanderContinuous-v2'), if_print=True) # return env_args
# pip install box2d box2d-kengz --user
# https://stackoverflow.com/questions/50037674/attributeerror-module-box2d-has-no-attribute-rand-limit-swigconstant

args = Config(agent_class, env_class, env_args) # see `config.py Arguments()` for hyperparameter explanation
args.break_step = int(8e4) # break training if 'total_step > break_step'
args.net_dims = (128, 128) # the middle layer dimension of MultiLayer Perceptron
args.horizon_len = 128 # collect horizon_len step while exploring, then update network
args.repeat_times = 1.0 # repeatedly update network using ReplayBuffer to keep critic's loss small
args.state_value_tau = 0.1 # todo
args.state_value_tau = 0.01 # todo
# args.state_value_tau = 0.001 # todo
# args.state_value_tau = 0.000 # todo
# todo YonV1943 2022-10-31 15:34:34 something wrong with the state_std and value_std !!!!!!!!!!
args.state_value_tau = 0.1 # do rolling normalization on state using soft update tau

args.gpu_id = GPU_ID
args.random_seed = GPU_ID
train_agent(args)
"""
cumulative returns range: -1500 < -140 < 200 < 280
SAC
SAC on CPU
| step time | avgR stdR avgS | objC objA
| 1.01e+04 88 | 19.53 148.64 362 | 1.93 23.59
| 2.02e+04 294 | -60.15 120.83 805 | 2.59 60.84
| 3.03e+04 617 | -50.82 46.35 965 | 3.53 104.68
| 4.04e+04 1051 | -55.18 22.74 972 | 2.58 90.86
| 5.06e+04 1560 | 172.70 84.48 664 | 2.06 66.80
| 6.07e+04 2175 | 211.03 90.33 511 | 2.07 55.08
TD3
"""


if __name__ == '__main__':
GPU_ID = int(sys.argv[1]) # todo
GPU_ID = int(sys.argv[1]) if len(sys.argv) > 1 else -1
# train_sac_td3_for_pendulum()
train_sac_td3_for_lunar_lander()

0 comments on commit 31d791e

Please sign in to comment.