Skip to content

Conversation

@anhminhnguyenhoang
Copy link

@anhminhnguyenhoang anhminhnguyenhoang commented Jan 16, 2026

Co-authors: @waqahmed-amd-fi @anhminhnguyenhoang

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@anhminhnguyenhoang anhminhnguyenhoang requested a review from a team January 16, 2026 15:15
@anhminhnguyenhoang anhminhnguyenhoang marked this pull request as draft January 16, 2026 15:15
@anhminhnguyenhoang anhminhnguyenhoang changed the title mHC: Manifold-constrained Hyper Connection [WIP] mHC: Manifold-constrained Hyper Connection Jan 16, 2026
waqahmed-amd-fi and others added 5 commits January 19, 2026 04:57
… kernel (#1877)

* Refactor mHC kernel and wrapper to implement equations 14-18 with fused kernel

* improve comments

* Enhance documentation for mhc function: clarify equations, input/output shapes, and activation details

* Enhance documentation in test_mhc.py: clarify equations, input/output shapes, and activation details for mHC kernel tests

* Add _sinkhorn_knopp_log_domain_kernel to the fusion module

* Add logging and sync Sinkhorn-Knopp function for doubly stochastic matrices

* sync log-domain Sinkhorn-Knopp kernel for doubly stochastic matrix projection

* Improve logging in mhc function to include all alpha parameters
else:
assert out.shape == (M, N), f"Output shape mismatch: expected ({M}, {N}), got {out.shape}"
assert out.dtype == x.dtype, f"Output dtype mismatch: expected {x.dtype}, got {out.dtype}"
assert out.device == x.device, f"Output device mismatch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F541> reported by reviewdog 🐶
f-string without any placeholders

Suggested change
assert out.device == x.device, f"Output device mismatch"
assert out.device == x.device, "Output device mismatch"


# Res-stream: no constraints (identity activation)
# Just verify it exists
assert out_res.shape == (M, n_squared), f"Res-stream shape mismatch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F541> reported by reviewdog 🐶
f-string without any placeholders

Suggested change
assert out_res.shape == (M, n_squared), f"Res-stream shape mismatch"
assert out_res.shape == (M, n_squared), "Res-stream shape mismatch"

# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from .mhc_ref import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F403> reported by reviewdog 🐶
from .mhc_ref import * used; unable to detect undefined names

# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

from .mhc_ref import *
from .mla_decode_ref import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F403> reported by reviewdog 🐶
from .mla_decode_ref import * used; unable to detect undefined names


from .mhc_ref import *
from .mla_decode_ref import *
from .mla_extend_ref import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F403> reported by reviewdog 🐶
from .mla_extend_ref import * used; unable to detect undefined names

from .mhc_ref import *
from .mla_decode_ref import *
from .mla_extend_ref import *
from .rotary_embedding import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F403> reported by reviewdog 🐶
from .rotary_embedding import * used; unable to detect undefined names

- H^res: [2n:2n+n²] residual connection (identity) (n² elements)
"""
x_f32 = x.to(torch.float32)
nC = x.shape[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable nC is assigned to but never used

Suggested change
nC = x.shape[1]
x.shape[1]

H_tilde = x_norm @ phi_f32

# Split into three streams
n_squared = n * n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable n_squared is assigned to but never used

Suggested change
n_squared = n * n
n * n

waqahmed-amd-fi and others added 3 commits January 21, 2026 10:29
* Refactor mHC kernel and wrapper to implement equations 14-18 with fused kernel

* improve comments

* Enhance documentation for mhc function: clarify equations, input/output shapes, and activation details

* Enhance documentation in test_mhc.py: clarify equations, input/output shapes, and activation details for mHC kernel tests

* Add _sinkhorn_knopp_log_domain_kernel to the fusion module

* Add logging and sync Sinkhorn-Knopp function for doubly stochastic matrices

* sync log-domain Sinkhorn-Knopp kernel for doubly stochastic matrix projection

* Improve logging in mhc function to include all alpha parameters

* Fix H dimensions

* Refactor mHC function to return separate output tensors for pre, post, and residual streams

* Refactor mhc_torch to return separate output tensors for pre, post, and residual streams

* Adjust tolerance for is_doubly_stochastic assertion in test_sk_matrix_sizes for bfloat16 precision
Copy link
Author

@anhminhnguyenhoang anhminhnguyenhoang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I would personally clean up the comments as they look a bit redundant

H_res_torch.to(torch.float32),
atol=1e-2,
rtol=1e-2,
atol=5e-2,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run into test failure because of this for similar tests that you need to relax the tolerance?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, mainly because of sinkhorn which is an iterative process and returns higher differences due you only 10 iterations. May be we can try 20 for better results?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants