@@ -88,6 +88,7 @@ static std::array<float, 256> ip1_val_b;
88
88
89
89
static std::array<float , 256 > ip2_val_w;
90
90
static std::array<float , 1 > ip2_val_b;
91
+ static bool value_head_not_stm;
91
92
92
93
// Symmetry helper
93
94
static std::array<std::array<int , BOARD_SQUARES>, 8 > symmetry_nn_idx_table;
@@ -194,8 +195,12 @@ std::vector<float> Network::zeropad_U(const std::vector<float>& U,
194
195
std::pair<int , int > Network::load_v1_network (std::istream& wtfile) {
195
196
// Count size of the network
196
197
myprintf (" Detecting residual layers..." );
197
- // We are version 1
198
- myprintf (" v%d..." , 1 );
198
+ // We are version 1 or 2
199
+ if (value_head_not_stm) {
200
+ myprintf (" v%d..." , 2 );
201
+ } else {
202
+ myprintf (" v%d..." , 1 );
203
+ }
199
204
// First line was the version number
200
205
auto linecount = size_t {1 };
201
206
auto channels = 0 ;
@@ -325,11 +330,18 @@ std::pair<int, int> Network::load_network_file(const std::string& filename) {
325
330
auto iss = std::stringstream{line};
326
331
// First line is the file format version id
327
332
iss >> format_version;
328
- if (iss.fail () || format_version != FORMAT_VERSION ) {
333
+ if (iss.fail () || ( format_version != 1 && format_version != 2 ) ) {
329
334
myprintf (" Weights file is the wrong version.\n " );
330
335
return {0 , 0 };
331
336
} else {
332
- assert (format_version == FORMAT_VERSION);
337
+ // Version 2 networks are identical to v1, except
338
+ // that they return the score for black instead of
339
+ // the player to move. This is used by ELF Open Go.
340
+ if (format_version == 2 ) {
341
+ value_head_not_stm = true ;
342
+ } else {
343
+ value_head_not_stm = false ;
344
+ }
333
345
return load_v1_network (buffer);
334
346
}
335
347
}
@@ -886,6 +898,13 @@ Network::Netresult Network::get_scored_moves(
886
898
result = get_scored_moves_internal (planes, rand_sym);
887
899
}
888
900
901
+ // v2 format (ELF Open Go) returns black value, not stm
902
+ if (value_head_not_stm) {
903
+ if (state->board .get_to_move () == FastBoard::WHITE) {
904
+ result.winrate = 1 .0f - result.winrate ;
905
+ }
906
+ }
907
+
889
908
// Insert result into cache.
890
909
NNCache::get_NNCache ().insert (state->board .get_hash (), result);
891
910
0 commit comments