Skip to content

Commit

Permalink
Fix type annotations of buffers (#1700)
Browse files Browse the repository at this point in the history
* Fix type annotation and replay buffer

* Exclude pytype check

* Remove some pytype specific annotaiton and update changelog

* Fix HerReplayBuffer type hints

* try remove   # type: ignore[assignment]

* revert change

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
araffin and qgallouedec committed Sep 28, 2023
1 parent fab6cb3 commit c6c660e
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 84 deletions.
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a5 (WIP)
Release 2.2.0a6 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -49,6 +49,9 @@ Others:
- Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints
- Fixed ``stable_baselines3/common/save_util.py`` type hints
- Updated docker images to Ubuntu Jammy using micromamba 1.5
- Fixed ``stable_baselines3/common/buffers.py`` type hints
- Fixed ``stable_baselines3/her/her_replay_buffer.py`` type hints
- Buffers do no call an additional ``.copy()`` when storing new transitions

Documentation:
^^^^^^^^^^^^^^
Expand Down
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,27 @@ line-length = 127
[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
# Checked with mypy
exclude = [
"stable_baselines3/common/buffers.py",
"stable_baselines3/common/base_class.py",
"stable_baselines3/common/callbacks.py",
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py"
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/buffers.py$
| stable_baselines3/common/distributions.py$
stable_baselines3/common/distributions.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/her/her_replay_buffer.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""
Expand Down
8 changes: 2 additions & 6 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,7 @@ def _setup_learn(
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
assert self.env is not None
# pytype: disable=annotation-type-mismatch
self._last_obs = self.env.reset() # type: ignore[assignment]
# pytype: enable=annotation-type-mismatch
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
Expand Down Expand Up @@ -707,7 +705,7 @@ def load( # noqa: C901

# Gym -> Gymnasium space conversion
for key in {"observation_space", "action_space"}:
data[key] = _convert_space(data[key]) # pytype: disable=unsupported-operands
data[key] = _convert_space(data[key])

if env is not None:
# Wrap first if needed
Expand All @@ -726,14 +724,12 @@ def load( # noqa: C901
if "env" in data:
env = data["env"]

# pytype: disable=not-instantiable,wrong-keyword-args
model = cls(
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # type: ignore[call-arg]
)
# pytype: enable=not-instantiable,wrong-keyword-args

# load parameters
model.__dict__.update(data)
Expand Down Expand Up @@ -776,7 +772,7 @@ def load( # noqa: C901
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error
model.policy.reset_noise() # type: ignore[operator]
return model

def get_parameters(self) -> Dict[str, Dict]:
Expand Down
Loading

0 comments on commit c6c660e

Please sign in to comment.