From 4890fdda4806315260c5dd02770d80948186c583 Mon Sep 17 00:00:00 2001 From: Ben Black Date: Wed, 20 Jan 2021 08:19:06 -0700 Subject: [PATCH] added two player support to ALE --- src/ale_interface.cpp | 92 +++++++++++-- src/ale_interface.hpp | 15 ++- src/ale_python_interface.hpp | 25 ++-- src/environment/ale_state.cpp | 176 +++++++++++-------------- src/environment/ale_state.hpp | 28 ++-- src/environment/stella_environment.cpp | 130 +++++++++++------- src/environment/stella_environment.hpp | 9 +- src/games/RomSettings.cpp | 48 ++++++- src/games/RomSettings.hpp | 13 ++ src/games/RomSettings2P.hpp | 67 ++++++++++ src/games/RomSettings4P.hpp | 63 +++++++++ src/games/supported/Pong.cpp | 51 ++++--- src/games/supported/Pong.hpp | 12 +- tests/test_python_interface.py | 13 ++ 14 files changed, 536 insertions(+), 206 deletions(-) create mode 100644 src/games/RomSettings2P.hpp create mode 100644 src/games/RomSettings4P.hpp diff --git a/src/ale_interface.cpp b/src/ale_interface.cpp index 312047ee8..cd50062ea 100644 --- a/src/ale_interface.cpp +++ b/src/ale_interface.cpp @@ -266,12 +266,49 @@ void ALEInterface::reset_game() { environment->reset(); } // Indicates if the game has ended. bool ALEInterface::game_over() const { return environment->isTerminal(); } +// The remaining number of lives. Can only be called in one player modes +int ALEInterface::lives(){ + if (romSettings == nullptr) { + throw std::runtime_error("ROM not set"); + } + else { + if (numPlayersActive() == 1) { + return romSettings->lives(); + } + else { + throw std::runtime_error("called `lives` in a multiplayer mode. Call allLives() instead."); + } + } +} + // The remaining number of lives. -int ALEInterface::lives() { +std::vector ALEInterface::allLives() { if (romSettings == nullptr) { throw std::runtime_error("ROM not set"); } else { - return romSettings->lives(); + int num_players = this->numPlayersActive(); + if(num_players == 1) { + return { + romSettings->lives() + }; + } + else if (num_players == 2) { + return { + romSettings->lives(), + romSettings->livesP2() + }; + } + else if(num_players == 4) { + return { + romSettings->lives(), + romSettings->livesP2(), + romSettings->livesP3(), + romSettings->livesP4() + }; + } + else{ + throw std::runtime_error("ALE only support 1,2 and 4 players"); + } } } @@ -280,37 +317,76 @@ int ALEInterface::lives() { // when necessary - this method will keep pressing buttons on the // game over screen. reward_t ALEInterface::act(Action action) { - reward_t reward = environment->act(action, PLAYER_B_NOOP); + reward_t reward = environment->act({action})[0]; if (theOSystem->p_display_screen != NULL) { theOSystem->p_display_screen->display_screen(); while (theOSystem->p_display_screen->manual_control_engaged()) { Action user_action = theOSystem->p_display_screen->getUserAction(); - reward += environment->act(user_action, PLAYER_B_NOOP); + reward += environment->act({user_action})[0]; theOSystem->p_display_screen->display_screen(); } } return reward; } +// Takes a vector of actions, one for each player in the game mode +// Does not allow user input from the screen +std::vector ALEInterface::act(std::vector action) { + if (romSettings == nullptr) { + throw std::runtime_error("ROM not set"); + } + if (action.size() != numPlayersActive()) { + throw std::runtime_error("number of players active in the mode is not equal to the action size given to act"); + } + + return environment->act(action); +} + // Returns the vector of modes available for the current game. // This should be called only after the rom is loaded. -ModeVect ALEInterface::getAvailableModes() { - return romSettings->getAvailableModes(); +ModeVect ALEInterface::getAvailableModes(int num_players) { + if(num_players == 1){ + return romSettings->getAvailableModes(); + } + else if(num_players == 2){ + return romSettings->get2PlayerModes(); + } + else if(num_players == 3){ + return ModeVect{}; + } + else if(num_players == 4){ + return romSettings->get4PlayerModes(); + } + else { + throw std::runtime_error(std::to_string(num_players)+" is not a valid number of players, only 1-2 players allowed."); + } } // Sets the mode of the game. // The mode must be an available mode. // This should be called only after the rom is loaded. void ALEInterface::setMode(game_mode_t m) { - //We first need to make sure m is an available mode + // We first need to make sure m is an available mode ModeVect available = romSettings->getAvailableModes(); - if (find(available.begin(), available.end(), m) != available.end()) { + ModeVect available2P = romSettings->get2PlayerModes(); + ModeVect available4P = romSettings->get4PlayerModes(); + + available.insert(available.end(),available2P.begin(),available2P.end()); + available.insert(available.end(),available4P.begin(),available4P.end()); + + if (std::find(available.begin(), available.end(), m) != available.end()) { environment->setMode(m); } else { throw std::runtime_error("Invalid game mode requested"); } } +// Number of players active in the current game mode +// also the, number of actions expected by act +int ALEInterface::numPlayersActive() { + return environment->getState().getNumActivePlayers(); +} + //Returns the vector of difficulties available for the current game. //This should be called only after the rom is loaded. DifficultyVect ALEInterface::getAvailableDifficulties() { diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 344fc72e0..79ce40176 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -92,6 +92,10 @@ class ALEInterface { // when necessary - this method will keep pressing buttons on the // game over screen. reward_t act(Action action); + // multiplayer version of the act function. + // takes in one action per player and + // returns one reward per player + std::vector act(std::vector action); // Indicates if the game has ended. bool game_over() const; @@ -101,7 +105,7 @@ class ALEInterface { // Returns the vector of modes available for the current game. // This should be called only after the rom is loaded. - ModeVect getAvailableModes(); + ModeVect getAvailableModes(int num_players=1); // Sets the mode of the game. // The mode must be an available mode (otherwise it throws an exception). @@ -113,6 +117,10 @@ class ALEInterface { // game mode changes only take effect when the environment is reset. game_mode_t getMode() const { return environment->getMode(); } + // Number of players active in the current game mode + // also the, number of actions expected by act + int numPlayersActive(); + //Returns the vector of difficulties available for the current game. //This should be called only after the rom is loaded. Notice // that there are 2 levers, the right and left switches. They @@ -145,9 +153,12 @@ class ALEInterface { // Returns the frame number since the loading of the ROM int getFrameNumber(); - // The remaining number of lives. + // The remaining number of lives for player 1. int lives(); + // lives for all players + std::vector allLives(); + // Returns the frame number since the start of the current episode int getEpisodeFrameNumber() const; diff --git a/src/ale_python_interface.hpp b/src/ale_python_interface.hpp index 4e2cdbfd5..c731e6dcf 100644 --- a/src/ale_python_interface.hpp +++ b/src/ale_python_interface.hpp @@ -42,10 +42,6 @@ class ALEPythonInterface : public ALEInterface { py::array_t getScreenRGB(); py::array_t getScreenGrayscale(); - inline reward_t act(unsigned int action) { - return ALEInterface::act((Action)action); - } - inline py::tuple getScreenDims() { const ALEScreen& screen = ALEInterface::getScreen(); return py::make_tuple(screen.height(), screen.width()); @@ -58,6 +54,16 @@ class ALEPythonInterface : public ALEInterface { } // namespace ale + +inline std::vector convert(std::vector a){ + std::vector v(a.size()); + for (size_t i = 0; i < a.size(); i++) { + v[i] = (ale::Action)(a[i]); + } + return v; +} + + PYBIND11_MODULE(ale_py, m) { m.attr("__version__") = py::str(ALE_VERSION_STR); #ifdef __USE_SDL @@ -128,13 +134,15 @@ PYBIND11_MODULE(ale_py, m) { .def("setFloat", &ale::ALEPythonInterface::setFloat) .def("loadROM", &ale::ALEPythonInterface::loadROM) .def_static("isSupportedRom", &ale::ALEPythonInterface::isSupportedRom) - .def("act", (ale::reward_t(ale::ALEPythonInterface::*)(uint32_t)) & - ale::ALEPythonInterface::act) - .def("act", (ale::reward_t(ale::ALEInterface::*)(ale::Action)) & - ale::ALEInterface::act) + .def("act", [](ale::ALEPythonInterface & ale, uint32_t a){ return ale.act((ale::Action)(a)); }) + .def("act", [](ale::ALEPythonInterface & ale, ale::Action a){ return ale.act(a); }) + .def("act", [](ale::ALEPythonInterface & ale, std::vector a){ return ale.act(a); }) + .def("act", [](ale::ALEPythonInterface & ale, std::vector a){ return ale.act(convert(a)); }) .def("game_over", &ale::ALEPythonInterface::game_over) .def("reset_game", &ale::ALEPythonInterface::reset_game) + .def("numPlayersActive", &ale::ALEPythonInterface::numPlayersActive) .def("getAvailableModes", &ale::ALEPythonInterface::getAvailableModes) + .def("getAvailableModes", [](ale::ALEPythonInterface & ale){ return ale.getAvailableModes(); }) .def("setMode", &ale::ALEPythonInterface::setMode) .def("getAvailableDifficulties", &ale::ALEPythonInterface::getAvailableDifficulties) @@ -143,6 +151,7 @@ PYBIND11_MODULE(ale_py, m) { .def("getMinimalActionSet", &ale::ALEPythonInterface::getMinimalActionSet) .def("getFrameNumber", &ale::ALEPythonInterface::getFrameNumber) .def("lives", &ale::ALEPythonInterface::lives) + .def("allLives", &ale::ALEPythonInterface::allLives) .def("getEpisodeFrameNumber", &ale::ALEPythonInterface::getEpisodeFrameNumber) .def("getScreen", (void (ale::ALEPythonInterface::*)( diff --git a/src/environment/ale_state.cpp b/src/environment/ale_state.cpp index 088a0d341..486bb1a0f 100644 --- a/src/environment/ale_state.cpp +++ b/src/environment/ale_state.cpp @@ -28,34 +28,42 @@ namespace ale { /** Default constructor - loads settings from system */ ALEState::ALEState() - : m_left_paddle(PADDLE_DEFAULT_VALUE), - m_right_paddle(PADDLE_DEFAULT_VALUE), - m_paddle_min(PADDLE_MIN), + : m_paddle_min(PADDLE_MIN), m_paddle_max(PADDLE_MAX), m_frame_number(0), m_episode_frame_number(0), m_mode(0), - m_difficulty(0) {} + m_difficulty(0), + m_num_players(1) { + for (int i = 0; i < 4; i++) { + m_paddle[i] = PADDLE_DEFAULT_VALUE; + } + } ALEState::ALEState(const ALEState& rhs, const std::string& serialized) - : m_left_paddle(rhs.m_left_paddle), - m_right_paddle(rhs.m_right_paddle), - m_paddle_min(rhs.m_paddle_min), + : m_paddle_min(rhs.m_paddle_min), m_paddle_max(rhs.m_paddle_max), m_frame_number(rhs.m_frame_number), m_episode_frame_number(rhs.m_episode_frame_number), m_serialized_state(serialized), m_mode(rhs.m_mode), - m_difficulty(rhs.m_difficulty) {} + m_difficulty(rhs.m_difficulty), + m_num_players(rhs.m_num_players) { + for (int i = 0; i < 4; i++) { + m_paddle[i] = rhs.m_paddle[i]; + } + } ALEState::ALEState(const std::string& serialized) { Deserializer des(serialized); - this->m_left_paddle = des.getInt(); - this->m_right_paddle = des.getInt(); + for (int i = 0; i < 4; i++) { + this->m_paddle[i] = des.getInt(); + } this->m_frame_number = des.getInt(); this->m_episode_frame_number = des.getInt(); this->m_mode = des.getInt(); this->m_difficulty = des.getInt(); + this->m_num_players = des.getInt(); this->m_serialized_state = des.getString(); this->m_paddle_min = des.getInt(); this->m_paddle_max = des.getInt(); @@ -81,14 +89,16 @@ void ALEState::load(OSystem* osystem, RomSettings* settings, std::string md5, settings->loadState(deser); // Copy over other member variables - m_left_paddle = rhs.m_left_paddle; - m_right_paddle = rhs.m_right_paddle; + for (int i = 0; i < 4; i++) { + m_paddle[i] = rhs.m_paddle[i]; + } m_paddle_min = rhs.m_paddle_min; m_paddle_max = rhs.m_paddle_max; m_frame_number = rhs.m_frame_number; m_episode_frame_number = rhs.m_episode_frame_number; m_mode = rhs.m_mode; m_difficulty = rhs.m_difficulty; + m_num_players = rhs.m_num_players; } ALEState ALEState::save(OSystem* osystem, RomSettings* settings, @@ -118,12 +128,14 @@ void ALEState::resetEpisodeFrameNumber() { m_episode_frame_number = 0; } std::string ALEState::serialize() { Serializer ser; - ser.putInt(this->m_left_paddle); - ser.putInt(this->m_right_paddle); + for (int i = 0; i < 4; i++) { + ser.putInt(this->m_paddle[i]); + } ser.putInt(this->m_frame_number); ser.putInt(this->m_episode_frame_number); ser.putInt(this->m_mode); ser.putInt(this->m_difficulty); + ser.putInt(this->m_num_players); ser.putString(this->m_serialized_state); ser.putInt(this->m_paddle_min); ser.putInt(this->m_paddle_max); @@ -140,20 +152,25 @@ int ALEState::calcPaddleResistance(int x_val) { void ALEState::resetPaddles(Event* event) { int paddle_default = (m_paddle_min + m_paddle_max) / 2; - setPaddles(event, paddle_default, paddle_default); + for (int i = 0; i < 4; i++) { + setPaddle(event, paddle_default, i); + } } -void ALEState::setPaddles(Event* event, int left, int right) { - m_left_paddle = left; - m_right_paddle = right; +void ALEState::setPaddle(Event* event, int paddle_val, int paddle_num) { + m_paddle[paddle_num] = paddle_val; // Compute the "resistance" (this is for vestigal clarity) - int left_resistance = calcPaddleResistance(m_left_paddle); - int right_resistance = calcPaddleResistance(m_right_paddle); - + int resitance = calcPaddleResistance(paddle_val); + + Event::Type paddle_resists[] = { + Event::PaddleZeroResistance, + Event::PaddleOneResistance, + Event::PaddleTwoResistance, + Event::PaddleThreeResistance + }; // Update the events with the new resistances - event->set(Event::PaddleZeroResistance, left_resistance); - event->set(Event::PaddleOneResistance, right_resistance); + event->set(paddle_resists[paddle_num], resitance); } void ALEState::setPaddleLimits(int paddle_min_val, int paddle_max_val) { @@ -164,51 +181,37 @@ void ALEState::setPaddleLimits(int paddle_min_val, int paddle_max_val) { } /* ********************************************************************* - * Updates the positions of the paddles, and sets an event for - * updating the corresponding paddle's resistance + * Updates the positions of the paddle indicated by paddle_num, + * and sets an event for updating the corresponding paddle's resistance * ********************************************************************/ -void ALEState::updatePaddlePositions(Event* event, int delta_left, - int delta_right) { +void ALEState::updatePaddlePosition(Event* event, int delta, + int paddle_num) { // Cap paddle outputs - - m_left_paddle += delta_left; - if (m_left_paddle < m_paddle_min) { - m_left_paddle = m_paddle_min; + m_paddle[paddle_num] += delta; + if (m_paddle[paddle_num] < m_paddle_min) { + m_paddle[paddle_num] = m_paddle_min; } - if (m_left_paddle > m_paddle_max) { - m_left_paddle = m_paddle_max; - } - - m_right_paddle += delta_right; - if (m_right_paddle < m_paddle_min) { - m_right_paddle = m_paddle_min; - } - if (m_right_paddle > m_paddle_max) { - m_right_paddle = m_paddle_max; + if (m_paddle[paddle_num] > m_paddle_max) { + m_paddle[paddle_num] = m_paddle_max; } // Now set the paddle to their new value - setPaddles(event, m_left_paddle, m_right_paddle); + setPaddle(event, m_paddle[paddle_num], paddle_num); } -void ALEState::applyActionPaddles(Event* event, int player_a_action, - int player_b_action) { - // Reset keys - resetKeys(event); - +// Apply the action for the paddle given by pnum +void ALEState::applyActionPaddle(Event* event, int action, int pnum) { // First compute whether we should increase or decrease the paddle position - // (for both left and right players) - int delta_left; - int delta_right; + int delta; - switch (player_a_action) { + switch (action) { case PLAYER_A_RIGHT: case PLAYER_A_RIGHTFIRE: case PLAYER_A_UPRIGHT: case PLAYER_A_DOWNRIGHT: case PLAYER_A_UPRIGHTFIRE: case PLAYER_A_DOWNRIGHTFIRE: - delta_left = -PADDLE_DELTA; + delta = -PADDLE_DELTA; break; case PLAYER_A_LEFT: @@ -217,45 +220,28 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, case PLAYER_A_DOWNLEFT: case PLAYER_A_UPLEFTFIRE: case PLAYER_A_DOWNLEFTFIRE: - delta_left = PADDLE_DELTA; + delta = PADDLE_DELTA; break; default: - delta_left = 0; + delta = 0; break; } - switch (player_b_action) { - case PLAYER_B_RIGHT: - case PLAYER_B_RIGHTFIRE: - case PLAYER_B_UPRIGHT: - case PLAYER_B_DOWNRIGHT: - case PLAYER_B_UPRIGHTFIRE: - case PLAYER_B_DOWNRIGHTFIRE: - delta_right = -PADDLE_DELTA; - break; - - case PLAYER_B_LEFT: - case PLAYER_B_LEFTFIRE: - case PLAYER_B_UPLEFT: - case PLAYER_B_DOWNLEFT: - case PLAYER_B_UPLEFTFIRE: - case PLAYER_B_DOWNLEFTFIRE: - delta_right = PADDLE_DELTA; - break; - default: - delta_right = 0; - break; - } - - // Now update the paddle positions - updatePaddlePositions(event, delta_left, delta_right); + // Now update the paddle position + updatePaddlePosition(event, delta, pnum); // Handle reset - if (player_a_action == RESET || player_b_action == RESET) + if (action == RESET) event->set(Event::ConsoleReset, 1); + Event::Type paddle_fires[] = { + Event::PaddleZeroFire, + Event::PaddleOneFire, + Event::PaddleTwoFire, + Event::PaddleThreeFire + }; // Now add the fire event - switch (player_a_action) { + switch (action) { case PLAYER_A_FIRE: case PLAYER_A_UPFIRE: case PLAYER_A_RIGHTFIRE: @@ -265,24 +251,7 @@ void ALEState::applyActionPaddles(Event* event, int player_a_action, case PLAYER_A_UPLEFTFIRE: case PLAYER_A_DOWNRIGHTFIRE: case PLAYER_A_DOWNLEFTFIRE: - event->set(Event::PaddleZeroFire, 1); - break; - default: - // Nothing - break; - } - - switch (player_b_action) { - case PLAYER_B_FIRE: - case PLAYER_B_UPFIRE: - case PLAYER_B_RIGHTFIRE: - case PLAYER_B_LEFTFIRE: - case PLAYER_B_DOWNFIRE: - case PLAYER_B_UPRIGHTFIRE: - case PLAYER_B_UPLEFTFIRE: - case PLAYER_B_DOWNRIGHTFIRE: - case PLAYER_B_DOWNLEFTFIRE: - event->set(Event::PaddleOneFire, 1); + event->set(paddle_fires[pnum], 1); break; default: // Nothing @@ -522,6 +491,8 @@ void ALEState::resetKeys(Event* event) { // also reset paddle fire event->set(Event::PaddleZeroFire, 0); event->set(Event::PaddleOneFire, 0); + event->set(Event::PaddleTwoFire, 0); + event->set(Event::PaddleThreeFire, 0); // Set the difficulty switches accordingly for this time step. setDifficultySwitches(event, m_difficulty); @@ -529,11 +500,12 @@ void ALEState::resetKeys(Event* event) { bool ALEState::equals(ALEState& rhs) { return (rhs.m_serialized_state == this->m_serialized_state && - rhs.m_left_paddle == this->m_left_paddle && - rhs.m_right_paddle == this->m_right_paddle && + std::equal(rhs.m_paddle,rhs.m_paddle+4,this->m_paddle) && rhs.m_frame_number == this->m_frame_number && rhs.m_episode_frame_number == this->m_episode_frame_number && - rhs.m_mode == this->m_mode && rhs.m_difficulty == this->m_difficulty); + rhs.m_mode == this->m_mode && + rhs.m_difficulty == this->m_difficulty && + rhs.m_num_players == this->m_num_players); } } // namespace ale diff --git a/src/environment/ale_state.hpp b/src/environment/ale_state.hpp index 52b1147d4..0b7ab1052 100644 --- a/src/environment/ale_state.hpp +++ b/src/environment/ale_state.hpp @@ -57,10 +57,10 @@ class ALEState { //Apply the special select action void pressSelect(Event* event_obj); - /** Applies paddle actions. This actually modifies the game state by updating the paddle - * resistances. */ - void applyActionPaddles(Event* event_obj, int player_a_action, - int player_b_action); + /** Applies paddle action for paddle for player pnm. + * This actually modifies the game state by updating the paddle + * resistances. */ + void applyActionPaddle(Event* event_obj, int action, int pnum); /** Sets the joystick events. No effect until the emulator is run forward. */ void setActionJoysticks(Event* event_obj, int player_a_action, int player_b_action); @@ -90,8 +90,17 @@ class ALEState { //Get the current mode we are in. game_mode_t getCurrentMode() const { return m_mode; } + //Save the number of players in the current mode. + void setNumActivePlayers(int value) { m_num_players = value; } + + //Get the number of players in the current mode. + int getNumActivePlayers() const { return m_num_players; } + std::string serialize(); + /** Reset key presses */ + void resetKeys(Event* event_obj); + protected: // Let StellaEnvironment access these methods: they are needed for emulation purposes friend class StellaEnvironment; @@ -107,17 +116,14 @@ class ALEState { ALEState save(OSystem* osystem, RomSettings* settings, std::string md5, bool save_system); - /** Reset key presses */ - void resetKeys(Event* event_obj); - /** Sets the paddle to a given position */ - void setPaddles(Event* event_obj, int left, int right); + void setPaddle(Event* event_obj, int paddle_val, int paddle_num); /** Set the paddle min/max values */ void setPaddleLimits(int paddle_min_val, int paddle_max_val); /** Updates the paddle position by a delta amount. */ - void updatePaddlePositions(Event* event_obj, int delta_x, int delta_y); + void updatePaddlePosition(Event* event_obj, int delta, int paddle_num); /** Calculates the Paddle resistance, based on the given x val */ int calcPaddleResistance(int x_val); @@ -126,8 +132,7 @@ class ALEState { void setDifficultySwitches(Event* event_obj, unsigned int value); private: - int m_left_paddle; // Current value for the left-paddle - int m_right_paddle; // Current value for the right-paddle + int m_paddle[4]; // Current value for the paddles int m_paddle_min; // Minimum value for paddle int m_paddle_max; // Maximum value for paddle @@ -139,6 +144,7 @@ class ALEState { game_mode_t m_mode; // The current mode we are in difficulty_t m_difficulty; // The current difficulty we are in + int m_num_players; // The number of players in the current mode }; } // namespace ale diff --git a/src/environment/stella_environment.cpp b/src/environment/stella_environment.cpp index 8de7ade05..e48cf0c1a 100644 --- a/src/environment/stella_environment.cpp +++ b/src/environment/stella_environment.cpp @@ -15,7 +15,7 @@ * **************************************************************************** */ -#include "environment/stella_environment.hpp" +#include "stella_environment.hpp" #include @@ -29,8 +29,7 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings) m_phosphor_blend(osystem), m_screen(m_osystem->console().mediaSource().height(), m_osystem->console().mediaSource().width()), - m_player_a_action(PLAYER_A_NOOP), - m_player_b_action(PLAYER_B_NOOP) { + m_actions(4, PLAYER_A_NOOP) { // Determine whether this is a paddle-based game if (m_osystem->console().properties().get(Controller_Left) == "PADDLES" || m_osystem->console().properties().get(Controller_Right) == "PADDLES") { @@ -47,7 +46,7 @@ StellaEnvironment::StellaEnvironment(OSystem* osystem, RomSettings* settings) m_cartridge_md5 = m_osystem->console().properties().get(Cartridge_MD5); // Set current mode to the ROM's default mode - m_state.setCurrentMode(settings->getDefaultMode()); + setMode(settings->getDefaultMode()); m_max_num_frames_per_episode = m_osystem->settings().getInt("max_num_frames_per_episode"); @@ -101,7 +100,7 @@ void StellaEnvironment::reset() { // Apply necessary actions specified by the rom itself ActionVect startingActions = m_settings->getStartingActions(); for (size_t i = 0; i < startingActions.size(); i++) { - emulate(startingActions[i], PLAYER_B_NOOP); + emulate({startingActions[i], startingActions[i]}); } } @@ -137,50 +136,37 @@ void StellaEnvironment::restoreSystemState(const ALEState& target_state) { m_state.load(m_osystem, m_settings, m_cartridge_md5, target_state, true); } -void StellaEnvironment::noopIllegalActions(Action& player_a_action, - Action& player_b_action) { - if (player_a_action < (Action)PLAYER_B_NOOP && - !m_settings->isLegal(player_a_action)) { - player_a_action = (Action)PLAYER_A_NOOP; +void StellaEnvironment::noopIllegalAction(Action& action) { + if ((!m_settings->isLegal(action) && action < (Action)PLAYER_B_NOOP) || action == RESET) { + action = PLAYER_A_NOOP; } - // Also drop RESET, which doesn't play nice with our clean notions of RL environments - else if (player_a_action == RESET) - player_a_action = (Action)PLAYER_A_NOOP; - - if (player_b_action < (Action)RESET && - !m_settings->isLegal((Action)((int)player_b_action - PLAYER_B_NOOP))) { - player_b_action = (Action)PLAYER_B_NOOP; - } else if (player_b_action == RESET) - player_b_action = (Action)PLAYER_B_NOOP; } reward_t StellaEnvironment::act(Action player_a_action, Action player_b_action) { + return act(std::vector{player_a_action,(Action)(player_b_action - PLAYER_B_NOOP)}).at(0); +} + +std::vector StellaEnvironment::act(std::vector actions) { // Total reward received as we repeat the action - reward_t sum_rewards = 0; + std::vector sum_rewards(actions.size(),0); Random& rng = m_osystem->rng(); // Apply the same action for a given number of times... note that act() will refuse to emulate // past the terminal state - for (size_t i = 0; i < m_frame_skip; i++) { - // Stochastically drop actions, according to m_repeat_action_probability - if (rng.nextDouble() >= m_repeat_action_probability) - m_player_a_action = player_a_action; - // @todo Possibly optimize by avoiding call to rand() when player B is "off" ? - if (rng.nextDouble() >= m_repeat_action_probability) - m_player_b_action = player_b_action; - - // If so desired, request one frame's worth of sound (this does nothing if recording - // is not enabled) - m_osystem->sound().recordNextFrame(); - - // Similarly record screen as needed - if (m_screen_exporter.get() != NULL) - m_screen_exporter->saveNext(m_screen); - - // Use the stored actions, which may or may not have changed this frame - sum_rewards += oneStepAct(m_player_a_action, m_player_b_action); + for (size_t j = 0; j < m_frame_skip; j++) { + // Stochastically drop actions, according to mm_repeat_action_probability + for (size_t i = 0; i < 4; i++) { + if (i < actions.size()) { + if (rng.nextDouble() >= m_repeat_action_probability) + m_actions[i] = actions[i]; + } + else { + m_actions[i] = PLAYER_A_NOOP; + } + } + oneStepAct(m_actions, sum_rewards); } return sum_rewards; @@ -191,29 +177,48 @@ void StellaEnvironment::softReset() { emulate(RESET, PLAYER_B_NOOP, m_num_reset_steps); // Reset previous actions to NOOP for correct action repeating - m_player_a_action = PLAYER_A_NOOP; - m_player_b_action = PLAYER_B_NOOP; + for (Action & a : m_actions) { + a = PLAYER_A_NOOP; + } } /** Applies the given actions (e.g. updating paddle positions when the paddle is used) * and performs one simulation step in Stella. */ -reward_t StellaEnvironment::oneStepAct(Action player_a_action, - Action player_b_action) { +void StellaEnvironment::oneStepAct(std::vector actions,std::vector & rewards) { // Once in a terminal state, refuse to go any further (special actions must be handled // outside of this environment; in particular reset() should be called rather than passing // RESET or SYSTEM_RESET. if (isTerminal()) - return 0; + return; + + // If so desired, request one frame's worth of sound (this does nothing if recording + // is not enabled) + m_osystem->sound().recordNextFrame(); + + // Similarly record screen as needed + if (m_screen_exporter.get() != NULL) + m_screen_exporter->saveNext(m_screen); // Convert illegal actions into NOOPs; actions such as reset are always legal - noopIllegalActions(player_a_action, player_b_action); + for(Action & a : actions){ + noopIllegalAction(a); + } // Emulate in the emulator - emulate(player_a_action, player_b_action); + emulate(actions); // Increment the number of frames seen so far m_state.incrementFrame(); - return m_settings->getReward(); + rewards.at(0) += m_settings->getReward(); + if(rewards.size() > 1){ + rewards.at(1) += m_settings->getRewardP2(); + } + if(rewards.size() > 2){ + rewards.at(2) += m_settings->getRewardP3(); + } + if(rewards.size() > 3){ + rewards.at(3) += m_settings->getRewardP4(); + } } bool StellaEnvironment::isTerminal() const { @@ -237,26 +242,57 @@ void StellaEnvironment::setDifficulty(difficulty_t value) { m_state.setDifficulty(value); } +// helper function for setMode +bool in_modes(const ModeVect & modes, game_mode_t m){ + return std::find(modes.begin(), modes.end(), m) != modes.end(); +} + void StellaEnvironment::setMode(game_mode_t value) { + int num_players; + if (in_modes(m_settings->getAvailableModes(), value)) { + num_players = 1; + } + else if (in_modes(m_settings->get2PlayerModes(), value)) { + num_players = 2; + } + else if(in_modes(m_settings->get4PlayerModes(), value)){ + num_players = 4; + } + else { + throw std::runtime_error("Invalid game mode requested"); + } + m_state.setNumActivePlayers(num_players); m_state.setCurrentMode(value); } void StellaEnvironment::emulate(Action player_a_action, Action player_b_action, size_t num_steps) { + emulate({player_a_action,(Action)(player_b_action - PLAYER_B_NOOP)},num_steps); +} +void StellaEnvironment::emulate(std::vector actions, + size_t num_steps) { Event* event = m_osystem->event(); + for(Action a : actions){ + assert ((a < PLAYER_B_NOOP || a >= RESET) && "Actions in multiplayer cannot use the PLAYER_B actions. Rather, action lists should indicate the player by the position in the input vector"); + } // Handle paddles separately: we have to manually update the paddle positions at each step if (m_use_paddles) { // Run emulator forward for 'num_steps' for (size_t t = 0; t < num_steps; t++) { // Update paddle position at every step - m_state.applyActionPaddles(event, player_a_action, player_b_action); + m_state.resetKeys(event); + for (size_t p = 0; p < actions.size(); p++) { + m_state.applyActionPaddle(event, actions[p], p); + } m_osystem->console().mediaSource().update(); m_settings->step(m_osystem->console().system()); } } else { // In joystick mode we only need to set the action events once + Action player_b_action = actions.size() >= 2 ? (Action)(actions[1] + PLAYER_B_NOOP) : PLAYER_B_NOOP; + Action player_a_action = actions.at(0); m_state.setActionJoysticks(event, player_a_action, player_b_action); for (size_t t = 0; t < num_steps; t++) { diff --git a/src/environment/stella_environment.hpp b/src/environment/stella_environment.hpp index b7bbf801e..5d50bbfdf 100644 --- a/src/environment/stella_environment.hpp +++ b/src/environment/stella_environment.hpp @@ -66,6 +66,7 @@ class StellaEnvironment { * number plus the frame skip. */ reward_t act(Action player_a_action, Action player_b_action); + std::vector act(std::vector actions); /** This functions emulates a push on the reset button of the console */ void softReset(); @@ -114,15 +115,17 @@ class StellaEnvironment { private: /** This applies an action exactly one time step. Helper function to act(). */ - reward_t oneStepAct(Action player_a_action, Action player_b_action); + void oneStepAct(std::vector actions,std::vector & rewards); /** Actually emulates the emulator for a given number of steps. */ void emulate(Action player_a_action, Action player_b_action, size_t num_steps = 1); + void emulate(std::vector actions, size_t num_steps = 1); + /** Drops illegal actions, such as the fire button in skiing. Note that this is different * from the minimal set of actions. */ - void noopIllegalActions(Action& player_a_action, Action& player_b_action); + void noopIllegalAction(Action& action); /** Processes the current emulator screen and saves it in m_screen */ void processScreen(); @@ -152,7 +155,7 @@ class StellaEnvironment { std::unique_ptr m_screen_exporter; // Automatic screen recorder // The last actions taken by our players - Action m_player_a_action, m_player_b_action; + std::vector m_actions; }; } // namespace ale diff --git a/src/games/RomSettings.cpp b/src/games/RomSettings.cpp index d7f3fe0e0..c5da25eb5 100644 --- a/src/games/RomSettings.cpp +++ b/src/games/RomSettings.cpp @@ -67,10 +67,15 @@ game_mode_t RomSettings::getDefaultMode() { // By default, return the first available mode, or 0 if none are listed ModeVect available_modes = getAvailableModes(); if (available_modes.empty()) { - return 0; - } else { - return available_modes[0]; + available_modes = get2PlayerModes(); + if (available_modes.empty()) { + available_modes = get4PlayerModes(); + if (available_modes.empty()) { + return 0; + } + } } + return available_modes[0]; } DifficultyVect RomSettings::getAvailableDifficulties() { @@ -82,4 +87,41 @@ bool RomSettings::isModeSupported(game_mode_t m) { return std::find(modes.begin(), modes.end(), m) != modes.end(); } +reward_t RomSettings::getRewardP2() const { + return 0; +} + +int RomSettings::livesP2() { + throw std::logic_error("2 player method used for 1 player game"); +} + +ModeVect RomSettings::get2PlayerModes() { + // return no modes for 2 players by default + return ModeVect{}; +} + +reward_t RomSettings::getRewardP3() const { + throw std::logic_error("4 player method used game that does not support 4 players"); +} + +reward_t RomSettings::getRewardP4() const { + throw std::logic_error("4 player method used game that does not support 4 players"); +} + +int RomSettings::livesP3() { + throw std::logic_error("4 player method used game that does not support 4 players"); +} + +int RomSettings::livesP4() { + throw std::logic_error("4 player method used game that does not support 4 players"); +} + +ModeVect RomSettings::get3PlayerModes() { + return ModeVect{}; +} + +ModeVect RomSettings::get4PlayerModes() { + return ModeVect{}; +} + } // namespace ale diff --git a/src/games/RomSettings.hpp b/src/games/RomSettings.hpp index c35894dc7..281469da7 100644 --- a/src/games/RomSettings.hpp +++ b/src/games/RomSettings.hpp @@ -125,6 +125,19 @@ class RomSettings { // By default, there is only one available difficulty. virtual DifficultyVect getAvailableDifficulties(); + //two player methods. all fail when on a single player game + virtual reward_t getRewardP2() const; + virtual int livesP2(); + virtual ModeVect get2PlayerModes(); + + // methods for 4 player games. all raise an error when called by default + virtual reward_t getRewardP3() const; + virtual reward_t getRewardP4() const; + virtual int livesP3(); + virtual int livesP4(); + virtual ModeVect get3PlayerModes(); + virtual ModeVect get4PlayerModes(); + protected: // Helper function that checks if our settings support this given mode. bool isModeSupported(game_mode_t m); diff --git a/src/games/RomSettings2P.hpp b/src/games/RomSettings2P.hpp new file mode 100644 index 000000000..d7688a14f --- /dev/null +++ b/src/games/RomSettings2P.hpp @@ -0,0 +1,67 @@ +/* ***************************************************************************** + * The line 78 is based on Xitari's code, from Google Inc. + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + * ***************************************************************************** + * A.L.E (Arcade Learning Environment) + * Copyright (c) 2009-2013 by Yavar Naddaf, Joel Veness, Marc G. Bellemare and + * the Reinforcement Learning and Artificial Intelligence Laboratory + * Released under the GNU General Public License; see License.txt for details. + * + * Based on: Stella -- "An Atari 2600 VCS Emulator" + * Copyright (c) 1995-2007 by Bradford W. Mott and the Stella team + * + * ***************************************************************************** + * + * RomSettings2P.hpp + * + * The interface to describe two player games as RL environments. It provides terminal + * and reward information. + * ***************************************************************************** + */ + +#ifndef __ROMSETTINGS2P_HPP__ +#define __ROMSETTINGS2P_HPP__ + +#include "RomSettings.hpp" + + +namespace ale { + +// rom support interface +class RomSettings2P : public RomSettings { + public: + RomSettings2P() {} + + virtual ~RomSettings2P() {} + + // get the most recently observed reward for player 2 + virtual reward_t getRewardP2() const = 0; + + // Remaining lives. + virtual int livesP2() { + return isTerminal() ? 0 : 1; + } + + // Returns a list of mode that the game can be played with two players. + // note that this list should be disjoint from getAvailableModes + virtual ModeVect get2PlayerModes() = 0; + + // there must be a setMode method implemented for two player games + virtual void setMode(game_mode_t m, System&, std::unique_ptr) = 0; +}; + +} // namespace ale + +#endif // __ROMSETTINGS2P_HPP__ diff --git a/src/games/RomSettings4P.hpp b/src/games/RomSettings4P.hpp new file mode 100644 index 000000000..8010c6978 --- /dev/null +++ b/src/games/RomSettings4P.hpp @@ -0,0 +1,63 @@ +/* ***************************************************************************** + * The line 78 is based on Xitari's code, from Google Inc. + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + * ***************************************************************************** + * A.L.E (Arcade Learning Environment) + * Copyright (c) 2009-2013 by Yavar Naddaf, Joel Veness, Marc G. Bellemare and + * the Reinforcement Learning and Artificial Intelligence Laboratory + * Released under the GNU General Public License; see License.txt for details. + * + * Based on: Stella -- "An Atari 2600 VCS Emulator" + * Copyright (c) 1995-2007 by Bradford W. Mott and the Stella team + * + * ***************************************************************************** + * + * RomSettings4P.hpp + * + * The interface to describe games as RL environments. It provides terminal and + * reward information. + * ***************************************************************************** + */ + +#ifndef __ROMSETTINGS4P_HPP__ +#define __ROMSETTINGS4P_HPP__ + +#include "RomSettings2P.hpp" + + +namespace ale { + +// rom support interface +class RomSettings4P : public RomSettings2P { + public: + RomSettings4P() {} + + virtual ~RomSettings4P() {} + + // gets reward for players 3 and 4 + virtual reward_t getRewardP3() const = 0; + virtual reward_t getRewardP4() const = 0; + // gets lives left for players 3 and 4 + virtual int livesP3() { return isTerminal() ? 0 : 1; } + virtual int livesP4() { return isTerminal() ? 0 : 1; } + // gets list of avaliable modes for player 4 + virtual ModeVect get4PlayerModes() = 0; + // enforces that 4 player games must override the setMode function + virtual void setMode(game_mode_t m, System&, std::unique_ptr) = 0; +}; + +} // namespace ale + +#endif // __ROMSETTINGS4P_HPP__ diff --git a/src/games/supported/Pong.cpp b/src/games/supported/Pong.cpp index d4ae83680..09f1d1432 100644 --- a/src/games/supported/Pong.cpp +++ b/src/games/supported/Pong.cpp @@ -42,6 +42,11 @@ bool PongSettings::isTerminal() const { return m_terminal; }; /* get the most recently observed reward */ reward_t PongSettings::getReward() const { return m_reward; } +reward_t PongSettings::getRewardP2() const { return -m_reward; } +//P3 is on same team as P1, P2-P4 +reward_t PongSettings::getRewardP3() const { return m_reward; }; +reward_t PongSettings::getRewardP4() const { return -m_reward; }; + /* is an action part of the minimal set? */ bool PongSettings::isMinimal(const Action& a) const { @@ -81,11 +86,28 @@ void PongSettings::loadState(Deserializer& ser) { // returns a list of mode that the game can be played in ModeVect PongSettings::getAvailableModes() { - ModeVect modes(getNumModes()); - for (unsigned int i = 0; i < modes.size(); i++) { - modes[i] = i; - } - return modes; + return {1, 2}; +} +ModeVect PongSettings::get2PlayerModes() { + return {3, 4, + 9, 10, + 13,14, + 19,20, + 23,24,25,26,27,28, + 35,36, + 39,40, + 43,44,45,46}; +} +ModeVect PongSettings::get4PlayerModes() { + return {5,6,7,8, + 11,12, + 15,16,17,18, + 21,22, + 29,30,31,32, + 33,34, + 37,38, + 41,42, + 47,48,49,50}; } // set the mode of the game @@ -93,19 +115,14 @@ ModeVect PongSettings::getAvailableModes() { void PongSettings::setMode( game_mode_t m, System& system, std::unique_ptr environment) { - if (m < getNumModes()) { - // read the mode we are currently in - unsigned char mode = readRam(&system, 0x96); - // press select until the correct mode is reached - while (mode != m) { - environment->pressSelect(2); - mode = readRam(&system, 0x96); - } - //reset the environment to apply changes. - environment->softReset(); - } else { - throw std::runtime_error("This mode doesn't currently exist for this game"); + game_mode_t target = m - 1; + + // press select until the correct mode is reached + while (readRam(&system, 0x96) != target) { + environment->pressSelect(2); } + //reset the environment to apply changes. + environment->softReset(); } // The left difficulty switch sets the width of the CPU opponent's bat. diff --git a/src/games/supported/Pong.hpp b/src/games/supported/Pong.hpp index 8c5e485a9..e7e41ff61 100644 --- a/src/games/supported/Pong.hpp +++ b/src/games/supported/Pong.hpp @@ -28,12 +28,12 @@ #ifndef __PONG_HPP__ #define __PONG_HPP__ -#include "games/RomSettings.hpp" +#include "../RomSettings4P.hpp" namespace ale { /* RL wrapper for Pong */ -class PongSettings : public RomSettings { +class PongSettings : public RomSettings4P { public: PongSettings(); @@ -45,6 +45,9 @@ class PongSettings : public RomSettings { // get the most recently observed reward reward_t getReward() const override; + reward_t getRewardP2() const override; + reward_t getRewardP3() const override; + reward_t getRewardP4() const override; // the rom-name const char* rom() const override { return "pong"; } @@ -52,9 +55,6 @@ class PongSettings : public RomSettings { // The md5 checksum of the ROM that this game supports const char* md5() const override { return "60e0ea3cbe0913d39803477945e9e5ec"; } - // get the available number of modes - unsigned int getNumModes() const { return 2; } - // create a new instance of the rom RomSettings* clone() const override; @@ -79,6 +79,8 @@ class PongSettings : public RomSettings { // returns a list of mode that the game can be played in // in this game, there are 2 available modes ModeVect getAvailableModes() override; + ModeVect get2PlayerModes() override; + ModeVect get4PlayerModes() override; // set the mode of the game // the given mode must be one returned by the previous function diff --git a/tests/test_python_interface.py b/tests/test_python_interface.py index fe94fb336..de71e325e 100644 --- a/tests/test_python_interface.py +++ b/tests/test_python_interface.py @@ -46,10 +46,16 @@ def test_int_config(tetris): assert value == 10 +def test_num_active_players(tetris): + assert tetris.numPlayersActive() == 1 + + def test_act(tetris): enum = tetris.getLegalActionSet() tetris.act(enum[0]) # NOOP tetris.act(0) # integer instead of enum + rew = tetris.act([0]) # list + assert len(rew) == 1 def test_game_over(tetris): @@ -97,6 +103,9 @@ def test_get_minimal_action_set(tetris): def test_get_available_modes(tetris): modes = tetris.getAvailableModes() assert len(modes) == 1 and modes[0] == 0 + assert tetris.getAvailableModes(1) == tetris.getAvailableModes() + assert tetris.getAvailableModes(2) == [] + assert tetris.getAvailableModes(4) == [] def test_set_mode(tetris): @@ -134,6 +143,10 @@ def test_lives(tetris): assert tetris.lives() == 0 +def test_allLives(tetris): + assert tetris.allLives() == [0] + + def test_get_episode_frame_number(tetris): tetris.setInt("frame_skip", 1) for _ in range(10):