Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolves #73 #82

Merged
merged 11 commits into from
Jul 17, 2024
Empty file.
70 changes: 70 additions & 0 deletions tests/units/sc2/states/test_protoss_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest

from pysc2.env import sc2_env
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.protoss_state import ProtossState


class TestProtossState(unittest.TestCase):


def test_protoss_state_no_raw(self):
# GIVEN
state = ProtossState(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.protoss
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_protoss_state_raw(self):
# GIVEN
state = ProtossState(grid_size=5, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.protoss
assert state.dimension == 22 + ((5 * 5) * 2)
assert len(state.state[0]) == 22 + ((5 * 5) * 2)
67 changes: 67 additions & 0 deletions tests/units/sc2/states/test_starcraft2_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.starcraft2_state import StarCraft2State


class TestStarCraft2State(unittest.TestCase):


def test_starcraft2_state_no_raw(self):
# GIVEN
state = StarCraft2State(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_starcraft2_state_raw(self):
# GIVEN
state = StarCraft2State(grid_size=5, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.dimension == 9 + ((5 * 5) * 2)
assert len(state.state[0]) == 9 + ((5 * 5) * 2)
70 changes: 70 additions & 0 deletions tests/units/sc2/states/test_terran_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest

from pysc2.env import sc2_env
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.terran_state import TerranState


class TestTerranState(unittest.TestCase):


def test_terran_state_no_raw(self):
# GIVEN
state = TerranState(use_raw_units=False)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
})
})
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.terran
assert state.dimension == 9
assert len(state.state[0]) == 9

def test_terran_state_raw(self):
# GIVEN
state = TerranState(grid_size=4, use_raw_units=True)
obs = NamedDict({
'player': NamedDict({
'minerals': 100,
'vespene': 100,
'food_cap': 200,
'food_used': 100,
'food_army': 50,
'food_workers': 50,
'army_count': 20,
'idle_worker_count': 10,
}),
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': 1,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': 1,
'build_progress': 100,
'x': 2,
'y': 2,
}),
]
})
# WHEN
state.update(obs)
# THEN
assert state.player_race == sc2_env.Race.terran
assert state.dimension == 22 + ((4 * 4) * 2)
assert len(state.state[0]) == 22 + ((4 * 4) * 2)
159 changes: 159 additions & 0 deletions tests/units/sc2/states/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import unittest

from pysc2.lib import features
from pysc2.lib.named_array import NamedDict

from urnai.sc2.states.utils import (
append_player_and_enemy_grids,
create_raw_units_amount_dict,
)


class TestAuxSC2State(unittest.TestCase):

def test_append_player_and_enemy_grids(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
new_state = []
# WHEN
new_state = append_player_and_enemy_grids(obs, new_state, 3, 64)
# THEN
assert len(new_state) == (18)
assert new_state[8] == 0.005
assert new_state[9] == 0.01

def test_create_raw_units_amount_dict(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.SELF)
# THEN
assert len(dict) == 2
assert dict[1] == 1
assert dict[2] == 1

def test_create_raw_units_amount_dict_alliance(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.ENEMY)
# THEN
assert len(dict) == 1
assert dict[1] == 0
assert dict[2] == 1

def test_create_raw_units_amount_dict_no_units(self):
# GIVEN
obs = NamedDict({
'raw_units': []
})
# WHEN
dict = create_raw_units_amount_dict(obs)
# THEN
assert len(dict) == 0
assert dict == {}

def test_create_raw_units_amount_dict_no_units_alliance(self):
# GIVEN
obs = NamedDict({
'raw_units': [
NamedDict({
'unit_type': 1,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 1,
'y': 1,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.SELF,
'build_progress': 100,
'x': 2,
'y': 2,
}),
NamedDict({
'unit_type': 2,
'alliance': features.PlayerRelative.ENEMY,
'build_progress': 100,
'x': 63,
'y': 63,
}),
]
})
# WHEN
dict = create_raw_units_amount_dict(obs, features.PlayerRelative.NEUTRAL)
# THEN
assert len(dict) == 0
assert dict == {}
Loading