Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add res30 PolicyValueNet and gobang #122

Open
wants to merge 71 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
6baef04
Decrease train batch num, batch size and increase evaluation batch nu…
minkefusiji May 12, 2021
980015a
Update batch size to 128 for training
minkefusiji May 12, 2021
d3908ee
Add output file
minkefusiji May 12, 2021
85e9b91
Add output file
minkefusiji May 12, 2021
08f1eec
Add output file
minkefusiji May 12, 2021
60017ce
Add residual bottleneck blocks
minkefusiji May 12, 2021
752e35b
Add batch normalization for bottleneck output
minkefusiji May 12, 2021
01d950f
Increase total batch num and training batch size
minkefusiji May 12, 2021
53d1df4
Update batch normalization
minkefusiji May 13, 2021
6c09794
Update batch normalization
minkefusiji May 13, 2021
3b3bd6c
Revert speedup configs
minkefusiji May 13, 2021
d743bd3
Add last 8 moves to state
minkefusiji May 14, 2021
fc0bc92
Add last 8 move to state
minkefusiji May 14, 2021
ecb1aa9
Update board to 9,9,5
minkefusiji May 14, 2021
d629baa
Update last 8 moves state
minkefusiji May 14, 2021
ab42ae1
Update last 8 moves state
minkefusiji May 14, 2021
78e7bc0
Enlarge last 8 moves to last 16 moves
minkefusiji May 14, 2021
1fca177
Enlarge last 8 moves to last 16 moves
minkefusiji May 14, 2021
bc243dc
Add INPUT_STATE_CHANNEL_SIZE constant
minkefusiji May 15, 2021
df2bfc9
Add log output
minkefusiji May 15, 2021
86525a0
Update output file path
minkefusiji May 15, 2021
e8883cb
Update last move step number
minkefusiji May 15, 2021
d7f4dbd
Add forbiden hands
minkefusiji May 16, 2021
5f876ee
Add empty states check for forbidden check
minkefusiji May 17, 2021
fc8cb1d
Update feature input
minkefusiji May 17, 2021
c7f4ae4
Remove forbidden hands
minkefusiji May 17, 2021
d6711a4
Update human play
minkefusiji May 18, 2021
12ad33b
Add last 16 moves as feature input
minkefusiji May 18, 2021
94840c0
Update human play
minkefusiji May 18, 2021
7504e21
Training from last best model
minkefusiji May 24, 2021
b8c9fe8
Train model with more playouts and total batch num and evaluate with …
minkefusiji May 24, 2021
4e6993f
Train model from scratch
minkefusiji May 24, 2021
6bf5b5b
Add gobang UI
minkefusiji May 27, 2021
60fa1d8
Change canvas tag
minkefusiji May 27, 2021
27882a8
Use current policy model in gobang
minkefusiji May 27, 2021
bb6509c
Reformat some texts
minkefusiji May 28, 2021
cdc651b
Use SGD+Momentum to replace Adam when retrain
minkefusiji Jun 2, 2021
7fe960f
Add comment for leaf value
minkefusiji Jun 2, 2021
36dcf0b
Retrain model from last version
minkefusiji Jun 2, 2021
019481f
Fix momentum optimizer error
minkefusiji Jun 2, 2021
7601fce
Create slots for MomentumOptimizer
minkefusiji Jun 3, 2021
091dd89
Use momentum optimizer to replace adam
minkefusiji Jun 3, 2021
f67a4d7
Use momentum optimizer to replace adam
minkefusiji Jun 3, 2021
0229a77
Init momentum optimizer
minkefusiji Jun 3, 2021
2d95ef1
Init momentum optimizer
minkefusiji Jun 3, 2021
76c3130
Use momentum optimizer to replace adam
minkefusiji Jun 3, 2021
c8f62fc
Add baseline with forbidden model
minkefusiji Jun 9, 2021
abec5f4
Add dedicate graphs for different policy value nets to avoid conflict
minkefusiji Jun 9, 2021
28d7567
Update output folder
minkefusiji Jun 9, 2021
7b68465
Add arguments to train
minkefusiji Jun 9, 2021
ae9d4e1
Add baseline and last16move states for game
minkefusiji Jun 9, 2021
2420c37
Update init model
minkefusiji Jun 9, 2021
2fbe0c5
Enable forbidden hands by default
minkefusiji Jun 9, 2021
4669ac0
Fix intermediate result file path
minkefusiji Jun 9, 2021
fdea5e9
Update intermediate result
minkefusiji Jun 9, 2021
534a7da
Update output path for res30
minkefusiji Jun 9, 2021
685bdb1
Add value loss and policy loss output
minkefusiji Jun 9, 2021
ad1d975
Add loss function to output folder path
minkefusiji Jun 9, 2021
eb29d43
Remove batch_norm import for baseline model
minkefusiji Jun 16, 2021
0d3a70d
Add res30 player and baseline player
minkefusiji Jun 16, 2021
359a97f
Add loss plot
minkefusiji Jun 16, 2021
f16ed17
Enable forbidden hands for humanplay and gobang
minkefusiji Jun 16, 2021
e8d28a3
Add evaluation pipeline
minkefusiji Jun 16, 2021
38aee05
Update loss plot
minkefusiji Jun 16, 2021
29b6728
Update loss plot
minkefusiji Jun 16, 2021
c57396e
Update loss plot
minkefusiji Jun 16, 2021
6af622b
Update loss plot
minkefusiji Jun 16, 2021
48e084f
Add intermediate result preprocess
minkefusiji Jun 16, 2021
02aa3a8
Add res30_l+ best model without forbidden hands
minkefusiji Jun 17, 2021
0590347
Update MCTS player name
minkefusiji Jun 17, 2021
700f9bc
Add baseline model
minkefusiji Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,51 @@ and then execute: ``python train.py`` (To use GPU in PyTorch, set ``use_gpu=Tru

The models (best_policy.model and current_policy.model) will be saved every a few updates (default 50).


With Tensorflow and ResNet30, uncomment the line
```
# from policy_value_net_res_tensorflow import PolicyValueNetRes30 # Tensorflow
```
Then execute:
```
python train.py -h
-h, --help show this help message and exit
--ModelName {baseline,res30}, -m {baseline,res30}
--LossFunction {lv,lp,l+,lx}, -l {lv,lp,l+,lx}
--EnableForbiddenHands, -fh Enable forbidden hands
```
baseline_l+:
```
python train.py --ModelName baseline --LossFunction l+ --EnableForbiddenHands True
```

baseline_lp:
```
python train.py --ModelName baseline --LossFunction lp --EnableForbiddenHands True
```

res30_l+:
```
python train.py --ModelName res30 --LossFunction l+ --EnableForbiddenHands True
```

res30_lp:
```
python train.py --ModelName res30 --LossFunction lp --EnableForbiddenHands True
```

Human play with AI

```
pip install tensorflow==1.14.0
python gobang_res30.py
```

**Note:** the 4 provided models were trained using Theano/Lasagne, to use them with PyTorch, please refer to [issue 5](https://github.com/junxiaosong/AlphaZero_Gomoku/issues/5).

**Tips for training:**
1. It is good to start with a 6 * 6 board and 4 in a row. For this case, we may obtain a reasonably good model within 500~1000 self-play games in about 2 hours.
2. For the case of 8 * 8 board and 5 in a row, it may need 2000~3000 self-play games to get a good model, and it may take about 2 days on a single PC.

### Further reading
My article describing some details about the implementation in Chinese: [https://zhuanlan.zhihu.com/p/32089487](https://zhuanlan.zhihu.com/p/32089487)
My article describing some details about the implementation in Chinese: [https://zhuanlan.zhihu.com/p/32089487](https://zhuanlan.zhihu.com/p/32089487)
91 changes: 91 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
"""
An implementation of the evaluation pipeline of AlphaZero for Gomoku

@author: Chunlei Wang
"""

import random
import numpy as np
from collections import defaultdict, deque
from game import Board, Game
from mcts_alphaZero import MCTSPlayer
#from policy_value_net import PolicyValueNet # Theano and Lasagne
#from policy_value_net_pytorch import PolicyValueNet # Pytorch
from policy_value_net_tensorflow import PolicyValueNet # Tensorflow
#from policy_value_net_keras import PolicyValueNet # Keras
from policy_value_net_res_tensorflow import PolicyValueNetRes30 # Tensorflow
from datetime import datetime
import utils
import os

OUTPUT_DIR = "evaluation/" + datetime.utcnow().strftime("%Y%m%d%H%M%S")
os.makedirs(OUTPUT_DIR, exist_ok=True)
EVALUATION_OUTPUT = OUTPUT_DIR + "/evaluation.txt"

class EvaluationPipeline():
def __init__(self, current_model, baseline_model):
# params of the board and the game
self.board_width = 9
self.board_height = 9
self.n_in_row = 5
self.board = Board(width=self.board_width,
height=self.board_height,
n_in_row=self.n_in_row,
forbidden_hands=True)
self.game = Game(self.board)
self.n_playout = 400 # num of simulations for each move
self.c_puct = 5

self.baseline_policy_value_net = PolicyValueNet(self.board_width,
self.board_height,
'l+',
model_file=baseline_model)

self.current_policy_value_net = PolicyValueNetRes30(self.board_width,
self.board_height,
'l+',
model_file=current_model)

def policy_evaluate(self, n_games=100):
"""
Evaluate the trained policy by playing against the baseline MCTS player
"""
current_mcts_player = MCTSPlayer(self.current_policy_value_net.policy_value_fn,
c_puct=self.c_puct,
n_playout=self.n_playout)

baseline_mcts_player = MCTSPlayer(self.baseline_policy_value_net.policy_value_fn,
c_puct=self.c_puct,
n_playout=self.n_playout)

win_cnt = defaultdict(int)
for i in range(n_games):
winner = self.game.start_play(current_mcts_player,
baseline_mcts_player,
start_player=i % 2,
is_shown=1)
win_cnt[winner] += 1
win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games

output = "Evaluation games: {}, num_playouts: {}, win: {}, lose: {}, tie: {}, win ratio: {}".format(
n_games,
self.n_playout,
win_cnt[1], win_cnt[2], win_cnt[-1], win_ratio)

utils.log(output, EVALUATION_OUTPUT)

return win_ratio

def run(self):
"""run the evaluation pipeline"""
try:
win_ratio = self.policy_evaluate()
return win_ratio
except KeyboardInterrupt:
print('\n\rquit')


if __name__ == '__main__':
evaluation_pipeline = EvaluationPipeline(current_model='output/current_policy.model', baseline_model='output/baseline_policy.model')
evaluation_pipeline.run()
142 changes: 134 additions & 8 deletions game.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,36 @@
from __future__ import print_function
import numpy as np

INPUT_STATE_CHANNEL_SIZE = 19

class Board(object):
"""board for the game"""

"""
0: blank
1: black
2: white
"""
forbidden_hands_of_three_patterns = [
[0, 1, 1, 1, 0],
[0, 1, 0, 1, 1, 0],
[0, 1, 1, 0, 1, 0],
]

forbidden_hands_of_four_patterns = [
[0, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 0, 1],
[0, 1, 0, 1, 1, 1],
[1, 1, 1, 0, 1, 0],
[1, 0, 1, 1, 1, 0],
[2, 1, 1, 1, 1, 0],
[2, 1, 1, 1, 0, 1],
[2, 1, 0, 1, 1, 1],
[0, 1, 1, 1, 1, 2],
[1, 1, 1, 0, 1, 2],
[1, 0, 1, 1, 1, 2],
]

def __init__(self, **kwargs):
self.width = int(kwargs.get('width', 8))
self.height = int(kwargs.get('height', 8))
Expand All @@ -19,17 +45,20 @@ def __init__(self, **kwargs):
self.states = {}
# need how many pieces in a row to win
self.n_in_row = int(kwargs.get('n_in_row', 5))
self.forbidden_hands = bool(kwargs.get('forbidden_hands', False))
self.players = [1, 2] # player1 and player2

def init_board(self, start_player=0):
if self.width < self.n_in_row or self.height < self.n_in_row:
raise Exception('board width and height can not be '
'less than {}'.format(self.n_in_row))
self.start_player = start_player
self.current_player = self.players[start_player] # start player
# keep available moves in a list
self.availables = list(range(self.width * self.height))
self.states = {}
self.last_move = -1
self.last_16_move = [0]*(INPUT_STATE_CHANNEL_SIZE-3)

def move_to_location(self, move):
"""
Expand All @@ -48,13 +77,41 @@ def location_to_move(self, location):
return -1
h = location[0]
w = location[1]
if h < 0 or h >= self.height:
return -1
if w < 0 or w >= self.width:
return -1

move = h * self.width + w
if move not in range(self.width * self.height):
return -1
return move

def current_last16move_state(self):
"""return the board state from the perspective of the current res30 player.
state shape: INPUT_STATE_CHANNEL_SIZE*width*height
"""

square_state = np.zeros((INPUT_STATE_CHANNEL_SIZE, self.width, self.height))
if self.states:
moves, players = np.array(list(zip(*self.states.items())))
move_curr = moves[players == self.current_player]
move_oppo = moves[players != self.current_player]

square_state[0][move_curr // self.width,
move_curr % self.height] = 1.0
square_state[1][move_oppo // self.width,
move_oppo % self.height] = 1.0
# indicate the last 16 move location
for i in range(INPUT_STATE_CHANNEL_SIZE-3):
square_state[2+i][np.array(self.last_16_move[i::2]) // self.width,
np.array(self.last_16_move[i::2]) % self.height] = 1.0
if len(self.states) % 2 == 0:
square_state[INPUT_STATE_CHANNEL_SIZE-1][:, :] = 1.0 # indicate the colour to play
return square_state[:, ::-1, :]

def current_state(self):
"""return the board state from the perspective of the current player.
"""return the board state from the perspective of the current baseline player.
state shape: 4*width*height
"""

Expand Down Expand Up @@ -82,13 +139,18 @@ def do_move(self, move):
else self.players[1]
)
self.last_move = move
self.last_16_move.pop(0)
self.last_16_move.append(move)

def has_a_winner(self):
width = self.width
height = self.height
states = self.states
n = self.n_in_row

if self.forbidden_hands and self.states and self.states[self.last_move] == self.players[self.start_player] and self.check_forbidden_hands():
return True, self.players[(self.start_player + 1) % 2]

moved = list(set(range(width * height)) - set(self.availables))
if len(moved) < self.n_in_row *2-1:
return False, -1
Expand Down Expand Up @@ -116,6 +178,71 @@ def has_a_winner(self):

return False, -1

def check_forbidden_hands(self):
directions = [
[1, 0],
[1, 1],
[0, 1],
[-1, 1],
]

patterns_of_three_matches = [
1 if self.check_forbidden_pattern(p, d) else 0
for d in directions
for p in self.forbidden_hands_of_three_patterns
]

patterns_of_four_matches = [
1 if self.check_forbidden_pattern(p, d) else 0
for d in directions
for p in self.forbidden_hands_of_four_patterns
]

if sum(patterns_of_three_matches) > 1 or sum(patterns_of_four_matches) > 1:
return True

def check_forbidden_pattern(self, pattern, direction):
for (i, x) in enumerate(pattern):
if x == 1:
pieces = self.collect_pieces(self.last_move, direction, i, len(pattern))
if pieces != [] and Board.list_equal(pieces, pattern):
return True

return False

def list_equal(list1, list2):
if len(list1) != len(list2):
return False

for (a, b) in zip(list1, list2):
if a != b:
return False

return True

def collect_pieces(self, move, direction, look_back, length):
cur_location = self.move_to_location(move)
start_location = [
cur_location[0] - direction[0] * look_back,
cur_location[1] - direction[1] * look_back,
]

pieces = []
for i in range(length):
location = [
start_location[0] + i * direction[0],
start_location[1] + i * direction[1],
]
move = self.location_to_move(location)
if move == -1:
return []
else:
if move in self.states:
pieces.append(1 if self.states[move] == self.players[self.start_player] else 2)
else:
pieces.append(0)
return pieces

def game_end(self):
"""Check whether the game is ended or not"""
win, winner = self.has_a_winner()
Expand Down Expand Up @@ -180,14 +307,13 @@ def start_play(self, player1, player2, start_player=0, is_shown=1):
self.graphic(self.board, player1.player, player2.player)
end, winner = self.board.game_end()
if end:
if is_shown:
if winner != -1:
if winner != -1:
print("Game end. Winner is", players[winner])
else:
print("Game end. Tie")
else:
print("Game end. Tie")
return winner

def start_self_play(self, player, is_shown=0, temp=1e-3):
def start_self_play(self, player, model_name, is_shown=0, temp=1e-3):
""" start a self-play game using a MCTS player, reuse the search tree,
and store the self-play data: (state, mcts_probs, z) for training
"""
Expand All @@ -199,7 +325,7 @@ def start_self_play(self, player, is_shown=0, temp=1e-3):
temp=temp,
return_prob=1)
# store the data
states.append(self.board.current_state())
states.append(self.board.current_state() if model_name == 'baseline' else self.board.current_last16move_state())
mcts_probs.append(move_probs)
current_players.append(self.board.current_player)
# perform a move
Expand All @@ -220,4 +346,4 @@ def start_self_play(self, player, is_shown=0, temp=1e-3):
print("Game end. Winner is player:", winner)
else:
print("Game end. Tie")
return winner, zip(states, mcts_probs, winners_z)
return winner, zip(states, mcts_probs, winners_z)
Loading