Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Enable multi-process training with multi-node Coach. #240

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def create_worker_server_and_device(cluster_spec: tf.train.ClusterSpec, task_ind
return server.target, device


def create_monitored_session(target: tf.train.Server, task_index: int,
checkpoint_dir: str, checkpoint_save_secs: int, config: tf.ConfigProto=None) -> tf.Session:
def create_monitored_session(target: tf.train.Server, task_index: int, checkpoint_dir: str, checkpoint_save_secs: int,
scaffold: tf.train.Scaffold, config: tf.ConfigProto=None) -> tf.Session:
"""
Create a monitored session for the worker
:param target: the target string for the tf.Session
Expand Down
2 changes: 1 addition & 1 deletion rl_coach/architectures/tensorflow_components/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, name):
# if graph is finalized, savers must have already already been added. This happens
# in the case of a MonitoredSession
self._variables = tf.global_variables()

# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
self._variables = [v for v in self._variables if '/target' not in v.name]
Expand Down
4 changes: 2 additions & 2 deletions rl_coach/base_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ def __init__(self, framework_type: Frameworks=Frameworks.tensorflow, evaluate_on


class DistributedTaskParameters(TaskParameters):
def __init__(self, framework_type: Frameworks, parameters_server_hosts: str, worker_hosts: str, job_type: str,
task_index: int, evaluate_only: int=None, num_tasks: int=None,
def __init__(self, framework_type: Frameworks=None, parameters_server_hosts: str=None, worker_hosts: str=None,
job_type: str=None, task_index: int=None, evaluate_only: int=None, num_tasks: int=None,
num_training_tasks: int=None, use_cpu: bool=False, experiment_path=None, dnd=None,
shared_memory_scratchpad=None, seed=None, checkpoint_save_secs=None, checkpoint_restore_dir=None,
checkpoint_save_dir=None, export_onnx_graph: bool=False, apply_stop_condition: bool=False):
Expand Down
135 changes: 53 additions & 82 deletions rl_coach/coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from multiprocessing.managers import BaseManager
import subprocess
from rl_coach.graph_managers.graph_manager import HumanPlayScheduleParameters, GraphManager
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, get_base_dir
from rl_coach.utils import list_all_presets, short_dynamic_import, get_open_port, SharedMemoryScratchPad, \
get_base_dir, start_multi_threaded_learning
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.environments.environment import SingleLevelSelection
from rl_coach.memories.backend.redis import RedisPubSubMemoryBackendParameters
Expand Down Expand Up @@ -87,6 +88,17 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'

def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
ckpt_inside_container = "/checkpoint"
non_dist_task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)

memory_backend_params = None
if args.memory_backend_params:
Expand All @@ -102,15 +114,18 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
graph_manager.data_store_params = data_store_params

if args.distributed_coach_run_type == RunType.TRAINER:
if not args.distributed_training:
task_parameters = non_dist_task_parameters
task_parameters.checkpoint_save_dir = ckpt_inside_container
training_worker(
graph_manager=graph_manager,
task_parameters=task_parameters,
args=args,
is_multi_node_test=args.is_multi_node_test
)

if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER:
task_parameters.checkpoint_restore_dir = ckpt_inside_container
non_dist_task_parameters.checkpoint_restore_dir = ckpt_inside_container

data_store = None
if args.data_store_params:
Expand All @@ -120,7 +135,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
graph_manager=graph_manager,
data_store=data_store,
num_workers=args.num_workers,
task_parameters=task_parameters
task_parameters=non_dist_task_parameters
)


Expand Down Expand Up @@ -552,6 +567,11 @@ def get_argument_parser(self) -> argparse.ArgumentParser:
parser.add_argument('-dc', '--distributed_coach',
help="(flag) Use distributed Coach.",
action='store_true')
parser.add_argument('-dt', '--distributed_training',
help="(flag) Use distributed training with Coach."
"Used only with --distributed_coach flag."
"Ignored if --distributed_coach flag is not used.",
action='store_true')
parser.add_argument('-dcp', '--distributed_coach_config_path',
help="(string) Path to config file when using distributed rollout workers."
"Only distributed Coach parameters should be provided through this config file."
Expand Down Expand Up @@ -607,18 +627,31 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
atexit.register(logger.summarize_experiment)
screen.change_terminal_title(args.experiment_name)

task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
if args.num_workers == 1:
task_parameters = TaskParameters(
framework_type=args.framework,
evaluate_only=args.evaluate,
experiment_path=args.experiment_path,
seed=args.seed,
use_cpu=args.use_cpu,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
else:
task_parameters = DistributedTaskParameters(
framework_type=args.framework,
use_cpu=args.use_cpu,
num_training_tasks=args.num_workers,
experiment_path=args.experiment_path,
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)

# open dashboard
if args.open_dashboard:
Expand All @@ -633,78 +666,16 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp

# Single-threaded runs
if args.num_workers == 1:
self.start_single_threaded(task_parameters, graph_manager, args)
self.start_single_threaded_learning(task_parameters, graph_manager, args)
else:
self.start_multi_threaded(graph_manager, args)
global start_graph
start_multi_threaded_learning(start_graph, (graph_manager, task_parameters),
task_parameters, graph_manager, args)

def start_single_threaded(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
def start_single_threaded_learning(self, task_parameters, graph_manager: 'GraphManager', args: argparse.Namespace):
# Start the training or evaluation
start_graph(graph_manager=graph_manager, task_parameters=task_parameters)

def start_multi_threaded(self, graph_manager: 'GraphManager', args: argparse.Namespace):
total_tasks = args.num_workers
if args.evaluation_worker:
total_tasks += 1

ps_hosts = "localhost:{}".format(get_open_port())
worker_hosts = ",".join(["localhost:{}".format(get_open_port()) for i in range(total_tasks)])

# Shared memory
class CommManager(BaseManager):
pass
CommManager.register('SharedMemoryScratchPad', SharedMemoryScratchPad, exposed=['add', 'get', 'internal_call'])
comm_manager = CommManager()
comm_manager.start()
shared_memory_scratchpad = comm_manager.SharedMemoryScratchPad()

def start_distributed_task(job_type, task_index, evaluation_worker=False,
shared_memory_scratchpad=shared_memory_scratchpad):
task_parameters = DistributedTaskParameters(
framework_type=args.framework,
parameters_server_hosts=ps_hosts,
worker_hosts=worker_hosts,
job_type=job_type,
task_index=task_index,
evaluate_only=0 if evaluation_worker else None, # 0 value for evaluation worker as it should run infinitely
use_cpu=args.use_cpu,
num_tasks=total_tasks, # training tasks + 1 evaluation task
num_training_tasks=args.num_workers,
experiment_path=args.experiment_path,
shared_memory_scratchpad=shared_memory_scratchpad,
seed=args.seed+task_index if args.seed is not None else None, # each worker gets a different seed
checkpoint_save_secs=args.checkpoint_save_secs,
checkpoint_restore_dir=args.checkpoint_restore_dir,
checkpoint_save_dir=args.checkpoint_save_dir,
export_onnx_graph=args.export_onnx_graph,
apply_stop_condition=args.apply_stop_condition
)
# we assume that only the evaluation workers are rendering
graph_manager.visualization_parameters.render = args.render and evaluation_worker
p = Process(target=start_graph, args=(graph_manager, task_parameters))
# p.daemon = True
p.start()
return p

# parameter server
parameter_server = start_distributed_task("ps", 0)

# training workers
# wait a bit before spawning the non chief workers in order to make sure the session is already created
workers = []
workers.append(start_distributed_task("worker", 0))
time.sleep(2)
for task_index in range(1, args.num_workers):
workers.append(start_distributed_task("worker", task_index))

# evaluation worker
if args.evaluation_worker or args.render:
evaluation_worker = start_distributed_task("worker", args.num_workers, evaluation_worker=True)

# wait for all workers
[w.join() for w in workers]
if args.evaluation_worker:
evaluation_worker.terminate()


def main():
launcher = CoachLauncher()
Expand Down
20 changes: 14 additions & 6 deletions rl_coach/data_stores/s3_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,20 @@ def save_to_store(self):
# Acquire lock
self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)

ckpt_state_filename = CheckpointStateFile.checkpoint_state_filename
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
if state_file.exists():
ckpt_state = state_file.read()
ckpt_name_prefix = ckpt_state.name

if ckpt_state_filename is not None and ckpt_name_prefix is not None:
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for filename in files:
if filename == CheckpointStateFile.checkpoint_state_filename:
if filename == ckpt_state_filename:
checkpoint_file = (root, filename)
continue
if filename.startswith(ckpt_state.name):
if filename.startswith(ckpt_name_prefix):
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)
Expand Down Expand Up @@ -131,6 +135,8 @@ def load_from_store(self):
"""
try:
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
ckpt_state_filename = state_file.filename
ckpt_state_file_path = state_file.path

# wait until lock is removed
while True:
Expand All @@ -139,7 +145,7 @@ def load_from_store(self):
if next(objects, None) is None:
try:
# fetch checkpoint state file from S3
self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
self.mc.fget_object(self.params.bucket_name, ckpt_state_filename, ckpt_state_file_path)
except Exception as e:
continue
break
Expand All @@ -156,10 +162,12 @@ def load_from_store(self):
)
except Exception as e:
pass
state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))
ckpt_state = state_file.read()
ckpt_name_prefix = ckpt_state.name

checkpoint_state = state_file.read()
if checkpoint_state is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
if ckpt_name_prefix is not None:
objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=ckpt_name_prefix, recursive=True)
for obj in objects:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
if not os.path.exists(filename):
Expand Down
23 changes: 16 additions & 7 deletions rl_coach/graph_managers/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,15 @@ def _create_session_tf(self, task_parameters: TaskParameters):
else:
checkpoint_dir = task_parameters.checkpoint_save_dir

self.checkpoint_saver = tf.train.Saver()
scaffold = tf.train.Scaffold(saver=self.checkpoint_saver)

self.sess = create_monitored_session(target=task_parameters.worker_target,
task_index=task_parameters.task_index,
checkpoint_dir=checkpoint_dir,
checkpoint_save_secs=task_parameters.checkpoint_save_secs,
config=config)
config=config,
scaffold=scaffold)
# set the session for all the modules
self.set_session(self.sess)
else:
Expand Down Expand Up @@ -258,9 +262,11 @@ def create_session(self, task_parameters: TaskParameters):
raise ValueError('Invalid framework {}'.format(task_parameters.framework_type))

# Create parameter saver
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())
if not isinstance(task_parameters, DistributedTaskParameters):
self.checkpoint_saver = SaverCollection()
for level in self.level_managers:
self.checkpoint_saver.update(level.collect_savers())

# restore from checkpoint if given
self.restore_checkpoint()

Expand Down Expand Up @@ -540,8 +546,9 @@ def improve(self):
count_end = self.total_steps_counters[RunPhase.TRAIN] + self.improve_steps
while self.total_steps_counters[RunPhase.TRAIN] < count_end:
self.train_and_act(self.steps_between_evaluation_periods)
if self.evaluate(self.evaluation_steps):
break
if self.task_parameters.task_index == 0 or self.task_parameters.task_index is None:
if self.evaluate(self.evaluation_steps):
break

def restore_checkpoint(self):
self.verify_graph_was_created()
Expand Down Expand Up @@ -599,7 +606,9 @@ def save_checkpoint(self):
if not isinstance(self.task_parameters, DistributedTaskParameters):
saved_checkpoint_path = self.checkpoint_saver.save(self.sess, checkpoint_path)
else:
saved_checkpoint_path = checkpoint_path
# FIXME: Explicitly managing Saver inside monitored training session is not recommended.
# https://github.com/tensorflow/tensorflow/issues/8425#issuecomment-286927528.
saved_checkpoint_path = self.checkpoint_saver.save(self.sess._tf_sess(), checkpoint_path)

# this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(checkpoint_name) for manager in self.level_managers]
Expand Down
4 changes: 3 additions & 1 deletion rl_coach/tests/test_dist_coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def get_tests():
"""
tests = [
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1'
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1',
'rl_coach/coach.py -p CartPole_ClippedPPO -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2',
'rl_coach/coach.py -p Mujoco_ClippedPPO -lvl inverted_pendulum -dc -e sample -dcp {template} --dump_worker_logs -asc --is_multi_node_test --seed 1 -n 2'
]
return tests

Expand Down
Loading