diff --git a/gym/__init__.py b/gym/__init__.py index 70fad3d6119..8fc7c62abf9 100644 --- a/gym/__init__.py +++ b/gym/__init__.py @@ -31,8 +31,8 @@ def sanity_check_dependencies(): sanity_check_dependencies() -from gym.core import Env, Space +from gym.core import Env, Space, Wrapper from gym.envs import make, spec from gym.scoreboard.api import upload -__all__ = ["Env", "Space", "make", "spec", "upload"] +__all__ = ["Env", "Space", "Wrapper", "make", "spec", "upload"] diff --git a/gym/core.py b/gym/core.py index 02928a544ae..03d9d9fe894 100644 --- a/gym/core.py +++ b/gym/core.py @@ -2,6 +2,7 @@ logger = logging.getLogger(__name__) import numpy as np +import weakref from gym import error, monitoring from gym.utils import closer, reraise @@ -52,6 +53,7 @@ def __new__(cls, *args, **kwargs): env._env_closer_id = env_closer.register(env) env._closed = False env._configured = False + env._unwrapped = None # Will be automatically set when creating an environment via 'make' env.spec = None @@ -238,6 +240,42 @@ def configure(self, *args, **kwargs): else: raise + def build(self): + """[EXPERIMENTAL: may be removed in a later version of Gym] Builds an + environment by applying any provided wrappers, with the + outmost wrapper supplied first. This method is automatically + invoked by 'gym.make', and should be manually invoked if + instantiating an environment by hand. + + Notes: + The default implementation will wrap the environment in the + list of wrappers provided in self.metadata['wrappers'], in reverse + order. So for example, given: + + class FooEnv(gym.Env): + metadata = { + 'wrappers': [Wrapper1, Wrapper2] + } + + Calling 'env.build' will return 'Wrapper1(Wrapper2(env))'. + + Returns: + gym.Env: A potentially wrapped environment instance. + + """ + wrapped = self + for wrapper in reversed(self.metadata.get('wrappers', [])): + wrapped = wrapper(wrapped) + return wrapped + + @property + def unwrapped(self): + """Avoid refcycles by making this into a property.""" + if self._unwrapped is not None: + return self._unwrapped + else: + return self + def __del__(self): self.close() @@ -247,10 +285,9 @@ def __str__(self): # Space-related abstractions class Space(object): - """ - Provides a classification state spaces and action spaces, - so you can write generic code that applies to any Environment. - E.g. to choose a random action. + """Defines the observation and action spaces, so you can write generic + code that applies to any Env. For example, you can choose a random + action. """ def sample(self, seed=0): @@ -275,3 +312,34 @@ def from_jsonable(self, sample_n): """Convert a JSONable data type to a batch of samples from this space.""" # By default, assume identity is JSONable return sample_n + +class Wrapper(Env): + def __init__(self, env): + self.env = env + self.metadata = env.metadata + self.action_space = env.action_space + self.observation_space = env.observation_space + self.reward_range = env.reward_range + self.spec = env.spec + self._unwrapped = env.unwrapped + + def _step(self, action): + return self.env.step(action) + + def _reset(self): + return self.env.reset() + + def _render(self, mode='human', close=False): + return self.env.render(mode, close) + + def _close(self): + return self.env.close() + + def _configure(self, *args, **kwargs): + return self.env.configure(*args, **kwargs) + + def _seed(self, seed=None): + return self.env.seed(seed) + + def __str__(self): + return '<{}{} instance>'.format(type(self).__name__, self.env) diff --git a/gym/envs/registration.py b/gym/envs/registration.py index 5fa02b69634..360d65c6aab 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -63,6 +63,8 @@ def make(self): # Make the enviroment aware of which spec it came from. env.spec = self + env = env.build() + return env def __repr__(self):