Skip to content

Commit

Permalink
Refactor json_encode_np
Browse files Browse the repository at this point in the history
  • Loading branch information
tlbtlbtlb committed Dec 30, 2016
1 parent 1059ccd commit ae89569
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
21 changes: 2 additions & 19 deletions gym/monitoring/monitor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gym import error, version
from gym.monitoring import stats_recorder, video_recorder
from gym.utils import atomic_write, closer
from gym.utils.json_utils import json_encode_np

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,7 +181,7 @@ def _flush(self, force=False):
'videos': [(os.path.basename(v), os.path.basename(m))
for v, m in self.videos],
'env_info': self._env_info(),
}, f, default=json_encode)
}, f, default=json_encode_np)

def close(self):
"""Flush all monitor data to disk and close any open rending windows."""
Expand Down Expand Up @@ -408,21 +409,3 @@ def collapse_env_infos(env_infos, training_dir):
if key not in first:
raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key))
return first


def json_encode(obj):
"""
JSON can't serialize numpy types, so convert to pure python
"""
if isinstance(obj, np.ndarray):
return list(obj)
elif isinstance(obj, np.float32):
return float(obj)
elif isinstance(obj, np.float64):
return float(obj)
elif isinstance(obj, np.int32):
return int(obj)
elif isinstance(obj, np.int64):
return int(obj)
else:
return obj
3 changes: 2 additions & 1 deletion gym/monitoring/stats_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from gym import error
from gym.utils import atomic_write
from gym.utils.json_utils import json_encode_np

class StatsRecorder(object):
def __init__(self, directory, file_prefix, autoreset=False, env_id=None):
Expand Down Expand Up @@ -99,4 +100,4 @@ def flush(self):
'episode_lengths': self.episode_lengths,
'episode_rewards': self.episode_rewards,
'episode_types': self.episode_types,
}, f)
}, f, default=json_encode_np)
18 changes: 18 additions & 0 deletions gym/utils/json_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np

def json_encode_np(obj):
"""
JSON can't serialize numpy types, so convert to pure python
"""
if isinstance(obj, np.ndarray):
return list(obj)
elif isinstance(obj, np.float32):
return float(obj)
elif isinstance(obj, np.float64):
return float(obj)
elif isinstance(obj, np.int32):
return int(obj)
elif isinstance(obj, np.int64):
return int(obj)
else:
return obj

0 comments on commit ae89569

Please sign in to comment.