Skip to content

Commit 8c56620

Browse files
authored
Merge pull request #63 from libsql/fix-memory-leak
Fix memory leak
2 parents 3c23110 + b33f6b4 commit 8c56620

File tree

2 files changed

+21
-41
lines changed

2 files changed

+21
-41
lines changed

package-lock.json

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/lib.rs

+19-39
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use neon::types::JsPromise;
33
use neon::{prelude::*, types::JsBigInt};
44
use once_cell::sync::OnceCell;
55
use std::cell::RefCell;
6-
use std::sync::{Arc, Weak};
6+
use std::sync::Arc;
77
use tokio::{runtime::Runtime, sync::Mutex};
88
use tracing::trace;
99

@@ -18,7 +18,6 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
1818
struct Database {
1919
db: Arc<Mutex<libsql::Database>>,
2020
conn: RefCell<Option<Arc<Mutex<libsql::Connection>>>>,
21-
stmts: Arc<Mutex<Vec<Arc<Mutex<libsql::Statement>>>>>,
2221
default_safe_integers: RefCell<bool>,
2322
}
2423

@@ -29,7 +28,6 @@ impl Database {
2928
Database {
3029
db: Arc::new(Mutex::new(db)),
3130
conn: RefCell::new(Some(Arc::new(Mutex::new(conn)))),
32-
stmts: Arc::new(Mutex::new(vec![])),
3331
default_safe_integers: RefCell::new(false),
3432
}
3533
}
@@ -91,15 +89,11 @@ impl Database {
9189
}
9290

9391
fn js_close(mut cx: FunctionContext) -> JsResult<JsUndefined> {
92+
// the conn will be closed when the last statement in discarded. In most situation that
93+
// means immediately because you don't want to hold on a statement for longer that its
94+
// database is alive.
9495
trace!("Closing database");
9596
let db: Handle<'_, JsBox<Database>> = cx.this()?;
96-
for stmt in db.stmts.blocking_lock().iter() {
97-
let mut stmt = stmt.blocking_lock();
98-
stmt.finalize();
99-
}
100-
db.stmts.blocking_lock().clear();
101-
let conn = db.get_conn();
102-
conn.blocking_lock().close();
10397
db.conn.replace(None);
10498
Ok(cx.undefined())
10599
}
@@ -185,13 +179,9 @@ impl Database {
185179
let result = rt.block_on(async { conn.lock().await.prepare(&sql).await });
186180
let stmt = result.or_else(|err| throw_libsql_error(&mut cx, err))?;
187181
let stmt = Arc::new(Mutex::new(stmt));
188-
{
189-
let mut stmts = db.stmts.blocking_lock();
190-
stmts.push(stmt.clone());
191-
}
192182
let stmt = Statement {
193-
conn: Arc::downgrade(&conn),
194-
stmt: Arc::downgrade(&stmt),
183+
conn: conn.clone(),
184+
stmt,
195185
raw: RefCell::new(false),
196186
safe_ints: RefCell::new(*db.default_safe_integers.borrow()),
197187
};
@@ -207,18 +197,13 @@ impl Database {
207197
let safe_ints = *db.default_safe_integers.borrow();
208198
let rt = runtime(&mut cx)?;
209199
let conn = db.get_conn();
210-
let stmts = db.stmts.clone();
211200
rt.spawn(async move {
212201
match conn.lock().await.prepare(&sql).await {
213202
Ok(stmt) => {
214203
let stmt = Arc::new(Mutex::new(stmt));
215-
{
216-
let mut stmts = stmts.lock().await;
217-
stmts.push(stmt.clone());
218-
}
219204
let stmt = Statement {
220-
conn: Arc::downgrade(&conn),
221-
stmt: Arc::downgrade(&stmt),
205+
conn: conn.clone(),
206+
stmt,
222207
raw: RefCell::new(false),
223208
safe_ints: RefCell::new(safe_ints),
224209
};
@@ -371,8 +356,8 @@ pub fn convert_sqlite_code(code: i32) -> String {
371356
}
372357
}
373358
struct Statement {
374-
conn: Weak<Mutex<libsql::Connection>>,
375-
stmt: Weak<Mutex<libsql::Statement>>,
359+
conn: Arc<Mutex<libsql::Connection>>,
360+
stmt: Arc<Mutex<libsql::Statement>>,
376361
raw: RefCell<bool>,
377362
safe_ints: RefCell<bool>,
378363
}
@@ -415,8 +400,7 @@ fn js_value_to_value(
415400
impl Statement {
416401
fn js_raw(mut cx: FunctionContext) -> JsResult<JsNull> {
417402
let stmt: Handle<'_, JsBox<Statement>> = cx.this()?;
418-
let raw_stmt = stmt.stmt.upgrade().unwrap();
419-
let raw_stmt = raw_stmt.blocking_lock();
403+
let raw_stmt = stmt.stmt.blocking_lock();
420404
if raw_stmt.columns().is_empty() {
421405
return cx.throw_error("The raw() method is only for statements that return data");
422406
}
@@ -434,14 +418,13 @@ impl Statement {
434418
let stmt: Handle<'_, JsBox<Statement>> = cx.this()?;
435419
let params = cx.argument::<JsValue>(0)?;
436420
let params = convert_params(&mut cx, &stmt, params)?;
437-
let raw_stmt = stmt.stmt.upgrade().unwrap();
438-
let mut raw_stmt = raw_stmt.blocking_lock();
421+
let mut raw_stmt = stmt.stmt.blocking_lock();
439422
raw_stmt.reset();
440423
let fut = raw_stmt.execute(params);
441424
let rt = runtime(&mut cx)?;
442425
let result = rt.block_on(fut);
443426
let changes = result.or_else(|err| throw_libsql_error(&mut cx, err))?;
444-
let raw_conn = stmt.conn.upgrade().unwrap();
427+
let raw_conn = stmt.conn.clone();
445428
let last_insert_rowid = raw_conn.blocking_lock().last_insert_rowid();
446429
let info = cx.empty_object();
447430
let changes = cx.number(changes as f64);
@@ -456,8 +439,7 @@ impl Statement {
456439
let params = cx.argument::<JsValue>(0)?;
457440
let params = convert_params(&mut cx, &stmt, params)?;
458441
let safe_ints = *stmt.safe_ints.borrow();
459-
let raw_stmt = stmt.stmt.upgrade().unwrap();
460-
let mut raw_stmt = raw_stmt.blocking_lock();
442+
let mut raw_stmt = stmt.stmt.blocking_lock();
461443
let fut = raw_stmt.query(params);
462444
let rt = runtime(&mut cx)?;
463445
let result = rt.block_on(fut);
@@ -488,9 +470,8 @@ impl Statement {
488470
let params = cx.argument::<JsValue>(0)?;
489471
let params = convert_params(&mut cx, &stmt, params)?;
490472
let rt = runtime(&mut cx)?;
491-
let raw_stmt = stmt.stmt.upgrade().unwrap();
492473
let result = rt.block_on(async move {
493-
let mut raw_stmt = raw_stmt.lock().await;
474+
let mut raw_stmt = stmt.stmt.lock().await;
494475
raw_stmt.reset();
495476
raw_stmt.query(params).await
496477
});
@@ -507,16 +488,16 @@ impl Statement {
507488
let stmt: Handle<'_, JsBox<Statement>> = cx.this()?;
508489
let params = cx.argument::<JsValue>(0)?;
509490
let params = convert_params(&mut cx, &stmt, params)?;
510-
let raw_stmt = stmt.stmt.upgrade().unwrap();
511491
{
512-
let mut raw_stmt = raw_stmt.blocking_lock();
492+
let mut raw_stmt = stmt.stmt.blocking_lock();
513493
raw_stmt.reset();
514494
}
515495
let (deferred, promise) = cx.promise();
516496
let channel = cx.channel();
517497
let rt = runtime(&mut cx)?;
518498
let raw = *stmt.raw.borrow();
519499
let safe_ints = *stmt.safe_ints.borrow();
500+
let raw_stmt = stmt.stmt.clone();
520501
rt.spawn(async move {
521502
let result = {
522503
let mut raw_stmt = raw_stmt.lock().await;
@@ -547,8 +528,7 @@ impl Statement {
547528
fn js_columns(mut cx: FunctionContext) -> JsResult<JsValue> {
548529
let stmt: Handle<'_, JsBox<Statement>> = cx.this()?;
549530
let result = cx.empty_array();
550-
let stmt = stmt.stmt.upgrade().unwrap();
551-
let raw_stmt = stmt.blocking_lock();
531+
let raw_stmt = stmt.stmt.blocking_lock();
552532
for (i, col) in raw_stmt.columns().iter().enumerate() {
553533
let column = cx.empty_object();
554534
let column_name = cx.string(col.name());
@@ -665,7 +645,7 @@ fn convert_params_object(
665645
v: Handle<'_, JsObject>,
666646
) -> NeonResult<libsql::params::Params> {
667647
let mut params = vec![];
668-
let stmt = stmt.stmt.upgrade().unwrap();
648+
let stmt = &stmt.stmt;
669649
let raw_stmt = stmt.blocking_lock();
670650
for idx in 0..raw_stmt.parameter_count() {
671651
let name = raw_stmt.parameter_name((idx + 1) as i32).unwrap();

0 commit comments

Comments
 (0)