Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New info API for vectorized environments #2657 #2773

Merged
merged 49 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8267cee
WIP refactor info API sync vector.
gianlucadecola Apr 22, 2022
26da959
Add missing untracked file.
gianlucadecola Apr 22, 2022
fca5d32
Add info strategy to reset_wait.
gianlucadecola Apr 22, 2022
b22f9b2
Add interface and docstring.
gianlucadecola Apr 23, 2022
00b876d
info with strategy pattern on async vector env.
gianlucadecola Apr 24, 2022
90fb9ec
Add default to async vecenv.
gianlucadecola Apr 24, 2022
ceb8d2e
episode statistics for asyncvecnev.
gianlucadecola Apr 24, 2022
fee3722
Add tests info strategy format.
gianlucadecola Apr 24, 2022
891b927
Add info strategy to reset_wait.
gianlucadecola Apr 24, 2022
efe22cb
refactor and cleanup.
gianlucadecola Apr 25, 2022
8061306
Code cleanup. Add tests.
gianlucadecola Apr 25, 2022
029980c
Add tests for video recording with new info format.
gianlucadecola Apr 25, 2022
8667dda
fix test case.
gianlucadecola Apr 25, 2022
c3855e5
fix camelcase.
gianlucadecola Apr 26, 2022
d09cc0c
rename enum.
gianlucadecola Apr 26, 2022
b05f7e6
update tests, docstrings, cleanup.
gianlucadecola Apr 27, 2022
ad003bb
Changes brax strategy to numpy. add_strategy method in StrategyFactor…
gianlucadecola Apr 30, 2022
a1340c5
fix docstring and logging format.
gianlucadecola Apr 30, 2022
f704758
Set Brax info format as default. Remove classic info format. Update t…
gianlucadecola May 8, 2022
ad89471
breaking the wrong loop.
gianlucadecola May 8, 2022
4a4efe9
WIP: wrapper.
gianlucadecola May 8, 2022
9693b35
Add wrapper for brax to classic info.
gianlucadecola May 10, 2022
36051b7
WIP: wrapper with nested RecordEpisodeStatistic.
gianlucadecola May 10, 2022
f2b4ab3
Add tests. Refactor docstrings. Cleanup.
gianlucadecola May 13, 2022
35041d9
cleanup.
gianlucadecola May 13, 2022
cb2b993
patch conflicts.
gianlucadecola May 13, 2022
0d1522d
rebase and conflicts.
gianlucadecola May 13, 2022
be23655
new pre-commit conventions.
gianlucadecola May 13, 2022
9532b51
docstring.
gianlucadecola May 13, 2022
6593114
renaming.
gianlucadecola May 14, 2022
479bc8b
incorporate info_processor in vecEnv.
gianlucadecola May 14, 2022
b9a862b
renaming. Create info dict only if needed.
gianlucadecola May 14, 2022
761b576
remove all brax references. update docstring. Update duplicate test.
gianlucadecola May 15, 2022
7a488b7
reviews.
gianlucadecola May 16, 2022
0e3e201
pre-commit.
gianlucadecola May 16, 2022
d2f8b1b
reviews.
gianlucadecola May 17, 2022
03e0627
docstring.
gianlucadecola May 17, 2022
e433d34
cleanup blank lines.
gianlucadecola May 17, 2022
ee556aa
add support for numpy dtypes.
gianlucadecola May 17, 2022
f92b75e
docstring fix.
gianlucadecola May 17, 2022
17b8cd3
formatting.
gianlucadecola May 17, 2022
29ec6bf
naming.
gianlucadecola May 18, 2022
5e2aead
assert correct info from wrappers chaining. Test correct wrappers cha…
gianlucadecola May 18, 2022
a2b186a
simplify episode_statistics.
gianlucadecola May 19, 2022
555bacc
change args orer.
gianlucadecola May 19, 2022
d6eb5e7
update tests.
gianlucadecola May 19, 2022
db21ebc
wip: refactor episode_statistics.
gianlucadecola May 20, 2022
0a02bd5
Merge branch 'master' into new-info-api-vecenv
gianlucadecola May 23, 2022
659b8fc
Add test for add_vecore_episode_statistics.
gianlucadecola May 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class AsyncVectorEnv(VectorEnv):
space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`,
or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``.

InvalidInfoFormat
If the info format does not matches any of the available.

Example
-------

Expand Down Expand Up @@ -308,8 +311,10 @@ def reset_wait(
self._state = AsyncState.DEFAULT

if return_info:
results, infos = zip(*results)
infos = list(infos)
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)

if not self.shared_memory:
self.observations = concatenate(
Expand Down Expand Up @@ -406,10 +411,20 @@ def step_wait(self, timeout=None):
f"The call to `step_wait` has timed out after {timeout} second(s)."
)

results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
observations_list, rewards, dones, infos = [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, done, info = result
infos = self._add_info(infos, info, i)
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved

successes.append(success)
observations_list.append(obs)
rewards.append(rew)
dones.append(done)

self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
observations_list, rewards, dones, infos = zip(*results)

if not self.shared_memory:
self.observations = concatenate(
Expand Down
23 changes: 16 additions & 7 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SyncVectorEnv(VectorEnv):
:obj:`observation_space` (or, by default, the observation space of
the first sub-environment).

InvalidInfoFormat
If the info format does not matches any of the available.

Example
-------

Expand All @@ -50,7 +53,13 @@ class SyncVectorEnv(VectorEnv):
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
"""

def __init__(self, env_fns, observation_space=None, action_space=None, copy=True):
def __init__(
self,
env_fns,
observation_space=None,
action_space=None,
copy=True,
):
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy
Expand Down Expand Up @@ -98,8 +107,8 @@ def reset_wait(

self._dones[:] = False
observations = []
data_list = []
for env, single_seed in zip(self.envs, seed):
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):

kwargs = {}
if single_seed is not None:
Expand All @@ -115,7 +124,7 @@ def reset_wait(
else:
observation, data = env.reset(**kwargs)
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
observations.append(observation)
data_list.append(data)
infos = self._add_info(infos, data, i)

self.observations = concatenate(
self.single_observation_space, observations, self.observations
Expand All @@ -125,20 +134,20 @@ def reset_wait(
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), data_list
), infos

def step_async(self, actions):
self._actions = iterate(self.action_space, actions)

def step_wait(self):
observations, infos = [], []
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
info["terminal_observation"] = observation
observation = env.reset()
observations.append(observation)
infos.append(info)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
Expand Down
23 changes: 23 additions & 0 deletions gym/vector/vector_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Optional, Union

import numpy as np

import gym
from gym.logger import deprecation
from gym.vector.utils.spaces import batch_space
Expand Down Expand Up @@ -211,6 +213,27 @@ def seed(self, seed=None):
"Please use `env.reset(seed=seed) instead in VectorEnvs."
)

def _add_info(self, infos: dict, info: dict, env_num: int):
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_array(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]

info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos

def _init_info_array(self, dtype: type) -> np.ndarray:
if dtype not in [int, float, bool]:
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
dtype = object
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
array = np.zeros(self.num_envs, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some explicit handling for numpy arrays too? I think the current implementation can cause some problems through a mixup between e.g. np.float32 and np.float64, or it will even default to the object type. And I can easily see passing an array in the info dict as a desired functionality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem for me, atm if an array is added to the info dict the resulting output will fall in the object dtype resulting in this:

"terminal_observation": [
        array([0.2097086, -0.5355723, -0.21343598, -0.11173592], dtype=float32),
        None,
        None
    ]

Something I can think of is making the output in this form:

"terminal_observation": [
        array(
            [[0.2097086, -0.5355723, -0.21343598, -0.11173592],
            [0., 0., 0., 0.],
            [0., 0., 0., 0.]],
           dtype=float32),
    ]

I'm not sure this adhere to the format of google-brax tho

array[:] = None
else:
array = np.zeros(self.num_envs, dtype=dtype)
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask

def __del__(self):
if not getattr(self, "closed", True):
self.close()
Expand Down
3 changes: 2 additions & 1 deletion gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module of wrapper classes."""
from gym import error
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.autoreset import AutoResetWrapper
from gym.wrappers.clip_action import ClipAction
Expand All @@ -8,7 +9,6 @@
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.order_enforcing import OrderEnforcing
from gym.wrappers.pixel_observation import PixelObservationWrapper
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gym.wrappers.rescale_action import RescaleAction
Expand All @@ -17,3 +17,4 @@
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.transform_observation import TransformObservation
from gym.wrappers.transform_reward import TransformReward
from gym.wrappers.vec_info_to_classic import ClassicVectorInfo
109 changes: 98 additions & 11 deletions gym/wrappers/record_episode_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,89 @@
import gym


class ClassicStatsInfo:
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
"""Manage episode statistics."""

def __init__(self, num_envs: int):
"""Classic EpisodeStatics info.

Args:
num_envs (int): number of environments.
"""
self.info = {}

def add_info(self, infos: dict, env_num: int):
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
"""Add info.

Args:
infos (dict): info dict of the environment.
env_num (int): environment number.
"""
self.info = {**self.info, **infos}

def add_episode_statistics(self, infos: dict, env_num: int):
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
"""Add episode statistics.

Args:
infos (dict): info dict of the environment.
env_num (int): env number.
"""
self.info = {**self.info, **infos}

def get_info(self):
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
"""Return info."""
return self.info


class BraxVecEnvStatsInfo:
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
"""Manage episode statistics in the Brax format for vectorized envs."""

def __init__(self, num_envs: int):
"""Brax-style episode statistics.

Args:
num_envs (int): number of environments.
"""
self.num_envs = num_envs
self.info = {}

def add_info(self, info: dict, env_num: int):
"""Add info.

Args:
info (dict): info dict of the environment.
env_num (int): environment number.
"""
self.info = {**self.info, **info}
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved

def add_episode_statistics(self, info: dict, env_num: int):
"""Add episode statistics.

Add statistics coming from the vectorized environment.

Args:
info (dict): info dict of the environment.
env_num (int): env number of the vectorized environments.
"""
episode_info = info["episode"]

self.info["episode"] = self.info.get("episode", {})

self.info["_episode"] = self.info.get(
"_episode", np.zeros(self.num_envs, dtype=bool)
)
self.info["_episode"][env_num] = True

for k in episode_info.keys():
info_array = self.info["episode"].get(k, np.zeros(self.num_envs))
info_array[env_num] = episode_info[k]
self.info["episode"][k] = info_array

def get_info(self):
"""Returns info."""
return self.info


class RecordEpisodeStatistics(gym.Wrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths.

Expand Down Expand Up @@ -46,6 +129,10 @@ def __init__(self, env: gym.Env, deque_size: int = 100):
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
if self.is_vector_env:
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
self.stats_info_processor = BraxVecEnvStatsInfo
else:
self.stats_info_processor = ClassicStatsInfo

def reset(self, **kwargs):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
Expand All @@ -56,35 +143,35 @@ def reset(self, **kwargs):

def step(self, action):
"""Steps through the environment, recording the episode statistics."""
infos_processor = self.stats_info_processor(self.num_envs)
gianlucadecola marked this conversation as resolved.
Show resolved Hide resolved
observations, rewards, dones, infos = super().step(action)
self.episode_returns += rewards
self.episode_lengths += 1
if not self.is_vector_env:
infos = [infos]
dones = [dones]
else:
infos = list(infos) # Convert infos to mutable type
dones = list(dones)

for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
infos_processor.add_info(infos, i)
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
"episode": {
"r": episode_return,
"l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
}
}
infos[i]["episode"] = episode_info
infos_processor.add_episode_statistics(episode_info, i)
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.is_vector_env:
infos = tuple(infos)
return (
observations,
rewards,
dones if self.is_vector_env else dones[0],
infos if self.is_vector_env else infos[0],
infos_processor.get_info(),
)
Loading