From 9fd3a3e8ab084b0743ae2614d2a3d1091f019e2a Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Wed, 28 Dec 2022 20:52:29 +0100 Subject: [PATCH 1/6] Deprecation warning for shared layers in Mlpextractor --- stable_baselines3/common/torch_layers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 105d39f37..cc810208c 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,3 +1,4 @@ +import warnings from itertools import zip_longest from typing import Dict, List, Tuple, Type, Union @@ -159,6 +160,9 @@ class MlpExtractor(nn.Module): It is formatted like ``dict(vf=[], pi=[])``. If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. + Depredcation note: shared layers in ``net_arch`` are deprecated, please use separate + pi and vf networks (e.g. net_arch=[dict(pi=[...], vf=[...])]) + For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128 @@ -189,6 +193,15 @@ def __init__( value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network last_layer_dim_shared = feature_dim + if len(net_arch) > 0 and isinstance(net_arch[0], int): + warnings.warn( + ( + "Shared layers in the mlp_extractor are deprecated, please use separate pi and vf networks" + "(e.g. net_arch=[dict(pi=[...], vf=[...])])" + ), + DeprecationWarning, + ) + # Iterate through the shared layers and build the shared parts of the network for layer in net_arch: if isinstance(layer, int): # Check that this is a shared layer From b3c79a35d66f530b23e0ed7a4ee6e512ea7ed504 Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Wed, 28 Dec 2022 21:05:04 +0100 Subject: [PATCH 2/6] Updated changelog --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b5b170176..b52587a4d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -52,6 +52,7 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ - You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` +- Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua) Others: ^^^^^^^ From 653b4da3693038f62dd6d0acdc01596b0de99edd Mon Sep 17 00:00:00 2001 From: AlexPasqua Date: Thu, 29 Dec 2022 18:50:52 +0100 Subject: [PATCH 3/6] Updated custom policy doc --- docs/guide/custom_policy.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 616fc4964..f9433499b 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -121,6 +121,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t .. warning:: If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``. + Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared. .. code-block:: python @@ -243,6 +244,11 @@ On-Policy Algorithms Shared Networks --------------- +.. warning:: + Shared layers in the the ``mlp_extractor`` are **deprecated**. + In a future release all layers will have to be non-shared. + If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_). + The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many of them are shared between the policy network and the value network. It is assumed to be a list with the following structure: From 1bbc6594a31c4b46c7aea41b69af57d33de46863 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 4 Jan 2023 10:43:26 +0100 Subject: [PATCH 4/6] Update doc and deprecation --- docs/guide/custom_policy.rst | 100 +++++++----------- docs/misc/changelog.rst | 47 +++++--- stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/envs/identity_env.py | 2 +- stable_baselines3/common/policies.py | 28 ++++- stable_baselines3/common/torch_layers.py | 48 +++++---- stable_baselines3/version.txt | 2 +- tests/test_custom_policy.py | 13 ++- 8 files changed, 136 insertions(+), 106 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index f0e571340..337ffab44 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -51,8 +51,6 @@ Each of these network have a features extractor followed by a fully-connected ne .. image:: ../_static/img/sb3_policy.png -.. .. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif - Custom Network Architecture ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -90,13 +88,13 @@ using ``policy_kwargs`` parameter: # of two layers of size 32 each with Relu activation function # Note: an extra linear layer will be added on top of the pi and the vf nets, respectively policy_kwargs = dict(activation_fn=th.nn.ReLU, - net_arch=[dict(pi=[32, 32], vf=[32, 32])]) + net_arch=dict(pi=[32, 32], vf=[32, 32])) # Create the agent model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) # Retrieve the environment env = model.get_env() # Train the agent - model.learn(total_timesteps=100000) + model.learn(total_timesteps=20_000) # Save the agent model.save("ppo_cartpole") @@ -114,7 +112,7 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t .. note:: - By default the features extractor is shared between the actor and the critic to save computation (when applicable). + For on-policy algorithms, the features extractor is shared by default between the actor and the critic to save computation (when applicable). However, this can be changed setting ``share_features_extractor=False`` in the ``policy_kwargs`` (both for on-policy and off-policy algorithms). @@ -241,7 +239,7 @@ downsampling and "vector" with a single linear layer. On-Policy Algorithms ^^^^^^^^^^^^^^^^^^^^ -Shared Networks +Custom Networks --------------- .. warning:: @@ -249,61 +247,48 @@ Shared Networks In a future release all layers will have to be non-shared. If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_). -The ``net_arch`` parameter of ``A2C`` and ``PPO`` policies allows to specify the amount and size of the hidden layers and how many -of them are shared between the policy network and the value network. It is assumed to be a list with the following -structure: - -1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer. - If the number of ints is zero, there will be no shared layers. -2. An optional dict, to specify the following non-shared layers for the value network and the policy network. - It is formatted like ``dict(vf=[], pi=[])``. - If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. - -In short: ``[, dict(vf=[], pi=[])]``. - -Examples -~~~~~~~~ - -Two shared layers of size 128: ``net_arch=[128, 128]`` - - -.. code-block:: none +.. warning:: + In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change + to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently) + for the actor and the critic, with the same architecture. - obs - | - <128> - | - <128> - / \ - action value +If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``, +you can pass a dictionary of the following structure: ``dict(pi=[], vf=[])``. -Value network deeper than policy network, first layer shared: ``net_arch=[128, dict(vf=[256, 256])]`` +For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks, +then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``. -.. code-block:: none - - obs - | - <128> - / \ - action <256> - | - <256> - | - value +.. Otherwise, to have actor and critic that share the same network architecture, +.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each). +Examples +~~~~~~~~ -Initially shared then diverging: ``[128, dict(vf=[256], pi=[16])]`` +.. TODO(antonin): uncomment when shared network is removed +.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]`` +.. +.. .. code-block:: none +.. +.. obs +.. / \ +.. <128> <128> +.. | | +.. <128> <128> +.. | | +.. action value + +Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])`` .. code-block:: none - obs - | - <128> - / \ - <16> <256> - | | - action value + obs + / \ + <32> <64> + | | + <32> <64> + | | + action value Advanced Example @@ -408,21 +393,16 @@ If your task requires even more granular control over the policy/value architect Off-Policy Algorithms ^^^^^^^^^^^^^^^^^^^^^ -If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG`` or ``TD3``, -you can pass a dictionary of the following structure: ``dict(qf=[], pi=[])``. +If you need a network architecture that is different for the actor and the critic when using ``SAC``, ``DDPG``, ``TQC`` or ``TD3``, +you can pass a dictionary of the following structure: ``dict(pi=[], qf=[])``. For example, if you want a different architecture for the actor (aka ``pi``) and the critic (Q-function aka ``qf``) networks, -then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``. +then you can specify ``net_arch=dict(pi=[64, 64], qf=[400, 300])``. Otherwise, to have actor and critic that share the same network architecture, you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each). -.. note:: - Compared to their on-policy counterparts, no shared layers (other than the features extractor) - between the actor and the critic are allowed (to prevent issues with target networks). - - .. note:: For advanced customization of off-policy algorithms policies, please take a look at the code. A good understanding of the algorithm used is required, see discussion in `issue #425 `_ diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d57c2d90d..19d5a5fcc 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,9 +4,16 @@ Changelog ========== -Release 1.7.0a11 (WIP) +Release 1.7.0a12 (WIP) -------------------------- +.. warning:: + + Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO. + This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]`` + will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms. + + .. note:: A2C and PPO saved with SB3 < 1.7.0 will show a warning about @@ -34,8 +41,15 @@ New Features: - Added ``normalized_image`` parameter to ``NatureCNN`` and ``CombinedExtractor`` - Added support for Python 3.10 -SB3-Contrib -^^^^^^^^^^^ +`SB3-Contrib`_ +^^^^^^^^^^^^^^ +- Fixed a bug in ``RecurrentPPO`` where the lstm states where incorrectly reshaped for ``n_lstm_layers > 1`` (thanks @kolbytn) +- Fixed ``RuntimeError: rnn: hx is not contiguous`` while predicting terminal values for ``RecurrentPPO`` when ``n_lstm_layers > 1`` + +`RL Zoo`_ +^^^^^^^^^ +- Added support for python file for configuration +- Added ``monitor_kwargs`` parameter Bug Fixes: ^^^^^^^^^^ @@ -100,8 +114,12 @@ New Features: - Added progress bar callback - The `RL Zoo `_ can now be installed as a package (``pip install rl_zoo3``) -SB3-Contrib -^^^^^^^^^^^ +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ +- RL Zoo is now a python package and can be installed using ``pip install rl_zoo3`` Bug Fixes: ^^^^^^^^^^ @@ -136,8 +154,8 @@ New Features: - Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio) - The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys -SB3-Contrib -^^^^^^^^^^^ +`SB3-Contrib`_ +^^^^^^^^^^^^^^ - Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel) Bug Fixes: @@ -193,8 +211,8 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ -SB3-Contrib -^^^^^^^^^^^ +`SB3-Contrib`_ +^^^^^^^^^^^^^^ - Added Recurrent PPO (PPO LSTM). See https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53 @@ -247,8 +265,8 @@ New Features: depending on desired maximum width of output. - Allow PPO to turn of advantage normalization (see `PR #763 `_) @vwxyzjn -SB3-Contrib -^^^^^^^^^^^ +`SB3-Contrib`_ +^^^^^^^^^^^^^^ - coming soon: Cross Entropy Method, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/62 Bug Fixes: @@ -310,8 +328,8 @@ New Features: - Added ``skip`` option to ``VecTransposeImage`` to skip transforming the channel order when the heuristic is wrong - Added ``copy()`` and ``combine()`` methods to ``RunningMeanStd`` -SB3-Contrib -^^^^^^^^^^^ +`SB3-Contrib`_ +^^^^^^^^^^^^^^ - Added Trust Region Policy Optimization (TRPO) (@cyprienc) - Added Augmented Random Search (ARS) (@sgillen) - Coming soon: PPO LSTM, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/53 @@ -1138,7 +1156,8 @@ and `Quentin Gallouédec`_ (aka @qgallouedec). .. _Quentin Gallouédec: https://gallouedec.com/ .. _@qgallouedec: https://github.com/qgallouedec - +.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib +.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo Contributors: ------------- diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 9d5fa0308..a71043d37 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -617,7 +617,7 @@ def set_parameters( f"expected {objects_needing_update}, got {updated_objects}" ) - @classmethod + @classmethod # noqa: C901 def load( cls: Type[SelfBaseAlgorithm], path: Union[str, pathlib.Path, io.BufferedIOBase], diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index ea1126811..9635e5319 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -4,7 +4,7 @@ import numpy as np from gym import spaces -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import GymStepReturn T = TypeVar("T", int, np.ndarray) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index a752a7af4..0cf491761 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -418,7 +418,8 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + # TODO(antonin): update type annotation when we remove shared network support + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -451,12 +452,28 @@ def __init__( normalize_images=normalize_images, ) + # Convert [dict()] to dict() as shared network are deprecated + if isinstance(net_arch, list) and len(net_arch) > 0: + if isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] + else: + # Note: deprecation warning will be emitted + # by the MlpExtractor constructor + pass + # Default network architecture, from stable-baselines if net_arch is None: if features_extractor_class == NatureCNN: net_arch = [] else: - net_arch = [dict(pi=[64, 64], vf=[64, 64])] + net_arch = dict(pi=[64, 64], vf=[64, 64]) self.net_arch = net_arch self.activation_fn = activation_fn @@ -472,7 +489,8 @@ def __init__( self.pi_features_extractor = self.features_extractor self.vf_features_extractor = self.make_features_extractor() # if the features extractor is not shared, there cannot be shared layers in the mlp_extractor - if len(net_arch) > 0 and not isinstance(net_arch[0], dict): + # TODO(antonin): update the check once we change net_arch behavior + if isinstance(net_arch, list) and len(net_arch) > 0: raise ValueError( "Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor" ) @@ -752,7 +770,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -825,7 +843,7 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index aaf92136a..a51caa03b 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -181,7 +181,7 @@ class MlpExtractor(nn.Module): def __init__( self, feature_dim: int, - net_arch: List[Union[int, Dict[str, List[int]]]], + net_arch: Union[Dict[str, List[int]], List[Union[int, Dict[str, List[int]]]]], activation_fn: Type[nn.Module], device: Union[th.device, str] = "auto", ) -> None: @@ -194,32 +194,38 @@ def __init__( value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network last_layer_dim_shared = feature_dim - if len(net_arch) > 0 and isinstance(net_arch[0], int): + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int): warnings.warn( ( - "Shared layers in the mlp_extractor are deprecated, please use separate pi and vf networks" - "(e.g. net_arch=[dict(pi=[...], vf=[...])])" + "Shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, " + "please use separate pi and vf networks " + "(e.g. net_arch=dict(pi=[...], vf=[...]))" ), DeprecationWarning, ) - # Iterate through the shared layers and build the shared parts of the network - for layer in net_arch: - if isinstance(layer, int): # Check that this is a shared layer - # TODO: give layer a meaningful name - shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer - shared_net.append(activation_fn()) - last_layer_dim_shared = layer - else: - assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" - if "pi" in layer: - assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers." - policy_only_layers = layer["pi"] - - if "vf" in layer: - assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers." - value_only_layers = layer["vf"] - break # From here on the network splits up in policy and value network + # TODO(antonin): update behavior for net_arch=[64, 64] + # once shared networks are removed + if isinstance(net_arch, dict): + policy_only_layers = net_arch["pi"] + value_only_layers = net_arch["vf"] + else: + # Iterate through the shared layers and build the shared parts of the network + for layer in net_arch: + if isinstance(layer, int): # Check that this is a shared layer + shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer + shared_net.append(activation_fn()) + last_layer_dim_shared = layer + else: + assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" + if "pi" in layer: + assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers." + policy_only_layers = layer["pi"] + + if "vf" in layer: + assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers." + value_only_layers = layer["vf"] + break # From here on the network splits up in policy and value network last_layer_dim_pi = last_layer_dim_shared last_layer_dim_vf = last_layer_dim_shared diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a02b7e49b..77ca7d320 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a11 +1.7.0a12 diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 4688bb973..11241b3b0 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -8,10 +8,13 @@ @pytest.mark.parametrize( "net_arch", [ - [12, dict(vf=[16], pi=[8])], - [4], [], + dict(vf=[16], pi=[8]), + # [] behavior will change + [4], [4, 4], + # All values below are deprecated + [12, dict(vf=[16], pi=[8])], [12, dict(vf=[8, 4], pi=[8])], [12, dict(vf=[8], pi=[8, 4])], [12, dict(pi=[8])], @@ -19,7 +22,11 @@ ) @pytest.mark.parametrize("model_class", [A2C, PPO]) def test_flexible_mlp(model_class, net_arch): - _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int): + with pytest.warns(): + _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) + else: + _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) @pytest.mark.parametrize("net_arch", [[], [4], [4, 4], dict(qf=[8], pi=[8, 4])]) From a5d284a1720bc7d7e218f53ec96ca632c4d3bdb5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 4 Jan 2023 10:47:09 +0100 Subject: [PATCH 5/6] Fix doc build --- docs/misc/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 19d5a5fcc..5a134540c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -8,8 +8,8 @@ Release 1.7.0a12 (WIP) -------------------------- .. warning:: - - Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO. + + Shared layers in MLP policy (``mlp_extractor``) are now deprecated for PPO, A2C and TRPO. This feature will be removed in SB3 v1.8.0 and the behavior of ``net_arch=[64, 64]`` will create **separate** networks with the same architecture, to be consistent with the off-policy algorithms. From b88cb2a3cae106263e6df0c3ba75260fdf468cd2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 4 Jan 2023 16:27:11 +0100 Subject: [PATCH 6/6] Minor edits --- docs/guide/custom_policy.rst | 6 +----- stable_baselines3/common/torch_layers.py | 4 ++-- tests/test_custom_policy.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 337ffab44..c9e598e6c 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -325,7 +325,7 @@ If your task requires even more granular control over the policy/value architect last_layer_dim_pi: int = 64, last_layer_dim_vf: int = 64, ): - super(CustomNetwork, self).__init__() + super().__init__() # IMPORTANT: # Save output dimensions, used to create the distributions @@ -361,8 +361,6 @@ If your task requires even more granular control over the policy/value architect observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Callable[[float], float], - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, *args, **kwargs, ): @@ -371,8 +369,6 @@ If your task requires even more granular control over the policy/value architect observation_space, action_space, lr_schedule, - net_arch, - activation_fn, # Pass remaining arguments to base class *args, **kwargs, diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index a51caa03b..302d9b187 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -161,8 +161,8 @@ class MlpExtractor(nn.Module): It is formatted like ``dict(vf=[], pi=[])``. If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. - Depredcation note: shared layers in ``net_arch`` are deprecated, please use separate - pi and vf networks (e.g. net_arch=[dict(pi=[...], vf=[...])]) + Deprecation note: shared layers in ``net_arch`` are deprecated, please use separate + pi and vf networks (e.g. net_arch=dict(pi=[...], vf=[...])) For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 11241b3b0..85c3d37d1 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize("model_class", [A2C, PPO]) def test_flexible_mlp(model_class, net_arch): if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int): - with pytest.warns(): + with pytest.warns(DeprecationWarning): _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) else: _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)