Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New info API for vectorized environments #2657 #2773

Merged
merged 49 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8267cee
WIP refactor info API sync vector.
gianlucadecola Apr 22, 2022
26da959
Add missing untracked file.
gianlucadecola Apr 22, 2022
fca5d32
Add info strategy to reset_wait.
gianlucadecola Apr 22, 2022
b22f9b2
Add interface and docstring.
gianlucadecola Apr 23, 2022
00b876d
info with strategy pattern on async vector env.
gianlucadecola Apr 24, 2022
90fb9ec
Add default to async vecenv.
gianlucadecola Apr 24, 2022
ceb8d2e
episode statistics for asyncvecnev.
gianlucadecola Apr 24, 2022
fee3722
Add tests info strategy format.
gianlucadecola Apr 24, 2022
891b927
Add info strategy to reset_wait.
gianlucadecola Apr 24, 2022
efe22cb
refactor and cleanup.
gianlucadecola Apr 25, 2022
8061306
Code cleanup. Add tests.
gianlucadecola Apr 25, 2022
029980c
Add tests for video recording with new info format.
gianlucadecola Apr 25, 2022
8667dda
fix test case.
gianlucadecola Apr 25, 2022
c3855e5
fix camelcase.
gianlucadecola Apr 26, 2022
d09cc0c
rename enum.
gianlucadecola Apr 26, 2022
b05f7e6
update tests, docstrings, cleanup.
gianlucadecola Apr 27, 2022
ad003bb
Changes brax strategy to numpy. add_strategy method in StrategyFactor…
gianlucadecola Apr 30, 2022
a1340c5
fix docstring and logging format.
gianlucadecola Apr 30, 2022
f704758
Set Brax info format as default. Remove classic info format. Update t…
gianlucadecola May 8, 2022
ad89471
breaking the wrong loop.
gianlucadecola May 8, 2022
4a4efe9
WIP: wrapper.
gianlucadecola May 8, 2022
9693b35
Add wrapper for brax to classic info.
gianlucadecola May 10, 2022
36051b7
WIP: wrapper with nested RecordEpisodeStatistic.
gianlucadecola May 10, 2022
f2b4ab3
Add tests. Refactor docstrings. Cleanup.
gianlucadecola May 13, 2022
35041d9
cleanup.
gianlucadecola May 13, 2022
cb2b993
patch conflicts.
gianlucadecola May 13, 2022
0d1522d
rebase and conflicts.
gianlucadecola May 13, 2022
be23655
new pre-commit conventions.
gianlucadecola May 13, 2022
9532b51
docstring.
gianlucadecola May 13, 2022
6593114
renaming.
gianlucadecola May 14, 2022
479bc8b
incorporate info_processor in vecEnv.
gianlucadecola May 14, 2022
b9a862b
renaming. Create info dict only if needed.
gianlucadecola May 14, 2022
761b576
remove all brax references. update docstring. Update duplicate test.
gianlucadecola May 15, 2022
7a488b7
reviews.
gianlucadecola May 16, 2022
0e3e201
pre-commit.
gianlucadecola May 16, 2022
d2f8b1b
reviews.
gianlucadecola May 17, 2022
03e0627
docstring.
gianlucadecola May 17, 2022
e433d34
cleanup blank lines.
gianlucadecola May 17, 2022
ee556aa
add support for numpy dtypes.
gianlucadecola May 17, 2022
f92b75e
docstring fix.
gianlucadecola May 17, 2022
17b8cd3
formatting.
gianlucadecola May 17, 2022
29ec6bf
naming.
gianlucadecola May 18, 2022
5e2aead
assert correct info from wrappers chaining. Test correct wrappers cha…
gianlucadecola May 18, 2022
a2b186a
simplify episode_statistics.
gianlucadecola May 19, 2022
555bacc
change args orer.
gianlucadecola May 19, 2022
d6eb5e7
update tests.
gianlucadecola May 19, 2022
db21ebc
wip: refactor episode_statistics.
gianlucadecola May 20, 2022
0a02bd5
Merge branch 'master' into new-info-api-vecenv
gianlucadecola May 23, 2022
659b8fc
Add test for add_vecore_episode_statistics.
gianlucadecola May 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Code cleanup. Add tests.
  • Loading branch information
gianlucadecola committed May 13, 2022
commit 80613060a9ac24804ef2fdbce0fca2b4d11e22b9
1 change: 0 additions & 1 deletion gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
AlreadyPendingCallError,
ClosedEnvironmentError,
CustomSpaceError,
InvalidInfoFormat,
NoAsyncCallError,
)
from gym.vector.utils import (
Expand Down
10 changes: 7 additions & 3 deletions gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module of wrapper classes."""
from gym import error
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.autoreset import AutoResetWrapper
from gym.wrappers.clip_action import ClipAction
Expand All @@ -8,8 +8,12 @@
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.order_enforcing import OrderEnforcing
from gym.wrappers.pixel_observation import PixelObservationWrapper
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_episode_statistics import (
BraxVecEnvStatsInfoStrategy,
ClassicVecEnvStatsInfoStrategy,
RecordEpisodeStatistics,
get_statistic_info_strategy,
)
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.resize_observation import ResizeObservation
Expand Down
2 changes: 1 addition & 1 deletion gym/wrappers/record_episode_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_statistic_info_strategy(wrapped_env_strategy: str):
}
if wrapped_env_strategy not in strategies:
raise NoMatchingInfoStrategy(
"Wrapped environment has an info format of type %s which is not a processable format by this wrappers. Please use one in %s"
"Wrapped environment has an info format of type %s which is not a processable format by this wrapper. Please use one in %s"
% (wrapped_env_strategy, list(strategies.keys()))
)
return strategies[wrapped_env_strategy]
Expand Down
49 changes: 33 additions & 16 deletions tests/vector/test_info_format_strategies.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,57 @@
import pytest

from gym.vector.utils import BraxVecEnvInfoStrategy, ClassicVecEnvInfoStrategy


def test_classic_vec_env_info_strategy():
NUM_ENVS = 3

infos = ClassicVecEnvInfoStrategy(NUM_ENVS)
for i in range(NUM_ENVS):
from gym.error import InvalidInfoFormat
from gym.vector.utils import (
BraxVecEnvInfoStrategy,
ClassicVecEnvInfoStrategy,
get_info_strategy,
)


@pytest.mark.parametrize(("num_envs"), [3])
def test_classic_vec_env_info_strategy(num_envs):
infos = ClassicVecEnvInfoStrategy(num_envs)
for i in range(num_envs):
info = {"example_info": i}
infos.add_info(info, i)

gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
expected_info = [{"example_info": 0}, {"example_info": 1}, {"example_info": 2}]
assert expected_info == infos.get_info()


def test_brax_vec_env_info_strategy():
@pytest.mark.parametrize(("num_envs"), [3])
def test_brax_vec_env_info_strategy(num_envs):
NUM_ENVS = 3

infos = BraxVecEnvInfoStrategy(NUM_ENVS)
for i in range(NUM_ENVS):
infos = BraxVecEnvInfoStrategy(num_envs)
for i in range(num_envs):
info = {"example_info": i}
infos.add_info(info, i)

expected_info = {"example_info": [0, 1, 2]}
assert expected_info == infos.get_info()


def test_brax_vec_env_info_strategy_with_nones():
NUM_ENVS = 5

infos = BraxVecEnvInfoStrategy(NUM_ENVS)
for i in range(NUM_ENVS):
@pytest.mark.parametrize(("num_envs"), [5])
def test_brax_vec_env_info_strategy_with_nones(num_envs):
infos = BraxVecEnvInfoStrategy(num_envs)
for i in range(num_envs):
if i % 2 == 0:
info = {"example_info": i}
infos.add_info(info, i)

gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
expected_info = {"example_info": [0, None, 2, None, 4]}
assert expected_info == infos.get_info()


@pytest.mark.parametrize(("info_format"), [("classic"), ("brax"), ("non_existent")])
def test_get_info_strategy(info_format):
if info_format == "classic":
InfoStrategy = get_info_strategy(info_format)
assert InfoStrategy == ClassicVecEnvInfoStrategy
elif info_format == "brax":
InfoStrategy = get_info_strategy(info_format)
assert InfoStrategy == BraxVecEnvInfoStrategy
else:
with pytest.raises(InvalidInfoFormat):
get_info_strategy(info_format)
43 changes: 42 additions & 1 deletion tests/wrappers/test_record_episode_statistics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import pytest

import gym
from gym.wrappers import RecordEpisodeStatistics
from gym.error import NoMatchingInfoStrategy
from gym.wrappers import (
BraxVecEnvStatsInfoStrategy,
ClassicVecEnvStatsInfoStrategy,
RecordEpisodeStatistics,
get_statistic_info_strategy,
)


@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
Expand Down Expand Up @@ -55,3 +61,38 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
assert "episode" in info
assert all([item in info["episode"] for item in ["r", "l", "t"]])
break


@pytest.mark.parametrize(("num_envs", "asynchronous"), [(3, False), (3, True)])
def test_episode_statistics_brax_info(num_envs, asynchronous):
envs = gym.vector.make(
"CartPole-v1", asynchronous=asynchronous, num_envs=num_envs, info_format="brax"
)
envs = RecordEpisodeStatistics(envs)
envs.reset()
dones = [False for _ in range(num_envs)]
actions = np.array([1, 0, 1])
while not any(dones):
_, _, dones, infos = envs.step(actions)

assert "episode" in infos
assert len(infos["episode"]) == num_envs
assert "terminal_observation" in infos
for i in range(num_envs):
if dones[i]:
assert infos["terminal_observation"][i] is not None
else:
assert infos["terminal_observation"][i] is None


@pytest.mark.parametrize(("info_format"), [("classic"), ("brax"), ("non_existent")])
def test_get_statistic_info_strategy(info_format):
if info_format == "classic":
InfoStrategy = get_statistic_info_strategy(info_format)
assert InfoStrategy == ClassicVecEnvStatsInfoStrategy
elif info_format == "brax":
InfoStrategy = get_statistic_info_strategy(info_format)
assert InfoStrategy == BraxVecEnvStatsInfoStrategy
else:
with pytest.raises(NoMatchingInfoStrategy):
get_statistic_info_strategy(info_format)