@@ -7,11 +7,13 @@ import (
7
7
"bytes"
8
8
"context"
9
9
"database/sql"
10
+ "errors"
10
11
"flag"
11
12
"fmt"
12
13
"log"
13
14
"net/url"
14
15
"os"
16
+ "reflect"
15
17
"regexp"
16
18
"strings"
17
19
"testing"
@@ -435,9 +437,11 @@ func TestCopy(t *testing.T) {
435
437
436
438
testCases := []struct {
437
439
dbName string
440
+ testCase string
438
441
setupQueries []setupQuery
439
442
src string
440
443
dest string
444
+ destCmpQuery string
441
445
}{
442
446
{
443
447
dbName : "pgsql" ,
@@ -449,7 +453,8 @@ func TestCopy(t *testing.T) {
449
453
dest : "staff_copy" ,
450
454
},
451
455
{
452
- dbName : "pgsql" ,
456
+ dbName : "pgsql" ,
457
+ testCase : "schemaInDest" ,
453
458
setupQueries : []setupQuery {
454
459
{query : "DROP TABLE staff_copy" },
455
460
{query : "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1" , check : true },
@@ -466,8 +471,9 @@ func TestCopy(t *testing.T) {
466
471
src : "select * from staff" ,
467
472
dest : "staff_copy" ,
468
473
},
469
- {
470
- dbName : "pgx" ,
474
+ { // this holds even select iterates over table in a ran
475
+ dbName : "pgx" ,
476
+ testCase : "schemaInDest" ,
471
477
setupQueries : []setupQuery {
472
478
{query : "DROP TABLE staff_copy" },
473
479
{query : "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1" , check : true },
@@ -478,12 +484,22 @@ func TestCopy(t *testing.T) {
478
484
{
479
485
dbName : "mysql" ,
480
486
setupQueries : []setupQuery {
481
- {query : "DROP TABLE staff_copy" },
482
487
{query : "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1" , check : true },
483
488
},
484
489
src : "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff" ,
485
490
dest : "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)" ,
486
491
},
492
+ {
493
+ dbName : "mysql" ,
494
+ testCase : "bulkCopy" ,
495
+ setupQueries : []setupQuery {
496
+ {query : "SET GLOBAL local_infile = ON" },
497
+ {query : "DROP TABLE staff_copy" },
498
+ {query : "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1" , check : true },
499
+ },
500
+ src : "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff" ,
501
+ dest : "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)" ,
502
+ },
487
503
{
488
504
dbName : "sqlserver" ,
489
505
setupQueries : []setupQuery {
@@ -497,9 +513,11 @@ func TestCopy(t *testing.T) {
497
513
dbName : "csvq" ,
498
514
setupQueries : []setupQuery {
499
515
{query : "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1" , check : true },
516
+ {query : "DELETE from staff_copy" , check : true },
500
517
},
501
- src : "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff" ,
502
- dest : "staff_copy" ,
518
+ src : "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff" ,
519
+ dest : "staff_copy" ,
520
+ destCmpQuery : "select first_name, last_name, address_id, email, store_id, active, username, password, datetime(last_update) from staff_copy" ,
503
521
},
504
522
}
505
523
for _ , test := range testCases {
@@ -508,7 +526,11 @@ func TestCopy(t *testing.T) {
508
526
continue
509
527
}
510
528
511
- t .Run (test .dbName , func (t * testing.T ) {
529
+ testName := test .dbName
530
+ if test .testCase != "" {
531
+ testName += "-" + test .testCase
532
+ }
533
+ t .Run (testName , func (t * testing.T ) {
512
534
513
535
// TODO test copy from a different DB, maybe csvq?
514
536
// TODO test copy from same DB
@@ -524,7 +546,7 @@ func TestCopy(t *testing.T) {
524
546
t .Fatalf ("Could not get rows to copy: %v" , err )
525
547
}
526
548
527
- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
549
+ ctx , cancel := context .WithTimeout (context .Background (), 500 * time .Second )
528
550
defer cancel ()
529
551
var rlen int64 = 1
530
552
n , err := drivers .Copy (ctx , db .URL , nil , nil , rows , test .dest )
@@ -534,10 +556,95 @@ func TestCopy(t *testing.T) {
534
556
if n != rlen {
535
557
t .Fatalf ("Expected to copy %d rows but got %d" , rlen , n )
536
558
}
559
+
560
+ checkSameData (t , ctx , pg .DB , test .src , db .DB , test .destCmpQuery )
537
561
})
538
562
}
539
563
}
540
564
565
+ // checkSameData fails the test if the data in the srcDB."staff" table is different than destDB."staff_copy" table
566
+ func checkSameData (t * testing.T , ctx context.Context , srcDB * sql.DB , srcQuery string , destDB * sql.DB , destCmpQuery string ) {
567
+ if destCmpQuery == "" {
568
+ srcQuery = strings .ToLower (srcQuery )
569
+ if ! strings .Contains (srcQuery , "from staff" ) {
570
+ t .Fatalf ("destCmpQuery needs to be configured if src '%s' is not for table 'staff'" , srcQuery )
571
+ }
572
+ // if destCmpQuery needs special syntax, configure it in the test case definitions above
573
+ destCmpQuery = strings .Replace (srcQuery , "from staff" , "from staff_copy" , 1 )
574
+ }
575
+ srcValues , srcColumnTypes , err := getSrcRow (ctx , srcDB , srcQuery )
576
+ if err != nil {
577
+ t .Fatalf ("Could not get src row from database: %v" , err )
578
+ }
579
+ destValues , err := getDestRow (ctx , destDB , destCmpQuery , srcColumnTypes )
580
+ if err != nil {
581
+ t .Fatalf ("Could not get dest row from database: %v" , err )
582
+ }
583
+ // Comparing more than 1 row is more complex because SELECT result order is undefined without order by
584
+ adjustDates (srcValues , destValues )
585
+ if ! reflect .DeepEqual (srcValues , destValues ) {
586
+ t .Fatalf ("Source and dest row don't match: \n %v\n vs \n %v" , srcValues , destValues )
587
+ }
588
+ }
589
+
590
+ // adjustDates removes sub-second differences between any dates in the two rows, because
591
+ // the difference are likely caused by difference in precision and not by a copy issue
592
+ func adjustDates (src []interface {}, dest []interface {}) {
593
+ for i , v := range src {
594
+ srcDate , okSrc := v .(time.Time )
595
+ destDate , okDest := dest [i ].(time.Time )
596
+ if okSrc && okDest && srcDate .Sub (destDate ).Abs () <= time .Second {
597
+ dest [i ] = srcDate
598
+ }
599
+ }
600
+ }
601
+
602
+ func getSrcRow (ctx context.Context , db * sql.DB , query string ) ([]interface {}, []* sql.ColumnType , error ) {
603
+ rows , err := db .QueryContext (ctx , query )
604
+ if err != nil {
605
+ return nil , nil , err
606
+ }
607
+ defer rows .Close ()
608
+ columnTypes , err := rows .ColumnTypes ()
609
+ if err != nil {
610
+ return nil , nil , err
611
+ }
612
+ values , err := readRow (rows , columnTypes )
613
+ return values , columnTypes , err
614
+ }
615
+
616
+ func getDestRow (ctx context.Context , db * sql.DB , query string , columnTypes []* sql.ColumnType ) ([]interface {}, error ) {
617
+ rows , err := db .QueryContext (ctx , query )
618
+ if err != nil {
619
+ return nil , err
620
+ }
621
+ defer rows .Close ()
622
+ return readRow (rows , columnTypes )
623
+ }
624
+
625
+ func readRow (rows * sql.Rows , columnTypes []* sql.ColumnType ) ([]interface {}, error ) {
626
+ if ! rows .Next () {
627
+ return nil , errors .New ("exactly one row expected but got 0" )
628
+ }
629
+ // some DB drivers don't handle reading into *any well so use *reportedType instead
630
+ values := make ([]interface {}, len (columnTypes ))
631
+ for i := 0 ; i < len (columnTypes ); i ++ {
632
+ values [i ] = reflect .New (columnTypes [i ].ScanType ()).Interface ()
633
+ }
634
+ err := rows .Scan (values ... )
635
+ if err != nil {
636
+ return nil , err
637
+ }
638
+ if rows .Next () {
639
+ return nil , errors .New ("exactly one row expected but more found" )
640
+ }
641
+ // dereference the pointers
642
+ for i , v := range values {
643
+ values [i ] = reflect .ValueOf (v ).Elem ().Interface ()
644
+ }
645
+ return values , nil
646
+ }
647
+
541
648
// filesEqual compares the files at paths a and b and returns an error if
542
649
// the content is not equal. Ignore is a regex. All matches will be removed
543
650
// from the file contents before comparison.
0 commit comments