diff --git a/tests/units/sc2/states/test_protoss_state.py b/tests/units/sc2/states/test_protoss_state.py index bfcc8ec..cc451e0 100644 --- a/tests/units/sc2/states/test_protoss_state.py +++ b/tests/units/sc2/states/test_protoss_state.py @@ -1,5 +1,6 @@ import unittest +from pysc2.env import sc2_env from pysc2.lib.named_array import NamedDict from urnai.sc2.states.protoss_state import ProtossState @@ -26,6 +27,7 @@ def test_protoss_state_no_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.protoss assert state.dimension == 9 assert len(state.state[0]) == 9 @@ -63,5 +65,6 @@ def test_protoss_state_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.protoss assert state.dimension == 22 + ((5 * 5) * 2) assert len(state.state[0]) == 22 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_starcraft2_state.py b/tests/units/sc2/states/test_starcraft2_state.py new file mode 100644 index 0000000..692821c --- /dev/null +++ b/tests/units/sc2/states/test_starcraft2_state.py @@ -0,0 +1,67 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.starcraft2_state import StarCraft2State + + +class TestStarCraft2State(unittest.TestCase): + + + def test_starcraft2_state_no_raw(self): + # GIVEN + state = StarCraft2State(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_starcraft2_state_raw(self): + # GIVEN + state = StarCraft2State(grid_size=5, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 9 + ((5 * 5) * 2) + assert len(state.state[0]) == 9 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_terran_state.py b/tests/units/sc2/states/test_terran_state.py index 40d4633..eb758db 100644 --- a/tests/units/sc2/states/test_terran_state.py +++ b/tests/units/sc2/states/test_terran_state.py @@ -1,5 +1,6 @@ import unittest +from pysc2.env import sc2_env from pysc2.lib.named_array import NamedDict from urnai.sc2.states.terran_state import TerranState @@ -26,6 +27,7 @@ def test_terran_state_no_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.terran assert state.dimension == 9 assert len(state.state[0]) == 9 @@ -63,5 +65,6 @@ def test_terran_state_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.terran assert state.dimension == 22 + ((4 * 4) * 2) assert len(state.state[0]) == 22 + ((4 * 4) * 2) diff --git a/tests/units/sc2/states/test_utils.py b/tests/units/sc2/states/test_utils.py index 78cd601..eb238fa 100644 --- a/tests/units/sc2/states/test_utils.py +++ b/tests/units/sc2/states/test_utils.py @@ -1,5 +1,6 @@ import unittest +from pysc2.lib import features from pysc2.lib.named_array import NamedDict from urnai.sc2.states.utils import ( @@ -16,21 +17,21 @@ def test_append_player_and_enemy_grids(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -51,21 +52,21 @@ def test_create_raw_units_amount_dict(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -73,7 +74,7 @@ def test_create_raw_units_amount_dict(self): ] }) # WHEN - dict = create_raw_units_amount_dict(obs, 1) + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.SELF) # THEN assert len(dict) == 2 assert dict[1] == 1 @@ -85,21 +86,21 @@ def test_create_raw_units_amount_dict_alliance(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -107,7 +108,7 @@ def test_create_raw_units_amount_dict_alliance(self): ] }) # WHEN - dict = create_raw_units_amount_dict(obs, 4) + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.ENEMY) # THEN assert len(dict) == 1 assert dict[1] == 0 @@ -130,21 +131,21 @@ def test_create_raw_units_amount_dict_no_units_alliance(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -152,7 +153,7 @@ def test_create_raw_units_amount_dict_no_units_alliance(self): ] }) # WHEN - dict = create_raw_units_amount_dict(obs, 3) + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.NEUTRAL) # THEN assert len(dict) == 0 assert dict == {} \ No newline at end of file diff --git a/tests/units/sc2/states/test_zerg_state.py b/tests/units/sc2/states/test_zerg_state.py index 348fdf9..c1e25f0 100644 --- a/tests/units/sc2/states/test_zerg_state.py +++ b/tests/units/sc2/states/test_zerg_state.py @@ -1,5 +1,6 @@ import unittest +from pysc2.env import sc2_env from pysc2.lib.named_array import NamedDict from urnai.sc2.states.zerg_state import ZergState @@ -26,6 +27,7 @@ def test_zerg_state_no_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.zerg assert state.dimension == 9 assert len(state.state[0]) == 9 @@ -63,5 +65,6 @@ def test_zerg_state_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.zerg assert state.dimension == 22 + ((6 * 6) * 2) assert len(state.state[0]) == 22 + ((6 * 6) * 2) diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py index 0fe6b9e..966e12e 100644 --- a/urnai/sc2/states/protoss_state.py +++ b/urnai/sc2/states/protoss_state.py @@ -2,15 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.constants import SC2Constants -from urnai.sc2.states.utils import ( - append_player_and_enemy_grids, - create_raw_units_amount_dict, -) -from urnai.states.state_base import StateBase +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict -class ProtossState(StateBase): +class ProtossState(StarCraft2State): def __init__( self, @@ -18,64 +14,34 @@ def __init__( use_raw_units: bool = True, raw_resolution: int = 64, ): - self.grid_size = grid_size + super().__init__(grid_size, use_raw_units, raw_resolution) self.player_race = sc2_env.Race.protoss - self.use_raw_units = use_raw_units - self.raw_resolution = raw_resolution - self.reset() def update(self, obs): - new_state = [ - # Adds general information from the player. - obs.player.minerals / SC2Constants.MAX_MINERALS, - obs.player.vespene / SC2Constants.MAX_VESPENE, - obs.player.food_cap / SC2Constants.MAX_UNITS, - obs.player.food_used / SC2Constants.MAX_UNITS, - obs.player.food_army / SC2Constants.MAX_UNITS, - obs.player.food_workers / SC2Constants.MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, - obs.player.army_count / SC2Constants.MAX_UNITS, - obs.player.idle_worker_count / SC2Constants.MAX_UNITS, - ] + state = super().update(obs) if self.use_raw_units: raw_units_amount_dict = create_raw_units_amount_dict( obs, sc2_env.features.PlayerRelative.SELF) - new_state.extend( - [ - # Adds information related to player's Protoss units/buildings. - raw_units_amount_dict[units.Protoss.Nexus], - raw_units_amount_dict[units.Protoss.Pylon], - raw_units_amount_dict[units.Protoss.Assimilator], - raw_units_amount_dict[units.Protoss.Forge], - raw_units_amount_dict[units.Protoss.Gateway], - raw_units_amount_dict[units.Protoss.CyberneticsCore], - raw_units_amount_dict[units.Protoss.PhotonCannon], - raw_units_amount_dict[units.Protoss.RoboticsFacility], - raw_units_amount_dict[units.Protoss.Stargate], - raw_units_amount_dict[units.Protoss.TwilightCouncil], - raw_units_amount_dict[units.Protoss.RoboticsBay], - raw_units_amount_dict[units.Protoss.TemplarArchive], - raw_units_amount_dict[units.Protoss.DarkShrine], - ] - ) - new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution - ) - - self._dimension = len(new_state) - final_state = np.expand_dims(new_state, axis=0) - self._state = final_state - return final_state - - @property - def state(self): - return self._state - - @property - def dimension(self): - return self._dimension - - def reset(self): - self._state = None - self._dimension = None + units_amount_info = [ + raw_units_amount_dict[units.Protoss.Nexus], + raw_units_amount_dict[units.Protoss.Pylon], + raw_units_amount_dict[units.Protoss.Assimilator], + raw_units_amount_dict[units.Protoss.Forge], + raw_units_amount_dict[units.Protoss.Gateway], + raw_units_amount_dict[units.Protoss.CyberneticsCore], + raw_units_amount_dict[units.Protoss.PhotonCannon], + raw_units_amount_dict[units.Protoss.RoboticsFacility], + raw_units_amount_dict[units.Protoss.Stargate], + raw_units_amount_dict[units.Protoss.TwilightCouncil], + raw_units_amount_dict[units.Protoss.RoboticsBay], + raw_units_amount_dict[units.Protoss.TemplarArchive], + raw_units_amount_dict[units.Protoss.DarkShrine], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file diff --git a/urnai/sc2/states/starcraft2_state.py b/urnai/sc2/states/starcraft2_state.py new file mode 100644 index 0000000..dc0b987 --- /dev/null +++ b/urnai/sc2/states/starcraft2_state.py @@ -0,0 +1,55 @@ +import numpy as np + +from urnai.constants import SC2Constants +from urnai.sc2.states.utils import append_player_and_enemy_grids +from urnai.states.state_base import StateBase + + +class StarCraft2State(StateBase): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + self.grid_size = grid_size + self.use_raw_units = use_raw_units + self.raw_resolution = raw_resolution + self.reset() + + def update(self, obs): + new_state = [ + # Adds general information from the player. + obs.player.minerals / SC2Constants.MAX_MINERALS, + obs.player.vespene / SC2Constants.MAX_VESPENE, + obs.player.food_cap / SC2Constants.MAX_UNITS, + obs.player.food_used / SC2Constants.MAX_UNITS, + obs.player.food_army / SC2Constants.MAX_UNITS, + obs.player.food_workers / SC2Constants.MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, + obs.player.army_count / SC2Constants.MAX_UNITS, + obs.player.idle_worker_count / SC2Constants.MAX_UNITS, + ] + + if self.use_raw_units: + new_state = append_player_and_enemy_grids( + obs, new_state, self.grid_size, self.raw_resolution + ) + + self._dimension = len(new_state) + final_state = np.expand_dims(new_state, axis=0) + self._state = final_state + return final_state + + @property + def state(self): + return self._state + + @property + def dimension(self): + return self._dimension + + def reset(self): + self._state = None + self._dimension = None diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 283ff69..0f6ec8e 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -2,15 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.constants import SC2Constants -from urnai.sc2.states.utils import ( - append_player_and_enemy_grids, - create_raw_units_amount_dict, -) -from urnai.states.state_base import StateBase +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict -class TerranState(StateBase): +class TerranState(StarCraft2State): def __init__( self, @@ -18,66 +14,36 @@ def __init__( use_raw_units: bool = True, raw_resolution: int = 64, ): - self.grid_size = grid_size + super().__init__(grid_size, use_raw_units, raw_resolution) self.player_race = sc2_env.Race.terran - self.use_raw_units = use_raw_units - self.raw_resolution = raw_resolution - self.reset() def update(self, obs): - new_state = [ - # Adds general information from the player. - obs.player.minerals / SC2Constants.MAX_MINERALS, - obs.player.vespene / SC2Constants.MAX_VESPENE, - obs.player.food_cap / SC2Constants.MAX_UNITS, - obs.player.food_used / SC2Constants.MAX_UNITS, - obs.player.food_army / SC2Constants.MAX_UNITS, - obs.player.food_workers / SC2Constants.MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, - obs.player.army_count / SC2Constants.MAX_UNITS, - obs.player.idle_worker_count / SC2Constants.MAX_UNITS, - ] + state = super().update(obs) if self.use_raw_units: raw_units_amount_dict = create_raw_units_amount_dict( obs, sc2_env.features.PlayerRelative.SELF) - new_state.extend( - [ - # Adds information related to player's Terran units/buildings. - raw_units_amount_dict[units.Terran.CommandCenter] - + raw_units_amount_dict[units.Terran.OrbitalCommand] - + raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2, - raw_units_amount_dict[units.Terran.SupplyDepot] / 18, - raw_units_amount_dict[units.Terran.Refinery] / 4, - raw_units_amount_dict[units.Terran.EngineeringBay], - raw_units_amount_dict[units.Terran.Armory], - raw_units_amount_dict[units.Terran.MissileTurret] / 4, - raw_units_amount_dict[units.Terran.SensorTower] / 1, - raw_units_amount_dict[units.Terran.Bunker] / 4, - raw_units_amount_dict[units.Terran.FusionCore], - raw_units_amount_dict[units.Terran.GhostAcademy], - raw_units_amount_dict[units.Terran.Barracks] / 3, - raw_units_amount_dict[units.Terran.Factory] / 2, - raw_units_amount_dict[units.Terran.Starport] / 2, - ] - ) - new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution - ) - - self._dimension = len(new_state) - final_state = np.expand_dims(new_state, axis=0) - self._state = final_state - return final_state - - @property - def state(self): - return self._state - - @property - def dimension(self): - return self._dimension - - def reset(self): - self._state = None - self._dimension = None + units_amount_info = [ + raw_units_amount_dict[units.Terran.CommandCenter] + + raw_units_amount_dict[units.Terran.OrbitalCommand] + + raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2, + raw_units_amount_dict[units.Terran.SupplyDepot] / 18, + raw_units_amount_dict[units.Terran.Refinery] / 4, + raw_units_amount_dict[units.Terran.EngineeringBay], + raw_units_amount_dict[units.Terran.Armory], + raw_units_amount_dict[units.Terran.MissileTurret] / 4, + raw_units_amount_dict[units.Terran.SensorTower] / 1, + raw_units_amount_dict[units.Terran.Bunker] / 4, + raw_units_amount_dict[units.Terran.FusionCore], + raw_units_amount_dict[units.Terran.GhostAcademy], + raw_units_amount_dict[units.Terran.Barracks] / 3, + raw_units_amount_dict[units.Terran.Factory] / 2, + raw_units_amount_dict[units.Terran.Starport] / 2, + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py index 1448696..2443cb1 100644 --- a/urnai/sc2/states/zerg_state.py +++ b/urnai/sc2/states/zerg_state.py @@ -2,15 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.constants import SC2Constants -from urnai.sc2.states.utils import ( - append_player_and_enemy_grids, - create_raw_units_amount_dict, -) -from urnai.states.state_base import StateBase +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict -class ZergState(StateBase): +class ZergState(StarCraft2State): def __init__( self, @@ -18,64 +14,34 @@ def __init__( use_raw_units: bool = True, raw_resolution: int = 64, ): - self.grid_size = grid_size + super().__init__(grid_size, use_raw_units, raw_resolution) self.player_race = sc2_env.Race.zerg - self.use_raw_units = use_raw_units - self.raw_resolution = raw_resolution - self.reset() def update(self, obs): - new_state = [ - # Adds general information from the player. - obs.player.minerals / SC2Constants.MAX_MINERALS, - obs.player.vespene / SC2Constants.MAX_VESPENE, - obs.player.food_cap / SC2Constants.MAX_UNITS, - obs.player.food_used / SC2Constants.MAX_UNITS, - obs.player.food_army / SC2Constants.MAX_UNITS, - obs.player.food_workers / SC2Constants.MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, - obs.player.army_count / SC2Constants.MAX_UNITS, - obs.player.idle_worker_count / SC2Constants.MAX_UNITS, - ] + state = super().update(obs) if self.use_raw_units: raw_units_amount_dict = create_raw_units_amount_dict( obs, sc2_env.features.PlayerRelative.SELF) - new_state.extend( - [ - # Adds information related to player's Zerg units/buildings. - raw_units_amount_dict[units.Zerg.BanelingNest], - raw_units_amount_dict[units.Zerg.EvolutionChamber], - raw_units_amount_dict[units.Zerg.Extractor], - raw_units_amount_dict[units.Zerg.Hatchery], - raw_units_amount_dict[units.Zerg.HydraliskDen], - raw_units_amount_dict[units.Zerg.InfestationPit], - raw_units_amount_dict[units.Zerg.LurkerDen], - raw_units_amount_dict[units.Zerg.NydusNetwork], - raw_units_amount_dict[units.Zerg.RoachWarren], - raw_units_amount_dict[units.Zerg.SpawningPool], - raw_units_amount_dict[units.Zerg.SpineCrawler], - raw_units_amount_dict[units.Zerg.Spire], - raw_units_amount_dict[units.Zerg.SporeCrawler], - ] - ) - new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution - ) - - self._dimension = len(new_state) - final_state = np.expand_dims(new_state, axis=0) - self._state = final_state - return final_state - - @property - def state(self): - return self._state - - @property - def dimension(self): - return self._dimension - - def reset(self): - self._state = None - self._dimension = None + units_amount_info = [ + raw_units_amount_dict[units.Zerg.BanelingNest], + raw_units_amount_dict[units.Zerg.EvolutionChamber], + raw_units_amount_dict[units.Zerg.Extractor], + raw_units_amount_dict[units.Zerg.Hatchery], + raw_units_amount_dict[units.Zerg.HydraliskDen], + raw_units_amount_dict[units.Zerg.InfestationPit], + raw_units_amount_dict[units.Zerg.LurkerDen], + raw_units_amount_dict[units.Zerg.NydusNetwork], + raw_units_amount_dict[units.Zerg.RoachWarren], + raw_units_amount_dict[units.Zerg.SpawningPool], + raw_units_amount_dict[units.Zerg.SpineCrawler], + raw_units_amount_dict[units.Zerg.Spire], + raw_units_amount_dict[units.Zerg.SporeCrawler], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file