Skip to content

Commit

Permalink
[in progress] Bulk load for MySQL
Browse files Browse the repository at this point in the history
  • Loading branch information
murfffi committed Dec 2, 2024
1 parent c3e8cde commit 6f73721
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 4 deletions.
24 changes: 21 additions & 3 deletions drivers/drivers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ func TestCopy(t *testing.T) {

testCases := []struct {
dbName string
testCase string
setupQueries []setupQuery
src string
dest string
Expand All @@ -449,7 +450,8 @@ func TestCopy(t *testing.T) {
dest: "staff_copy",
},
{
dbName: "pgsql",
dbName: "pgsql",
testCase: "schemaInDest",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
Expand All @@ -467,7 +469,8 @@ func TestCopy(t *testing.T) {
dest: "staff_copy",
},
{
dbName: "pgx",
dbName: "pgx",
testCase: "schemaInDest",
setupQueries: []setupQuery{
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
Expand All @@ -484,6 +487,17 @@ func TestCopy(t *testing.T) {
src: "select staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy(staff_id, first_name, last_name, address_id, picture, email, store_id, active, username, password, last_update)",
},
{
dbName: "mysql",
testCase: "bulkCopy",
setupQueries: []setupQuery{
{query: "SET GLOBAL local_infile = ON"},
{query: "DROP TABLE staff_copy"},
{query: "CREATE TABLE staff_copy AS SELECT * FROM staff WHERE 0=1", check: true},
},
src: "select staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update from staff",
dest: "staff_copy(staff_id, first_name, last_name, address_id, email, store_id, active, username, password, last_update)",
},
{
dbName: "sqlserver",
setupQueries: []setupQuery{
Expand All @@ -508,7 +522,11 @@ func TestCopy(t *testing.T) {
continue
}

t.Run(test.dbName, func(t *testing.T) {
testName := test.dbName
if test.testCase != "" {
testName += "-" + test.testCase
}
t.Run(testName, func(t *testing.T) {

// TODO test copy from a different DB, maybe csvq?
// TODO test copy from same DB
Expand Down
108 changes: 108 additions & 0 deletions drivers/mysql/copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package mysql

import (
"context"
"database/sql"
"encoding/csv"
"fmt"
"io"
"os"
"reflect"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/xo/usql/drivers"
)

func copyRows(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
localInfileSupported := false
row := db.QueryRowContext(ctx, "SELECT @@GLOBAL.local_infile")
err := row.Scan(&localInfileSupported)
if err == nil && localInfileSupported && !hasBlobColumn(rows) {
return bulkCopy(ctx, db, rows, table)
} else {
return drivers.CopyWithInsert(func(int) string { return "?" })(ctx, db, rows, table)
}
}

func bulkCopy(ctx context.Context, db *sql.DB, rows *sql.Rows, table string) (int64, error) {
mysql.RegisterReaderHandler("data", func() io.Reader {
return toCsvReader(rows)
})
defer mysql.DeregisterReaderHandler("data")
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
var cnt int64
res, err := tx.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE %s",
strings.Replace(table, "(", " FIELDS TERMINATED BY ',' (", 1)))
if err != nil {
tx.Rollback()
} else {
err = tx.Commit()
if err == nil {
cnt, err = res.RowsAffected()
}
}
return cnt, err
}

func hasBlobColumn(rows *sql.Rows) bool {
columnTypes, err := rows.ColumnTypes()
if err != nil {
return false
}
for _, ct := range columnTypes {
if ct.DatabaseTypeName() == "BLOB" {
return true
}
}
return false
}

func toCsvReader(rows *sql.Rows) io.Reader {
r, w := io.Pipe()
go writeAsCsv(rows, w)
return r
}

// writeAsCsv writes the rows in a CSV format compatible with LOAD DATA INFILE
func writeAsCsv(rows *sql.Rows, w *io.PipeWriter) {
defer w.Close() // noop if already closed
columnTypes, err := rows.ColumnTypes()
if err != nil {
w.CloseWithError(err)
return
}
values := make([]interface{}, len(columnTypes))
valueRefs := make([]reflect.Value, len(columnTypes))
for i := 0; i < len(columnTypes); i++ {
valueRefs[i] = reflect.New(columnTypes[i].ScanType())
values[i] = valueRefs[i].Interface()
}
record := make([]string, len(values))
csvWriter := csv.NewWriter(io.MultiWriter(w, os.Stdout))
for rows.Next() {
if err = rows.Err(); err != nil {
break
}
err = rows.Scan(values...)
if err != nil {
break
}
for i, valueRef := range valueRefs {
// NB: Does not work for BLOBs. Use regular copy if there are BLOB columns
record[i] = fmt.Sprintf("%v", valueRef.Elem().Interface())
}
err = csvWriter.Write(record)
if err != nil {
break
}
}
if err == nil {
csvWriter.Flush()
err = csvWriter.Error()
}
w.CloseWithError(err) // same as w.Close(), if err is nil
}
2 changes: 1 addition & 1 deletion drivers/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func init() {
NewMetadataWriter: func(db drivers.DB, w io.Writer, opts ...metadata.ReaderOption) metadata.Writer {
return metadata.NewDefaultWriter(mymeta.NewReader(db, opts...))(db, w)
},
Copy: drivers.CopyWithInsert(func(int) string { return "?" }),
Copy: copyRows,
NewCompleter: mymeta.NewCompleter,
}, "memsql", "vitess", "tidb")
}

0 comments on commit 6f73721

Please sign in to comment.