-
Notifications
You must be signed in to change notification settings - Fork 12
[RLlib] Introduce RLlib Intro template #392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
a1b4c77
27204d7
a4dae85
6e3021d
ab7cf70
1c94db1
acac982
f859d4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
head_node_type: | ||
name: head | ||
instance_type: m5.2xlarge | ||
resources: | ||
cpu: 0 | ||
worker_node_types: | ||
- name: cpu-worker | ||
instance_type: m5.2xlarge | ||
min_workers: 0 | ||
max_workers: 100 | ||
use_spot: false | ||
- name: gpu-worker-t4-1 | ||
instance_type: g4dn.2xlarge | ||
resources: | ||
cpu: | ||
gpu: | ||
memory: | ||
object_store_memory: | ||
custom_resources: | ||
"accelerator_type:T4": 1 | ||
min_workers: 0 | ||
max_workers: 100 | ||
- name: gpu-worker-t4-4 | ||
instance_type: g4dn.12xlarge | ||
resources: | ||
cpu: | ||
gpu: | ||
memory: | ||
object_store_memory: | ||
custom_resources: | ||
"accelerator_type:T4": 1 | ||
min_workers: 0 | ||
max_workers: 100 | ||
aws: | ||
TagSpecifications: | ||
- ResourceType: instance | ||
Tags: | ||
- Key: as-feature-enable-multi-az-serve | ||
Value: "true" | ||
- Key: as-feature-multi-zone | ||
Value: "true" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
head_node_type: | ||
name: head | ||
instance_type: n2-standard-8 | ||
resources: | ||
cpu: 0 | ||
worker_node_types: | ||
- name: cpu-worker | ||
instance_type: n2-standard-8 | ||
min_workers: 0 | ||
max_workers: 100 | ||
- name: gpu-worker-t4-1 | ||
instance_type: n1-standard-8-nvidia-t4-16gb-1 | ||
resources: | ||
cpu: | ||
gpu: | ||
memory: | ||
object_store_memory: | ||
custom_resources: | ||
"accelerator_type:T4": 1 | ||
min_workers: 0 | ||
max_workers: 100 | ||
- name: gpu-worker-l4-1 | ||
instance_type: g2-standard-12-nvidia-l4-1 | ||
resources: | ||
cpu: | ||
gpu: | ||
memory: | ||
object_store_memory: | ||
custom_resources: | ||
"accelerator_type:L4": 1 | ||
min_workers: 0 | ||
max_workers: 100 | ||
- name: gpu-worker-l4-2 | ||
instance_type: g2-standard-24-nvidia-l4-2 | ||
resources: | ||
cpu: | ||
gpu: | ||
memory: | ||
object_store_memory: | ||
custom_resources: | ||
"accelerator_type:L4": 1 | ||
min_workers: 0 | ||
max_workers: 100 | ||
- name: gpu-worker-l4-4 | ||
instance_type: g2-standard-48-nvidia-l4-4 | ||
resources: | ||
cpu: | ||
gpu: | ||
memory: | ||
object_store_memory: | ||
custom_resources: | ||
"accelerator_type:L4": 1 | ||
min_workers: 0 | ||
max_workers: 100 | ||
gcp_advanced_configurations_json: | ||
instance_properties: | ||
labels: | ||
as-feature-multi-zone: 'true' | ||
as-feature-enable-multi-az-serve: 'true' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Reinforcement Learning with RLlib\n", | ||
"\n", | ||
"**⏱️ Time to complete**: 5 min\n", | ||
"\n", | ||
"RLlib is Ray's library for reinforcement learning. Built on Ray, it is highly scalable and fault-tolerant.\n", | ||
"This template walks you through running a quick entry-level training. Specifically, we use RLlib's main APIs for defining a training workload, kick it off and observe it's metrics to see how it is doing. In addition, you can explore common use-cases like introducing your own environment." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install \"ray[rllib]\" torch \"gymnasium[atari,accept-rom-license,mujoco]\" python-opencv-headless\n", | ||
"# !RAY_VERSION=2.39.0 # Set this when releasing an image instead of hardcoding it here\n", | ||
"# !curl -O https://raw.githubusercontent.com/ray-project/ray/refs/heads/releases/${RAY_VERSION}/rllib/tuned_examples/ppo/atari_ppo.py" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Learning Pong from images\n", | ||
"\n", | ||
"We start by learning the classic [Pong](https://en.wikipedia.org/wiki/Pong) video game.\n", | ||
"Expect this to scale the cluster 4 GPUs and 96 CPUs. It should take around 5 minutes to learn Pong." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# --wandb-key=... --wandb-project=my_atari_tests\n", | ||
"!python atari_ppo.py --env ale_py:ALE/Pong-v5 --num-gpus 4 --num-env-runners 95" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"You can find a multitude of examples [here](https://github.com/ray-project/ray/tree/master/rllib/examples) and run them from the command line like we did here." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import gymnasium as gym | ||
|
||
from ray import air | ||
from ray import tune | ||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.rllib.connectors.env_to_module.frame_stacking import FrameStackingEnvToModule | ||
from ray.rllib.connectors.learner.frame_stacking import FrameStackingLearner | ||
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig | ||
from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack | ||
from ray.rllib.utils.test_utils import check_train_results_new_api_stack | ||
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from ray.rllib.utils.test_utils import add_rllib_example_script_args | ||
|
||
|
||
parser = add_rllib_example_script_args( | ||
default_reward=float("inf"), | ||
default_timesteps=3000000, | ||
default_iters=100000000000, | ||
) | ||
parser.set_defaults( | ||
enable_new_api_stack=True, | ||
env="ale_py:ALE/Pong-v5", | ||
) | ||
# Use `parser` to add your own custom command line options to this script | ||
# and (if needed) use their values toset up `config` below. | ||
args = parser.parse_args() | ||
|
||
|
||
def _make_env_to_module_connector(env): | ||
return FrameStackingEnvToModule(num_frames=4) | ||
|
||
|
||
def _make_learner_connector(input_observation_space, input_action_space): | ||
return FrameStackingLearner(num_frames=4) | ||
|
||
|
||
# Create a custom Atari setup (w/o the usual RLlib-hard-coded framestacking in it). | ||
# We would like our frame stacking connector to do this job. | ||
def _env_creator(cfg): | ||
return wrap_atari_for_new_api_stack( | ||
gym.make("ale_py:ALE/Pong-v5", **cfg, render_mode="rgb_array"), | ||
# Perform frame-stacking through ConnectorV2 API. | ||
framestack=None, | ||
) | ||
|
||
|
||
tune.register_env("env", _env_creator) | ||
|
||
|
||
config = ( | ||
PPOConfig() | ||
.environment( | ||
"env", | ||
env_config={ | ||
# Make analogous to old v4 + NoFrameskip. | ||
"frameskip": 1, | ||
"full_action_space": False, | ||
"repeat_action_probability": 0.0, | ||
}, | ||
clip_rewards=True, | ||
) | ||
.env_runners( | ||
env_to_module_connector=_make_env_to_module_connector, | ||
) | ||
.training( | ||
learner_connector=_make_learner_connector, | ||
train_batch_size_per_learner=4000, # 5000 on old yaml example | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove: "# 5000 on old yaml example" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: With the approach we are establish now as per our last offline conversation, this won't work anymore, since we copy the config from the tuned example from the ray repo. |
||
minibatch_size=128, # 500 on old yaml example | ||
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lambda_=0.95, | ||
kl_coeff=0.5, | ||
clip_param=0.1, | ||
vf_clip_param=10.0, | ||
entropy_coeff=0.01, | ||
num_epochs=10, | ||
lr=0.00015 * args.num_gpus, | ||
grad_clip=100.0, | ||
grad_clip_by="global_norm", | ||
) | ||
.rl_module( | ||
model_config=DefaultModelConfig( | ||
conv_filters=[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], | ||
conv_activation="relu", | ||
head_fcnet_hiddens=[256], | ||
vf_share_layers=True, | ||
), | ||
) | ||
) | ||
|
||
config.resources(num_gpus=0) | ||
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved
Hide resolved
|
||
config.learners( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move these into the main config bracket above. |
||
num_learners=4, | ||
num_gpus_per_learner=1, | ||
) | ||
|
||
config.env_runners(num_env_runners=95) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this, too (move these into the main config bracket above). |
||
|
||
algo = config.build() | ||
|
||
for i in range(100): | ||
results = algo.train() | ||
print(f"loss={results["learners"]["default_policy"]["total_loss"]}") |
Uh oh!
There was an error while loading. Please reload this page.