Source code for schola.sb3.env

Copied!


# Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved.

"""
Implementation of stable_baselines3.common.vec_env.VecEnv backed by a Schola Environment.
"""

from collections import OrderedDict
from difflib import unified_diff
from functools import cached_property
from typing import Dict, List, Optional, Tuple, TypeVar, TypedDict, Union
from venv import logger
import stable_baselines3 as sb3
from schola.core.env import ScholaEnv
from stable_baselines3.common.vec_env import VecEnv as Sb3VecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs
from schola.core.spaces import (
    DiscreteSpace,
    MultiDiscreteSpace,
    BoxSpace,
    MultiBinarySpace,
    DictSpace
)

import numpy as np
import gymnasium as gym
from schola.core.error_manager import EnvironmentException
from schola.core.utils import nested_get, IdManager
import logging

T = TypeVar("T")


[docs] class VecEnv(Sb3VecEnv):
[docs] def __init__( self, unreal_connection, verbosity: int = 0 ): self._env = ScholaEnv( unreal_connection, verbosity, ) self.id_manager = IdManager(self._env.ids) # we just use the default UID to get the shared definition obs_space = self._env.get_obs_space(*self.id_manager[0]) action_space = self._env.get_action_space(*self.id_manager[0]) #test that everything is setup correctly for env_id, agent_id in self.id_manager.id_list: assert self._env.get_action_space(env_id,agent_id) == action_space, f"Action Space Mismatch on Agent:{agent_id} in Env {env_id}.\\nGot: {self._env.get_action_space(env_id,agent_id)}\\nExpected:{action_space}" assert self._env.get_obs_space(env_id,agent_id) == obs_space, f"Observation Space Mismatch on Agent:{agent_id} in Env {env_id}.\\nGot: {self._env.get_obs_space(env_id,agent_id)}\\nExpected:{obs_space}" logging.debug(action_space) logging.debug(obs_space) self.reset_infos = [{} for _ in range(self._env.num_agents)] self._seed : Optional[int] = None self.options : Optional[Dict[str,str]]= None super().__init__(self._env.num_agents, obs_space, action_space)
[docs] def close(self) -> None: return self._env.close()
[docs] def env_method(method_name, *method_args, indices=None, **method_kwargs): ...
[docs] def get_attr(self, attr_name, indices=None): return [None for x in range(0, self._env.num_envs)]
[docs] def reset(self) -> Dict[str,np.ndarray]: obs, nested_infos = self._env.hard_reset(seeds=self._seed, options=self.options) self._seed = None self.options = None for env_id in nested_infos: for agent_id in nested_infos[env_id]: uid = self.id_manager[env_id, agent_id] self.reset_infos[uid] = nested_infos[env_id][agent_id] # flatten the observations, converting from dict to list using key as indices obs = self.id_manager.flatten_id_dict(obs) #flatten even more, for sb3 compatibility obs = _flatten_obs(obs, self.observation_space) return obs
[docs] def env_is_wrapped(self, wrapper_class, indices=None) -> bool: if indices is None: indices = (x for x in range(len(self._env.ids))) return [False for x in indices]
[docs] def seed(self, seed:Optional[int]=None) -> None: if not seed is None: self._seed = seed
[docs] def set_options(self,options:Optional[Dict[str,str]]=None) -> None: """ Set the options for the environment. Parameters ---------- options : Optional[Dict[str,str]], optional The options to set, by default None. """ if not options is None: self.options = options
[docs] def set_attr(self, attr_name, value, indices=None): ...
[docs] def step_async(self, actions: Union[List[np.ndarray], List[Dict[str,np.ndarray]]]) -> None: #actions can come in as a list of flattened tensors so we need to unflatten them if isinstance(actions[0],np.ndarray): unflattened_actions = [OrderedDict() for _ in actions] for i,action in enumerate(actions): start_dim = 0 for name, space in self.action_space.items(): #TODO apply any necessary DTYPE conversions here (everything is a float32 otherwise) unflattened_actions[i][name] = action[start_dim:start_dim+len(space)] start_dim+=len(space) else: #actions came in as a dict. how nice! unflattened_actions = actions #convert vector to Nested dictionary actions = self.id_manager.nest_id_list(unflattened_actions) self._env.send_actions(actions)
[docs] def step_wait(self) -> Tuple[Dict[str,np.ndarray],np.ndarray, np.ndarray, List[Dict[str,str]]]: observations, rewards, terminateds, truncateds, nested_infos = self._env.poll() array_dones = np.empty((self._env.num_agents,)) array_rewards = np.asarray(self.id_manager.flatten_id_dict(rewards)) array_observations = self.id_manager.flatten_id_dict(observations) infos = [{} for _ in range(self._env.num_agents)] for env_id in nested_infos: for agent_id in nested_infos[env_id]: uid = self.id_manager[env_id, agent_id] #safe because we are iterating over nested_infos infos[uid] = nested_infos[env_id][agent_id] envs_to_reset = [] for env_id, agent_id_list in enumerate(self.id_manager.ids): any_done = False all_done = True for agent_id in agent_id_list: uid = self.id_manager[env_id, agent_id] array_dones[uid] = ( nested_get(truncateds, (env_id, agent_id), False) or nested_get(terminateds, (env_id, agent_id), False) ) any_done = any_done or array_dones[uid] all_done = all_done and array_dones[uid] #We don't handle the case where 1 agent ends early currently. if(any_done): if(all_done): envs_to_reset.append(env_id) else: raise EnvironmentException(f"SB3 with multi-agent environments does not support agents completing at different steps. Env {env_id} had agents in different completion states.") # following the sb3 vec env guideline we self reset if len(envs_to_reset) > 0: resetted_obs, reset_infos = self._env.soft_reset(envs_to_reset) self.reset_infos = [{} for _ in range(self._env.num_agents)] for env_id in reset_infos: for agent_id in reset_infos[env_id]: uid = self.id_manager[env_id, agent_id] #safe because we are iterating over nested_infos self.reset_infos[uid] = reset_infos[env_id][agent_id] for env_id in envs_to_reset: for agent_id in self.id_manager.partial_get(env_id): uid = self.id_manager[env_id, agent_id] #Observations of the last step of the episode go into the info section infos[uid]["terminal_observation"] = observations[env_id][agent_id] infos[uid]["TimeLimit.truncated"] = ( truncateds[env_id][agent_id] and not terminateds[env_id][agent_id] ) # put the new observations from the start of the new episode into the returned obs array_observations[uid] = resetted_obs[env_id][agent_id] return _flatten_obs(array_observations, self.observation_space), array_rewards, array_dones, infos

Related pages

  • Visit the Schola product page for download links and more information.

Looking for more documentation on GPUOpen?

AMD GPUOpen software blogs

Our handy software release blogs will help you make good use of our tools, SDKs, and effects, as well as sharing the latest features with new releases.

GPUOpen Manuals

Don’t miss our manual documentation! And if slide decks are what you’re after, you’ll find 100+ of our finest presentations here.

AMD GPUOpen Performance Guides

The home of great performance and optimization advice for AMD RDNA™ 2 GPUs, AMD Ryzen™ CPUs, and so much more.

Getting started: AMD GPUOpen software

New or fairly new to AMD’s tools, libraries, and effects? This is the best place to get started on GPUOpen!

AMD GPUOpen Getting Started Development and Performance

Looking for tips on getting started with developing and/or optimizing your game, whether on AMD hardware or generally? We’ve got you covered!

AMD GPUOpen Technical blogs

Browse our technical blogs, and find valuable advice on developing with AMD hardware, ray tracing, Vulkan®, DirectX®, Unreal Engine, and lots more.

Find out more about our software!

AMD GPUOpen Effects - AMD FidelityFX technologies

Create wonder. No black boxes. Meet the AMD FidelityFX SDK!

AMD GPUOpen Samples

Browse all our useful samples. Perfect for when you’re needing to get started, want to integrate one of our libraries, and much more.

AMD GPUOpen developer SDKs

Discover what our SDK technologies can offer you. Query hardware or software, manage memory, create rendering applications or machine learning, and much more!

AMD GPUOpen Developer Tools

Analyze, Optimize, Profile, Benchmark. We provide you with the developer tools you need to make sure your game is the best it can be!