Skip to content

Commit

Permalink
feat: Added starcraft2 general state class
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Jul 16, 2024
1 parent 4d2f055 commit 9858508
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 200 deletions.
3 changes: 3 additions & 0 deletions tests/units/sc2/states/test_protoss_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

from pysc2.env import sc2_env
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.protoss_state import ProtossState
Expand All @@ -26,6 +27,7 @@ def test_protoss_state_no_raw(self):
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.protoss
assert state.dimension == 9
assert len(state.state[0]) == 9

Expand Down Expand Up @@ -63,5 +65,6 @@ def test_protoss_state_raw(self):
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.protoss
assert state.dimension == 22 + ((5 * 5) * 2)
assert len(state.state[0]) == 22 + ((5 * 5) * 2)
67 changes: 67 additions & 0 deletions tests/units/sc2/states/test_starcraft2_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.starcraft2_state import StarCraft2State


class TestStarCraft2State(unittest.TestCase):


def test_starcraft2_state_no_raw(self):
# GIVEN
state = StarCraft2State(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_starcraft2_state_raw(self):
# GIVEN
state = StarCraft2State(grid_size=5, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9 + ((5 * 5) * 2)
assert len(state.state[0]) == 9 + ((5 * 5) * 2)
3 changes: 3 additions & 0 deletions tests/units/sc2/states/test_terran_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

from pysc2.env import sc2_env
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.terran_state import TerranState
Expand All @@ -26,6 +27,7 @@ def test_terran_state_no_raw(self):
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.terran
assert state.dimension == 9
assert len(state.state[0]) == 9

Expand Down Expand Up @@ -63,5 +65,6 @@ def test_terran_state_raw(self):
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.terran
assert state.dimension == 22 + ((4 * 4) * 2)
assert len(state.state[0]) == 22 + ((4 * 4) * 2)
31 changes: 16 additions & 15 deletions tests/units/sc2/states/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

from pysc2.lib import features
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.utils import (
Expand All @@ -16,21 +17,21 @@ def test_append_player_and_enemy_grids(self):
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
Expand All @@ -51,29 +52,29 @@ def test_create_raw_units_amount_dict(self):
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, 1)
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.SELF)
# THEN
assert len(dict) == 2
assert dict[1] == 1
Expand All @@ -85,29 +86,29 @@ def test_create_raw_units_amount_dict_alliance(self):
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, 4)
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.ENEMY)
# THEN
assert len(dict) == 1
assert dict[1] == 0
Expand All @@ -130,29 +131,29 @@ def test_create_raw_units_amount_dict_no_units_alliance(self):
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, 3)
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.NEUTRAL)
# THEN
assert len(dict) == 0
assert dict == {}
3 changes: 3 additions & 0 deletions tests/units/sc2/states/test_zerg_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

from pysc2.env import sc2_env
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.zerg_state import ZergState
Expand All @@ -26,6 +27,7 @@ def test_zerg_state_no_raw(self):
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.zerg
assert state.dimension == 9
assert len(state.state[0]) == 9

Expand Down Expand Up @@ -63,5 +65,6 @@ def test_zerg_state_raw(self):
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.zerg
assert state.dimension == 22 + ((6 * 6) * 2)
assert len(state.state[0]) == 22 + ((6 * 6) * 2)
88 changes: 27 additions & 61 deletions urnai/sc2/states/protoss_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,80 +2,46 @@
from pysc2.env import sc2_env
from pysc2.lib import units

from urnai.constants import SC2Constants
from urnai.sc2.states.utils import (
append_player_and_enemy_grids,
create_raw_units_amount_dict,
)
from urnai.states.state_base import StateBase
from urnai.sc2.states.starcraft2_state import StarCraft2State
from urnai.sc2.states.utils import create_raw_units_amount_dict


class ProtossState(StateBase):
class ProtossState(StarCraft2State):

def __init__(
self,
grid_size: int = 4,
use_raw_units: bool = True,
raw_resolution: int = 64,
):
self.grid_size = grid_size
super().__init__(grid_size, use_raw_units, raw_resolution)
self.player_race = sc2_env.Race.protoss
self.use_raw_units = use_raw_units
self.raw_resolution = raw_resolution
self.reset()

def update(self, obs):
new_state = [
# Adds general information from the player.
obs.player.minerals / SC2Constants.MAX_MINERALS,
obs.player.vespene / SC2Constants.MAX_VESPENE,
obs.player.food_cap / SC2Constants.MAX_UNITS,
obs.player.food_used / SC2Constants.MAX_UNITS,
obs.player.food_army / SC2Constants.MAX_UNITS,
obs.player.food_workers / SC2Constants.MAX_UNITS,
(obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS,
obs.player.army_count / SC2Constants.MAX_UNITS,
obs.player.idle_worker_count / SC2Constants.MAX_UNITS,
]
state = super().update(obs)

if self.use_raw_units:
raw_units_amount_dict = create_raw_units_amount_dict(
obs, sc2_env.features.PlayerRelative.SELF)
new_state.extend(
[
# Adds information related to player's Protoss units/buildings.
raw_units_amount_dict[units.Protoss.Nexus],
raw_units_amount_dict[units.Protoss.Pylon],
raw_units_amount_dict[units.Protoss.Assimilator],
raw_units_amount_dict[units.Protoss.Forge],
raw_units_amount_dict[units.Protoss.Gateway],
raw_units_amount_dict[units.Protoss.CyberneticsCore],
raw_units_amount_dict[units.Protoss.PhotonCannon],
raw_units_amount_dict[units.Protoss.RoboticsFacility],
raw_units_amount_dict[units.Protoss.Stargate],
raw_units_amount_dict[units.Protoss.TwilightCouncil],
raw_units_amount_dict[units.Protoss.RoboticsBay],
raw_units_amount_dict[units.Protoss.TemplarArchive],
raw_units_amount_dict[units.Protoss.DarkShrine],
]
)
new_state = append_player_and_enemy_grids(
obs, new_state, self.grid_size, self.raw_resolution
)

self._dimension = len(new_state)
final_state = np.expand_dims(new_state, axis=0)
self._state = final_state
return final_state

@property
def state(self):
return self._state

@property
def dimension(self):
return self._dimension

def reset(self):
self._state = None
self._dimension = None
units_amount_info = [
raw_units_amount_dict[units.Protoss.Nexus],
raw_units_amount_dict[units.Protoss.Pylon],
raw_units_amount_dict[units.Protoss.Assimilator],
raw_units_amount_dict[units.Protoss.Forge],
raw_units_amount_dict[units.Protoss.Gateway],
raw_units_amount_dict[units.Protoss.CyberneticsCore],
raw_units_amount_dict[units.Protoss.PhotonCannon],
raw_units_amount_dict[units.Protoss.RoboticsFacility],
raw_units_amount_dict[units.Protoss.Stargate],
raw_units_amount_dict[units.Protoss.TwilightCouncil],
raw_units_amount_dict[units.Protoss.RoboticsBay],
raw_units_amount_dict[units.Protoss.TemplarArchive],
raw_units_amount_dict[units.Protoss.DarkShrine],
]
state = np.squeeze(state)
state = np.append(state, units_amount_info)
self._dimension = len(state)
state = np.expand_dims(state, axis=0)
self._state = state

return state
Loading

0 comments on commit 9858508

Please sign in to comment.