Extending launch.py for Sb3 and Ray
Schola supports extending the launch.py scripts for both ray and sb3 with additional callbacks and logging. This is done using Python plugins, which enable automatic discovery of new code extending schola. The plugins must extend an appropriate base class, and register an appropriately named entry point. The launch.py scripts will then automatically discover the plugin and use it to modify the training process. Below are the steps to specifically extend the schola.scripts.sb3.launch
and schola.scripts.ray.launch
scripts.
Extending schola.scripts.sb3.launch
You can extend schola.scripts.sb3.launch
with additional callbacks, KVWriters for logging and command line arguments. Below is an example of how to implement a plugin which adds a CSV logger and a callback to log every N timesteps.
-
Create a new class that inherits from
Sb3LauncherExtension
, and implement these methods if relevant:get_extra_KVWriters()
,get_extra_callbacks()
, andadd_plugin_args_to_parser()
.
from schola.scripts.common import Sb3LauncherExtension
from dataclasses import dataclass
from typing import Dict, Any
import argparse
from stable_baselines3.common.logger import KVWriters, CSVOutputFormat
from stable_baselines3.common.callbacks import LogEveryNTimesteps
@dataclass
class ExampleSb3Extension(Sb3LauncherExtension):
csv_save_path: str = "./output.csv"
log_frequency: int = 1000
def get_extra_KVWriters(self):
return [CSVOutputFormat(self.csv_save_path)]
def get_extra_callbacks(self):
return [LogEveryNTimesteps(n_steps=log_frequency)]
@classmethod
def add_plugin_args_to_parser(cls, parser: argparse.ArgumentParser):
"""
Add example logging arguments to the parser.
Parameters
----------
parser : argparse.ArgumentParser
The parser to which the arguments will be added.
"""
group = parser.add_argument_group("CSV Logging")
group.add_argument("--csv-save-path", type=str, help="The path to save the CSV file to")
group.add_argument("--log-frequency", type=int, help="The frequency to log to the terminal")
-
Create a new Python package, with an entrypoint in the
schola.plugins.sb3.launch
group pointing to your new class.
Extending schola.scripts.ray.launch
You can extend schola.scripts.ray.launch
with additional callbacks, and command line arguments. Below is an example of how to implement a plugin which adds support for logging with Wandb.
-
Create a new class that inherits from
RLLibLauncherExtension
, and implement these methods if relevant:get_extra_callbacks()
, andadd_plugin_args_to_parser()
.
from schola.scripts.common import RLLibLauncherExtension
from dataclasses import dataclass
from typing import Any, Dict, List
import argparse
from ray.tune.integration.wandb import WandbLoggerCallback
@dataclass
class ExampleRayExtension(RLLibLauncherExtension):
experiment_id: str = None
def get_extra_callbacks(self):
return [WandbLoggerCallback(project=self.experiment_id)]
@classmethod
def add_plugin_args_to_parser(cls, parser: argparse.ArgumentParser):
"""
Add example logging arguments to the parser.
Parameters
----------
parser : argparse.ArgumentParser
The parser to which the arguments will be added.
"""
group = parser.add_argument_group("Wandb Logging")
group.add_argument("--experiment-id", type=str, help="The experiment ID to log to")
-
Create a new Python package, with an entrypoint in the
schola.plugins.ray.launch
group pointing to your new class.