Source code for schola.sb3.utils

Copied!


# Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights Reserved.

"""
Utility functions for working with stable baselines 3
"""

from collections import OrderedDict
from typing import Dict, List, Tuple, Union
import torch as th
from stable_baselines3 import PPO
from stable_baselines3.common.base_class import BaseAlgorithm
import os
from argparse import ArgumentParser
import gymnasium as gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import (
    VecEnvObs,
    VecEnv,
    VecEnvWrapper
)
# The below code is adapted from https://github.com/DLR-RM/stable-baselines3/blob/v2.2.1/docs/guide/export.rst

#The MIT License
#
#Copyright (c) 2019 Antonin Raffin
#
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#
#The above copyright notice and this permission notice shall be included in
#all copies or substantial portions of the Software.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
#THE SOFTWARE.

# Modifications Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved.

# we don't always include the value network here since we don't need it for inference
[docs] class OnnxablePolicy(th.nn.Module): """ A PyTorch Module that wraps a stable baselines policy and extracts the necessary components to export to ONNX. Parameters ---------- extractor : th.nn.Module The feature extractor from the policy. action_net : th.nn.Module The action network from the policy. value_net : th.nn.Module The value network from the policy. include_value_net : bool Whether to include the value network in the output. Attributes ---------- extractor : th.nn.Module The feature extractor from the policy. action_net : th.nn.Module The action network from the policy. value_net : th.nn.Module The value network from the policy. include_value_net : bool Whether to include the value network in the output """
[docs] def __init__(self, extractor: th.nn.Module, action_net : th.nn.Module, value_net: th.nn.Module, include_value_net:bool=False): super().__init__() self.extractor = extractor self.action_net = action_net self.value_net = value_net self.include_value_net = include_value_net
[docs] def forward(self, x : th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: action_hidden, value_hidden = self.extractor(x) if self.include_value_net: return self.action_net(action_hidden), self.value_net(value_hidden) else: return self.action_net(action_hidden)
[docs] def save_model_as_onnx(model : BaseAlgorithm, export_path:str) -> None: """ Save a stable baselines model as an ONNX file. Parameters ---------- model : stable_baselines3.common.base_class.BaseAlgorithm The model to save. export_path : str The path to save the model to. """ new_model = OnnxablePolicy( model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net ) #make directories if they don't exist directory_path = export_path.rsplit("/",1)[0] if not os.path.exists(directory_path): os.makedirs(directory_path) # Get the input dim from the model input_dim = gym.spaces.utils.flatten_space(model.observation_space).shape # Export the model to ONNX print("Exporting model to ONNX") with open(export_path, "w+b") as f: th.onnx.export( new_model, (th.rand(input_dim),), f, opset_version=9, input_names=["input"], ) print("Model exported to ONNX")
# end of adapted code
[docs] def convert_ckpt_to_onnx_for_unreal(trainer=PPO, model_path="./ckpt/ppo_final.zip", export_path="./ckpt/OnnxFiles/Model.onnx") -> None: """ Convert a stable baselines model to ONNX for use in Unreal. Parameters ---------- trainer : stable_baselines3.common.base_class.BaseAlgorithm The trainer to load the model from. model_path : str The path to the model to convert. export_path : str The path to save the converted model to. """ model = trainer.load(model_path) save_model_as_onnx(model,export_path)
[docs] class VecMergeDictActionWrapper(VecEnvWrapper): """ A vectorized wrapper for merging a dictionary of actions into 1 single action. All actions in the dictionary must be of compatible types. Parameters ---------- venv : VecEnv The vectorized environment being wrapped. """
[docs] def __init__(self, venv: VecEnv): all_action_spaces = list(venv.action_space.spaces.values()) assert len(all_action_spaces) > 0, "No Action Spaces to merge." action_space = all_action_spaces[0].merge(*all_action_spaces) super().__init__(venv=venv, action_space=action_space)
[docs] def reset(self) -> VecEnvObs: return self.venv.reset()
[docs] def step(self, action: np.ndarray) -> Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]: return self.venv.step(action)
[docs] def step_async(self, actions: np.ndarray) -> None: self.venv.step_async(actions)
[docs] def step_wait(self) -> Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]: return self.venv.step_wait()

Related pages

  • Visit the Schola product page for download links and more information.

Looking for more documentation on GPUOpen?

AMD GPUOpen software blogs

Our handy software release blogs will help you make good use of our tools, SDKs, and effects, as well as sharing the latest features with new releases.

GPUOpen Manuals

Don’t miss our manual documentation! And if slide decks are what you’re after, you’ll find 100+ of our finest presentations here.

AMD GPUOpen Performance Guides

The home of great performance and optimization advice for AMD RDNA™ 2 GPUs, AMD Ryzen™ CPUs, and so much more.

Getting started: AMD GPUOpen software

New or fairly new to AMD’s tools, libraries, and effects? This is the best place to get started on GPUOpen!

AMD GPUOpen Getting Started Development and Performance

Looking for tips on getting started with developing and/or optimizing your game, whether on AMD hardware or generally? We’ve got you covered!

AMD GPUOpen Technical blogs

Browse our technical blogs, and find valuable advice on developing with AMD hardware, ray tracing, Vulkan®, DirectX®, Unreal Engine, and lots more.

Find out more about our software!

AMD GPUOpen Effects - AMD FidelityFX technologies

Create wonder. No black boxes. Meet the AMD FidelityFX SDK!

AMD GPUOpen Samples

Browse all our useful samples. Perfect for when you’re needing to get started, want to integrate one of our libraries, and much more.

AMD GPUOpen developer SDKs

Discover what our SDK technologies can offer you. Query hardware or software, manage memory, create rendering applications or machine learning, and much more!

AMD GPUOpen Developer Tools

Analyze, Optimize, Profile, Benchmark. We provide you with the developer tools you need to make sure your game is the best it can be!