Skip to content

Commit 3a868fb

Browse files
committed
fix one parameter
1 parent db7cdff commit 3a868fb

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/run.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def warm_start_simulation():
318318
cumulative_turns = 0
319319

320320
res = {}
321+
warm_start_run_epochs = 0
321322
for episode in xrange(warm_start_epochs):
322323
dialog_manager.initialize_episode()
323324
episode_over = False
@@ -331,13 +332,15 @@ def warm_start_simulation():
331332
else: print ("warm_start simulation episode %s: Fail" % (episode))
332333
cumulative_turns += dialog_manager.state_tracker.turn_count
333334

335+
warm_start_run_epochs += 1
336+
334337
if len(agent.experience_replay_pool) >= agent.experience_replay_pool_size:
335338
break
336-
339+
337340
agent.warm_start = 2
338-
res['success_rate'] = float(successes)/simulation_epoch_size
339-
res['ave_reward'] = float(cumulative_reward)/simulation_epoch_size
340-
res['ave_turns'] = float(cumulative_turns)/simulation_epoch_size
341+
res['success_rate'] = float(successes)/warm_start_run_epochs
342+
res['ave_reward'] = float(cumulative_reward)/warm_start_run_epochs
343+
res['ave_turns'] = float(cumulative_turns)/warm_start_run_epochs
341344
print ("Warm_Start %s epochs, success rate %s, ave reward %s, ave turns %s" % (episode+1, res['success_rate'], res['ave_reward'], res['ave_turns']))
342345
print ("Current experience replay buffer size %s" % (len(agent.experience_replay_pool)))
343346

0 commit comments

Comments
 (0)