Skip to content

Commit d818750

Browse files
authored
Add a complete compatibility wrapper (#3066)
* Added a new compatibility wrapper along with tests * Fix for 3.6 typing * Fix for 3.6 typing again * Add make integration * Unrelated change that for some reason is necessary to fix pyright * Ignore weird (and very non-critical) type check bug * Adjust old tests * Rename the compatibility argument in make * Rename the compatibility argument in register and envspec * Documentation updates * Remove test envs from the registry * Some rogue renames * Add nicer str and repr to the compatibility layer * Reorder the compatibility layer application * Add metadata to test envs * Add proper handling of automatic human rendering * Add auto human rendering to reset * Enable setting render_mode in gym.make * Documentation update * Fix an unrelated stochastic test
1 parent 2f33096 commit d818750

File tree

8 files changed

+299
-37
lines changed

8 files changed

+299
-37
lines changed

gym/envs/registration.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
HumanRendering,
2727
OrderEnforcing,
2828
RenderCollection,
29-
StepAPICompatibility,
3029
TimeLimit,
3130
)
31+
from gym.wrappers.compatibility import EnvCompatibility
3232
from gym.wrappers.env_checker import PassiveEnvChecker
3333

3434
if sys.version_info < (3, 10):
@@ -141,7 +141,7 @@ class EnvSpec:
141141
order_enforce: bool = field(default=True)
142142
autoreset: bool = field(default=False)
143143
disable_env_checker: bool = field(default=False)
144-
apply_step_compatibility: bool = field(default=False)
144+
apply_api_compatibility: bool = field(default=False)
145145

146146
# Environment arguments
147147
kwargs: dict = field(default_factory=dict)
@@ -440,7 +440,7 @@ def register(
440440
order_enforce: bool = True,
441441
autoreset: bool = False,
442442
disable_env_checker: bool = False,
443-
apply_step_compatibility: bool = False,
443+
apply_api_compatibility: bool = False,
444444
**kwargs,
445445
):
446446
"""Register an environment with gym.
@@ -459,7 +459,7 @@ def register(
459459
order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order
460460
autoreset: If to add the autoreset wrapper such that reset does not need to be called.
461461
disable_env_checker: If to disable the environment checker for the environment. Recommended to False.
462-
apply_step_compatibility: If to apply the `StepAPICompatibility` wrapper.
462+
apply_api_compatibility: If to apply the `StepAPICompatibility` wrapper.
463463
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
464464
"""
465465
global registry, current_namespace
@@ -490,7 +490,7 @@ def register(
490490
order_enforce=order_enforce,
491491
autoreset=autoreset,
492492
disable_env_checker=disable_env_checker,
493-
apply_step_compatibility=apply_step_compatibility,
493+
apply_api_compatibility=apply_api_compatibility,
494494
**kwargs,
495495
)
496496
_check_spec_register(new_spec)
@@ -503,7 +503,7 @@ def make(
503503
id: Union[str, EnvSpec],
504504
max_episode_steps: Optional[int] = None,
505505
autoreset: bool = False,
506-
apply_step_compatibility: Optional[bool] = None,
506+
apply_api_compatibility: Optional[bool] = None,
507507
disable_env_checker: Optional[bool] = None,
508508
**kwargs,
509509
) -> Env:
@@ -515,10 +515,10 @@ def make(
515515
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
516516
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
517517
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
518-
apply_step_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
518+
apply_api_compatibility: Whether to wrap the environment with the `StepAPICompatibility` wrapper that
519519
converts the environment step from a done bool to return termination and truncation bools.
520-
By default, the argument is None to which the environment specification `apply_step_compatibility` is used
521-
which defaults to False. Otherwise, the value of `apply_step_compatibility` is used.
520+
By default, the argument is None to which the environment specification `apply_api_compatibility` is used
521+
which defaults to False. Otherwise, the value of `apply_api_compatibility` is used.
522522
If `True`, the wrapper is applied otherwise, the wrapper is not applied.
523523
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
524524
(which is by default False, running the environment checker),
@@ -628,6 +628,14 @@ def make(
628628
f"The environment creator metadata doesn't include `render_modes`, contains: {list(env_creator.metadata.keys())}"
629629
)
630630

631+
if apply_api_compatibility is True or (
632+
apply_api_compatibility is None and spec_.apply_api_compatibility is True
633+
):
634+
# If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator
635+
render_mode = _kwargs.pop("render_mode", None)
636+
else:
637+
render_mode = None
638+
631639
try:
632640
env = env_creator(**_kwargs)
633641
except TypeError as e:
@@ -648,18 +656,18 @@ def make(
648656
spec_.kwargs = _kwargs
649657
env.unwrapped.spec = spec_
650658

659+
# Add step API wrapper
660+
if apply_api_compatibility is True or (
661+
apply_api_compatibility is None and spec_.apply_api_compatibility is True
662+
):
663+
env = EnvCompatibility(env, render_mode)
664+
651665
# Run the environment checker as the lowest level wrapper
652666
if disable_env_checker is False or (
653667
disable_env_checker is None and spec_.disable_env_checker is False
654668
):
655669
env = PassiveEnvChecker(env)
656670

657-
# Add step API wrapper
658-
if apply_step_compatibility is True or (
659-
apply_step_compatibility is None and spec_.apply_step_compatibility is True
660-
):
661-
env = StepAPICompatibility(env, output_truncation_bool=True)
662-
663671
# Add the order enforcing wrapper
664672
if spec_.order_enforce:
665673
env = OrderEnforcing(env)

gym/wrappers/compatibility.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""A compatibility wrapper converting an old-style environment into a valid environment."""
2+
import sys
3+
from typing import Any, Dict, Optional, Tuple
4+
5+
import gym
6+
from gym.core import ObsType
7+
from gym.utils.step_api_compatibility import convert_to_terminated_truncated_step_api
8+
9+
if sys.version_info >= (3, 8):
10+
from typing import Protocol, runtime_checkable
11+
elif sys.version_info >= (3, 7):
12+
from typing_extensions import Protocol, runtime_checkable
13+
else:
14+
Protocol = object
15+
runtime_checkable = lambda x: x # noqa: E731
16+
17+
18+
@runtime_checkable
19+
class LegacyEnv(Protocol):
20+
"""A protocol for environments using the old step API."""
21+
22+
observation_space: gym.Space
23+
action_space: gym.Space
24+
25+
def reset(self) -> Any:
26+
"""Reset the environment and return the initial observation."""
27+
...
28+
29+
def step(self, action: Any) -> Tuple[Any, float, bool, Dict]:
30+
"""Run one timestep of the environment's dynamics."""
31+
...
32+
33+
def render(self, mode: Optional[str] = "human") -> Any:
34+
"""Render the environment."""
35+
...
36+
37+
def close(self):
38+
"""Close the environment."""
39+
...
40+
41+
def seed(self, seed: Optional[int] = None):
42+
"""Set the seed for this env's random number generator(s)."""
43+
...
44+
45+
46+
class EnvCompatibility(gym.Env):
47+
r"""A wrapper which can transform an environment from the old API to the new API.
48+
49+
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation.
50+
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info).
51+
(Refer to docs for details on the API change)
52+
53+
Known limitations:
54+
- Environments that use `self.np_random` might not work as expected.
55+
"""
56+
57+
def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None):
58+
"""A wrapper which converts old-style envs to valid modern envs.
59+
60+
Some information may be lost in the conversion, so we recommend updating your environment.
61+
62+
Args:
63+
old_env (LegacyEnv): the env to wrap, implemented with the old API
64+
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render
65+
"""
66+
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
67+
self.render_mode = render_mode
68+
self.reward_range = getattr(old_env, "reward_range", None)
69+
self.spec = getattr(old_env, "spec", None)
70+
self.env = old_env
71+
72+
self.observation_space = old_env.observation_space
73+
self.action_space = old_env.action_space
74+
75+
def reset(
76+
self, seed: Optional[int] = None, options: Optional[dict] = None
77+
) -> Tuple[ObsType, dict]:
78+
"""Resets the environment.
79+
80+
Args:
81+
seed: the seed to reset the environment with
82+
options: the options to reset the environment with
83+
84+
Returns:
85+
(observation, info)
86+
"""
87+
if seed is not None:
88+
self.env.seed(seed)
89+
# Options are ignored
90+
91+
if self.render_mode == "human":
92+
self.render()
93+
94+
return self.env.reset(), {}
95+
96+
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
97+
"""Steps through the environment.
98+
99+
Args:
100+
action: action to step through the environment with
101+
102+
Returns:
103+
(observation, reward, terminated, truncated, info)
104+
"""
105+
obs, reward, done, info = self.env.step(action)
106+
107+
if self.render_mode == "human":
108+
self.render()
109+
110+
return convert_to_terminated_truncated_step_api((obs, reward, done, info))
111+
112+
def render(self) -> Any:
113+
"""Renders the environment.
114+
115+
Returns:
116+
The rendering of the environment, depending on the render mode
117+
"""
118+
return self.env.render(mode=self.render_mode)
119+
120+
def close(self):
121+
"""Closes the environment."""
122+
self.env.close()
123+
124+
def __str__(self):
125+
"""Returns the wrapper name and the unwrapped environment string."""
126+
return f"<{type(self).__name__}{self.env}>"
127+
128+
def __repr__(self):
129+
"""Returns the string representation of the wrapper."""
130+
return str(self)

gym/wrappers/step_api_compatibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class StepAPICompatibility(gym.Wrapper):
2222
>>> env = gym.make("CartPole-v1")
2323
>>> env # wrapper not applied by default, set to new API
2424
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
25-
>>> env = gym.make("CartPole-v1", apply_step_compatibility=True) # set to old API
25+
>>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API
2626
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
2727
>>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs
2828

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ cloudpickle>=1.2.0
33
importlib_metadata>=4.8.0; python_version < '3.10'
44
gym_notices>=0.0.4
55
dataclasses==0.8; python_version == '3.6'
6+
typing_extensions==4.3.0; python_version == '3.7'
67
opencv-python>=3.0
78
lz4>=3.1.0
89
matplotlib>=3.0

tests/envs/test_compatibility.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import sys
2+
from typing import Any, Dict, Optional, Tuple
3+
4+
import numpy as np
5+
6+
import gym
7+
from gym.spaces import Discrete
8+
from gym.wrappers.compatibility import EnvCompatibility, LegacyEnv
9+
10+
11+
class LegacyEnvExplicit(LegacyEnv, gym.Env):
12+
"""Legacy env that explicitly implements the old API."""
13+
14+
observation_space = Discrete(1)
15+
action_space = Discrete(1)
16+
metadata = {"render.modes": ["human", "rgb_array"]}
17+
18+
def __init__(self):
19+
pass
20+
21+
def reset(self):
22+
return 0
23+
24+
def step(self, action):
25+
return 0, 0, False, {}
26+
27+
def render(self, mode="human"):
28+
if mode == "human":
29+
return
30+
elif mode == "rgb_array":
31+
return np.zeros((1, 1, 3), dtype=np.uint8)
32+
33+
def close(self):
34+
pass
35+
36+
def seed(self, seed=None):
37+
pass
38+
39+
40+
class LegacyEnvImplicit(gym.Env):
41+
"""Legacy env that implicitly implements the old API as a protocol."""
42+
43+
observation_space = Discrete(1)
44+
action_space = Discrete(1)
45+
metadata = {"render.modes": ["human", "rgb_array"]}
46+
47+
def __init__(self):
48+
pass
49+
50+
def reset(self): # type: ignore
51+
return 0 # type: ignore
52+
53+
def step(self, action: Any) -> Tuple[int, float, bool, Dict]:
54+
return 0, 0.0, False, {}
55+
56+
def render(self, mode: Optional[str] = "human") -> Any:
57+
if mode == "human":
58+
return
59+
elif mode == "rgb_array":
60+
return np.zeros((1, 1, 3), dtype=np.uint8)
61+
62+
def close(self):
63+
pass
64+
65+
def seed(self, seed: Optional[int] = None):
66+
pass
67+
68+
69+
def test_explicit():
70+
old_env = LegacyEnvExplicit()
71+
assert isinstance(old_env, LegacyEnv)
72+
env = EnvCompatibility(old_env, render_mode="rgb_array")
73+
assert env.observation_space == Discrete(1)
74+
assert env.action_space == Discrete(1)
75+
assert env.reset() == (0, {})
76+
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
77+
assert env.step(0) == (0, 0, False, False, {})
78+
assert env.render().shape == (1, 1, 3)
79+
env.close()
80+
81+
82+
def test_implicit():
83+
old_env = LegacyEnvImplicit()
84+
if sys.version_info >= (3, 7):
85+
# We need to give up on typing in Python 3.6
86+
assert isinstance(old_env, LegacyEnv)
87+
env = EnvCompatibility(old_env, render_mode="rgb_array")
88+
assert env.observation_space == Discrete(1)
89+
assert env.action_space == Discrete(1)
90+
assert env.reset() == (0, {})
91+
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
92+
assert env.step(0) == (0, 0, False, False, {})
93+
assert env.render().shape == (1, 1, 3)
94+
env.close()
95+
96+
97+
def test_make_compatibility_in_spec():
98+
gym.register(
99+
id="LegacyTestEnv-v0",
100+
entry_point=LegacyEnvExplicit,
101+
apply_api_compatibility=True,
102+
)
103+
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
104+
assert env.observation_space == Discrete(1)
105+
assert env.action_space == Discrete(1)
106+
assert env.reset() == (0, {})
107+
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
108+
assert env.step(0) == (0, 0, False, False, {})
109+
img = env.render()
110+
assert isinstance(img, np.ndarray)
111+
assert img.shape == (1, 1, 3) # type: ignore
112+
env.close()
113+
del gym.envs.registration.registry["LegacyTestEnv-v0"]
114+
115+
116+
def test_make_compatibility_in_make():
117+
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
118+
env = gym.make(
119+
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
120+
)
121+
assert env.observation_space == Discrete(1)
122+
assert env.action_space == Discrete(1)
123+
assert env.reset() == (0, {})
124+
assert env.reset(seed=0, options={"some": "option"}) == (0, {})
125+
assert env.step(0) == (0, 0, False, False, {})
126+
img = env.render()
127+
assert isinstance(img, np.ndarray)
128+
assert img.shape == (1, 1, 3) # type: ignore
129+
env.close()
130+
del gym.envs.registration.registry["LegacyTestEnv-v0"]

0 commit comments

Comments
 (0)