schola.sb3.action_space_patch.PatchedPPO
Class Definition
class schola.sb3.action_space_patch.PatchedPPO( policy, env, learning_rate=0.0003, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None, normalize_advantage=True, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, use_sde=False, sde_sample_freq=-1, target_kl=None, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)
Bases: PPO
Parameters
policy
Type: str | Type[ActorCriticPolicy]
The policy model to use (MlpPolicy, CnnPolicy, …).
env
Type: Env | VecEnv | str
The environment to learn from.
learning_rate
Type: float | Callable[[float], float]
Default: 0.0003
The learning rate, it can be a function of the current progress remaining.
n_steps
Type: int
Default: 2048
The number of steps to run for each environment per update.
batch_size
Type: int
Default: 64
Minibatch size.
n_epochs
Type: int
Default: 10
Number of epoch when optimizing the surrogate loss.
gamma
Type: float
Default: 0.99
Discount factor.
gae_lambda
Type: float
Default: 0.95
Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
clip_range
Type: float | Callable[[float], float]
Default: 0.2
Clipping parameter, it can be a function of the current progress remaining.
clip_range_vf
Type: None | float | Callable[[float], float]
Default: None
Clipping parameter for the value function, it can be a function of the current progress remaining.
normalize_advantage
Type: bool
Default: True
Whether to normalize or not the advantage.
ent_coef
Type: float
Default: 0.0
Entropy coefficient for the loss calculation.
vf_coef
Type: float
Default: 0.5
Value function coefficient for the loss calculation.
max_grad_norm
Type: float
Default: 0.5
The maximum value for the gradient clipping.
use_sde
Type: bool
Default: False
Whether to use generalized State Dependent Exploration (gSDE).
sde_sample_freq
Type: int
Default: -1
Sample a new noise matrix every n steps when using gSDE.
target_kl
Type: float | None
Default: None
Limit the KL divergence between updates.
stats_window_size
Type: int
Default: 100
Window size for the rollout logging, specifying the number of episodes to average the reported success rate and episode lengths over.
tensorboard_log
Type: str | None
Default: None
The log location for tensorboard (if None, no logging).
policy_kwargs
Type: Dict[str, Any] | None
Default: None
Additional arguments to be passed to the policy on creation.
verbose
Type: int
Default: 0
The verbosity level: 0 no output, 1 info, 2 debug.
seed
Type: int | None
Default: None
Seed for the pseudo random generators.
device
Type: device | str
Default: 'auto'
Device (cpu, cuda, …) on which the code should be run.
_init_setup_model
Type: bool
Default: True
Whether or not to build the network at the creation of the instance.
Attributes
logger
Getter for the logger object.
policy_aliases
Policy aliases for the PPO algorithm.
rollout_buffer
The rollout buffer used for collecting experiences.
policy
The policy model.
observation_space
The observation space of the environment.
action_space
The action space of the environment.
n_envs
The number of parallel environments.
lr_schedule
The learning rate schedule.
Methods
__init__
__init__( policy, env, learning_rate=0.0003, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, clip_range_vf=None, normalize_advantage=True, ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, use_sde=False, sde_sample_freq=-1, target_kl=None, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)
Initialize the PatchedPPO algorithm.
collect_rollouts
collect_rollouts(env, callback, rollout_buffer, n_rollout_steps)
Collect experiences using the current policy and fill a RolloutBuffer
.
get_env
get_env()
Returns the current environment (can be None if not defined).
get_parameters
get_parameters()
Return the parameters of the agent.
get_vec_normalize_env
get_vec_normalize_env()
Return the VecNormalize
wrapper of the training env if it exists.
learn
learn(total_timesteps, callback=None, log_interval=1, eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="PPO", eval_log_path=None, reset_num_timesteps=True)
Return a trained model.
load
load(path, env=None, device='auto', custom_objects=None, print_system_info=False, force_reset=True, **kwargs)
Load the model from a zip-file.
predict
predict(observation, state=None, episode_start=None, deterministic=False)
Get the policy action from an observation (and optional hidden state).
save
save(path, exclude=None, include=None)
Save all the attributes of the object and the model parameters in a zip-file.
set_env
set_env(env, force_reset=True)
Checks the validity of the environment, and if it is coherent, set it as the current environment.
set_logger
set_logger(logger)
Setter for for logger object.
set_parameters
set_parameters(load_path_or_dict, exact_match=True, device='auto')
Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see get_parameters
).
set_random_seed
set_random_seed(seed=None)
Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space).
train
train()
Update policy using the currently gathered rollout buffer.