Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash_attention_v2_backward #10495

Merged
merged 5 commits into from
May 6, 2024
Merged

flash_attention_v2_backward #10495

merged 5 commits into from
May 6, 2024

Conversation

cccddd77
Copy link
Contributor

flash attention v2 backward算子

Copy link
Contributor

Speed stats:

Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@cccddd77 cccddd77 requested review from oneflow-ci-bot and removed request for oneflow-ci-bot April 23, 2024 00:53
Copy link
Contributor

auto grad_k_ = (*output_)[1];
auto grad_v_ = (*output_)[2];

// auto grad_q_padded = JUST(functional::Transpose(grad_q_, {0, 2, 1, 3}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除

if dtype == flow.float16:
test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))
error_tol = 1e-2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float16 tolerance=1e-2会不会太大了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个1e-2是参考

test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2))

然后测试发现1e-3也过不了测试

Copy link
Contributor

github-actions bot commented May 6, 2024

Copy link
Contributor

github-actions bot commented May 6, 2024

Speed stats:
GPU Name: NVIDIA GeForce RTX 3080 Ti 

❌ OneFlow resnet50 time: 43.7ms (= 4371.3ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 58.0ms (= 5797.6ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.33 (= 58.0ms / 43.7ms)

OneFlow resnet50 time: 26.2ms (= 2616.5ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 38.1ms (= 3812.5ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.46 (= 38.1ms / 26.2ms)

OneFlow resnet50 time: 19.7ms (= 3932.1ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.3ms (= 7060.7ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.80 (= 35.3ms / 19.7ms)

OneFlow resnet50 time: 17.9ms (= 3571.5ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 31.5ms (= 6297.8ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.76 (= 31.5ms / 17.9ms)

OneFlow resnet50 time: 16.8ms (= 3353.0ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 29.5ms (= 5903.2ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.76 (= 29.5ms / 16.8ms)

OneFlow swin dataloader time: 0.201s (= 40.171s / 200, num_workers=1)
PyTorch swin dataloader time: 0.127s (= 25.467s / 200, num_workers=1)
Relative speed: 0.634 (= 0.127s / 0.201s)

OneFlow swin dataloader time: 0.054s (= 10.830s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.583s / 200, num_workers=4)
Relative speed: 0.608 (= 0.033s / 0.054s)

OneFlow swin dataloader time: 0.031s (= 6.216s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.320s / 200, num_workers=8)
Relative speed: 0.534 (= 0.017s / 0.031s)

❌ OneFlow resnet50 time: 49.2ms (= 4924.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 65.9ms (= 6586.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.34 (= 65.9ms / 49.2ms)

OneFlow resnet50 time: 36.4ms (= 3638.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 47.1ms (= 4710.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.29 (= 47.1ms / 36.4ms)

OneFlow resnet50 time: 27.9ms (= 5587.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 42.5ms (= 8501.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.52 (= 42.5ms / 27.9ms)

OneFlow resnet50 time: 25.5ms (= 5100.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 39.0ms (= 7800.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.53 (= 39.0ms / 25.5ms)

OneFlow resnet50 time: 24.5ms (= 4901.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 35.7ms (= 7149.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.46 (= 35.7ms / 24.5ms)

@MARD1NO MARD1NO merged commit ea585f6 into master May 6, 2024
20 checks passed
@MARD1NO MARD1NO deleted the flash_attention_v2_backward branch May 6, 2024 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants