9
9
"crypto/sha1"
10
10
"crypto/x509"
11
11
"encoding/pem"
12
+ goErrors "errors"
12
13
"io"
13
14
"net"
14
15
"sync"
@@ -108,28 +109,10 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
108
109
109
110
if c .Compression != MYSQL_COMPRESS_NONE {
110
111
if ! c .compressedReaderActive {
111
- if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
112
- return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
113
- }
114
-
115
- compressedSequence := c .compressedHeader [3 ]
116
- uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
117
- if compressedSequence != c .CompressedSequence {
118
- return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
119
- compressedSequence , c .CompressedSequence )
120
- }
121
-
122
- if uncompressedLength > 0 {
123
- var err error
124
- switch c .Compression {
125
- case MYSQL_COMPRESS_ZLIB :
126
- c .compressedReader , err = zlib .NewReader (c .reader )
127
- case MYSQL_COMPRESS_ZSTD :
128
- c .compressedReader , err = zstd .NewReader (c .reader )
129
- }
130
- if err != nil {
131
- return nil , err
132
- }
112
+ var err error
113
+ c .compressedReader , err = c .newCompressedPacketReader ()
114
+ if err != nil {
115
+ return nil , err
133
116
}
134
117
c .compressedReaderActive = true
135
118
}
@@ -167,6 +150,32 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
167
150
return result , nil
168
151
}
169
152
153
+ func (c * Conn ) newCompressedPacketReader () (io.Reader , error ) {
154
+ if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
155
+ return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
156
+ }
157
+
158
+ compressedSequence := c .compressedHeader [3 ]
159
+ if compressedSequence != c .CompressedSequence {
160
+ return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
161
+ compressedSequence , c .CompressedSequence )
162
+ }
163
+
164
+ compressedLength := int (uint32 (c .compressedHeader [0 ]) | uint32 (c .compressedHeader [1 ])<< 8 | uint32 (c .compressedHeader [2 ])<< 16 )
165
+ uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
166
+ if uncompressedLength > 0 {
167
+ limitedReader := io .LimitReader (c .reader , int64 (compressedLength ))
168
+ switch c .Compression {
169
+ case MYSQL_COMPRESS_ZLIB :
170
+ return zlib .NewReader (limitedReader )
171
+ case MYSQL_COMPRESS_ZSTD :
172
+ return zstd .NewReader (limitedReader )
173
+ }
174
+ }
175
+
176
+ return nil , nil
177
+ }
178
+
170
179
func (c * Conn ) copyN (dst io.Writer , src io.Reader , n int64 ) (written int64 , err error ) {
171
180
for n > 0 {
172
181
bcap := cap (c .copyNBuf )
@@ -175,8 +184,20 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
175
184
}
176
185
buf := c .copyNBuf [:bcap ]
177
186
178
- rd , err := io .ReadAtLeast (src , buf , bcap )
187
+ var rd int
188
+ rd , err = io .ReadAtLeast (src , buf , bcap )
189
+
179
190
n -= int64 (rd )
191
+ // if we've read to EOF and we have compression then advance the sequence number
192
+ // and reset the compressed reader to continue reading the remaining bytes
193
+ // in the next compressed packet.
194
+
195
+ if goErrors .Is (err , io .ErrUnexpectedEOF ) && c .Compression != MYSQL_COMPRESS_NONE && rd < bcap {
196
+ c .CompressedSequence ++
197
+ c .compressedReader , err = c .newCompressedPacketReader ()
198
+ rd , err = io .ReadAtLeast (c .compressedReader , buf [rd :], bcap - rd )
199
+ n -= int64 (rd )
200
+ }
180
201
181
202
if err != nil {
182
203
return written , errors .Trace (err )
0 commit comments