Skip to content

Commit 33cb7b8

Browse files
committed
ability to send multiple resultsets from handleQuery
1 parent 0c5789d commit 33cb7b8

File tree

4 files changed

+56
-25
lines changed

4 files changed

+56
-25
lines changed

server/caching_sha2_cache_test.go

+11-6
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (s *cacheTestSuite) onAccept(c *C) {
105105
}
106106

107107
func (s *cacheTestSuite) onConn(conn net.Conn, c *C) {
108-
//co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
108+
// co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
109109
co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testCacheHandler{s})
110110
c.Assert(err, IsNil)
111111
for {
@@ -137,7 +137,7 @@ func (s *cacheTestSuite) TestCache(c *C) {
137137
t2 := time.Now()
138138

139139
d1 := int(t2.Sub(t1).Nanoseconds() / 1e6)
140-
//log.Debugf("first connection took %d milliseconds", d1)
140+
// log.Debugf("first connection took %d milliseconds", d1)
141141

142142
c.Assert(d1, GreaterEqual, delay)
143143

@@ -154,7 +154,7 @@ func (s *cacheTestSuite) TestCache(c *C) {
154154
t4 := time.Now()
155155

156156
d2 := int(t4.Sub(t3).Nanoseconds() / 1e6)
157-
//log.Debugf("second connection took %d milliseconds", d2)
157+
// log.Debugf("second connection took %d milliseconds", d2)
158158

159159
c.Assert(d2, Less, delay)
160160
if s.db != nil {
@@ -178,7 +178,7 @@ func (h *testCacheHandler) handleQuery(query string, binary bool) (*mysql.Result
178178
case "select":
179179
var r *mysql.Resultset
180180
var err error
181-
//for handle go mysql driver select @@max_allowed_packet
181+
// for handle go mysql driver select @@max_allowed_packet
182182
if strings.Contains(strings.ToLower(query), "max_allowed_packet") {
183183
r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{
184184
{mysql.MaxPayloadLen},
@@ -209,8 +209,13 @@ func (h *testCacheHandler) handleQuery(query string, binary bool) (*mysql.Result
209209
return nil, nil
210210
}
211211

212-
func (h *testCacheHandler) HandleQuery(query string) (*mysql.Result, error) {
213-
return h.handleQuery(query, false)
212+
func (h *testCacheHandler) HandleQuery(query string) ([]*mysql.Result, error) {
213+
res, err := h.handleQuery(query, false)
214+
if err != nil {
215+
return nil, err
216+
}
217+
218+
return []*mysql.Result{res}, nil
214219
}
215220

216221
func (h *testCacheHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {

server/command.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@ import (
99
)
1010

1111
type Handler interface {
12-
//handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
12+
// handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
1313
UseDB(dbName string) error
14-
//handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc...
15-
//If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result
16-
HandleQuery(query string) (*Result, error)
17-
//handle COM_FILED_LIST command
14+
// handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc...
15+
// If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result
16+
HandleQuery(query string) ([]*Result, error)
17+
// handle COM_FILED_LIST command
1818
HandleFieldList(table string, fieldWildcard string) ([]*Field, error)
19-
//handle COM_STMT_PREPARE, params is the param number for this statement, columns is the column number
20-
//context will be used later for statement execute
19+
// handle COM_STMT_PREPARE, params is the param number for this statement, columns is the column number
20+
// context will be used later for statement execute
2121
HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error)
22-
//handle COM_STMT_EXECUTE, context is the previous one set in prepare
23-
//query is the statement prepare query, and args is the params for this statement
22+
// handle COM_STMT_EXECUTE, context is the previous one set in prepare
23+
// query is the statement prepare query, and args is the params for this statement
2424
HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error)
25-
//handle COM_STMT_CLOSE, context is the previous one set in prepare
26-
//this handler has no response
25+
// handle COM_STMT_CLOSE, context is the previous one set in prepare
26+
// this handler has no response
2727
HandleStmtClose(context interface{}) error
28-
//handle any other command that is not currently handled by the library,
29-
//default implementation for this method will return an ER_UNKNOWN_ERROR
28+
// handle any other command that is not currently handled by the library,
29+
// default implementation for this method will return an ER_UNKNOWN_ERROR
3030
HandleOtherCommand(cmd byte, data []byte) error
3131
}
3232

@@ -134,7 +134,7 @@ type EmptyHandler struct {
134134
func (h EmptyHandler) UseDB(dbName string) error {
135135
return nil
136136
}
137-
func (h EmptyHandler) HandleQuery(query string) (*Result, error) {
137+
func (h EmptyHandler) HandleQuery(query string) ([]*Result, error) {
138138
return nil, fmt.Errorf("not supported now")
139139
}
140140

server/resp.go

+20
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,26 @@ func (c *Conn) writeValue(value interface{}) error {
185185
} else {
186186
return c.writeOK(v)
187187
}
188+
case []*Result:
189+
if len(v) == 0 {
190+
return c.writeValue(nil)
191+
}
192+
193+
c.status |= SERVER_MORE_RESULTS_EXISTS
194+
195+
for i, res := range v {
196+
if i == len(v)-1 {
197+
c.status &= ^SERVER_MORE_RESULTS_EXISTS
198+
}
199+
200+
if err := c.writeValue(res); err != nil {
201+
c.status &= ^SERVER_MORE_RESULTS_EXISTS
202+
203+
return err
204+
}
205+
}
206+
207+
return nil
188208
case []*Field:
189209
return c.writeFieldList(v)
190210
case *Stmt:

server/server_test.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func Test(t *testing.T) {
6363
inMemProvider.AddUser(*testUser, *testPassword)
6464

6565
servers := prepareServerConf()
66-
//no TLS
66+
// no TLS
6767
for _, svr := range servers {
6868
Suite(&serverTestSuite{
6969
server: svr,
@@ -135,7 +135,7 @@ func (s *serverTestSuite) onAccept(c *C) {
135135
}
136136

137137
func (s *serverTestSuite) onConn(conn net.Conn, c *C) {
138-
//co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
138+
// co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s})
139139
co, err := NewCustomizedConn(conn, s.server, s.credProvider, &testHandler{s})
140140
c.Assert(err, IsNil)
141141
// set SSL if defined
@@ -225,7 +225,7 @@ func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, err
225225
case "select":
226226
var r *mysql.Resultset
227227
var err error
228-
//for handle go mysql driver select @@max_allowed_packet
228+
// for handle go mysql driver select @@max_allowed_packet
229229
if strings.Contains(strings.ToLower(query), "max_allowed_packet") {
230230
r, err = mysql.BuildSimpleResultset([]string{"@@max_allowed_packet"}, [][]interface{}{
231231
{mysql.MaxPayloadLen},
@@ -256,8 +256,14 @@ func (h *testHandler) handleQuery(query string, binary bool) (*mysql.Result, err
256256
return nil, nil
257257
}
258258

259-
func (h *testHandler) HandleQuery(query string) (*mysql.Result, error) {
260-
return h.handleQuery(query, false)
259+
func (h *testHandler) HandleQuery(query string) ([]*mysql.Result, error) {
260+
res, err := h.handleQuery(query, false)
261+
if err != nil {
262+
return nil, err
263+
264+
}
265+
266+
return []*mysql.Result{res}, nil
261267
}
262268

263269
func (h *testHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) {

0 commit comments

Comments
 (0)