Skip to content

Commit c34acf9

Browse files
nemithtz70s
authored andcommitted
Add support for OK packets representing EOF
Fixes: #805
1 parent 46351a8 commit c34acf9

File tree

2 files changed

+84
-36
lines changed

2 files changed

+84
-36
lines changed

connection.go

+11-14
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
180180

181181
// Read Result
182182
columnCount, err := stmt.readPrepareResultPacket()
183-
if err == nil {
184-
if stmt.paramCount > 0 {
185-
if err = mc.readUntilEOF(); err != nil {
186-
return nil, err
187-
}
188-
}
183+
if err != nil {
184+
return stmt, err
185+
}
189186

190-
if columnCount > 0 {
191-
err = mc.readUntilEOF()
192-
}
187+
if err := mc.readPackets(stmt.paramCount); err != nil {
188+
return nil, err
189+
}
190+
191+
if err := mc.readPackets(int(columnCount)); err != nil {
192+
return nil, err
193193
}
194194

195195
return stmt, err
@@ -415,11 +415,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
415415
rows.mc = mc
416416
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
417417

418-
if resLen > 0 {
419-
// Columns
420-
if err := mc.readUntilEOF(); err != nil {
421-
return nil, err
422-
}
418+
if err := mc.readPackets(resLen); err != nil {
419+
return nil, err
423420
}
424421

425422
dest := make([]driver.Value, resLen)

packets.go

+73-22
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
235235
if len(data) > pos {
236236
// character set [1 byte]
237237
// status flags [2 bytes]
238+
pos += 1 + 2
239+
238240
// capability flags (upper 2 bytes) [2 bytes]
241+
mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
242+
pos += 2
243+
239244
// length of auth-plugin-data [1 byte]
240245
// reserved (all [00]) [10 bytes]
241-
pos += 1 + 2 + 2 + 1 + 10
246+
pos += +1 + 10
242247

243248
// second part of the password cipher [mininum 13 bytes],
244249
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -286,6 +291,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
286291
clientLocalFiles |
287292
clientPluginAuth |
288293
clientMultiResults |
294+
mc.flags&clientDeprecateEOF |
289295
mc.flags&clientLongFlag
290296

291297
if mc.cfg.ClientFoundRows {
@@ -610,18 +616,19 @@ func readStatus(b []byte) statusFlag {
610616
// Ok Packet
611617
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
612618
func (mc *mysqlConn) handleOkPacket(data []byte) error {
613-
var n, m int
614-
615-
// 0x00 [1 byte]
616-
619+
// 0x00 or 0xFE [1 byte]
620+
n := 1
621+
var l int
617622
// Affected rows [Length Coded Binary]
618-
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
623+
mc.affectedRows, _, l = readLengthEncodedInteger(data[n:])
624+
n += l
619625

620626
// Insert id [Length Coded Binary]
621-
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
627+
mc.insertId, _, l = readLengthEncodedInteger(data[n:])
628+
n += l
622629

623630
// server_status [2 bytes]
624-
mc.status = readStatus(data[1+n+m : 1+n+m+2])
631+
mc.status = readStatus(data[n : n+2])
625632
if mc.status&statusMoreResultsExists != 0 {
626633
return nil
627634
}
@@ -631,19 +638,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
631638
return nil
632639
}
633640

641+
// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet
642+
// acting as an EOF.
643+
func isEOFPacket(data []byte) bool {
644+
return data[0] == iEOF && len(data) < 9
645+
}
646+
634647
// Read Packets as Field Packets until EOF-Packet or an Error appears
635648
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
636649
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
637650
columns := make([]mysqlField, count)
638651

639-
for i := 0; ; i++ {
652+
for i := 0; i < count; i++ {
640653
data, err := mc.readPacket()
641654
if err != nil {
642655
return nil, err
643656
}
644657

645-
// EOF Packet
646-
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
658+
if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) {
647659
if i == count {
648660
return columns, nil
649661
}
@@ -729,9 +741,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
729741
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
730742
//}
731743
}
744+
return columns, nil
732745
}
733746

734-
// Read Packets as Field Packets until EOF-Packet or an Error appears
747+
// Read Packets as Field Packets until EOF/OK-Packet or an Error appears
735748
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
736749
func (rows *textRows) readRow(dest []driver.Value) error {
737750
mc := rows.mc
@@ -746,9 +759,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
746759
}
747760

748761
// EOF Packet
749-
if data[0] == iEOF && len(data) == 5 {
750-
// server_status [2 bytes]
751-
rows.mc.status = readStatus(data[3:])
762+
if isEOFPacket(data) {
763+
if mc.flags&clientDeprecateEOF == 0 {
764+
// server_status [2 bytes]
765+
rows.mc.status = readStatus(data[3:])
766+
} else {
767+
if err := mc.handleOkPacket(data); err != nil {
768+
return err
769+
}
770+
}
752771
rows.rs.done = true
753772
if !rows.HasNextResultSet() {
754773
rows.mc = nil
@@ -808,18 +827,44 @@ func (mc *mysqlConn) readUntilEOF() error {
808827
return err
809828
}
810829

811-
switch data[0] {
812-
case iERR:
830+
switch {
831+
case data[0] == iERR:
813832
return mc.handleErrorPacket(data)
814-
case iEOF:
815-
if len(data) == 5 {
833+
case isEOFPacket(data):
834+
if mc.flags&clientDeprecateEOF == 0 {
816835
mc.status = readStatus(data[3:])
836+
} else {
837+
return mc.handleOkPacket(data)
817838
}
818839
return nil
819840
}
820841
}
821842
}
822843

844+
func (mc *mysqlConn) readPackets(num int) error {
845+
846+
// we need to read EOF as well
847+
if mc.flags&clientDeprecateEOF == 0 {
848+
num++
849+
}
850+
851+
for i := 0; i < num; i++ {
852+
data, err := mc.readPacket()
853+
if err != nil {
854+
return err
855+
}
856+
857+
switch {
858+
case data[0] == iERR:
859+
return mc.handleErrorPacket(data)
860+
case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data):
861+
mc.status = readStatus(data[3:])
862+
return nil
863+
}
864+
}
865+
return nil
866+
}
867+
823868
/******************************************************************************
824869
* Prepared Statements *
825870
******************************************************************************/
@@ -1178,15 +1223,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
11781223

11791224
// packet indicator [1 byte]
11801225
if data[0] != iOK {
1181-
// EOF Packet
1182-
if data[0] == iEOF && len(data) == 5 {
1183-
rows.mc.status = readStatus(data[3:])
1226+
if isEOFPacket(data) {
1227+
if rows.mc.flags&clientDeprecateEOF == 0 {
1228+
rows.mc.status = readStatus(data[3:])
1229+
} else {
1230+
if err := rows.mc.handleOkPacket(data); err != nil {
1231+
return err
1232+
}
1233+
}
11841234
rows.rs.done = true
11851235
if !rows.HasNextResultSet() {
11861236
rows.mc = nil
11871237
}
11881238
return io.EOF
11891239
}
1240+
11901241
mc := rows.mc
11911242
rows.mc = nil
11921243

0 commit comments

Comments
 (0)