A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play based reinforcement learning based on the AlphaGo Zero paper (Silver et al). It is designed to be easy to adopt for any two-player turn-based adversarial game and any deep learning framework of your choice. A sample implementation has been provided for the game of Othello in PyTorch, Keras and TensorFlow. An accompanying tutorial can be found here.
To use a game of your choice, subclass the classes in Game.py
and NeuralNet.py
and implement their functions. Example implementations for Othello can be found in othello/OthelloGame.py
and othello/{pytorch,keras,tensorflow}/NNet.py
.
Coach.py
contains the core training loop and MCTS.py
performs the Monte Carlo Tree Search. The parameters for the self-play can be specified in main.py
. Additional neural network parameters are in othello/{pytorch,keras,tensorflow}/NNet.py
(cuda flag, batch size, epochs, learning rate etc.).
To start training a model for Othello:
python main.py
Choose your framework and game in main.py
.
We trained a PyTorch model for 6x6 Othello (~80 iterations, 100 episodes per iteration and 25 MCTS simulations per turn). This took about 3 days on an NVIDIA Tesla K80. The pretrained model (PyTorch) can be found in pretrained_models/othello/pytorch/
. You can play a game against it using pit.py
. Below is the performance of the model against a random and a greedy baseline with the number of iterations.
A concise description of our algorithm can be found here.
In addition, Evgeny Tyurin has contributed rules and a trained model for TicTacToe.
While the current code is fairly functional, we could benefit from the following contributions:
- Game logic files for more games that follow the specifications in
Game.py
, along with their neural networks - Neural networks in other frameworks
- Pre-trained models for different game configurations
- An asynchronous version of the code- parallel processes for self-play, neural net training and model comparison.
- Asynchronous MCTS as described in the paper
Thanks to pytorch-classification and progress.