Skip to content

Commit f5174e1

Browse files
committed
feat: bulk load for MySQL
The PR implements bulk loading for MySQL using the "LOAD DATA from io.Reader" feature of github.com/go-sql-driver/mysql - https://github.com/go-sql-driver/mysql?tab=readme-ov-file#load-data-local-infile-support . As expected, bulk loading this way is significantly faster. 1 mln. rows in the "staff" table from the test schema are inserted for 15 sec vs. 120 sec using INSERT: 8x improvement. Note that LOAD DATA INFILE LOCAL is disabled by default on MySQL 8+ servers and must be enabled using SET GLOBAL local_infile = ON beforehand. MySQL doesn't seem to have any remote bulk loading options that are enabled by default. The PR also extends TestCopy in drivers_test.go with comparison of copied data to ensure MySQL bulk loading is safe across data types. Testing Done: tests in drivers_test.go#
1 parent e0e0807 commit f5174e1

File tree

3 files changed

+249
-9
lines changed

3 files changed

+249
-9
lines changed

drivers/drivers_test.go

+115-8
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
"bytes"
88
"context"
99
"database/sql"
10+
"errors"
1011
"flag"
1112
"fmt"
1213
"log"
1314
"net/url"
1415
"os"
16+
"reflect"
1517
"regexp"
1618
"strings"
1719
"testing"
@@ -435,9 +437,11 @@ func TestCopy(t *testing.T) {
435437

436438
testCases := []struct {
437439
dbName string
440+
testCase string
438441
setupQueries []setupQuery
439442
src string
440443
dest string
444+
destCmpQuery string
441445
}{
442446
{
443447
dbName: "pgsql",
@@ -449,7 +453,8 @@ func TestCopy(t *testing.T) {
449453
dest: "staff_copy",
450454
},
451455
{
452-
dbName: "pgsql",
456+
dbName: "pgsql",
457+
testCase: "schemaInDest",
453458
setupQueries: []setupQuery{
454459
{query: "DROP TABLE staff_copy"},
455460
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
@@ -466,8 +471,9 @@ func TestCopy(t *testing.T) {
466471
src: "select * from staff",
467472
dest: "staff_copy",
468473
},
469-
{
470-
dbName: "pgx",
474+
{ // this holds even select iterates over table in a ran
475+
dbName: "pgx",
476+
testCase: "schemaInDest",
471477
setupQueries: []setupQuery{
472478
{query: "DROP TABLE staff_copy"},
473479
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
@@ -478,12 +484,22 @@ func TestCopy(t *testing.T) {
478484
{
479485
dbName: "mysql",
480486
setupQueries: []setupQuery{
481-
{query: "DROP TABLE staff_copy"},
482487
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
483488
},
484489
src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
485490
dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
486491
},
492+
{
493+
dbName: "mysql",
494+
testCase: "bulkCopy",
495+
setupQueries: []setupQuery{
496+
{query: "SET GLOBAL local_infile = ON"},
497+
{query: "DROP TABLE staff_copy"},
498+
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
499+
},
500+
src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
501+
dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)",
502+
},
487503
{
488504
dbName: "sqlserver",
489505
setupQueries: []setupQuery{
@@ -497,9 +513,11 @@ func TestCopy(t *testing.T) {
497513
dbName: "csvq",
498514
setupQueries: []setupQuery{
499515
{query: "CREATE TABLE IF NOT EXISTS staff_copy AS SELECT * FROM `staff.csv` WHERE 0=1", check: true},
516+
{query: "DELETE from staff_copy", check: true},
500517
},
501-
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
502-
dest: "staff_copy",
518+
src: "select first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
519+
dest: "staff_copy",
520+
destCmpQuery: "select first_name, last_name, address_id, email, store_id, active, username, password, datetime(last_update) from staff_copy",
503521
},
504522
}
505523
for _, test := range testCases {
@@ -508,7 +526,11 @@ func TestCopy(t *testing.T) {
508526
continue
509527
}
510528

511-
t.Run(test.dbName, func(t *testing.T) {
529+
testName := test.dbName
530+
if test.testCase != "" {
531+
testName += "-" + test.testCase
532+
}
533+
t.Run(testName, func(t *testing.T) {
512534

513535
// TODO test copy from a different DB, maybe csvq?
514536
// TODO test copy from same DB
@@ -524,7 +546,7 @@ func TestCopy(t *testing.T) {
524546
t.Fatalf("Could not get rows to copy: %v", err)
525547
}
526548

527-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
549+
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Second)
528550
defer cancel()
529551
var rlen int64 = 1
530552
n, err := drivers.Copy(ctx, db.URL, nil, nil, rows, test.dest)
@@ -534,10 +556,95 @@ func TestCopy(t *testing.T) {
534556
if n != rlen {
535557
t.Fatalf("Expected to copy %d rows but got %d", rlen, n)
536558
}
559+
560+
checkSameData(t, ctx, pg.DB, test.src, db.DB, test.destCmpQuery)
537561
})
538562
}
539563
}
540564

565+
// checkSameData fails the test if the data in the srcDB."staff" table is different than destDB."staff_copy" table
566+
func checkSameData(t *testing.T, ctx context.Context, srcDB *sql.DB, srcQuery string, destDB *sql.DB, destCmpQuery string) {
567+
if destCmpQuery == "" {
568+
srcQuery = strings.ToLower(srcQuery)
569+
if !strings.Contains(srcQuery, "from staff") {
570+
t.Fatalf("destCmpQuery needs to be configured if src '%s' is not for table 'staff'", srcQuery)
571+
}
572+
// if destCmpQuery needs special syntax, configure it in the test case definitions above
573+
destCmpQuery = strings.Replace(srcQuery, "from staff", "from staff_copy", 1)
574+
}
575+
srcValues, srcColumnTypes, err := getSrcRow(ctx, srcDB, srcQuery)
576+
if err != nil {
577+
t.Fatalf("Could not get src row from database: %v", err)
578+
}
579+
destValues, err := getDestRow(ctx, destDB, destCmpQuery, srcColumnTypes)
580+
if err != nil {
581+
t.Fatalf("Could not get dest row from database: %v", err)
582+
}
583+
// Comparing more than 1 row is more complex because SELECT result order is undefined without order by
584+
adjustDates(srcValues, destValues)
585+
if !reflect.DeepEqual(srcValues, destValues) {
586+
t.Fatalf("Source and dest row don't match: \n%v\n vs \n%v", srcValues, destValues)
587+
}
588+
}
589+
590+
// adjustDates removes sub-second differences between any dates in the two rows, because
591+
// the difference are likely caused by difference in precision and not by a copy issue
592+
func adjustDates(src []interface{}, dest []interface{}) {
593+
for i, v := range src {
594+
srcDate, okSrc := v.(time.Time)
595+
destDate, okDest := dest[i].(time.Time)
596+
if okSrc && okDest && srcDate.Sub(destDate).Abs() <= time.Second {
597+
dest[i] = srcDate
598+
}
599+
}
600+
}
601+
602+
func getSrcRow(ctx context.Context, db *sql.DB, query string) ([]interface{}, []*sql.ColumnType, error) {
603+
rows, err := db.QueryContext(ctx, query)
604+
if err != nil {
605+
return nil, nil, err
606+
}
607+
defer rows.Close()
608+
columnTypes, err := rows.ColumnTypes()
609+
if err != nil {
610+
return nil, nil, err
611+
}
612+
values, err := readRow(rows, columnTypes)
613+
return values, columnTypes, err
614+
}
615+
616+
func getDestRow(ctx context.Context, db *sql.DB, query string, columnTypes []*sql.ColumnType) ([]interface{}, error) {
617+
rows, err := db.QueryContext(ctx, query)
618+
if err != nil {
619+
return nil, err
620+
}
621+
defer rows.Close()
622+
return readRow(rows, columnTypes)
623+
}
624+
625+
func readRow(rows *sql.Rows, columnTypes []*sql.ColumnType) ([]interface{}, error) {
626+
if !rows.Next() {
627+
return nil, errors.New("exactly one row expected but got 0")
628+
}
629+
// some DB drivers don't handle reading into *any well so use *reportedType instead
630+
values := make([]interface{}, len(columnTypes))
631+
for i := 0; i < len(columnTypes); i++ {
632+
values[i] = reflect.New(columnTypes[i].ScanType()).Interface()
633+
}
634+
err := rows.Scan(values...)
635+
if err != nil {
636+
return nil, err
637+
}
638+
if rows.Next() {
639+
return nil, errors.New("exactly one row expected but more found")
640+
}
641+
// dereference the pointers
642+
for i, v := range values {
643+
values[i] = reflect.ValueOf(v).Elem().Interface()
644+
}
645+
return values, nil
646+
}
647+
541648
// filesEqual compares the files at paths a and b and returns an error if
542649
// the content is not equal. Ignore is a regex. All matches will be removed
543650
// from the file contents before comparison.

drivers/mysql/copy.go

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package mysql
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"encoding/csv"
7+
"fmt"
8+
"io"
9+
"reflect"
10+
"strings"
11+
12+
"github.com/go-sql-driver/mysql"
13+
"github.com/xo/usql/drivers"
14+
)
15+
16+
func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
17+
localInfileSupported := false
18+
row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile")
19+
err := row.Scan(&localInfileSupported)
20+
if err == nil && localInfileSupported && !hasBlobColumn(rows) {
21+
return bulkCopy(ctx, db, rows, table)
22+
} else {
23+
return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table)
24+
}
25+
}
26+
27+
func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
28+
mysql.RegisterReaderHandler("data", func() io.Reader {
29+
return toCsvReader(rows)
30+
})
31+
defer mysql.DeregisterReaderHandler("data")
32+
tx, err := db.BeginTx(ctx, nil)
33+
if err != nil {
34+
return 0, err
35+
}
36+
var cnt int64
37+
csvSpec := " FIELDS TERMINATED BY ',' "
38+
stmt := fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s",
39+
// if there is a column list, csvSpec goes between the table name and the list
40+
strings.Replace(table, "(", csvSpec+" (", 1))
41+
// if there wasn't a column list in the table spec, csvSpec goes at the end
42+
if !strings.Contains(table, "(") {
43+
stmt += csvSpec
44+
}
45+
res, err := tx.ExecContext(ctx, stmt)
46+
if err != nil {
47+
tx.Rollback()
48+
} else {
49+
err = tx.Commit()
50+
if err == nil {
51+
cnt, err = res.RowsAffected()
52+
}
53+
}
54+
return cnt, err
55+
}
56+
57+
func hasBlobColumn(rows *sql.Rows) bool {
58+
columnTypes, err := rows.ColumnTypes()
59+
if err != nil {
60+
return false
61+
}
62+
for _, ct := range columnTypes {
63+
if ct.DatabaseTypeName() == "BLOB" {
64+
return true
65+
}
66+
}
67+
return false
68+
}
69+
70+
// toCsvReader converts the rows to CSV, compatible with LOAD DATA, and creates a reader over the CSV
71+
// as required by the MySQL driver
72+
func toCsvReader(rows *sql.Rows) io.Reader {
73+
r, w := io.Pipe()
74+
// Writes to w block until the driver is ready to read data, or the driver closes the reader.
75+
// The driver code always closes the reader if it implements io.Closer -
76+
// https://github.com/go-sql-driver/mysql/blob/575e1b288d624fb14bf56532689f3ec1c1989149/infile.go#L112
77+
// In turn, that guarantees our goroutine will exit and won't leak.
78+
go writeAsCsv(rows, w)
79+
return r
80+
}
81+
82+
// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE
83+
func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) {
84+
defer w.Close() // noop if already closed
85+
columnTypes, err := rows.ColumnTypes()
86+
if err != nil {
87+
w.CloseWithError(err)
88+
return
89+
}
90+
values := make([]interface{}, len(columnTypes))
91+
valueRefs := make([]reflect.Value, len(columnTypes))
92+
for i := 0; i < len(columnTypes); i++ {
93+
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
94+
values[i] = valueRefs[i].Interface()
95+
}
96+
record := make([]string, len(values))
97+
csvWriter := csv.NewWriter(w)
98+
for rows.Next() {
99+
if err = rows.Err(); err != nil {
100+
break
101+
}
102+
err = rows.Scan(values...)
103+
if err != nil {
104+
break
105+
}
106+
for i, valueRef := range valueRefs {
107+
val := valueRef.Elem().Interface()
108+
val = toIntIfBool(val)
109+
// NB: There is no nice way to store BLOBs for use in LOAD DATA.
110+
// Use regular copy if there are BLOB columns. See fallback code in copyRows.
111+
record[i] = fmt.Sprintf("%v", val)
112+
}
113+
err = csvWriter.Write(record) // may block but not forever, see toCsvReader
114+
if err != nil {
115+
break
116+
}
117+
}
118+
if err == nil {
119+
csvWriter.Flush() // may block but not forever, see toCsvReader
120+
err = csvWriter.Error()
121+
}
122+
w.CloseWithError(err) // same as w.Close(), if err is nil
123+
}
124+
125+
func toIntIfBool(val interface{}) interface{} {
126+
if boolVal, ok := val.(bool); ok {
127+
val = 0
128+
if boolVal {
129+
val = 1
130+
}
131+
}
132+
return val
133+
}

drivers/mysql/mysql.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func init() {
4545
NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer {
4646
return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w)
4747
},
48-
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
48+
Copy: copyRows,
4949
NewCompleter: mymeta.NewCompleter,
5050
}, "memsql", "vitess", "tidb")
5151
}

0 commit comments

Comments
 (0)