@@ -10,6 +10,16 @@ use tokio::runtime::{Handle, Runtime};
1010
1111const LEGACY_TRANSACTION_CONTROL : i32 = -1 ;
1212
13+ enum ListOrTuple < ' py > {
14+ List ( & ' py PyList ) ,
15+ Tuple ( & ' py PyTuple ) ,
16+ }
17+
18+ struct ListOrTupleIterator < ' py > {
19+ index : usize ,
20+ inner : & ' py ListOrTuple < ' py >
21+ }
22+
1323fn rt ( ) -> Handle {
1424 static RT : OnceLock < Runtime > = OnceLock :: new ( ) ;
1525
@@ -286,7 +296,7 @@ impl Connection {
286296 fn execute (
287297 self_ : PyRef < ' _ , Self > ,
288298 sql : String ,
289- parameters : Option < & PyTuple > ,
299+ parameters : Option < ListOrTuple > ,
290300 ) -> PyResult < Cursor > {
291301 let cursor = Connection :: cursor ( & self_) ?;
292302 rt ( ) . block_on ( async { execute ( & cursor, sql, parameters) . await } ) ?;
@@ -300,7 +310,7 @@ impl Connection {
300310 ) -> PyResult < Cursor > {
301311 let cursor = Connection :: cursor ( & self_) ?;
302312 for parameters in parameters. unwrap ( ) . iter ( ) {
303- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
313+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
304314 rt ( ) . block_on ( async { execute ( & cursor, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
305315 }
306316 Ok ( cursor)
@@ -419,7 +429,7 @@ impl Cursor {
419429 fn execute < ' a > (
420430 self_ : PyRef < ' a , Self > ,
421431 sql : String ,
422- parameters : Option < & PyTuple > ,
432+ parameters : Option < ListOrTuple > ,
423433 ) -> PyResult < pyo3:: PyRef < ' a , Self > > {
424434 rt ( ) . block_on ( async { execute ( & self_, sql, parameters) . await } ) ?;
425435 Ok ( self_)
@@ -431,7 +441,7 @@ impl Cursor {
431441 parameters : Option < & PyList > ,
432442 ) -> PyResult < pyo3:: PyRef < ' a , Cursor > > {
433443 for parameters in parameters. unwrap ( ) . iter ( ) {
434- let parameters = parameters. extract :: < & PyTuple > ( ) ?;
444+ let parameters = parameters. extract :: < ListOrTuple > ( ) ?;
435445 rt ( ) . block_on ( async { execute ( & self_, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
436446 }
437447 Ok ( self_)
@@ -575,7 +585,11 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
575585 Ok ( ( ) )
576586}
577587
578- async fn execute ( cursor : & Cursor , sql : String , parameters : Option < & PyTuple > ) -> PyResult < ( ) > {
588+ async fn execute < ' py > (
589+ cursor : & Cursor ,
590+ sql : String ,
591+ parameters : Option < ListOrTuple < ' py > > ,
592+ ) -> PyResult < ( ) > {
579593 if cursor. conn . borrow ( ) . as_ref ( ) . is_none ( ) {
580594 return Err ( PyValueError :: new_err ( "Connection already closed" ) ) ;
581595 }
@@ -599,7 +613,10 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
599613 } else if let Ok ( value) = param. extract :: < & [ u8 ] > ( ) {
600614 libsql_core:: Value :: Blob ( value. to_vec ( ) )
601615 } else {
602- return Err ( PyValueError :: new_err ( "Unsupported parameter type" ) ) ;
616+ return Err ( PyValueError :: new_err ( format ! (
617+ "Unsupported parameter type {}" ,
618+ param. to_string( )
619+ ) ) ) ;
603620 } ;
604621 params. push ( param) ;
605622 }
@@ -676,6 +693,44 @@ fn convert_row(py: Python, row: libsql_core::Row, column_count: i32) -> PyResult
676693
677694create_exception ! ( libsql, Error , pyo3:: exceptions:: PyException ) ;
678695
696+ impl < ' py > FromPyObject < ' py > for ListOrTuple < ' py > {
697+ fn extract ( ob : & ' py PyAny ) -> PyResult < Self > {
698+ if let Ok ( list) = ob. downcast :: < PyList > ( ) {
699+ Ok ( ListOrTuple :: List ( list) )
700+ } else if let Ok ( tuple) = ob. downcast :: < PyTuple > ( ) {
701+ Ok ( ListOrTuple :: Tuple ( tuple) )
702+ } else {
703+ Err ( PyValueError :: new_err (
704+ "Expected a list or tuple for parameters" ,
705+ ) )
706+ }
707+ }
708+ }
709+
710+ impl < ' py > ListOrTuple < ' py > {
711+ pub fn iter ( & self ) -> ListOrTupleIterator {
712+ ListOrTupleIterator {
713+ index : 0 ,
714+ inner : self ,
715+ }
716+ }
717+ }
718+
719+ impl < ' py > Iterator for ListOrTupleIterator < ' py > {
720+ type Item = & ' py PyAny ;
721+
722+ fn next ( & mut self ) -> Option < Self :: Item > {
723+ let rv = match self . inner {
724+ ListOrTuple :: List ( list) => list. get_item ( self . index ) ,
725+ ListOrTuple :: Tuple ( tuple) => tuple. get_item ( self . index ) ,
726+ } ;
727+
728+ rv. ok ( ) . map ( |item| {
729+ self . index += 1 ;
730+ item
731+ } )
732+ }
733+ }
679734#[ pymodule]
680735fn libsql ( py : Python , m : & PyModule ) -> PyResult < ( ) > {
681736 let _ = tracing_subscriber:: fmt:: try_init ( ) ;
0 commit comments