Skip to content

Commit d0b6231

Browse files
authored
more documentation for libero examples (Physical-Intelligence#344)
2 parents bf25a4d + 4a10482 commit d0b6231

File tree

4 files changed

+161
-20
lines changed

4 files changed

+161
-20
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero
158158

159159
This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run the Libero evaluation script to query the server. For instructions how to install Libero and run the evaluation script, see the [Libero README](examples/libero/README.md).
160160

161+
If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
162+
163+
161164

162165
### More Examples
163166

docs/remote_inference.md

+32-3
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,39 @@ pip install -e .
3333
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
3434

3535
```python
36+
from openpi_client import image_tools
3637
from openpi_client import websocket_client_policy
3738

38-
policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000)
39-
action_chunk = policy_client.infer(example)["actions"]
39+
# Outside of episode loop, initialize the policy client.
40+
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
41+
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
42+
43+
for step in range(num_steps):
44+
# Inside the episode loop, construct the observation.
45+
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
46+
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
47+
# The typical resize_size for pre-trained pi0 models is 224.
48+
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
49+
observation = {
50+
"observation/image": image_tools.convert_to_uint8(
51+
image_tools.resize_with_pad(img, 224, 224)
52+
),
53+
"observation/wrist_image": image_tools.convert_to_uint8(
54+
image_tools.resize_with_pad(wrist_img, 224, 224)
55+
),
56+
"observation/state": state,
57+
"prompt": task_instruction,
58+
}
59+
60+
# Call the policy server with the current observation.
61+
# This returns an action chunk of shape (action_horizon, action_dim).
62+
# Note that you typically only need to call the policy every N steps and execute steps
63+
# from the predicted action chunk open-loop in the remaining steps.
64+
action_chunk = client.infer(observation)["actions"]
65+
66+
# Execute the actions in the environment.
67+
...
68+
4069
```
4170

42-
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
71+
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).

src/openpi/policies/libero_policy.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -28,45 +28,72 @@ def _parse_image(image) -> np.ndarray:
2828

2929
@dataclasses.dataclass(frozen=True)
3030
class LiberoInputs(transforms.DataTransformFn):
31+
"""
32+
This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
33+
34+
For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
35+
the correct elements of your dataset into the model.
36+
"""
37+
3138
# The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
39+
# Do not change this for your own dataset.
3240
action_dim: int
3341

3442
# Determines which model will be used.
43+
# Do not change this for your own dataset.
3544
model_type: _model.ModelType = _model.ModelType.PI0
3645

3746
def __call__(self, data: dict) -> dict:
38-
mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST.
39-
40-
# Get the state. We are padding from 8 to the model action dim.
41-
# For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
47+
# We only mask padding for pi0 model, not pi0-FAST. Do not change this for your own dataset.
48+
mask_padding = self.model_type == _model.ModelType.PI0
49+
50+
# We pad the proprioceptive input to the action dimension of the model.
51+
# For pi0-FAST, we don't pad the state. For Libero, we don't need to differentiate
52+
# since the pi0-FAST action_dim = 7, which is < state_dim = 8, so pad is skipped.
53+
# Keep this for your own dataset, but if your dataset stores the proprioceptive input
54+
# in a different key than "observation/state", you should change it below.
4255
state = transforms.pad_to_dim(data["observation/state"], self.action_dim)
4356

4457
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
45-
# stores as float32 (C,H,W), gets skipped for policy inference
58+
# stores as float32 (C,H,W), gets skipped for policy inference.
59+
# Keep this for your own dataset, but if your dataset stores the images
60+
# in a different key than "observation/image" or "observation/wrist_image",
61+
# you should change it below.
62+
# Pi0 models support three image inputs at the moment: one third-person view,
63+
# and two wrist views (left and right). If your dataset does not have a particular type
64+
# of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
65+
# right wrist image below.
4666
base_image = _parse_image(data["observation/image"])
4767
wrist_image = _parse_image(data["observation/wrist_image"])
4868

69+
# Create inputs dict. Do not change the keys in the dict below.
4970
inputs = {
5071
"state": state,
5172
"image": {
5273
"base_0_rgb": base_image,
5374
"left_wrist_0_rgb": wrist_image,
75+
# Pad any non-existent images with zero-arrays of the appropriate shape.
5476
"right_wrist_0_rgb": np.zeros_like(base_image),
5577
},
5678
"image_mask": {
5779
"base_0_rgb": np.True_,
5880
"left_wrist_0_rgb": np.True_,
81+
# Mask any non-existent images with False (if ``mask_padding`` is True).
5982
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
6083
},
6184
}
6285

86+
# Pad actions to the model action dimension. Keep this for your own dataset.
6387
# Actions are only available during training.
6488
if "actions" in data:
65-
# We are padding from 7 to the model action dim.
89+
# We are padding to the model action dim.
6690
# For pi0-FAST, this is a no-op (since action_dim = 7).
6791
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
6892
inputs["actions"] = actions
6993

94+
# Pass the prompt (aka language instruction) to the model.
95+
# Keep this for your own dataset (but modify the key if the instruction is not
96+
# stored in "prompt"; the output dict always needs to have the key "prompt").
7097
if "prompt" in data:
7198
inputs["prompt"] = data["prompt"]
7299

@@ -75,6 +102,16 @@ def __call__(self, data: dict) -> dict:
75102

76103
@dataclasses.dataclass(frozen=True)
77104
class LiberoOutputs(transforms.DataTransformFn):
105+
"""
106+
This class is used to convert outputs from the model back the the dataset specific format. It is
107+
used for inference only.
108+
109+
For your own dataset, you can copy this class and modify the action dimension based on the comments below.
110+
"""
111+
78112
def __call__(self, data: dict) -> dict:
79-
# Only return the first 7 dims.
113+
# Only return the first N actions -- since we padded actions above to fit the model action
114+
# dimension, we need to now parse out the correct number of actions in the return dict.
115+
# For Libero, we only return the first 7 actions (since the rest is padding).
116+
# For your own dataset, replace `7` with the action dimension of your dataset.
80117
return {"actions": np.asarray(data["actions"][:, :7])}

src/openpi/training/config.py

+82-10
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,22 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
251251

252252
@dataclasses.dataclass(frozen=True)
253253
class LeRobotLiberoDataConfig(DataConfigFactory):
254+
"""
255+
This config is used to configure transforms that are applied at various parts of the data pipeline.
256+
For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
257+
comments below.
258+
"""
259+
254260
@override
255261
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
256-
# Make inputs look like they come from the Libero environment
262+
# The repack transform is *only* applied to the data coming from the dataset,
263+
# and *not* during inference. We can use it to make inputs from the dataset look
264+
# as close as possible to those coming from the inference environment (e.g. match the keys).
265+
# Below, we match the keys in the dataset (which we defined in the data conversion script) to
266+
# the keys we use in our inference pipeline (defined in the inference script for libero).
267+
# For your own dataset, first figure out what keys your environment passes to the policy server
268+
# and then modify the mappings below so your dataset's keys get matched to those target keys.
269+
# The repack transform simply remaps key names here.
257270
repack_transform = _transforms.Group(
258271
inputs=[
259272
_transforms.RepackTransform(
@@ -268,22 +281,38 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
268281
]
269282
)
270283

271-
# Prepare data for policy training
272-
# Convert images to uint8 numpy arrays, add masks
284+
# The data transforms are applied to the data coming from the dataset *and* during inference.
285+
# Below, we define the transforms for data going into the model (``inputs``) and the transforms
286+
# for data coming out of the model (``outputs``) (the latter is only used during inference).
287+
# We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
288+
# how to modify the transforms to match your dataset. Once you created your own transforms, you can
289+
# replace the transforms below with your own.
273290
data_transforms = _transforms.Group(
274291
inputs=[libero_policy.LiberoInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
275292
outputs=[libero_policy.LiberoOutputs()],
276293
)
277-
# Use delta actions (not for gripper)
278-
delta_action_mask = _transforms.make_bool_mask(6, -1)
279-
data_transforms = data_transforms.push(
280-
inputs=[_transforms.DeltaActions(delta_action_mask)],
281-
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
282-
)
294+
295+
# One additional data transform: pi0 models are trained on delta actions (relative to the first
296+
# state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
297+
# you can uncomment the following line to convert the actions to delta actions. The only exception
298+
# is for the gripper actions which are always absolute.
299+
# In the example below, we would apply the delta conversion to the first 6 actions (joints) and
300+
# leave the 7th action (gripper) unchanged, i.e. absolute.
301+
# In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
302+
# apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
303+
# transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
304+
305+
# delta_action_mask = _transforms.make_bool_mask(6, -1)
306+
# data_transforms = data_transforms.push(
307+
# inputs=[_transforms.DeltaActions(delta_action_mask)],
308+
# outputs=[_transforms.AbsoluteActions(delta_action_mask)],
309+
# )
283310

284311
# Model transforms include things like tokenizing the prompt and action targets
312+
# You do not need to change anything here for your own dataset.
285313
model_transforms = ModelTransformFactory()(model_config)
286314

315+
# We return all data transforms for training and inference. No need to change anything here.
287316
return dataclasses.replace(
288317
self.create_base_config(assets_dirs),
289318
repack_transforms=repack_transform,
@@ -442,21 +471,41 @@ def __post_init__(self) -> None:
442471
#
443472
# Fine-tuning Libero configs.
444473
#
474+
# These train configs define the hyperparameters for fine-tuning the base model on your own dataset.
475+
# They are used to define key elements like the dataset you are training on, the base checkpoint you
476+
# are using, and other hyperparameters like how many training steps to run or what learning rate to use.
477+
# For your own dataset, you can copy this class and modify the dataset name, and data transforms based on
478+
# the comments below.
445479
TrainConfig(
480+
# Change the name to reflect your model and dataset.
446481
name="pi0_libero",
482+
# Here you define the model config -- In this example we use pi0 as the model
483+
# architecture and perform *full* finetuning. in the examples below we show how to modify
484+
# this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.
447485
model=pi0.Pi0Config(),
486+
# Here you define the dataset you are training on. In this example we use the Libero
487+
# dataset. For your own dataset, you can change the repo_id to point to your dataset.
488+
# Also modify the DataConfig to use the new config you made for your dataset above.
448489
data=LeRobotLiberoDataConfig(
449490
repo_id="physical-intelligence/libero",
450491
base_config=DataConfig(
451492
local_files_only=False, # Set to True for local-only datasets.
493+
# This flag determines whether we load the prompt (i.e. the task instruction) from the
494+
# ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
495+
# a field called ``prompt`` in the input dict. The recommended setting is True.
452496
prompt_from_task=True,
453497
),
454498
),
499+
# Here you define which pre-trained checkpoint you want to load to initialize the model.
500+
# This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
455501
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
502+
# Below you can define other hyperparameters like the learning rate, number of training steps, etc.
503+
# Check the base TrainConfig class for a full list of available hyperparameters.
456504
num_train_steps=30_000,
457505
),
458506
TrainConfig(
459507
name="pi0_libero_low_mem_finetune",
508+
# Here is an example of loading a pi0 model for LoRA fine-tuning.
460509
model=pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
461510
data=LeRobotLiberoDataConfig(
462511
repo_id="physical-intelligence/libero",
@@ -467,13 +516,28 @@ def __post_init__(self) -> None:
467516
),
468517
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
469518
num_train_steps=30_000,
519+
# The freeze filter defines which parameters should be frozen during training.
520+
# We have a convenience function in the model config that returns the default freeze filter
521+
# for the given model config for LoRA finetuning. Just make sure it matches the model config
522+
# you chose above.
470523
freeze_filter=pi0.Pi0Config(
471524
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"
472525
).get_freeze_filter(),
526+
# Turn off EMA for LoRA finetuning.
473527
ema_decay=None,
474528
),
475529
TrainConfig(
476530
name="pi0_fast_libero",
531+
# Here is an example of loading a pi0-FAST model for full finetuning.
532+
# Modify action_dim and action_horizon to match your dataset (action horizon is equal to
533+
# the desired action chunk length).
534+
# The max_token_len is the maximum number of (non-image) tokens the model can handle.
535+
# This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.
536+
# Choosing this value too small may chop off tokens at the end of your sequence (the code will throw
537+
# a warning), while choosing it too large will waste memory (since we pad each batch element to the
538+
# max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for
539+
# two-arm robots. Generally, err on the lower side here first, and potentially increase the value if
540+
# you see many warnings being thrown during training.
477541
model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),
478542
data=LeRobotLiberoDataConfig(
479543
repo_id="physical-intelligence/libero",
@@ -482,12 +546,17 @@ def __post_init__(self) -> None:
482546
prompt_from_task=True,
483547
),
484548
),
549+
# Note that we load the pi0-FAST base model checkpoint here.
485550
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
486551
num_train_steps=30_000,
487552
),
488553
TrainConfig(
489554
name="pi0_fast_libero_low_mem_finetune",
490-
model=pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora"),
555+
# Here is an example of loading a pi0-FAST model for LoRA finetuning.
556+
# For setting action_dim, action_horizon, and max_token_len, see the comments above.
557+
model=pi0_fast.Pi0FASTConfig(
558+
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
559+
),
491560
data=LeRobotLiberoDataConfig(
492561
repo_id="physical-intelligence/libero",
493562
base_config=DataConfig(
@@ -497,9 +566,12 @@ def __post_init__(self) -> None:
497566
),
498567
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
499568
num_train_steps=30_000,
569+
# Again, make sure to match the model config above when extracting the freeze filter
570+
# that specifies which parameters should be frozen during LoRA finetuning.
500571
freeze_filter=pi0_fast.Pi0FASTConfig(
501572
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
502573
).get_freeze_filter(),
574+
# Turn off EMA for LoRA finetuning.
503575
ema_decay=None,
504576
),
505577
#

0 commit comments

Comments
 (0)