Skip to content

Incorrect calculation of num_triplets at SparseMatrix construction #8501

@junwha

Description

@junwha

Describe the bug
Sparse matrix is constructed with incorrect num_triplets, which leads to buffer-overflow read.

To Reproduce
PoC was modified from test_sparse_matrix.py

import taichi as ti
arch = ti.cpu # or ti.cuda
ti.init(arch=arch)

def test_build_sparse_matrix_frome_ndarray(dtype, storage_format):
    n = 8
    triplets = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=n)
    A = ti.linalg.SparseMatrix(n=10, m=10, dtype=ti.f32, storage_format=storage_format)

    @ti.kernel
    def fill(triplets: ti.types.ndarray()):
        for i in range(n):
            triplet = ti.Vector([i, i, i], dt=ti.f32)
            triplets[i] = triplet

    fill(triplets)
    A.build_from_ndarray(triplets)

    for i in range(n):
        assert A[i, i] == i

test_build_sparse_matrix_frome_ndarray(ti.f32, "col_major")

Additional comments
At make_sparse_matrix_from_ndarray (taichi/program/sparse_matrix.cpp:378), it calculates num_triplets by ndarray.get_nelement() * ndarray.get_element_size() / 3. Here, let ndarray.get_nelement() be N and ndarray.get_element_size() be M. and we know only 3*N*M bytes are accessible from data_ptr.

void make_sparse_matrix_from_ndarray(Program *prog,
                                     SparseMatrix &sm,
                                     const Ndarray &ndarray) {
  std::string sdtype = taichi::lang::data_type_name(sm.get_data_type());
  auto data_ptr = prog->get_ndarray_data_ptr_as_int(&ndarray);
  auto num_triplets = ndarray.get_nelement() * ndarray.get_element_size() / 3;
  if (sdtype == "f32") {
    build_ndarray_template<float32>(sm, data_ptr, num_triplets);
  } else if (sdtype == "f64") {
    build_ndarray_template<float64>(sm, data_ptr, num_triplets);
  } else {
    TI_ERROR("Unsupported sparse matrix data type {}!", sdtype);
  }
}

And at build_ndarray_template (taichi/program/sparse_matrix.cpp:373), it casts data to T typed array, and accesses to index 0 to 3*(num_triplets-1)+2, which is 3*(N*M/3-1)+2 = N*M-1.
Thus, it accesses ((char*) data_ptr + (N*M-1)*M), that is, it overflows the limit ((T*) data_ptr + 3*N-1).

template <typename T>
void build_ndarray_template(SparseMatrix &sm,
                            intptr_t data_ptr,
                            size_t num_triplets) {
  using V = Eigen::Triplet<T>;
  std::vector<V> triplets;
  T *data = reinterpret_cast<T *>(data_ptr);
  for (int i = 0; i < num_triplets; i++) {
    x.push_back(
        V(data[i * 3], data[i * 3 + 1], taichi_union_cast<T>(data[i * 3 + 2])));
  }
  sm.build_triplets(static_cast<void *>(&triplets));
}

To fix this, we need to correct the num_triplets as ndarray.get_nelement().
I will open the PR for this.

Thank you!:)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    Untriaged

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions