Skip to content

Commit

Permalink
get best checkpoint available using mean reward
Browse files Browse the repository at this point in the history
  • Loading branch information
skourta committed Nov 14, 2023
1 parent 5a7ad25 commit 4464988
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
12 changes: 10 additions & 2 deletions rl_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@

Config.init()

# check if args.output_path exists and create it if it doesn't
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)

# Empty the output directory
for file in output_path.iterdir():
file.unlink()

# read the ip and port from the server_address file
ip_and_port = ""
while ip_and_port == "":
Expand All @@ -56,7 +64,8 @@
if num_workers == -1:
num_workers = int(ray.available_resources()["CPU"])

# print(f"num workers: {num_workers}")
print(f"num workers: {num_workers}")
print(f"dataset size: {dataset_size}")

num_programs_per_task = dataset_size // num_workers
programs_remaining = dataset_size % num_workers
Expand Down Expand Up @@ -93,7 +102,6 @@

explorations.append(benchmark_actor.explore_benchmarks.remote())

print(len(explorations))
while len(explorations) > 0:
# Wait for actors to finish their exploration
done, explorations = ray.wait(explorations)
Expand Down
16 changes: 15 additions & 1 deletion rllib_ray_utils/evaluators/lstm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog

Expand Down Expand Up @@ -54,10 +55,23 @@ def __init__(
.resources(num_gpus=0)
.debugging(log_level="WARN")
)

restored_tuner = tune.Tuner.restore(config.ray.restore_checkpoint)
result_grid = restored_tuner.get_results()
best_result = result_grid.get_best_result("episode_reward_mean", "max")
best_result
best_checkpoint = None
highest_reward = float("-inf")
for checkpoint in best_result.best_checkpoints:
episode_reward_mean = checkpoint[1]["episode_reward_mean"]
if episode_reward_mean > highest_reward:
highest_reward = episode_reward_mean
best_checkpoint = checkpoint

# Build the Algorithm instance using the config.
# Restore the algo's state from the checkpoint.
self.algo = self.config_model.build()
self.algo.restore(config.ray.restore_checkpoint)
self.algo.restore(best_checkpoint[0])
self.num_programs_done = 0

# explore schedules for benchmarks
Expand Down

0 comments on commit 4464988

Please sign in to comment.