Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Sync workers with global lock #361

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion rl_coach/agents/clipped_ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
#

import copy
import multiprocessing
import sys
import time
from collections import OrderedDict
from random import shuffle
from typing import Union

import numpy as np

from rl_coach import sync_var
from rl_coach.agents.actor_critic_agent import ActorCriticAgent
from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
Expand Down Expand Up @@ -192,6 +196,26 @@ def fill_advantages(self, batch):

def train_network(self, batch, epochs):
batch_results = []

min_batch_size = batch.size
# distributed and sync training
# unification the min_bach_size between processes
if hasattr(self.ap.task_parameters, 'num_training_tasks') \
and self.ap.task_parameters.num_training_tasks > 1 \
and not self.networks['main'].network_parameters.async_training:
num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks
with sync_var.global_sync_obj.agent_lock_counter.get_lock():
sync_var.global_sync_obj.agent_lock_counter.value += 1
print('A PID:%s, agent count %s, workers %s' % (
multiprocessing.current_process().pid, sync_var.global_sync_obj.agent_lock_counter.value,
num_workers_to_wait_for))
if sync_var.global_sync_obj.min_batch_size.value > batch.size:
sync_var.global_sync_obj.min_batch_size.value = batch.size
while sync_var.global_sync_obj.agent_lock_counter.value % num_workers_to_wait_for != 0:
time.sleep(0.00001)
sync_var.global_sync_obj.agent_lock_counter.value = 0
min_batch_size = sync_var.global_sync_obj.min_batch_size.value

for j in range(epochs):
batch.shuffle()
batch_results = {
Expand All @@ -209,7 +233,11 @@ def train_network(self, batch, epochs):

# TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on
# some of the data
for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)):

# here is bug, if int(batch.size / self.ap.network_wrappers['main'].batch_size) is different between each workers,
# some worker will be in acting phase and other workers will be block in wait_for_all_workers_barrier

for i in range(int(min_batch_size / self.ap.network_wrappers['main'].batch_size)):
start = i * self.ap.network_wrappers['main'].batch_size
end = (i + 1) * self.ap.network_wrappers['main'].batch_size

Expand Down Expand Up @@ -291,6 +319,23 @@ def train_network(self, batch, epochs):
self.total_kl_divergence_during_training_process = batch_results['kl_divergence']
self.entropy.add_sample(batch_results['entropy'])
self.kl_divergence.add_sample(batch_results['kl_divergence'])

# process finish training loop, reset global_sync_obj.min_batch_size
# join all the worker
if hasattr(self.ap.task_parameters, 'num_training_tasks') \
and self.ap.task_parameters.num_training_tasks > 1 \
and not self.networks['main'].network_parameters.async_training:
num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks
with sync_var.global_sync_obj.agent_release_counter.get_lock():
sync_var.global_sync_obj.agent_release_counter.value += 1
sync_var.global_sync_obj.min_batch_size.value = sys.maxsize
print('B PID:%s, agent count %s, workers %s' % (
multiprocessing.current_process().pid, sync_var.global_sync_obj.agent_release_counter.value,
num_workers_to_wait_for))
while sync_var.global_sync_obj.agent_release_counter.value % num_workers_to_wait_for != 0:
time.sleep(0.00001)
sync_var.global_sync_obj.agent_release_counter.value = 0

return batch_results['losses']

def post_training_commands(self):
Expand Down
26 changes: 22 additions & 4 deletions rl_coach/architectures/tensorflow_components/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import tensorflow as tf

from rl_coach import sync_var
from rl_coach.architectures.architecture import Architecture
from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver
from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters
Expand Down Expand Up @@ -450,15 +451,32 @@ def wait_for_all_workers_barrier(self, include_only_training_workers: bool=False
:param include_only_training_workers: wait only for training workers or for all the workers?
:return: None
"""
self.wait_for_all_workers_to_lock('lock', include_only_training_workers=include_only_training_workers)
self.sess.run(self.lock_init)
# self.wait_for_all_workers_to_lock('lock', include_only_training_workers=include_only_training_workers)
# self.sess.run(self.lock_init)

# we need to lock again (on a different lock) in order to prevent a situation where one of the workers continue
# and then was able to first increase the lock again by one, only to have a late worker to reset it again.
# so we want to make sure that all workers are done resetting the lock before continuting to reuse that lock.

self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
self.sess.run(self.release_init)
# self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
# self.sess.run(self.release_init)

if include_only_training_workers:
num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks
else:
num_workers_to_wait_for = self.ap.task_parameters.num_tasks

with sync_var.global_sync_obj.lock_counter.get_lock():
sync_var.global_sync_obj.lock_counter.value += 1
while sync_var.global_sync_obj.lock_counter.value % num_workers_to_wait_for != 0:
time.sleep(0.00001)
sync_var.global_sync_obj.lock_counter.value = 0

with sync_var.global_sync_obj.release_counter.get_lock():
sync_var.global_sync_obj.release_counter.value += 1
while sync_var.global_sync_obj.release_counter.value % num_workers_to_wait_for != 0:
time.sleep(0.00001)
sync_var.global_sync_obj.release_counter.value = 0

def apply_gradients(self, gradients, scaler=1.):
"""
Expand Down
9 changes: 9 additions & 0 deletions rl_coach/check_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join('./experiments/kevin_test/checkpoint', "model.ckpt-67")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) # tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
# print(reader.get_tensor(key))
Loading