You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thanks for your great work on Transformer Engine!
I am working on a project that requires high-performance batched matrix multiplication (i.e., 3D tensor multiplication) where all inputs are stored in the FP8 data type. However, I noticed that te.Linear only takes a single input matrix and uses its internal weights, which does not fit the case where I need to multiply two arbitrary 3D tensors.
Could you please advise which function or API in Transformer Engine is recommended for performing batched matrix multiplication (GEMM) directly on two FP8 3D tensors? Is there a public interface for this use case, or is it only available through the lower-level generic_gemm/general_gemm functions? If so, could you share an example or best practice for this scenario?