Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/facebook_opengo_nets' into nex…
Browse files Browse the repository at this point in the history
…t_lizzie
  • Loading branch information
Ka-zam committed May 3, 2018
2 parents 91b1f43 + 12dc829 commit 05a20b3
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
27 changes: 23 additions & 4 deletions src/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ static std::array<float, 256> ip1_val_b;

static std::array<float, 256> ip2_val_w;
static std::array<float, 1> ip2_val_b;
static bool value_head_not_stm;

// Symmetry helper
static std::array<std::array<int, BOARD_SQUARES>, 8> symmetry_nn_idx_table;
Expand Down Expand Up @@ -194,8 +195,12 @@ std::vector<float> Network::zeropad_U(const std::vector<float>& U,
std::pair<int, int> Network::load_v1_network(std::istream& wtfile) {
// Count size of the network
myprintf("Detecting residual layers...");
// We are version 1
myprintf("v%d...", 1);
// We are version 1 or 2
if (value_head_not_stm) {
myprintf("v%d...", 2);
} else {
myprintf("v%d...", 1);
}
// First line was the version number
auto linecount = size_t{1};
auto channels = 0;
Expand Down Expand Up @@ -325,11 +330,18 @@ std::pair<int, int> Network::load_network_file(const std::string& filename) {
auto iss = std::stringstream{line};
// First line is the file format version id
iss >> format_version;
if (iss.fail() || format_version != FORMAT_VERSION) {
if (iss.fail() || (format_version != 1 && format_version != 2)) {
myprintf("Weights file is the wrong version.\n");
return {0, 0};
} else {
assert(format_version == FORMAT_VERSION);
// Version 2 networks are identical to v1, except
// that they return the score for black instead of
// the player to move. This is used by ELF Open Go.
if (format_version == 2) {
value_head_not_stm = true;
} else {
value_head_not_stm = false;
}
return load_v1_network(buffer);
}
}
Expand Down Expand Up @@ -886,6 +898,13 @@ Network::Netresult Network::get_scored_moves(
result = get_scored_moves_internal(planes, rand_sym);
}

// v2 format (ELF Open Go) returns black value, not stm
if (value_head_not_stm) {
if (state->board.get_to_move() == FastBoard::WHITE) {
result.winrate = 1.0f - result.winrate;
}
}

// Insert result into cache.
NNCache::get_NNCache().insert(state->board.get_hash(), result);

Expand Down
1 change: 0 additions & 1 deletion src/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Network {
const int symmetry = -1,
const bool skip_cache = false);
// File format version
static constexpr auto FORMAT_VERSION = 1;
static constexpr auto INPUT_MOVES = 8;
static constexpr auto INPUT_CHANNELS = 2 * INPUT_MOVES + 2;
static constexpr auto OUTPUTS_POLICY = 2;
Expand Down
3 changes: 2 additions & 1 deletion training/elf/elf_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def write_block(f, b):
print(key, state[key].shape)

with open('elf_converted_weights.txt', 'w') as f:
f.write('1\n')
# version 2 means value head is for black, not for side to move
f.write('2\n')
b = convert_block(state, 'init_conv')

# Permutate input planes
Expand Down

0 comments on commit 05a20b3

Please sign in to comment.