Skip to content

Commit 1486d33

Browse files
Fix reset info being lost in vector environments (#3111)
* Fix reset info * Added test for checking vector info
1 parent 21e6e27 commit 1486d33

File tree

4 files changed

+73
-6
lines changed

4 files changed

+73
-6
lines changed

gym/vector/async_vector_env.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,10 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
566566
info,
567567
) = env.step(data)
568568
if terminated or truncated:
569-
old_observation = observation
569+
old_observation, old_info = observation, info
570570
observation, info = env.reset()
571571
info["final_observation"] = old_observation
572+
info["final_info"] = old_info
572573
pipe.send(((observation, reward, terminated, truncated, info), True))
573574
elif command == "seed":
574575
env.seed(data)
@@ -636,10 +637,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
636637
info,
637638
) = env.step(data)
638639
if terminated or truncated:
639-
old_observation = observation
640+
old_observation, old_info = observation, info
640641
observation, info = env.reset()
641642
info["final_observation"] = old_observation
642-
643+
info["final_info"] = old_info
643644
write_to_shared_memory(
644645
observation_space, index, observation, shared_memory
645646
)

gym/vector/sync_vector_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ def step_wait(self):
150150
) = env.step(action)
151151

152152
if self._terminateds[i] or self._truncateds[i]:
153-
old_observation = observation
153+
old_observation, old_info = observation, info
154154
observation, info = env.reset()
155155
info["final_observation"] = old_observation
156+
info["final_info"] = old_info
156157
observations.append(observation)
157158
infos = self._add_info(infos, info, i)
158159
self.observations = concatenate(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"],
1919
"classic_control": ["pygame==2.1.0"],
2020
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
21-
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
21+
"mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
2222
"toy_text": ["pygame==2.1.0"],
2323
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
2424
}

tests/vector/test_vector_env.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from functools import partial
2+
13
import numpy as np
24
import pytest
35

4-
from gym.spaces import Tuple
6+
from gym.spaces import Discrete, Tuple
57
from gym.vector.async_vector_env import AsyncVectorEnv
68
from gym.vector.sync_vector_env import SyncVectorEnv
79
from gym.vector.vector_env import VectorEnv
10+
from tests.testing_env import GenericTestEnv
811
from tests.vector.utils import CustomSpace, make_env
912

1013

@@ -58,3 +61,65 @@ def test_custom_space_vector_env():
5861

5962
assert isinstance(env.single_action_space, CustomSpace)
6063
assert isinstance(env.action_space, Tuple)
64+
65+
66+
@pytest.mark.parametrize(
67+
"vectoriser",
68+
(
69+
SyncVectorEnv,
70+
partial(AsyncVectorEnv, shared_memory=True),
71+
partial(AsyncVectorEnv, shared_memory=False),
72+
),
73+
ids=["Sync", "Async with shared memory", "Async without shared memory"],
74+
)
75+
def test_final_obs_info(vectoriser):
76+
"""Tests that the vector environments correctly return the final observation and info."""
77+
78+
def reset_fn(self, seed=None, options=None):
79+
return 0, {"reset": True}
80+
81+
def thunk():
82+
return GenericTestEnv(
83+
action_space=Discrete(4),
84+
observation_space=Discrete(4),
85+
reset_fn=reset_fn,
86+
step_fn=lambda self, action: (
87+
action if action < 3 else 0,
88+
0,
89+
action >= 3,
90+
False,
91+
{"action": action},
92+
),
93+
)
94+
95+
env = vectoriser([thunk])
96+
obs, info = env.reset()
97+
assert obs == np.array([0]) and info == {
98+
"reset": np.array([True]),
99+
"_reset": np.array([True]),
100+
}
101+
102+
obs, _, termination, _, info = env.step([1])
103+
assert (
104+
obs == np.array([1])
105+
and termination == np.array([False])
106+
and info == {"action": np.array([1]), "_action": np.array([True])}
107+
)
108+
109+
obs, _, termination, _, info = env.step([2])
110+
assert (
111+
obs == np.array([2])
112+
and termination == np.array([False])
113+
and info == {"action": np.array([2]), "_action": np.array([True])}
114+
)
115+
116+
obs, _, termination, _, info = env.step([3])
117+
assert (
118+
obs == np.array([0])
119+
and termination == np.array([True])
120+
and info["reset"] == np.array([True])
121+
)
122+
assert "final_observation" in info and "final_info" in info
123+
assert info["final_observation"] == np.array([0]) and info["final_info"] == {
124+
"action": 3
125+
}

0 commit comments

Comments
 (0)