From 51cd0c41083edfa5f099d14b5424b7d5a12f6ca9 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 19 Mar 2024 16:06:28 -0300 Subject: [PATCH 01/11] chore(issue-73): Added initial Terran State class --- urnai/sc2/states/__init__.py | 0 urnai/sc2/states/terran_state.py | 121 +++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 urnai/sc2/states/__init__.py create mode 100644 urnai/sc2/states/terran_state.py diff --git a/urnai/sc2/states/__init__.py b/urnai/sc2/states/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py new file mode 100644 index 00000000..c3bd47a5 --- /dev/null +++ b/urnai/sc2/states/terran_state.py @@ -0,0 +1,121 @@ +import math + +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import features, units + +from urnai.states.state_base import StateBase + + +class TerranState(StateBase): + + def __init__(self, grid_size=4): + + self.grid_size = grid_size + # Size of the state returned + # 22: number of data added to the state (number of minerals, army_count, etc) + # 2 * 4 ** 2: 2 grids of size 4 x 4, representing enemy and player units + self._state_size = int(22 + 2 * (self.grid_size ** 2)) + self.player_race = sc2_env.Race.terran + self.base_top_left = None + self._state = None + + def update(self, obs): + if obs.game_loop[0] < 80 and self.base_top_left is None: + + commandcenter = get_units_by_type(obs, units.Terran.CommandCenter) + + if len(commandcenter) > 0: + townhall = commandcenter[0] + self.player_race = sc2_env.Race.terran + + self.base_top_left = (townhall.x < 32) + + new_state = [] + # Adds general information from the player. + new_state.append(obs.player.minerals / 6000) + new_state.append(obs.player.vespene / 6000) + new_state.append(obs.player.food_cap / 200) + new_state.append(obs.player.food_used / 200) + new_state.append(obs.player.food_army / 200) + new_state.append(obs.player.food_workers / 200) + new_state.append((obs.player.food_cap - obs.player.food_used) / 200) + new_state.append(obs.player.army_count / 200) + new_state.append(obs.player.idle_worker_count / 200) + + # Adds information related to player's Terran units/buildings. + new_state.append(get_my_units_amount(obs, units.Terran.CommandCenter) + + get_my_units_amount(obs, units.Terran.OrbitalCommand) + + get_my_units_amount(obs, units.Terran.PlanetaryFortress) / 2) + new_state.append(get_my_units_amount(obs, units.Terran.SupplyDepot) / 18) + new_state.append(get_my_units_amount(obs, units.Terran.Refinery) / 4) + new_state.append(get_my_units_amount(obs, units.Terran.EngineeringBay)) + new_state.append(get_my_units_amount(obs, units.Terran.Armory)) + new_state.append(get_my_units_amount(obs, units.Terran.MissileTurret) / 4) + new_state.append(get_my_units_amount(obs, units.Terran.SensorTower)/1) + new_state.append(get_my_units_amount(obs, units.Terran.Bunker)/4) + new_state.append(get_my_units_amount(obs, units.Terran.FusionCore)) + new_state.append(get_my_units_amount(obs, units.Terran.GhostAcademy)) + new_state.append(get_my_units_amount(obs, units.Terran.Barracks) / 3) + new_state.append(get_my_units_amount(obs, units.Terran.Factory) / 2) + new_state.append(get_my_units_amount(obs, units.Terran.Starport) / 2) + + # Instead of making a vector for all coordnates on the map, we'll + # discretize our enemy space + # and use a 4x4 grid to store enemy positions by marking a square as 1 if + # there's any enemy on it. + + enemy_grid = np.zeros((self.grid_size, self.grid_size)) + player_grid = np.zeros((self.grid_size, self.grid_size)) + + enemy_units = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.ENEMY] + player_units = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.SELF] + + for i in range(0, len(enemy_units)): + y = int(math.ceil((enemy_units[i].x + 1) / 64 / self.grid_size)) + x = int(math.ceil((enemy_units[i].y + 1) / 64 / self.grid_size)) + enemy_grid[x - 1][y - 1] += 1 + + for i in range(0, len(player_units)): + y = int(math.ceil((player_units[i].x + 1) / (64 / self.grid_size))) + x = int(math.ceil((player_units[i].y + 1) / (64 / self.grid_size))) + player_grid[x - 1][y - 1] += 1 + + if not self.base_top_left: + enemy_grid = np.rot90(enemy_grid, 2) + player_grid = np.rot90(player_grid, 2) + + # Normalizing the values to always be between 0 and 1 + # (since the max amount of units in SC2 is 200) + enemy_grid = enemy_grid / 200 + player_grid = player_grid / 200 + + new_state.extend(enemy_grid.flatten()) + new_state.extend(player_grid.flatten()) + final_state = np.expand_dims(new_state, axis=0) + + self._state = final_state + return final_state + + @property + def state(self): + return self._state + + @property + def dimension(self): + return self._state_size + + def reset(self): + self._state = None + self.base_top_left = None + +def get_my_units_amount(obs, unit_type): + return len(get_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) + +def get_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): + return [unit for unit in obs.raw_units + if unit.unit_type == unit_type + and unit.alliance == alliance + and unit.build_progress == 100] \ No newline at end of file From e90542f4bb997cf23518509f58e1475c6901122a Mon Sep 17 00:00:00 2001 From: RickFqt Date: Thu, 18 Apr 2024 15:46:42 -0300 Subject: [PATCH 02/11] refactor: Adjusted to reviewer's suggestions --- urnai/sc2/states/states_utils.py | 11 +++ urnai/sc2/states/terran_state.py | 165 ++++++++++++++----------------- 2 files changed, 84 insertions(+), 92 deletions(-) create mode 100644 urnai/sc2/states/states_utils.py diff --git a/urnai/sc2/states/states_utils.py b/urnai/sc2/states/states_utils.py new file mode 100644 index 00000000..172dc60f --- /dev/null +++ b/urnai/sc2/states/states_utils.py @@ -0,0 +1,11 @@ +from pysc2.lib import features + + +def get_my_units_amount(obs, unit_type): + return len(get_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) + +def get_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): + return [unit for unit in obs.raw_units + if unit.unit_type == unit_type + and unit.alliance == alliance + and unit.build_progress == 100] \ No newline at end of file diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index c3bd47a5..8e4c3067 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -4,118 +4,99 @@ from pysc2.env import sc2_env from pysc2.lib import features, units +from urnai.sc2.states.states_utils import get_my_units_amount from urnai.states.state_base import StateBase class TerranState(StateBase): - - def __init__(self, grid_size=4): + def __init__(self, grid_size=4): self.grid_size = grid_size # Size of the state returned # 22: number of data added to the state (number of minerals, army_count, etc) # 2 * 4 ** 2: 2 grids of size 4 x 4, representing enemy and player units - self._state_size = int(22 + 2 * (self.grid_size ** 2)) + self.dimension = int(22 + 2 * (self.grid_size ** 2)) self.player_race = sc2_env.Race.terran - self.base_top_left = None - self._state = None + self.reset() def update(self, obs): - if obs.game_loop[0] < 80 and self.base_top_left is None: - - commandcenter = get_units_by_type(obs, units.Terran.CommandCenter) - - if len(commandcenter) > 0: - townhall = commandcenter[0] - self.player_race = sc2_env.Race.terran - - self.base_top_left = (townhall.x < 32) - - new_state = [] - # Adds general information from the player. - new_state.append(obs.player.minerals / 6000) - new_state.append(obs.player.vespene / 6000) - new_state.append(obs.player.food_cap / 200) - new_state.append(obs.player.food_used / 200) - new_state.append(obs.player.food_army / 200) - new_state.append(obs.player.food_workers / 200) - new_state.append((obs.player.food_cap - obs.player.food_used) / 200) - new_state.append(obs.player.army_count / 200) - new_state.append(obs.player.idle_worker_count / 200) - - # Adds information related to player's Terran units/buildings. - new_state.append(get_my_units_amount(obs, units.Terran.CommandCenter) + - get_my_units_amount(obs, units.Terran.OrbitalCommand) + - get_my_units_amount(obs, units.Terran.PlanetaryFortress) / 2) - new_state.append(get_my_units_amount(obs, units.Terran.SupplyDepot) / 18) - new_state.append(get_my_units_amount(obs, units.Terran.Refinery) / 4) - new_state.append(get_my_units_amount(obs, units.Terran.EngineeringBay)) - new_state.append(get_my_units_amount(obs, units.Terran.Armory)) - new_state.append(get_my_units_amount(obs, units.Terran.MissileTurret) / 4) - new_state.append(get_my_units_amount(obs, units.Terran.SensorTower)/1) - new_state.append(get_my_units_amount(obs, units.Terran.Bunker)/4) - new_state.append(get_my_units_amount(obs, units.Terran.FusionCore)) - new_state.append(get_my_units_amount(obs, units.Terran.GhostAcademy)) - new_state.append(get_my_units_amount(obs, units.Terran.Barracks) / 3) - new_state.append(get_my_units_amount(obs, units.Terran.Factory) / 2) - new_state.append(get_my_units_amount(obs, units.Terran.Starport) / 2) - - # Instead of making a vector for all coordnates on the map, we'll - # discretize our enemy space - # and use a 4x4 grid to store enemy positions by marking a square as 1 if - # there's any enemy on it. - - enemy_grid = np.zeros((self.grid_size, self.grid_size)) - player_grid = np.zeros((self.grid_size, self.grid_size)) - - enemy_units = [unit for unit in obs.raw_units if - unit.alliance == features.PlayerRelative.ENEMY] - player_units = [unit for unit in obs.raw_units if - unit.alliance == features.PlayerRelative.SELF] - - for i in range(0, len(enemy_units)): - y = int(math.ceil((enemy_units[i].x + 1) / 64 / self.grid_size)) - x = int(math.ceil((enemy_units[i].y + 1) / 64 / self.grid_size)) - enemy_grid[x - 1][y - 1] += 1 - - for i in range(0, len(player_units)): - y = int(math.ceil((player_units[i].x + 1) / (64 / self.grid_size))) - x = int(math.ceil((player_units[i].y + 1) / (64 / self.grid_size))) - player_grid[x - 1][y - 1] += 1 - - if not self.base_top_left: - enemy_grid = np.rot90(enemy_grid, 2) - player_grid = np.rot90(player_grid, 2) - - # Normalizing the values to always be between 0 and 1 - # (since the max amount of units in SC2 is 200) - enemy_grid = enemy_grid / 200 - player_grid = player_grid / 200 - - new_state.extend(enemy_grid.flatten()) - new_state.extend(player_grid.flatten()) + new_state = [ + # Adds general information from the player. + obs.player.minerals / 6000, + obs.player.vespene / 6000, + obs.player.food_cap / 200, + obs.player.food_used / 200, + obs.player.food_army / 200, + obs.player.food_workers / 200, + (obs.player.food_cap - obs.player.food_used) / 200, + obs.player.army_count / 200, + obs.player.idle_worker_count / 200, + # Adds information related to player's Terran units/buildings. + get_my_units_amount(obs, units.Terran.CommandCenter) + + get_my_units_amount(obs, units.Terran.OrbitalCommand) + + get_my_units_amount(obs, units.Terran.PlanetaryFortress) / 2, + get_my_units_amount(obs, units.Terran.SupplyDepot) / 18, + get_my_units_amount(obs, units.Terran.Refinery) / 4, + get_my_units_amount(obs, units.Terran.EngineeringBay), + get_my_units_amount(obs, units.Terran.Armory), + get_my_units_amount(obs, units.Terran.MissileTurret) / 4, + get_my_units_amount(obs, units.Terran.SensorTower)/1, + get_my_units_amount(obs, units.Terran.Bunker)/4, + get_my_units_amount(obs, units.Terran.FusionCore), + get_my_units_amount(obs, units.Terran.GhostAcademy), + get_my_units_amount(obs, units.Terran.Barracks) / 3, + get_my_units_amount(obs, units.Terran.Factory) / 2, + get_my_units_amount(obs, units.Terran.Starport) / 2 + ] + + new_state = append_player_and_enemy_grids(obs, new_state, self.grid_size) + final_state = np.expand_dims(new_state, axis=0) - self._state = final_state + self.state = final_state return final_state @property def state(self): - return self._state + return self.state @property def dimension(self): - return self._state_size + return self.dimension def reset(self): - self._state = None - self.base_top_left = None - -def get_my_units_amount(obs, unit_type): - return len(get_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) - -def get_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): - return [unit for unit in obs.raw_units - if unit.unit_type == unit_type - and unit.alliance == alliance - and unit.build_progress == 100] \ No newline at end of file + self.state = None + +def append_player_and_enemy_grids(obs, new_state, grid_size): + """ Instead of making a vector for all coordnates on the map, we'll + discretize our enemy space and use a grid to store enemy positions + by marking a square as 1 if there's any enemy on it.""" + enemy_grid = np.zeros((grid_size, grid_size)) + player_grid = np.zeros((grid_size, grid_size)) + + enemy_units = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.ENEMY] + player_units = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.SELF] + + for enemy_index in range(0, len(enemy_units)): + #TODO: Fix the "64" number, probably should be the size of the map + y = int(math.ceil((enemy_units[enemy_index].x + 1) / (64 / grid_size))) + x = int(math.ceil((enemy_units[enemy_index].y + 1) / (64 / grid_size))) + enemy_grid[x - 1][y - 1] += 1 + + for player_index in range(0, len(player_units)): + #TODO: Fix the "64" number, probably should be the size of the map + y = int(math.ceil((player_units[player_index].x + 1) / (64 / grid_size))) + x = int(math.ceil((player_units[player_index].y + 1) / (64 / grid_size))) + player_grid[x - 1][y - 1] += 1 + + # Normalizing the values to always be between 0 and 1 + # (since the max amount of units in SC2 is 200) + enemy_grid = enemy_grid / 200 + player_grid = player_grid / 200 + + new_state.extend(enemy_grid.flatten()) + new_state.extend(player_grid.flatten()) + + return new_state From cbe7fc73de9358b71045e6503e27dbe2d8325751 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 21 Jun 2024 10:49:12 -0300 Subject: [PATCH 03/11] refactor: Changed TerranState to work with raw units --- urnai/sc2/states/states_utils.py | 6 +-- urnai/sc2/states/terran_state.py | 81 +++++++++++++++++--------------- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/urnai/sc2/states/states_utils.py b/urnai/sc2/states/states_utils.py index 172dc60f..c5327fa3 100644 --- a/urnai/sc2/states/states_utils.py +++ b/urnai/sc2/states/states_utils.py @@ -1,10 +1,10 @@ from pysc2.lib import features -def get_my_units_amount(obs, unit_type): - return len(get_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) +def get_my_raw_units_amount(obs, unit_type): + return len(get_raw_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) -def get_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): +def get_raw_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): return [unit for unit in obs.raw_units if unit.unit_type == unit_type and unit.alliance == alliance diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 8e4c3067..a9b49b75 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -4,19 +4,22 @@ from pysc2.env import sc2_env from pysc2.lib import features, units -from urnai.sc2.states.states_utils import get_my_units_amount +from urnai.sc2.states.states_utils import get_my_raw_units_amount from urnai.states.state_base import StateBase class TerranState(StateBase): - def __init__(self, grid_size=4): + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): self.grid_size = grid_size - # Size of the state returned - # 22: number of data added to the state (number of minerals, army_count, etc) - # 2 * 4 ** 2: 2 grids of size 4 x 4, representing enemy and player units - self.dimension = int(22 + 2 * (self.grid_size ** 2)) self.player_race = sc2_env.Race.terran + self.use_raw_units = use_raw_units + self.raw_resolution = raw_resolution self.reset() def update(self, obs): @@ -30,44 +33,49 @@ def update(self, obs): obs.player.food_workers / 200, (obs.player.food_cap - obs.player.food_used) / 200, obs.player.army_count / 200, - obs.player.idle_worker_count / 200, - # Adds information related to player's Terran units/buildings. - get_my_units_amount(obs, units.Terran.CommandCenter) + - get_my_units_amount(obs, units.Terran.OrbitalCommand) + - get_my_units_amount(obs, units.Terran.PlanetaryFortress) / 2, - get_my_units_amount(obs, units.Terran.SupplyDepot) / 18, - get_my_units_amount(obs, units.Terran.Refinery) / 4, - get_my_units_amount(obs, units.Terran.EngineeringBay), - get_my_units_amount(obs, units.Terran.Armory), - get_my_units_amount(obs, units.Terran.MissileTurret) / 4, - get_my_units_amount(obs, units.Terran.SensorTower)/1, - get_my_units_amount(obs, units.Terran.Bunker)/4, - get_my_units_amount(obs, units.Terran.FusionCore), - get_my_units_amount(obs, units.Terran.GhostAcademy), - get_my_units_amount(obs, units.Terran.Barracks) / 3, - get_my_units_amount(obs, units.Terran.Factory) / 2, - get_my_units_amount(obs, units.Terran.Starport) / 2 + obs.player.idle_worker_count / 200 ] - new_state = append_player_and_enemy_grids(obs, new_state, self.grid_size) - + if(self.use_raw_units): + new_state.extend([ + # Adds information related to player's Terran units/buildings. + get_my_raw_units_amount(obs, units.Terran.CommandCenter) + + get_my_raw_units_amount(obs, units.Terran.OrbitalCommand) + + get_my_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2, + get_my_raw_units_amount(obs, units.Terran.SupplyDepot) / 18, + get_my_raw_units_amount(obs, units.Terran.Refinery) / 4, + get_my_raw_units_amount(obs, units.Terran.EngineeringBay), + get_my_raw_units_amount(obs, units.Terran.Armory), + get_my_raw_units_amount(obs, units.Terran.MissileTurret) / 4, + get_my_raw_units_amount(obs, units.Terran.SensorTower)/1, + get_my_raw_units_amount(obs, units.Terran.Bunker)/4, + get_my_raw_units_amount(obs, units.Terran.FusionCore), + get_my_raw_units_amount(obs, units.Terran.GhostAcademy_raw), + get_my_raw_units_amount(obs, units.Terran.Barracks) / 3, + get_my_raw_units_amount(obs, units.Terran.Factory) / 2, + get_my_raw_units_amount(obs, units.Terran.Starport) / 2 + ]) + new_state = append_player_and_enemy_grids( + obs, new_state, self.grid_size, self.raw_resolution) + + self._dimension = len(new_state) final_state = np.expand_dims(new_state, axis=0) - - self.state = final_state + self._state = final_state return final_state @property def state(self): - return self.state + return self._state @property def dimension(self): - return self.dimension + return self._dimension def reset(self): - self.state = None + self._state = None + self._dimension = None -def append_player_and_enemy_grids(obs, new_state, grid_size): +def append_player_and_enemy_grids(obs, new_state, grid_size, raw_resolution): """ Instead of making a vector for all coordnates on the map, we'll discretize our enemy space and use a grid to store enemy positions by marking a square as 1 if there's any enemy on it.""" @@ -78,17 +86,16 @@ def append_player_and_enemy_grids(obs, new_state, grid_size): unit.alliance == features.PlayerRelative.ENEMY] player_units = [unit for unit in obs.raw_units if unit.alliance == features.PlayerRelative.SELF] + raw_to_grid_ratio = raw_resolution / grid_size for enemy_index in range(0, len(enemy_units)): - #TODO: Fix the "64" number, probably should be the size of the map - y = int(math.ceil((enemy_units[enemy_index].x + 1) / (64 / grid_size))) - x = int(math.ceil((enemy_units[enemy_index].y + 1) / (64 / grid_size))) + y = int(math.ceil((enemy_units[enemy_index].x + 1) / raw_to_grid_ratio)) + x = int(math.ceil((enemy_units[enemy_index].y + 1) / raw_to_grid_ratio)) enemy_grid[x - 1][y - 1] += 1 for player_index in range(0, len(player_units)): - #TODO: Fix the "64" number, probably should be the size of the map - y = int(math.ceil((player_units[player_index].x + 1) / (64 / grid_size))) - x = int(math.ceil((player_units[player_index].y + 1) / (64 / grid_size))) + y = int(math.ceil((player_units[player_index].x + 1) / raw_to_grid_ratio)) + x = int(math.ceil((player_units[player_index].y + 1) / raw_to_grid_ratio)) player_grid[x - 1][y - 1] += 1 # Normalizing the values to always be between 0 and 1 From e1c2dd6540b87cf3896728ee12795d844b9c32d2 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 2 Jul 2024 14:18:56 -0300 Subject: [PATCH 04/11] test: Added test for sc2 terran state --- tests/units/sc2/states/__init__.py | 0 .../sc2/states/test_aux_methods_sc2_state.py | 218 ++++++++++++++++++ tests/units/sc2/states/test_sc2_state.py | 69 ++++++ urnai/sc2/states/states_utils.py | 36 +++ urnai/sc2/states/terran_state.py | 97 +++----- 5 files changed, 357 insertions(+), 63 deletions(-) create mode 100644 tests/units/sc2/states/__init__.py create mode 100644 tests/units/sc2/states/test_aux_methods_sc2_state.py create mode 100644 tests/units/sc2/states/test_sc2_state.py diff --git a/tests/units/sc2/states/__init__.py b/tests/units/sc2/states/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/units/sc2/states/test_aux_methods_sc2_state.py b/tests/units/sc2/states/test_aux_methods_sc2_state.py new file mode 100644 index 00000000..c29b2bf7 --- /dev/null +++ b/tests/units/sc2/states/test_aux_methods_sc2_state.py @@ -0,0 +1,218 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.states_utils import ( + append_player_and_enemy_grids, + get_my_raw_units_amount, + get_raw_units_by_type, +) + + +class TestAuxSC2State(unittest.TestCase): + + def test_append_player_and_enemy_grids(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 4, + 'build_progress': 100, + 'x': 63, + 'y': 63, + }), + ] + }) + new_state = [] + # WHEN + new_state = append_player_and_enemy_grids(obs, new_state, 3, 64) + # THEN + print(new_state) + assert len(new_state) == (18) + assert new_state[8] == 0.005 + assert new_state[9] == 0.01 + + def test_get_my_raw_units_amount(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + amount = get_my_raw_units_amount(obs, 1) + # THEN + assert amount == 1 + + def test_get_my_raw_units_amount_no_units(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [] + }) + # WHEN + amount = get_my_raw_units_amount(obs, 1) + # THEN + assert amount == 0 + + def test_get_my_raw_units_amount_no_units_of_type(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + amount = get_my_raw_units_amount(obs, 1) + # THEN + assert amount == 0 + + def test_get_my_raw_units_amount_no_units_of_alliance(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 2, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 1, + 'alliance': 2, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + amount = get_my_raw_units_amount(obs, 1) + # THEN + assert amount == 0 + + def test_get_my_raw_units_amount_no_units_of_build_progress(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 50, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 50, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + amount = get_my_raw_units_amount(obs, 1) + # THEN + assert amount == 0 + + def test_get_raw_units_by_type(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + units = get_raw_units_by_type(obs, 1) + # THEN + assert len(units) == 1 + + def test_get_raw_units_by_type_no_units(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [] + }) + # WHEN + units = get_raw_units_by_type(obs, 1) + # THEN + assert len(units) == 0 + + def test_get_raw_units_by_type_no_units_of_type(self): + # GIVEN + obs = NamedDict({ + 'raw_units': [ + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + units = get_raw_units_by_type(obs, 1) + # THEN + assert len(units) == 0 \ No newline at end of file diff --git a/tests/units/sc2/states/test_sc2_state.py b/tests/units/sc2/states/test_sc2_state.py new file mode 100644 index 00000000..add4f73b --- /dev/null +++ b/tests/units/sc2/states/test_sc2_state.py @@ -0,0 +1,69 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.terran_state import TerranState + + +class TestSC2State(unittest.TestCase): + + + def test_terran_state_no_raw(self): + # GIVEN + state = TerranState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + print(state.state) + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + + def test_terran_state_raw(self): + # GIVEN + state = TerranState(grid_size=4, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 22 + ((4 * 4) * 2) + assert len(state.state[0]) == 22 + ((4 * 4) * 2) \ No newline at end of file diff --git a/urnai/sc2/states/states_utils.py b/urnai/sc2/states/states_utils.py index c5327fa3..ed94c096 100644 --- a/urnai/sc2/states/states_utils.py +++ b/urnai/sc2/states/states_utils.py @@ -1,6 +1,42 @@ +import math + +import numpy as np from pysc2.lib import features +def append_player_and_enemy_grids(obs, new_state, grid_size, raw_resolution): + """ Instead of making a vector for all coordnates on the map, we'll + discretize our enemy space and use a grid to store enemy positions + by marking a square as 1 if there's any enemy on it.""" + enemy_grid = np.zeros((grid_size, grid_size)) + player_grid = np.zeros((grid_size, grid_size)) + + enemy_units = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.ENEMY] + player_units = [unit for unit in obs.raw_units if + unit.alliance == features.PlayerRelative.SELF] + raw_to_grid_ratio = raw_resolution / grid_size + + for enemy_index in range(0, len(enemy_units)): + y = int(math.ceil((enemy_units[enemy_index].x + 1) / raw_to_grid_ratio)) + x = int(math.ceil((enemy_units[enemy_index].y + 1) / raw_to_grid_ratio)) + enemy_grid[x - 1][y - 1] += 1 + + for player_index in range(0, len(player_units)): + y = int(math.ceil((player_units[player_index].x + 1) / raw_to_grid_ratio)) + x = int(math.ceil((player_units[player_index].y + 1) / raw_to_grid_ratio)) + player_grid[x - 1][y - 1] += 1 + + # Normalizing the values to always be between 0 and 1 + # (since the max amount of units in SC2 is 200) + enemy_grid = enemy_grid / 200 + player_grid = player_grid / 200 + + new_state.extend(enemy_grid.flatten()) + new_state.extend(player_grid.flatten()) + + return new_state + def get_my_raw_units_amount(obs, unit_type): return len(get_raw_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index a9b49b75..e9c8c0ce 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -1,21 +1,22 @@ -import math - import numpy as np from pysc2.env import sc2_env -from pysc2.lib import features, units +from pysc2.lib import units -from urnai.sc2.states.states_utils import get_my_raw_units_amount +from urnai.sc2.states.states_utils import ( + append_player_and_enemy_grids, + get_my_raw_units_amount, +) from urnai.states.state_base import StateBase class TerranState(StateBase): def __init__( - self, - grid_size: int = 4, - use_raw_units: bool = True, - raw_resolution: int = 64, - ): + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): self.grid_size = grid_size self.player_race = sc2_env.Race.terran self.use_raw_units = use_raw_units @@ -33,30 +34,33 @@ def update(self, obs): obs.player.food_workers / 200, (obs.player.food_cap - obs.player.food_used) / 200, obs.player.army_count / 200, - obs.player.idle_worker_count / 200 + obs.player.idle_worker_count / 200, ] - if(self.use_raw_units): - new_state.extend([ - # Adds information related to player's Terran units/buildings. - get_my_raw_units_amount(obs, units.Terran.CommandCenter) + - get_my_raw_units_amount(obs, units.Terran.OrbitalCommand) + - get_my_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2, - get_my_raw_units_amount(obs, units.Terran.SupplyDepot) / 18, - get_my_raw_units_amount(obs, units.Terran.Refinery) / 4, - get_my_raw_units_amount(obs, units.Terran.EngineeringBay), - get_my_raw_units_amount(obs, units.Terran.Armory), - get_my_raw_units_amount(obs, units.Terran.MissileTurret) / 4, - get_my_raw_units_amount(obs, units.Terran.SensorTower)/1, - get_my_raw_units_amount(obs, units.Terran.Bunker)/4, - get_my_raw_units_amount(obs, units.Terran.FusionCore), - get_my_raw_units_amount(obs, units.Terran.GhostAcademy_raw), - get_my_raw_units_amount(obs, units.Terran.Barracks) / 3, - get_my_raw_units_amount(obs, units.Terran.Factory) / 2, - get_my_raw_units_amount(obs, units.Terran.Starport) / 2 - ]) + if self.use_raw_units: + new_state.extend( + [ + # Adds information related to player's Terran units/buildings. + get_my_raw_units_amount(obs, units.Terran.CommandCenter) + + get_my_raw_units_amount(obs, units.Terran.OrbitalCommand) + + get_my_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2, + get_my_raw_units_amount(obs, units.Terran.SupplyDepot) / 18, + get_my_raw_units_amount(obs, units.Terran.Refinery) / 4, + get_my_raw_units_amount(obs, units.Terran.EngineeringBay), + get_my_raw_units_amount(obs, units.Terran.Armory), + get_my_raw_units_amount(obs, units.Terran.MissileTurret) / 4, + get_my_raw_units_amount(obs, units.Terran.SensorTower) / 1, + get_my_raw_units_amount(obs, units.Terran.Bunker) / 4, + get_my_raw_units_amount(obs, units.Terran.FusionCore), + get_my_raw_units_amount(obs, units.Terran.GhostAcademy), + get_my_raw_units_amount(obs, units.Terran.Barracks) / 3, + get_my_raw_units_amount(obs, units.Terran.Factory) / 2, + get_my_raw_units_amount(obs, units.Terran.Starport) / 2, + ] + ) new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution) + obs, new_state, self.grid_size, self.raw_resolution + ) self._dimension = len(new_state) final_state = np.expand_dims(new_state, axis=0) @@ -74,36 +78,3 @@ def dimension(self): def reset(self): self._state = None self._dimension = None - -def append_player_and_enemy_grids(obs, new_state, grid_size, raw_resolution): - """ Instead of making a vector for all coordnates on the map, we'll - discretize our enemy space and use a grid to store enemy positions - by marking a square as 1 if there's any enemy on it.""" - enemy_grid = np.zeros((grid_size, grid_size)) - player_grid = np.zeros((grid_size, grid_size)) - - enemy_units = [unit for unit in obs.raw_units if - unit.alliance == features.PlayerRelative.ENEMY] - player_units = [unit for unit in obs.raw_units if - unit.alliance == features.PlayerRelative.SELF] - raw_to_grid_ratio = raw_resolution / grid_size - - for enemy_index in range(0, len(enemy_units)): - y = int(math.ceil((enemy_units[enemy_index].x + 1) / raw_to_grid_ratio)) - x = int(math.ceil((enemy_units[enemy_index].y + 1) / raw_to_grid_ratio)) - enemy_grid[x - 1][y - 1] += 1 - - for player_index in range(0, len(player_units)): - y = int(math.ceil((player_units[player_index].x + 1) / raw_to_grid_ratio)) - x = int(math.ceil((player_units[player_index].y + 1) / raw_to_grid_ratio)) - player_grid[x - 1][y - 1] += 1 - - # Normalizing the values to always be between 0 and 1 - # (since the max amount of units in SC2 is 200) - enemy_grid = enemy_grid / 200 - player_grid = player_grid / 200 - - new_state.extend(enemy_grid.flatten()) - new_state.extend(player_grid.flatten()) - - return new_state From b1d21e4a940ae7420b0bb499622af6d9195af04b Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 2 Jul 2024 14:39:45 -0300 Subject: [PATCH 05/11] feat: Added protoss and zerg states --- tests/units/sc2/states/test_sc2_state.py | 123 ++++++++++++++++++++++- urnai/sc2/states/protoss_state.py | 78 ++++++++++++++ urnai/sc2/states/zerg_state.py | 78 ++++++++++++++ 3 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 urnai/sc2/states/protoss_state.py create mode 100644 urnai/sc2/states/zerg_state.py diff --git a/tests/units/sc2/states/test_sc2_state.py b/tests/units/sc2/states/test_sc2_state.py index add4f73b..6ca7cf59 100644 --- a/tests/units/sc2/states/test_sc2_state.py +++ b/tests/units/sc2/states/test_sc2_state.py @@ -2,7 +2,9 @@ from pysc2.lib.named_array import NamedDict +from urnai.sc2.states.protoss_state import ProtossState from urnai.sc2.states.terran_state import TerranState +from urnai.sc2.states.zerg_state import ZergState class TestSC2State(unittest.TestCase): @@ -30,7 +32,6 @@ def test_terran_state_no_raw(self): assert state.dimension == 9 assert len(state.state[0]) == 9 - def test_terran_state_raw(self): # GIVEN state = TerranState(grid_size=4, use_raw_units=True) @@ -66,4 +67,122 @@ def test_terran_state_raw(self): state.update(obs) # THEN assert state.dimension == 22 + ((4 * 4) * 2) - assert len(state.state[0]) == 22 + ((4 * 4) * 2) \ No newline at end of file + assert len(state.state[0]) == 22 + ((4 * 4) * 2) + + def test_protoss_state_no_raw(self): + # GIVEN + state = ProtossState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + print(state.state) + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_protoss_state_raw(self): + # GIVEN + state = ProtossState(grid_size=5, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 22 + ((5 * 5) * 2) + assert len(state.state[0]) == 22 + ((5 * 5) * 2) + + def test_zerg_state_no_raw(self): + # GIVEN + state = ZergState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + print(state.state) + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_zerg_state_raw(self): + # GIVEN + state = ZergState(grid_size=6, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 22 + ((6 * 6) * 2) + assert len(state.state[0]) == 22 + ((6 * 6) * 2) diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py new file mode 100644 index 00000000..f176ada2 --- /dev/null +++ b/urnai/sc2/states/protoss_state.py @@ -0,0 +1,78 @@ +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import units + +from urnai.sc2.states.states_utils import ( + append_player_and_enemy_grids, + get_my_raw_units_amount, +) +from urnai.states.state_base import StateBase + + +class ProtossState(StateBase): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + self.grid_size = grid_size + self.player_race = sc2_env.Race.protoss + self.use_raw_units = use_raw_units + self.raw_resolution = raw_resolution + self.reset() + + def update(self, obs): + new_state = [ + # Adds general information from the player. + obs.player.minerals / 6000, + obs.player.vespene / 6000, + obs.player.food_cap / 200, + obs.player.food_used / 200, + obs.player.food_army / 200, + obs.player.food_workers / 200, + (obs.player.food_cap - obs.player.food_used) / 200, + obs.player.army_count / 200, + obs.player.idle_worker_count / 200, + ] + + if self.use_raw_units: + new_state.extend( + [ + # Adds information related to player's Protoss units/buildings. + get_my_raw_units_amount(obs, units.Protoss.Nexus), + get_my_raw_units_amount(obs, units.Protoss.Pylon), + get_my_raw_units_amount(obs, units.Protoss.Assimilator), + get_my_raw_units_amount(obs, units.Protoss.Forge), + get_my_raw_units_amount(obs, units.Protoss.Gateway), + get_my_raw_units_amount(obs, units.Protoss.CyberneticsCore), + get_my_raw_units_amount(obs, units.Protoss.PhotonCannon), + get_my_raw_units_amount(obs, units.Protoss.RoboticsFacility), + get_my_raw_units_amount(obs, units.Protoss.Stargate), + get_my_raw_units_amount(obs, units.Protoss.TwilightCouncil), + get_my_raw_units_amount(obs, units.Protoss.RoboticsBay), + get_my_raw_units_amount(obs, units.Protoss.TemplarArchive), + get_my_raw_units_amount(obs, units.Protoss.DarkShrine), + ] + ) + new_state = append_player_and_enemy_grids( + obs, new_state, self.grid_size, self.raw_resolution + ) + + self._dimension = len(new_state) + final_state = np.expand_dims(new_state, axis=0) + self._state = final_state + return final_state + + @property + def state(self): + return self._state + + @property + def dimension(self): + return self._dimension + + def reset(self): + self._state = None + self._dimension = None diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py new file mode 100644 index 00000000..48045ff2 --- /dev/null +++ b/urnai/sc2/states/zerg_state.py @@ -0,0 +1,78 @@ +import numpy as np +from pysc2.env import sc2_env +from pysc2.lib import units + +from urnai.sc2.states.states_utils import ( + append_player_and_enemy_grids, + get_my_raw_units_amount, +) +from urnai.states.state_base import StateBase + + +class ZergState(StateBase): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + self.grid_size = grid_size + self.player_race = sc2_env.Race.zerg + self.use_raw_units = use_raw_units + self.raw_resolution = raw_resolution + self.reset() + + def update(self, obs): + new_state = [ + # Adds general information from the player. + obs.player.minerals / 6000, + obs.player.vespene / 6000, + obs.player.food_cap / 200, + obs.player.food_used / 200, + obs.player.food_army / 200, + obs.player.food_workers / 200, + (obs.player.food_cap - obs.player.food_used) / 200, + obs.player.army_count / 200, + obs.player.idle_worker_count / 200, + ] + + if self.use_raw_units: + new_state.extend( + [ + # Adds information related to player's Zerg units/buildings. + get_my_raw_units_amount(obs, units.Zerg.BanelingNest), + get_my_raw_units_amount(obs, units.Zerg.EvolutionChamber), + get_my_raw_units_amount(obs, units.Zerg.Extractor), + get_my_raw_units_amount(obs, units.Zerg.Hatchery), + get_my_raw_units_amount(obs, units.Zerg.HydraliskDen), + get_my_raw_units_amount(obs, units.Zerg.InfestationPit), + get_my_raw_units_amount(obs, units.Zerg.LurkerDen), + get_my_raw_units_amount(obs, units.Zerg.NydusNetwork), + get_my_raw_units_amount(obs, units.Zerg.RoachWarren), + get_my_raw_units_amount(obs, units.Zerg.SpawningPool), + get_my_raw_units_amount(obs, units.Zerg.SpineCrawler), + get_my_raw_units_amount(obs, units.Zerg.Spire), + get_my_raw_units_amount(obs, units.Zerg.SporeCrawler), + ] + ) + new_state = append_player_and_enemy_grids( + obs, new_state, self.grid_size, self.raw_resolution + ) + + self._dimension = len(new_state) + final_state = np.expand_dims(new_state, axis=0) + self._state = final_state + return final_state + + @property + def state(self): + return self._state + + @property + def dimension(self): + return self._dimension + + def reset(self): + self._state = None + self._dimension = None From 373aede364a4c1c58b07adf59a68b24a6e5918c3 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 5 Jul 2024 16:02:55 -0300 Subject: [PATCH 06/11] refactor: Added constants to improve readability --- urnai/sc2/states/protoss_state.py | 23 ++++++---- urnai/sc2/states/state_constants.py | 4 ++ urnai/sc2/states/states_utils.py | 69 ++++++++++++++++++----------- urnai/sc2/states/terran_state.py | 23 ++++++---- urnai/sc2/states/zerg_state.py | 23 ++++++---- 5 files changed, 88 insertions(+), 54 deletions(-) create mode 100644 urnai/sc2/states/state_constants.py diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py index f176ada2..31d7b3b3 100644 --- a/urnai/sc2/states/protoss_state.py +++ b/urnai/sc2/states/protoss_state.py @@ -2,6 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units +from urnai.sc2.states.state_constants import ( + MAX_MINERALS, + MAX_UNITS, + MAX_VESPENE, +) from urnai.sc2.states.states_utils import ( append_player_and_enemy_grids, get_my_raw_units_amount, @@ -26,15 +31,15 @@ def __init__( def update(self, obs): new_state = [ # Adds general information from the player. - obs.player.minerals / 6000, - obs.player.vespene / 6000, - obs.player.food_cap / 200, - obs.player.food_used / 200, - obs.player.food_army / 200, - obs.player.food_workers / 200, - (obs.player.food_cap - obs.player.food_used) / 200, - obs.player.army_count / 200, - obs.player.idle_worker_count / 200, + obs.player.minerals / MAX_MINERALS, + obs.player.vespene / MAX_VESPENE, + obs.player.food_cap / MAX_UNITS, + obs.player.food_used / MAX_UNITS, + obs.player.food_army / MAX_UNITS, + obs.player.food_workers / MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / MAX_UNITS, + obs.player.army_count / MAX_UNITS, + obs.player.idle_worker_count / MAX_UNITS, ] if self.use_raw_units: diff --git a/urnai/sc2/states/state_constants.py b/urnai/sc2/states/state_constants.py new file mode 100644 index 00000000..e02b4bbc --- /dev/null +++ b/urnai/sc2/states/state_constants.py @@ -0,0 +1,4 @@ +MAX_UNITS = 200 +MAX_MINERALS = 6000 +MAX_VESPENE = 6000 +BUILD_PROGRESS_COMPLETE = 100 \ No newline at end of file diff --git a/urnai/sc2/states/states_utils.py b/urnai/sc2/states/states_utils.py index ed94c096..f3bf46da 100644 --- a/urnai/sc2/states/states_utils.py +++ b/urnai/sc2/states/states_utils.py @@ -3,45 +3,60 @@ import numpy as np from pysc2.lib import features +from urnai.sc2.states.state_constants import BUILD_PROGRESS_COMPLETE, MAX_UNITS + + +def append_player_and_enemy_grids( + obs : list, + new_state : list, + grid_size : int, + raw_resolution : int, + ) -> list: + new_state = append_grid( + obs, new_state, grid_size, raw_resolution, features.PlayerRelative.ENEMY + ) + new_state = append_grid( + obs, new_state, grid_size, raw_resolution, features.PlayerRelative.SELF + ) + return new_state -def append_player_and_enemy_grids(obs, new_state, grid_size, raw_resolution): +def append_grid( + obs : list, + new_state : list, + grid_size : int, + raw_resolution : int, + alliance : features.PlayerRelative, + ) -> list: """ Instead of making a vector for all coordnates on the map, we'll - discretize our enemy space and use a grid to store enemy positions - by marking a square as 1 if there's any enemy on it.""" - enemy_grid = np.zeros((grid_size, grid_size)) - player_grid = np.zeros((grid_size, grid_size)) - - enemy_units = [unit for unit in obs.raw_units if - unit.alliance == features.PlayerRelative.ENEMY] - player_units = [unit for unit in obs.raw_units if - unit.alliance == features.PlayerRelative.SELF] - raw_to_grid_ratio = raw_resolution / grid_size + discretize our unit space and use a grid to store unit positions + by marking a square as 1 if there's any enemy/ally on it.""" + grid = np.zeros((grid_size, grid_size)) - for enemy_index in range(0, len(enemy_units)): - y = int(math.ceil((enemy_units[enemy_index].x + 1) / raw_to_grid_ratio)) - x = int(math.ceil((enemy_units[enemy_index].y + 1) / raw_to_grid_ratio)) - enemy_grid[x - 1][y - 1] += 1 + units = [unit for unit in obs.raw_units if + unit.alliance == alliance] + raw_to_grid_ratio = raw_resolution / grid_size - for player_index in range(0, len(player_units)): - y = int(math.ceil((player_units[player_index].x + 1) / raw_to_grid_ratio)) - x = int(math.ceil((player_units[player_index].y + 1) / raw_to_grid_ratio)) - player_grid[x - 1][y - 1] += 1 + for index in range(0, len(units)): + y = int(math.ceil((units[index].x + 1) / raw_to_grid_ratio)) + x = int(math.ceil((units[index].y + 1) / raw_to_grid_ratio)) + grid[x - 1][y - 1] += 1 # Normalizing the values to always be between 0 and 1 - # (since the max amount of units in SC2 is 200) - enemy_grid = enemy_grid / 200 - player_grid = player_grid / 200 + grid = grid / MAX_UNITS - new_state.extend(enemy_grid.flatten()) - new_state.extend(player_grid.flatten()) + new_state.extend(grid.flatten()) return new_state -def get_my_raw_units_amount(obs, unit_type): +def get_my_raw_units_amount(obs : list, unit_type : int) -> int: return len(get_raw_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) -def get_raw_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF): +def get_raw_units_by_type( + obs : list, + unit_type : int, + alliance : features.PlayerRelative = features.PlayerRelative.SELF, + ) -> list: return [unit for unit in obs.raw_units if unit.unit_type == unit_type and unit.alliance == alliance - and unit.build_progress == 100] \ No newline at end of file + and unit.build_progress == BUILD_PROGRESS_COMPLETE] \ No newline at end of file diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index e9c8c0ce..08c26f63 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -2,6 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units +from urnai.sc2.states.state_constants import ( + MAX_MINERALS, + MAX_UNITS, + MAX_VESPENE, +) from urnai.sc2.states.states_utils import ( append_player_and_enemy_grids, get_my_raw_units_amount, @@ -26,15 +31,15 @@ def __init__( def update(self, obs): new_state = [ # Adds general information from the player. - obs.player.minerals / 6000, - obs.player.vespene / 6000, - obs.player.food_cap / 200, - obs.player.food_used / 200, - obs.player.food_army / 200, - obs.player.food_workers / 200, - (obs.player.food_cap - obs.player.food_used) / 200, - obs.player.army_count / 200, - obs.player.idle_worker_count / 200, + obs.player.minerals / MAX_MINERALS, + obs.player.vespene / MAX_VESPENE, + obs.player.food_cap / MAX_UNITS, + obs.player.food_used / MAX_UNITS, + obs.player.food_army / MAX_UNITS, + obs.player.food_workers / MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / MAX_UNITS, + obs.player.army_count / MAX_UNITS, + obs.player.idle_worker_count / MAX_UNITS, ] if self.use_raw_units: diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py index 48045ff2..15114301 100644 --- a/urnai/sc2/states/zerg_state.py +++ b/urnai/sc2/states/zerg_state.py @@ -2,6 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units +from urnai.sc2.states.state_constants import ( + MAX_MINERALS, + MAX_UNITS, + MAX_VESPENE, +) from urnai.sc2.states.states_utils import ( append_player_and_enemy_grids, get_my_raw_units_amount, @@ -26,15 +31,15 @@ def __init__( def update(self, obs): new_state = [ # Adds general information from the player. - obs.player.minerals / 6000, - obs.player.vespene / 6000, - obs.player.food_cap / 200, - obs.player.food_used / 200, - obs.player.food_army / 200, - obs.player.food_workers / 200, - (obs.player.food_cap - obs.player.food_used) / 200, - obs.player.army_count / 200, - obs.player.idle_worker_count / 200, + obs.player.minerals / MAX_MINERALS, + obs.player.vespene / MAX_VESPENE, + obs.player.food_cap / MAX_UNITS, + obs.player.food_used / MAX_UNITS, + obs.player.food_army / MAX_UNITS, + obs.player.food_workers / MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / MAX_UNITS, + obs.player.army_count / MAX_UNITS, + obs.player.idle_worker_count / MAX_UNITS, ] if self.use_raw_units: From 98166c909531ee955f5a7e583872421f26ca44d2 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Thu, 11 Jul 2024 09:23:12 -0300 Subject: [PATCH 07/11] refactor: Added constants file --- urnai/constants.py | 8 ++++++++ urnai/sc2/states/protoss_state.py | 24 ++++++++++-------------- urnai/sc2/states/state_constants.py | 4 ---- urnai/sc2/states/states_utils.py | 6 +++--- urnai/sc2/states/terran_state.py | 24 ++++++++++-------------- urnai/sc2/states/zerg_state.py | 24 ++++++++++-------------- 6 files changed, 41 insertions(+), 49 deletions(-) create mode 100644 urnai/constants.py delete mode 100644 urnai/sc2/states/state_constants.py diff --git a/urnai/constants.py b/urnai/constants.py new file mode 100644 index 00000000..1be00b52 --- /dev/null +++ b/urnai/constants.py @@ -0,0 +1,8 @@ +from enum import IntEnum + + +class SC2Constants(IntEnum): + MAX_UNITS = 200 + MAX_MINERALS = 6000 + MAX_VESPENE = 6000 + BUILD_PROGRESS_COMPLETE = 100 \ No newline at end of file diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py index 31d7b3b3..159ca391 100644 --- a/urnai/sc2/states/protoss_state.py +++ b/urnai/sc2/states/protoss_state.py @@ -2,11 +2,7 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.sc2.states.state_constants import ( - MAX_MINERALS, - MAX_UNITS, - MAX_VESPENE, -) +from urnai.constants import SC2Constants from urnai.sc2.states.states_utils import ( append_player_and_enemy_grids, get_my_raw_units_amount, @@ -31,15 +27,15 @@ def __init__( def update(self, obs): new_state = [ # Adds general information from the player. - obs.player.minerals / MAX_MINERALS, - obs.player.vespene / MAX_VESPENE, - obs.player.food_cap / MAX_UNITS, - obs.player.food_used / MAX_UNITS, - obs.player.food_army / MAX_UNITS, - obs.player.food_workers / MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / MAX_UNITS, - obs.player.army_count / MAX_UNITS, - obs.player.idle_worker_count / MAX_UNITS, + obs.player.minerals / SC2Constants.MAX_MINERALS, + obs.player.vespene / SC2Constants.MAX_VESPENE, + obs.player.food_cap / SC2Constants.MAX_UNITS, + obs.player.food_used / SC2Constants.MAX_UNITS, + obs.player.food_army / SC2Constants.MAX_UNITS, + obs.player.food_workers / SC2Constants.MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, + obs.player.army_count / SC2Constants.MAX_UNITS, + obs.player.idle_worker_count / SC2Constants.MAX_UNITS, ] if self.use_raw_units: diff --git a/urnai/sc2/states/state_constants.py b/urnai/sc2/states/state_constants.py deleted file mode 100644 index e02b4bbc..00000000 --- a/urnai/sc2/states/state_constants.py +++ /dev/null @@ -1,4 +0,0 @@ -MAX_UNITS = 200 -MAX_MINERALS = 6000 -MAX_VESPENE = 6000 -BUILD_PROGRESS_COMPLETE = 100 \ No newline at end of file diff --git a/urnai/sc2/states/states_utils.py b/urnai/sc2/states/states_utils.py index f3bf46da..7f945d19 100644 --- a/urnai/sc2/states/states_utils.py +++ b/urnai/sc2/states/states_utils.py @@ -3,7 +3,7 @@ import numpy as np from pysc2.lib import features -from urnai.sc2.states.state_constants import BUILD_PROGRESS_COMPLETE, MAX_UNITS +from urnai.constants import SC2Constants def append_player_and_enemy_grids( @@ -42,7 +42,7 @@ def append_grid( grid[x - 1][y - 1] += 1 # Normalizing the values to always be between 0 and 1 - grid = grid / MAX_UNITS + grid = grid / SC2Constants.MAX_UNITS new_state.extend(grid.flatten()) @@ -59,4 +59,4 @@ def get_raw_units_by_type( return [unit for unit in obs.raw_units if unit.unit_type == unit_type and unit.alliance == alliance - and unit.build_progress == BUILD_PROGRESS_COMPLETE] \ No newline at end of file + and unit.build_progress == SC2Constants.BUILD_PROGRESS_COMPLETE] \ No newline at end of file diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 08c26f63..851d7e50 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -2,11 +2,7 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.sc2.states.state_constants import ( - MAX_MINERALS, - MAX_UNITS, - MAX_VESPENE, -) +from urnai.constants import SC2Constants from urnai.sc2.states.states_utils import ( append_player_and_enemy_grids, get_my_raw_units_amount, @@ -31,15 +27,15 @@ def __init__( def update(self, obs): new_state = [ # Adds general information from the player. - obs.player.minerals / MAX_MINERALS, - obs.player.vespene / MAX_VESPENE, - obs.player.food_cap / MAX_UNITS, - obs.player.food_used / MAX_UNITS, - obs.player.food_army / MAX_UNITS, - obs.player.food_workers / MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / MAX_UNITS, - obs.player.army_count / MAX_UNITS, - obs.player.idle_worker_count / MAX_UNITS, + obs.player.minerals / SC2Constants.MAX_MINERALS, + obs.player.vespene / SC2Constants.MAX_VESPENE, + obs.player.food_cap / SC2Constants.MAX_UNITS, + obs.player.food_used / SC2Constants.MAX_UNITS, + obs.player.food_army / SC2Constants.MAX_UNITS, + obs.player.food_workers / SC2Constants.MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, + obs.player.army_count / SC2Constants.MAX_UNITS, + obs.player.idle_worker_count / SC2Constants.MAX_UNITS, ] if self.use_raw_units: diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py index 15114301..441073e8 100644 --- a/urnai/sc2/states/zerg_state.py +++ b/urnai/sc2/states/zerg_state.py @@ -2,11 +2,7 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.sc2.states.state_constants import ( - MAX_MINERALS, - MAX_UNITS, - MAX_VESPENE, -) +from urnai.constants import SC2Constants from urnai.sc2.states.states_utils import ( append_player_and_enemy_grids, get_my_raw_units_amount, @@ -31,15 +27,15 @@ def __init__( def update(self, obs): new_state = [ # Adds general information from the player. - obs.player.minerals / MAX_MINERALS, - obs.player.vespene / MAX_VESPENE, - obs.player.food_cap / MAX_UNITS, - obs.player.food_used / MAX_UNITS, - obs.player.food_army / MAX_UNITS, - obs.player.food_workers / MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / MAX_UNITS, - obs.player.army_count / MAX_UNITS, - obs.player.idle_worker_count / MAX_UNITS, + obs.player.minerals / SC2Constants.MAX_MINERALS, + obs.player.vespene / SC2Constants.MAX_VESPENE, + obs.player.food_cap / SC2Constants.MAX_UNITS, + obs.player.food_used / SC2Constants.MAX_UNITS, + obs.player.food_army / SC2Constants.MAX_UNITS, + obs.player.food_workers / SC2Constants.MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, + obs.player.army_count / SC2Constants.MAX_UNITS, + obs.player.idle_worker_count / SC2Constants.MAX_UNITS, ] if self.use_raw_units: From 94263bc1843132c7aafcd4425539aea385e92f65 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 12 Jul 2024 15:04:34 -0300 Subject: [PATCH 08/11] refactor: Changed files organization --- tests/units/sc2/states/test_protoss_state.py | 68 +++++++ tests/units/sc2/states/test_sc2_state.py | 188 ------------------ tests/units/sc2/states/test_terran_state.py | 68 +++++++ ...aux_methods_sc2_state.py => test_utils.py} | 24 +-- tests/units/sc2/states/test_zerg_state.py | 68 +++++++ urnai/sc2/states/protoss_state.py | 30 +-- urnai/sc2/states/terran_state.py | 34 ++-- .../sc2/states/{states_utils.py => utils.py} | 20 +- urnai/sc2/states/zerg_state.py | 30 +-- 9 files changed, 275 insertions(+), 255 deletions(-) create mode 100644 tests/units/sc2/states/test_protoss_state.py delete mode 100644 tests/units/sc2/states/test_sc2_state.py create mode 100644 tests/units/sc2/states/test_terran_state.py rename tests/units/sc2/states/{test_aux_methods_sc2_state.py => test_utils.py} (89%) create mode 100644 tests/units/sc2/states/test_zerg_state.py rename urnai/sc2/states/{states_utils.py => utils.py} (76%) diff --git a/tests/units/sc2/states/test_protoss_state.py b/tests/units/sc2/states/test_protoss_state.py new file mode 100644 index 00000000..1abab130 --- /dev/null +++ b/tests/units/sc2/states/test_protoss_state.py @@ -0,0 +1,68 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.protoss_state import ProtossState + + +class TestProtossState(unittest.TestCase): + + + def test_protoss_state_no_raw(self): + # GIVEN + state = ProtossState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + print(state.state) + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_protoss_state_raw(self): + # GIVEN + state = ProtossState(grid_size=5, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 22 + ((5 * 5) * 2) + assert len(state.state[0]) == 22 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_sc2_state.py b/tests/units/sc2/states/test_sc2_state.py deleted file mode 100644 index 6ca7cf59..00000000 --- a/tests/units/sc2/states/test_sc2_state.py +++ /dev/null @@ -1,188 +0,0 @@ -import unittest - -from pysc2.lib.named_array import NamedDict - -from urnai.sc2.states.protoss_state import ProtossState -from urnai.sc2.states.terran_state import TerranState -from urnai.sc2.states.zerg_state import ZergState - - -class TestSC2State(unittest.TestCase): - - - def test_terran_state_no_raw(self): - # GIVEN - state = TerranState(use_raw_units=False) - obs = NamedDict({ - 'player': NamedDict({ - 'minerals': 100, - 'vespene': 100, - 'food_cap': 200, - 'food_used': 100, - 'food_army': 50, - 'food_workers': 50, - 'army_count': 20, - 'idle_worker_count': 10, - }) - }) - # WHEN - state.update(obs) - # THEN - print(state.state) - assert state.dimension == 9 - assert len(state.state[0]) == 9 - - def test_terran_state_raw(self): - # GIVEN - state = TerranState(grid_size=4, use_raw_units=True) - obs = NamedDict({ - 'player': NamedDict({ - 'minerals': 100, - 'vespene': 100, - 'food_cap': 200, - 'food_used': 100, - 'food_army': 50, - 'food_workers': 50, - 'army_count': 20, - 'idle_worker_count': 10, - }), - 'raw_units': [ - NamedDict({ - 'unit_type': 1, - 'alliance': 1, - 'build_progress': 100, - 'x': 1, - 'y': 1, - }), - NamedDict({ - 'unit_type': 2, - 'alliance': 1, - 'build_progress': 100, - 'x': 2, - 'y': 2, - }), - ] - }) - # WHEN - state.update(obs) - # THEN - assert state.dimension == 22 + ((4 * 4) * 2) - assert len(state.state[0]) == 22 + ((4 * 4) * 2) - - def test_protoss_state_no_raw(self): - # GIVEN - state = ProtossState(use_raw_units=False) - obs = NamedDict({ - 'player': NamedDict({ - 'minerals': 100, - 'vespene': 100, - 'food_cap': 200, - 'food_used': 100, - 'food_army': 50, - 'food_workers': 50, - 'army_count': 20, - 'idle_worker_count': 10, - }) - }) - # WHEN - state.update(obs) - # THEN - print(state.state) - assert state.dimension == 9 - assert len(state.state[0]) == 9 - - def test_protoss_state_raw(self): - # GIVEN - state = ProtossState(grid_size=5, use_raw_units=True) - obs = NamedDict({ - 'player': NamedDict({ - 'minerals': 100, - 'vespene': 100, - 'food_cap': 200, - 'food_used': 100, - 'food_army': 50, - 'food_workers': 50, - 'army_count': 20, - 'idle_worker_count': 10, - }), - 'raw_units': [ - NamedDict({ - 'unit_type': 1, - 'alliance': 1, - 'build_progress': 100, - 'x': 1, - 'y': 1, - }), - NamedDict({ - 'unit_type': 2, - 'alliance': 1, - 'build_progress': 100, - 'x': 2, - 'y': 2, - }), - ] - }) - # WHEN - state.update(obs) - # THEN - assert state.dimension == 22 + ((5 * 5) * 2) - assert len(state.state[0]) == 22 + ((5 * 5) * 2) - - def test_zerg_state_no_raw(self): - # GIVEN - state = ZergState(use_raw_units=False) - obs = NamedDict({ - 'player': NamedDict({ - 'minerals': 100, - 'vespene': 100, - 'food_cap': 200, - 'food_used': 100, - 'food_army': 50, - 'food_workers': 50, - 'army_count': 20, - 'idle_worker_count': 10, - }) - }) - # WHEN - state.update(obs) - # THEN - print(state.state) - assert state.dimension == 9 - assert len(state.state[0]) == 9 - - def test_zerg_state_raw(self): - # GIVEN - state = ZergState(grid_size=6, use_raw_units=True) - obs = NamedDict({ - 'player': NamedDict({ - 'minerals': 100, - 'vespene': 100, - 'food_cap': 200, - 'food_used': 100, - 'food_army': 50, - 'food_workers': 50, - 'army_count': 20, - 'idle_worker_count': 10, - }), - 'raw_units': [ - NamedDict({ - 'unit_type': 1, - 'alliance': 1, - 'build_progress': 100, - 'x': 1, - 'y': 1, - }), - NamedDict({ - 'unit_type': 2, - 'alliance': 1, - 'build_progress': 100, - 'x': 2, - 'y': 2, - }), - ] - }) - # WHEN - state.update(obs) - # THEN - assert state.dimension == 22 + ((6 * 6) * 2) - assert len(state.state[0]) == 22 + ((6 * 6) * 2) diff --git a/tests/units/sc2/states/test_terran_state.py b/tests/units/sc2/states/test_terran_state.py new file mode 100644 index 00000000..4b8d3910 --- /dev/null +++ b/tests/units/sc2/states/test_terran_state.py @@ -0,0 +1,68 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.terran_state import TerranState + + +class TestTerranState(unittest.TestCase): + + + def test_terran_state_no_raw(self): + # GIVEN + state = TerranState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + print(state.state) + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_terran_state_raw(self): + # GIVEN + state = TerranState(grid_size=4, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 22 + ((4 * 4) * 2) + assert len(state.state[0]) == 22 + ((4 * 4) * 2) diff --git a/tests/units/sc2/states/test_aux_methods_sc2_state.py b/tests/units/sc2/states/test_utils.py similarity index 89% rename from tests/units/sc2/states/test_aux_methods_sc2_state.py rename to tests/units/sc2/states/test_utils.py index c29b2bf7..9e87edfe 100644 --- a/tests/units/sc2/states/test_aux_methods_sc2_state.py +++ b/tests/units/sc2/states/test_utils.py @@ -2,9 +2,9 @@ from pysc2.lib.named_array import NamedDict -from urnai.sc2.states.states_utils import ( +from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_my_raw_units_amount, + get_raw_units_amount, get_raw_units_by_type, ) @@ -47,7 +47,7 @@ def test_append_player_and_enemy_grids(self): assert new_state[8] == 0.005 assert new_state[9] == 0.01 - def test_get_my_raw_units_amount(self): + def test_get_raw_units_amount(self): # GIVEN obs = NamedDict({ 'raw_units': [ @@ -68,21 +68,21 @@ def test_get_my_raw_units_amount(self): ] }) # WHEN - amount = get_my_raw_units_amount(obs, 1) + amount = get_raw_units_amount(obs, 1) # THEN assert amount == 1 - def test_get_my_raw_units_amount_no_units(self): + def test_get_raw_units_amount_no_units(self): # GIVEN obs = NamedDict({ 'raw_units': [] }) # WHEN - amount = get_my_raw_units_amount(obs, 1) + amount = get_raw_units_amount(obs, 1) # THEN assert amount == 0 - def test_get_my_raw_units_amount_no_units_of_type(self): + def test_get_raw_units_amount_no_units_of_type(self): # GIVEN obs = NamedDict({ 'raw_units': [ @@ -103,11 +103,11 @@ def test_get_my_raw_units_amount_no_units_of_type(self): ] }) # WHEN - amount = get_my_raw_units_amount(obs, 1) + amount = get_raw_units_amount(obs, 1) # THEN assert amount == 0 - def test_get_my_raw_units_amount_no_units_of_alliance(self): + def test_get_raw_units_amount_no_units_of_alliance(self): # GIVEN obs = NamedDict({ 'raw_units': [ @@ -128,11 +128,11 @@ def test_get_my_raw_units_amount_no_units_of_alliance(self): ] }) # WHEN - amount = get_my_raw_units_amount(obs, 1) + amount = get_raw_units_amount(obs, 1) # THEN assert amount == 0 - def test_get_my_raw_units_amount_no_units_of_build_progress(self): + def test_get_raw_units_amount_no_units_of_build_progress(self): # GIVEN obs = NamedDict({ 'raw_units': [ @@ -153,7 +153,7 @@ def test_get_my_raw_units_amount_no_units_of_build_progress(self): ] }) # WHEN - amount = get_my_raw_units_amount(obs, 1) + amount = get_raw_units_amount(obs, 1) # THEN assert amount == 0 diff --git a/tests/units/sc2/states/test_zerg_state.py b/tests/units/sc2/states/test_zerg_state.py new file mode 100644 index 00000000..0da32622 --- /dev/null +++ b/tests/units/sc2/states/test_zerg_state.py @@ -0,0 +1,68 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.zerg_state import ZergState + + +class TestZergState(unittest.TestCase): + + + def test_zerg_state_no_raw(self): + # GIVEN + state = ZergState(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + print(state.state) + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_zerg_state_raw(self): + # GIVEN + state = ZergState(grid_size=6, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 22 + ((6 * 6) * 2) + assert len(state.state[0]) == 22 + ((6 * 6) * 2) diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py index 159ca391..38ccabd0 100644 --- a/urnai/sc2/states/protoss_state.py +++ b/urnai/sc2/states/protoss_state.py @@ -3,9 +3,9 @@ from pysc2.lib import units from urnai.constants import SC2Constants -from urnai.sc2.states.states_utils import ( +from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_my_raw_units_amount, + get_raw_units_amount, ) from urnai.states.state_base import StateBase @@ -42,19 +42,19 @@ def update(self, obs): new_state.extend( [ # Adds information related to player's Protoss units/buildings. - get_my_raw_units_amount(obs, units.Protoss.Nexus), - get_my_raw_units_amount(obs, units.Protoss.Pylon), - get_my_raw_units_amount(obs, units.Protoss.Assimilator), - get_my_raw_units_amount(obs, units.Protoss.Forge), - get_my_raw_units_amount(obs, units.Protoss.Gateway), - get_my_raw_units_amount(obs, units.Protoss.CyberneticsCore), - get_my_raw_units_amount(obs, units.Protoss.PhotonCannon), - get_my_raw_units_amount(obs, units.Protoss.RoboticsFacility), - get_my_raw_units_amount(obs, units.Protoss.Stargate), - get_my_raw_units_amount(obs, units.Protoss.TwilightCouncil), - get_my_raw_units_amount(obs, units.Protoss.RoboticsBay), - get_my_raw_units_amount(obs, units.Protoss.TemplarArchive), - get_my_raw_units_amount(obs, units.Protoss.DarkShrine), + get_raw_units_amount(obs, units.Protoss.Nexus), + get_raw_units_amount(obs, units.Protoss.Pylon), + get_raw_units_amount(obs, units.Protoss.Assimilator), + get_raw_units_amount(obs, units.Protoss.Forge), + get_raw_units_amount(obs, units.Protoss.Gateway), + get_raw_units_amount(obs, units.Protoss.CyberneticsCore), + get_raw_units_amount(obs, units.Protoss.PhotonCannon), + get_raw_units_amount(obs, units.Protoss.RoboticsFacility), + get_raw_units_amount(obs, units.Protoss.Stargate), + get_raw_units_amount(obs, units.Protoss.TwilightCouncil), + get_raw_units_amount(obs, units.Protoss.RoboticsBay), + get_raw_units_amount(obs, units.Protoss.TemplarArchive), + get_raw_units_amount(obs, units.Protoss.DarkShrine), ] ) new_state = append_player_and_enemy_grids( diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 851d7e50..68cd7150 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -3,9 +3,9 @@ from pysc2.lib import units from urnai.constants import SC2Constants -from urnai.sc2.states.states_utils import ( +from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_my_raw_units_amount, + get_raw_units_amount, ) from urnai.states.state_base import StateBase @@ -42,21 +42,21 @@ def update(self, obs): new_state.extend( [ # Adds information related to player's Terran units/buildings. - get_my_raw_units_amount(obs, units.Terran.CommandCenter) - + get_my_raw_units_amount(obs, units.Terran.OrbitalCommand) - + get_my_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2, - get_my_raw_units_amount(obs, units.Terran.SupplyDepot) / 18, - get_my_raw_units_amount(obs, units.Terran.Refinery) / 4, - get_my_raw_units_amount(obs, units.Terran.EngineeringBay), - get_my_raw_units_amount(obs, units.Terran.Armory), - get_my_raw_units_amount(obs, units.Terran.MissileTurret) / 4, - get_my_raw_units_amount(obs, units.Terran.SensorTower) / 1, - get_my_raw_units_amount(obs, units.Terran.Bunker) / 4, - get_my_raw_units_amount(obs, units.Terran.FusionCore), - get_my_raw_units_amount(obs, units.Terran.GhostAcademy), - get_my_raw_units_amount(obs, units.Terran.Barracks) / 3, - get_my_raw_units_amount(obs, units.Terran.Factory) / 2, - get_my_raw_units_amount(obs, units.Terran.Starport) / 2, + get_raw_units_amount(obs, units.Terran.CommandCenter) + + get_raw_units_amount(obs, units.Terran.OrbitalCommand) + + get_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2, + get_raw_units_amount(obs, units.Terran.SupplyDepot) / 18, + get_raw_units_amount(obs, units.Terran.Refinery) / 4, + get_raw_units_amount(obs, units.Terran.EngineeringBay), + get_raw_units_amount(obs, units.Terran.Armory), + get_raw_units_amount(obs, units.Terran.MissileTurret) / 4, + get_raw_units_amount(obs, units.Terran.SensorTower) / 1, + get_raw_units_amount(obs, units.Terran.Bunker) / 4, + get_raw_units_amount(obs, units.Terran.FusionCore), + get_raw_units_amount(obs, units.Terran.GhostAcademy), + get_raw_units_amount(obs, units.Terran.Barracks) / 3, + get_raw_units_amount(obs, units.Terran.Factory) / 2, + get_raw_units_amount(obs, units.Terran.Starport) / 2, ] ) new_state = append_player_and_enemy_grids( diff --git a/urnai/sc2/states/states_utils.py b/urnai/sc2/states/utils.py similarity index 76% rename from urnai/sc2/states/states_utils.py rename to urnai/sc2/states/utils.py index 7f945d19..d36a3559 100644 --- a/urnai/sc2/states/states_utils.py +++ b/urnai/sc2/states/utils.py @@ -12,20 +12,20 @@ def append_player_and_enemy_grids( grid_size : int, raw_resolution : int, ) -> list: - new_state = append_grid( - obs, new_state, grid_size, raw_resolution, features.PlayerRelative.ENEMY + new_state = append_alliance_grid( + features.PlayerRelative.ENEMY, obs, new_state, grid_size, raw_resolution ) - new_state = append_grid( - obs, new_state, grid_size, raw_resolution, features.PlayerRelative.SELF + new_state = append_alliance_grid( + features.PlayerRelative.SELF, obs, new_state, grid_size, raw_resolution ) return new_state -def append_grid( +def append_alliance_grid( + alliance : features.PlayerRelative, obs : list, new_state : list, grid_size : int, raw_resolution : int, - alliance : features.PlayerRelative, ) -> list: """ Instead of making a vector for all coordnates on the map, we'll discretize our unit space and use a grid to store unit positions @@ -48,8 +48,12 @@ def append_grid( return new_state -def get_my_raw_units_amount(obs : list, unit_type : int) -> int: - return len(get_raw_units_by_type(obs, unit_type, features.PlayerRelative.SELF)) +def get_raw_units_amount( + obs : list, + unit_type : int, + alliance : features.PlayerRelative = features.PlayerRelative.SELF, + ) -> int: + return len(get_raw_units_by_type(obs, unit_type, alliance)) def get_raw_units_by_type( obs : list, diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py index 441073e8..bf905c29 100644 --- a/urnai/sc2/states/zerg_state.py +++ b/urnai/sc2/states/zerg_state.py @@ -3,9 +3,9 @@ from pysc2.lib import units from urnai.constants import SC2Constants -from urnai.sc2.states.states_utils import ( +from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_my_raw_units_amount, + get_raw_units_amount, ) from urnai.states.state_base import StateBase @@ -42,19 +42,19 @@ def update(self, obs): new_state.extend( [ # Adds information related to player's Zerg units/buildings. - get_my_raw_units_amount(obs, units.Zerg.BanelingNest), - get_my_raw_units_amount(obs, units.Zerg.EvolutionChamber), - get_my_raw_units_amount(obs, units.Zerg.Extractor), - get_my_raw_units_amount(obs, units.Zerg.Hatchery), - get_my_raw_units_amount(obs, units.Zerg.HydraliskDen), - get_my_raw_units_amount(obs, units.Zerg.InfestationPit), - get_my_raw_units_amount(obs, units.Zerg.LurkerDen), - get_my_raw_units_amount(obs, units.Zerg.NydusNetwork), - get_my_raw_units_amount(obs, units.Zerg.RoachWarren), - get_my_raw_units_amount(obs, units.Zerg.SpawningPool), - get_my_raw_units_amount(obs, units.Zerg.SpineCrawler), - get_my_raw_units_amount(obs, units.Zerg.Spire), - get_my_raw_units_amount(obs, units.Zerg.SporeCrawler), + get_raw_units_amount(obs, units.Zerg.BanelingNest), + get_raw_units_amount(obs, units.Zerg.EvolutionChamber), + get_raw_units_amount(obs, units.Zerg.Extractor), + get_raw_units_amount(obs, units.Zerg.Hatchery), + get_raw_units_amount(obs, units.Zerg.HydraliskDen), + get_raw_units_amount(obs, units.Zerg.InfestationPit), + get_raw_units_amount(obs, units.Zerg.LurkerDen), + get_raw_units_amount(obs, units.Zerg.NydusNetwork), + get_raw_units_amount(obs, units.Zerg.RoachWarren), + get_raw_units_amount(obs, units.Zerg.SpawningPool), + get_raw_units_amount(obs, units.Zerg.SpineCrawler), + get_raw_units_amount(obs, units.Zerg.Spire), + get_raw_units_amount(obs, units.Zerg.SporeCrawler), ] ) new_state = append_player_and_enemy_grids( From 4d2f0551efb7b9caa23afdd3a21d15bf74fc66d2 Mon Sep 17 00:00:00 2001 From: RickFqt Date: Fri, 12 Jul 2024 17:24:49 -0300 Subject: [PATCH 09/11] refactor: Changed units amount functions --- tests/units/sc2/states/test_protoss_state.py | 1 - tests/units/sc2/states/test_terran_state.py | 1 - tests/units/sc2/states/test_utils.py | 130 +++++-------------- tests/units/sc2/states/test_zerg_state.py | 1 - urnai/sc2/states/protoss_state.py | 30 +++-- urnai/sc2/states/terran_state.py | 34 ++--- urnai/sc2/states/utils.py | 27 ++-- urnai/sc2/states/zerg_state.py | 30 +++-- 8 files changed, 96 insertions(+), 158 deletions(-) diff --git a/tests/units/sc2/states/test_protoss_state.py b/tests/units/sc2/states/test_protoss_state.py index 1abab130..bfcc8ecf 100644 --- a/tests/units/sc2/states/test_protoss_state.py +++ b/tests/units/sc2/states/test_protoss_state.py @@ -26,7 +26,6 @@ def test_protoss_state_no_raw(self): # WHEN state.update(obs) # THEN - print(state.state) assert state.dimension == 9 assert len(state.state[0]) == 9 diff --git a/tests/units/sc2/states/test_terran_state.py b/tests/units/sc2/states/test_terran_state.py index 4b8d3910..40d46330 100644 --- a/tests/units/sc2/states/test_terran_state.py +++ b/tests/units/sc2/states/test_terran_state.py @@ -26,7 +26,6 @@ def test_terran_state_no_raw(self): # WHEN state.update(obs) # THEN - print(state.state) assert state.dimension == 9 assert len(state.state[0]) == 9 diff --git a/tests/units/sc2/states/test_utils.py b/tests/units/sc2/states/test_utils.py index 9e87edfe..78cd601a 100644 --- a/tests/units/sc2/states/test_utils.py +++ b/tests/units/sc2/states/test_utils.py @@ -4,8 +4,7 @@ from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_raw_units_amount, - get_raw_units_by_type, + create_raw_units_amount_dict, ) @@ -42,12 +41,11 @@ def test_append_player_and_enemy_grids(self): # WHEN new_state = append_player_and_enemy_grids(obs, new_state, 3, 64) # THEN - print(new_state) assert len(new_state) == (18) assert new_state[8] == 0.005 assert new_state[9] == 0.01 - def test_get_raw_units_amount(self): + def test_create_raw_units_amount_dict(self): # GIVEN obs = NamedDict({ 'raw_units': [ @@ -65,29 +63,28 @@ def test_get_raw_units_amount(self): 'x': 2, 'y': 2, }), + NamedDict({ + 'unit_type': 2, + 'alliance': 4, + 'build_progress': 100, + 'x': 63, + 'y': 63, + }), ] }) # WHEN - amount = get_raw_units_amount(obs, 1) - # THEN - assert amount == 1 - - def test_get_raw_units_amount_no_units(self): - # GIVEN - obs = NamedDict({ - 'raw_units': [] - }) - # WHEN - amount = get_raw_units_amount(obs, 1) + dict = create_raw_units_amount_dict(obs, 1) # THEN - assert amount == 0 + assert len(dict) == 2 + assert dict[1] == 1 + assert dict[2] == 1 - def test_get_raw_units_amount_no_units_of_type(self): + def test_create_raw_units_amount_dict_alliance(self): # GIVEN obs = NamedDict({ 'raw_units': [ NamedDict({ - 'unit_type': 2, + 'unit_type': 1, 'alliance': 1, 'build_progress': 100, 'x': 1, @@ -100,64 +97,34 @@ def test_get_raw_units_amount_no_units_of_type(self): 'x': 2, 'y': 2, }), - ] - }) - # WHEN - amount = get_raw_units_amount(obs, 1) - # THEN - assert amount == 0 - - def test_get_raw_units_amount_no_units_of_alliance(self): - # GIVEN - obs = NamedDict({ - 'raw_units': [ - NamedDict({ - 'unit_type': 1, - 'alliance': 2, - 'build_progress': 100, - 'x': 1, - 'y': 1, - }), NamedDict({ - 'unit_type': 1, - 'alliance': 2, + 'unit_type': 2, + 'alliance': 4, 'build_progress': 100, - 'x': 2, - 'y': 2, + 'x': 63, + 'y': 63, }), ] }) # WHEN - amount = get_raw_units_amount(obs, 1) + dict = create_raw_units_amount_dict(obs, 4) # THEN - assert amount == 0 + assert len(dict) == 1 + assert dict[1] == 0 + assert dict[2] == 1 - def test_get_raw_units_amount_no_units_of_build_progress(self): + def test_create_raw_units_amount_dict_no_units(self): # GIVEN obs = NamedDict({ - 'raw_units': [ - NamedDict({ - 'unit_type': 1, - 'alliance': 1, - 'build_progress': 50, - 'x': 1, - 'y': 1, - }), - NamedDict({ - 'unit_type': 1, - 'alliance': 1, - 'build_progress': 50, - 'x': 2, - 'y': 2, - }), - ] + 'raw_units': [] }) # WHEN - amount = get_raw_units_amount(obs, 1) + dict = create_raw_units_amount_dict(obs) # THEN - assert amount == 0 + assert len(dict) == 0 + assert dict == {} - def test_get_raw_units_by_type(self): + def test_create_raw_units_amount_dict_no_units_alliance(self): # GIVEN obs = NamedDict({ 'raw_units': [ @@ -175,44 +142,17 @@ def test_get_raw_units_by_type(self): 'x': 2, 'y': 2, }), - ] - }) - # WHEN - units = get_raw_units_by_type(obs, 1) - # THEN - assert len(units) == 1 - - def test_get_raw_units_by_type_no_units(self): - # GIVEN - obs = NamedDict({ - 'raw_units': [] - }) - # WHEN - units = get_raw_units_by_type(obs, 1) - # THEN - assert len(units) == 0 - - def test_get_raw_units_by_type_no_units_of_type(self): - # GIVEN - obs = NamedDict({ - 'raw_units': [ - NamedDict({ - 'unit_type': 2, - 'alliance': 1, - 'build_progress': 100, - 'x': 1, - 'y': 1, - }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': 4, 'build_progress': 100, - 'x': 2, - 'y': 2, + 'x': 63, + 'y': 63, }), ] }) # WHEN - units = get_raw_units_by_type(obs, 1) + dict = create_raw_units_amount_dict(obs, 3) # THEN - assert len(units) == 0 \ No newline at end of file + assert len(dict) == 0 + assert dict == {} \ No newline at end of file diff --git a/tests/units/sc2/states/test_zerg_state.py b/tests/units/sc2/states/test_zerg_state.py index 0da32622..348fdf99 100644 --- a/tests/units/sc2/states/test_zerg_state.py +++ b/tests/units/sc2/states/test_zerg_state.py @@ -26,7 +26,6 @@ def test_zerg_state_no_raw(self): # WHEN state.update(obs) # THEN - print(state.state) assert state.dimension == 9 assert len(state.state[0]) == 9 diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py index 38ccabd0..0fe6b9e1 100644 --- a/urnai/sc2/states/protoss_state.py +++ b/urnai/sc2/states/protoss_state.py @@ -5,7 +5,7 @@ from urnai.constants import SC2Constants from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_raw_units_amount, + create_raw_units_amount_dict, ) from urnai.states.state_base import StateBase @@ -39,22 +39,24 @@ def update(self, obs): ] if self.use_raw_units: + raw_units_amount_dict = create_raw_units_amount_dict( + obs, sc2_env.features.PlayerRelative.SELF) new_state.extend( [ # Adds information related to player's Protoss units/buildings. - get_raw_units_amount(obs, units.Protoss.Nexus), - get_raw_units_amount(obs, units.Protoss.Pylon), - get_raw_units_amount(obs, units.Protoss.Assimilator), - get_raw_units_amount(obs, units.Protoss.Forge), - get_raw_units_amount(obs, units.Protoss.Gateway), - get_raw_units_amount(obs, units.Protoss.CyberneticsCore), - get_raw_units_amount(obs, units.Protoss.PhotonCannon), - get_raw_units_amount(obs, units.Protoss.RoboticsFacility), - get_raw_units_amount(obs, units.Protoss.Stargate), - get_raw_units_amount(obs, units.Protoss.TwilightCouncil), - get_raw_units_amount(obs, units.Protoss.RoboticsBay), - get_raw_units_amount(obs, units.Protoss.TemplarArchive), - get_raw_units_amount(obs, units.Protoss.DarkShrine), + raw_units_amount_dict[units.Protoss.Nexus], + raw_units_amount_dict[units.Protoss.Pylon], + raw_units_amount_dict[units.Protoss.Assimilator], + raw_units_amount_dict[units.Protoss.Forge], + raw_units_amount_dict[units.Protoss.Gateway], + raw_units_amount_dict[units.Protoss.CyberneticsCore], + raw_units_amount_dict[units.Protoss.PhotonCannon], + raw_units_amount_dict[units.Protoss.RoboticsFacility], + raw_units_amount_dict[units.Protoss.Stargate], + raw_units_amount_dict[units.Protoss.TwilightCouncil], + raw_units_amount_dict[units.Protoss.RoboticsBay], + raw_units_amount_dict[units.Protoss.TemplarArchive], + raw_units_amount_dict[units.Protoss.DarkShrine], ] ) new_state = append_player_and_enemy_grids( diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 68cd7150..283ff694 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -5,7 +5,7 @@ from urnai.constants import SC2Constants from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_raw_units_amount, + create_raw_units_amount_dict, ) from urnai.states.state_base import StateBase @@ -39,24 +39,26 @@ def update(self, obs): ] if self.use_raw_units: + raw_units_amount_dict = create_raw_units_amount_dict( + obs, sc2_env.features.PlayerRelative.SELF) new_state.extend( [ # Adds information related to player's Terran units/buildings. - get_raw_units_amount(obs, units.Terran.CommandCenter) - + get_raw_units_amount(obs, units.Terran.OrbitalCommand) - + get_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2, - get_raw_units_amount(obs, units.Terran.SupplyDepot) / 18, - get_raw_units_amount(obs, units.Terran.Refinery) / 4, - get_raw_units_amount(obs, units.Terran.EngineeringBay), - get_raw_units_amount(obs, units.Terran.Armory), - get_raw_units_amount(obs, units.Terran.MissileTurret) / 4, - get_raw_units_amount(obs, units.Terran.SensorTower) / 1, - get_raw_units_amount(obs, units.Terran.Bunker) / 4, - get_raw_units_amount(obs, units.Terran.FusionCore), - get_raw_units_amount(obs, units.Terran.GhostAcademy), - get_raw_units_amount(obs, units.Terran.Barracks) / 3, - get_raw_units_amount(obs, units.Terran.Factory) / 2, - get_raw_units_amount(obs, units.Terran.Starport) / 2, + raw_units_amount_dict[units.Terran.CommandCenter] + + raw_units_amount_dict[units.Terran.OrbitalCommand] + + raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2, + raw_units_amount_dict[units.Terran.SupplyDepot] / 18, + raw_units_amount_dict[units.Terran.Refinery] / 4, + raw_units_amount_dict[units.Terran.EngineeringBay], + raw_units_amount_dict[units.Terran.Armory], + raw_units_amount_dict[units.Terran.MissileTurret] / 4, + raw_units_amount_dict[units.Terran.SensorTower] / 1, + raw_units_amount_dict[units.Terran.Bunker] / 4, + raw_units_amount_dict[units.Terran.FusionCore], + raw_units_amount_dict[units.Terran.GhostAcademy], + raw_units_amount_dict[units.Terran.Barracks] / 3, + raw_units_amount_dict[units.Terran.Factory] / 2, + raw_units_amount_dict[units.Terran.Starport] / 2, ] ) new_state = append_player_and_enemy_grids( diff --git a/urnai/sc2/states/utils.py b/urnai/sc2/states/utils.py index d36a3559..a5347ada 100644 --- a/urnai/sc2/states/utils.py +++ b/urnai/sc2/states/utils.py @@ -1,4 +1,5 @@ import math +from collections import defaultdict import numpy as np from pysc2.lib import features @@ -48,19 +49,13 @@ def append_alliance_grid( return new_state -def get_raw_units_amount( - obs : list, - unit_type : int, - alliance : features.PlayerRelative = features.PlayerRelative.SELF, - ) -> int: - return len(get_raw_units_by_type(obs, unit_type, alliance)) - -def get_raw_units_by_type( - obs : list, - unit_type : int, - alliance : features.PlayerRelative = features.PlayerRelative.SELF, - ) -> list: - return [unit for unit in obs.raw_units - if unit.unit_type == unit_type - and unit.alliance == alliance - and unit.build_progress == SC2Constants.BUILD_PROGRESS_COMPLETE] \ No newline at end of file +def create_raw_units_amount_dict( + obs : list, + alliance : features.PlayerRelative = features.PlayerRelative.SELF + ) -> defaultdict: + dict = defaultdict(lambda: 0) + for unit in obs.raw_units: + if (unit.alliance == alliance + and unit.build_progress == SC2Constants.BUILD_PROGRESS_COMPLETE): + dict[unit.unit_type] += 1 + return dict diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py index bf905c29..1448696a 100644 --- a/urnai/sc2/states/zerg_state.py +++ b/urnai/sc2/states/zerg_state.py @@ -5,7 +5,7 @@ from urnai.constants import SC2Constants from urnai.sc2.states.utils import ( append_player_and_enemy_grids, - get_raw_units_amount, + create_raw_units_amount_dict, ) from urnai.states.state_base import StateBase @@ -39,22 +39,24 @@ def update(self, obs): ] if self.use_raw_units: + raw_units_amount_dict = create_raw_units_amount_dict( + obs, sc2_env.features.PlayerRelative.SELF) new_state.extend( [ # Adds information related to player's Zerg units/buildings. - get_raw_units_amount(obs, units.Zerg.BanelingNest), - get_raw_units_amount(obs, units.Zerg.EvolutionChamber), - get_raw_units_amount(obs, units.Zerg.Extractor), - get_raw_units_amount(obs, units.Zerg.Hatchery), - get_raw_units_amount(obs, units.Zerg.HydraliskDen), - get_raw_units_amount(obs, units.Zerg.InfestationPit), - get_raw_units_amount(obs, units.Zerg.LurkerDen), - get_raw_units_amount(obs, units.Zerg.NydusNetwork), - get_raw_units_amount(obs, units.Zerg.RoachWarren), - get_raw_units_amount(obs, units.Zerg.SpawningPool), - get_raw_units_amount(obs, units.Zerg.SpineCrawler), - get_raw_units_amount(obs, units.Zerg.Spire), - get_raw_units_amount(obs, units.Zerg.SporeCrawler), + raw_units_amount_dict[units.Zerg.BanelingNest], + raw_units_amount_dict[units.Zerg.EvolutionChamber], + raw_units_amount_dict[units.Zerg.Extractor], + raw_units_amount_dict[units.Zerg.Hatchery], + raw_units_amount_dict[units.Zerg.HydraliskDen], + raw_units_amount_dict[units.Zerg.InfestationPit], + raw_units_amount_dict[units.Zerg.LurkerDen], + raw_units_amount_dict[units.Zerg.NydusNetwork], + raw_units_amount_dict[units.Zerg.RoachWarren], + raw_units_amount_dict[units.Zerg.SpawningPool], + raw_units_amount_dict[units.Zerg.SpineCrawler], + raw_units_amount_dict[units.Zerg.Spire], + raw_units_amount_dict[units.Zerg.SporeCrawler], ] ) new_state = append_player_and_enemy_grids( From 98585081b99eba5eeee5cf0f88a2ec502416306c Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 16 Jul 2024 11:06:24 -0300 Subject: [PATCH 10/11] feat: Added starcraft2 general state class --- tests/units/sc2/states/test_protoss_state.py | 3 + .../units/sc2/states/test_starcraft2_state.py | 67 ++++++++++++++ tests/units/sc2/states/test_terran_state.py | 3 + tests/units/sc2/states/test_utils.py | 31 ++++--- tests/units/sc2/states/test_zerg_state.py | 3 + urnai/sc2/states/protoss_state.py | 88 ++++++------------ urnai/sc2/states/starcraft2_state.py | 55 +++++++++++ urnai/sc2/states/terran_state.py | 92 ++++++------------- urnai/sc2/states/zerg_state.py | 88 ++++++------------ 9 files changed, 230 insertions(+), 200 deletions(-) create mode 100644 tests/units/sc2/states/test_starcraft2_state.py create mode 100644 urnai/sc2/states/starcraft2_state.py diff --git a/tests/units/sc2/states/test_protoss_state.py b/tests/units/sc2/states/test_protoss_state.py index bfcc8ecf..cc451e09 100644 --- a/tests/units/sc2/states/test_protoss_state.py +++ b/tests/units/sc2/states/test_protoss_state.py @@ -1,5 +1,6 @@ import unittest +from pysc2.env import sc2_env from pysc2.lib.named_array import NamedDict from urnai.sc2.states.protoss_state import ProtossState @@ -26,6 +27,7 @@ def test_protoss_state_no_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.protoss assert state.dimension == 9 assert len(state.state[0]) == 9 @@ -63,5 +65,6 @@ def test_protoss_state_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.protoss assert state.dimension == 22 + ((5 * 5) * 2) assert len(state.state[0]) == 22 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_starcraft2_state.py b/tests/units/sc2/states/test_starcraft2_state.py new file mode 100644 index 00000000..692821c0 --- /dev/null +++ b/tests/units/sc2/states/test_starcraft2_state.py @@ -0,0 +1,67 @@ +import unittest + +from pysc2.lib.named_array import NamedDict + +from urnai.sc2.states.starcraft2_state import StarCraft2State + + +class TestStarCraft2State(unittest.TestCase): + + + def test_starcraft2_state_no_raw(self): + # GIVEN + state = StarCraft2State(use_raw_units=False) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }) + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 9 + assert len(state.state[0]) == 9 + + def test_starcraft2_state_raw(self): + # GIVEN + state = StarCraft2State(grid_size=5, use_raw_units=True) + obs = NamedDict({ + 'player': NamedDict({ + 'minerals': 100, + 'vespene': 100, + 'food_cap': 200, + 'food_used': 100, + 'food_army': 50, + 'food_workers': 50, + 'army_count': 20, + 'idle_worker_count': 10, + }), + 'raw_units': [ + NamedDict({ + 'unit_type': 1, + 'alliance': 1, + 'build_progress': 100, + 'x': 1, + 'y': 1, + }), + NamedDict({ + 'unit_type': 2, + 'alliance': 1, + 'build_progress': 100, + 'x': 2, + 'y': 2, + }), + ] + }) + # WHEN + state.update(obs) + # THEN + assert state.dimension == 9 + ((5 * 5) * 2) + assert len(state.state[0]) == 9 + ((5 * 5) * 2) diff --git a/tests/units/sc2/states/test_terran_state.py b/tests/units/sc2/states/test_terran_state.py index 40d46330..eb758dbe 100644 --- a/tests/units/sc2/states/test_terran_state.py +++ b/tests/units/sc2/states/test_terran_state.py @@ -1,5 +1,6 @@ import unittest +from pysc2.env import sc2_env from pysc2.lib.named_array import NamedDict from urnai.sc2.states.terran_state import TerranState @@ -26,6 +27,7 @@ def test_terran_state_no_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.terran assert state.dimension == 9 assert len(state.state[0]) == 9 @@ -63,5 +65,6 @@ def test_terran_state_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.terran assert state.dimension == 22 + ((4 * 4) * 2) assert len(state.state[0]) == 22 + ((4 * 4) * 2) diff --git a/tests/units/sc2/states/test_utils.py b/tests/units/sc2/states/test_utils.py index 78cd601a..eb238fac 100644 --- a/tests/units/sc2/states/test_utils.py +++ b/tests/units/sc2/states/test_utils.py @@ -1,5 +1,6 @@ import unittest +from pysc2.lib import features from pysc2.lib.named_array import NamedDict from urnai.sc2.states.utils import ( @@ -16,21 +17,21 @@ def test_append_player_and_enemy_grids(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -51,21 +52,21 @@ def test_create_raw_units_amount_dict(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -73,7 +74,7 @@ def test_create_raw_units_amount_dict(self): ] }) # WHEN - dict = create_raw_units_amount_dict(obs, 1) + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.SELF) # THEN assert len(dict) == 2 assert dict[1] == 1 @@ -85,21 +86,21 @@ def test_create_raw_units_amount_dict_alliance(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -107,7 +108,7 @@ def test_create_raw_units_amount_dict_alliance(self): ] }) # WHEN - dict = create_raw_units_amount_dict(obs, 4) + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.ENEMY) # THEN assert len(dict) == 1 assert dict[1] == 0 @@ -130,21 +131,21 @@ def test_create_raw_units_amount_dict_no_units_alliance(self): 'raw_units': [ NamedDict({ 'unit_type': 1, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 1, 'y': 1, }), NamedDict({ 'unit_type': 2, - 'alliance': 1, + 'alliance': features.PlayerRelative.SELF, 'build_progress': 100, 'x': 2, 'y': 2, }), NamedDict({ 'unit_type': 2, - 'alliance': 4, + 'alliance': features.PlayerRelative.ENEMY, 'build_progress': 100, 'x': 63, 'y': 63, @@ -152,7 +153,7 @@ def test_create_raw_units_amount_dict_no_units_alliance(self): ] }) # WHEN - dict = create_raw_units_amount_dict(obs, 3) + dict = create_raw_units_amount_dict(obs, features.PlayerRelative.NEUTRAL) # THEN assert len(dict) == 0 assert dict == {} \ No newline at end of file diff --git a/tests/units/sc2/states/test_zerg_state.py b/tests/units/sc2/states/test_zerg_state.py index 348fdf99..c1e25f05 100644 --- a/tests/units/sc2/states/test_zerg_state.py +++ b/tests/units/sc2/states/test_zerg_state.py @@ -1,5 +1,6 @@ import unittest +from pysc2.env import sc2_env from pysc2.lib.named_array import NamedDict from urnai.sc2.states.zerg_state import ZergState @@ -26,6 +27,7 @@ def test_zerg_state_no_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.zerg assert state.dimension == 9 assert len(state.state[0]) == 9 @@ -63,5 +65,6 @@ def test_zerg_state_raw(self): # WHEN state.update(obs) # THEN + assert state.player_race == sc2_env.Race.zerg assert state.dimension == 22 + ((6 * 6) * 2) assert len(state.state[0]) == 22 + ((6 * 6) * 2) diff --git a/urnai/sc2/states/protoss_state.py b/urnai/sc2/states/protoss_state.py index 0fe6b9e1..966e12ed 100644 --- a/urnai/sc2/states/protoss_state.py +++ b/urnai/sc2/states/protoss_state.py @@ -2,15 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.constants import SC2Constants -from urnai.sc2.states.utils import ( - append_player_and_enemy_grids, - create_raw_units_amount_dict, -) -from urnai.states.state_base import StateBase +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict -class ProtossState(StateBase): +class ProtossState(StarCraft2State): def __init__( self, @@ -18,64 +14,34 @@ def __init__( use_raw_units: bool = True, raw_resolution: int = 64, ): - self.grid_size = grid_size + super().__init__(grid_size, use_raw_units, raw_resolution) self.player_race = sc2_env.Race.protoss - self.use_raw_units = use_raw_units - self.raw_resolution = raw_resolution - self.reset() def update(self, obs): - new_state = [ - # Adds general information from the player. - obs.player.minerals / SC2Constants.MAX_MINERALS, - obs.player.vespene / SC2Constants.MAX_VESPENE, - obs.player.food_cap / SC2Constants.MAX_UNITS, - obs.player.food_used / SC2Constants.MAX_UNITS, - obs.player.food_army / SC2Constants.MAX_UNITS, - obs.player.food_workers / SC2Constants.MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, - obs.player.army_count / SC2Constants.MAX_UNITS, - obs.player.idle_worker_count / SC2Constants.MAX_UNITS, - ] + state = super().update(obs) if self.use_raw_units: raw_units_amount_dict = create_raw_units_amount_dict( obs, sc2_env.features.PlayerRelative.SELF) - new_state.extend( - [ - # Adds information related to player's Protoss units/buildings. - raw_units_amount_dict[units.Protoss.Nexus], - raw_units_amount_dict[units.Protoss.Pylon], - raw_units_amount_dict[units.Protoss.Assimilator], - raw_units_amount_dict[units.Protoss.Forge], - raw_units_amount_dict[units.Protoss.Gateway], - raw_units_amount_dict[units.Protoss.CyberneticsCore], - raw_units_amount_dict[units.Protoss.PhotonCannon], - raw_units_amount_dict[units.Protoss.RoboticsFacility], - raw_units_amount_dict[units.Protoss.Stargate], - raw_units_amount_dict[units.Protoss.TwilightCouncil], - raw_units_amount_dict[units.Protoss.RoboticsBay], - raw_units_amount_dict[units.Protoss.TemplarArchive], - raw_units_amount_dict[units.Protoss.DarkShrine], - ] - ) - new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution - ) - - self._dimension = len(new_state) - final_state = np.expand_dims(new_state, axis=0) - self._state = final_state - return final_state - - @property - def state(self): - return self._state - - @property - def dimension(self): - return self._dimension - - def reset(self): - self._state = None - self._dimension = None + units_amount_info = [ + raw_units_amount_dict[units.Protoss.Nexus], + raw_units_amount_dict[units.Protoss.Pylon], + raw_units_amount_dict[units.Protoss.Assimilator], + raw_units_amount_dict[units.Protoss.Forge], + raw_units_amount_dict[units.Protoss.Gateway], + raw_units_amount_dict[units.Protoss.CyberneticsCore], + raw_units_amount_dict[units.Protoss.PhotonCannon], + raw_units_amount_dict[units.Protoss.RoboticsFacility], + raw_units_amount_dict[units.Protoss.Stargate], + raw_units_amount_dict[units.Protoss.TwilightCouncil], + raw_units_amount_dict[units.Protoss.RoboticsBay], + raw_units_amount_dict[units.Protoss.TemplarArchive], + raw_units_amount_dict[units.Protoss.DarkShrine], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file diff --git a/urnai/sc2/states/starcraft2_state.py b/urnai/sc2/states/starcraft2_state.py new file mode 100644 index 00000000..dc0b987e --- /dev/null +++ b/urnai/sc2/states/starcraft2_state.py @@ -0,0 +1,55 @@ +import numpy as np + +from urnai.constants import SC2Constants +from urnai.sc2.states.utils import append_player_and_enemy_grids +from urnai.states.state_base import StateBase + + +class StarCraft2State(StateBase): + + def __init__( + self, + grid_size: int = 4, + use_raw_units: bool = True, + raw_resolution: int = 64, + ): + self.grid_size = grid_size + self.use_raw_units = use_raw_units + self.raw_resolution = raw_resolution + self.reset() + + def update(self, obs): + new_state = [ + # Adds general information from the player. + obs.player.minerals / SC2Constants.MAX_MINERALS, + obs.player.vespene / SC2Constants.MAX_VESPENE, + obs.player.food_cap / SC2Constants.MAX_UNITS, + obs.player.food_used / SC2Constants.MAX_UNITS, + obs.player.food_army / SC2Constants.MAX_UNITS, + obs.player.food_workers / SC2Constants.MAX_UNITS, + (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, + obs.player.army_count / SC2Constants.MAX_UNITS, + obs.player.idle_worker_count / SC2Constants.MAX_UNITS, + ] + + if self.use_raw_units: + new_state = append_player_and_enemy_grids( + obs, new_state, self.grid_size, self.raw_resolution + ) + + self._dimension = len(new_state) + final_state = np.expand_dims(new_state, axis=0) + self._state = final_state + return final_state + + @property + def state(self): + return self._state + + @property + def dimension(self): + return self._dimension + + def reset(self): + self._state = None + self._dimension = None diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 283ff694..0f6ec8e7 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -2,15 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.constants import SC2Constants -from urnai.sc2.states.utils import ( - append_player_and_enemy_grids, - create_raw_units_amount_dict, -) -from urnai.states.state_base import StateBase +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict -class TerranState(StateBase): +class TerranState(StarCraft2State): def __init__( self, @@ -18,66 +14,36 @@ def __init__( use_raw_units: bool = True, raw_resolution: int = 64, ): - self.grid_size = grid_size + super().__init__(grid_size, use_raw_units, raw_resolution) self.player_race = sc2_env.Race.terran - self.use_raw_units = use_raw_units - self.raw_resolution = raw_resolution - self.reset() def update(self, obs): - new_state = [ - # Adds general information from the player. - obs.player.minerals / SC2Constants.MAX_MINERALS, - obs.player.vespene / SC2Constants.MAX_VESPENE, - obs.player.food_cap / SC2Constants.MAX_UNITS, - obs.player.food_used / SC2Constants.MAX_UNITS, - obs.player.food_army / SC2Constants.MAX_UNITS, - obs.player.food_workers / SC2Constants.MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, - obs.player.army_count / SC2Constants.MAX_UNITS, - obs.player.idle_worker_count / SC2Constants.MAX_UNITS, - ] + state = super().update(obs) if self.use_raw_units: raw_units_amount_dict = create_raw_units_amount_dict( obs, sc2_env.features.PlayerRelative.SELF) - new_state.extend( - [ - # Adds information related to player's Terran units/buildings. - raw_units_amount_dict[units.Terran.CommandCenter] - + raw_units_amount_dict[units.Terran.OrbitalCommand] - + raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2, - raw_units_amount_dict[units.Terran.SupplyDepot] / 18, - raw_units_amount_dict[units.Terran.Refinery] / 4, - raw_units_amount_dict[units.Terran.EngineeringBay], - raw_units_amount_dict[units.Terran.Armory], - raw_units_amount_dict[units.Terran.MissileTurret] / 4, - raw_units_amount_dict[units.Terran.SensorTower] / 1, - raw_units_amount_dict[units.Terran.Bunker] / 4, - raw_units_amount_dict[units.Terran.FusionCore], - raw_units_amount_dict[units.Terran.GhostAcademy], - raw_units_amount_dict[units.Terran.Barracks] / 3, - raw_units_amount_dict[units.Terran.Factory] / 2, - raw_units_amount_dict[units.Terran.Starport] / 2, - ] - ) - new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution - ) - - self._dimension = len(new_state) - final_state = np.expand_dims(new_state, axis=0) - self._state = final_state - return final_state - - @property - def state(self): - return self._state - - @property - def dimension(self): - return self._dimension - - def reset(self): - self._state = None - self._dimension = None + units_amount_info = [ + raw_units_amount_dict[units.Terran.CommandCenter] + + raw_units_amount_dict[units.Terran.OrbitalCommand] + + raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2, + raw_units_amount_dict[units.Terran.SupplyDepot] / 18, + raw_units_amount_dict[units.Terran.Refinery] / 4, + raw_units_amount_dict[units.Terran.EngineeringBay], + raw_units_amount_dict[units.Terran.Armory], + raw_units_amount_dict[units.Terran.MissileTurret] / 4, + raw_units_amount_dict[units.Terran.SensorTower] / 1, + raw_units_amount_dict[units.Terran.Bunker] / 4, + raw_units_amount_dict[units.Terran.FusionCore], + raw_units_amount_dict[units.Terran.GhostAcademy], + raw_units_amount_dict[units.Terran.Barracks] / 3, + raw_units_amount_dict[units.Terran.Factory] / 2, + raw_units_amount_dict[units.Terran.Starport] / 2, + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file diff --git a/urnai/sc2/states/zerg_state.py b/urnai/sc2/states/zerg_state.py index 1448696a..2443cb1a 100644 --- a/urnai/sc2/states/zerg_state.py +++ b/urnai/sc2/states/zerg_state.py @@ -2,15 +2,11 @@ from pysc2.env import sc2_env from pysc2.lib import units -from urnai.constants import SC2Constants -from urnai.sc2.states.utils import ( - append_player_and_enemy_grids, - create_raw_units_amount_dict, -) -from urnai.states.state_base import StateBase +from urnai.sc2.states.starcraft2_state import StarCraft2State +from urnai.sc2.states.utils import create_raw_units_amount_dict -class ZergState(StateBase): +class ZergState(StarCraft2State): def __init__( self, @@ -18,64 +14,34 @@ def __init__( use_raw_units: bool = True, raw_resolution: int = 64, ): - self.grid_size = grid_size + super().__init__(grid_size, use_raw_units, raw_resolution) self.player_race = sc2_env.Race.zerg - self.use_raw_units = use_raw_units - self.raw_resolution = raw_resolution - self.reset() def update(self, obs): - new_state = [ - # Adds general information from the player. - obs.player.minerals / SC2Constants.MAX_MINERALS, - obs.player.vespene / SC2Constants.MAX_VESPENE, - obs.player.food_cap / SC2Constants.MAX_UNITS, - obs.player.food_used / SC2Constants.MAX_UNITS, - obs.player.food_army / SC2Constants.MAX_UNITS, - obs.player.food_workers / SC2Constants.MAX_UNITS, - (obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS, - obs.player.army_count / SC2Constants.MAX_UNITS, - obs.player.idle_worker_count / SC2Constants.MAX_UNITS, - ] + state = super().update(obs) if self.use_raw_units: raw_units_amount_dict = create_raw_units_amount_dict( obs, sc2_env.features.PlayerRelative.SELF) - new_state.extend( - [ - # Adds information related to player's Zerg units/buildings. - raw_units_amount_dict[units.Zerg.BanelingNest], - raw_units_amount_dict[units.Zerg.EvolutionChamber], - raw_units_amount_dict[units.Zerg.Extractor], - raw_units_amount_dict[units.Zerg.Hatchery], - raw_units_amount_dict[units.Zerg.HydraliskDen], - raw_units_amount_dict[units.Zerg.InfestationPit], - raw_units_amount_dict[units.Zerg.LurkerDen], - raw_units_amount_dict[units.Zerg.NydusNetwork], - raw_units_amount_dict[units.Zerg.RoachWarren], - raw_units_amount_dict[units.Zerg.SpawningPool], - raw_units_amount_dict[units.Zerg.SpineCrawler], - raw_units_amount_dict[units.Zerg.Spire], - raw_units_amount_dict[units.Zerg.SporeCrawler], - ] - ) - new_state = append_player_and_enemy_grids( - obs, new_state, self.grid_size, self.raw_resolution - ) - - self._dimension = len(new_state) - final_state = np.expand_dims(new_state, axis=0) - self._state = final_state - return final_state - - @property - def state(self): - return self._state - - @property - def dimension(self): - return self._dimension - - def reset(self): - self._state = None - self._dimension = None + units_amount_info = [ + raw_units_amount_dict[units.Zerg.BanelingNest], + raw_units_amount_dict[units.Zerg.EvolutionChamber], + raw_units_amount_dict[units.Zerg.Extractor], + raw_units_amount_dict[units.Zerg.Hatchery], + raw_units_amount_dict[units.Zerg.HydraliskDen], + raw_units_amount_dict[units.Zerg.InfestationPit], + raw_units_amount_dict[units.Zerg.LurkerDen], + raw_units_amount_dict[units.Zerg.NydusNetwork], + raw_units_amount_dict[units.Zerg.RoachWarren], + raw_units_amount_dict[units.Zerg.SpawningPool], + raw_units_amount_dict[units.Zerg.SpineCrawler], + raw_units_amount_dict[units.Zerg.Spire], + raw_units_amount_dict[units.Zerg.SporeCrawler], + ] + state = np.squeeze(state) + state = np.append(state, units_amount_info) + self._dimension = len(state) + state = np.expand_dims(state, axis=0) + self._state = state + + return state \ No newline at end of file From ca80acf06c5446f7602b65575e16999d0db8a64b Mon Sep 17 00:00:00 2001 From: RickFqt Date: Tue, 16 Jul 2024 17:56:26 -0300 Subject: [PATCH 11/11] refactor: Removed magic numbers from terran state --- urnai/sc2/states/terran_state.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/urnai/sc2/states/terran_state.py b/urnai/sc2/states/terran_state.py index 0f6ec8e7..285a8d45 100644 --- a/urnai/sc2/states/terran_state.py +++ b/urnai/sc2/states/terran_state.py @@ -26,19 +26,19 @@ def update(self, obs): units_amount_info = [ raw_units_amount_dict[units.Terran.CommandCenter] + raw_units_amount_dict[units.Terran.OrbitalCommand] - + raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2, - raw_units_amount_dict[units.Terran.SupplyDepot] / 18, - raw_units_amount_dict[units.Terran.Refinery] / 4, + + raw_units_amount_dict[units.Terran.PlanetaryFortress], + raw_units_amount_dict[units.Terran.SupplyDepot], + raw_units_amount_dict[units.Terran.Refinery], raw_units_amount_dict[units.Terran.EngineeringBay], raw_units_amount_dict[units.Terran.Armory], - raw_units_amount_dict[units.Terran.MissileTurret] / 4, - raw_units_amount_dict[units.Terran.SensorTower] / 1, - raw_units_amount_dict[units.Terran.Bunker] / 4, + raw_units_amount_dict[units.Terran.MissileTurret], + raw_units_amount_dict[units.Terran.SensorTower], + raw_units_amount_dict[units.Terran.Bunker], raw_units_amount_dict[units.Terran.FusionCore], raw_units_amount_dict[units.Terran.GhostAcademy], - raw_units_amount_dict[units.Terran.Barracks] / 3, - raw_units_amount_dict[units.Terran.Factory] / 2, - raw_units_amount_dict[units.Terran.Starport] / 2, + raw_units_amount_dict[units.Terran.Barracks], + raw_units_amount_dict[units.Terran.Factory], + raw_units_amount_dict[units.Terran.Starport], ] state = np.squeeze(state) state = np.append(state, units_amount_info)