@@ -251,9 +251,22 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
251
251
252
252
@dataclasses .dataclass (frozen = True )
253
253
class LeRobotLiberoDataConfig (DataConfigFactory ):
254
+ """
255
+ This config is used to configure transforms that are applied at various parts of the data pipeline.
256
+ For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
257
+ comments below.
258
+ """
259
+
254
260
@override
255
261
def create (self , assets_dirs : pathlib .Path , model_config : _model .BaseModelConfig ) -> DataConfig :
256
- # Make inputs look like they come from the Libero environment
262
+ # The repack transform is *only* applied to the data coming from the dataset,
263
+ # and *not* during inference. We can use it to make inputs from the dataset look
264
+ # as close as possible to those coming from the inference environment (e.g. match the keys).
265
+ # Below, we match the keys in the dataset (which we defined in the data conversion script) to
266
+ # the keys we use in our inference pipeline (defined in the inference script for libero).
267
+ # For your own dataset, first figure out what keys your environment passes to the policy server
268
+ # and then modify the mappings below so your dataset's keys get matched to those target keys.
269
+ # The repack transform simply remaps key names here.
257
270
repack_transform = _transforms .Group (
258
271
inputs = [
259
272
_transforms .RepackTransform (
@@ -268,22 +281,38 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
268
281
]
269
282
)
270
283
271
- # Prepare data for policy training
272
- # Convert images to uint8 numpy arrays, add masks
284
+ # The data transforms are applied to the data coming from the dataset *and* during inference.
285
+ # Below, we define the transforms for data going into the model (``inputs``) and the transforms
286
+ # for data coming out of the model (``outputs``) (the latter is only used during inference).
287
+ # We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
288
+ # how to modify the transforms to match your dataset. Once you created your own transforms, you can
289
+ # replace the transforms below with your own.
273
290
data_transforms = _transforms .Group (
274
291
inputs = [libero_policy .LiberoInputs (action_dim = model_config .action_dim , model_type = model_config .model_type )],
275
292
outputs = [libero_policy .LiberoOutputs ()],
276
293
)
277
- # Use delta actions (not for gripper)
278
- delta_action_mask = _transforms .make_bool_mask (6 , - 1 )
279
- data_transforms = data_transforms .push (
280
- inputs = [_transforms .DeltaActions (delta_action_mask )],
281
- outputs = [_transforms .AbsoluteActions (delta_action_mask )],
282
- )
294
+
295
+ # One additional data transform: pi0 models are trained on delta actions (relative to the first
296
+ # state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
297
+ # you can uncomment the following line to convert the actions to delta actions. The only exception
298
+ # is for the gripper actions which are always absolute.
299
+ # In the example below, we would apply the delta conversion to the first 6 actions (joints) and
300
+ # leave the 7th action (gripper) unchanged, i.e. absolute.
301
+ # In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
302
+ # apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
303
+ # transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
304
+
305
+ # delta_action_mask = _transforms.make_bool_mask(6, -1)
306
+ # data_transforms = data_transforms.push(
307
+ # inputs=[_transforms.DeltaActions(delta_action_mask)],
308
+ # outputs=[_transforms.AbsoluteActions(delta_action_mask)],
309
+ # )
283
310
284
311
# Model transforms include things like tokenizing the prompt and action targets
312
+ # You do not need to change anything here for your own dataset.
285
313
model_transforms = ModelTransformFactory ()(model_config )
286
314
315
+ # We return all data transforms for training and inference. No need to change anything here.
287
316
return dataclasses .replace (
288
317
self .create_base_config (assets_dirs ),
289
318
repack_transforms = repack_transform ,
@@ -442,21 +471,41 @@ def __post_init__(self) -> None:
442
471
#
443
472
# Fine-tuning Libero configs.
444
473
#
474
+ # These train configs define the hyperparameters for fine-tuning the base model on your own dataset.
475
+ # They are used to define key elements like the dataset you are training on, the base checkpoint you
476
+ # are using, and other hyperparameters like how many training steps to run or what learning rate to use.
477
+ # For your own dataset, you can copy this class and modify the dataset name, and data transforms based on
478
+ # the comments below.
445
479
TrainConfig (
480
+ # Change the name to reflect your model and dataset.
446
481
name = "pi0_libero" ,
482
+ # Here you define the model config -- In this example we use pi0 as the model
483
+ # architecture and perform *full* finetuning. in the examples below we show how to modify
484
+ # this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.
447
485
model = pi0 .Pi0Config (),
486
+ # Here you define the dataset you are training on. In this example we use the Libero
487
+ # dataset. For your own dataset, you can change the repo_id to point to your dataset.
488
+ # Also modify the DataConfig to use the new config you made for your dataset above.
448
489
data = LeRobotLiberoDataConfig (
449
490
repo_id = "physical-intelligence/libero" ,
450
491
base_config = DataConfig (
451
492
local_files_only = False , # Set to True for local-only datasets.
493
+ # This flag determines whether we load the prompt (i.e. the task instruction) from the
494
+ # ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
495
+ # a field called ``prompt`` in the input dict. The recommended setting is True.
452
496
prompt_from_task = True ,
453
497
),
454
498
),
499
+ # Here you define which pre-trained checkpoint you want to load to initialize the model.
500
+ # This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
455
501
weight_loader = weight_loaders .CheckpointWeightLoader ("s3://openpi-assets/checkpoints/pi0_base/params" ),
502
+ # Below you can define other hyperparameters like the learning rate, number of training steps, etc.
503
+ # Check the base TrainConfig class for a full list of available hyperparameters.
456
504
num_train_steps = 30_000 ,
457
505
),
458
506
TrainConfig (
459
507
name = "pi0_libero_low_mem_finetune" ,
508
+ # Here is an example of loading a pi0 model for LoRA fine-tuning.
460
509
model = pi0 .Pi0Config (paligemma_variant = "gemma_2b_lora" , action_expert_variant = "gemma_300m_lora" ),
461
510
data = LeRobotLiberoDataConfig (
462
511
repo_id = "physical-intelligence/libero" ,
@@ -467,13 +516,28 @@ def __post_init__(self) -> None:
467
516
),
468
517
weight_loader = weight_loaders .CheckpointWeightLoader ("s3://openpi-assets/checkpoints/pi0_base/params" ),
469
518
num_train_steps = 30_000 ,
519
+ # The freeze filter defines which parameters should be frozen during training.
520
+ # We have a convenience function in the model config that returns the default freeze filter
521
+ # for the given model config for LoRA finetuning. Just make sure it matches the model config
522
+ # you chose above.
470
523
freeze_filter = pi0 .Pi0Config (
471
524
paligemma_variant = "gemma_2b_lora" , action_expert_variant = "gemma_300m_lora"
472
525
).get_freeze_filter (),
526
+ # Turn off EMA for LoRA finetuning.
473
527
ema_decay = None ,
474
528
),
475
529
TrainConfig (
476
530
name = "pi0_fast_libero" ,
531
+ # Here is an example of loading a pi0-FAST model for full finetuning.
532
+ # Modify action_dim and action_horizon to match your dataset (action horizon is equal to
533
+ # the desired action chunk length).
534
+ # The max_token_len is the maximum number of (non-image) tokens the model can handle.
535
+ # This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.
536
+ # Choosing this value too small may chop off tokens at the end of your sequence (the code will throw
537
+ # a warning), while choosing it too large will waste memory (since we pad each batch element to the
538
+ # max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for
539
+ # two-arm robots. Generally, err on the lower side here first, and potentially increase the value if
540
+ # you see many warnings being thrown during training.
477
541
model = pi0_fast .Pi0FASTConfig (action_dim = 7 , action_horizon = 10 , max_token_len = 180 ),
478
542
data = LeRobotLiberoDataConfig (
479
543
repo_id = "physical-intelligence/libero" ,
@@ -482,12 +546,17 @@ def __post_init__(self) -> None:
482
546
prompt_from_task = True ,
483
547
),
484
548
),
549
+ # Note that we load the pi0-FAST base model checkpoint here.
485
550
weight_loader = weight_loaders .CheckpointWeightLoader ("s3://openpi-assets/checkpoints/pi0_fast_base/params" ),
486
551
num_train_steps = 30_000 ,
487
552
),
488
553
TrainConfig (
489
554
name = "pi0_fast_libero_low_mem_finetune" ,
490
- model = pi0_fast .Pi0FASTConfig (paligemma_variant = "gemma_2b_lora" ),
555
+ # Here is an example of loading a pi0-FAST model for LoRA finetuning.
556
+ # For setting action_dim, action_horizon, and max_token_len, see the comments above.
557
+ model = pi0_fast .Pi0FASTConfig (
558
+ action_dim = 7 , action_horizon = 10 , max_token_len = 180 , paligemma_variant = "gemma_2b_lora"
559
+ ),
491
560
data = LeRobotLiberoDataConfig (
492
561
repo_id = "physical-intelligence/libero" ,
493
562
base_config = DataConfig (
@@ -497,9 +566,12 @@ def __post_init__(self) -> None:
497
566
),
498
567
weight_loader = weight_loaders .CheckpointWeightLoader ("s3://openpi-assets/checkpoints/pi0_fast_base/params" ),
499
568
num_train_steps = 30_000 ,
569
+ # Again, make sure to match the model config above when extracting the freeze filter
570
+ # that specifies which parameters should be frozen during LoRA finetuning.
500
571
freeze_filter = pi0_fast .Pi0FASTConfig (
501
572
action_dim = 7 , action_horizon = 10 , max_token_len = 180 , paligemma_variant = "gemma_2b_lora"
502
573
).get_freeze_filter (),
574
+ # Turn off EMA for LoRA finetuning.
503
575
ema_decay = None ,
504
576
),
505
577
#
0 commit comments