Skip to content

ScholaModel

Full path: schola.core.model.ScholaModel

A PyTorch Module that is compatible with Schola inference. All Models have the following properties to allow for easy conversion to ONNX.

ScholaModel(observation_space, action_space)

Bases: StatefulModelMixin

Parameters

  • observation_space (gym.Space) - The observation space of the model. If not a gym.spaces.Dict, it will be wrapped in a Dict with a single key “obs”.

  • action_space (gym.Space) - The action space of the model. If not a gym.spaces.Dict, it will be wrapped in a Dict with a single key “action”.

Methods

init

__init__(observation_space, action_space)

Parameters

  • observation_space (Space)

  • action_space (Space)


export_onnx_program

export_onnx_program(onnx_opset = 21)

Export the model as an ONNX program.

Parameters

  • onnx_opset (int) - The ONNX opset version to use for the export.

forward

forward(*args)

Parameters

  • args (Any)

get_logit_dimensions

get_logit_dimensions()

Get the flat dimensions of the action spaces. :returns: Flat size per action dict key (gymnasium.spaces.flatdim on each subspace). :rtype: Dict[str, int]

Returns

Return type: Dict[str, int]


make_box_output

make_box_output(logits, space_name = 'action')

Map logits to a gymnasium.spaces.Box action slice (identity for Box).

Parameters

  • logits (torch.Tensor) - Logits slice for space_name (typically shaped for one fundamental space).

  • space_name (str) - Key in action_space used only for symmetry with other make_* helpers.


make_discrete_output

make_discrete_output(logits, space_name = 'action')

Map logits to a gymnasium.spaces.Discrete action (argmax).

Parameters

  • logits (torch.Tensor) - Logits for the discrete branch.

  • space_name (str) - Key in action_space (unused for Discrete; kept for API uniformity).


make_fundamental_output

make_fundamental_output(logits, space_name = 'action')

Dispatch to the appropriate make_*_output helper for space_name.

Parameters

  • logits (torch.Tensor) - Logits slice for the fundamental space at space_name.

  • space_name (str) - Key in action_space.


make_multi_binary_output

make_multi_binary_output(logits, space_name = 'action')

Map logits to a gymnasium.spaces.MultiBinary action.

Parameters

  • logits (torch.Tensor) - Logits for the multi-binary branch.

  • space_name (str) - Key in action_space (unused; kept for API uniformity).


make_multi_discrete_output

make_multi_discrete_output(logits, space_name = 'action')

Map logits to a gymnasium.spaces.MultiDiscrete action (per-section argmax).

Parameters

  • logits (torch.Tensor) - Concatenated logits aligned with action_space[space_name].nvec.

  • space_name (str) - Key of the MultiDiscrete subspace in action_space.


make_outputs

make_outputs(logits)

Split concatenated logits and produce one output tensor per action key.

Parameters

  • logits (torch.Tensor) - Concatenated logits over action branches (sequence dimensions flattened to batch).

save_as_onnx

save_as_onnx(export_path, onnx_opset = 21)

Export this model to an .onnx file on disk.

Parameters

  • export_path (str) - Output file path; parent directories are created if missing.

  • onnx_opset (int) - ONNX opset passed to export_onnx_program().

Attributes

observation_space

observation_space : gym.spaces.Dict

The observation space of the model.


action_space

action_space : gym.spaces.Dict

The action space of the ScholaModel.


flat_dims

flat_dims : Dict[str, int]

A dictionary of the flat dimensions of the action spaces. Used to convert logits outputs to the correct output shapes.


input_obs_keys

input_obs_keys

Keys of observation_space in forward / export input order.


output_action_keys

output_action_keys

returns: Keys of action_space in forward / export output order. :rtype: list of str