ScholaModel
Full path:
schola.core.model.ScholaModel
A PyTorch Module that is compatible with Schola inference. All Models have the following properties to allow for easy conversion to ONNX.
ScholaModel(observation_space, action_space)Bases: StatefulModelMixin
Parameters
-
observation_space(gym.Space) - The observation space of the model. If not a gym.spaces.Dict, it will be wrapped in a Dict with a single key “obs”. -
action_space(gym.Space) - The action space of the model. If not a gym.spaces.Dict, it will be wrapped in a Dict with a single key “action”.
Methods
init
__init__(observation_space, action_space)Parameters
-
observation_space(Space) -
action_space(Space)
export_onnx_program
export_onnx_program(onnx_opset = 21)Export the model as an ONNX program.
Parameters
onnx_opset(int) - The ONNX opset version to use for the export.
forward
forward(*args)Parameters
args(Any)
get_logit_dimensions
get_logit_dimensions()Get the flat dimensions of the action spaces. :returns: Flat size per action dict key (gymnasium.spaces.flatdim on each subspace). :rtype: Dict[str, int]
Returns
Return type: Dict[str, int]
make_box_output
make_box_output(logits, space_name = 'action')Map logits to a gymnasium.spaces.Box action slice (identity for Box).
Parameters
-
logits(torch.Tensor) - Logits slice for space_name (typically shaped for one fundamental space). -
space_name(str) - Key in action_space used only for symmetry with other make_* helpers.
make_discrete_output
make_discrete_output(logits, space_name = 'action')Map logits to a gymnasium.spaces.Discrete action (argmax).
Parameters
-
logits(torch.Tensor) - Logits for the discrete branch. -
space_name(str) - Key in action_space (unused for Discrete; kept for API uniformity).
make_fundamental_output
make_fundamental_output(logits, space_name = 'action')Dispatch to the appropriate make_*_output helper for space_name.
Parameters
-
logits(torch.Tensor) - Logits slice for the fundamental space at space_name. -
space_name(str) - Key in action_space.
make_multi_binary_output
make_multi_binary_output(logits, space_name = 'action')Map logits to a gymnasium.spaces.MultiBinary action.
Parameters
-
logits(torch.Tensor) - Logits for the multi-binary branch. -
space_name(str) - Key in action_space (unused; kept for API uniformity).
make_multi_discrete_output
make_multi_discrete_output(logits, space_name = 'action')Map logits to a gymnasium.spaces.MultiDiscrete action (per-section argmax).
Parameters
-
logits(torch.Tensor) - Concatenated logits aligned with action_space[space_name].nvec. -
space_name(str) - Key of the MultiDiscrete subspace in action_space.
make_outputs
make_outputs(logits)Split concatenated logits and produce one output tensor per action key.
Parameters
logits(torch.Tensor) - Concatenated logits over action branches (sequence dimensions flattened to batch).
save_as_onnx
save_as_onnx(export_path, onnx_opset = 21)Export this model to an .onnx file on disk.
Parameters
-
export_path(str) - Output file path; parent directories are created if missing. -
onnx_opset(int) - ONNX opset passed to export_onnx_program().
Attributes
observation_space
observation_space : gym.spaces.DictThe observation space of the model.
action_space
action_space : gym.spaces.DictThe action space of the ScholaModel.
flat_dims
flat_dims : Dict[str, int]A dictionary of the flat dimensions of the action spaces. Used to convert logits outputs to the correct output shapes.
input_obs_keys
input_obs_keysKeys of observation_space in forward / export input order.
output_action_keys
output_action_keysreturns: Keys of action_space in forward / export output order. :rtype: list of str