Stable Baslines3
Ver. 1.2.8 (2023-09-25)
This module provides wrapper classes for integrating stable baselines3 policy algorithms.
See also: https://pypi.org/project/stable-baselines3/
- class mlpro.wrappers.sb3.DummyEnv(p_observation_space=None, p_action_space=None)
Bases:
Env
Dummy class for Environment. This is required due to some of the SB3 Policy Algorithm requires to have an Environment. As for now, it only needs the observation space and the action space.
- observation_space: spaces.Space[ObsType] = None
- action_space: spaces.Space[ActType] = None
- compute_reward(achieved_goal: int | ndarray, desired_goal: int | ndarray, _info: Dict[str, Any] | None) float32
- seed(seed=None)
- class mlpro.wrappers.sb3.VecExtractDictObs(venv: VecEnv, observation_space: Space | None = None, action_space: Space | None = None)
Bases:
VecEnvWrapper
A vectorized wrapper for filtering a specific key from dictionary observations. This is used for HER incorporation on off-policy algorithms. Similar to Gym’s FilterObservation wrapper:
- reset() ndarray
Reset all the environments and return an array of observations, or a tuple of observation arrays.
If step_async is still doing work, that work will be cancelled and step_wait() should not be called until step_async() is invoked again.
- Returns:
observation
- step_async(actions: ndarray) None
Tell all the environments to start taking a step with the given actions. Call step_wait() to get the results of the step.
You should not call this if a step_async run is already pending.
- step_wait() Tuple[ndarray | Dict[str, ndarray] | Tuple[ndarray, ...], ndarray, ndarray, List[Dict]]
Wait for the step taken with step_async().
- Returns:
observation, reward, done, information
- _abc_impl = <_abc._abc_data object>
- class mlpro.wrappers.sb3.WrPolicySB32MLPro(p_sb3_policy, p_cycle_limit, p_observation_space: MSpace, p_action_space: MSpace, p_ada: bool = True, p_visualize: bool = False, p_logging=True, p_num_envs: int = 1, p_desired_goals=None)
Bases:
Wrapper
,Policy
This class provides a policy wrapper from Standard Baselines 3 (SB3). Especially On-Policy Algorithm
- Parameters:
p_sb3_policy – SB3 Policy
p_cycle_limit – Maximum number of cycles
p_observation_space (MSpace) – Observation Space
p_action_space (MSpace) – Environment Action Space
p_ada (bool) – Adaptability. Defaults to True.
p_visualize (bool) – Boolean switch for visualisation. Default = False.
p_logging – Log level (see constants of class Log). Default = Log.C_LOG_ALL.
p_num_envs (int) – Number of environments, specifically for vectorized environment.
p_desired_goals (list, Optional) – Desired state goals for Hindsight Experience Replay (HER).
- C_TYPE = 'Wrapper SB3 -> MLPro'
- C_WRAPPED_PACKAGE = 'stable_baselines3'
- C_MINIMUM_VERSION = '2.1.0'
- _adapt_off_policy(p_sars_elem: SARSElement) bool
- _adapt_on_policy(p_sars_elem: SARSElement) bool
- _clear_buffer_on_policy()
- _clear_buffer_off_policy()
- _add_buffer_off_policy(p_buffer_element: SARSElement)
Redefine add_buffer function. Instead of adding to MLPro SARBuffer, we are using internal buffer from SB3 for off_policy.
If you are incorporating HER, please read the following decriptions: The observation space is required to contain at least three elements, namely observation, desired_goal, and achieved_goal. Here, desired_goal specifies the goal that the agent should attempt to achieve. achieved_goal is the goal that it currently achieved instead. observation contains the actual observations of the environment as per usual.
- _add_buffer_on_policy(p_buffer_element: SARSElement)
Redefine add_buffer function. Instead of adding to MLPro SARBuffer, we are using internal buffer from SB3 for on_policy.
- _add_additional_buffer(p_buffer_element: SARSElement)
Cross References
Howto RL-WP-004: Train an Agent with SB3