Skip to content

Commit c04fbba

Browse files
committed
feat: Added initial trainer to solve collectables
1 parent cc69f3e commit c04fbba

File tree

14 files changed

+2197
-8
lines changed

14 files changed

+2197
-8
lines changed

sc2_collectables_trainer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pathlib
2+
import sys
3+
4+
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))
5+
6+
from absl import app
7+
from pysc2.env import sc2_env
8+
9+
from urnai.models.dqn_pytorch import DQNPytorch
10+
from urnai.sc2.actions.collectables import CollectablesActionSpace
11+
from urnai.sc2.agents.sc2_agent import SC2Agent
12+
from urnai.sc2.environments.sc2environment import SC2Env
13+
from urnai.sc2.rewards.collectables import CollectablesReward
14+
from urnai.sc2.states.collectables import CollectablesState
15+
from urnai.trainers.trainer import Trainer
16+
17+
18+
def declare_trainer():
19+
players = [sc2_env.Agent(sc2_env.Race.terran)]
20+
env = SC2Env(map_name='CollectMineralShards', visualize=False,
21+
step_mul=16, players=players)
22+
23+
24+
action_space = CollectablesActionSpace()
25+
state_builder = CollectablesState()
26+
reward_builder = CollectablesReward()
27+
28+
model = DQNPytorch(action_space, state_builder)
29+
30+
agent = SC2Agent(action_space, state_builder, model, reward_builder)
31+
32+
trainer = Trainer(env, agent,
33+
max_training_episodes=200, max_steps_training=100000,
34+
max_playing_episodes=200, max_steps_playing=100000)
35+
return trainer
36+
37+
def main(unused_argv):
38+
try:
39+
trainer = declare_trainer()
40+
trainer.train()
41+
# trainer.play()
42+
43+
except KeyboardInterrupt:
44+
pass
45+
46+
47+
if __name__ == '__main__':
48+
app.run(main)

urnai/agents/agent_base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from abc import ABC, abstractmethod
22

3-
from urnai.actions.action_base import ActionBase
43
from urnai.actions.action_space_base import ActionSpaceBase
54
from urnai.models.model_base import ModelBase
5+
from urnai.rewards.reward_base import RewardBase
66
from urnai.states.state_base import StateBase
77

88

@@ -11,7 +11,7 @@ class AgentBase(ABC):
1111
def __init__(self, action_space : ActionSpaceBase,
1212
state_space : StateBase,
1313
model : ModelBase,
14-
reward):
14+
reward : RewardBase):
1515

1616
self.action_space = action_space
1717
self.state_space = state_space
@@ -28,12 +28,12 @@ def __init__(self, action_space : ActionSpaceBase,
2828
def step(self) -> None:
2929
...
3030

31-
@abstractmethod
32-
def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
33-
"""
34-
Method that contains the agent's strategy for choosing actions
35-
"""
36-
...
31+
# @abstractmethod
32+
# def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
33+
# """
34+
# Method that contains the agent's strategy for choosing actions
35+
# """
36+
# ...
3737

3838
def reset(self, episode=0) -> None:
3939
"""

0 commit comments

Comments
 (0)