@@ -12,10 +12,8 @@ def test_run(model_class):
1212 env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"
1313 env = make_vec_env (env_id , n_envs = 2 )
1414
15- # FIXME: need to set the discount factor manually
1615 n_steps = 2
1716 gamma = 0.99
18- discount = gamma ** n_steps
1917
2018 model = model_class (
2119 "MlpPolicy" ,
@@ -29,7 +27,7 @@ def test_run(model_class):
2927 policy_kwargs = dict (net_arch = [64 ]),
3028 learning_starts = 100 ,
3129 buffer_size = int (2e4 ),
32- gamma = discount ,
30+ gamma = gamma ,
3331 )
3432
3533 model .learn (total_timesteps = 150 )
@@ -103,11 +101,11 @@ def test_nstep_early_termination(done_at, n_steps):
103101
104102 base_idx = 0
105103 batch = buffer ._get_samples (np .array ([base_idx ]))
106- actual = batch .rewards .numpy (). item ()
104+ actual = batch .rewards .item ()
107105
108106 expected = compute_expected_nstep_reward (gamma = 0.99 , n_steps = n_steps , stop_idx = done_at - base_idx )
109107 np .testing .assert_allclose (actual , expected , rtol = 1e-4 )
110- assert batch .dones .numpy (). item () == 1.0
108+ assert batch .dones .item () == 1.0
111109
112110
113111@pytest .mark .parametrize ("truncated_at" , [1 , 2 ])
@@ -117,46 +115,51 @@ def test_nstep_early_truncation(truncated_at):
117115
118116 base_idx = 0
119117 batch = buffer ._get_samples (np .array ([base_idx ]))
120- actual = batch .rewards .numpy (). item ()
118+ actual = batch .rewards .item ()
121119
122120 expected = compute_expected_nstep_reward (gamma = 0.99 , n_steps = 3 , stop_idx = truncated_at - base_idx )
123121 np .testing .assert_allclose (actual , expected , rtol = 1e-4 )
124- assert batch .dones .numpy (). item () == 0.0
122+ assert batch .dones .item () == 0.0
125123
126124
127125@pytest .mark .parametrize ("n_steps" , [3 , 5 ])
128- def test_nstep_no_termination_or_truncation (n_steps ):
126+ def test_nstep_no_terminations (n_steps ):
129127 buffer = create_buffer (n_steps = n_steps )
130128 fill_buffer (buffer , length = 10 ) # no done or truncation
129+ gamma = 0.99
131130
132131 base_idx = 3
133132 batch = buffer ._get_samples (np .array ([base_idx ]))
134- actual = batch .rewards .numpy ().item ()
135-
136- expected = compute_expected_nstep_reward (gamma = 0.99 , n_steps = n_steps )
133+ actual = batch .rewards .item ()
134+ # Discount factor for bootstrapping with target Q-Value
135+ np .testing .assert_allclose (batch .discounts .item (), gamma ** n_steps )
136+ expected = compute_expected_nstep_reward (gamma = gamma , n_steps = n_steps )
137137 np .testing .assert_allclose (actual , expected , rtol = 1e-4 )
138- assert batch .dones .numpy (). item () == 0.0
138+ assert batch .dones .item () == 0.0
139139
140140 # Check that self.pos-1 truncation is set when buffer is full
141141 # Note: buffer size is 10, here we are erasing past transitions
142142 fill_buffer (buffer , length = 2 )
143143 # We create a tmp truncation to not sample across episodes
144144 base_idx = 0
145145 batch = buffer ._get_samples (np .array ([base_idx ]))
146- actual = batch .rewards .numpy (). item ()
146+ actual = batch .rewards .item ()
147147 # Note: compute_expected_nstep assumes base_idx=1
148148 expected = compute_expected_nstep_reward (gamma = 0.99 , n_steps = n_steps , stop_idx = buffer .pos - 1 )
149149 np .testing .assert_allclose (actual , expected , rtol = 1e-4 )
150- assert batch .dones .numpy ().item () == 0.0
150+ assert batch .dones .item () == 0.0
151+ # Discount factor for bootstrapping with target Q-Value
152+ # (bigger than gamma ** n_steps because of truncation at n_steps=2)
153+ np .testing .assert_allclose (batch .discounts .item (), gamma ** 2 )
151154
152155 # Set done=1 manually, the tmp truncation should not be set (it would set batch.done=False)
153156 buffer .dones [buffer .pos - 1 , :] = True
154157 batch = buffer ._get_samples (np .array ([base_idx ]))
155- actual = batch .rewards .numpy (). item ()
158+ actual = batch .rewards .item ()
156159 # Note: compute_expected_nstep assumes base_idx=0
157160 expected = compute_expected_nstep_reward (gamma = 0.99 , n_steps = n_steps , stop_idx = buffer .pos - 1 )
158161 np .testing .assert_allclose (actual , expected , rtol = 1e-4 )
159- assert batch .dones .numpy (). item () == 1.0
162+ assert batch .dones .item () == 1.0
160163
161164
162165def test_match_normal_buffer ():
@@ -168,12 +171,12 @@ def test_match_normal_buffer():
168171
169172 base_idx = 3
170173 batch1 = buffer ._get_samples (np .array ([base_idx ]))
171- actual1 = batch1 .rewards .numpy (). item ()
174+ actual1 = batch1 .rewards .item ()
172175
173176 batch2 = ref_buffer ._get_samples (np .array ([base_idx ]))
174177
175178 expected = compute_expected_nstep_reward (gamma = 0.99 , n_steps = 1 )
176179 np .testing .assert_allclose (actual1 , expected , rtol = 1e-4 )
177- assert batch1 .dones .numpy (). item () == 0.0
180+ assert batch1 .dones .item () == 0.0
178181
179182 np .testing .assert_allclose (batch1 .rewards .numpy (), batch2 .rewards .numpy (), rtol = 1e-4 )
0 commit comments