Source code for schola.scripts.ray.launch

Copied!


# Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All Rights Reserved.

"""
Script to train an rllib model using Schola.
"""
from argparse import ArgumentParser
from schola.ray.utils import export_onnx_from_policy
from typing import Any, Dict, Type, Union
import traceback

from schola.ray.env import BaseEnv
from schola.core.env import ScholaEnv
from schola.core.utils import get_plugins

import ray
from ray import air, tune
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.appo.appo import APPOConfig
from ray.rllib.algorithms.impala.impala import IMPALAConfig
from ray.tune.registry import register_env
from schola.scripts.common import (
    add_unreal_process_args,
    add_checkpoint_args,
)
from dataclasses import fields
from ray.rllib.policy.policy import Policy
from schola.scripts.ray.settings import TrainingSettings, ResumeSettings, LoggingSettings, NetworkArchitectureSettings, ResourceSettings, RLlibScriptArgs, PPOSettings, APPOSettings, IMPALASettings

[docs] def make_parser(): """ Create an argument parser for launching training with ray. Returns ------- ArgumentParser The argument parser for the script. """ parser = ArgumentParser(prog="Launch Schola Examples with RLlib") add_unreal_process_args(parser) training_args_group = parser.add_argument_group("Training Arguments") TrainingSettings.populate_arg_group(training_args_group) logging_args_group = parser.add_argument_group("Logging Arguments") LoggingSettings.populate_arg_group(logging_args_group) checkpoint_group = add_checkpoint_args(parser) ResumeSettings.populate_arg_group(checkpoint_group) architecture_group = parser.add_argument_group("Network Architecture Arguments") NetworkArchitectureSettings.populate_arg_group(architecture_group) resource_group = parser.add_argument_group("Resource Arguments") ResourceSettings.populate_arg_group(resource_group) subparsers = parser.add_subparsers(required=True, help="Choose the algorithm to use") ppo_parser = subparsers.add_parser("PPO", help="Proximal Policy Optimization", parents=[PPOSettings.get_parser()]) appo_parser = subparsers.add_parser("APPO", help="Asynchronous Proximal Policy Optimization", parents=[APPOSettings.get_parser()]) impala_parser = subparsers.add_parser("IMPALA", help="Importance Weighted Actor-Learner Architecture", parents=[IMPALASettings.get_parser()]) return parser
[docs] def get_dataclass_args(args: Dict[str,Any], dataclass : Type[Any] ) -> Dict[str,Any]: """ Get the arguments for a dataclass from a dictionary, potentially containing additional arguments. Parameters ---------- args : Dict[str,Any] The dictionary of arguments. dataclass : Type[Any] The dataclass to get the arguments for. Returns ------- Dict[str,Any] The arguments for the dataclass. """ return {k: v for k, v in args.items() if k in {f.name for f in fields(dataclass)}}
[docs] def main_from_cli() -> tune.ExperimentAnalysis: """ Main function for launching training with ray from the command line. Returns ------- tune.ExperimentAnalysis The results of the training See Also -------- main : The main function for launching training with ray """ parser = make_parser() discovered_plugins = get_plugins("schola.plugins.ray.launch") for plugin in discovered_plugins: plugin.add_plugin_args_to_parser(parser) args = parser.parse_args() args_dict = vars(args) # split the arguments into individual dictionaries for each dataclass algorithm_args = get_dataclass_args(args_dict, args.algorithm_settings_class) training_args = get_dataclass_args(args_dict, TrainingSettings) logging_args = get_dataclass_args(args_dict, LoggingSettings) resume_args = get_dataclass_args(args_dict, ResumeSettings) network_args = get_dataclass_args(args_dict, NetworkArchitectureSettings) resource_args = get_dataclass_args(args_dict, ResourceSettings) rllib_args = get_dataclass_args(args_dict, RLlibScriptArgs) # build datraclasses from the dictionaries algorithm_args = args.algorithm_settings_class(**algorithm_args) training_args = TrainingSettings(**training_args) logging_args = LoggingSettings(**logging_args) resume_args = ResumeSettings(**resume_args) network_args = NetworkArchitectureSettings(**network_args) resource_args = ResourceSettings(**resource_args) plugins=[] for plugin in discovered_plugins: plugin_args = get_dataclass_args(args_dict, plugin) plugins.append(plugin(**plugin_args)) args = RLlibScriptArgs( algorithm_settings=algorithm_args, training_settings=training_args, logging_settings=logging_args, resume_settings=resume_args, network_architecture_settings=network_args, resource_settings=resource_args, plugins=plugins, **rllib_args ) return main(args)
[docs] def main(args: RLlibScriptArgs) -> tune.ExperimentAnalysis: """ Main function for launching training with ray. Parameters ---------- args : RLlibArgs The arguments for the script as a dataclass Returns ------- tune.ExperimentAnalysis The results of the training """ # collect the names of the agents by creating a temporary environment schola_env = ScholaEnv(args.make_unreal_connection(), verbosity=args.logging_settings.schola_verbosity) agent_names = schola_env.agent_display_names[0] schola_env.close() # Clusters configure resources automatically if args.resource_settings.using_cluster: ray.init() else: ray.init(num_cpus=args.resource_settings.num_cpus, num_gpus=args.resource_settings.num_gpus) def env_creator(env_config): env = BaseEnv(args.make_unreal_connection(), verbosity=args.logging_settings.schola_verbosity) return env def policy_mapping_fn(agent_id, episode=None, worker=None, **kwargs): return agent_names[agent_id] register_env("schola_env", env_creator) #Note New Ray Stack doesn't support Vectorized MutiAgent environments yet so the old stack is better config : Union[PPOConfig, APPOConfig, IMPALAConfig] = ( args.algorithm_settings.rllib_config() .api_stack( enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False, ) .environment("schola_env", clip_rewards=False, clip_actions=True, normalize_actions=False) .framework("torch") .env_runners( num_env_runners=0, num_envs_per_env_runner=1 ) .multi_agent( policies={ agent_name: PolicySpec(observation_space=None, action_space=None) for agent_name in set(agent_names.values()) }, policy_mapping_fn=policy_mapping_fn, policies_to_train=None, # default to training all policies ) .resources( num_cpus_for_main_process=args.resource_settings.num_cpus_for_main_process, num_gpus = args.resource_settings.num_gpus ) .learners( num_learners=args.resource_settings.num_learners, num_cpus_per_learner=args.resource_settings.num_cpus_per_learner, num_gpus_per_learner=args.resource_settings.num_gpus_per_learner, ) .training( lr=args.training_settings.learning_rate, gamma=args.training_settings.gamma, num_sgd_iter=args.training_settings.num_sgd_iter, train_batch_size_per_learner=args.training_settings.train_batch_size_per_learner, minibatch_size=args.training_settings.minibatch_size, model={ "fcnet_hiddens": args.network_architecture_settings.fcnet_hiddens, "fcnet_activation": args.network_architecture_settings.activation.layer, "free_log_std":False, # onnx fails to load if this is set to True "use_attention": args.network_architecture_settings.use_attention, "attention_dim": args.network_architecture_settings.attention_dim, }, **args.algorithm_settings.get_settings_dict() ) ) stop = { "timesteps_total": args.training_settings.timesteps, } callbacks = [] for plugin in args.plugins: callbacks+=plugin.get_extra_callbacks() print("Starting training") try: results = tune.run( args.algorithm_settings.name, config=config, stop=stop, checkpoint_config=air.CheckpointConfig( checkpoint_frequency=args.save_freq if args.enable_checkpoints else 0, checkpoint_at_end=args.save_final_policy, ), restore=args.resume_settings.resume_from, verbose=args.logging_settings.rllib_verbosity, storage_path=args.checkpoint_dir, callbacks=callbacks, ) last_checkpoint = results.get_last_checkpoint() print("Training complete") finally: # Always shutdown ray and release the environment from training even if there is an error # will reraise the error unless a control flow statement is added ray.shutdown() if args.export_onnx: export_onnx_from_policy( Policy.from_checkpoint(last_checkpoint), results.trials[-1].path ) print("Models exported to ONNX at ", results.trials[-1].path) return results
[docs] def debug_main_from_cli() -> None: """ Main function for launching training with ray from the command line, that catches any errors and waits for user input to close. See Also -------- main_from_cli : The main function for launching training with ray from the command line main : The main function for launching training with ray """ try: main_from_cli() except Exception as e: traceback.print_exc() finally: input("Press any key to close:")
if __name__ == "__main__": debug_main_from_cli()

Related pages

  • Visit the Schola product page for download links and more information.

Looking for more documentation on GPUOpen?

AMD GPUOpen software blogs

Our handy software release blogs will help you make good use of our tools, SDKs, and effects, as well as sharing the latest features with new releases.

GPUOpen Manuals

Don’t miss our manual documentation! And if slide decks are what you’re after, you’ll find 100+ of our finest presentations here.

AMD GPUOpen Performance Guides

The home of great performance and optimization advice for AMD RDNAâ„¢ 2 GPUs, AMD Ryzenâ„¢ CPUs, and so much more.

Getting started: AMD GPUOpen software

New or fairly new to AMD’s tools, libraries, and effects? This is the best place to get started on GPUOpen!

AMD GPUOpen Getting Started Development and Performance

Looking for tips on getting started with developing and/or optimizing your game, whether on AMD hardware or generally? We’ve got you covered!

AMD GPUOpen Technical blogs

Browse our technical blogs, and find valuable advice on developing with AMD hardware, ray tracing, Vulkan®, DirectX®, Unreal Engine, and lots more.

Find out more about our software!

AMD GPUOpen Effects - AMD FidelityFX technologies

Create wonder. No black boxes. Meet the AMD FidelityFX SDK!

AMD GPUOpen Samples

Browse all our useful samples. Perfect for when you’re needing to get started, want to integrate one of our libraries, and much more.

AMD GPUOpen developer SDKs

Discover what our SDK technologies can offer you. Query hardware or software, manage memory, create rendering applications or machine learning, and much more!

AMD GPUOpen Developer Tools

Analyze, Optimize, Profile, Benchmark. We provide you with the developer tools you need to make sure your game is the best it can be!