diff --git a/conncheck.go b/conncheck.go index 0ea721720..35ff1eac6 100644 --- a/conncheck.go +++ b/conncheck.go @@ -12,17 +12,14 @@ package mysql import ( - "errors" - "io" + "fmt" "net" "syscall" -) -var errUnexpectedRead = errors.New("unexpected read from socket") + "golang.org/x/sys/unix" +) func connCheck(conn net.Conn) error { - var sysErr error - sysConn, ok := conn.(syscall.Conn) if !ok { return nil @@ -32,24 +29,22 @@ func connCheck(conn net.Conn) error { return err } - err = rawConn.Read(func(fd uintptr) bool { - var buf [1]byte - n, err := syscall.Read(int(fd), buf[:]) - switch { - case n == 0 && err == nil: - sysErr = io.EOF - case n > 0: - sysErr = errUnexpectedRead - case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: - sysErr = nil - default: - sysErr = err + var pollErr error + err = rawConn.Control(func(fd uintptr) { + fds := []unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN | unix.POLLERR}, + } + n, err := unix.Poll(fds, 0) + if err != nil { + pollErr = fmt.Errorf("poll: %w", err) + } + if n > 0 { + // fmt.Errorf("poll: %v", fds[0].Revents) + pollErr = errUnexpectedEvent } - return true }) if err != nil { return err } - - return sysErr + return pollErr } diff --git a/go.mod b/go.mod index 77bbb8dbf..2d95971d5 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/go-sql-driver/mysql go 1.18 + +require golang.org/x/sys v0.10.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..55a3ff26d --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/packets.go b/packets.go index 66635c55b..ad757a346 100644 --- a/packets.go +++ b/packets.go @@ -14,6 +14,7 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "math" @@ -44,12 +45,24 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { + var syncErr error if data[3] > mc.sequence { - return nil, ErrPktSyncMul + syncErr = ErrPktSyncMul + } else { + syncErr = ErrPktSync } - return nil, ErrPktSync + + if prevData != nil { + return nil, syncErr + } else { + // log and ignore seqno mismatch error. + // MySQL sometimes sends wrong sequence no. + mc.cfg.Logger.Print(syncErr) + mc.sequence = data[3] + 1 + } + } else { + mc.sequence++ } - mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long @@ -89,6 +102,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } } +// used in conncheck.go +var errUnexpectedEvent = errors.New("recieved unexpected event") + // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 @@ -111,18 +127,29 @@ func (mc *mysqlConn) writePacket(data []byte) error { } var err error if mc.cfg.CheckConnLiveness { - if mc.cfg.ReadTimeout != 0 { - err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) - } - if err == nil { - err = connCheck(conn) + err = connCheck(conn) + if err != nil { + if err == errUnexpectedEvent { + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) + var data []byte + data, err = mc.readPacket() + + if err == nil { + if data[0] == iERR { + err = mc.handleErrorPacket(data) + } else { + err = fmt.Errorf("unexpected packet: % x", data[:128]) + } + } else { + err = fmt.Errorf("readPacket(): %w", err) + } + } + + mc.cfg.Logger.Print("checkConn() failed: ", err) + mc.Close() + return driver.ErrBadConn } } - if err != nil { - mc.cfg.Logger.Print("closing bad idle connection: ", err) - mc.Close() - return driver.ErrBadConn - } } for { diff --git a/packets_test.go b/packets_test.go index f429087e9..5a6882fcc 100644 --- a/packets_test.go +++ b/packets_test.go @@ -11,6 +11,7 @@ package mysql import ( "bytes" "errors" + "fmt" "net" "testing" "time" @@ -132,31 +133,57 @@ func TestReadPacketSingleByte(t *testing.T) { } } +type mockLogger struct { + bytes.Buffer +} + +func (ml *mockLogger) Print(v ...any) { + ml.WriteString(fmt.Sprint(v...) + "\n") +} + func TestReadPacketWrongSequenceID(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ buf: newBuffer(conn), + cfg: NewConfig(), } + logger := &mockLogger{} + mc.cfg.Logger = Logger(logger) // too low sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 mc.sequence = 1 - _, err := mc.readPacket() - if err != ErrPktSync { - t.Errorf("expected ErrPktSync, got %v", err) + data, err := mc.readPacket() + if err != nil { + t.Errorf("expected nil, got %v", err) + } + if len(data) != 1 || data[0] != 0xff { + t.Errorf("expected [0xff], got % x", data) + } + logMsg := logger.String() + if logMsg != ErrPktSync.Error()+"\n" { + t.Errorf("expected ErrPktSync.Error(), got %q", logMsg) } // reset conn.reads = 0 mc.sequence = 0 mc.buf = newBuffer(conn) + logger.Reset() // too high sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} - _, err = mc.readPacket() - if err != ErrPktSyncMul { - t.Errorf("expected ErrPktSyncMul, got %v", err) + data, err = mc.readPacket() + if err != nil { + t.Errorf("expected nil, got %v", err) + } + if len(data) != 1 || data[0] != 0xff { + t.Errorf("expected [0xff], got % x", data) + } + logMsg = logger.String() + if logMsg != ErrPktSyncMul.Error()+"\n" { + t.Errorf("expected ErrPktSync.Error(), got %q", logMsg) } }