Skip to content

Commit bee9460

Browse files
authoredMar 17, 2025··
fix(engine): IsRunning returns true when it's ready (#1301)
* added Listener() method for network transport * check Listener() in IsRunning()
1 parent a7c3cf1 commit bee9460

File tree

9 files changed

+81
-79
lines changed

9 files changed

+81
-79
lines changed
 

‎pkg/app/client/client_test.go

+1-16
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,10 @@ func nextUnixSock() string {
111111
func waitEngineRunning(e *route.Engine) {
112112
for i := 0; i < 100; i++ {
113113
if e.IsRunning() {
114-
break
114+
return
115115
}
116116
time.Sleep(10 * time.Millisecond)
117117
}
118-
opts := e.GetOptions()
119-
network, addr := opts.Network, opts.Addr
120-
if network == "" {
121-
network = "tcp"
122-
}
123-
for i := 0; i < 100; i++ {
124-
conn, err := net.Dial(network, addr)
125-
if err != nil {
126-
time.Sleep(10 * time.Millisecond)
127-
continue
128-
}
129-
conn.Close()
130-
return
131-
}
132-
133118
panic("not running")
134119
}
135120

‎pkg/app/server/hertz_test.go

+1-22
Original file line numberDiff line numberDiff line change
@@ -53,30 +53,15 @@ import (
5353

5454
type routeEngine interface {
5555
IsRunning() bool
56-
GetOptions() *config.Options
5756
}
5857

5958
func waitEngineRunning(e routeEngine) {
6059
for i := 0; i < 100; i++ {
6160
if e.IsRunning() {
62-
break
61+
return
6362
}
6463
time.Sleep(10 * time.Millisecond)
6564
}
66-
opts := e.GetOptions()
67-
network, addr := opts.Network, opts.Addr
68-
if network == "" {
69-
network = "tcp"
70-
}
71-
for i := 0; i < 100; i++ {
72-
conn, err := net.Dial(network, addr)
73-
if err != nil {
74-
time.Sleep(10 * time.Millisecond)
75-
continue
76-
}
77-
conn.Close()
78-
return
79-
}
8065
panic("not running")
8166
}
8267

@@ -840,15 +825,9 @@ type CloseWithoutResetBuffer interface {
840825
}
841826

842827
func TestOnprepare(t *testing.T) {
843-
n := int32(0)
844828
h1 := New(
845829
WithHostPorts("localhost:9333"),
846830
WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
847-
if atomic.AddInt32(&n, 1) == 1 {
848-
// the 1st connection is from waitEngineRunning
849-
conn.Close()
850-
return ctx
851-
}
852831
b, err := conn.Peek(3)
853832
assert.Nil(t, err)
854833
assert.DeepEqual(t, string(b), "GET")

‎pkg/network/netpoll/dial_test.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,17 @@ import (
2727
"github.com/cloudwego/hertz/pkg/common/config"
2828
"github.com/cloudwego/hertz/pkg/common/test/assert"
2929
"github.com/cloudwego/hertz/pkg/common/test/mock"
30+
"github.com/cloudwego/hertz/pkg/network"
3031
)
3132

33+
func getListenerAddr(trans network.Transporter) string {
34+
return trans.(*transporter).Listener().Addr().String()
35+
}
36+
3237
func TestDial(t *testing.T) {
3338
t.Run("NetpollDial", func(t *testing.T) {
3439
const nw = "tcp"
35-
const addr = "localhost:10100"
40+
var addr = "127.0.0.1:0"
3641
transporter := NewTransporter(&config.Options{
3742
Addr: addr,
3843
Network: nw,
@@ -48,6 +53,7 @@ func TestDial(t *testing.T) {
4853
_, err := dial.DialConnection("tcp", "localhost:10101", time.Second, nil) // wrong addr
4954
assert.NotNil(t, err)
5055

56+
addr = getListenerAddr(transporter)
5157
nwConn, err := dial.DialConnection(nw, addr, time.Second, nil)
5258
assert.Nil(t, err)
5359
defer nwConn.Close()

‎pkg/network/netpoll/transport.go

+26-18
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,19 @@ func cancelContext(ctx context.Context) context.Context {
4747
}
4848

4949
type transporter struct {
50-
sync.RWMutex
5150
senseClientDisconnection bool
5251
network string
5352
addr string
5453
keepAliveTimeout time.Duration
5554
readTimeout time.Duration
5655
writeTimeout time.Duration
57-
listener net.Listener
58-
eventLoop netpoll.EventLoop
5956
listenConfig *net.ListenConfig
6057
OnAccept func(conn net.Conn) context.Context
6158
OnConnect func(ctx context.Context, conn network.Conn) context.Context
59+
60+
mu sync.RWMutex
61+
ln net.Listener
62+
el netpoll.EventLoop
6263
}
6364

6465
// For transporter switch
@@ -70,23 +71,31 @@ func NewTransporter(options *config.Options) network.Transporter {
7071
keepAliveTimeout: options.KeepAliveTimeout,
7172
readTimeout: options.ReadTimeout,
7273
writeTimeout: options.WriteTimeout,
73-
listener: nil,
74-
eventLoop: nil,
7574
listenConfig: options.ListenConfig,
7675
OnAccept: options.OnAccept,
7776
OnConnect: options.OnConnect,
7877
}
7978
}
8079

80+
func (t *transporter) Listener() net.Listener {
81+
t.mu.RLock()
82+
defer t.mu.RUnlock()
83+
return t.ln
84+
}
85+
8186
// ListenAndServe binds listen address and keep serving, until an error occurs
8287
// or the transport shutdowns
8388
func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
8489
network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
90+
91+
t.mu.Lock()
8592
if t.listenConfig != nil {
86-
t.listener, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
93+
t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
8794
} else {
88-
t.listener, err = net.Listen(t.network, t.addr)
95+
t.ln, err = net.Listen(t.network, t.addr)
8996
}
97+
ln := t.ln
98+
t.mu.Unlock()
9099

91100
if err != nil {
92101
panic("create netpoll listener fail: " + err.Error())
@@ -127,20 +136,19 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
127136
}
128137

129138
// Create EventLoop
130-
t.Lock()
131-
t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
139+
t.mu.Lock()
140+
t.el, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
132141
return onReq(ctx, newConn(connection))
133142
}, opts...)
134-
t.Unlock()
143+
eventLoop := t.el
144+
t.mu.Unlock()
135145
if err != nil {
136146
panic("create netpoll event-loop fail")
137147
}
138148

139149
// Start Server
140-
hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.listener.Addr().String())
141-
t.RLock()
142-
err = t.eventLoop.Serve(t.listener)
143-
t.RUnlock()
150+
hlog.SystemLogger().Infof("HTTP server listening on address=%s", ln.Addr().String())
151+
err = eventLoop.Serve(ln)
144152
if err != nil {
145153
panic("netpoll server exit")
146154
}
@@ -160,11 +168,11 @@ func (t *transporter) Close() error {
160168
func (t *transporter) Shutdown(ctx context.Context) error {
161169
defer func() {
162170
network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
163-
t.RUnlock()
171+
t.mu.RUnlock()
164172
}()
165-
t.RLock()
166-
if t.eventLoop == nil {
173+
t.mu.RLock()
174+
if t.el == nil {
167175
return nil
168176
}
169-
return t.eventLoop.Shutdown(ctx)
177+
return t.el.Shutdown(ctx)
170178
}

‎pkg/network/netpoll/transport_test.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ import (
3434

3535
func TestTransport(t *testing.T) {
3636
const nw = "tcp"
37-
const addr = "localhost:10103"
3837
t.Run("TestDefault", func(t *testing.T) {
38+
var addr = "127.0.0.1:0"
3939
var onConnFlag, onAcceptFlag, onDataFlag int32
4040
transporter := NewTransporter(&config.Options{
4141
Addr: addr,
@@ -57,6 +57,8 @@ func TestTransport(t *testing.T) {
5757
defer transporter.Close()
5858
time.Sleep(100 * time.Millisecond)
5959

60+
addr = getListenerAddr(transporter)
61+
6062
dial := NewDialer()
6163
conn, err := dial.DialConnection(nw, addr, time.Second, nil)
6264
assert.Nil(t, err)
@@ -70,6 +72,7 @@ func TestTransport(t *testing.T) {
7072
})
7173

7274
t.Run("TestSenseClientDisconnection", func(t *testing.T) {
75+
var addr = "127.0.0.1:0"
7376
var onReqFlag int32
7477
transporter := NewTransporter(&config.Options{
7578
Addr: addr,
@@ -86,6 +89,8 @@ func TestTransport(t *testing.T) {
8689
defer transporter.Close()
8790
time.Sleep(100 * time.Millisecond)
8891

92+
addr = getListenerAddr(transporter)
93+
8994
dial := NewDialer()
9095
conn, err := dial.DialConnection(nw, addr, time.Second, nil)
9196
assert.Nil(t, err)
@@ -106,7 +111,7 @@ func TestTransport(t *testing.T) {
106111
})
107112
}}
108113
transporter := NewTransporter(&config.Options{
109-
Addr: addr,
114+
Addr: "127.0.0.1:0",
110115
Network: nw,
111116
ListenConfig: listenCfg,
112117
})

‎pkg/network/standard/dial_test.go

+14-6
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,16 @@ import (
3131

3232
"github.com/cloudwego/hertz/pkg/common/config"
3333
"github.com/cloudwego/hertz/pkg/common/test/assert"
34+
"github.com/cloudwego/hertz/pkg/network"
3435
)
3536

37+
func getListenerAddr(trans network.Transporter) string {
38+
return trans.(*transport).Listener().Addr().String()
39+
}
40+
3641
func TestDial(t *testing.T) {
3742
const nw = "tcp"
38-
const addr = "localhost:10104"
43+
var addr = "127.0.0.1:0"
3944
transporter := NewTransporter(&config.Options{
4045
Addr: addr,
4146
Network: nw,
@@ -47,6 +52,8 @@ func TestDial(t *testing.T) {
4752
defer transporter.Close()
4853
time.Sleep(time.Millisecond * 100)
4954

55+
addr = getListenerAddr(transporter)
56+
5057
dial := NewDialer()
5158
_, err := dial.DialConnection(nw, addr, time.Second, nil)
5259
assert.Nil(t, err)
@@ -58,9 +65,9 @@ func TestDial(t *testing.T) {
5865

5966
func TestDialTLS(t *testing.T) {
6067
const nw = "tcp"
61-
const addr = "localhost:10105"
68+
var addr = "127.0.0.1:0"
6269
data := []byte("abcdefg")
63-
listened := make(chan struct{})
70+
listened := make(chan net.Listener)
6471
go func() {
6572
mockTLSServe(nw, addr, func(conn net.Conn) {
6673
defer conn.Close()
@@ -70,7 +77,8 @@ func TestDialTLS(t *testing.T) {
7077
}()
7178

7279
select {
73-
case <-listened:
80+
case ln := <-listened:
81+
addr = ln.Addr().String()
7482
case <-time.After(time.Second * 5):
7583
t.Fatalf("timeout")
7684
}
@@ -99,7 +107,7 @@ func TestDialTLS(t *testing.T) {
99107
assert.DeepEqual(t, string(data), string(buf))
100108
}
101109

102-
func mockTLSServe(nw, addr string, handle func(conn net.Conn), listened chan struct{}) (err error) {
110+
func mockTLSServe(nw, addr string, handle func(conn net.Conn), listened chan net.Listener) (err error) {
103111
certData, keyData, err := generateTestCertificate("")
104112
if err != nil {
105113
return
@@ -119,7 +127,7 @@ func mockTLSServe(nw, addr string, handle func(conn net.Conn), listened chan str
119127
}
120128
defer ln.Close()
121129

122-
listened <- struct{}{}
130+
listened <- ln
123131
for {
124132
conn, err := ln.Accept()
125133
if err != nil {

‎pkg/network/standard/transport.go

+10-11
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,27 @@ type transport struct {
4444
keepAliveTimeout time.Duration
4545
readTimeout time.Duration
4646
handler network.OnData
47-
ln net.Listener
4847
tls *tls.Config
4948
listenConfig *net.ListenConfig
50-
lock sync.Mutex
5149
OnAccept func(conn net.Conn) context.Context
5250
OnConnect func(ctx context.Context, conn network.Conn) context.Context
5351

5452
// active connections. it +1 after accept and -1 after handler returns
5553
active int32
54+
55+
mu sync.RWMutex
56+
ln net.Listener
5657
}
5758

58-
func (t *transport) listener() net.Listener {
59-
t.lock.Lock()
60-
defer t.lock.Unlock()
59+
func (t *transport) Listener() net.Listener {
60+
t.mu.RLock()
61+
defer t.mu.RUnlock()
6162
return t.ln
6263
}
6364

6465
func (t *transport) serve() (err error) {
6566
network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
66-
t.lock.Lock()
67+
t.mu.Lock()
6768
if t.listenConfig != nil {
6869
t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
6970
} else {
@@ -72,7 +73,7 @@ func (t *transport) serve() (err error) {
7273
// fix concurrency issue
7374
// normally listener must not be changed during serve()
7475
ln := t.ln
75-
t.lock.Unlock()
76+
t.mu.Unlock()
7677
if err != nil {
7778
return err
7879
}
@@ -133,11 +134,9 @@ func (t *transport) Shutdown(ctx context.Context) error {
133134
defer func() {
134135
network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
135136
}()
136-
t.lock.Lock()
137-
if t.ln != nil {
138-
_ = t.ln.Close()
137+
if ln := t.Listener(); ln != nil {
138+
_ = ln.Close()
139139
}
140-
t.lock.Unlock()
141140

142141
tk := time.NewTicker(shutdownTicker)
143142
defer tk.Stop()

‎pkg/network/standard/transport_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ func TestTransporter(t *testing.T) {
5555
<-handlerExit
5656
return nil
5757
})
58-
for trans.listener() == nil { // wait server up
58+
for trans.Listener() == nil { // wait server up
5959
time.Sleep(5 * time.Millisecond)
6060
}
6161

6262
// dial and test
63-
c, err := net.Dial("tcp", trans.listener().Addr().String())
63+
c, err := net.Dial("tcp", trans.Listener().Addr().String())
6464
if err != nil {
6565
t.Fatal(err)
6666
}

‎pkg/route/engine.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import (
4747
"fmt"
4848
"html/template"
4949
"io"
50+
"net"
5051
"path/filepath"
5152
"reflect"
5253
"runtime"
@@ -258,7 +259,18 @@ func (engine *Engine) IsStreamRequestBody() bool {
258259
}
259260

260261
func (engine *Engine) IsRunning() bool {
261-
return atomic.LoadUint32(&engine.status) == statusRunning
262+
if atomic.LoadUint32(&engine.status) != statusRunning {
263+
return false
264+
}
265+
// double check listener
266+
type ListenerIface interface {
267+
Listener() net.Listener
268+
}
269+
v, ok := engine.transport.(ListenerIface)
270+
if ok {
271+
return v.Listener() != nil
272+
}
273+
return true // default behavior if no ListenerIface
262274
}
263275

264276
func (engine *Engine) HijackConnHandle(c network.Conn, h app.HijackHandler) {

0 commit comments

Comments
 (0)
Please sign in to comment.