-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Description
Describe the bug
Sparse matrix is constructed with incorrect num_triplets, which leads to buffer-overflow read.
To Reproduce
PoC was modified from test_sparse_matrix.py
import taichi as ti
arch = ti.cpu # or ti.cuda
ti.init(arch=arch)
def test_build_sparse_matrix_frome_ndarray(dtype, storage_format):
n = 8
triplets = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=n)
A = ti.linalg.SparseMatrix(n=10, m=10, dtype=ti.f32, storage_format=storage_format)
@ti.kernel
def fill(triplets: ti.types.ndarray()):
for i in range(n):
triplet = ti.Vector([i, i, i], dt=ti.f32)
triplets[i] = triplet
fill(triplets)
A.build_from_ndarray(triplets)
for i in range(n):
assert A[i, i] == i
test_build_sparse_matrix_frome_ndarray(ti.f32, "col_major")
Additional comments
At make_sparse_matrix_from_ndarray
(taichi/program/sparse_matrix.cpp:378), it calculates num_triplets by ndarray.get_nelement() * ndarray.get_element_size() / 3
. Here, let ndarray.get_nelement()
be N and ndarray.get_element_size()
be M. and we know only 3*N*M bytes are accessible from data_ptr
.
void make_sparse_matrix_from_ndarray(Program *prog,
SparseMatrix &sm,
const Ndarray &ndarray) {
std::string sdtype = taichi::lang::data_type_name(sm.get_data_type());
auto data_ptr = prog->get_ndarray_data_ptr_as_int(&ndarray);
auto num_triplets = ndarray.get_nelement() * ndarray.get_element_size() / 3;
if (sdtype == "f32") {
build_ndarray_template<float32>(sm, data_ptr, num_triplets);
} else if (sdtype == "f64") {
build_ndarray_template<float64>(sm, data_ptr, num_triplets);
} else {
TI_ERROR("Unsupported sparse matrix data type {}!", sdtype);
}
}
And at build_ndarray_template
(taichi/program/sparse_matrix.cpp:373), it casts data to T typed array, and accesses to index 0 to 3*(num_triplets-1)+2, which is 3*(N*M/3-1)+2 = N*M-1.
Thus, it accesses ((char*) data_ptr + (N*M-1)*M), that is, it overflows the limit ((T*) data_ptr + 3*N-1).
template <typename T>
void build_ndarray_template(SparseMatrix &sm,
intptr_t data_ptr,
size_t num_triplets) {
using V = Eigen::Triplet<T>;
std::vector<V> triplets;
T *data = reinterpret_cast<T *>(data_ptr);
for (int i = 0; i < num_triplets; i++) {
x.push_back(
V(data[i * 3], data[i * 3 + 1], taichi_union_cast<T>(data[i * 3 + 2])));
}
sm.build_triplets(static_cast<void *>(&triplets));
}
To fix this, we need to correct the num_triplets
as ndarray.get_nelement()
.
I will open the PR for this.
Thank you!:)
Metadata
Metadata
Assignees
Labels
Type
Projects
Status