Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor frame processing & add frame maxing #495

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions src/ale_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
#include "emucore/Console.hxx"
#include "emucore/Props.hxx"
#include "emucore/MD5.hxx"
#include "environment/ale_screen.hpp"
#include "environment/phosphor_blend.hpp"
#include "environment/frame_max.hpp"
#include "environment/frame_identity.hpp"
#include "games/RomSettings.hpp"

namespace fs = std::filesystem;
Expand Down Expand Up @@ -160,6 +162,25 @@ void ALEInterface::loadROM(fs::path rom_file) {
// before the StellaEnvironment is constructed.
romSettings->modifyEnvironmentSettings(theOSystem->settings());

bool shouldFrameAverage = theOSystem->settings().getBool("frame_average");
if (theOSystem->settings().getBool("color_averaging")) {
Logger::Warning << "color_averaging is deprecated. Please use frame_average instead." << std::endl;
shouldFrameAverage = true;
}
bool shouldFrameMax = theOSystem->settings().getBool("frame_max");

if (shouldFrameAverage && shouldFrameMax) {
throw new std::runtime_error("Cannot enable both frame averaging and frame maxing.");
}

if (shouldFrameAverage) {
frameProcessor = std::make_unique<PhosphorBlend>(theOSystem->colourPalette());
} else if (shouldFrameMax) {
frameProcessor = std::make_unique<FrameMax>(theOSystem->colourPalette());
} else {
frameProcessor = std::make_unique<FrameIdentity>(theOSystem->colourPalette());
}

environment.reset(new StellaEnvironment(theOSystem.get(), romSettings.get()));
max_num_frames = theOSystem->settings().getInt("max_num_frames_per_episode");
environment->reset();
Expand Down Expand Up @@ -328,41 +349,35 @@ int ALEInterface::getEpisodeFrameNumber() const {
return environment->getEpisodeFrameNumber();
}

// Returns the current game screen
const ALEScreen& ALEInterface::getScreen() const { return environment->getScreen(); }

//This method should receive an empty vector to fill it with
//the grayscale colours
void ALEInterface::getScreenGrayscale(
std::vector<unsigned char>& grayscale_output_buffer) const {
size_t w = environment->getScreen().width();
size_t h = environment->getScreen().height();
size_t screen_size = w * h;

pixel_t* ale_screen_data = environment->getScreen().getArray();
theOSystem->colourPalette().applyPaletteGrayscale(
grayscale_output_buffer, ale_screen_data, screen_size);
frameProcessor->processGrayscale(
theOSystem->console().mediaSource(),
grayscale_output_buffer.data()
);
}

//This method should receive a vector to fill it with
//the RGB colours. The first positions contain the red colours,
//followed by the green colours and then the blue colours
void ALEInterface::getScreenRGB(std::vector<unsigned char>& output_rgb_buffer) const {
size_t w = environment->getScreen().width();
size_t h = environment->getScreen().height();
size_t screen_size = w * h;

pixel_t* ale_screen_data = environment->getScreen().getArray();

theOSystem->colourPalette().applyPaletteRGB(output_rgb_buffer,
ale_screen_data, screen_size);
void ALEInterface::getScreenRGB(std::vector<unsigned char>& rgb_output_buffer) const {
frameProcessor->processRGB(
theOSystem->console().mediaSource(),
rgb_output_buffer.data()
);
}

// Returns the current RAM content
const ALERAM& ALEInterface::getRAM() const { return environment->getRAM(); }
void ALEInterface::getRAM(std::array<uint8_t, 128>& output_ram_buffer) const {
for (size_t i = 0; i < 128; ++i) {
output_ram_buffer[i] = theOSystem->console().system().peek(i + 0x80);
}
}

// Set byte at memory address
void ALEInterface::setRAM(size_t memory_index, byte_t value) {
void ALEInterface::setRAM(size_t memory_index, uint8_t value) {
if (memory_index < 0 || memory_index >= 128){
throw std::runtime_error("setRAM index out of bounds.");
}
Expand All @@ -386,13 +401,16 @@ void ALEInterface::restoreSystemState(const ALEState& state) {
}

void ALEInterface::saveScreenPNG(const std::string& filename) {
ScreenExporter exporter(theOSystem->colourPalette());
exporter.save(environment->getScreen(), filename);
ScreenExporter exporter(theOSystem->console().mediaSource(), theOSystem->colourPalette());
exporter.save(filename);
}

ScreenExporter*
ALEInterface::createScreenExporter(const std::string& filename) const {
return new ScreenExporter(theOSystem->colourPalette(), filename);
return new ScreenExporter(
theOSystem->console().mediaSource(), theOSystem->colourPalette(),
filename
);
}

} // namespace ale
16 changes: 9 additions & 7 deletions src/ale_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@
#include "emucore/OSystem.hxx"
#include "games/Roms.hpp"
#include "environment/stella_environment.hpp"
#include "environment/frame_processor.hpp"
#include "common/ScreenExporter.hpp"
#include "common/Log.hpp"
#include "version.hpp"

#include <cstdint>
#include <vector>
#include <array>
#include <string>
#include <optional>
#include <memory>
Expand Down Expand Up @@ -149,24 +153,21 @@ class ALEInterface {
// Returns the frame number since the start of the current episode
int getEpisodeFrameNumber() const;

// Returns the current game screen
const ALEScreen& getScreen() const;

//This method should receive an empty vector to fill it with
//the grayscale colours
void getScreenGrayscale(std::vector<unsigned char>& grayscale_output_buffer) const;
void getScreenGrayscale(std::vector<uint8_t>& grayscale_output_buffer) const;

//This method should receive a vector to fill it with
//the RGB colours. The first positions contain the red colours,
//followed by the green colours and then the blue colours
void getScreenRGB(std::vector<unsigned char>& output_rgb_buffer) const;
void getScreenRGB(std::vector<uint8_t>& output_rgb_buffer) const;

// Returns the current RAM content
const ALERAM& getRAM() const;
void getRAM(std::array<uint8_t, 128>& output_ram_buffer) const;

// Set byte at memory address. This can be useful to change the environment
// for example if you were trying to learn a causal model of RAM locations.
void setRAM(size_t memory_index, byte_t value);
void setRAM(size_t memory_index, uint8_t value);

// This makes a copy of the environment state. By defualt this copy does *not* include pseudorandomness
// making it suitable for planning purposes. If `include_prng` is set to true, then the
Expand Down Expand Up @@ -200,6 +201,7 @@ class ALEInterface {
std::unique_ptr<stella::Settings> theSettings;
std::unique_ptr<RomSettings> romSettings;
std::unique_ptr<StellaEnvironment> environment;
std::unique_ptr<FrameProcessor> frameProcessor;
int max_num_frames; // Maximum number of frames for each episode

public:
Expand Down
2 changes: 2 additions & 0 deletions src/common/ColourPalette.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ uint8_t ColourPalette::getGrayscale(int val) const {

uint32_t ColourPalette::getRGB(int val) const { return m_palette[val]; }

uint8_t ColourPalette::convertRGBToGrayscale(uint32_t rgb) const { return (uint8_t) convertGrayscale(rgb); }

void ColourPalette::applyPaletteRGB(uint8_t* dst_buffer, uint8_t* src_buffer,
std::size_t src_size) {
uint8_t* p = src_buffer;
Expand Down
2 changes: 2 additions & 0 deletions src/common/ColourPalette.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class ColourPalette {
/** Returns the byte-sized grayscale value for this palette index. */
uint8_t getGrayscale(int val) const;

uint8_t convertRGBToGrayscale(uint32_t rgb) const;

/** Applies the current RGB palette to the src_buffer and returns the results in dst_buffer
* For each byte in src_buffer, three bytes are returned in dst_buffer
* 8 bits => 24 bits
Expand Down
34 changes: 17 additions & 17 deletions src/common/ScreenExporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ static void writePNGChunk(std::ofstream& out, const char* type, uint8_t* data,
out.write((const char*)temp, 4);
}

static void writePNGHeader(std::ofstream& out, const ALEScreen& screen,
static void writePNGHeader(std::ofstream& out, const stella::MediaSource& media,
bool doubleWidth = true) {
int width = doubleWidth ? screen.width() * 2 : screen.width();
int height = screen.height();
int width = doubleWidth ? media.width() * 2 : media.width();
int height = media.height();
// PNG file header
uint8_t header[8] = {137, 80, 78, 71, 13, 10, 26, 10};
out.write((const char*)header, sizeof(header));
Expand All @@ -88,27 +88,28 @@ static void writePNGHeader(std::ofstream& out, const ALEScreen& screen,
writePNGChunk(out, "IHDR", ihdr, sizeof(ihdr));
}

static void writePNGData(std::ofstream& out, const ALEScreen& screen,
static void writePNGData(std::ofstream& out, const stella::MediaSource& media,
const ColourPalette& palette,
bool doubleWidth = true) {
int dataWidth = screen.width();
int dataWidth = media.width();
int width = doubleWidth ? dataWidth * 2 : dataWidth;
int height = screen.height();
int height = media.height();

// If so desired, double the width

// Fill the buffer with scanline data
int rowbytes = width * 3;

std::vector<uint8_t> buffer((rowbytes + 1) * height, 0);
uint8_t* currentFrameBuffer = media.currentFrameBuffer();
uint8_t* buf_ptr = &buffer[0];

for (int i = 0; i < height; i++) {
*buf_ptr++ = 0; // first byte of row is filter type
for (int j = 0; j < dataWidth; j++) {
int r, g, b;

palette.getRGB(screen.getArray()[i * dataWidth + j], r, g, b);
palette.getRGB(currentFrameBuffer[i * dataWidth + j], r, g, b);
// Double the pixel width, if so desired
int jj = doubleWidth ? 2 * j : j;

Expand Down Expand Up @@ -147,15 +148,14 @@ static void writePNGEnd(std::ofstream& out) {
writePNGChunk(out, "IEND", 0, 0);
}

ScreenExporter::ScreenExporter(ColourPalette& palette)
: m_palette(palette), m_frame_number(0), m_frame_field_width(6) {}
ScreenExporter::ScreenExporter(stella::MediaSource& media, ColourPalette& palette)
: m_media(media), m_palette(palette), m_frame_number(0), m_frame_field_width(6) {}

ScreenExporter::ScreenExporter(ColourPalette& palette, const std::string& path)
: m_palette(palette), m_frame_number(0), m_frame_field_width(6),
ScreenExporter::ScreenExporter(stella::MediaSource& media, ColourPalette& palette, const std::string& path)
: m_media(media), m_palette(palette), m_frame_number(0), m_frame_field_width(6),
m_path(path) {}

void ScreenExporter::save(const ALEScreen& screen,
const std::string& filename) const {
void ScreenExporter::save(const std::string& filename) const {
// Open file for writing
std::ofstream out(filename.c_str(), std::ios_base::binary);
if (!out.good()) {
Expand All @@ -165,14 +165,14 @@ void ScreenExporter::save(const ALEScreen& screen,
}

// Now write the PNG proper
writePNGHeader(out, screen, true);
writePNGData(out, screen, m_palette, true);
writePNGHeader(out, m_media, true);
writePNGData(out, m_media, m_palette, true);
writePNGEnd(out);

out.close();
}

void ScreenExporter::saveNext(const ALEScreen& screen) {
void ScreenExporter::saveNext() {
// Must have specified a directory.
assert(!m_path.empty());

Expand All @@ -185,7 +185,7 @@ void ScreenExporter::saveNext(const ALEScreen& screen) {
<< m_frame_number << ".png";

// Save the png
save(screen, oss.str());
save(oss.str());

m_frame_number++;
}
Expand Down
11 changes: 6 additions & 5 deletions src/common/ScreenExporter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,27 @@

#include "common/Constants.h"
#include "common/ColourPalette.hpp"
#include "environment/ale_screen.hpp"
#include "emucore/MediaSrc.hxx"

namespace ale {

class ScreenExporter {
public:
/** Creates a new ScreenExporter which can be used to save screens using save(filename). */
ScreenExporter(ColourPalette& palette);
ScreenExporter(stella::MediaSource& media, ColourPalette& palette);

/** Creates a new ScreenExporter which will save frames successively in the directory provided.
* Frames are sequentially named with 6 digits, starting at 000000. */
ScreenExporter(ColourPalette& palette, const std::string& path);
ScreenExporter(stella::MediaSource& media, ColourPalette& palette, const std::string& path);

/** Save the given screen to the given filename. No paths are created. */
void save(const ALEScreen& screen, const std::string& filename) const;
void save(const std::string& filename) const;

/** Save the given screen according to our own internal numbering. */
void saveNext(const ALEScreen& screen);
void saveNext();

private:
stella::MediaSource& m_media;
ColourPalette& m_palette;

/** The next frame number. */
Expand Down
2 changes: 2 additions & 0 deletions src/emucore/Settings.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ void Settings::setDefaultSettings() {
boolSettings.insert(std::pair<std::string, bool>("restricted_action_set", false));
intSettings.insert(std::pair<std::string, int>("random_seed", -1));
boolSettings.insert(std::pair<std::string, bool>("color_averaging", false));
boolSettings.insert(std::pair<std::string, bool>("frame_average", false));
boolSettings.insert(std::pair<std::string, bool>("frame_max", false));
boolSettings.insert(std::pair<std::string, bool>("send_rgb", false));
intSettings.insert(std::pair<std::string, int>("frame_skip", 1));
floatSettings.insert(std::pair<std::string, float>("repeat_action_probability", 0.25));
Expand Down
65 changes: 0 additions & 65 deletions src/environment/ale_ram.hpp

This file was deleted.

Loading
Loading