Source code for schola.core.env

Copied!


# Copyright (c) 2023 Advanced Micro Devices, Inc. All Rights Reserved.
"""
Main Schola Environment 
"""

from schola.core.unreal_connections import UnrealConnection
from schola.core.error_manager import NoAgentsException, NoEnvironmentsException
import schola.generated.GymConnector_pb2 as gym_communication
import schola.generated.GymConnector_pb2_grpc as gym_grpc
import schola.generated.Definitions_pb2 as env_definitions
import schola.generated.State_pb2 as state
from schola.core.spaces import DictSpace
import logging
import numpy as np
import atexit
from typing import Any, List, Dict, Optional, Tuple, Union, TypeVar


T = TypeVar("T")

# A Dictionary, with EnvIds as keys and a Dictionary of AgentIds to some TypeVar as Value.
EnvAgentIdDict = Dict[int,Dict[int,T]]

[docs] class ScholaEnv: """ A Gym-Like Environment that wraps a connection to the Unreal Engine, running the Schola Plugin for Unreal. Parameters ---------- unreal_connection : UnrealConnection The connection to the Unreal Engine. verbosity : int, default=0 The verbosity level for the environment. environment_start_timeout : int, default=45 The time to wait for the environment to start in seconds. Attributes ---------- unreal_connection : UnrealConnection The connection to the Unreal Engine. gym_stub : gym_grpc.GymServiceStub The gRPC stub for the Gym Service. ids : List[List[int]] A nested list of all the environments and their active agents. agent_display_names : List[Dict[int,str]] A list of mappings from the id to the display names for each agent in each environment. obs_defns : Dict[int,Dict[int,DictSpace]] The observation space definitions for each agent in each environment. action_defns : Dict[int,Dict[int,DictSpace]] The action space definitions for each agent in each environment. steps : int The number of steps taken in the current episode of the environment. next_action : Dict[int,Dict[int,Any]], optional The next action to be taken by each agent in each environment. Raises ------ NoEnvironmentsException If there are no environment definitions. NoAgentsException If there are no agents defined for any environment. """
[docs] def __init__( self, unreal_connection : UnrealConnection, verbosity:int=0, environment_start_timeout:int = 45, ): log_level = logging.WARNING log_level = logging.INFO if verbosity == 1 else log_level log_level = logging.DEBUG if verbosity >= 2 else log_level logging.basicConfig( format="%(asctime)s:%(levelname)s:%(message)s", level=log_level ) logging.info("creating channel") self.unreal_connection = unreal_connection self.unreal_connection.start() atexit.register(self.close) self.gym_stub : gym_grpc.GymServiceStub = self.unreal_connection.connect_stubs(gym_grpc.GymServiceStub)[0] #Server might be booting up if we have a standalone connection, so we wait for 15 to verify start_msg = gym_communication.GymConnectorStartRequest() self.gym_stub.StartGymConnector(start_msg, timeout=environment_start_timeout, wait_for_ready=True) logging.info("requesting environment definition") self.ids : List[List[int]] = [] self.agent_display_names : List[Dict[int,str]] = [] # ids is set here self._define_environment() self.steps : int = 0 self.next_action : Optional[Dict[int,Dict[int,Any]]] = None
def _create_space_definitions(self, defn_map : Dict[int, Dict[int,env_definitions.AgentDefinition]]) -> None: """ Create space definitions for observation and action spaces. Parameters ---------- defn_map : Dict[int, Dict[int, env_definitions.AgentDefinition]] A dictionary containing environment and agent definitions. """ self.obs_defns: Dict[int,Dict[int,DictSpace]] = {} self.action_defns : Dict[int,Dict[int,DictSpace]] = {} for env_id, env_defn in enumerate(defn_map): for agent_id, agent_defn in env_defn.agent_definitions.items(): obs_space = DictSpace.from_proto(agent_defn.obs_space) if agent_defn.normalize_obs: obs_space = obs_space.to_normalized() self.obs_defns.setdefault(env_id, {}).setdefault(agent_id, obs_space) self.action_defns.setdefault(env_id, {}).setdefault( agent_id, DictSpace.from_proto(agent_defn.action_space) )
[docs] def get_obs_space(self, env_id:int, agent_id:int) -> DictSpace: """ Get the observation space for a specific environment and agent. Parameters ---------- env_id : int The ID of the environment. agent_id : int The ID of the agent. Returns ------- DictSpace The observation space for the specified environment and agent. """ return self.obs_defns[env_id][agent_id]
[docs] def get_action_space(self, env_id:int, agent_id:int) -> DictSpace: """ Get the action space for a specific environment and agent. Parameters ---------- env_id : int The ID of the environment. agent_id : int The ID of the agent. Returns ------- DictSpace The action space for the specified environment and agent. """ return self.action_defns[env_id][agent_id]
def _define_environment(self) -> None: """ Define the environment. This method retrieves the training definition from the gym stub and defines the environment based on the retrieved data. It populates the `ids` attribute with a nested list of all the environments and their active agents. It also populates the `agent_display_names` attribute with a nested dict mapping the id to the display names for each agent in each environment. Finally, it calls the `_create_space_definitions` method to create space definitions for the environment. Raises ------ NoEnvironmentsException If there are no environment definitions. NoAgentsException If there are no agents defined for any environment. """ training_defn : env_definitions.TrainingDefinition = self.gym_stub.RequestTrainingDefinition( gym_communication.TrainingDefinitionRequest() ) # just a nested list of all the environments and their active agents self.ids : List[List[int]] = [ [agent_id for agent_id in env_defn.agent_definitions] for env_defn in training_defn.environment_definitions ] if len(self.ids) == 0: raise NoEnvironmentsException() for env_id, agent_id_list in enumerate(self.ids): if len(agent_id_list) == 0: raise NoAgentsException(env_id) self.agent_display_names = [ {agent_id:env_defn.agent_definitions[agent_id].name for agent_id in self.ids[i]} for i,env_defn in enumerate(training_defn.environment_definitions) ] self._create_space_definitions(training_defn.environment_definitions)
[docs] def poll(self) -> Tuple[EnvAgentIdDict[Dict[str,Any]], EnvAgentIdDict[float], EnvAgentIdDict[bool], EnvAgentIdDict[bool], EnvAgentIdDict[Dict[str,str]]]: """ Polls the environment for the current state. Returns ------- observations : EnvAgentIdDict[Dict[str,Any]] A dictionary, keyed by the environment and agent Id, containing the observations for each agent. rewards : EnvAgentIdDict[float] A dictionary, keyed by the environment and agent Id, containing the reward for each agent. terminateds : EnvAgentIdDict[bool] A dictionary, keyed by the environment and agent Id, containing the termination flag for each agent. truncateds : EnvAgentIdDict[bool] A dictionary, keyed by the environment and agent Id, containing the truncation flag for each agent. infos : EnvAgentIdDict[Dict[str,str]]]: A dictionary, keyed by the environment and agent Id, containing the information dictionary for each agent. """ if self.steps == 0: logging.info("Starting Epoch") # convert action into Proto class state_update = gym_communication.TrainingStateUpdate() for env_id in self.next_action: env_update = state_update.updates[env_id].step for agent_id in self.next_action[env_id]: agent_update = env_update.updates[agent_id] self.action_defns[env_id][agent_id].fill_proto( agent_update.actions, self.next_action[env_id][agent_id] ) state_update.status = gym_communication.CommunicatorStatus.GOOD logging.debug(state_update) # send it to Unreal training_state = self.gym_stub.UpdateState(state_update) # convert proto to observations, reward, terminated, truncated and other info self.steps += 1 logging.debug(training_state) observations, rewards, terminateds, truncateds, infos = ( self._convert_state_to_tuple(training_state) ) logging.debug(observations) # welp let's see if this goes if len(observations.keys()) < 1: return self.poll() return observations, rewards, terminateds, truncateds, infos
[docs] def send_actions(self, action : EnvAgentIdDict[Dict[str,Any]]) -> None: """ Send Actions to all agents and environments. Parameters ---------- action : EnvAgentIdDict[Dict[str,Any]] A dictionary, keyed by the environment and agent Id, containing the actions for all active environments and agents. Notes ----- The actions are not sent to Unreal until Poll is called. See Also -------- poll : Where the actions are actually sent to unreal """ self.next_action = action
[docs] def hard_reset(self, env_ids:Optional[List[int]] = None, seeds: Union[None, List[int], int] = None, options: Union[List[Dict[str,str]], Dict[str,str], None] = None): """ Perform a hard reset on the environment. Parameters ---------- env_ids : Optional[List[int]] A list of environment IDs to reset. If None, all environments will be reset. Default is None. seeds : Union[None, List[int], int] The seeds to use for random number generation. If an int is provided, it will be used as the seed for all environments. If a list of ints is provided, each environment will be assigned a seed from the list. Default is None. options : Union[List[Dict[str,str]], Dict[str,str], None] The options to set for each environment. If a list of dictionaries is provided, each environment will be assigned the corresponding dictionary of options. If a single dictionary is provided, all environments will be assigned the same options. Default is None. Returns ------- List A list of environment IDs that were reset. Raises ------ AssertionError If the number of seeds provided, is not zero or one, and does not match the number of environments. AssertionError If the number of options dictionaries provided, is not zero or one, does not match the number of environments. Notes ----- - If seeds are provided, the environment will be seeded with the specified values. - If options are provided, the environment will be configured with the specified options. See Also -------- gymnasium.Env.reset : The equivalent operation in gymnasium """ if seeds is not None and isinstance(seeds, int): self.seed_sequence = np.random.SeedSequence(entropy=seeds) self.np_random = np.random.default_rng(self.seed_sequence.spawn(1)[0]) target_env_ids = env_ids if env_ids else range(self.num_envs) # abort any inprogress stuff state_update = gym_communication.TrainingStateUpdate() #generate seeds out here if not seeds is None: if isinstance(seeds,list): assert len(seeds) == self.num_envs, "Number of seeds must match number of environments, if passed as list" self.seeds = seeds else: #Note this converts the uint32 to a python int self.seeds = [np.int32(x.generate_state(1)).item() for x in self.seed_sequence.spawn(self.num_envs)] for env_id in target_env_ids: reset_msg = state_update.updates[env_id].reset if not seeds is None: reset_msg.seed = self.seeds[env_id] if not options is None: if isinstance(options,list): assert len(options) == self.num_envs, "Number of options dictionaries must match number of environments, if passed as list" env_options = options[env_id] else: env_options = options #convert to string for key in env_options: reset_msg.options[key] = str(env_options[key]) # send the message without caring about the response self.gym_stub.UpdateState.future(state_update) # reset everyone return self.soft_reset(target_env_ids)
[docs] def soft_reset(self, ids: List[str] = None) -> Tuple[EnvAgentIdDict[Dict[str,Any]], EnvAgentIdDict[Dict[str,str]]]: """ Soft reset the environment, by waiting for Unreal to rself reset and send a Post Reset State to python. Parameters ---------- ids : List[str], optional A list of environment IDs to reset. If not provided or set to None, all environment IDs will be reset. Returns ------- observations : EnvAgentIdDict[Dict[str,Any]] A dictionary, keyed by the environment and agent Id, containing the observations of the agents in the environments immediately following a reset infos : EnvAgentIdDict[Dict[str,str]] A dictionary, keyed by the environment and agent Id, containing the infos of the agents in the environment """ if ids == None or len(ids) == 0: ids = range(len(self.ids)) self.steps = 0 # send an empty request for an update logging.info("requesting environment state post reset") logging.info( f"Waiting for environment(s) {','.join([str(x) for x in ids])} to reset" ) state_request = gym_communication.InitialTrainingStateRequest() env_state : state.TrainingState = self.gym_stub.RequestInitialTrainingState(state_request) logging.debug(env_state) logging.info("initial environment state received") # Note: Removed other portions for Gym compatibility instead of gymnasium return self._convert_reset_state_to_tuple(env_state)
@property def num_agents(self) -> int: """ Return the total number of agents in the environment. Returns ------- int The total number of agents. """ return sum([len(x) for x in self.ids]) @property def num_envs(self) -> int: """ Return the number of environments. Returns ------- int The number of environments. """ return len(self.ids)
[docs] def close(self) -> None: """ Closes the connection to the Unreal Engine and cleans up any resources. It is safe to call this method multiple times. See Also -------- gymnasium.Env.close : The equivalent operation in gymnasium """ # if the connection is active if self.unreal_connection.is_active: state_update = gym_communication.TrainingStateUpdate() state_update.status = gym_communication.CommunicatorStatus.CLOSED self.gym_stub.UpdateState.future(state_update) logging.info("Sending closed msg to Unreal") # this closes the event loop as well #this method is safe to call multiple times self.unreal_connection.close()
def _convert_reset_state_to_tuple(self, reset_state : state.TrainingState) -> Tuple[EnvAgentIdDict[Dict[str,Any]], EnvAgentIdDict[Dict[str,str]]]: """ Convert the reset state, from a protobuf message to a tuple of observations and info. Parameters ---------- reset_state : state.TrainingState The reset state object. Returns ------- observations : EnvAgentIdDict[Dict[str,Any]] A dictionary, keyed by the environment and agent Id, containing the observations of the agents in the environments immediately following a reset infos : EnvAgentIdDict[Dict[str,str]] A dictionary, keyed by the environment and agent Id, containing the infos of the agents in the environment """ observations = {} info = {} for env_id, env_state in reset_state.environment_states.items(): for agent_id, agent_state in env_state.agent_states.items(): proc_obs = self.get_obs_space(env_id, agent_id).process_data( agent_state.observations ) observations.setdefault(env_id, {})[agent_id] = proc_obs info.setdefault(env_id, {})[agent_id] = dict(agent_state.info) return observations, info def _convert_state_to_tuple(self, training_state : state.TrainingState) -> Tuple[EnvAgentIdDict[Dict[str,Any]], EnvAgentIdDict[float], EnvAgentIdDict[bool], EnvAgentIdDict[bool], EnvAgentIdDict[Dict[str,str]]]: """ Convert a training state, from a protobuf message to a tuple of observations, rewards, terminateds, truncateds and infos. Parameters ---------- training_state : state.TrainingState The training state object. Returns ------- observations : EnvAgentIdDict[Dict[str,Any]] A dictionary, keyed by the environment and agent Id, containing the observations for each agent. rewards : EnvAgentIdDict[float] A dictionary, keyed by the environment and agent Id, containing the reward for each agent. terminateds : EnvAgentIdDict[bool] A dictionary, keyed by the environment and agent Id, containing the termination flag for each agent. truncateds : EnvAgentIdDict[bool] A dictionary, keyed by the environment and agent Id, containing the truncation flag for each agent. infos : EnvAgentIdDict[Dict[str,str]]]: A dictionary, keyed by the environment and agent Id, containing the information dictionary for each agent. """ observations = {} rewards = {} completeds = {} truncateds = {} info = {} for env_id, env_state in enumerate(training_state.environment_states): for agent_id, agent_state in env_state.agent_states.items(): proc_obs = self.get_obs_space(env_id, agent_id).process_data( agent_state.observations ) observations.setdefault(env_id, {})[agent_id] = proc_obs rewards.setdefault(env_id, {})[agent_id] = agent_state.reward completeds.setdefault(env_id, {})[agent_id] = ( agent_state.status == state.Status.COMPLETED ) truncateds.setdefault(env_id, {})[agent_id] = ( agent_state.status == state.Status.TRUNCATED ) info.setdefault(env_id, {})[agent_id] = dict(agent_state.info) return observations, rewards, completeds, truncateds, info

Related pages

  • Visit the Schola product page for download links and more information.

Looking for more documentation on GPUOpen?

AMD GPUOpen software blogs

Our handy software release blogs will help you make good use of our tools, SDKs, and effects, as well as sharing the latest features with new releases.

GPUOpen Manuals

Don’t miss our manual documentation! And if slide decks are what you’re after, you’ll find 100+ of our finest presentations here.

AMD GPUOpen Performance Guides

The home of great performance and optimization advice for AMD RDNA™ 2 GPUs, AMD Ryzen™ CPUs, and so much more.

Getting started: AMD GPUOpen software

New or fairly new to AMD’s tools, libraries, and effects? This is the best place to get started on GPUOpen!

AMD GPUOpen Getting Started Development and Performance

Looking for tips on getting started with developing and/or optimizing your game, whether on AMD hardware or generally? We’ve got you covered!

AMD GPUOpen Technical blogs

Browse our technical blogs, and find valuable advice on developing with AMD hardware, ray tracing, Vulkan®, DirectX®, Unreal Engine, and lots more.

Find out more about our software!

AMD GPUOpen Effects - AMD FidelityFX technologies

Create wonder. No black boxes. Meet the AMD FidelityFX SDK!

AMD GPUOpen Samples

Browse all our useful samples. Perfect for when you’re needing to get started, want to integrate one of our libraries, and much more.

AMD GPUOpen developer SDKs

Discover what our SDK technologies can offer you. Query hardware or software, manage memory, create rendering applications or machine learning, and much more!

AMD GPUOpen Developer Tools

Analyze, Optimize, Profile, Benchmark. We provide you with the developer tools you need to make sure your game is the best it can be!