Skip to content

add std_mean op #1971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

khushi-411
Copy link
Contributor

As per the title.

@khushi-411 khushi-411 marked this pull request as ready for review April 16, 2025 17:55
@mruberry mruberry requested a review from beverlylytle April 16, 2025 18:17
@mruberry
Copy link
Collaborator

@beverlylytle, would you like to review this PR?

Copy link
Collaborator

@beverlylytle beverlylytle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, thank you! I was looking through PyTorch's issues for std_mean and I leave two thoughts for your consideration:

  1. Did you happen to check float('inf') values? torch.std_mean returns NaN as mean of an inf array. pytorch/pytorch#138570
  2. I don't think there's something to be done here, but I thought it was interesting: torch.std_mean slower than separate torch.mean and torch.std calls on CPU pytorch/pytorch#122191

sample_input_generator=std_sample_generator,
error_input_generator=std_error_generator,
torch_reference=torch.std_mean,
dtypes=(datatypes.floating,),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are checks in the meta for complex types, but they are omitted from testing. I know there are other issues with testing complex types, but were they left out here for a reason?

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a prim? I think we can just have a decomposition in torch/__init__.py which calls ltorch.mean and ltorch.std. Fusion executor like nvFuser would generate a good kernel. This way, we won't need to add a prim and grad rule for the same.

Wdyt @khushi-411 @beverlylytle?

@beverlylytle
Copy link
Collaborator

beverlylytle commented Apr 23, 2025

@kshitij12345 you make a good point, but var_mean is a primitive with a distinct nvfuser op for a reason, right? In computing the variance, the mean is computed along the way. std is the square root of variance, so computing std and mean separately would calculate the mean twice. Nvfuser doesn't have a separate op for std_mean. What about a decomposition of std_mean into var_mean followed by sqrt?

@kshitij12345
Copy link
Collaborator

Good point. That makes sense to me, thanks!

@mruberry
Copy link
Collaborator

@kshitij12345 you make a good point, but var_mean is a primitive with a distinct nvfuser op for a reason, right? In computing the variance, the mean is computed along the way. std is the square root of variance, so computing std and mean separately would calculate the mean twice. Nvfuser doesn't have a separate op for std_mean. What about a decomposition of std_mean into var_mean followed by sqrt?

var_mean is a primitive because there are several ways to compute the variance, and one popular way is Welford's algorithm (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) which has one and two pass variants. We could require executors use Welford's algorithm by making a primitive for it, and we could even explicitly require the one or two pass version if we wanted, but currently we let executors figure out how they should compute the mean and variance.

For std_mean, is it OK to take the sqrt after the variance is computed, or is there a similar issue as with the Welford computation? I think it is OK to take it after, and I believe that's what the PyTorch code does:

https://github.com/pytorch/pytorch/blob/b32b002a6ea879e506453b09a4b206632e530abf/aten/src/ATen/native/cuda/ReduceMomentKernel.cu#L14

https://github.com/pytorch/pytorch/blob/b32b002a6ea879e506453b09a4b206632e530abf/aten/src/ATen/native/SharedReduceOps.h#L97

But I could be mistaken.

@beverlylytle
Copy link
Collaborator

We could require executors use Welford's algorithm by making a primitive for it, and we could even explicitly require the one or two pass version if we wanted

I am inclined against being so prescriptive without a hard reason.

While it's possible that executors may want to provide their own implementations of std_mean in the future, there are none (besides the default) doing so now. Thunder can provide an immediate efficiency improvement in the case of NvFuser execution for std_mean with a var_mean-sqrt implementation (which I do believe is OK), rather than being a primitive. What do you think @khushi-411 ?

@mruberry
Copy link
Collaborator

We could require executors use Welford's algorithm by making a primitive for it, and we could even explicitly require the one or two pass version if we wanted

I am inclined against being so prescriptive without a hard reason.

While it's possible that executors may want to provide their own implementations of std_mean in the future, there are none (besides the default) doing so now. Thunder can provide an immediate efficiency improvement in the case of NvFuser execution for std_mean with a var_mean-sqrt implementation (which I do believe is OK), rather than being a primitive. What do you think @khushi-411 ?

Sounds good; executors can also consume the torch operation directly if they have custom std+mean logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants