Skip to content

Support optional resultset metadata #1150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
@@ -95,6 +95,7 @@ Tan Jinhua <312841925 at qq.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Tzu-Chiao Yeh <su3g4284zo6y7 at gmail.com>
Vladimir Kovpak <cn007b at gmail.com>
Vladyslav Zhelezniak <zhvladi at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -399,6 +399,7 @@ Examples:
* `autocommit=1`: `SET autocommit=1`
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
* [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'`
* metata=none`](https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_resultset_metadata): `SET resultset_metadata=none` (note that this is only applicable to MySQL 8.0+ versions).


#### Examples
49 changes: 35 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
@@ -21,20 +21,22 @@ import (
)

type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
optionalResultSetMetadata bool
resultSetMetadata uint8

// for context support (Go 1.8+)
watching bool
@@ -392,6 +394,10 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
}
}

if mc.optionalResultSetMetadata && mc.resultSetMetadata == resultSetMetadataNone {
return mc.readIgnoreColumns(rows, resLen)
}

// Columns
rows.rs.columns, err = mc.readColumns(resLen)
return rows, err
@@ -400,6 +406,21 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
return nil, mc.markBadConn(err)
}

func (mc *mysqlConn) readIgnoreColumns(rows *textRows, resLen int) (*textRows, error) {
data, err := mc.readPacket()
if err != nil {
errLog.Print(err)
return nil, err
}
// Expected an EOF packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
// Set empty columnNames, we will first read these columnNames via rows.Columns().
rows.rs.columnNames = make([]string, resLen)
return rows, nil
}
return nil, ErrOptionalResultSetMetadataPkt
}

// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
15 changes: 15 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ import (
"context"
"database/sql/driver"
"net"
"strings"
)

type connector struct {
@@ -88,6 +89,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
plugin = defaultAuthPlugin
}

// Set the optionalResultSetMetadata ahead to set the client capability flag.
if resultSetMetadata, ok := mc.cfg.Params["resultset_metadata"]; ok {
upperVal := strings.ToUpper(resultSetMetadata)
switch upperVal {
case resultSetMetadataSysVarNone:
mc.optionalResultSetMetadata = true
mc.resultSetMetadata = resultSetMetadataNone
case resultSetMetadataSysVarFull:
mc.optionalResultSetMetadata = true
mc.resultSetMetadata = resultSetMetadataFull
}
// To be consistent with other params, in case the param is passed wrongly still send to MySQL to let the server side rejects it.
}

// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
14 changes: 14 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ const (
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
clientOptionalResultSetMetadata
)

const (
@@ -172,3 +173,16 @@ const (
cachingSha2PasswordFastAuthSuccess = 3
cachingSha2PasswordPerformFullAuthentication = 4
)

const (
// One-byte metadata flag
// https://dev.mysql.com/worklog/task/?id=8134
resultSetMetadataNone uint8 = iota
resultSetMetadataFull
)

const (
// ResultSet Metadata system var
resultSetMetadataSysVarNone = "NONE"
resultSetMetadataSysVarFull = "FULL"
)
45 changes: 45 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ var (
prot string
addr string
dbname string
vendor string
dsn string
netAddr string
available bool
@@ -202,6 +203,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
func maybeSkip(t *testing.T, err error, skipErrno uint16) {
mySQLErr, ok := err.(*MySQLError)
if !ok {
errLog.Print("non match")
return
}

@@ -1345,6 +1347,49 @@ func TestFoundRows(t *testing.T) {
})
}

func TestOptionalResultSetMetadata(t *testing.T) {
runTests(t, dsn+"&resultset_metadata=none", func(dbt *DBTest) {
_, err := dbt.db.Exec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
if err == ErrNoOptionalResultMetadataSet {
t.Skip("server does not support resultset metadata")
} else if err != nil {
dbt.Fatal(err)
}
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")

row := dbt.db.QueryRow("SELECT id, data FROM test WHERE id = 1")
id, data := 0, 0
err = row.Scan(&id, &data)
if err != nil {
dbt.Fatal(err)
}

if id != 1 && data != 0 {
dbt.Fatal("invalid result")
}
})
runTests(t, dsn+"&resultset_metadata=full", func(dbt *DBTest) {
_, err := dbt.db.Exec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
if err == ErrNoOptionalResultMetadataSet {
t.Skip("server does not support resultset metadata")
} else if err != nil {
dbt.Fatal(err)
}
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")

row := dbt.db.QueryRow("SELECT id, data FROM test WHERE id = 1")
id, data := 0, 0
err = row.Scan(&id, &data)
if err != nil {
dbt.Fatal(err)
}

if id != 1 && data != 0 {
dbt.Fatal("invalid result")
}
})
}

func TestTLS(t *testing.T) {
tlsTestReq := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
32 changes: 16 additions & 16 deletions dsn.go
Original file line number Diff line number Diff line change
@@ -34,22 +34,22 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
3 changes: 3 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
@@ -44,6 +44,9 @@ var testDSNs = []struct {
}, {
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
}, {
"user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false},
}, {
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true},
26 changes: 14 additions & 12 deletions errors.go
Original file line number Diff line number Diff line change
@@ -17,18 +17,20 @@ import (

// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication.")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrNativePassword = errors.New("this user requires mysql native password authentication")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
ErrNoOptionalResultMetadataSet = errors.New("requested optional resultset metadata but server does not support")
ErrOptionalResultSetMetadataPkt = errors.New("malformed optional resultset metadata packets")

// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
25 changes: 24 additions & 1 deletion packets.go
Original file line number Diff line number Diff line change
@@ -234,10 +234,18 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2
// capability flags (upper 2 bytes) [2 bytes]
upperFlags := clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
mc.flags |= upperFlags << 16
pos += 2
if mc.flags&clientOptionalResultSetMetadata == 0 && mc.optionalResultSetMetadata {
return nil, "", ErrNoOptionalResultMetadataSet
}

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += 1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -300,6 +308,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientFlags |= clientMultiStatements
}

if mc.optionalResultSetMetadata {
clientFlags |= clientOptionalResultSetMetadata
}

// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
@@ -554,6 +566,17 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
return int(num), nil
}

// Sniff one extra byte for resultset metadata if we set capability
// CLIENT_OPTIONAL_RESULTSET_METADTA
// https://dev.mysql.com/worklog/task/?id=8134
if len(data) == 2 && mc.flags&clientOptionalResultSetMetadata != 0 {
// ResultSet metadata flag check
if mc.resultSetMetadata != data[1] {
return 0, ErrOptionalResultSetMetadataPkt
}
return int(num), nil
}

return 0, ErrMalformPkt
}
return 0, err