forked from maurock/snake-ga
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bayesian optimization added, code ported to pytorch
- Loading branch information
Showing
6 changed files
with
167 additions
and
129 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,25 @@ | ||
# Deep Reinforcement Learning | ||
## Project: Train AI to play Snake | ||
*UPDATE:* | ||
|
||
This project has been recently updated: | ||
- The code of Deep Reinforcement Learning was ported from Keras/TF to Pytorch. To see the original version of the code in Keras/TF, please refer to this repository: [snake-ga-tf](https://github.com/maurock/snake-ga-tf). | ||
- I added Bayesian Optimization to optimize some parameters of Deep RL. | ||
|
||
## Introduction | ||
The goal of this project is to develop an AI Bot able to learn how to play the popular game Snake from scratch. In order to do it, I implemented a Deep Reinforcement Learning algorithm. This approach consists in giving the system parameters related to its state, and a positive or negative reward based on its actions. No rules about the game are given, and initially the Bot has no information on what it needs to do. The goal for the system is to figure it out and elaborate a strategy to maximize the score - or the reward. | ||
We are going to see how a Deep Q-Learning algorithm learns how to play snake, scoring up to 50 points and showing a solid strategy after only 5 minutes of training. | ||
The goal of this project is to develop an AI Bot able to learn how to play the popular game Snake from scratch. In order to do it, I implemented a Deep Reinforcement Learning algorithm. This approach consists in giving the system parameters related to its state, and a positive or negative reward based on its actions. No rules about the game are given, and initially the Bot has no information on what it needs to do. The goal for the system is to figure it out and elaborate a strategy to maximize the score - or the reward. \ | ||
We are going to see how a Deep Q-Learning algorithm learns how to play Snake, scoring up to 50 points and showing a solid strategy after only 5 minutes of training. \ | ||
Additionally, it is possible to run the Bayesian Optimization method to find the optimal parameters of the Deep neural network, as well as some parameters of the Deep RL approach. | ||
|
||
## Install | ||
This project requires Python 3.6 with the pygame library installed, as well as Keras with Tensorflow backend. | ||
This project requires Python 3.6 with the pygame library installed, as well as Pytorch. \ | ||
The full list of requirements is in `requirements.txt`. | ||
```bash | ||
git clone [email protected]:maurock/snake-ga.git | ||
``` | ||
|
||
## Run | ||
To run the game, executes in the snake-ga folder: | ||
To run and show the game, executes in the snake-ga folder: | ||
|
||
```python | ||
python snakeClass.py --display=True --speed=50 | ||
|
@@ -22,15 +29,24 @@ Arguments description: | |
- --display - Type bool, default True, display or not game view | ||
- --speed - Type integer, default 50, game speed | ||
|
||
This will run and show the agent. The default configuration loads the file *weights/weights.hdf5* and runs a test. | ||
The Deep neural network can be customized in the file snakeClass.py modifying the dictionary *params* in the function *define_parameters()* | ||
The default configuration loads the file *weights/weights.hdf5* and runs a test. | ||
The parameters of the Deep neural network can be changed in *snakeClass.py* by modifying the dictionary `params` in the function `define_parameters()` | ||
|
||
To train the agent, set in the file snakeClass.py: | ||
- params['load_weights'] = False | ||
- params['train'] = True | ||
|
||
In snakeClass.py you can set argument *--display*=False and *--speed*=0, if you do not want to see the game running. This speeds up the training phase. | ||
|
||
## Optimize Deep RL with Bayesian Optimization | ||
To optimize the Deep neural network and additional parameters, run: | ||
|
||
```python | ||
python snakeClass.py --bayesianopt=True | ||
``` | ||
|
||
This method uses Bayesian optimization to optimize some parameters of Deep RL. The parameters and the features' search space can be modified in *bayesOpt.py*, by editing the `optim_params` dictionary in `optimize_RL`. | ||
|
||
## For Mac users | ||
It seems there is a OSX specific problem, since many users cannot see the game running. | ||
To fix this problem, in update_screen(), add this line. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,9 @@ | ||
absl-py==0.8.0 | ||
astor==0.8.0 | ||
blinker==1.4 | ||
brotlipy==0.7.0 | ||
cachetools==4.1.0 | ||
certifi==2020.4.5.2 | ||
cffi==1.14.0 | ||
chardet==3.0.4 | ||
click==7.1.2 | ||
cmake-example==0.0.1 | ||
cryptography==2.9.2 | ||
cycler==0.10.0 | ||
gast==0.2.2 | ||
google-auth==1.14.1 | ||
google-auth-oauthlib==0.4.1 | ||
google-pasta==0.1.7 | ||
grpcio==1.27.2 | ||
h5py==2.10.0 | ||
idna==2.9 | ||
Keras==2.3.1 | ||
Keras-Applications==1.0.8 | ||
Keras-Preprocessing==1.1.0 | ||
kiwisolver==1.2.0 | ||
Markdown==3.1.1 | ||
matplotlib==3.2.0 | ||
mkl-fft==1.1.0 | ||
mkl-random==1.1.1 | ||
mkl-service==2.3.0 | ||
msgpack-numpy==0.4.4.3 | ||
numpy==1.18.1 | ||
oauthlib==3.1.0 | ||
opt-einsum==3.1.0 | ||
Keras==2.2.4 | ||
numpy==1.17.2 | ||
torch==1.4.0 | ||
seaborn==0.9.0 | ||
pygame==1.9.3 | ||
pandas==0.25.1 | ||
protobuf==3.12.3 | ||
pyasn1==0.4.8 | ||
pyasn1-modules==0.2.7 | ||
pycparser==2.20 | ||
pygame==1.9.6 | ||
PyJWT==1.7.1 | ||
pyOpenSSL==19.1.0 | ||
pyparsing==2.4.7 | ||
pyreadline==2.1 | ||
PySocks==1.7.1 | ||
python-dateutil==2.8.1 | ||
pytz==2020.1 | ||
PyYAML==5.3.1 | ||
requests==2.23.0 | ||
requests-oauthlib==1.3.0 | ||
rsa==4.0 | ||
scipy==1.4.1 | ||
seaborn==0.10.1 | ||
six==1.15.0 | ||
tabulate==0.8.3 | ||
tensorboard==2.2.1 | ||
tensorboard-plugin-wit==1.6.0 | ||
tensorflow==2.1.0 | ||
tensorflow-estimator==1.14.0 | ||
tensorpack==0.9.4 | ||
termcolor==1.1.0 | ||
tgan==0.1.0 | ||
urllib3==1.25.9 | ||
Werkzeug==0.16.1 | ||
win-inet-pton==1.1.0 | ||
wincertstore==0.2 | ||
wrapt==1.11.2 | ||
GPyOpt==1.2.6 | ||
numpy==1.19.4 |
Oops, something went wrong.