Howto 13 - (RL) Comparison Native and Wrapper SB3 Policy

Ver. 1.0.2 (2021-12-07)

This module shows how to train with SB3 Wrapper for On-Policy Algorithm

Prerequisites

Please install the following packages to run this examples properly:

Example Code

## -------------------------------------------------------------------------------------------------
## -- Project : FH-SWF Automation Technology - Common Code Base (CCB)
## -- Package : mlpro
## -- Module  : Howto 13 - Comparison Native and Wrapper SB3 Policy
## -------------------------------------------------------------------------------------------------
## -- History :
## -- yyyy-mm-dd  Ver.      Auth.    Description
## -- 2021-10-27  0.0.0     MRD      Creation
## -- 2021-10-27  1.0.0     MRD      Released first version
## -- 2021-11-16  1.0.1     DA       Refactoring
## -- 2021-12-07  1.0.2     DA       Refactoring
## -- 2021-12-20  1.0.3     DA       Refactoring
## -- 2021-12-23  1.0.4     MRD      Small change on custom _reset Wrapper
## -- 2021-12-24  1.0.5     DA       Replaced separtor in log line by Training.C_LOG_SEPARATOR
## -------------------------------------------------------------------------------------------------

"""
Ver. 1.0.4 (2021-12-23)

This module shows how to train with SB3 Wrapper for On-Policy Algorithm
"""

import gym
import pandas as pd
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from mlpro.rl.models import *
from mlpro.wrappers.openai_gym import WrEnvGYM2MLPro
from mlpro.wrappers.sb3 import WrPolicySB32MLPro
from pathlib import Path

# 1 Parameter
# 2 Create scenario and start training

if __name__ == "__main__":
    # 2.1 Parameters for demo mode
    logging     = Log.C_LOG_ALL
    visualize   = True
    path        = str(Path.home())
    max_episode = 400
 
else:
    # 2.2 Parameters for internal unit test
    logging     = Log.C_LOG_NOTHING
    visualize   = False
    path        = None
    max_episode = 200

mva_window = 1
buffer_size = 100
policy_kwargs = dict(activation_fn=torch.nn.Tanh,
                     net_arch=[dict(pi=[10, 10], vf=[10, 10])])

# 2 Implement your own RL scenario
class MyScenario(RLScenario):

    C_NAME      = 'Matrix'

    def _setup(self, p_mode, p_ada, p_logging):
        class CustomWrapperFixedSeed(WrEnvGYM2MLPro):
            def _reset(self, p_seed=None):
                self.log(self.C_LOG_TYPE_I, 'Reset')

                # 1 Reset Gym environment and determine initial state
                observation = self._gym_env.reset()
                obs         = DataObject(observation)

                # 2 Create state object from Gym observation
                state   = State(self._state_space)
                state.set_values(obs.get_data())
                state.set_success(True)
                self._set_state(state)

        # 1 Setup environment
        gym_env     = gym.make('CartPole-v1')
        gym_env.seed(1)
        # self._env   = mlpro_env
        self._env   = CustomWrapperFixedSeed(gym_env, p_logging=p_logging) 

        # 2 Instatiate Policy From SB3
        # env is set to None, it will be set up later inside the wrapper
        # _init_setup_model is set to False, the _setup_model() will be called inside
        # the wrapper manually

        # PPO
        policy_sb3 = PPO(
                    policy="MlpPolicy", 
                    env=None,
                    n_steps=buffer_size,
                    _init_setup_model=False,
                    policy_kwargs=policy_kwargs,
                    seed=1)

        # 3 Wrap the policy
        self.policy_wrapped = WrPolicySB32MLPro(
                p_sb3_policy=policy_sb3, 
                p_observation_space=self._env.get_state_space(),
                p_action_space=self._env.get_action_space(),
                p_ada=p_ada,
                p_logging=p_logging)
        
        # 4 Setup standard single-agent with own policy
        return Agent(
            p_policy=self.policy_wrapped,   
            p_envmodel=None,
            p_name='Smith',
            p_ada=p_ada,
            p_logging=p_logging
        )




# 3 Instantiate training
training        = RLTraining(
    p_scenario_cls=MyScenario,
    p_cycle_limit=1000,      
    p_collect_states=True,
    p_collect_actions=True,
    p_collect_rewards=True,
    p_collect_eval=True,
    p_path=path,
    p_visualize=visualize,
    p_logging=logging )


# 4 Train SB3 Wrapper
training.run()


# 5 Create Plotting Class
class MyDataPlotting(DataPlotting):
    def get_plots(self):
        """
        A function to plot data
        """
        for name in self.data.names:
            maxval  = 0
            minval  = 0
            if self.printing[name][0]:
                fig     = plt.figure(figsize=(7,7))
                raw   = []
                label   = []
                ax = fig.subplots(1,1)
                ax.set_title(name)
                ax.grid(True, which="both", axis="both")
                for fr_id in self.data.frame_id[name]:
                    raw.append(np.sum(self.data.get_values(name,fr_id)))
                    if self.printing[name][1] == -1:
                        maxval = max(raw)
                        minval = min(raw)
                    else:
                        maxval = self.printing[name][2]
                        minval = self.printing[name][1]
                    
                    label.append("%s"%fr_id)
                ax.plot(raw)
                ax.set_ylim(minval-(abs(minval)*0.1), maxval+(maxval*0.1))
                ax.set_xlabel("Episode")
                ax.legend(label, bbox_to_anchor = (1,0.5), loc = "center left")
                self.plots[0].append(name)
                self.plots[1].append(ax)
                if self.showing:
                    plt.show()
                else:
                    plt.close(fig)

# 6 Plotting 1 MLpro    
data_printing   = {"Cycle":        [False],
                    "Day":          [False],
                    "Second":       [False],
                    "Microsecond":  [False],
                    "Smith":        [True,-1]}


mem = training.get_results().ds_rewards
mem_plot    = MyDataPlotting(mem, p_showing=False, p_printing=data_printing)
mem_plot.get_plots()
wrapper_plot = mem_plot.plots

# 7 Create Callback for the SB3 Training
class CustomCallback(BaseCallback, Log):
    """
    A custom callback that derives from ``BaseCallback``.

    :param verbose: (int) Verbosity level 0: not output 1: info 2: debug
    """

    C_TYPE                  = 'Wrapper'
    C_NAME                  = 'SB3 Policy'

    def __init__(self, p_verbose=0):
        super(CustomCallback, self).__init__(p_verbose)
        reward_space = Set()
        reward_space.add_dim(Dimension(0, "Native"))
        self.ds_rewards  = RLDataStoring(reward_space)
        self.episode_num = 0
        self.total_cycle = 0
        self.cycles = 0
        self.plots = None

        self.continue_training = True
        self.rewards_cnt = []

    def _on_training_start(self) -> None:
        self.log(self.C_LOG_TYPE_I, Training.C_LOG_SEPARATOR)
        self.log(self.C_LOG_TYPE_I, '-- Episode', self.episode_num, 'started...')
        self.log(self.C_LOG_TYPE_I, Training.C_LOG_SEPARATOR, '\n')
        self.ds_rewards.add_episode(self.episode_num)

    def _on_step(self) -> bool:
        # With Cycle Limit
        self.ds_rewards.memorize_row(self.total_cycle, timedelta(0,0,0), self.locals.get("rewards"))
        self.total_cycle += 1
        self.cycles += 1
        if self.locals.get("infos")[0]:
            self.log(self.C_LOG_TYPE_I, Training.C_LOG_SEPARATOR)
            self.log(self.C_LOG_TYPE_I, '-- Episode', self.episode_num, 'finished after', self.total_cycle + 1, 'cycles')
            self.log(self.C_LOG_TYPE_I, Training.C_LOG_SEPARATOR, '\n\n')
            self.episode_num += 1
            self.total_cycle = 0
            self.ds_rewards.add_episode(self.episode_num)
            self.log(self.C_LOG_TYPE_I, Training.C_LOG_SEPARATOR)
            self.log(self.C_LOG_TYPE_I, '-- Episode', self.episode_num, 'started...')
            self.log(self.C_LOG_TYPE_I, Training.C_LOG_SEPARATOR, '\n')
        
        return True

    def _on_training_end(self) -> None:
        data_printing   = {"Cycle":        [False],
                            "Day":          [False],
                            "Second":       [False],
                            "Microsecond":  [False],
                            "Native":        [True,-1]}
        mem_plot    = MyDataPlotting(self.ds_rewards, p_showing=False, p_printing=data_printing)
        mem_plot.get_plots()
        self.plots = mem_plot.plots

# 8 Run the SB3 Training Native
gym_env     = gym.make('CartPole-v1')
gym_env.seed(1)
policy_sb3 = PPO(
                policy="MlpPolicy", 
                env=gym_env,
                n_steps=buffer_size,
                verbose=0,
                policy_kwargs=policy_kwargs,
                seed=1)

cus_callback = CustomCallback()
policy_sb3.learn(total_timesteps=1000, callback=cus_callback)
native_plot = cus_callback.plots

# 9 Difference Plot
native_ydata = native_plot[1][0].lines[0].get_ydata()
wrapper_ydata = wrapper_plot[1][0].lines[0].get_ydata()
smoothed_native = pd.Series.rolling(pd.Series(native_ydata), mva_window).mean()
smoothed_native = [elem for elem in smoothed_native]
smoothed_wrapper = pd.Series.rolling(pd.Series(wrapper_ydata), mva_window).mean()
smoothed_wrapper = [elem for elem in smoothed_wrapper]
plt.plot(smoothed_native, label="Native")
plt.plot(smoothed_wrapper, label="Wrapper")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.legend()

if __name__ == "__main__":
    plt.show()

Results