Skip to content

Commit 4d2f055

Browse files
committed
refactor: Changed units amount functions
1 parent 94263bc commit 4d2f055

File tree

8 files changed

+96
-158
lines changed

8 files changed

+96
-158
lines changed

tests/units/sc2/states/test_protoss_state.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def test_protoss_state_no_raw(self):
2626
# WHEN
2727
state.update(obs)
2828
# THEN
29-
print(state.state)
3029
assert state.dimension == 9
3130
assert len(state.state[0]) == 9
3231

tests/units/sc2/states/test_terran_state.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def test_terran_state_no_raw(self):
2626
# WHEN
2727
state.update(obs)
2828
# THEN
29-
print(state.state)
3029
assert state.dimension == 9
3130
assert len(state.state[0]) == 9
3231

tests/units/sc2/states/test_utils.py

+35-95
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from urnai.sc2.states.utils import (
66
append_player_and_enemy_grids,
7-
get_raw_units_amount,
8-
get_raw_units_by_type,
7+
create_raw_units_amount_dict,
98
)
109

1110

@@ -42,12 +41,11 @@ def test_append_player_and_enemy_grids(self):
4241
# WHEN
4342
new_state = append_player_and_enemy_grids(obs, new_state, 3, 64)
4443
# THEN
45-
print(new_state)
4644
assert len(new_state) == (18)
4745
assert new_state[8] == 0.005
4846
assert new_state[9] == 0.01
4947

50-
def test_get_raw_units_amount(self):
48+
def test_create_raw_units_amount_dict(self):
5149
# GIVEN
5250
obs = NamedDict({
5351
'raw_units': [
@@ -65,29 +63,28 @@ def test_get_raw_units_amount(self):
6563
'x': 2,
6664
'y': 2,
6765
}),
66+
NamedDict({
67+
'unit_type': 2,
68+
'alliance': 4,
69+
'build_progress': 100,
70+
'x': 63,
71+
'y': 63,
72+
}),
6873
]
6974
})
7075
# WHEN
71-
amount = get_raw_units_amount(obs, 1)
72-
# THEN
73-
assert amount == 1
74-
75-
def test_get_raw_units_amount_no_units(self):
76-
# GIVEN
77-
obs = NamedDict({
78-
'raw_units': []
79-
})
80-
# WHEN
81-
amount = get_raw_units_amount(obs, 1)
76+
dict = create_raw_units_amount_dict(obs, 1)
8277
# THEN
83-
assert amount == 0
78+
assert len(dict) == 2
79+
assert dict[1] == 1
80+
assert dict[2] == 1
8481

85-
def test_get_raw_units_amount_no_units_of_type(self):
82+
def test_create_raw_units_amount_dict_alliance(self):
8683
# GIVEN
8784
obs = NamedDict({
8885
'raw_units': [
8986
NamedDict({
90-
'unit_type': 2,
87+
'unit_type': 1,
9188
'alliance': 1,
9289
'build_progress': 100,
9390
'x': 1,
@@ -100,64 +97,34 @@ def test_get_raw_units_amount_no_units_of_type(self):
10097
'x': 2,
10198
'y': 2,
10299
}),
103-
]
104-
})
105-
# WHEN
106-
amount = get_raw_units_amount(obs, 1)
107-
# THEN
108-
assert amount == 0
109-
110-
def test_get_raw_units_amount_no_units_of_alliance(self):
111-
# GIVEN
112-
obs = NamedDict({
113-
'raw_units': [
114-
NamedDict({
115-
'unit_type': 1,
116-
'alliance': 2,
117-
'build_progress': 100,
118-
'x': 1,
119-
'y': 1,
120-
}),
121100
NamedDict({
122-
'unit_type': 1,
123-
'alliance': 2,
101+
'unit_type': 2,
102+
'alliance': 4,
124103
'build_progress': 100,
125-
'x': 2,
126-
'y': 2,
104+
'x': 63,
105+
'y': 63,
127106
}),
128107
]
129108
})
130109
# WHEN
131-
amount = get_raw_units_amount(obs, 1)
110+
dict = create_raw_units_amount_dict(obs, 4)
132111
# THEN
133-
assert amount == 0
112+
assert len(dict) == 1
113+
assert dict[1] == 0
114+
assert dict[2] == 1
134115

135-
def test_get_raw_units_amount_no_units_of_build_progress(self):
116+
def test_create_raw_units_amount_dict_no_units(self):
136117
# GIVEN
137118
obs = NamedDict({
138-
'raw_units': [
139-
NamedDict({
140-
'unit_type': 1,
141-
'alliance': 1,
142-
'build_progress': 50,
143-
'x': 1,
144-
'y': 1,
145-
}),
146-
NamedDict({
147-
'unit_type': 1,
148-
'alliance': 1,
149-
'build_progress': 50,
150-
'x': 2,
151-
'y': 2,
152-
}),
153-
]
119+
'raw_units': []
154120
})
155121
# WHEN
156-
amount = get_raw_units_amount(obs, 1)
122+
dict = create_raw_units_amount_dict(obs)
157123
# THEN
158-
assert amount == 0
124+
assert len(dict) == 0
125+
assert dict == {}
159126

160-
def test_get_raw_units_by_type(self):
127+
def test_create_raw_units_amount_dict_no_units_alliance(self):
161128
# GIVEN
162129
obs = NamedDict({
163130
'raw_units': [
@@ -175,44 +142,17 @@ def test_get_raw_units_by_type(self):
175142
'x': 2,
176143
'y': 2,
177144
}),
178-
]
179-
})
180-
# WHEN
181-
units = get_raw_units_by_type(obs, 1)
182-
# THEN
183-
assert len(units) == 1
184-
185-
def test_get_raw_units_by_type_no_units(self):
186-
# GIVEN
187-
obs = NamedDict({
188-
'raw_units': []
189-
})
190-
# WHEN
191-
units = get_raw_units_by_type(obs, 1)
192-
# THEN
193-
assert len(units) == 0
194-
195-
def test_get_raw_units_by_type_no_units_of_type(self):
196-
# GIVEN
197-
obs = NamedDict({
198-
'raw_units': [
199-
NamedDict({
200-
'unit_type': 2,
201-
'alliance': 1,
202-
'build_progress': 100,
203-
'x': 1,
204-
'y': 1,
205-
}),
206145
NamedDict({
207146
'unit_type': 2,
208-
'alliance': 1,
147+
'alliance': 4,
209148
'build_progress': 100,
210-
'x': 2,
211-
'y': 2,
149+
'x': 63,
150+
'y': 63,
212151
}),
213152
]
214153
})
215154
# WHEN
216-
units = get_raw_units_by_type(obs, 1)
155+
dict = create_raw_units_amount_dict(obs, 3)
217156
# THEN
218-
assert len(units) == 0
157+
assert len(dict) == 0
158+
assert dict == {}

tests/units/sc2/states/test_zerg_state.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def test_zerg_state_no_raw(self):
2626
# WHEN
2727
state.update(obs)
2828
# THEN
29-
print(state.state)
3029
assert state.dimension == 9
3130
assert len(state.state[0]) == 9
3231

urnai/sc2/states/protoss_state.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from urnai.constants import SC2Constants
66
from urnai.sc2.states.utils import (
77
append_player_and_enemy_grids,
8-
get_raw_units_amount,
8+
create_raw_units_amount_dict,
99
)
1010
from urnai.states.state_base import StateBase
1111

@@ -39,22 +39,24 @@ def update(self, obs):
3939
]
4040

4141
if self.use_raw_units:
42+
raw_units_amount_dict = create_raw_units_amount_dict(
43+
obs, sc2_env.features.PlayerRelative.SELF)
4244
new_state.extend(
4345
[
4446
# Adds information related to player's Protoss units/buildings.
45-
get_raw_units_amount(obs, units.Protoss.Nexus),
46-
get_raw_units_amount(obs, units.Protoss.Pylon),
47-
get_raw_units_amount(obs, units.Protoss.Assimilator),
48-
get_raw_units_amount(obs, units.Protoss.Forge),
49-
get_raw_units_amount(obs, units.Protoss.Gateway),
50-
get_raw_units_amount(obs, units.Protoss.CyberneticsCore),
51-
get_raw_units_amount(obs, units.Protoss.PhotonCannon),
52-
get_raw_units_amount(obs, units.Protoss.RoboticsFacility),
53-
get_raw_units_amount(obs, units.Protoss.Stargate),
54-
get_raw_units_amount(obs, units.Protoss.TwilightCouncil),
55-
get_raw_units_amount(obs, units.Protoss.RoboticsBay),
56-
get_raw_units_amount(obs, units.Protoss.TemplarArchive),
57-
get_raw_units_amount(obs, units.Protoss.DarkShrine),
47+
raw_units_amount_dict[units.Protoss.Nexus],
48+
raw_units_amount_dict[units.Protoss.Pylon],
49+
raw_units_amount_dict[units.Protoss.Assimilator],
50+
raw_units_amount_dict[units.Protoss.Forge],
51+
raw_units_amount_dict[units.Protoss.Gateway],
52+
raw_units_amount_dict[units.Protoss.CyberneticsCore],
53+
raw_units_amount_dict[units.Protoss.PhotonCannon],
54+
raw_units_amount_dict[units.Protoss.RoboticsFacility],
55+
raw_units_amount_dict[units.Protoss.Stargate],
56+
raw_units_amount_dict[units.Protoss.TwilightCouncil],
57+
raw_units_amount_dict[units.Protoss.RoboticsBay],
58+
raw_units_amount_dict[units.Protoss.TemplarArchive],
59+
raw_units_amount_dict[units.Protoss.DarkShrine],
5860
]
5961
)
6062
new_state = append_player_and_enemy_grids(

urnai/sc2/states/terran_state.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from urnai.constants import SC2Constants
66
from urnai.sc2.states.utils import (
77
append_player_and_enemy_grids,
8-
get_raw_units_amount,
8+
create_raw_units_amount_dict,
99
)
1010
from urnai.states.state_base import StateBase
1111

@@ -39,24 +39,26 @@ def update(self, obs):
3939
]
4040

4141
if self.use_raw_units:
42+
raw_units_amount_dict = create_raw_units_amount_dict(
43+
obs, sc2_env.features.PlayerRelative.SELF)
4244
new_state.extend(
4345
[
4446
# Adds information related to player's Terran units/buildings.
45-
get_raw_units_amount(obs, units.Terran.CommandCenter)
46-
+ get_raw_units_amount(obs, units.Terran.OrbitalCommand)
47-
+ get_raw_units_amount(obs, units.Terran.PlanetaryFortress) / 2,
48-
get_raw_units_amount(obs, units.Terran.SupplyDepot) / 18,
49-
get_raw_units_amount(obs, units.Terran.Refinery) / 4,
50-
get_raw_units_amount(obs, units.Terran.EngineeringBay),
51-
get_raw_units_amount(obs, units.Terran.Armory),
52-
get_raw_units_amount(obs, units.Terran.MissileTurret) / 4,
53-
get_raw_units_amount(obs, units.Terran.SensorTower) / 1,
54-
get_raw_units_amount(obs, units.Terran.Bunker) / 4,
55-
get_raw_units_amount(obs, units.Terran.FusionCore),
56-
get_raw_units_amount(obs, units.Terran.GhostAcademy),
57-
get_raw_units_amount(obs, units.Terran.Barracks) / 3,
58-
get_raw_units_amount(obs, units.Terran.Factory) / 2,
59-
get_raw_units_amount(obs, units.Terran.Starport) / 2,
47+
raw_units_amount_dict[units.Terran.CommandCenter]
48+
+ raw_units_amount_dict[units.Terran.OrbitalCommand]
49+
+ raw_units_amount_dict[units.Terran.PlanetaryFortress] / 2,
50+
raw_units_amount_dict[units.Terran.SupplyDepot] / 18,
51+
raw_units_amount_dict[units.Terran.Refinery] / 4,
52+
raw_units_amount_dict[units.Terran.EngineeringBay],
53+
raw_units_amount_dict[units.Terran.Armory],
54+
raw_units_amount_dict[units.Terran.MissileTurret] / 4,
55+
raw_units_amount_dict[units.Terran.SensorTower] / 1,
56+
raw_units_amount_dict[units.Terran.Bunker] / 4,
57+
raw_units_amount_dict[units.Terran.FusionCore],
58+
raw_units_amount_dict[units.Terran.GhostAcademy],
59+
raw_units_amount_dict[units.Terran.Barracks] / 3,
60+
raw_units_amount_dict[units.Terran.Factory] / 2,
61+
raw_units_amount_dict[units.Terran.Starport] / 2,
6062
]
6163
)
6264
new_state = append_player_and_enemy_grids(

urnai/sc2/states/utils.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from collections import defaultdict
23

34
import numpy as np
45
from pysc2.lib import features
@@ -48,19 +49,13 @@ def append_alliance_grid(
4849

4950
return new_state
5051

51-
def get_raw_units_amount(
52-
obs : list,
53-
unit_type : int,
54-
alliance : features.PlayerRelative = features.PlayerRelative.SELF,
55-
) -> int:
56-
return len(get_raw_units_by_type(obs, unit_type, alliance))
57-
58-
def get_raw_units_by_type(
59-
obs : list,
60-
unit_type : int,
61-
alliance : features.PlayerRelative = features.PlayerRelative.SELF,
62-
) -> list:
63-
return [unit for unit in obs.raw_units
64-
if unit.unit_type == unit_type
65-
and unit.alliance == alliance
66-
and unit.build_progress == SC2Constants.BUILD_PROGRESS_COMPLETE]
52+
def create_raw_units_amount_dict(
53+
obs : list,
54+
alliance : features.PlayerRelative = features.PlayerRelative.SELF
55+
) -> defaultdict:
56+
dict = defaultdict(lambda: 0)
57+
for unit in obs.raw_units:
58+
if (unit.alliance == alliance
59+
and unit.build_progress == SC2Constants.BUILD_PROGRESS_COMPLETE):
60+
dict[unit.unit_type] += 1
61+
return dict

0 commit comments

Comments
 (0)