diff --git a/tests/units/sc2/states/__init__.py b/tests/units/sc2/states/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/units/sc2/states/test_protoss_state.py b/tests/units/sc2/states/test_protoss_state.py new file mode 100644 index 00000000..cc451e09 --- /dev/null +++ b/tests/units/sc2/states/test_protoss_state.py @@ -0,0 +1,70 @@ +import unittest + +from pysc2.env import sc2_env +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.protoss_state import ProtossState + + +class TestProtossState(unittest.TestCase): + + + def test_protoss_state_no_raw(self): + # GIVEN + state = ProtossState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + assert state.player_race == sc2_env.Race.protoss + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_protoss_state_raw(self): + # GIVEN + state = ProtossState(grid_size=5, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.player_race == sc2_env.Race.protoss + assert state.dimension == 22 + ((5 * 5) * 2) + assert len(state.state[0]) == 22 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_starcraft2_state.py b/tests/units/sc2/states/test_starcraft2_state.py new file mode 100644 index 00000000..692821c0 --- /dev/null +++ b/tests/units/sc2/states/test_starcraft2_state.py @@ -0,0 +1,67 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.starcraft2_state import StarCraft2State + + +class TestStarCraft2State(unittest.TestCase): + + + def test_starcraft2_state_no_raw(self): + # GIVEN + state = StarCraft2State(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_starcraft2_state_raw(self): + # GIVEN + state = StarCraft2State(grid_size=5, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 9 + ((5 * 5) * 2) + assert len(state.state[0]) == 9 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_terran_state.py b/tests/units/sc2/states/test_terran_state.py new file mode 100644 index 00000000..eb758dbe --- /dev/null +++ b/tests/units/sc2/states/test_terran_state.py @@ -0,0 +1,70 @@ +import unittest + +from pysc2.env import sc2_env +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.terran_state import TerranState + + +class TestTerranState(unittest.TestCase): + + + def test_terran_state_no_raw(self): + # GIVEN + state = TerranState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + assert state.player_race == sc2_env.Race.terran + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_terran_state_raw(self): + # GIVEN + state = TerranState(grid_size=4, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.player_race == sc2_env.Race.terran + assert state.dimension == 22 + ((4 * 4) * 2) + assert len(state.state[0]) == 22 + ((4 * 4) * 2) diff --git a/tests/units/sc2/states/test_utils.py b/tests/units/sc2/states/test_utils.py new file mode 100644 index 00000000..eb238fac --- /dev/null +++ b/tests/units/sc2/states/test_utils.py @@ -0,0 +1,159 @@ +import unittest + +from pysc2.lib import features +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.utils import ( + append_player_and_enemy_grids, + create_raw_units_amount_dict, +) + + +class TestAuxSC2State(unittest.TestCase): + + def test_append_player_and_enemy_grids(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.ENEMY, + 'build_progress': 100, + 'x': 63, + 'y': 63, + }), + ] + }) + new_state = [] + # WHEN + new_state = append_player_and_enemy_grids(obs, new_state, 3, 64) + # THEN + assert len(new_state) == (18) + assert new_state[8] == 0.005 + assert new_state[9] == 0.01 + + def test_create_raw_units_amount_dict(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.ENEMY, + 'build_progress': 100, + 'x': 63, + 'y': 63, + }), + ] + }) + # WHEN + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.SELF) + # THEN + assert len(dict) == 2 + assert dict[1] == 1 + assert dict[2] == 1 + + def test_create_raw_units_amount_dict_alliance(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.ENEMY, + 'build_progress': 100, + 'x': 63, + 'y': 63, + }), + ] + }) + # WHEN + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.ENEMY) + # THEN + assert len(dict) == 1 + assert dict[1] == 0 + assert dict[2] == 1 + + def test_create_raw_units_amount_dict_no_units(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [] + }) + # WHEN + dict = create_raw_units_amount_dict(obs) + # THEN + assert len(dict) == 0 + assert dict == {} + + def test_create_raw_units_amount_dict_no_units_alliance(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.SELF, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': features.PlayerRelative.ENEMY, + 'build_progress': 100, + 'x': 63, + 'y': 63, + }), + ] + }) + # WHEN + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.NEUTRAL) + # THEN + assert len(dict) == 0 + assert dict == {} \ No newline at end of file diff --git a/tests/units/sc2/states/test_zerg_state.py b/tests/units/sc2/states/test_zerg_state.py new file mode 100644 index 00000000..c1e25f05 --- /dev/null +++ b/tests/units/sc2/states/test_zerg_state.py @@ -0,0 +1,70 @@ +import unittest + +from pysc2.env import sc2_env +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.zerg_state import ZergState + + +class TestZergState(unittest.TestCase): + + + def test_zerg_state_no_raw(self): + # GIVEN + state = ZergState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + assert state.player_race == sc2_env.Race.zerg + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_zerg_state_raw(self): + # GIVEN + state = ZergState(grid_size=6, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.player_race == sc2_env.Race.zerg + assert state.dimension == 22 + ((6 * 6) * 2) + assert len(state.state[0]) == 22 + ((6 * 6) * 2) diff --git a/urnai/constants.py b/urnai/constants.py new file mode 100644 index 00000000..1be00b52 --- /dev/null +++ b/urnai/constants.py @@ -0,0 +1,8 @@ +from enum import IntEnum + + +class SC2Constants(IntEnum): + MAX_UNITS = 200 + MAX_MINERALS = 6000 + MAX_VESPENE = 6000 + BUILD_PROGRESS_COMPLETE = 100 \ No newline at end of file diff --git a/urnai/sc2/states/__init__.py b/urnai/sc2/states/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py new file mode 100644 index 00000000..966e12ed --- /dev/null +++ b/urnai/sc2/states/protoss_state.py @@ -0,0 +1,47 @@ +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import units + +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict + + +class ProtossState(StarCraft2State): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + super().__init__(grid_size, use_raw_units, raw_resolution) + self.player_race = sc2_env.Race.protoss + + def update(self, obs): + state = super().update(obs) + + if self.use_raw_units: + raw_units_amount_dict = create_raw_units_amount_dict( + obs, sc2_env.features.PlayerRelative.SELF) + units_amount_info = [ + raw_units_amount_dict[units.Protoss.Nexus], + raw_units_amount_dict[units.Protoss.Pylon], + raw_units_amount_dict[units.Protoss.Assimilator], + raw_units_amount_dict[units.Protoss.Forge], + raw_units_amount_dict[units.Protoss.Gateway], + raw_units_amount_dict[units.Protoss.CyberneticsCore], + raw_units_amount_dict[units.Protoss.PhotonCannon], + raw_units_amount_dict[units.Protoss.RoboticsFacility], + raw_units_amount_dict[units.Protoss.Stargate], + raw_units_amount_dict[units.Protoss.TwilightCouncil], + raw_units_amount_dict[units.Protoss.RoboticsBay], + raw_units_amount_dict[units.Protoss.TemplarArchive], + raw_units_amount_dict[units.Protoss.DarkShrine], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file diff --git a/urnai/sc2/states/starcraft2_state.py b/urnai/sc2/states/starcraft2_state.py new file mode 100644 index 00000000..dc0b987e --- /dev/null +++ b/urnai/sc2/states/starcraft2_state.py @@ -0,0 +1,55 @@ +import numpy as np + +from urnai.constants import SC2Constants +from urnai.sc2.states.utils import append_player_and_enemy_grids +from urnai.states.state_base import StateBase + + +class StarCraft2State(StateBase): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + self.grid_size = grid_size + self.use_raw_units = use_raw_units + self.raw_resolution = raw_resolution + self.reset() + + def update(self, obs): + new_state = [ + # Adds general information from the player. + obs.player.minerals / SC2Constants.MAX_MINERALS, + obs.player.vespene / SC2Constants.MAX_VESPENE, + obs.player.food_cap / SC2Constants.MAX_UNITS, + obs.player.food_used / SC2Constants.MAX_UNITS, + obs.player.food_army / SC2Constants.MAX_UNITS, + obs.player.food_workers / SC2Constants.MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, + obs.player.army_count / SC2Constants.MAX_UNITS, + obs.player.idle_worker_count / SC2Constants.MAX_UNITS, + ] + + if self.use_raw_units: + new_state = append_player_and_enemy_grids( + obs, new_state, self.grid_size, self.raw_resolution + ) + + self._dimension = len(new_state) + final_state = np.expand_dims(new_state, axis=0) + self._state = final_state + return final_state + + @property + def state(self): + return self._state + + @property + def dimension(self): + return self._dimension + + def reset(self): + self._state = None + self._dimension = None diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py new file mode 100644 index 00000000..285a8d45 --- /dev/null +++ b/urnai/sc2/states/terran_state.py @@ -0,0 +1,49 @@ +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import units + +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict + + +class TerranState(StarCraft2State): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + super().__init__(grid_size, use_raw_units, raw_resolution) + self.player_race = sc2_env.Race.terran + + def update(self, obs): + state = super().update(obs) + + if self.use_raw_units: + raw_units_amount_dict = create_raw_units_amount_dict( + obs, sc2_env.features.PlayerRelative.SELF) + units_amount_info = [ + raw_units_amount_dict[units.Terran.CommandCenter] + + raw_units_amount_dict[units.Terran.OrbitalCommand] + + raw_units_amount_dict[units.Terran.PlanetaryFortress], + raw_units_amount_dict[units.Terran.SupplyDepot], + raw_units_amount_dict[units.Terran.Refinery], + raw_units_amount_dict[units.Terran.EngineeringBay], + raw_units_amount_dict[units.Terran.Armory], + raw_units_amount_dict[units.Terran.MissileTurret], + raw_units_amount_dict[units.Terran.SensorTower], + raw_units_amount_dict[units.Terran.Bunker], + raw_units_amount_dict[units.Terran.FusionCore], + raw_units_amount_dict[units.Terran.GhostAcademy], + raw_units_amount_dict[units.Terran.Barracks], + raw_units_amount_dict[units.Terran.Factory], + raw_units_amount_dict[units.Terran.Starport], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file diff --git a/urnai/sc2/states/utils.py b/urnai/sc2/states/utils.py new file mode 100644 index 00000000..a5347ada --- /dev/null +++ b/urnai/sc2/states/utils.py @@ -0,0 +1,61 @@ +import math +from collections import defaultdict + +import numpy as np +from pysc2.lib import features + +from urnai.constants import SC2Constants + + +def append_player_and_enemy_grids( + obs : list, + new_state : list, + grid_size : int, + raw_resolution : int, + ) -> list: + new_state = append_alliance_grid( + features.PlayerRelative.ENEMY, obs, new_state, grid_size, raw_resolution + ) + new_state = append_alliance_grid( + features.PlayerRelative.SELF, obs, new_state, grid_size, raw_resolution + ) + return new_state + +def append_alliance_grid( + alliance : features.PlayerRelative, + obs : list, + new_state : list, + grid_size : int, + raw_resolution : int, + ) -> list: + """ Instead of making a vector for all coordnates on the map, we'll + discretize our unit space and use a grid to store unit positions + by marking a square as 1 if there's any enemy/ally on it.""" + grid = np.zeros((grid_size, grid_size)) + + units = [unit for unit in obs.raw_units if + unit.alliance == alliance] + raw_to_grid_ratio = raw_resolution / grid_size + + for index in range(0, len(units)): + y = int(math.ceil((units[index].x + 1) / raw_to_grid_ratio)) + x = int(math.ceil((units[index].y + 1) / raw_to_grid_ratio)) + grid[x - 1][y - 1] += 1 + + # Normalizing the values to always be between 0 and 1 + grid = grid / SC2Constants.MAX_UNITS + + new_state.extend(grid.flatten()) + + return new_state + +def create_raw_units_amount_dict( + obs : list, + alliance : features.PlayerRelative = features.PlayerRelative.SELF + ) -> defaultdict: + dict = defaultdict(lambda: 0) + for unit in obs.raw_units: + if (unit.alliance == alliance + and unit.build_progress == SC2Constants.BUILD_PROGRESS_COMPLETE): + dict[unit.unit_type] += 1 + return dict diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py new file mode 100644 index 00000000..2443cb1a --- /dev/null +++ b/urnai/sc2/states/zerg_state.py @@ -0,0 +1,47 @@ +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import units + +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict + + +class ZergState(StarCraft2State): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + super().__init__(grid_size, use_raw_units, raw_resolution) + self.player_race = sc2_env.Race.zerg + + def update(self, obs): + state = super().update(obs) + + if self.use_raw_units: + raw_units_amount_dict = create_raw_units_amount_dict( + obs, sc2_env.features.PlayerRelative.SELF) + units_amount_info = [ + raw_units_amount_dict[units.Zerg.BanelingNest], + raw_units_amount_dict[units.Zerg.EvolutionChamber], + raw_units_amount_dict[units.Zerg.Extractor], + raw_units_amount_dict[units.Zerg.Hatchery], + raw_units_amount_dict[units.Zerg.HydraliskDen], + raw_units_amount_dict[units.Zerg.InfestationPit], + raw_units_amount_dict[units.Zerg.LurkerDen], + raw_units_amount_dict[units.Zerg.NydusNetwork], + raw_units_amount_dict[units.Zerg.RoachWarren], + raw_units_amount_dict[units.Zerg.SpawningPool], + raw_units_amount_dict[units.Zerg.SpineCrawler], + raw_units_amount_dict[units.Zerg.Spire], + raw_units_amount_dict[units.Zerg.SporeCrawler], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file