-
Notifications
You must be signed in to change notification settings - Fork 113
Backward register #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backward register #423
Conversation
9f79739
to
01bee17
Compare
affine: tl.constexpr, | ||
input_grad_mask: tl.constexpr, | ||
weight_grad_mask: tl.constexpr, | ||
bias_grad_mask: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backward kernel may need is_train arg also, to distinguish between train and non-train cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave it for future work tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a bit complex. fix it later QAQ
running_var=None, | ||
save_mean=None, | ||
save_invstd=None, | ||
train=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel should be able to handle train=True case.
src/flag_gems/ops/dropout.py
Outdated
|
||
def native_dropout(x, p=0.5, train=True): | ||
return NativeDropout.apply(x, p, train) | ||
def dropout(input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arg train is optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
logging.debug("GEMS NATIVE DROPOUT FORWARD") | ||
assert p > 0.0 and p < 1.0, "p must be in (0, 1)" | ||
device = input.device | ||
input = input.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note that we'll remove contiguous enforcement in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
indices = indices.contiguous() | ||
weight = weight.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor this in TODOs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/flag_gems/ops/groupnorm.py
Outdated
mean = mean.contiguous() | ||
rstd = rstd.contiguous() | ||
weight = None if weight is None else weight.contiguous() | ||
group_size = C // group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cdiv?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
src/flag_gems/ops/groupnorm.py
Outdated
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups), | ||
BLOCK_HW_SIZE=triton.next_power_of_2(HW), | ||
HxW, | ||
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // group), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cdiv(C, group)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
src/flag_gems/ops/dropout.py
Outdated
|
||
def native_dropout(x, p=0.5, train=True): | ||
return NativeDropout.apply(x, p, train) | ||
def dropout(input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized we didn't handle we train=False correctly in the previous version. Let's fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
cdcef25
to
0eb24e2
Compare
bd86725
to
1cc1ab5
Compare
"mean": 1, | ||
"sum": 2, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest using torch's torch/nn/_reduction.py like this
with flag_gems.use_gems(): | ||
res_out = torch.sigmoid_(inp * 1.0) | ||
res_out = torch.sigmoid_(res_inp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why there was inp*1.0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… directory, which are registered as AutogradCUDA before
…um to convert reduction string to integer
PR Category
Operator
Type of Change
New Feature
Description
register backward functions as aten interfaces
implement threshold operator incidentally
Issue
Progress
Performance