Skip to content

LoRA and Transformers TP#3079

Open
michaelbenayoun wants to merge 28 commits intohuggingface:mainfrom
michaelbenayoun:lora_and_tp
Open

LoRA and Transformers TP#3079
michaelbenayoun wants to merge 28 commits intohuggingface:mainfrom
michaelbenayoun:lora_and_tp

Conversation

@michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Mar 3, 2026

The goal of this PR is to integrate Transformers' API for Tensor Parallelism to PEFT, starting by LoRA.

As #3044 pointed out, there are issues.

First, the code used fails. It was due to the fact that to create adapters PEFT looks at the parent modules attributes, and not the actual weights.

For instance for LoRA, it will check the in_features and out_features attributes of torch.nn.Linear, instead of the weight's shape, failing in the case of TP because there is a mismatch. I addressed this issue here: huggingface/transformers#44421.

On top of that, we need to handle a few things to make it work. There are two cases:

  • Column Linears: in this case, the output is sharded. So our adapters should also produce sharded outputs. We should have lora_A to be a regular non-sharded linear, and lora_B to be a column linear, just as the base layer.
  • Row Linears: in this case, the input comes sharded, and we should produce a un-sharded output. We should have lora_A to be a row linear, and lora_B to be a regular non-sharded linear.

To do that, we need to do multiple things:

  • If we create the adapter from scratch (not loaded from an existing checkpoint), then we only need to add the Transformers TensorParallelLayer's hooks
  • If we load the adapter from an existing checkpoint, then we also need to shard the loaded weights accordingly

This PR provides such features and a test file to check that everything works as expected.

Next: add similar support for the Embedding layer.

@BenjaminBossan
Copy link
Member

Thanks for taking care of this @michaelbenayoun. LMK if I can help with anything. If you have a minimal example to test this, that would be great.

A different approach that may work is to detect if a TP plan is being used and then initialize the corresponding PEFT layer differently using the sharded layers, I'm not sure what approach would be more robust.

@michaelbenayoun michaelbenayoun marked this pull request as ready for review March 7, 2026 00:03
dist.destroy_process_group()


def _test_training(rank, world_size, port):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to test using the following methodology ? https://github.com/huggingface/transformers/blob/main/tests/test_training_mixin.py#L387

take a tiny random model, overfit the same sample abcdefg... for several steps until the loss and grad_norm has decreased of 70% -> save and load -> generate ? (since it has overfit the sample, it should perfectly predict the sequence)

Image

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

device_mesh = getattr(base_layer, "_hf_device_mesh", None)
if device_mesh is not None and tp_plan in ("colwise", "rowwise"):
pg = device_mesh.get_group()
src = torch.distributed.get_global_rank(pg, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

styling but better to import torch.distributed as dist then use dist (be consistent on the transformer side) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@3outeille
Copy link
Member

Nice PR ! Several questions that pops into my mind:

  • Regarding TP Row linear + lora, should the all_reduce happens before or after the lora computation ?
  • Given how importance of MoE as well, that could be good to have a TP + Peft + Moe support, what do you think?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@michaelbenayoun
Copy link
Member Author

Nice PR ! Several questions that pops into my mind:

  • Regarding TP Row linear + lora, should the all_reduce happens before or after the lora computation ?
  • Given how importance of MoE as well, that could be good to have a TP + Peft + Moe support, what do you think?
  • TP Row linear + Lora will do this: Lora A gets sharded inputs, computes the output and all reduce, just like a regular RowLinear. Then Lora B gets un-sharded inputs.
  • I agree, as well as for Embeddings. I suggest we postpone that to other PRs to not make this one grow.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for TP to LoRA. I'm not too knowledgeable when it comes to TP, so I can't judge the details there, so this review focuses more on the PEFT integration itself.

Right now, this is very LoRA specific. I think the same idea should work with other PEFT methods too, but it's not quite trivial to write the code in a generic way. So I'm fine with the approach here and we can adjust once/if there is demand for TP in other PEFT methods.

One question that I had is: Do you know the minimum transformers version that would be required to run this? The whole TP module seems to be from one year ago, but I'm not sure if later changes are required for this to actually work. If you know, could you please add a small section to the docs (https://github.com/huggingface/peft/blob/main/docs/source/developer_guides/lora.md) mentioning that TP is supported and requires transformers > x.y.z?

Moreover, although we don't have CI for this, we generally try to support older transformers versions as much as possible. Code like getattr(base_layer, "_hf_tp_plan", None) should always be fine, as this would just return None for older versions. But importing from from transformers.integrations.tensor_parallel would fail. So how about importing it locally, only when needed?

Regarding the CI, it currently fails, most likely because the TP tests are being run on CPU runners. I made some suggestions that should hopefully resolve this. However, even if I run the tests locally on a machine with 2 GPUs, I get an error:

E           tp_base = AutoModelForCausalLM.from_pretrained(MODEL_ID, tp_plan="auto")
E         File "/home/name/work/forks/transformers/src/transformers/models/auto/auto_factory.py", line 381, in from_pretrained
E           return model_class.from_pretrained(
E                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~^
E               pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
E               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E           )
E           ^
E         File "/home/name/work/forks/transformers/src/transformers/modeling_utils.py", line 3989, in from_pretrained
E           device_map, device_mesh, tp_size = initialize_tensor_parallelism(
E                                              ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
E               tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map
E               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E           )
E           ^
E         File "/home/name/work/forks/transformers/src/transformers/integrations/tensor_parallel.py", line 81, in initialize_tensor_parallelism
E           current_device.set_device(int(os.environ["LOCAL_RANK"]))
E                                         ~~~~~~~~~~^^^^^^^^^^^^^^
E         File "<frozen os>", line 717, in __getitem__
E       KeyError: 'LOCAL_RANK'


# We create and initialize the TensorParallelLayer on the fly,
# and we set the `empty_param` attribute depending on the proper
# state dict key to shard
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please extend this comment as to why this is needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +611 to +622
if isinstance(tp_layer, ColwiseParallel):
key = f"{name}.lora_B.{adapter_name}.weight"
tp_layer.empty_param = peft_model_state_dict[key]
peft_model_state_dict[key] = tp_layer.shard_tensor(
peft_model_state_dict[key], device=device, dtype=dtype
)
elif isinstance(tp_layer, RowwiseParallel):
key = f"{name}.lora_A.{adapter_name}.weight"
tp_layer.empty_param = peft_model_state_dict[key]
peft_model_state_dict[key] = tp_layer.shard_tensor(
peft_model_state_dict[key], device=device, dtype=dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should do the same, right?

Suggested change
if isinstance(tp_layer, ColwiseParallel):
key = f"{name}.lora_B.{adapter_name}.weight"
tp_layer.empty_param = peft_model_state_dict[key]
peft_model_state_dict[key] = tp_layer.shard_tensor(
peft_model_state_dict[key], device=device, dtype=dtype
)
elif isinstance(tp_layer, RowwiseParallel):
key = f"{name}.lora_A.{adapter_name}.weight"
tp_layer.empty_param = peft_model_state_dict[key]
peft_model_state_dict[key] = tp_layer.shard_tensor(
peft_model_state_dict[key], device=device, dtype=dtype
)
if isinstance(tp_layer, ColwiseParallel):
key = f"{name}.lora_B.{adapter_name}.weight"
elif isinstance(tp_layer, RowwiseParallel):
key = f"{name}.lora_A.{adapter_name}.weight"
tp_layer.empty_param = peft_model_state_dict[key]
peft_model_state_dict[key] = tp_layer.shard_tensor(
peft_model_state_dict[key], device=device, dtype=dtype
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +286 to +303
if tp_plan == "colwise":
add_tensor_parallel_hooks_to_module(
self.model,
lora_module.lora_B[adapter_name],
tp_plan,
f"{current_key}.lora_B.{adapter_name}",
tp_plan,
device_mesh,
)
elif tp_plan == "rowwise":
add_tensor_parallel_hooks_to_module(
self.model,
lora_module.lora_A[adapter_name],
tp_plan,
f"{current_key}.lora_A.{adapter_name}",
tp_plan,
device_mesh,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should do the same but with less repetition, right?

Suggested change
if tp_plan == "colwise":
add_tensor_parallel_hooks_to_module(
self.model,
lora_module.lora_B[adapter_name],
tp_plan,
f"{current_key}.lora_B.{adapter_name}",
tp_plan,
device_mesh,
)
elif tp_plan == "rowwise":
add_tensor_parallel_hooks_to_module(
self.model,
lora_module.lora_A[adapter_name],
tp_plan,
f"{current_key}.lora_A.{adapter_name}",
tp_plan,
device_mesh,
)
if tp_plan == "colwise":
tp_module = lora_module.lora_B[adapter_name]
tp_layer_name = f"{current_key}.lora_B.{adapter_name}",
else:
tp_module = lora_module.lora_A[adapter_name],
tp_layer_name = f"{current_key}.lora_A.{adapter_name}",
add_tensor_parallel_hooks_to_module(
model=self.model,
module=tp_module,
tp_plan=tp_plan,
layer_name=tp_layer_name,
current_module_plan=tp_plan,
device_mesh=device_mesh,
)

Also, it looks like the tp_plan argument to add_tensor_parallel_hooks_to_module is not used at all, why is it needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right about your question about the tp_plan parameter. I opened a PR here: huggingface/transformers#44768. It should be merge before the release so I will make sure to include the changes here if it happens.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other refactor mentioned, done.

@@ -0,0 +1,393 @@
# Copyright 2025-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2025-present the HuggingFace Inc. team.
# Copyright 2026-present the HuggingFace Inc. team.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +35 to +36
MODEL_ID = "Qwen/Qwen3-0.6B"
TINY_MODEL_ID = "amazingvince/zephyr-smol_llama-100m-sft-full"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason to use these models? Otherwise, I'd like to move to models we're already using as they are already cached (the CI is already close to triggering rate limits from the Hub so we have to be careful).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright I get you.
For MODEL_ID, the idea is just to use a small qwen3 model. We can use any LLM tbh.

For TINY_MODEL_ID, it is a bit more nuanced. Because we are running training steps, and we check that overfitting happens, we cannot simply take a tiny randomly initialized model because there's a glass ceiling in what the LoRA adapters can adapt when the base model is just full of garbage. So we need an actual train model, small enough to run fast in the CI, and I managed to find this small 100m finetuned model.

_teardown_dist()


def _test_training_overfit(rank, world_size, port):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks more like an integration test to me. We could put a separate script for this into tests/training/ and then invoke it here:

peft/Makefile

Line 65 in 2513f57

tests_training:

This ensures it's only run in the correct context and when it's needed (not part of the regular CI, which runs on CPU).

logger.info(f"{Colors.GREEN}✓ Generated sequence matches training sequence{Colors.RESET}")


def _test_lora_weight_synchronization(rank, world_size, port):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test and the rest below should be put into tests/test_gpu_examples.py. Let's put them all in the same test class to make it clear they belong together. Then decorate the class with @pytest.mark.multi_gpu_tests. That way, we know that the tests only run on the multi GPU runner.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also works on any runner with multiple CPU cores. But I did exactly as you suggested.

Comment on lines +371 to +372
@unittest.skipUnless(_is_tp_available(), "transformers TP integration not available")
class TestLoraTP(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@unittest.skipUnless(_is_tp_available(), "transformers TP integration not available")
class TestLoraTP(unittest.TestCase):
@pytest.mark.skipf(not _is_tp_available(), reason="transformers TP integration not available")
class TestLoraTensorParallel:

As we're away from unittest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants