Skip to content

Backward register #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
May 19, 2025
Merged

Backward register #423

merged 30 commits into from
May 19, 2025

Conversation

StrongSpoon
Copy link
Collaborator

@StrongSpoon StrongSpoon commented Jan 16, 2025

PR Category

Operator

Type of Change

New Feature

Description

register backward functions as aten interfaces
implement threshold operator incidentally

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@StrongSpoon StrongSpoon force-pushed the bwd branch 2 times, most recently from 9f79739 to 01bee17 Compare February 6, 2025 09:26
@StrongSpoon StrongSpoon marked this pull request as ready for review February 11, 2025 02:04
affine: tl.constexpr,
input_grad_mask: tl.constexpr,
weight_grad_mask: tl.constexpr,
bias_grad_mask: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backward kernel may need is_train arg also, to distinguish between train and non-train cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave it for future work tho.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit complex. fix it later QAQ

running_var=None,
save_mean=None,
save_invstd=None,
train=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel should be able to handle train=True case.


def native_dropout(x, p=0.5, train=True):
return NativeDropout.apply(x, p, train)
def dropout(input, p, train):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arg train is optional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

logging.debug("GEMS NATIVE DROPOUT FORWARD")
assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
device = input.device
input = input.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that we'll remove contiguous enforcement in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +119 to +121
indices = indices.contiguous()
weight = weight.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor this in TODOs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

mean = mean.contiguous()
rstd = rstd.contiguous()
weight = None if weight is None else weight.contiguous()
group_size = C // group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cdiv?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

BLOCK_GROUP_SIZE=triton.next_power_of_2(C // num_groups),
BLOCK_HW_SIZE=triton.next_power_of_2(HW),
HxW,
BLOCK_GROUP_SIZE=triton.next_power_of_2(C // group),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cdiv(C, group)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


def native_dropout(x, p=0.5, train=True):
return NativeDropout.apply(x, p, train)
def dropout(input, p, train):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we didn't handle we train=False correctly in the previous version. Let's fix that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

"mean": 1,
"sum": 2,
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with flag_gems.use_gems():
res_out = torch.sigmoid_(inp * 1.0)
res_out = torch.sigmoid_(res_inp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why there was inp*1.0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know either.

iclementine
iclementine previously approved these changes May 14, 2025
Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

iclementine
iclementine previously approved these changes May 15, 2025
@FlagOpen FlagOpen locked and limited conversation to collaborators May 16, 2025
@FlagOpen FlagOpen unlocked this conversation May 16, 2025
@Bowen12992 Bowen12992 marked this pull request as draft May 16, 2025 06:41
@Bowen12992 Bowen12992 marked this pull request as ready for review May 16, 2025 06:41
@iclementine iclementine merged commit 337267c into master May 19, 2025
11 of 15 checks passed
@iclementine iclementine deleted the bwd branch May 19, 2025 03:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants