-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flash_attention_v2_backward #10495
flash_attention_v2_backward #10495
Conversation
Speed stats:
|
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/docs/Oneflow-Inc/oneflow/pr/10495/ |
auto grad_k_ = (*output_)[1]; | ||
auto grad_v_ = (*output_)[2]; | ||
|
||
// auto grad_q_padded = JUST(functional::Transpose(grad_q_, {0, 2, 1, 3})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除
if dtype == flow.float16: | ||
test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) | ||
error_tol = 1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float16 tolerance=1e-2会不会太大了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个1e-2是参考
test_case.assertTrue(np.allclose(ref_out, fused_out, atol=1e-2, rtol=1e-2)) |
然后测试发现1e-3也过不了测试
View latest API docs preview at: https://oneflow-staging.oss-cn-beijing.aliyuncs.com/docs/Oneflow-Inc/oneflow/pr/10495/ |
Speed stats:
|
flash attention v2 backward算子