Remove unnecessary device synchronizations from finegrained FP8 matmul#43349
Remove unnecessary device synchronizations from finegrained FP8 matmul#43349readleyj wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43349&sha=be30ec |
|
cc @MekkCyber |
| # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the | ||
| # preceding operations are ready before proceeding | ||
| torch_accelerator_module.synchronize() |
There was a problem hiding this comment.
I remember it was necessary to add these to avoid having nan values in the output, are you sure it's not necessary ?
There was a problem hiding this comment.
Hey, @MekkCyber. Yeap, I just tested locally. Ran a stress test without the syncs and didn't see any nans
import time, types, torch
from transformers.integrations.finegrained_fp8 import FP8Expert, FP8Linear
def naninf(x): return bool(torch.isnan(x).any().item() or torch.isinf(x).any().item())
def sync():
if torch.cuda.is_available(): torch.cuda.synchronize()
@torch.no_grad()
def fp8linear(b,s,in_f,out_f,dtype,iters,block=(128,128)):
assert in_f % block[1] == 0 and out_f % block[0] == 0
m = FP8Linear(in_features=in_f,out_features=out_f,bias=True,dtype=torch.float8_e4m3fn,block_size=block,activation_scheme="dynamic").cuda()
w = torch.randn(out_f,in_f,device="cuda",dtype=torch.float16).clamp(-2,2)
m.weight.copy_(w.to(torch.float8_e4m3fn)); m.weight_scale_inv.fill_(1.0)
m.bias.copy_(torch.randn(out_f,device="cuda",dtype=torch.float16).to(dtype))
x = torch.randn(b,s,in_f,device="cuda",dtype=dtype).contiguous()
for _ in range(10): y = m(x) # warmup
t0=time.time()
for i in range(iters):
y = m(x)
if naninf(y): raise RuntimeError(f"FP8Linear NaN/Inf shape={(b,s,in_f,out_f)} dtype={dtype} i={i}")
if i % 25 == 0: x = (x + 0.01*torch.randn_like(x)).contiguous()
sync()
print(f"FP8Linear OK shape={(b,s,in_f,out_f)} dtype={str(dtype).replace('torch.','')} iters={iters} sec/iter={(time.time()-t0)/iters:.6f}")
@torch.no_grad()
def fp8expert(tokens,h,inter,experts,topk,dtype,iters,block=(128,128)):
assert h % block[1] == 0 and inter % block[0] == 0
cfg = types.SimpleNamespace(num_local_experts=experts, hidden_size=h, intermediate_size=inter, hidden_act="silu")
moe = FP8Expert(cfg, block_size=block, dtype=torch.float8_e4m3fn).cuda()
moe.gate_up_proj.copy_(torch.randn_like(moe.gate_up_proj,dtype=torch.float16).clamp(-2,2).to(torch.float8_e4m3fn))
moe.down_proj.copy_(torch.randn_like(moe.down_proj,dtype=torch.float16).clamp(-2,2).to(torch.float8_e4m3fn))
moe.gate_up_proj_scale_inv.fill_(1.0); moe.down_proj_scale_inv.fill_(1.0)
hs = torch.randn(tokens,h,device="cuda",dtype=dtype).contiguous()
idx = torch.randint(0,experts,(tokens,topk),device="cuda")
wts = torch.rand(tokens,topk,device="cuda",dtype=dtype)
for _ in range(5): out = moe(hs,idx,wts) # warmup
t0=time.time()
for i in range(iters):
out = moe(hs,idx,wts)
if naninf(out): raise RuntimeError(f"FP8Expert NaN/Inf tokens={tokens} h={h} inter={inter} dtype={dtype} i={i}")
if i % 10 == 0:
hs = (hs + 0.01*torch.randn_like(hs)).contiguous()
wts = torch.rand_like(wts)
sync()
print(f"FP8Expert OK tokens={tokens} h={h} inter={inter} experts={experts} topk={topk} dtype={str(dtype).replace('torch.','')} iters={iters} sec/iter={(time.time()-t0)/iters:.6f}")
def main():
assert torch.cuda.is_available() and hasattr(torch,"float8_e4m3fn")
torch.manual_seed(0); torch.backends.cuda.matmul.allow_tf32 = True
dtypes = [torch.bfloat16, torch.float16]
lin_shapes = [(1,1,256,256),(1,16,256,256),(4,8,256,256),(2,32,512,256),(2,32,512,512),(8,8,1024,512)]
moe_shapes = [(64,256,256,4,2),(128,256,256,4,2),(256,256,512,4,2),(256,512,512,8,2)]
print("=== FP8Linear ===")
for dt in dtypes:
for b,s,in_f,out_f in lin_shapes: fp8linear(b,s,in_f,out_f,dt,iters=400)
print("=== FP8Expert ===")
for dt in dtypes:
for t,h,inter,e,k in moe_shapes: fp8expert(t,h,inter,e,k,dt,iters=150)
print("All OK")
if __name__ == "__main__":
main()=== FP8Linear ===
FP8Linear OK shape=(1, 1, 256, 256) dtype=bfloat16 iters=400 sec/iter=0.000321
FP8Linear OK shape=(1, 16, 256, 256) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(4, 8, 256, 256) dtype=bfloat16 iters=400 sec/iter=0.000201
FP8Linear OK shape=(2, 32, 512, 256) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(2, 32, 512, 512) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(8, 8, 1024, 512) dtype=bfloat16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(1, 1, 256, 256) dtype=float16 iters=400 sec/iter=0.000201
FP8Linear OK shape=(1, 16, 256, 256) dtype=float16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(4, 8, 256, 256) dtype=float16 iters=400 sec/iter=0.000202
FP8Linear OK shape=(2, 32, 512, 256) dtype=float16 iters=400 sec/iter=0.000206
FP8Linear OK shape=(2, 32, 512, 512) dtype=float16 iters=400 sec/iter=0.000213
FP8Linear OK shape=(8, 8, 1024, 512) dtype=float16 iters=400 sec/iter=0.000206
=== FP8Expert ===
FP8Expert OK tokens=64 h=256 inter=256 experts=4 topk=2 dtype=bfloat16 iters=150 sec/iter=0.002137
FP8Expert OK tokens=128 h=256 inter=256 experts=4 topk=2 dtype=bfloat16 iters=150 sec/iter=0.002126
FP8Expert OK tokens=256 h=256 inter=512 experts=4 topk=2 dtype=bfloat16 iters=150 sec/iter=0.002144
FP8Expert OK tokens=256 h=512 inter=512 experts=8 topk=2 dtype=bfloat16 iters=150 sec/iter=0.004046
FP8Expert OK tokens=64 h=256 inter=256 experts=4 topk=2 dtype=float16 iters=150 sec/iter=0.002133
FP8Expert OK tokens=128 h=256 inter=256 experts=4 topk=2 dtype=float16 iters=150 sec/iter=0.002140
FP8Expert OK tokens=256 h=256 inter=512 experts=4 topk=2 dtype=float16 iters=150 sec/iter=0.002143
FP8Expert OK tokens=256 h=512 inter=512 experts=8 topk=2 dtype=float16 iters=150 sec/iter=0.004019
All OK
I don't see any reason a sync would be necessary, all the ops are GPU ops, there is no CPU interaction which is when you would need a sync. Is it possible the nans were there regardless of the sync and were addressed by #43154?
| # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the | ||
| # preceding operations are ready before proceeding | ||
| torch_accelerator_module.synchronize() |
There was a problem hiding this comment.
Can you just comment the synchronization instead of removing them in case in the future, we face the issue with the nans ? Also can you try to check if we get a correct generation with a model like qwen fp8 or deepspeek ? Thanks !
The synchronizations are unnecessary and kill performance. Before and after traces


cc @SunMarc