Skip to content

update op of cat #769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 14, 2025
Merged

update op of cat #769

merged 6 commits into from
Jul 14, 2025

Conversation

meinie0826
Copy link
Collaborator

@meinie0826 meinie0826 commented Jul 10, 2025

PR Category

Type of Change

Description

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

benchmark/test_tensor_concat_perf.py 
Operator: cat  Performance Test (dtype=torch.float16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006832            0.005984               1.142          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': 0})
SUCCESS               0.006656            0.005952               1.118          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': -1})
SUCCESS               0.007136            0.006464               1.104          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': 0})
SUCCESS               0.007168            0.006464               1.109          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': -1})
SUCCESS               0.012192            0.007360               1.657          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': 0})
SUCCESS               0.008352            0.007360               1.135          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': -1})
SUCCESS               0.009344            0.008448               1.106          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': 0})
SUCCESS               0.009376            0.008384               1.118          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': -1})
SUCCESS               0.011488            0.010496               1.095          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': 0})
SUCCESS               0.011616            0.010624               1.093          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': -1})
SUCCESS               0.006816            0.006304               1.081          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': 0})
SUCCESS               0.006880            0.006080               1.132          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': -1})
SUCCESS               0.006976            0.006240               1.118          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': 0})
SUCCESS               0.007200            0.006240               1.154          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': -1})
SUCCESS               0.009344            0.008480               1.102          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': 0})
SUCCESS               0.009376            0.008448               1.110          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': -1})
SUCCESS               0.006816            0.005984               1.139          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': 0})
SUCCESS               0.006656            0.006112               1.089          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': -1})
SUCCESS               0.007104            0.006400               1.110          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': 0})
SUCCESS               0.007168            0.006592               1.087          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': -1})


Operator: cat  Performance Test (dtype=torch.float32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006880            0.006144               1.120          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': 0})
SUCCESS               0.006912            0.006080               1.137          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': -1})
SUCCESS               0.007392            0.006816               1.085          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': 0})
SUCCESS               0.007584            0.006688               1.134          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': -1})
SUCCESS               0.009152            0.008640               1.059          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': 0})
SUCCESS               0.009344            0.008448               1.106          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': -1})
SUCCESS               0.011360            0.010528               1.079          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': 0})
SUCCESS               0.011264            0.010336               1.090          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': -1})
SUCCESS               0.017792            0.014912               1.193          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': 0})
SUCCESS               0.015840            0.014880               1.065          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': -1})
SUCCESS               0.006848            0.006944               0.986          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': 0})
SUCCESS               0.006688            0.006720               0.995          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': -1})
SUCCESS               0.007104            0.006464               1.099          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': 0})
SUCCESS               0.007296            0.006368               1.146          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': -1})
SUCCESS               0.011360            0.010720               1.060          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': 0})
SUCCESS               0.011360            0.010464               1.086          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': -1})
SUCCESS               0.006752            0.006080               1.111          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': 0})
SUCCESS               0.006816            0.006752               1.009          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': -1})
SUCCESS               0.007808            0.006816               1.146          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': 0})
SUCCESS               0.007968            0.006880               1.158          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': -1})


Operator: cat  Performance Test (dtype=torch.bfloat16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006656            0.005952               1.118          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': 0})
SUCCESS               0.006656            0.005984               1.112          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': -1})
SUCCESS               0.007136            0.006400               1.115          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': 0})
SUCCESS               0.007104            0.006400               1.110          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': -1})
SUCCESS               0.008448            0.007488               1.128          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': 0})
SUCCESS               0.008448            0.007488               1.128          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': -1})
SUCCESS               0.009568            0.008416               1.137          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': 0})
SUCCESS               0.009504            0.008384               1.134          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': -1})
SUCCESS               0.011776            0.010752               1.095          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': 0})
SUCCESS               0.011808            0.010752               1.098          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': -1})
SUCCESS               0.006816            0.006272               1.087          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': 0})
SUCCESS               0.006880            0.006080               1.132          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': -1})
SUCCESS               0.007168            0.006240               1.149          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': 0})
SUCCESS               0.007008            0.006240               1.123          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': -1})
SUCCESS               0.009568            0.008416               1.137          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': 0})
SUCCESS               0.009504            0.008352               1.138          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': -1})
SUCCESS               0.006784            0.006048               1.122          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': 0})
SUCCESS               0.006656            0.006208               1.072          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': -1})
SUCCESS               0.007104            0.006400               1.110          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': 0})
SUCCESS               0.007392            0.006592               1.121          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': -1})


Operator: cat  Performance Test (dtype=torch.int16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006656            0.005984               1.112          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': 0})
SUCCESS               0.006656            0.006016               1.106          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': -1})
SUCCESS               0.007296            0.006400               1.140          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': 0})
SUCCESS               0.007296            0.006400               1.140          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': -1})
SUCCESS               0.008480            0.007456               1.137          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': 0})
SUCCESS               0.008448            0.007456               1.133          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': -1})
SUCCESS               0.009536            0.008416               1.133          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': 0})
SUCCESS               0.009280            0.008352               1.111          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': -1})
SUCCESS               0.011744            0.010752               1.092          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': 0})
SUCCESS               0.011584            0.010720               1.081          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': -1})
SUCCESS               0.006624            0.006336               1.045          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': 0})
SUCCESS               0.006656            0.006048               1.101          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': -1})
SUCCESS               0.007136            0.006208               1.149          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': 0})
SUCCESS               0.007200            0.006208               1.160          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': -1})
SUCCESS               0.009568            0.008384               1.141          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': 0})
SUCCESS               0.009472            0.008352               1.134          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': -1})
SUCCESS               0.006656            0.006016               1.106          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': 0})
SUCCESS               0.006656            0.030208               0.220          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': -1})
SUCCESS               0.007520            0.006368               1.181          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': 0})
SUCCESS               0.007168            0.006592               1.087          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': -1})


Operator: cat  Performance Test (dtype=torch.int32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.006880            0.006176               1.114          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': 0})
SUCCESS               0.006656            0.006080               1.095          ([[torch.Size([64, 64]), torch.Size([64, 64]), torch.Size([64, 64])]], {'dim': -1})
SUCCESS               0.007584            0.006784               1.118          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': 0})
SUCCESS               0.007584            0.006688               1.134          ([[torch.Size([256, 256]), torch.Size([256, 256]), torch.Size([256, 256])]], {'dim': -1})
SUCCESS               0.009088            0.008512               1.068          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': 0})
SUCCESS               0.009248            0.008320               1.112          ([[torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])]], {'dim': -1})
SUCCESS               0.011584            0.010688               1.084          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': 0})
SUCCESS               0.011552            0.010464               1.104          ([[torch.Size([512, 1024]), torch.Size([512, 1024]), torch.Size([512, 1024])]], {'dim': -1})
SUCCESS               0.015552            0.014816               1.050          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': 0})
SUCCESS               0.015968            0.014944               1.069          ([[torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])]], {'dim': -1})
SUCCESS               0.006656            0.006944               0.959          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': 0})
SUCCESS               0.006688            0.006688               1.000          ([[torch.Size([1024, 2]), torch.Size([1024, 2]), torch.Size([1024, 2])]], {'dim': -1})
SUCCESS               0.007136            0.006464               1.104          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': 0})
SUCCESS               0.007104            0.006368               1.116          ([[torch.Size([1024, 32]), torch.Size([1024, 32]), torch.Size([1024, 32])]], {'dim': -1})
SUCCESS               0.011360            0.010688               1.063          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': 0})
SUCCESS               0.011552            0.010496               1.101          ([[torch.Size([1024, 512]), torch.Size([1024, 512]), torch.Size([1024, 512])]], {'dim': -1})
SUCCESS               0.006976            0.006176               1.130          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': 0})
SUCCESS               0.007040            0.006720               1.048          ([[torch.Size([64, 64, 1]), torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]], {'dim': -1})
SUCCESS               0.007584            0.006816               1.113          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': 0})
SUCCESS               0.007744            0.006944               1.115          ([[torch.Size([64, 64, 16]), torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]], {'dim': -1})

@iclementine
Copy link
Collaborator

I think this modification no longer supports input tensors that are not contiguous.

@meinie0826
Copy link
Collaborator Author

I think this modification no longer supports input tensors that are not contiguous.

Thank you for your suggestion. I have updated the code. I am not sure if it meets your requirements? If convenient, please provide more suggestions.

iclementine
iclementine previously approved these changes Jul 11, 2025
@iclementine iclementine self-assigned this Jul 11, 2025
StrongSpoon
StrongSpoon previously approved these changes Jul 11, 2025
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm


for j in range(4):
if j < num_tensors_in_batch:
tensor = tensors_in_batch[j].contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend reusing the flag_gems.contiguous to ensure that the triton kernel runs, even when it is called explicitly.

Copy link
Collaborator Author

@meinie0826 meinie0826 Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is flag_gems.contiguous ? I haven't seen any places that have used it?

@triton.jit
def copy_func(x):
return x
def copy_func_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the copy_func_kernel still useful? If not, I suggest deleting it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sry, i didnt notice that, I will remove it!

@meinie0826 meinie0826 dismissed stale reviews from StrongSpoon and iclementine via 4296a36 July 11, 2025 09:15
@iclementine iclementine merged commit ed2db4a into master Jul 14, 2025
10 of 14 checks passed
@iclementine iclementine deleted the op/cat branch July 14, 2025 08:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants