schola.sb3.utils.SB3A2CModel
Class Definition
class schola.sb3.utils.SB3A2CModel(policy, action_space)Bases: SB3PPOModel
Parameters
policy
Type: Policy
The policy model.
action_space
Type: Space
The action space.
Attributes
T_destination
call_super_init
dump_patches
training
Methods
__init__
__init__(policy, action_space)Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module
add_module(name, module)Add a child module to the current module.
apply
apply(fn)Apply fn recursively to every submodule (as returned by .children()) as well as self.
bfloat16
bfloat16()Casts all floating point parameters and buffers to bfloat16 datatype.
buffers
buffers(recurse=True)Return an iterator over module buffers.
children
children()Return an iterator over immediate children modules.
compile
compile(*args, **kwargs)Compile this Module’s forward using torch.compile().
cpu
cpu()Move all model parameters and buffers to the CPU.
cuda
cuda(device=None)Move all model parameters and buffers to the GPU.
double
double()Casts all floating point parameters and buffers to double datatype.
eval
eval()Set the module in evaluation mode.
extra_repr
extra_repr()Return the extra representation of the module.
float
float()Casts all floating point parameters and buffers to float datatype.
forward
forward(*args)Define the computation performed at every call.
get_buffer
get_buffer(target)Return the buffer given by target if it exists, otherwise throw an error.
get_extra_state
get_extra_state()Return any extra state to include in the module’s state_dict.
get_logits
get_logits(x)get_parameter
get_parameter(target)Return the parameter given by target if it exists, otherwise throw an error.
get_submodule
get_submodule(target)Return the submodule given by target if it exists, otherwise throw an error.
half
half()Casts all floating point parameters and buffers to half datatype.
ipu
ipu(device=None)Move all model parameters and buffers to the IPU.
load_state_dict
load_state_dict(state_dict, strict=True, assign=False)Copy parameters and buffers from state_dict into this module and its descendants.
modules
modules()Return an iterator over all modules in the network.
mtia
mtia(device=None)Move all model parameters and buffers to the MTIA.
named_buffers
named_buffers(prefix='', recurse=True, remove_duplicate=True)Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules
named_modules(memo=None, prefix='', remove_duplicate=True)Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters
named_parameters(prefix='', recurse=True, remove_duplicate=True)Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters
parameters(recurse=True)Return an iterator over module parameters.
register_backward_hook
register_backward_hook(hook)Register a backward hook on the module.
register_buffer
register_buffer(name, tensor, persistent=True)Add a buffer to the module.
register_forward_hook
register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)Register a forward hook on the module.
register_forward_pre_hook
register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)Register a forward pre-hook on the module.
register_full_backward_hook
register_full_backward_hook(hook, prepend=False)Register a backward hook on the module.
register_full_backward_pre_hook
register_full_backward_pre_hook(hook, prepend=False)Register a backward pre-hook on the module.
register_load_state_dict_post_hook
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module’s load_state_dict() is called.
register_load_state_dict_pre_hook
register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module’s load_state_dict() is called.
register_module
register_module(name, module)Alias for add_module().
register_parameter
register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook
register_state_dict_post_hook(hook)Register a post-hook for the state_dict() method.
register_state_dict_pre_hook
register_state_dict_pre_hook(hook)Register a pre-hook for the state_dict() method.
requires_grad_
requires_grad_(requires_grad=True)Change if autograd should record operations on parameters in this module.
save_as_onnx
save_as_onnx(export_path, onnx_opset=17)set_extra_state
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule
set_submodule(target, module)Set the submodule given by target if it exists, otherwise throw an error.
share_memory
share_memory()See torch.Tensor.share_memory_().
state_dict
state_dict(*args, destination=None, prefix='', keep_vars=False)Return a dictionary containing references to the whole state of the module.
to
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty
to_empty(*, device, recurse=True)Move the parameters and buffers to the specified device without copying storage.
train
train(mode=True)Set the module in training mode.
type
type(dst_type)Casts all parameters and buffers to dst_type.
xpu
xpu(device=None)Move all model parameters and buffers to the XPU.
zero_grad
zero_grad(set_to_none=True)Reset gradients of all model parameters.