Skip to content

Remove generic kernel invocations of MatrixLinewiseOp #2682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
divyegala opened this issue May 21, 2025 · 0 comments · May be fixed by #2701, rapidsai/cuvs#1018 or rapidsai/cuml#6900
Open

Remove generic kernel invocations of MatrixLinewiseOp #2682

divyegala opened this issue May 21, 2025 · 0 comments · May be fixed by #2701, rapidsai/cuvs#1018 or rapidsai/cuml#6900
Assignees

Comments

@divyegala
Copy link
Member

divyegala commented May 21, 2025

The invocations can be converted to compile time features at 2 places:

  1. if (alongLines)
    return matrixLinewiseVecRows<Type, IdxType, VecBytes, BlockSize, Lambda, Vecs...>(
    out, in, lineLen, nLines, op, stream, vecs...);
    else
    return matrixLinewiseVecCols<Type, IdxType, VecBytes, BlockSize, Lambda, Vecs...>(
    out, in, lineLen, nLines, op, stream, vecs...);
    }
  2. if (alongLines)
    return matrixLinewiseVecRowsSpan<Type,
    IdxType,
    LayoutPolicy,
    VecBytes,
    BlockSize,
    Lambda,
    Vecs...>(out, in, lineLen, nLines, op, stream, vecs...);
    else
    return matrixLinewiseVecColsSpan<Type,
    IdxType,
    LayoutPolicy,
    VecBytes,
    BlockSize,
    Lambda,
    Vecs...>(out, in, lineLen, nLines, op, stream, vecs...);
    }

This struct has callers in:
raft/matrix/linewise_op.cuh

  1. detail::MatrixLinewiseOp<16, 256>::run<m_t, idx_t>(out.data_handle(),
  2. detail::MatrixLinewiseOp<16, 256>::runPadded<m_t, idx_t>(out,

raft/matrix/matrix.cuh

  1. detail::MatrixLinewiseOp<16, 256>::run<m_t, idx_t, Lambda, Vecs...>(

Also, the change will proliferate to several other primitives that use the above 2. For example, in raft/linalg/matrix_vector_op.cuh, there is the opportunity to convert both rowMajor and bcastAlongRows to template parameters. Once the original API changes, the results will cascade down to other caller sites as well.

bool along_lines = rowMajor == bcastAlongRows;
if (rowMajor) {
matrix::linewise_op<MatT, IdxType, row_major, Lambda>(
handle,
make_device_matrix_view<const MatT, IdxType, row_major>(matrix, N, D),
make_device_matrix_view<MatT, IdxType, row_major>(out, N, D),
along_lines,
op,
make_device_vector_view<const VecT, IdxType>(vec, bcastAlongRows ? N : D));
} else {
matrix::linewise_op<MatT, IdxType, col_major, Lambda>(
handle,
make_device_matrix_view<const MatT, IdxType, col_major>(matrix, N, D),
make_device_matrix_view<MatT, IdxType, col_major>(out, N, D),
along_lines,
op,
make_device_vector_view<const VecT, IdxType>(vec, bcastAlongRows ? N : D));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants