From f813a1c59267053b7e13df0bb74d4b67438973db Mon Sep 17 00:00:00 2001 From: kevin Date: Mon, 24 Jun 2019 00:48:06 +0800 Subject: [PATCH] 1, use global sync lock instand of lock_counter and release_counter in tensorflow. speed up counter access and prevent not zero problem. 2, fix a sync bug. Clipped_ppo_agent for example if int(batch.size / self.ap.network_wrappers['main'].batch_size) is different between each workers, some worker will be in acting phase and other workers will be block in wait_for_all_workers_barrier --- rl_coach/agents/clipped_ppo_agent.py | 47 +++++++++- .../tensorflow_components/architecture.py | 26 +++++- rl_coach/check_ckpt.py | 9 ++ rl_coach/coach.py | 93 +++++++++++++------ rl_coach/sync_var.py | 13 +++ 5 files changed, 153 insertions(+), 35 deletions(-) create mode 100644 rl_coach/check_ckpt.py create mode 100644 rl_coach/sync_var.py diff --git a/rl_coach/agents/clipped_ppo_agent.py b/rl_coach/agents/clipped_ppo_agent.py index cc29f3339..325c57862 100644 --- a/rl_coach/agents/clipped_ppo_agent.py +++ b/rl_coach/agents/clipped_ppo_agent.py @@ -15,12 +15,16 @@ # import copy +import multiprocessing +import sys +import time from collections import OrderedDict from random import shuffle from typing import Union import numpy as np +from rl_coach import sync_var from rl_coach.agents.actor_critic_agent import ActorCriticAgent from rl_coach.agents.policy_optimization_agent import PolicyGradientRescaler from rl_coach.architectures.embedder_parameters import InputEmbedderParameters @@ -192,6 +196,26 @@ def fill_advantages(self, batch): def train_network(self, batch, epochs): batch_results = [] + + min_batch_size = batch.size + # distributed and sync training + # unification the min_bach_size between processes + if hasattr(self.ap.task_parameters, 'num_training_tasks') \ + and self.ap.task_parameters.num_training_tasks > 1 \ + and not self.networks['main'].network_parameters.async_training: + num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks + with sync_var.global_sync_obj.agent_lock_counter.get_lock(): + sync_var.global_sync_obj.agent_lock_counter.value += 1 + print('A PID:%s, agent count %s, workers %s' % ( + multiprocessing.current_process().pid, sync_var.global_sync_obj.agent_lock_counter.value, + num_workers_to_wait_for)) + if sync_var.global_sync_obj.min_batch_size.value > batch.size: + sync_var.global_sync_obj.min_batch_size.value = batch.size + while sync_var.global_sync_obj.agent_lock_counter.value % num_workers_to_wait_for != 0: + time.sleep(0.00001) + sync_var.global_sync_obj.agent_lock_counter.value = 0 + min_batch_size = sync_var.global_sync_obj.min_batch_size.value + for j in range(epochs): batch.shuffle() batch_results = { @@ -209,7 +233,11 @@ def train_network(self, batch, epochs): # TODO-fixme if batch.size / self.ap.network_wrappers['main'].batch_size is not an integer, we do not train on # some of the data - for i in range(int(batch.size / self.ap.network_wrappers['main'].batch_size)): + + # here is bug, if int(batch.size / self.ap.network_wrappers['main'].batch_size) is different between each workers, + # some worker will be in acting phase and other workers will be block in wait_for_all_workers_barrier + + for i in range(int(min_batch_size / self.ap.network_wrappers['main'].batch_size)): start = i * self.ap.network_wrappers['main'].batch_size end = (i + 1) * self.ap.network_wrappers['main'].batch_size @@ -291,6 +319,23 @@ def train_network(self, batch, epochs): self.total_kl_divergence_during_training_process = batch_results['kl_divergence'] self.entropy.add_sample(batch_results['entropy']) self.kl_divergence.add_sample(batch_results['kl_divergence']) + + # process finish training loop, reset global_sync_obj.min_batch_size + # join all the worker + if hasattr(self.ap.task_parameters, 'num_training_tasks') \ + and self.ap.task_parameters.num_training_tasks > 1 \ + and not self.networks['main'].network_parameters.async_training: + num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks + with sync_var.global_sync_obj.agent_release_counter.get_lock(): + sync_var.global_sync_obj.agent_release_counter.value += 1 + sync_var.global_sync_obj.min_batch_size.value = sys.maxsize + print('B PID:%s, agent count %s, workers %s' % ( + multiprocessing.current_process().pid, sync_var.global_sync_obj.agent_release_counter.value, + num_workers_to_wait_for)) + while sync_var.global_sync_obj.agent_release_counter.value % num_workers_to_wait_for != 0: + time.sleep(0.00001) + sync_var.global_sync_obj.agent_release_counter.value = 0 + return batch_results['losses'] def post_training_commands(self): diff --git a/rl_coach/architectures/tensorflow_components/architecture.py b/rl_coach/architectures/tensorflow_components/architecture.py index 907593658..46b87e45d 100644 --- a/rl_coach/architectures/tensorflow_components/architecture.py +++ b/rl_coach/architectures/tensorflow_components/architecture.py @@ -20,6 +20,7 @@ import numpy as np import tensorflow as tf +from rl_coach import sync_var from rl_coach.architectures.architecture import Architecture from rl_coach.architectures.tensorflow_components.savers import GlobalVariableSaver from rl_coach.base_parameters import AgentParameters, DistributedTaskParameters @@ -450,15 +451,32 @@ def wait_for_all_workers_barrier(self, include_only_training_workers: bool=False :param include_only_training_workers: wait only for training workers or for all the workers? :return: None """ - self.wait_for_all_workers_to_lock('lock', include_only_training_workers=include_only_training_workers) - self.sess.run(self.lock_init) + # self.wait_for_all_workers_to_lock('lock', include_only_training_workers=include_only_training_workers) + # self.sess.run(self.lock_init) # we need to lock again (on a different lock) in order to prevent a situation where one of the workers continue # and then was able to first increase the lock again by one, only to have a late worker to reset it again. # so we want to make sure that all workers are done resetting the lock before continuting to reuse that lock. - self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers) - self.sess.run(self.release_init) + # self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers) + # self.sess.run(self.release_init) + + if include_only_training_workers: + num_workers_to_wait_for = self.ap.task_parameters.num_training_tasks + else: + num_workers_to_wait_for = self.ap.task_parameters.num_tasks + + with sync_var.global_sync_obj.lock_counter.get_lock(): + sync_var.global_sync_obj.lock_counter.value += 1 + while sync_var.global_sync_obj.lock_counter.value % num_workers_to_wait_for != 0: + time.sleep(0.00001) + sync_var.global_sync_obj.lock_counter.value = 0 + + with sync_var.global_sync_obj.release_counter.get_lock(): + sync_var.global_sync_obj.release_counter.value += 1 + while sync_var.global_sync_obj.release_counter.value % num_workers_to_wait_for != 0: + time.sleep(0.00001) + sync_var.global_sync_obj.release_counter.value = 0 def apply_gradients(self, gradients, scaler=1.): """ diff --git a/rl_coach/check_ckpt.py b/rl_coach/check_ckpt.py new file mode 100644 index 000000000..d0e04e77f --- /dev/null +++ b/rl_coach/check_ckpt.py @@ -0,0 +1,9 @@ +import os +from tensorflow.python import pywrap_tensorflow + +checkpoint_path = os.path.join('./experiments/kevin_test/checkpoint', "model.ckpt-67") +reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) # tf.train.NewCheckpointReader +var_to_shape_map = reader.get_variable_to_shape_map() +for key in var_to_shape_map: + print("tensor_name: ", key) + # print(reader.get_tensor(key)) diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 7edfbe5b0..038b3675c 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import multiprocessing import sys + +import numpy + sys.path.append('.') import copy from configparser import ConfigParser, Error import os -from rl_coach import logger +from rl_coach import logger, sync_var import traceback from rl_coach.logger import screen, failed_imports import argparse @@ -29,8 +32,7 @@ import json from rl_coach.base_parameters import Frameworks, VisualizationParameters, TaskParameters, DistributedTaskParameters, \ RunType, DistributedCoachSynchronizationType -from rl_coach.core_types import TotalStepsCounter, RunPhase, PlayingStepsType, TrainingSteps, EnvironmentEpisodes, \ - EnvironmentSteps, StepMethod, Transition +from rl_coach.core_types import EnvironmentEpisodes from multiprocessing import Process from multiprocessing.managers import BaseManager import subprocess @@ -49,7 +51,6 @@ from rl_coach.training_worker import training_worker from rl_coach.rollout_worker import rollout_worker - if len(set(failed_imports)) > 0: screen.warning("Warning: failed to import the following packages - {}".format(', '.join(set(failed_imports)))) @@ -71,20 +72,31 @@ def open_dashboard(experiment_path): subprocess.Popen(cmd, shell=True, executable="bash") -def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'): +def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters', + global_sync_obj: sync_var.SyncVar = None): """ Runs the graph_manager using the configured task_parameters. This stand-alone method is a convenience for multiprocessing. """ + sync_var.global_sync_obj = global_sync_obj graph_manager.create_graph(task_parameters) # let the adventure begin - if task_parameters.evaluate_only is not None: - steps_to_evaluate = task_parameters.evaluate_only if task_parameters.evaluate_only > 0 \ - else sys.maxsize - graph_manager.evaluate(EnvironmentSteps(steps_to_evaluate)) - else: - graph_manager.improve() + # if task_parameters.evaluate_only is not None: + # steps_to_evaluate = task_parameters.evaluate_only if task_parameters.evaluate_only > 0 \ + # else sys.maxsize + # graph_manager.evaluate(EnvironmentSteps(steps_to_evaluate)) + # else: + # graph_manager.improve() + # graph_manager.close() + + graph_manager.sync() + graph_manager.heatup(EnvironmentEpisodes(300)) + + for train_epoch in range(10000): + time.sleep(numpy.random.randint(10)) + graph_manager.train_and_act(EnvironmentEpisodes(600)) + graph_manager.evaluate(EnvironmentEpisodes(100)) graph_manager.close() @@ -95,7 +107,8 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters): if args.memory_backend_params: memory_backend_params = json.loads(args.memory_backend_params) memory_backend_params['run_type'] = str(args.distributed_coach_run_type) - graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params)) + graph_manager.agent_params.memory.register_var('memory_backend_params', + construct_memory_params(memory_backend_params)) data_store_params = None if args.data_store_params: @@ -142,7 +155,8 @@ def handle_distributed_coach_orchestrator(args): pass trainer_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.TRAINER)] + arg_list - rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', str(RunType.ROLLOUT_WORKER)] + arg_list + rollout_command = ['python3', 'rl_coach/coach.py', '--distributed_coach_run_type', + str(RunType.ROLLOUT_WORKER)] + arg_list if '--experiment_name' not in rollout_command: rollout_command = rollout_command + ['--experiment_name', args.experiment_name] @@ -157,13 +171,16 @@ def handle_distributed_coach_orchestrator(args): ds_params_instance = None if args.data_store == "s3": ds_params = DataStoreParameters("s3", "", "") - ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, bucket_name=args.s3_bucket_name, - creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container, expt_dir=args.experiment_path) + ds_params_instance = S3DataStoreParameters(ds_params=ds_params, end_point=args.s3_end_point, + bucket_name=args.s3_bucket_name, + creds_file=args.s3_creds_file, checkpoint_dir=ckpt_inside_container, + expt_dir=args.experiment_path) elif args.data_store == "nfs": ds_params = DataStoreParameters("nfs", "kubernetes", "") ds_params_instance = NFSDataStoreParameters(ds_params) - worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), num_replicas=args.num_workers) + worker_run_type_params = RunTypeParameters(args.image, rollout_command, run_type=str(RunType.ROLLOUT_WORKER), + num_replicas=args.num_workers) trainer_run_type_params = RunTypeParameters(args.image, trainer_command, run_type=str(RunType.TRAINER)) orchestration_params = KubernetesParameters([worker_run_type_params, trainer_run_type_params], @@ -242,7 +259,8 @@ def get_graph_manager_from_args(self, args: argparse.Namespace) -> 'GraphManager env_params = short_dynamic_import(args.environment_type, ignore_module_case=True)() env_params.human_control = True schedule_params = HumanPlayScheduleParameters() - graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, VisualizationParameters()) + graph_manager = BasicRLGraphManager(HumanAgentParameters(), env_params, schedule_params, + VisualizationParameters()) # Set framework # Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params @@ -401,7 +419,8 @@ def get_config_args(self, parser: argparse.ArgumentParser, arguments=None) -> ar # validate the checkpoints args if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir): # If distributed trainer, the checkpoint dir is not yet available so skipping the check in that case. - if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, RunType.ROLLOUT_WORKER]): + if not (args.distributed_coach and args.distributed_coach_run_type in [RunType.TRAINER, + RunType.ROLLOUT_WORKER]): screen.error("The requested checkpoint folder to load from does not exist.") # validate the checkpoints args @@ -432,7 +451,8 @@ def get_config_args(self, parser: argparse.ArgumentParser, arguments=None) -> ar args.framework = Frameworks[args.framework.lower()] # checkpoints - args.checkpoint_save_dir = os.path.join(args.experiment_path, 'checkpoint') if args.checkpoint_save_secs is not None else None + args.checkpoint_save_dir = os.path.join(args.experiment_path, + 'checkpoint') if args.checkpoint_save_secs is not None else None if args.export_onnx_graph and not args.checkpoint_save_secs: screen.warning("Exporting ONNX graphs requires setting the --checkpoint_save_secs flag. " @@ -487,9 +507,9 @@ def get_argument_parser(self) -> argparse.ArgumentParser: action='store_true') parser.add_argument('--evaluate', help="(int) Run evaluation only, for at least the given number of steps (note that complete " - "episodes are evaluated). This is a convenient way to disable training in order " - "to evaluate an existing checkpoint. If value is 0, or no value is provided, " - "evaluation will run for an infinite number of steps.", + "episodes are evaluated). This is a convenient way to disable training in order " + "to evaluate an existing checkpoint. If value is 0, or no value is provided, " + "evaluation will run for an infinite number of steps.", nargs='?', const=0, type=int) @@ -683,6 +703,7 @@ def start_multi_threaded(graph_manager: 'GraphManager', args: argparse.Namespace # Shared memory class CommManager(BaseManager): pass + CommManager.register('SharedMemoryScratchPad', SharedMemoryScratchPad, exposed=['add', 'get', 'internal_call']) comm_manager = CommManager() comm_manager.start() @@ -692,21 +713,22 @@ class CommManager(BaseManager): raise ValueError("Multi-Process runs only support restoring checkpoints from a directory, " "and not from a file. ") - def start_distributed_task(job_type, task_index, evaluation_worker=False, - shared_memory_scratchpad=shared_memory_scratchpad): + def start_distributed_task(job_type, task_index, global_sync_obj: sync_var.SyncVar = None, + evaluation_worker=False, shared_memory_scratchpad=shared_memory_scratchpad): task_parameters = DistributedTaskParameters( framework_type=args.framework, parameters_server_hosts=ps_hosts, worker_hosts=worker_hosts, job_type=job_type, task_index=task_index, - evaluate_only=0 if evaluation_worker else None, # 0 value for evaluation worker as it should run infinitely + evaluate_only=0 if evaluation_worker else None, + # 0 value for evaluation worker as it should run infinitely use_cpu=args.use_cpu, num_tasks=total_tasks, # training tasks + 1 evaluation task num_training_tasks=args.num_workers, experiment_path=args.experiment_path, shared_memory_scratchpad=shared_memory_scratchpad, - seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed + seed=args.seed + task_index if args.seed is not None else None, # each worker gets a different seed checkpoint_save_secs=args.checkpoint_save_secs, checkpoint_restore_path=args.checkpoint_restore_dir, # MonitoredTrainingSession only supports a dir checkpoint_save_dir=args.checkpoint_save_dir, @@ -715,7 +737,7 @@ def start_distributed_task(job_type, task_index, evaluation_worker=False, ) # we assume that only the evaluation workers are rendering graph_manager.visualization_parameters.render = args.render and evaluation_worker - p = Process(target=start_graph, args=(graph_manager, task_parameters)) + p = Process(target=start_graph, args=(graph_manager, task_parameters, global_sync_obj)) # p.daemon = True p.start() return p @@ -723,14 +745,24 @@ def start_distributed_task(job_type, task_index, evaluation_worker=False, # parameter server parameter_server = start_distributed_task("ps", 0) + # initial the global sync variable + global_sync_obj = sync_var.SyncVar + global_sync_obj.lock_counter = multiprocessing.Value('i', 0, lock=True) + global_sync_obj.release_counter = multiprocessing.Value('i', 0, lock=True) + global_sync_obj.agent_lock_counter = multiprocessing.Value('i', 0, lock=True) + global_sync_obj.agent_release_counter = multiprocessing.Value('i', 0, lock=True) + global_sync_obj.min_batch_size = multiprocessing.Value('l', sys.maxsize, lock=True) + # training workers # wait a bit before spawning the non chief workers in order to make sure the session is already created workers = [] - workers.append(start_distributed_task("worker", 0)) + workers.append(start_distributed_task("worker", 0, global_sync_obj)) time.sleep(2) for task_index in range(1, args.num_workers): - workers.append(start_distributed_task("worker", task_index)) + # start workers one by one, prevent the checkpoint file access conflict + time.sleep(1) + workers.append(start_distributed_task("worker", task_index, global_sync_obj)) # evaluation worker if args.evaluation_worker or args.render: @@ -747,6 +779,7 @@ class CoachInterface(CoachLauncher): This class is used as an interface to use coach as library. It can take any of the command line arguments (with the respective names) as arguments to the class. """ + def __init__(self, **kwargs): parser = self.get_argument_parser() diff --git a/rl_coach/sync_var.py b/rl_coach/sync_var.py new file mode 100644 index 000000000..52182680f --- /dev/null +++ b/rl_coach/sync_var.py @@ -0,0 +1,13 @@ +class SyncVar: + # 全局同步变量 + def __init__(self): + # for architecture/wait_for_all_workers_barrier + self.lock_counter = None + self.release_counter = None + # for agent/train_network + self.min_batch_size = None + self.agent_lock_counter = None + self.agent_release_counter = None + + +global global_sync_obj