Skip to content

Commit 05a20b3

Browse files
committed
Merge remote-tracking branch 'upstream/facebook_opengo_nets' into next_lizzie
2 parents 91b1f43 + 12dc829 commit 05a20b3

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

src/Network.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ static std::array<float, 256> ip1_val_b;
8888

8989
static std::array<float, 256> ip2_val_w;
9090
static std::array<float, 1> ip2_val_b;
91+
static bool value_head_not_stm;
9192

9293
// Symmetry helper
9394
static std::array<std::array<int, BOARD_SQUARES>, 8> symmetry_nn_idx_table;
@@ -194,8 +195,12 @@ std::vector<float> Network::zeropad_U(const std::vector<float>& U,
194195
std::pair<int, int> Network::load_v1_network(std::istream& wtfile) {
195196
// Count size of the network
196197
myprintf("Detecting residual layers...");
197-
// We are version 1
198-
myprintf("v%d...", 1);
198+
// We are version 1 or 2
199+
if (value_head_not_stm) {
200+
myprintf("v%d...", 2);
201+
} else {
202+
myprintf("v%d...", 1);
203+
}
199204
// First line was the version number
200205
auto linecount = size_t{1};
201206
auto channels = 0;
@@ -325,11 +330,18 @@ std::pair<int, int> Network::load_network_file(const std::string& filename) {
325330
auto iss = std::stringstream{line};
326331
// First line is the file format version id
327332
iss >> format_version;
328-
if (iss.fail() || format_version != FORMAT_VERSION) {
333+
if (iss.fail() || (format_version != 1 && format_version != 2)) {
329334
myprintf("Weights file is the wrong version.\n");
330335
return {0, 0};
331336
} else {
332-
assert(format_version == FORMAT_VERSION);
337+
// Version 2 networks are identical to v1, except
338+
// that they return the score for black instead of
339+
// the player to move. This is used by ELF Open Go.
340+
if (format_version == 2) {
341+
value_head_not_stm = true;
342+
} else {
343+
value_head_not_stm = false;
344+
}
333345
return load_v1_network(buffer);
334346
}
335347
}
@@ -886,6 +898,13 @@ Network::Netresult Network::get_scored_moves(
886898
result = get_scored_moves_internal(planes, rand_sym);
887899
}
888900

901+
// v2 format (ELF Open Go) returns black value, not stm
902+
if (value_head_not_stm) {
903+
if (state->board.get_to_move() == FastBoard::WHITE) {
904+
result.winrate = 1.0f - result.winrate;
905+
}
906+
}
907+
889908
// Insert result into cache.
890909
NNCache::get_NNCache().insert(state->board.get_hash(), result);
891910

src/Network.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class Network {
5959
const int symmetry = -1,
6060
const bool skip_cache = false);
6161
// File format version
62-
static constexpr auto FORMAT_VERSION = 1;
6362
static constexpr auto INPUT_MOVES = 8;
6463
static constexpr auto INPUT_CHANNELS = 2 * INPUT_MOVES + 2;
6564
static constexpr auto OUTPUTS_POLICY = 2;

training/elf/elf_convert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def write_block(f, b):
3939
print(key, state[key].shape)
4040

4141
with open('elf_converted_weights.txt', 'w') as f:
42-
f.write('1\n')
42+
# version 2 means value head is for black, not for side to move
43+
f.write('2\n')
4344
b = convert_block(state, 'init_conv')
4445

4546
# Permutate input planes

0 commit comments

Comments
 (0)