Skip to content

Remove unnecessary device synchronizations from finegrained FP8 matmul#43349

Open
readleyj wants to merge 1 commit intohuggingface:mainfrom
readleyj:remove-fp8-device-syncs
Open

Remove unnecessary device synchronizations from finegrained FP8 matmul#43349
readleyj wants to merge 1 commit intohuggingface:mainfrom
readleyj:remove-fp8-device-syncs

Conversation

@readleyj
Copy link

@readleyj readleyj commented Jan 19, 2026

The synchronizations are unnecessary and kill performance. Before and after traces
image
image

cc @SunMarc

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43349&sha=be30ec

@vasqu
Copy link
Contributor

vasqu commented Jan 19, 2026

cc @MekkCyber

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @readleyj

Comment on lines -475 to -477
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it was necessary to add these to avoid having nan values in the output, are you sure it's not necessary ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, @MekkCyber. Yeap, I just tested locally. Ran a stress test without the syncs and didn't see any nans

import time, types, torch
from transformers.integrations.finegrained_fp8 import FP8Expert, FP8Linear

def naninf(x): return bool(torch.isnan(x).any().item() or torch.isinf(x).any().item())
def sync(): 
    if torch.cuda.is_available(): torch.cuda.synchronize()

@torch.no_grad()
def fp8linear(b,s,in_f,out_f,dtype,iters,block=(128,128)):
    assert in_f % block[1] == 0 and out_f % block[0] == 0
    m = FP8Linear(in_features=in_f,out_features=out_f,bias=True,dtype=torch.float8_e4m3fn,block_size=block,activation_scheme="dynamic").cuda()
    w = torch.randn(out_f,in_f,device="cuda",dtype=torch.float16).clamp(-2,2)
    m.weight.copy_(w.to(torch.float8_e4m3fn)); m.weight_scale_inv.fill_(1.0)
    m.bias.copy_(torch.randn(out_f,device="cuda",dtype=torch.float16).to(dtype))
    x = torch.randn(b,s,in_f,device="cuda",dtype=dtype).contiguous()
    for _ in range(10): y = m(x)  # warmup
    t0=time.time()
    for i in range(iters):
        y = m(x)
        if naninf(y): raise RuntimeError(f"FP8Linear NaN/Inf shape={(b,s,in_f,out_f)} dtype={dtype} i={i}")
        if i % 25 == 0: x = (x + 0.01*torch.randn_like(x)).contiguous()
    sync()
    print(f"FP8Linear OK shape={(b,s,in_f,out_f)} dtype={str(dtype).replace('torch.','')} iters={iters} sec/iter={(time.time()-t0)/iters:.6f}")

@torch.no_grad()
def fp8expert(tokens,h,inter,experts,topk,dtype,iters,block=(128,128)):
    assert h % block[1] == 0 and inter % block[0] == 0
    cfg = types.SimpleNamespace(num_local_experts=experts, hidden_size=h, intermediate_size=inter, hidden_act="silu")
    moe = FP8Expert(cfg, block_size=block, dtype=torch.float8_e4m3fn).cuda()
    moe.gate_up_proj.copy_(torch.randn_like(moe.gate_up_proj,dtype=torch.float16).clamp(-2,2).to(torch.float8_e4m3fn))
    moe.down_proj.copy_(torch.randn_like(moe.down_proj,dtype=torch.float16).clamp(-2,2).to(torch.float8_e4m3fn))
    moe.gate_up_proj_scale_inv.fill_(1.0); moe.down_proj_scale_inv.fill_(1.0)
    hs = torch.randn(tokens,h,device="cuda",dtype=dtype).contiguous()
    idx = torch.randint(0,experts,(tokens,topk),device="cuda")
    wts = torch.rand(tokens,topk,device="cuda",dtype=dtype)
    for _ in range(5): out = moe(hs,idx,wts)  # warmup
    t0=time.time()
    for i in range(iters):
        out = moe(hs,idx,wts)
        if naninf(out): raise RuntimeError(f"FP8Expert NaN/Inf tokens={tokens} h={h} inter={inter} dtype={dtype} i={i}")
        if i % 10 == 0:
            hs = (hs + 0.01*torch.randn_like(hs)).contiguous()
            wts = torch.rand_like(wts)
    sync()
    print(f"FP8Expert OK tokens={tokens} h={h} inter={inter} experts={experts} topk={topk} dtype={str(dtype).replace('torch.','')} iters={iters} sec/iter={(time.time()-t0)/iters:.6f}")

def main():
    assert torch.cuda.is_available() and hasattr(torch,"float8_e4m3fn")
    torch.manual_seed(0); torch.backends.cuda.matmul.allow_tf32 = True
    dtypes = [torch.bfloat16, torch.float16]
    lin_shapes = [(1,1,256,256),(1,16,256,256),(4,8,256,256),(2,32,512,256),(2,32,512,512),(8,8,1024,512)]
    moe_shapes = [(64,256,256,4,2),(128,256,256,4,2),(256,256,512,4,2),(256,512,512,8,2)]
    print("=== FP8Linear ===")
    for dt in dtypes:
        for b,s,in_f,out_f in lin_shapes: fp8linear(b,s,in_f,out_f,dt,iters=400)
    print("=== FP8Expert ===")
    for dt in dtypes:
        for t,h,inter,e,k in moe_shapes: fp8expert(t,h,inter,e,k,dt,iters=150)
    print("All OK")

if __name__ == "__main__":
    main()
=== FP8Linear ===
FP8Linear OK shape=(1, 1, 256, 256) dtype=bfloat16 iters=400 sec/iter=0.000321
FP8Linear OK shape=(1, 16, 256, 256) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(4, 8, 256, 256) dtype=bfloat16 iters=400 sec/iter=0.000201
FP8Linear OK shape=(2, 32, 512, 256) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(2, 32, 512, 512) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(8, 8, 1024, 512) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(1, 1, 256, 256) dtype=float16 iters=400 sec/iter=0.000201
FP8Linear OK shape=(1, 16, 256, 256) dtype=float16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(4, 8, 256, 256) dtype=float16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(2, 32, 512, 256) dtype=float16 iters=400 sec/iter=0.000206
FP8Linear OK shape=(2, 32, 512, 512) dtype=float16 iters=400 sec/iter=0.000213
FP8Linear OK shape=(8, 8, 1024, 512) dtype=float16 iters=400 sec/iter=0.000206
=== FP8Expert ===
FP8Expert OK tokens=64 h=256 inter=256 experts=4 topk=2 dtype=bfloat16 iters=150 sec/iter=0.002137
FP8Expert OK tokens=128 h=256 inter=256 experts=4 topk=2 dtype=bfloat16 iters=150 sec/iter=0.002126
FP8Expert OK tokens=256 h=256 inter=512 experts=4 topk=2 dtype=bfloat16 iters=150 sec/iter=0.002144
FP8Expert OK tokens=256 h=512 inter=512 experts=8 topk=2 dtype=bfloat16 iters=150 sec/iter=0.004046
FP8Expert OK tokens=64 h=256 inter=256 experts=4 topk=2 dtype=float16 iters=150 sec/iter=0.002133
FP8Expert OK tokens=128 h=256 inter=256 experts=4 topk=2 dtype=float16 iters=150 sec/iter=0.002140
FP8Expert OK tokens=256 h=256 inter=512 experts=4 topk=2 dtype=float16 iters=150 sec/iter=0.002143
FP8Expert OK tokens=256 h=512 inter=512 experts=8 topk=2 dtype=float16 iters=150 sec/iter=0.004019
All OK

I don't see any reason a sync would be necessary, all the ops are GPU ops, there is no CPU interaction which is when you would need a sync. Is it possible the nans were there regardless of the sync and were addressed by #43154?

@readleyj readleyj requested a review from MekkCyber January 26, 2026 00:58
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a nit

Comment on lines -475 to -477
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just comment the synchronization instead of removing them in case in the future, we face the issue with the nans ? Also can you try to check if we get a correct generation with a model like qwen fp8 or deepspeek ? Thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants