Skip to content

Commit

Permalink
some tools for checking if environments have changed
Browse files Browse the repository at this point in the history
  • Loading branch information
joschu committed May 1, 2016
1 parent 2030bd4 commit fbb6033
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
37 changes: 37 additions & 0 deletions misc/check_envs_for_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
ENVS = ["Ant-v0", "HalfCheetah-v0", "Hopper-v0", "Humanoid-v0", "InvertedDoublePendulum-v0", "Reacher-v0", "Swimmer-v0", "Walker2d-v0"]
OLD_COMMIT = "HEAD"

# ================================================================

import subprocess, gym
from gym import utils
from os import path

def cap(cmd):
"Call and print command"
print utils.colorize(cmd, "green")
subprocess.check_call(cmd,shell=True)

# ================================================================

gymroot = path.abspath(path.dirname(path.dirname(gym.__file__)))
oldgymroot = "/tmp/old-gym"
comparedir = "/tmp/gym-comparison"

oldgymbase = path.basename(oldgymroot)

print "gym root", gymroot
thisdir = path.abspath(path.dirname(__file__))
print "this directory", thisdir
cap("rm -rf %(oldgymroot)s %(comparedir)s && mkdir %(comparedir)s && cd /tmp && git clone %(gymroot)s %(oldgymbase)s"%locals())
for env in ENVS:
print utils.colorize("*"*50 + "\nENV: %s" % env, "red")
writescript = path.join(thisdir, "write_rollout_data.py")
outfileA = path.join(comparedir, env) + "-A.npz"
cap("python %(writescript)s %(env)s %(outfileA)s"%locals())
outfileB = path.join(comparedir, env) + "-B.npz"
cap("python %(writescript)s %(env)s %(outfileB)s --gymdir=%(oldgymroot)s"%locals())

comparescript = path.join(thisdir, "compare_rollout_data.py")
cap("python %(comparescript)s %(outfileA)s %(outfileB)s"%locals())

26 changes: 26 additions & 0 deletions misc/compare_rollout_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import argparse, numpy as np

def main():
parser = argparse.ArgumentParser()
parser.add_argument("file1")
parser.add_argument("file2")
args = parser.parse_args()
file1 = np.load(args.file1)
file2 = np.load(args.file2)

for k in sorted(file1.keys()):
arr1 = file1[k]
arr2 = file2[k]
if arr1.shape == arr2.shape:
if np.allclose(file1[k], file2[k]):
print "%s: matches!"%k
continue
else:
print "%s: arrays are not equal. Difference = %g"%(k, np.abs(arr1 - arr2).max())
else:
print "%s: arrays have different shape! %s vs %s"%(k, arr1.shape, arr2.shape)
print "first 30 els:\n1. %s\n2. %s"%(arr1.flat[:30], arr2.flat[:30])


if __name__ == "__main__":
main()
55 changes: 55 additions & 0 deletions misc/write_rollout_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
This script does a few rollouts with an environment and writes the data to an npz file
Its purpose is to help with verifying that you haven't functionally changed an environment.
(If you have, you should bump the version number.)
"""
import argparse, numpy as np, collections, sys
from os import path


class RandomAgent(object):
def __init__(self, ac_space):
self.ac_space = ac_space
def act(self, _):
return self.ac_space.sample()

def rollout(env, agent, timestep_limit):
"""
Simulate the env and agent for timestep_limit steps
"""
ob = env.reset()
data = collections.defaultdict(list)
for _ in xrange(timestep_limit):
data["observation"].append(ob)
action = agent.act(ob)
data["action"].append(action)
ob,rew,done,_ = env.step(action)
data["reward"].append(rew)
if done:
break
return data

def main():
parser = argparse.ArgumentParser()
parser.add_argument("envid")
parser.add_argument("outfile")
parser.add_argument("--gymdir")

args = parser.parse_args()
if args.gymdir:
sys.path.insert(0, args.gymdir)
import gym
from gym import utils
print utils.colorize("gym directory: %s"%path.dirname(gym.__file__), "yellow")
env = gym.make(args.envid)
agent = RandomAgent(env.action_space)
alldata = {}
for i in xrange(2):
np.random.seed(i)
data = rollout(env, agent, env.spec.timestep_limit)
for (k, v) in data.items():
alldata["%i-%s"%(i, k)] = v
np.savez(args.outfile, **alldata)

if __name__ == "__main__":
main()

0 comments on commit fbb6033

Please sign in to comment.