Skip to content

Commit

Permalink
Vectorizing Replay Buffer
Browse files Browse the repository at this point in the history
Summary:
Improved the slow iterative sampling of the Dopamine to a vectorized version. Originally, the Dopamine buffer was about 5-10x slower than our OpenAIGymMemoryPool.
With this change, it is a bit faster.

Reviewed By: kittipatv

Differential Revision: D20243566

fbshipit-source-id: cc91609268057f12c7ade9ad40fec91593a56a4d
  • Loading branch information
kaiwenw authored and facebook-github-bot committed Mar 11, 2020
1 parent 5a74002 commit a37e96a
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 122 deletions.
247 changes: 130 additions & 117 deletions ml/rl/replay_memory/circular_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ class ReplayBuffer(object):
Attributes:
add_count: int, counter of how many transitions have been added (including
the blank ones at the beginning of an episode).
invalid_range: np.array, an array with the indices of cursor-related invalid
transitions
"""

def __init__(
Expand Down Expand Up @@ -187,6 +185,26 @@ def __init__(
self._cumulative_discount_vector = np.array(
[math.pow(self._gamma, n) for n in range(update_horizon)], dtype=np.float32
)
# track if index is valid for sampling purposes. there're two cases
# 1) first stack_size-1 zero transitions at start of episode
# 2) last update_horizon transitions before the cursor
self._is_index_valid = np.zeros(self._replay_capacity, dtype=np.bool)
self._num_valid_indices = 0
self._num_transitions_in_current_episode = 0

@property
def size(self) -> int:
return self._num_valid_indices

def set_index_valid_status(self, idx: int, is_valid: bool):
old_valid = self._is_index_valid[idx]
if not old_valid and is_valid:
self._num_valid_indices += 1
elif old_valid and not is_valid:
self._num_valid_indices -= 1
assert self._num_valid_indices >= 0, f"{self._num_valid_indices} is negative"

self._is_index_valid[idx] = is_valid

def _create_storage(self) -> None:
"""Creates the numpy arrays used to store transitions.
Expand Down Expand Up @@ -253,12 +271,46 @@ def add(self, observation, action, reward, terminal, *args):
extra_storage_types.
"""
self._check_add_types(observation, action, reward, terminal, *args)
if self.is_empty() or self._store["terminal"][self.cursor() - 1] == 1:
last_idx = (self.cursor() - 1) % self._replay_capacity
if self.is_empty() or self._store["terminal"][last_idx] == 1:
self._num_transitions_in_current_episode = 0
for _ in range(self._stack_size - 1):
# Child classes can rely on the padding transitions being filled with
# zeros. This is useful when there is a priority argument.
self._add_zero_transition()

# remember, the last update_horizon transitions shouldn't be sampled
cur_idx = self.cursor()
self.set_index_valid_status(idx=cur_idx, is_valid=False)
if self._num_transitions_in_current_episode >= self._update_horizon:
idx = (cur_idx - self._update_horizon) % self._replay_capacity
self.set_index_valid_status(idx=idx, is_valid=True)
self._add(observation, action, reward, terminal, *args)
self._num_transitions_in_current_episode += 1

# mark the next stack_size-1 as invalid (note cursor has advanced by 1)
for i in range(self._stack_size - 1):
idx = (self.cursor() + i) % self._replay_capacity
self.set_index_valid_status(idx=idx, is_valid=False)

if terminal:
# Since the frame (cur_idx) we just inserted was terminal, we now mark
# the last "num_back" transitions as valid for sampling (including cur_idx).
# This is because next_state is not relevant for those terminal (multi-step)
# transitions.
# NOTE: this was not accounted for by the original Dopamine buffer.
# It is not a big problem, since after update_horizon steps,
# the original Dopamine buffer will make these frames
# available for sampling.
# But that is update_horizon steps too late. If we train right
# after an episode terminates, this can result in missing the
# bulk of rewards at the end of the most recent episode.
num_back = min(
self._num_transitions_in_current_episode, self._update_horizon
)
for i in range(0, num_back):
idx = (cur_idx - i) % self._replay_capacity
self.set_index_valid_status(idx=idx, is_valid=True)

def _add(self, *args):
"""Internal add method to add to the storage arrays.
Expand Down Expand Up @@ -320,8 +372,8 @@ def _check_add_types(self, *args):
store_element_shape = tuple(store_element.shape)
if arg_shape != store_element_shape:
raise ValueError(
"arg has shape {}, expected {}".format(
arg_shape, store_element_shape
"arg {} has shape {}, expected {}".format(
store_element.name, arg_shape, store_element_shape
)
)

Expand Down Expand Up @@ -386,35 +438,7 @@ def get_terminal_stack(self, index):
)

def is_valid_transition(self, index):
"""Checks if the index contains a valid transition.
Checks for collisions with the end of episodes and the current position
of the cursor.
Args:
index: int, the index to the state in the transition.
Returns:
Is the index valid: Boolean.
"""
# Check the index is in the valid range
if index < 0 or index >= self._replay_capacity:
return False
if not self.is_full():
# The indices and next_indices must be smaller than the cursor.
if index >= self.cursor() - self._update_horizon:
return False
# The first few indices contain the padding states of the first episode.
if index < self._stack_size - 1:
return False

# Skip transitions that straddle the cursor.
if index in set(self.invalid_range):
return False

# If there are terminal flags in any other frame other than the last one
# the stack is not valid, so don't sample it.
if self.get_terminal_stack(index)[:-1].any():
return False

return True
return self._is_index_valid[index]

def _create_batch_arrays(self, batch_size):
"""Create a tuple of arrays with the type of get_transition_elements.
Expand All @@ -440,41 +464,16 @@ def sample_index_batch(self, batch_size):
Returns:
list of ints, a batch of valid indices sampled uniformly.
Raises:
RuntimeError: If the batch was not constructed after maximum number of
tries.
RuntimeError: If there are no valid indices to sample.
"""
if self.is_full():
# add_count >= self._replay_capacity > self._stack_size
min_id = self.cursor() - self._replay_capacity + self._stack_size - 1
max_id = self.cursor() - self._update_horizon
else:
# add_count < self._replay_capacity
min_id = self._stack_size - 1
max_id = self.cursor() - self._update_horizon
if max_id <= min_id:
raise RuntimeError(
"Cannot sample a batch with fewer than stack size "
"({}) + update_horizon ({}) transitions.".format(
self._stack_size, self._update_horizon
)
)

indices = []
attempt_count = 0
while len(indices) < batch_size and attempt_count < self._max_sample_attempts:
index = np.random.randint(min_id, max_id) % self._replay_capacity
if self.is_valid_transition(index):
indices.append(index)
else:
attempt_count += 1
if len(indices) != batch_size:
if self._num_valid_indices == 0:
raise RuntimeError(
"Max sample attempts: Tried {} times but only sampled {}"
" valid indices. Batch size is {}".format(
self._max_sample_attempts, len(indices), batch_size
)
f"Cannot sample {batch_size} since there are no valid indices so far."
)

p = self._is_index_valid.astype(np.float64) / float(self._num_valid_indices)
indices = np.random.choice(
a=self._replay_capacity, size=batch_size, replace=True, p=p
)
return indices

def sample_transition_batch(self, batch_size=None, indices=None):
Expand Down Expand Up @@ -504,62 +503,76 @@ def sample_transition_batch(self, batch_size=None, indices=None):
batch_size = self._batch_size
if indices is None:
indices = self.sample_index_batch(batch_size)
assert isinstance(
indices, np.ndarray
), f"Indices {indices} have type {type(indices)} instead of np.darray"
assert len(indices) == batch_size

transition_elements = self.get_transition_elements(batch_size)
batch_arrays = self._create_batch_arrays(batch_size)
for batch_element, state_index in enumerate(indices):
trajectory_indices = [
(state_index + j) % self._replay_capacity
for j in range(self._update_horizon)
]
trajectory_terminals = self._store["terminal"][trajectory_indices]
is_terminal_transition = trajectory_terminals.any()
if not is_terminal_transition:
trajectory_length = self._update_horizon
else:
# np.argmax of a bool array returns the index of the first True.
trajectory_length = (
np.argmax(trajectory_terminals.astype(np.bool), 0) + 1
)
next_state_index = state_index + trajectory_length
trajectory_discount_vector = self._cumulative_discount_vector[
:trajectory_length
]
trajectory_rewards = self.get_range(
self._store["reward"], state_index, next_state_index
)

# Fill the contents of each array in the sampled batch.
assert len(transition_elements) == len(batch_arrays)
for element_array, element in zip(batch_arrays, transition_elements):
if element.name == "state":
element_array[batch_element] = self.get_observation_stack(
state_index
)
elif element.name == "reward":
# compute the discounted sum of rewards in the trajectory.
element_array[batch_element] = np.sum(
trajectory_discount_vector * trajectory_rewards, axis=0
)
elif element.name == "next_state":
element_array[batch_element] = self.get_observation_stack(
(next_state_index) % self._replay_capacity
)
elif element.name in ("next_action", "next_reward"):
element_array[batch_element] = self._store[
element.name.lstrip("next_")
][(next_state_index) % self._replay_capacity]
elif element.name == "terminal":
element_array[batch_element] = is_terminal_transition
elif element.name == "indices":
element_array[batch_element] = state_index
elif element.name in self._store.keys():
element_array[batch_element] = self._store[element.name][
state_index
]
# We assume the other elements are filled in by the subclass.
def get_obs_stack_for_indices(indices):
""" Get stack of observations """
# calculate 2d array of indices with size (batch_size, stack_size)
# ith row contain indices in the stack of obs at indices[i]
stack_indices = indices.reshape(-1, 1) + np.arange(-self._stack_size + 1, 1)
stack_indices %= self._replay_capacity
# Reshape to (batch_size, obs_shape, stack_size)
perm = [0] + list(range(2, len(self._observation_shape) + 2)) + [1]
return self._store["observation"][stack_indices].transpose(perm)

# calculate 2d array of indices with size (batch_size, update_horizon)
# ith row contain the multistep indices starting at indices[i]
multistep_indices = indices.reshape(-1, 1) + np.arange(self._update_horizon)
multistep_indices %= self._replay_capacity

def get_traj_lengths():
""" Calculate trajectory length, defined to be the number of states
in this multi_step transition until terminal state or end of
multi_step. Dopamine calls multi_step as "update_horizon".
"""
terminals = self._store["terminal"][multistep_indices]
# if trajectory is non-terminal, we'll have traj_length = update_horizon
terminals[:, -1] = True
# Argmax find the first True in each one
traj_lengths = np.argmax(terminals.astype(np.bool), axis=1) + 1
return traj_lengths

traj_lengths = get_traj_lengths()
next_indices = (indices + traj_lengths) % self._replay_capacity

def get_multistep_reward_for_indices():
""" Sums up the reward for trajectory. """
decays = self._gamma ** np.arange(self._update_horizon)
decays = decays.reshape(1, self._update_horizon)
masks = np.arange(self._update_horizon) < traj_lengths.reshape(-1, 1)
rewards = self._store["reward"][multistep_indices] * decays * masks
return rewards.sum(axis=1)

batch_arrays = []
for element in transition_elements:
if element.name == "state":
batch = get_obs_stack_for_indices(indices)
elif element.name == "next_state":
batch = get_obs_stack_for_indices(next_indices)
elif element.name == "reward":
batch = get_multistep_reward_for_indices()
elif element.name == "terminal":
terminal_indices = (next_indices - 1) % self._replay_capacity
batch = self._store["terminal"][terminal_indices].astype(np.bool)
elif element.name == "indices":
batch = indices
elif element.name in ("next_action", "next_reward"):
store_name = element.name.lstrip("next_")
batch = self._store[store_name][next_indices]
elif element.name in self._store.keys():
batch = self._store[element.name][indices]

batch = batch.astype(element.type)
batch_arrays.append(batch)

batch_arrays = tuple(batch_arrays)

# We assume the other elements are filled in by the subclass.
return batch_arrays

def get_transition_elements(self, batch_size=None):
Expand Down
2 changes: 1 addition & 1 deletion ml/rl/replay_memory/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def sample_index_batch(self, batch_size):
index = self.sum_tree.sample()
allowed_attempts -= 1
indices[i] = index
return indices
return np.array(indices)

def sample_transition_batch(self, batch_size=None, indices=None):
"""Returns a batch of transitions with extra storage and the priorities.
Expand Down
20 changes: 16 additions & 4 deletions ml/rl/test/replay_memory/circular_replay_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ def testSampleTransitionBatch(self):
expected_terminal = np.array(
[min((x + num_adds - replay_capacity) % 4, 1) for x in indices]
)
batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices)
batch = memory.sample_transition_batch(
batch_size=len(indices), indices=np.array(indices)
)
(
states,
action,
Expand Down Expand Up @@ -408,7 +410,9 @@ def testSampleTransitionBatchExtra(self):
[min((x + num_adds - replay_capacity) % 4, 1) for x in indices]
)
expected_extra2 = np.zeros([len(indices), 2])
batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices)
batch = memory.sample_transition_batch(
batch_size=len(indices), indices=np.array(indices)
)
(
states,
action,
Expand Down Expand Up @@ -451,7 +455,9 @@ def testSamplingWithterminalInTrajectory(self):
1 if i == 3 else 0,
) # terminal
indices = [2, 3, 4]
batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices)
batch = memory.sample_transition_batch(
batch_size=len(indices), indices=np.array(indices)
)
states, action, reward, _, _, _, terminal, indices_batch = batch
expected_states = np.array(
[np.full(OBSERVATION_SHAPE + (1,), i, dtype=OBS_DTYPE) for i in indices]
Expand Down Expand Up @@ -514,7 +520,13 @@ def testIsTransitionValid(self):

# These valids account for the automatically applied padding (3 blanks each
# episode.
correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0]
# correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0]
# The above comment is for the original Dopamine buffer, which doesn't
# account for terminal frames within the update_horizon frames before
# the cursor. In this case, the frame right before the cursor
# is terminal, so even though it is within [c-update_horizon, c],
# it should still be valid for sampling, as next state doesn't matter.
correct_valids = [0, 0, 0, 1, 1, 1, 0, 0, 0, 0]
# The cursor is: ^\
for i in range(10):
self.assertEqual(
Expand Down

0 comments on commit a37e96a

Please sign in to comment.