Skip to content

Commit 9858508

Browse files
committed
feat: Added starcraft2 general state class
1 parent 4d2f055 commit 9858508

File tree

9 files changed

+230
-200
lines changed

9 files changed

+230
-200
lines changed

tests/units/sc2/states/test_protoss_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
from pysc2.env import sc2_env
34
from pysc2.lib.named_array import NamedDict
45

56
from urnai.sc2.states.protoss_state import ProtossState
@@ -26,6 +27,7 @@ def test_protoss_state_no_raw(self):
2627
# WHEN
2728
state.update(obs)
2829
# THEN
30+
assert state.player_race == sc2_env.Race.protoss
2931
assert state.dimension == 9
3032
assert len(state.state[0]) == 9
3133

@@ -63,5 +65,6 @@ def test_protoss_state_raw(self):
6365
# WHEN
6466
state.update(obs)
6567
# THEN
68+
assert state.player_race == sc2_env.Race.protoss
6669
assert state.dimension == 22 + ((5 * 5) * 2)
6770
assert len(state.state[0]) == 22 + ((5 * 5) * 2)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import unittest
2+
3+
from pysc2.lib.named_array import NamedDict
4+
5+
from urnai.sc2.states.starcraft2_state import StarCraft2State
6+
7+
8+
class TestStarCraft2State(unittest.TestCase):
9+
10+
11+
def test_starcraft2_state_no_raw(self):
12+
# GIVEN
13+
state = StarCraft2State(use_raw_units=False)
14+
obs = NamedDict({
15+
'player': NamedDict({
16+
'minerals': 100,
17+
'vespene': 100,
18+
'food_cap': 200,
19+
'food_used': 100,
20+
'food_army': 50,
21+
'food_workers': 50,
22+
'army_count': 20,
23+
'idle_worker_count': 10,
24+
})
25+
})
26+
# WHEN
27+
state.update(obs)
28+
# THEN
29+
assert state.dimension == 9
30+
assert len(state.state[0]) == 9
31+
32+
def test_starcraft2_state_raw(self):
33+
# GIVEN
34+
state = StarCraft2State(grid_size=5, use_raw_units=True)
35+
obs = NamedDict({
36+
'player': NamedDict({
37+
'minerals': 100,
38+
'vespene': 100,
39+
'food_cap': 200,
40+
'food_used': 100,
41+
'food_army': 50,
42+
'food_workers': 50,
43+
'army_count': 20,
44+
'idle_worker_count': 10,
45+
}),
46+
'raw_units': [
47+
NamedDict({
48+
'unit_type': 1,
49+
'alliance': 1,
50+
'build_progress': 100,
51+
'x': 1,
52+
'y': 1,
53+
}),
54+
NamedDict({
55+
'unit_type': 2,
56+
'alliance': 1,
57+
'build_progress': 100,
58+
'x': 2,
59+
'y': 2,
60+
}),
61+
]
62+
})
63+
# WHEN
64+
state.update(obs)
65+
# THEN
66+
assert state.dimension == 9 + ((5 * 5) * 2)
67+
assert len(state.state[0]) == 9 + ((5 * 5) * 2)

tests/units/sc2/states/test_terran_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
from pysc2.env import sc2_env
34
from pysc2.lib.named_array import NamedDict
45

56
from urnai.sc2.states.terran_state import TerranState
@@ -26,6 +27,7 @@ def test_terran_state_no_raw(self):
2627
# WHEN
2728
state.update(obs)
2829
# THEN
30+
assert state.player_race == sc2_env.Race.terran
2931
assert state.dimension == 9
3032
assert len(state.state[0]) == 9
3133

@@ -63,5 +65,6 @@ def test_terran_state_raw(self):
6365
# WHEN
6466
state.update(obs)
6567
# THEN
68+
assert state.player_race == sc2_env.Race.terran
6669
assert state.dimension == 22 + ((4 * 4) * 2)
6770
assert len(state.state[0]) == 22 + ((4 * 4) * 2)

tests/units/sc2/states/test_utils.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
from pysc2.lib import features
34
from pysc2.lib.named_array import NamedDict
45

56
from urnai.sc2.states.utils import (
@@ -16,21 +17,21 @@ def test_append_player_and_enemy_grids(self):
1617
'raw_units': [
1718
NamedDict({
1819
'unit_type': 1,
19-
'alliance': 1,
20+
'alliance': features.PlayerRelative.SELF,
2021
'build_progress': 100,
2122
'x': 1,
2223
'y': 1,
2324
}),
2425
NamedDict({
2526
'unit_type': 2,
26-
'alliance': 1,
27+
'alliance': features.PlayerRelative.SELF,
2728
'build_progress': 100,
2829
'x': 2,
2930
'y': 2,
3031
}),
3132
NamedDict({
3233
'unit_type': 2,
33-
'alliance': 4,
34+
'alliance': features.PlayerRelative.ENEMY,
3435
'build_progress': 100,
3536
'x': 63,
3637
'y': 63,
@@ -51,29 +52,29 @@ def test_create_raw_units_amount_dict(self):
5152
'raw_units': [
5253
NamedDict({
5354
'unit_type': 1,
54-
'alliance': 1,
55+
'alliance': features.PlayerRelative.SELF,
5556
'build_progress': 100,
5657
'x': 1,
5758
'y': 1,
5859
}),
5960
NamedDict({
6061
'unit_type': 2,
61-
'alliance': 1,
62+
'alliance': features.PlayerRelative.SELF,
6263
'build_progress': 100,
6364
'x': 2,
6465
'y': 2,
6566
}),
6667
NamedDict({
6768
'unit_type': 2,
68-
'alliance': 4,
69+
'alliance': features.PlayerRelative.ENEMY,
6970
'build_progress': 100,
7071
'x': 63,
7172
'y': 63,
7273
}),
7374
]
7475
})
7576
# WHEN
76-
dict = create_raw_units_amount_dict(obs, 1)
77+
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.SELF)
7778
# THEN
7879
assert len(dict) == 2
7980
assert dict[1] == 1
@@ -85,29 +86,29 @@ def test_create_raw_units_amount_dict_alliance(self):
8586
'raw_units': [
8687
NamedDict({
8788
'unit_type': 1,
88-
'alliance': 1,
89+
'alliance': features.PlayerRelative.SELF,
8990
'build_progress': 100,
9091
'x': 1,
9192
'y': 1,
9293
}),
9394
NamedDict({
9495
'unit_type': 2,
95-
'alliance': 1,
96+
'alliance': features.PlayerRelative.SELF,
9697
'build_progress': 100,
9798
'x': 2,
9899
'y': 2,
99100
}),
100101
NamedDict({
101102
'unit_type': 2,
102-
'alliance': 4,
103+
'alliance': features.PlayerRelative.ENEMY,
103104
'build_progress': 100,
104105
'x': 63,
105106
'y': 63,
106107
}),
107108
]
108109
})
109110
# WHEN
110-
dict = create_raw_units_amount_dict(obs, 4)
111+
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.ENEMY)
111112
# THEN
112113
assert len(dict) == 1
113114
assert dict[1] == 0
@@ -130,29 +131,29 @@ def test_create_raw_units_amount_dict_no_units_alliance(self):
130131
'raw_units': [
131132
NamedDict({
132133
'unit_type': 1,
133-
'alliance': 1,
134+
'alliance': features.PlayerRelative.SELF,
134135
'build_progress': 100,
135136
'x': 1,
136137
'y': 1,
137138
}),
138139
NamedDict({
139140
'unit_type': 2,
140-
'alliance': 1,
141+
'alliance': features.PlayerRelative.SELF,
141142
'build_progress': 100,
142143
'x': 2,
143144
'y': 2,
144145
}),
145146
NamedDict({
146147
'unit_type': 2,
147-
'alliance': 4,
148+
'alliance': features.PlayerRelative.ENEMY,
148149
'build_progress': 100,
149150
'x': 63,
150151
'y': 63,
151152
}),
152153
]
153154
})
154155
# WHEN
155-
dict = create_raw_units_amount_dict(obs, 3)
156+
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.NEUTRAL)
156157
# THEN
157158
assert len(dict) == 0
158159
assert dict == {}

tests/units/sc2/states/test_zerg_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
from pysc2.env import sc2_env
34
from pysc2.lib.named_array import NamedDict
45

56
from urnai.sc2.states.zerg_state import ZergState
@@ -26,6 +27,7 @@ def test_zerg_state_no_raw(self):
2627
# WHEN
2728
state.update(obs)
2829
# THEN
30+
assert state.player_race == sc2_env.Race.zerg
2931
assert state.dimension == 9
3032
assert len(state.state[0]) == 9
3133

@@ -63,5 +65,6 @@ def test_zerg_state_raw(self):
6365
# WHEN
6466
state.update(obs)
6567
# THEN
68+
assert state.player_race == sc2_env.Race.zerg
6669
assert state.dimension == 22 + ((6 * 6) * 2)
6770
assert len(state.state[0]) == 22 + ((6 * 6) * 2)

urnai/sc2/states/protoss_state.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,80 +2,46 @@
22
from pysc2.env import sc2_env
33
from pysc2.lib import units
44

5-
from urnai.constants import SC2Constants
6-
from urnai.sc2.states.utils import (
7-
append_player_and_enemy_grids,
8-
create_raw_units_amount_dict,
9-
)
10-
from urnai.states.state_base import StateBase
5+
from urnai.sc2.states.starcraft2_state import StarCraft2State
6+
from urnai.sc2.states.utils import create_raw_units_amount_dict
117

128

13-
class ProtossState(StateBase):
9+
class ProtossState(StarCraft2State):
1410

1511
def __init__(
1612
self,
1713
grid_size: int = 4,
1814
use_raw_units: bool = True,
1915
raw_resolution: int = 64,
2016
):
21-
self.grid_size = grid_size
17+
super().__init__(grid_size, use_raw_units, raw_resolution)
2218
self.player_race = sc2_env.Race.protoss
23-
self.use_raw_units = use_raw_units
24-
self.raw_resolution = raw_resolution
25-
self.reset()
2619

2720
def update(self, obs):
28-
new_state = [
29-
# Adds general information from the player.
30-
obs.player.minerals / SC2Constants.MAX_MINERALS,
31-
obs.player.vespene / SC2Constants.MAX_VESPENE,
32-
obs.player.food_cap / SC2Constants.MAX_UNITS,
33-
obs.player.food_used / SC2Constants.MAX_UNITS,
34-
obs.player.food_army / SC2Constants.MAX_UNITS,
35-
obs.player.food_workers / SC2Constants.MAX_UNITS,
36-
(obs.player.food_cap - obs.player.food_used) / SC2Constants.MAX_UNITS,
37-
obs.player.army_count / SC2Constants.MAX_UNITS,
38-
obs.player.idle_worker_count / SC2Constants.MAX_UNITS,
39-
]
21+
state = super().update(obs)
4022

4123
if self.use_raw_units:
4224
raw_units_amount_dict = create_raw_units_amount_dict(
4325
obs, sc2_env.features.PlayerRelative.SELF)
44-
new_state.extend(
45-
[
46-
# Adds information related to player's Protoss units/buildings.
47-
raw_units_amount_dict[units.Protoss.Nexus],
48-
raw_units_amount_dict[units.Protoss.Pylon],
49-
raw_units_amount_dict[units.Protoss.Assimilator],
50-
raw_units_amount_dict[units.Protoss.Forge],
51-
raw_units_amount_dict[units.Protoss.Gateway],
52-
raw_units_amount_dict[units.Protoss.CyberneticsCore],
53-
raw_units_amount_dict[units.Protoss.PhotonCannon],
54-
raw_units_amount_dict[units.Protoss.RoboticsFacility],
55-
raw_units_amount_dict[units.Protoss.Stargate],
56-
raw_units_amount_dict[units.Protoss.TwilightCouncil],
57-
raw_units_amount_dict[units.Protoss.RoboticsBay],
58-
raw_units_amount_dict[units.Protoss.TemplarArchive],
59-
raw_units_amount_dict[units.Protoss.DarkShrine],
60-
]
61-
)
62-
new_state = append_player_and_enemy_grids(
63-
obs, new_state, self.grid_size, self.raw_resolution
64-
)
65-
66-
self._dimension = len(new_state)
67-
final_state = np.expand_dims(new_state, axis=0)
68-
self._state = final_state
69-
return final_state
70-
71-
@property
72-
def state(self):
73-
return self._state
74-
75-
@property
76-
def dimension(self):
77-
return self._dimension
78-
79-
def reset(self):
80-
self._state = None
81-
self._dimension = None
26+
units_amount_info = [
27+
raw_units_amount_dict[units.Protoss.Nexus],
28+
raw_units_amount_dict[units.Protoss.Pylon],
29+
raw_units_amount_dict[units.Protoss.Assimilator],
30+
raw_units_amount_dict[units.Protoss.Forge],
31+
raw_units_amount_dict[units.Protoss.Gateway],
32+
raw_units_amount_dict[units.Protoss.CyberneticsCore],
33+
raw_units_amount_dict[units.Protoss.PhotonCannon],
34+
raw_units_amount_dict[units.Protoss.RoboticsFacility],
35+
raw_units_amount_dict[units.Protoss.Stargate],
36+
raw_units_amount_dict[units.Protoss.TwilightCouncil],
37+
raw_units_amount_dict[units.Protoss.RoboticsBay],
38+
raw_units_amount_dict[units.Protoss.TemplarArchive],
39+
raw_units_amount_dict[units.Protoss.DarkShrine],
40+
]
41+
state = np.squeeze(state)
42+
state = np.append(state, units_amount_info)
43+
self._dimension = len(state)
44+
state = np.expand_dims(state, axis=0)
45+
self._state = state
46+
47+
return state

0 commit comments

Comments
 (0)