Skip to content

Commit

Permalink
update Atari envs to v4 and warn Python 2 users.
Browse files Browse the repository at this point in the history
  • Loading branch information
siemanko committed May 25, 2017
1 parent 0071b85 commit 7327a15
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 47 deletions.
37 changes: 0 additions & 37 deletions baselines/common/tf_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,43 +355,6 @@ def dropout(x, pkeep, phase=None, mask=None):
return switch(phase, mask * x, pkeep * x)


def batchnorm(x, name, phase, updates, gamma=0.96):
k = x.get_shape()[1]
runningmean = tf.get_variable(name + "/mean",
shape=[1, k],
initializer=tf.constant_initializer(0.0),
trainable=False)
runningvar = tf.get_variable(name + "/var",
shape=[1, k],
initializer=tf.constant_initializer(1e-4),
trainable=False)
testy = (x - runningmean) / tf.sqrt(runningvar)

mean_ = mean(x, axis=0, keepdims=True)
var_ = mean(tf.square(x), axis=0, keepdims=True)
std = tf.sqrt(var_)
trainy = (x - mean_) / std

updates.extend([
tf.assign(runningmean, runningmean * gamma + mean_ * (1 - gamma)),
tf.assign(runningvar, runningvar * gamma + var_ * (1 - gamma))
])

y = switch(phase, trainy, testy)

scaling = tf.get_variable(name + "/scaling",
shape=[1, k],
initializer=tf.constant_initializer(1.0),
trainable=True)

translation = tf.get_variable(name + "/translation",
shape=[1, k],
initializer=tf.constant_initializer(0.0),
trainable=True)

return y * scaling + translation


# ================================================================
# Theano-like Function
# ================================================================
Expand Down
2 changes: 1 addition & 1 deletion baselines/deepq/experiments/atari/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse_args():


def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v3")
env = gym.make(game_name + "NoFrameskip-v4")
env = SimpleMonitor(env)
env = wrap_dqn(env)
return env
Expand Down
2 changes: 1 addition & 1 deletion baselines/deepq/experiments/atari/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def parse_args():


def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v3")
env = gym.make(game_name + "NoFrameskip-v4")
monitored_env = SimpleMonitor(env) # puts rewards and number of steps in info, before environment is wrapped
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
return env, monitored_env
Expand Down
2 changes: 1 addition & 1 deletion baselines/deepq/experiments/atari/wang2015_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v3")
env = gym.make(game_name + "NoFrameskip-v4")
env_monitored = SimpleMonitor(env)
env = wrap_dqn(env_monitored)
return env_monitored, env
Expand Down
2 changes: 1 addition & 1 deletion baselines/deepq/experiments/enjoy_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def main():
env = gym.make("PongNoFrameskip-v3")
env = gym.make("PongNoFrameskip-v4")
env = ScaledFloatFrame(wrap_dqn(env))
act = deepq.load("pong_model.pkl")

Expand Down
2 changes: 1 addition & 1 deletion baselines/deepq/experiments/train_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def main():
env = gym.make("PongNoFrameskip-v3")
env = gym.make("PongNoFrameskip-v4")
env = ScaledFloatFrame(wrap_dqn(env))
model = deepq.models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
Expand Down
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from setuptools import setup, find_packages
import os


repo_dir = os.path.dirname(os.path.abspath(__file__))
import sys

if sys.version_info.major != 3:
print("This Python is only compatible with Python 3, but you are running "
"Python {}. The installation will likely fail.".format(sys.version_info.major))

setup(name='baselines',
packages=[package for package in find_packages()
if package.startswith('baselines')],
install_requires=[
'gym',
'scipy',
'tqdm',
'joblib',
Expand All @@ -22,4 +23,4 @@
author="OpenAI",
url='https://github.com/openai/baselines',
author_email="[email protected]",
version="0.1.0")
version="0.1.3")

0 comments on commit 7327a15

Please sign in to comment.