schola.sb3.utils.SB3PPOModel
Class Definition
class schola.sb3.utils.SB3PPOModel(policy, action_space)
Bases: SB3ScholaModel
Parameters
policy
Type: Policy
The policy model.
action_space
Type: Space
The action space.
Attributes
T_destination
call_super_init
dump_patches
training
Methods
__init__
__init__(policy, action_space)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module
add_module(name, module)
Add a child module to the current module.
apply
apply(fn)
Apply fn
recursively to every submodule (as returned by .children()
) as well as self.
bfloat16
bfloat16()
Casts all floating point parameters and buffers to bfloat16
datatype.
buffers
buffers(recurse=True)
Return an iterator over module buffers.
children
children()
Return an iterator over immediate children modules.
compile
compile(*args, **kwargs)
Compile this Module’s forward using torch.compile()
.
cpu
cpu()
Move all model parameters and buffers to the CPU.
cuda
cuda(device=None)
Move all model parameters and buffers to the GPU.
double
double()
Casts all floating point parameters and buffers to double
datatype.
eval
eval()
Set the module in evaluation mode.
extra_repr
extra_repr()
Return the extra representation of the module.
float
float()
Casts all floating point parameters and buffers to float
datatype.
forward
forward(*args)
Define the computation performed at every call.
get_buffer
get_buffer(target)
Return the buffer given by target
if it exists, otherwise throw an error.
get_extra_state
get_extra_state()
Return any extra state to include in the module’s state_dict.
get_logits
get_logits(x)
get_parameter
get_parameter(target)
Return the parameter given by target
if it exists, otherwise throw an error.
get_submodule
get_submodule(target)
Return the submodule given by target
if it exists, otherwise throw an error.
half
half()
Casts all floating point parameters and buffers to half
datatype.
ipu
ipu(device=None)
Move all model parameters and buffers to the IPU.
load_state_dict
load_state_dict(state_dict, strict=True, assign=False)
Copy parameters and buffers from state_dict
into this module and its descendants.
modules
modules()
Return an iterator over all modules in the network.
mtia
mtia(device=None)
Move all model parameters and buffers to the MTIA.
named_buffers
named_buffers(prefix='', recurse=True, remove_duplicate=True)
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children
named_children()
Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules
named_modules(memo=None, prefix='', remove_duplicate=True)
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters
named_parameters(prefix='', recurse=True, remove_duplicate=True)
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters
parameters(recurse=True)
Return an iterator over module parameters.
register_backward_hook
register_backward_hook(hook)
Register a backward hook on the module.
register_buffer
register_buffer(name, tensor, persistent=True)
Add a buffer to the module.
register_forward_hook
register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)
Register a forward hook on the module.
register_forward_pre_hook
register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)
Register a forward pre-hook on the module.
register_full_backward_hook
register_full_backward_hook(hook, prepend=False)
Register a backward hook on the module.
register_full_backward_pre_hook
register_full_backward_pre_hook(hook, prepend=False)
Register a backward pre-hook on the module.
register_load_state_dict_post_hook
register_load_state_dict_post_hook(hook)
Register a post-hook to be run after module’s load_state_dict()
is called.
register_load_state_dict_pre_hook
register_load_state_dict_pre_hook(hook)
Register a pre-hook to be run before module’s load_state_dict()
is called.
register_module
register_module(name, module)
Alias for add_module()
.
register_parameter
register_parameter(name, param)
Add a parameter to the module.
register_state_dict_post_hook
register_state_dict_post_hook(hook)
Register a post-hook for the state_dict()
method.
register_state_dict_pre_hook
register_state_dict_pre_hook(hook)
Register a pre-hook for the state_dict()
method.
requires_grad_
requires_grad_(requires_grad=True)
Change if autograd should record operations on parameters in this module.
save_as_onnx
save_as_onnx(export_path, onnx_opset=17)
set_extra_state
set_extra_state(state)
Set extra state contained in the loaded state_dict.
set_submodule
set_submodule(target, module)
Set the submodule given by target
if it exists, otherwise throw an error.
share_memory
share_memory()
See torch.Tensor.share_memory_()
.
state_dict
state_dict(*args, destination=None, prefix='', keep_vars=False)
Return a dictionary containing references to the whole state of the module.
to
to(*args, **kwargs)
Move and/or cast the parameters and buffers.
to_empty
to_empty(*, device, recurse=True)
Move the parameters and buffers to the specified device without copying storage.
train
train(mode=True)
Set the module in training mode.
type
type(dst_type)
Casts all parameters and buffers to dst_type
.
xpu
xpu(device=None)
Move all model parameters and buffers to the XPU.
zero_grad
zero_grad(set_to_none=True)
Reset gradients of all model parameters.