Stable Baslines3

Ver. 1.2.8 (2023-09-25)

This module provides wrapper classes for integrating stable baselines3 policy algorithms.

See also: https://pypi.org/project/stable-baselines3/

class mlpro.wrappers.sb3.DummyEnv(p_observation_space=None, p_action_space=None)

Bases: Env

Dummy class for Environment. This is required due to some of the SB3 Policy Algorithm requires to have an Environment. As for now, it only needs the observation space and the action space.

observation_space: spaces.Space[ObsType] = None
action_space: spaces.Space[ActType] = None
compute_reward(achieved_goal: int | ndarray, desired_goal: int | ndarray, _info: Dict[str, Any] | None) float32
seed(seed=None)
class mlpro.wrappers.sb3.VecExtractDictObs(venv: VecEnv, observation_space: Space | None = None, action_space: Space | None = None)

Bases: VecEnvWrapper

A vectorized wrapper for filtering a specific key from dictionary observations. This is used for HER incorporation on off-policy algorithms. Similar to Gym’s FilterObservation wrapper:

reset() ndarray

Reset all the environments and return an array of observations, or a tuple of observation arrays.

If step_async is still doing work, that work will be cancelled and step_wait() should not be called until step_async() is invoked again.

Returns:

observation

step_async(actions: ndarray) None

Tell all the environments to start taking a step with the given actions. Call step_wait() to get the results of the step.

You should not call this if a step_async run is already pending.

step_wait() Tuple[ndarray | Dict[str, ndarray] | Tuple[ndarray, ...], ndarray, ndarray, List[Dict]]

Wait for the step taken with step_async().

Returns:

observation, reward, done, information

_abc_impl = <_abc._abc_data object>
class mlpro.wrappers.sb3.WrPolicySB32MLPro(p_sb3_policy, p_cycle_limit, p_observation_space: MSpace, p_action_space: MSpace, p_ada: bool = True, p_visualize: bool = False, p_logging=True, p_num_envs: int = 1, p_desired_goals=None)

Bases: Wrapper, Policy

This class provides a policy wrapper from Standard Baselines 3 (SB3). Especially On-Policy Algorithm

Parameters:
  • p_sb3_policy – SB3 Policy

  • p_cycle_limit – Maximum number of cycles

  • p_observation_space (MSpace) – Observation Space

  • p_action_space (MSpace) – Environment Action Space

  • p_ada (bool) – Adaptability. Defaults to True.

  • p_visualize (bool) – Boolean switch for visualisation. Default = False.

  • p_logging – Log level (see constants of class Log). Default = Log.C_LOG_ALL.

  • p_num_envs (int) – Number of environments, specifically for vectorized environment.

  • p_desired_goals (list, Optional) – Desired state goals for Hindsight Experience Replay (HER).

C_TYPE = 'Wrapper SB3 -> MLPro'
C_WRAPPED_PACKAGE = 'stable_baselines3'
C_MINIMUM_VERSION = '2.1.0'
_compute_action_on_policy(p_obs: State) Action
_compute_action_off_policy(p_obs: State) Action
_adapt_off_policy(p_sars_elem: SARSElement) bool
_adapt_on_policy(p_sars_elem: SARSElement) bool
_clear_buffer_on_policy()
_clear_buffer_off_policy()
_add_buffer_off_policy(p_buffer_element: SARSElement)

Redefine add_buffer function. Instead of adding to MLPro SARBuffer, we are using internal buffer from SB3 for off_policy.

If you are incorporating HER, please read the following decriptions: The observation space is required to contain at least three elements, namely observation, desired_goal, and achieved_goal. Here, desired_goal specifies the goal that the agent should attempt to achieve. achieved_goal is the goal that it currently achieved instead. observation contains the actual observations of the environment as per usual.

_add_buffer_on_policy(p_buffer_element: SARSElement)

Redefine add_buffer function. Instead of adding to MLPro SARBuffer, we are using internal buffer from SB3 for on_policy.

_add_additional_buffer(p_buffer_element: SARSElement)

Cross References

  • Howto RL-WP-004: Train an Agent with SB3