Skip to content

Commit 51cd0c4

Browse files
committed
chore(issue-73): Added initial Terran State class
1 parent 6d1f718 commit 51cd0c4

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

urnai/sc2/states/__init__.py

Whitespace-only changes.

urnai/sc2/states/terran_state.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import math
2+
3+
import numpy as np
4+
from pysc2.env import sc2_env
5+
from pysc2.lib import features, units
6+
7+
from urnai.states.state_base import StateBase
8+
9+
10+
class TerranState(StateBase):
11+
12+
def __init__(self, grid_size=4):
13+
14+
self.grid_size = grid_size
15+
# Size of the state returned
16+
# 22: number of data added to the state (number of minerals, army_count, etc)
17+
# 2 * 4 ** 2: 2 grids of size 4 x 4, representing enemy and player units
18+
self._state_size = int(22 + 2 * (self.grid_size ** 2))
19+
self.player_race = sc2_env.Race.terran
20+
self.base_top_left = None
21+
self._state = None
22+
23+
def update(self, obs):
24+
if obs.game_loop[0] < 80 and self.base_top_left is None:
25+
26+
commandcenter = get_units_by_type(obs, units.Terran.CommandCenter)
27+
28+
if len(commandcenter) > 0:
29+
townhall = commandcenter[0]
30+
self.player_race = sc2_env.Race.terran
31+
32+
self.base_top_left = (townhall.x < 32)
33+
34+
new_state = []
35+
# Adds general information from the player.
36+
new_state.append(obs.player.minerals / 6000)
37+
new_state.append(obs.player.vespene / 6000)
38+
new_state.append(obs.player.food_cap / 200)
39+
new_state.append(obs.player.food_used / 200)
40+
new_state.append(obs.player.food_army / 200)
41+
new_state.append(obs.player.food_workers / 200)
42+
new_state.append((obs.player.food_cap - obs.player.food_used) / 200)
43+
new_state.append(obs.player.army_count / 200)
44+
new_state.append(obs.player.idle_worker_count / 200)
45+
46+
# Adds information related to player's Terran units/buildings.
47+
new_state.append(get_my_units_amount(obs, units.Terran.CommandCenter) +
48+
get_my_units_amount(obs, units.Terran.OrbitalCommand) +
49+
get_my_units_amount(obs, units.Terran.PlanetaryFortress) / 2)
50+
new_state.append(get_my_units_amount(obs, units.Terran.SupplyDepot) / 18)
51+
new_state.append(get_my_units_amount(obs, units.Terran.Refinery) / 4)
52+
new_state.append(get_my_units_amount(obs, units.Terran.EngineeringBay))
53+
new_state.append(get_my_units_amount(obs, units.Terran.Armory))
54+
new_state.append(get_my_units_amount(obs, units.Terran.MissileTurret) / 4)
55+
new_state.append(get_my_units_amount(obs, units.Terran.SensorTower)/1)
56+
new_state.append(get_my_units_amount(obs, units.Terran.Bunker)/4)
57+
new_state.append(get_my_units_amount(obs, units.Terran.FusionCore))
58+
new_state.append(get_my_units_amount(obs, units.Terran.GhostAcademy))
59+
new_state.append(get_my_units_amount(obs, units.Terran.Barracks) / 3)
60+
new_state.append(get_my_units_amount(obs, units.Terran.Factory) / 2)
61+
new_state.append(get_my_units_amount(obs, units.Terran.Starport) / 2)
62+
63+
# Instead of making a vector for all coordnates on the map, we'll
64+
# discretize our enemy space
65+
# and use a 4x4 grid to store enemy positions by marking a square as 1 if
66+
# there's any enemy on it.
67+
68+
enemy_grid = np.zeros((self.grid_size, self.grid_size))
69+
player_grid = np.zeros((self.grid_size, self.grid_size))
70+
71+
enemy_units = [unit for unit in obs.raw_units if
72+
unit.alliance == features.PlayerRelative.ENEMY]
73+
player_units = [unit for unit in obs.raw_units if
74+
unit.alliance == features.PlayerRelative.SELF]
75+
76+
for i in range(0, len(enemy_units)):
77+
y = int(math.ceil((enemy_units[i].x + 1) / 64 / self.grid_size))
78+
x = int(math.ceil((enemy_units[i].y + 1) / 64 / self.grid_size))
79+
enemy_grid[x - 1][y - 1] += 1
80+
81+
for i in range(0, len(player_units)):
82+
y = int(math.ceil((player_units[i].x + 1) / (64 / self.grid_size)))
83+
x = int(math.ceil((player_units[i].y + 1) / (64 / self.grid_size)))
84+
player_grid[x - 1][y - 1] += 1
85+
86+
if not self.base_top_left:
87+
enemy_grid = np.rot90(enemy_grid, 2)
88+
player_grid = np.rot90(player_grid, 2)
89+
90+
# Normalizing the values to always be between 0 and 1
91+
# (since the max amount of units in SC2 is 200)
92+
enemy_grid = enemy_grid / 200
93+
player_grid = player_grid / 200
94+
95+
new_state.extend(enemy_grid.flatten())
96+
new_state.extend(player_grid.flatten())
97+
final_state = np.expand_dims(new_state, axis=0)
98+
99+
self._state = final_state
100+
return final_state
101+
102+
@property
103+
def state(self):
104+
return self._state
105+
106+
@property
107+
def dimension(self):
108+
return self._state_size
109+
110+
def reset(self):
111+
self._state = None
112+
self.base_top_left = None
113+
114+
def get_my_units_amount(obs, unit_type):
115+
return len(get_units_by_type(obs, unit_type, features.PlayerRelative.SELF))
116+
117+
def get_units_by_type(obs, unit_type, alliance=features.PlayerRelative.SELF):
118+
return [unit for unit in obs.raw_units
119+
if unit.unit_type == unit_type
120+
and unit.alliance == alliance
121+
and unit.build_progress == 100]

0 commit comments

Comments
 (0)