Skip to content

Commit

Permalink
feat: Added initial trainer to solve collectables
Browse files Browse the repository at this point in the history
  • Loading branch information
RickFqt committed Aug 12, 2024
1 parent cc69f3e commit c04fbba
Show file tree
Hide file tree
Showing 14 changed files with 2,197 additions and 8 deletions.
48 changes: 48 additions & 0 deletions sc2_collectables_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pathlib
import sys

sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent))

from absl import app
from pysc2.env import sc2_env

from urnai.models.dqn_pytorch import DQNPytorch
from urnai.sc2.actions.collectables import CollectablesActionSpace
from urnai.sc2.agents.sc2_agent import SC2Agent
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.rewards.collectables import CollectablesReward
from urnai.sc2.states.collectables import CollectablesState
from urnai.trainers.trainer import Trainer


def declare_trainer():
players = [sc2_env.Agent(sc2_env.Race.terran)]
env = SC2Env(map_name='CollectMineralShards', visualize=False,
step_mul=16, players=players)


action_space = CollectablesActionSpace()
state_builder = CollectablesState()
reward_builder = CollectablesReward()

model = DQNPytorch(action_space, state_builder)

agent = SC2Agent(action_space, state_builder, model, reward_builder)

trainer = Trainer(env, agent,
max_training_episodes=200, max_steps_training=100000,
max_playing_episodes=200, max_steps_playing=100000)
return trainer

def main(unused_argv):
try:
trainer = declare_trainer()
trainer.train()
# trainer.play()

except KeyboardInterrupt:
pass


if __name__ == '__main__':
app.run(main)
16 changes: 8 additions & 8 deletions urnai/agents/agent_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod

from urnai.actions.action_base import ActionBase
from urnai.actions.action_space_base import ActionSpaceBase
from urnai.models.model_base import ModelBase
from urnai.rewards.reward_base import RewardBase
from urnai.states.state_base import StateBase


Expand All @@ -11,7 +11,7 @@ class AgentBase(ABC):
def __init__(self, action_space : ActionSpaceBase,
state_space : StateBase,
model : ModelBase,
reward):
reward : RewardBase):

self.action_space = action_space
self.state_space = state_space
Expand All @@ -28,12 +28,12 @@ def __init__(self, action_space : ActionSpaceBase,
def step(self) -> None:
...

@abstractmethod
def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
"""
Method that contains the agent's strategy for choosing actions
"""
...
# @abstractmethod
# def choose_action(self, action_space : ActionSpaceBase) -> ActionBase:
# """
# Method that contains the agent's strategy for choosing actions
# """
# ...

def reset(self, episode=0) -> None:
"""
Expand Down
Loading

0 comments on commit c04fbba

Please sign in to comment.