Skip to content

Commit d082741

Browse files
author
dvilaverde
committed
fixing bad connection error when reading large compressed packets
1 parent 8551be2 commit d082741

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

packet/conn.go

+44-23
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"crypto/sha1"
1010
"crypto/x509"
1111
"encoding/pem"
12+
goErrors "errors"
1213
"io"
1314
"net"
1415
"sync"
@@ -108,28 +109,10 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
108109

109110
if c.Compression != MYSQL_COMPRESS_NONE {
110111
if !c.compressedReaderActive {
111-
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
112-
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
113-
}
114-
115-
compressedSequence := c.compressedHeader[3]
116-
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
117-
if compressedSequence != c.CompressedSequence {
118-
return nil, errors.Errorf("invalid compressed sequence %d != %d",
119-
compressedSequence, c.CompressedSequence)
120-
}
121-
122-
if uncompressedLength > 0 {
123-
var err error
124-
switch c.Compression {
125-
case MYSQL_COMPRESS_ZLIB:
126-
c.compressedReader, err = zlib.NewReader(c.reader)
127-
case MYSQL_COMPRESS_ZSTD:
128-
c.compressedReader, err = zstd.NewReader(c.reader)
129-
}
130-
if err != nil {
131-
return nil, err
132-
}
112+
var err error
113+
c.compressedReader, err = c.newCompressedPacketReader()
114+
if err != nil {
115+
return nil, err
133116
}
134117
c.compressedReaderActive = true
135118
}
@@ -167,6 +150,32 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
167150
return result, nil
168151
}
169152

153+
func (c *Conn) newCompressedPacketReader() (io.Reader, error) {
154+
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
155+
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
156+
}
157+
158+
compressedSequence := c.compressedHeader[3]
159+
if compressedSequence != c.CompressedSequence {
160+
return nil, errors.Errorf("invalid compressed sequence %d != %d",
161+
compressedSequence, c.CompressedSequence)
162+
}
163+
164+
compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16)
165+
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
166+
if uncompressedLength > 0 {
167+
limitedReader := io.LimitReader(c.reader, int64(compressedLength))
168+
switch c.Compression {
169+
case MYSQL_COMPRESS_ZLIB:
170+
return zlib.NewReader(limitedReader)
171+
case MYSQL_COMPRESS_ZSTD:
172+
return zstd.NewReader(limitedReader)
173+
}
174+
}
175+
176+
return nil, nil
177+
}
178+
170179
func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
171180
for n > 0 {
172181
bcap := cap(c.copyNBuf)
@@ -175,8 +184,20 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
175184
}
176185
buf := c.copyNBuf[:bcap]
177186

178-
rd, err := io.ReadAtLeast(src, buf, bcap)
187+
var rd int
188+
rd, err = io.ReadAtLeast(src, buf, bcap)
189+
179190
n -= int64(rd)
191+
// if we've read to EOF and we have compression then advance the sequence number
192+
// and reset the compressed reader to continue reading the remaining bytes
193+
// in the next compressed packet.
194+
195+
if goErrors.Is(err, io.ErrUnexpectedEOF) && c.Compression != MYSQL_COMPRESS_NONE && rd < bcap {
196+
c.CompressedSequence++
197+
c.compressedReader, err = c.newCompressedPacketReader()
198+
rd, err = io.ReadAtLeast(c.compressedReader, buf[rd:], bcap-rd)
199+
n -= int64(rd)
200+
}
180201

181202
if err != nil {
182203
return written, errors.Trace(err)

0 commit comments

Comments
 (0)