Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolves #73 #82

Merged
merged 11 commits into from
Jul 17, 2024
Empty file.
67 changes: 67 additions & 0 deletions tests/units/sc2/states/test_protoss_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.protoss_state import ProtossState


class TestProtossState(unittest.TestCase):


def test_protoss_state_no_raw(self):
# GIVEN
state = ProtossState(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_protoss_state_raw(self):
# GIVEN
state = ProtossState(grid_size=5, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 22 + ((5 * 5) * 2)
assert len(state.state[0]) == 22 + ((5 * 5) * 2)
67 changes: 67 additions & 0 deletions tests/units/sc2/states/test_terran_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.terran_state import TerranState


class TestTerranState(unittest.TestCase):


def test_terran_state_no_raw(self):
# GIVEN
state = TerranState(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_terran_state_raw(self):
# GIVEN
state = TerranState(grid_size=4, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 22 + ((4 * 4) * 2)
assert len(state.state[0]) == 22 + ((4 * 4) * 2)
158 changes: 158 additions & 0 deletions tests/units/sc2/states/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import unittest

from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.utils import (
append_player_and_enemy_grids,
create_raw_units_amount_dict,
)


class TestAuxSC2State(unittest.TestCase):

def test_append_player_and_enemy_grids(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
new_state = []
# WHEN
new_state = append_player_and_enemy_grids(obs, new_state, 3, 64)
# THEN
assert len(new_state) == (18)
assert new_state[8] == 0.005
assert new_state[9] == 0.01

def test_create_raw_units_amount_dict(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, 1)
# THEN
assert len(dict) == 2
assert dict[1] == 1
assert dict[2] == 1

def test_create_raw_units_amount_dict_alliance(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, 4)
# THEN
assert len(dict) == 1
assert dict[1] == 0
assert dict[2] == 1

def test_create_raw_units_amount_dict_no_units(self):
# GIVEN
obs = NamedDict({
'raw_units': []
})
# WHEN
dict = create_raw_units_amount_dict(obs)
# THEN
assert len(dict) == 0
assert dict == {}

def test_create_raw_units_amount_dict_no_units_alliance(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': 4,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, 3)
# THEN
assert len(dict) == 0
assert dict == {}
67 changes: 67 additions & 0 deletions tests/units/sc2/states/test_zerg_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.zerg_state import ZergState


class TestZergState(unittest.TestCase):


def test_zerg_state_no_raw(self):
# GIVEN
state = ZergState(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_zerg_state_raw(self):
# GIVEN
state = ZergState(grid_size=6, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 22 + ((6 * 6) * 2)
assert len(state.state[0]) == 22 + ((6 * 6) * 2)
8 changes: 8 additions & 0 deletions urnai/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import IntEnum


class SC2Constants(IntEnum):
MAX_UNITS = 200
MAX_MINERALS = 6000
MAX_VESPENE = 6000
BUILD_PROGRESS_COMPLETE = 100
Empty file added urnai/sc2/states/__init__.py
Empty file.
Loading