Skip to content

update hstack #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025
Merged

update hstack #775

merged 1 commit into from
Jul 11, 2025

Conversation

meinie0826
Copy link
Collaborator

PR Category

Type of Change

Description

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

benchmark/test_tensor_concat_perf.py 
Operator: hstack  Performance Test (dtype=torch.float16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.019616            0.005952               3.296          [[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]]
SUCCESS               0.007072            0.006336               1.116          [[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]]
SUCCESS               0.008352            0.007360               1.135          [[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]]
SUCCESS               0.009536            0.008352               1.142          [[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]]
SUCCESS               0.011776            0.010560               1.115          [[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]]
SUCCESS               0.006656            0.006016               1.106          [[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]]
SUCCESS               0.006944            0.006176               1.124          [[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]]
SUCCESS               0.009344            0.008480               1.102          [[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]]
SUCCESS               0.006560            0.005984               1.096          [[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]]
SUCCESS               0.007232            0.006336               1.141          [[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]]


Operator: hstack  Performance Test (dtype=torch.float32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006848            0.005984               1.144          [[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]]
SUCCESS               0.007328            0.006624               1.106          [[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]]
SUCCESS               0.009312            0.008384               1.111          [[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]]
SUCCESS               0.011424            0.010352               1.104          [[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]]
SUCCESS               0.015712            0.014880               1.056          [[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]]
SUCCESS               0.006880            0.006656               1.034          [[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]]
SUCCESS               0.007072            0.006304               1.122          [[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]]
SUCCESS               0.011520            0.010464               1.101          [[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]]
SUCCESS               0.006752            0.005984               1.128          [[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]]
SUCCESS               0.007712            0.006688               1.153          [[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]]


Operator: hstack  Performance Test (dtype=torch.bfloat16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006592            0.005984               1.102          [[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]]
SUCCESS               0.007264            0.006336               1.146          [[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]]
SUCCESS               0.008256            0.007424               1.112          [[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]]
SUCCESS               0.009440            0.008352               1.130          [[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]]
SUCCESS               0.011584            0.010720               1.081          [[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]]
SUCCESS               0.006656            0.006016               1.106          [[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]]
SUCCESS               0.006976            0.006208               1.124          [[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]]
SUCCESS               0.009472            0.008352               1.134          [[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]]
SUCCESS               0.006752            0.005984               1.128          [[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]]
SUCCESS               0.007232            0.006336               1.141          [[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]]

@meinie0826 meinie0826 force-pushed the op/hstack branch 2 times, most recently from 428c681 to 8ef6853 Compare July 11, 2025 08:08
Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tongxin
Copy link
Contributor

tongxin commented Jul 11, 2025

Cool. That was one more optimized operator in one day!

@@ -55,19 +113,93 @@ def hstack(
for tensor number {tensor_num + 1} in the list."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the shape checking can be further polished.

if tensor.select(dim, 0).shape != inp0.select(dim, 0).shape:
    raise RuntimeError(f'Sizes of tensors must match except in dimension {dim}. ')

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't carefully review the check code. I think since the previous PR merged, it should ensure the scope of the check. If you have any suggestions, you can directly make modifications

@meinie0826 meinie0826 merged commit 86e4a7e into master Jul 11, 2025
10 of 14 checks passed
@meinie0826 meinie0826 deleted the op/hstack branch July 11, 2025 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants