1+ from __future__ import unicode_literals
2+ from __future__ import print_function
3+ from __future__ import division
4+ from __future__ import absolute_import
5+ from builtins import * # NOQA
6+ from future import standard_library
7+ standard_library .install_aliases ()
18import random
29
10+ import numpy as np
11+
312
413class PrioritizedBuffer (object ):
5- def __init__ (self , capacity = None ):
14+
15+ def __init__ (self , capacity = None , wait_priority_after_sampling = True ):
616 self .capacity = capacity
717 self .data = []
818 self .priority_tree = SumTree ()
919 self .data_inf = []
20+ self .wait_priority_after_sampling = wait_priority_after_sampling
1021 self .flag_wait_priority = False
1122
1223 def __len__ (self ):
1324 return len (self .data ) + len (self .data_inf )
1425
15- def append (self , value ):
16- # new values are the most prioritized
17- self .data_inf .append (value )
18- if self .capacity is not None and len (self ) > self .capacity :
26+ def append (self , value , priority = None ):
27+ if self .capacity is not None and len (self ) == self .capacity :
1928 self .pop ()
29+ if priority is not None :
30+ # Append with a given priority
31+ i = len (self .data )
32+ self .data .append (value )
33+ self .priority_tree [i ] = priority
34+ else :
35+ # Append with the highest priority
36+ self .data_inf .append (value )
2037
2138 def _pop_random_data_inf (self ):
2239 assert self .data_inf
2340 n = len (self .data_inf )
2441 i = random .randrange (n )
2542 ret = self .data_inf [i ]
26- self .data_inf [i ] = self .data_inf [n - 1 ]
43+ self .data_inf [i ] = self .data_inf [n - 1 ]
2744 self .data_inf .pop ()
2845 return ret
2946
@@ -33,47 +50,96 @@ def pop(self):
3350 Not prioritized.
3451 """
3552 assert len (self ) > 0
36- assert not self .flag_wait_priority
53+ assert (not self .wait_priority_after_sampling or
54+ not self .flag_wait_priority )
3755 n = len (self .data )
3856 if n == 0 :
3957 return self ._pop_random_data_inf ()
4058 i = random .randrange (0 , n )
4159 # remove i-th
42- self .priority_tree [i ] = self .priority_tree [n - 1 ]
43- del self .priority_tree [n - 1 ]
60+ self .priority_tree [i ] = self .priority_tree [n - 1 ]
61+ del self .priority_tree [n - 1 ]
4462 ret = self .data [i ]
45- self .data [i ] = self .data [n - 1 ]
46- del self .data [n - 1 ]
63+ self .data [i ] = self .data [n - 1 ]
64+ del self .data [n - 1 ]
4765 return ret
4866
49- def sample (self , n ):
50- """Sample n distinct elements"""
67+ def _prioritized_sample_indices_and_probabilities (self , n ):
5168 assert 0 <= n <= len (self )
52- assert not self .flag_wait_priority
5369 indices , probabilities = self .priority_tree .prioritized_sample (
54- max (0 , n - len (self .data_inf )), remove = True )
55- sampled = []
56- for i in indices :
57- sampled .append (self .data [i ])
58- while len (sampled ) < n and len (self .data_inf ) > 0 :
70+ max (0 , n - len (self .data_inf )),
71+ remove = self .wait_priority_after_sampling )
72+ while len (indices ) < n :
5973 i = len (self .data )
6074 e = self ._pop_random_data_inf ()
6175 self .data .append (e )
6276 del self .priority_tree [i ]
6377 indices .append (i )
6478 probabilities .append (None )
65- sampled .append (self .data [i ])
79+ return indices , probabilities
80+
81+ def _sample_indices_and_probabilities (self , n , uniform_ratio ):
82+ if uniform_ratio > 0 :
83+ # Mix uniform samples and prioritized samples
84+ n_uniform = np .random .binomial (n , uniform_ratio )
85+ n_prioritized = n - n_uniform
86+ pr_indices , pr_probs = \
87+ self ._prioritized_sample_indices_and_probabilities (
88+ n_prioritized )
89+ un_indices , un_probs = \
90+ self ._uniform_sample_indices_and_probabilities (
91+ n_uniform )
92+ indices = pr_indices + un_indices
93+ # Note: when uniform samples and prioritized samples are mixed,
94+ # resulting probabilities are not the true probabilities for each
95+ # entry to be sampled.
96+ probabilities = pr_probs + un_probs
97+ return indices , probabilities
98+ else :
99+ # Only prioritized samples
100+ return self ._prioritized_sample_indices_and_probabilities (n )
101+
102+ def sample (self , n , uniform_ratio = 0 ):
103+ """Sample data along with their corresponding probabilities.
104+
105+ Args:
106+ n (int): Number of data to sample.
107+ uniform_ratio (float): Ratio of uniformly sampled data.
108+ Returns:
109+ sampled data (list)
110+ probabitilies (list)
111+ """
112+ assert (not self .wait_priority_after_sampling or
113+ not self .flag_wait_priority )
114+ indices , probabilities = self ._sample_indices_and_probabilities (
115+ n , uniform_ratio = uniform_ratio )
116+ sampled = [self .data [i ] for i in indices ]
66117 self .sampled_indices = indices
67118 self .flag_wait_priority = True
68119 return sampled , probabilities
69120
70121 def set_last_priority (self , priority ):
71- assert self .flag_wait_priority
122+ assert (not self .wait_priority_after_sampling or
123+ self .flag_wait_priority )
72124 assert all ([p > 0.0 for p in priority ])
73125 assert len (self .sampled_indices ) == len (priority )
74126 for i , p in zip (self .sampled_indices , priority ):
75127 self .priority_tree [i ] = p
76128 self .flag_wait_priority = False
129+ self .sampled_indices = []
130+
131+ def _uniform_sample_indices_and_probabilities (self , n ):
132+ indices = random .sample (range (len (self .data )),
133+ max (0 , n - len (self .data_inf )))
134+ probabilities = [1 / len (self )] * len (indices )
135+ while len (indices ) < n :
136+ i = len (self .data )
137+ e = self ._pop_random_data_inf ()
138+ self .data .append (e )
139+ del self .priority_tree [i ]
140+ indices .append (i )
141+ probabilities .append (None )
142+ return indices , probabilities
77143
78144
79145class SumTree (object ):
@@ -119,17 +185,18 @@ def _center(self):
119185
120186 def _allocindex (self , ix ):
121187 if self .bd is None :
122- self .bd = (ix , ix + 1 )
188+ self .bd = (ix , ix + 1 )
123189 while ix >= self .bd [1 ]:
124- r_bd = (self .bd [1 ], self .bd [1 ]* 2 - self .bd [0 ])
190+ r_bd = (self .bd [1 ], self .bd [1 ] * 2 - self .bd [0 ])
125191 l = SumTree (self .bd , self .l , self .r , self .s )
192+
126193 r = SumTree (bd = r_bd )._initdescendant ()
127194 self .bd = (l .bd [0 ], r .bd [1 ])
128195 self .l = l
129196 self .r = r
130197 # no need to update self.s because self.r.s == 0
131198 while ix < self .bd [0 ]:
132- l_bd = (self .bd [0 ]* 2 - self .bd [1 ], self .bd [0 ])
199+ l_bd = (self .bd [0 ] * 2 - self .bd [1 ], self .bd [0 ])
133200 l = SumTree (bd = l_bd )._initdescendant ()
134201 r = SumTree (self .bd , self .l , self .r , self .s )
135202 self .bd = (l .bd [0 ], r .bd [1 ])
0 commit comments