-
Notifications
You must be signed in to change notification settings - Fork 809
Description
Hi,
I've been looking at the implementation of PrioritizedReplayBuffer, which uses SumSegmentTree. It appears to follow the standard and efficient approach for prioritized experience replay.
I wanted to raise a discussion point regarding the behavior of priority updates when the replay buffer wraps around and overwrites existing data.
Scenario:
- The buffer is full (
self.size == self.capacity
). - The sample method is called, returning a batch including a transition
T_old
originally stored at index i (returned in samples["indices"]). - Before update_priorities is called for index i, new transitions are added via the add method. Due to the circular buffer logic (
self.pos = (self.pos + 1) % self.capacity
), the data at index i (including T_old) is overwritten by a new transitionT_new
. - Later, update_priorities is called with the originally sampled index i and a priority calculated based on the TD error of
T_old
. - Inside update_priorities, the line
self.sum_tree.update(idx, priority)
(and similarly for min_tree) updates the priority value stored at the fixed leaf index i in the segment tree.
Observation:
Because index i now corresponds to the slot holding T_new
, the priority update derived from T_old
's error is applied to the priority associated with T_new
. While T_new
might be relevant, it's not the transition that generated the specific TD error being used for the update.
Discussion Point:
This behavior is characteristic of the standard, efficient SumTree PER implementation – it prioritizes O(log N) complexity over guaranteeing that an update signal always matches the exact transition that generated it (if that transition has been overwritten).
While efficient, this behavior means the priority of one transition (T_new
) is being influenced directly by the error signal from a potentially unrelated, previous transition (T_old
)
Potential Alternatives (with Trade-offs):
An alternative approach involves using unique IDs for transitions and an ID-to-index map. This allows checking if the transition at the target index still matches the originally sampled ID before applying the update, ensuring accuracy but adding complexity and memory overhead (though retaining O(log N) complexity).
I would be happy to see your perspectives on this issue, and I wonder if this priority mismatch can introduce bias to the learning procedure or not, because it can affect the sampling step of the replay buffer.
Thank you for this great repository and your support of the RL community.