@@ -3,13 +3,23 @@ use pyo3::create_exception;
33use pyo3:: exceptions:: PyValueError ;
44use pyo3:: prelude:: * ;
55use pyo3:: types:: { PyList , PyTuple } ;
6- use std:: cell:: { OnceCell , RefCell } ;
6+ use std:: cell:: RefCell ;
77use std:: sync:: { Arc , OnceLock } ;
88use std:: time:: Duration ;
99use tokio:: runtime:: { Handle , Runtime } ;
1010
1111const LEGACY_TRANSACTION_CONTROL : i32 = -1 ;
1212
13+ enum ListOrTuple < ' py > {
14+ List ( & ' py PyList ) ,
15+ Tuple ( & ' py PyTuple ) ,
16+ }
17+
18+ struct ListOrTupleIterator < ' py > {
19+ index : usize ,
20+ inner : & ' py ListOrTuple < ' py >
21+ }
22+
1323fn rt ( ) -> Handle {
1424 static RT : OnceLock < Runtime > = OnceLock :: new ( ) ;
1525
@@ -286,7 +296,7 @@ impl Connection {
286296 fn execute (
287297 self_ : PyRef < ' _ , Self > ,
288298 sql : String ,
289- parameters : Option < & PyTuple > ,
299+ parameters : Option < ListOrTuple > ,
290300 ) -> PyResult < Cursor > {
291301 let cursor = Connection :: cursor ( & self_) ?;
292302 rt ( ) . block_on ( async { execute ( & cursor, sql, parameters) . await } ) ?;
@@ -300,7 +310,7 @@ impl Connection {
300310 ) -> PyResult < Cursor > {
301311 let cursor = Connection :: cursor ( & self_) ?;
302312 for parameters in parameters. unwrap ( ) . iter ( ) {
303- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
313+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
304314 rt ( ) . block_on ( async { execute ( & cursor, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
305315 }
306316 Ok ( cursor)
@@ -396,7 +406,7 @@ impl Cursor {
396406 fn execute < ' a > (
397407 self_ : PyRef < ' a , Self > ,
398408 sql : String ,
399- parameters : Option < & PyTuple > ,
409+ parameters : Option < ListOrTuple > ,
400410 ) -> PyResult < pyo3:: PyRef < ' a , Self > > {
401411 rt ( ) . block_on ( async { execute ( & self_, sql, parameters) . await } ) ?;
402412 Ok ( self_)
@@ -408,7 +418,7 @@ impl Cursor {
408418 parameters : Option < & PyList > ,
409419 ) -> PyResult < pyo3:: PyRef < ' a , Cursor > > {
410420 for parameters in parameters. unwrap ( ) . iter ( ) {
411- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
421+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
412422 rt ( ) . block_on ( async { execute ( & self_, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
413423 }
414424 Ok ( self_)
@@ -552,7 +562,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
552562 Ok ( ( ) )
553563}
554564
555- async fn execute ( cursor : & Cursor , sql : String , parameters : Option < & PyTuple > ) -> PyResult < ( ) > {
565+ async fn execute < ' py > (
566+ cursor : & Cursor ,
567+ sql : String ,
568+ parameters : Option < ListOrTuple < ' py > > ,
569+ ) -> PyResult < ( ) > {
556570 if cursor. conn . borrow ( ) . as_ref ( ) . is_none ( ) {
557571 return Err ( PyValueError :: new_err ( "Connection already closed" ) ) ;
558572 }
@@ -576,7 +590,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
576590 } else if let Ok ( value) = param. extract :: < & [ u8 ] > ( ) {
577591 libsql_core:: Value :: Blob ( value. to_vec ( ) )
578592 } else {
579- return Err ( PyValueError :: new_err ( "Unsupported parameter type" ) ) ;
593+ return Err ( PyValueError :: new_err ( format ! (
594+ "Unsupported parameter type {}" ,
595+ param. to_string( )
596+ ) ) ) ;
580597 } ;
581598 params. push ( param) ;
582599 }
@@ -653,6 +670,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult
653670
654671create_exception ! ( libsql, Error , pyo3:: exceptions:: PyException ) ;
655672
673+ impl < ' py > FromPyObject < ' py > for ListOrTuple < ' py > {
674+ fn extract ( ob : & ' py PyAny ) -> PyResult < Self > {
675+ if let Ok ( list) = ob. downcast :: < PyList > ( ) {
676+ Ok ( ListOrTuple :: List ( list) )
677+ } else if let Ok ( tuple) = ob. downcast :: < PyTuple > ( ) {
678+ Ok ( ListOrTuple :: Tuple ( tuple) )
679+ } else {
680+ Err ( PyValueError :: new_err (
681+ "Expected a list or tuple for parameters" ,
682+ ) )
683+ }
684+ }
685+ }
686+
687+ impl < ' py > ListOrTuple < ' py > {
688+ pub fn iter ( & self ) -> ListOrTupleIterator {
689+ ListOrTupleIterator {
690+ index : 0 ,
691+ inner : self ,
692+ }
693+ }
694+ }
695+
696+ impl < ' py > Iterator for ListOrTupleIterator < ' py > {
697+ type Item = & ' py PyAny ;
698+
699+ fn next ( & mut self ) -> Option < Self :: Item > {
700+ let rv = match self . inner {
701+ ListOrTuple :: List ( list) => list. get_item ( self . index ) ,
702+ ListOrTuple :: Tuple ( tuple) => tuple. get_item ( self . index ) ,
703+ } ;
704+
705+ rv. ok ( ) . map ( |item| {
706+ self . index += 1 ;
707+ item
708+ } )
709+ }
710+ }
656711#[ pymodule]
657712fn libsql ( py : Python , m : & PyModule ) -> PyResult < ( ) > {
658713 let _ = tracing_subscriber:: fmt:: try_init ( ) ;
0 commit comments