Skip to content

Fused Qwen3 MoE layer for faster training, compatible with HF Transformers, LoRA, 4-bit quant, Unsloth

License

Notifications You must be signed in to change notification settings

woct0rdho/transformers-qwen3-moe-fused

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

91 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Qwen3 MoE Fused

Update: Transformers v5 will soon be released and it supports the fused MoE kernels. We will implement the LoRA in PEFT. This repo only supports Transformers v4.

The Qwen3 MoE model (and all other MoE models) in HF Transformers is notoriously slow, because it uses a for loop to access the experts. The purpose of this repo is to fine-tune Qwen3-30B-A3B on a single GPU with 24GB VRAM and achieve high throughput. The implementation is compatible with the HF Transformers ecosystem, such as LoRA, bitsandbytes 4-bit quantization, and Unsloth. See example_train_30b_a3b_unsloth.py for the usage.

Fused linear layer

The critical part is to implement the moe_fused_linear function:

output[b, o] = sum_i weight[selected_experts[b], o, i] * input[b, i]

There are already several good implementations, such as triton-kernels, llama.cpp, vLLM, fanshiqing/grouped_gemm. torch._grouped_mm is also being implemented. We need to sort input by the experts to improve the memory coalescence of weight, and more optimizations are explained in https://pytorch.org/blog/accelerating-moes-with-a-triton-persistent-cache-aware-grouped-gemm-kernel/

The implementation in this repo is largely based on the MoE kernel in Unsloth, which is based on the Triton grouped GEMM. I've added strides, masks, and autotune configs for small or 'thin' matrices, which are needed for LoRA.

I aim to keep the code readable and easy to follow. I only used the most mature features of Triton, such as load and store, rather than things like TMA and swizzle. I've benchmarked it on RTX 3080 and it's close to the theoretical fp16 and bf16 performance.

LoRA

The LoRA for the fused linear layer is defined by first creating a LoRA for the linear layer in each expert, then stack them along the experts dimension. For the weight tensor with shape (num_experts, out_features, in_features), the two LoRA weights have shape lora_A: (num_experts, lora_rank, in_features), lora_B: (num_experts, out_features, lora_rank). Therefore, we can losslessly convert between the fused and the unfused formats, and a previously trained LoRA can continue to be trained.

The functions in qwen3_moe_fused/convert.py can convert a model or a LoRA between the fused and the unfused formats. After you train a LoRA in the fused format, you can convert it to the unfused format, then convert it to other formats such as GGUF. llama.cpp already supports this kind of LoRA. Support in vLLM is being implemented, see vllm-project/vllm#21229

TODO

License

The files in qwen3_moe_fused/grouped_gemm/ are modified from the Unsloth MoE kernels so they are AGPLv3 licensed, see the explanation. For more robust and performant integration, it's possible to use the MIT licensed triton-kernels as an alternative.

The rest of this repo, including files modified from Transformers, PEFT, and bitsandbytes, are Apache-2.0 licensed.

About

Fused Qwen3 MoE layer for faster training, compatible with HF Transformers, LoRA, 4-bit quant, Unsloth

Resources

License

Stars

Watchers

Forks

Languages