Skip to content

Commit 0e70f03

Browse files
committed
import testify,The server needs to be started explicitly
1 parent dd46246 commit 0e70f03

8 files changed

+197
-228
lines changed

benchmark_test.go

+25-25
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func serverClientSetup() *serverClient {
5353

5454
func (setup *serverClient) Close() {
5555
setup.clientTCPHandler.Close()
56-
setup.slave.Close()
56+
setup.slave.Shutdown()
5757
}
5858

5959
func BenchmarkModbusWrite1968MultipleCoils(b *testing.B) {
@@ -145,7 +145,8 @@ func Example() {
145145
log.Printf("%v\n", err)
146146
return
147147
}
148-
defer serv.Close()
148+
defer serv.Shutdown()
149+
go serv.Serve()
149150

150151
// Wait for the server to start
151152
time.Sleep(1 * time.Millisecond)
@@ -179,37 +180,36 @@ func Example() {
179180

180181
// Override the default ReadDiscreteInputs funtion.
181182
func ExampleServer_RegisterFunctionHandler() {
182-
serv := NewServer()
183-
184183
// Override ReadDiscreteInputs function.
185-
serv.RegisterFunctionHandler(2,
186-
func(s *Server, frame Framer) ([]byte, *Exception) {
187-
register, numRegs, endRegister := registerAddressAndNumber(frame)
188-
// Check the request is within the allocated memory
189-
if endRegister > 65535 {
190-
return []byte{}, &IllegalDataAddress
191-
}
192-
dataSize := numRegs / 8
193-
if (numRegs % 8) != 0 {
194-
dataSize++
195-
}
196-
data := make([]byte, 1+dataSize)
197-
data[0] = byte(dataSize)
198-
for i := range s.DiscreteInputs[register:endRegister] {
199-
// Return all 1s, regardless of the value in the DiscreteInputs array.
200-
shift := uint(i) % 8
201-
data[1+i/8] |= byte(1 << shift)
202-
}
203-
return data, &Success
204-
})
184+
wf := WithRegisterFunction(2, func(s *Server, frame Framer) ([]byte, *Exception) {
185+
register, numRegs, endRegister := registerAddressAndNumber(frame)
186+
// Check the request is within the allocated memory
187+
if endRegister > 65535 {
188+
return []byte{}, &IllegalDataAddress
189+
}
190+
dataSize := numRegs / 8
191+
if (numRegs % 8) != 0 {
192+
dataSize++
193+
}
194+
data := make([]byte, 1+dataSize)
195+
data[0] = byte(dataSize)
196+
for i := range s.DiscreteInputs[register:endRegister] {
197+
// Return all 1s, regardless of the value in the DiscreteInputs array.
198+
shift := uint(i) % 8
199+
data[1+i/8] |= byte(1 << shift)
200+
}
201+
return data, &Success
202+
})
203+
serv := NewServer(wf)
205204

206205
// Start the server.
207206
err := serv.ListenTCP("localhost:4321")
208207
if err != nil {
209208
log.Printf("%v\n", err)
210209
return
211210
}
212-
defer serv.Close()
211+
defer serv.Shutdown()
212+
go serv.Serve()
213213

214214
// Wait for the server to start
215215
time.Sleep(1 * time.Millisecond)

functions_test.go

+28-46
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ package mbserver
22

33
import (
44
"encoding/json"
5+
"errors"
56
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
610
)
711

812
func isEqual(a interface{}, b interface{}) bool {
@@ -43,9 +47,8 @@ func TestReadCoils(t *testing.T) {
4347
// 2 bytes, 0b1000011, 0b00000001
4448
expect := []byte{2, 131, 1}
4549
got := response.GetData()
46-
if !isEqual(expect, got) {
47-
t.Errorf("expected %v, got %v", expect, got)
48-
}
50+
51+
assert.Equal(t, expect, got)
4952
}
5053

5154
// Function 2
@@ -76,9 +79,8 @@ func TestReadDiscreteInputs(t *testing.T) {
7679
}
7780
expect := []byte{2, 129, 3}
7881
got := response.GetData()
79-
if !isEqual(expect, got) {
80-
t.Errorf("expected %v, got %v", expect, got)
81-
}
82+
83+
assert.Equal(t, expect, got)
8284
}
8385

8486
// Function 3
@@ -106,9 +108,7 @@ func TestReadHoldingRegisters(t *testing.T) {
106108
}
107109
expect := []byte{6, 0, 1, 0, 2, 255, 255}
108110
got := response.GetData()
109-
if !isEqual(expect, got) {
110-
t.Errorf("expected %v, got %v", expect, got)
111-
}
111+
assert.Equal(t, expect, got)
112112
}
113113

114114
// Function 4
@@ -136,9 +136,7 @@ func TestReadInputRegisters(t *testing.T) {
136136
}
137137
expect := []byte{6, 0, 1, 0, 2, 255, 255}
138138
got := response.GetData()
139-
if !isEqual(expect, got) {
140-
t.Errorf("expected %v, got %v", expect, got)
141-
}
139+
assert.Equal(t, expect, got)
142140
}
143141

144142
// Function 5
@@ -157,15 +155,11 @@ func TestWriteSingleCoil(t *testing.T) {
157155
req.frame = &frame
158156
response := s.handle(&req)
159157
exception := GetException(response)
160-
if exception != Success {
161-
t.Errorf("expected Success, got %v", exception.String())
162-
t.FailNow()
163-
}
164-
expect := 1
158+
require.Equalf(t, exception, Success, "expected Success, got %v", exception.String())
159+
160+
expect := uint8(1)
165161
got := s.Coils[65535]
166-
if !isEqual(expect, got) {
167-
t.Errorf("expected %v, got %v\n", expect, got)
168-
}
162+
assert.Equal(t, expect, got)
169163
}
170164

171165
// Function 6
@@ -184,15 +178,11 @@ func TestWriteHoldingRegister(t *testing.T) {
184178
req.frame = &frame
185179
response := s.handle(&req)
186180
exception := GetException(response)
187-
if exception != Success {
188-
t.Errorf("expected Success, got %v", exception.String())
189-
t.FailNow()
190-
}
191-
expect := 6
181+
require.Equalf(t, exception, Success, "expected Success, got %v", exception.String())
182+
183+
expect := uint16(6)
192184
got := s.HoldingRegisters[5]
193-
if !isEqual(expect, got) {
194-
t.Errorf("expected %v, got %v\n", expect, got)
195-
}
185+
assert.Equal(t, expect, got)
196186
}
197187

198188
// Function 15
@@ -217,9 +207,7 @@ func TestWriteMultipleCoils(t *testing.T) {
217207
}
218208
expect := []byte{1, 1}
219209
got := s.Coils[1:3]
220-
if !isEqual(expect, got) {
221-
t.Errorf("expected %v, got %v\n", expect, got)
222-
}
210+
assert.Equal(t, expect, got)
223211
}
224212

225213
// Function 16
@@ -244,27 +232,21 @@ func TestWriteHoldingRegisters(t *testing.T) {
244232
}
245233
expect := []uint16{3, 4}
246234
got := s.HoldingRegisters[1:3]
247-
if !isEqual(expect, got) {
248-
t.Errorf("expected %v, got %v\n", expect, got)
249-
}
235+
assert.Equal(t, expect, got)
250236
}
251237

252238
func TestBytesToUint16(t *testing.T) {
253239
bytes := []byte{1, 2, 3, 4}
254240
got := BytesToUint16(bytes)
255241
expect := []uint16{258, 772}
256-
if !isEqual(expect, got) {
257-
t.Errorf("expected %v, got %v\n", expect, got)
258-
}
242+
assert.Equal(t, expect, got)
259243
}
260244

261245
func TestUint16ToBytes(t *testing.T) {
262246
values := []uint16{1, 2, 3}
263247
got := Uint16ToBytes(values)
264248
expect := []byte{0, 1, 0, 2, 0, 3}
265-
if !isEqual(expect, got) {
266-
t.Errorf("expected %v, got %v\n", expect, got)
267-
}
249+
assert.Equal(t, expect, got)
268250
}
269251

270252
func TestOutOfBounds(t *testing.T) {
@@ -285,22 +267,22 @@ func TestOutOfBounds(t *testing.T) {
285267
frame.Function = 1
286268
response := s.handle(&req)
287269
exception := GetException(response)
288-
if exception != IllegalDataAddress {
270+
if !errors.Is(exception, IllegalDataAddress) {
289271
t.Errorf("expected IllegalDataAddress, got %v", exception.String())
290272
}
291273

292274
frame.Function = 2
293275
response = s.handle(&req)
294276
exception = GetException(response)
295-
if exception != IllegalDataAddress {
277+
if !errors.Is(exception, IllegalDataAddress) {
296278
t.Errorf("expected IllegalDataAddress, got %v", exception.String())
297279
}
298280

299281
SetDataWithRegisterAndNumberAndBytes(&frame, 65535, 2, []byte{3})
300282
frame.Function = 15
301283
response = s.handle(&req)
302284
exception = GetException(response)
303-
if exception != IllegalDataAddress {
285+
if !errors.Is(exception, IllegalDataAddress) {
304286
t.Errorf("expected IllegalDataAddress, got %v", exception.String())
305287
}
306288

@@ -310,22 +292,22 @@ func TestOutOfBounds(t *testing.T) {
310292
frame.Function = 3
311293
response = s.handle(&req)
312294
exception = GetException(response)
313-
if exception != IllegalDataAddress {
295+
if !errors.Is(exception, IllegalDataAddress) {
314296
t.Errorf("expected IllegalDataAddress, got %v", exception.String())
315297
}
316298

317299
frame.Function = 4
318300
response = s.handle(&req)
319301
exception = GetException(response)
320-
if exception != IllegalDataAddress {
302+
if !errors.Is(exception, IllegalDataAddress) {
321303
t.Errorf("expected IllegalDataAddress, got %v", exception.String())
322304
}
323305

324306
SetDataWithRegisterAndNumberAndValues(&frame, 65535, 2, []uint16{0, 0})
325307
frame.Function = 16
326308
response = s.handle(&req)
327309
exception = GetException(response)
328-
if exception != IllegalDataAddress {
310+
if !errors.Is(exception, IllegalDataAddress) {
329311
t.Errorf("expected IllegalDataAddress, got %v", exception.String())
330312
}
331313
}

go.mod

+7
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@ go 1.23
55
require (
66
github.com/goburrow/modbus v0.1.0
77
github.com/goburrow/serial v0.1.0
8+
github.com/stretchr/testify v1.9.0
9+
)
10+
11+
require (
12+
github.com/davecgh/go-spew v1.1.1 // indirect
13+
github.com/pmezard/go-difflib v1.0.0 // indirect
14+
gopkg.in/yaml.v3 v3.0.1 // indirect
815
)

go.sum

+10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
13
github.com/goburrow/modbus v0.1.0 h1:DejRZY73nEM6+bt5JSP6IsFolJ9dVcqxsYbpLbeW/ro=
24
github.com/goburrow/modbus v0.1.0/go.mod h1:Kx552D5rLIS8E7TyUwQ/UdHEqvX5T8tyiGBTlzMcZBg=
35
github.com/goburrow/serial v0.1.0 h1:v2T1SQa/dlUqQiYIT8+Cu7YolfqAi3K96UmhwYyuSrA=
46
github.com/goburrow/serial v0.1.0/go.mod h1:sAiqG0nRVswsm1C97xsttiYCzSLBmUZ/VSlVLZJ8haA=
7+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
8+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
9+
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
10+
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
11+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
12+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
13+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
14+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

server.go

+8-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type Server struct {
2424
requestChan chan *Request
2525
l *slog.Logger
2626

27-
function [256](func(*Server, Framer) ([]byte, *Exception))
27+
function [256]function
2828
DiscreteInputs []byte
2929
Coils []byte
3030
HoldingRegisters []uint16
@@ -131,8 +131,6 @@ func NewServer(opts ...OptionFunc) *Server {
131131
s.requestChan = make(chan *Request)
132132
s.closeSignalChan = make(chan struct{})
133133

134-
go s.handler()
135-
136134
return s
137135
}
138136

@@ -147,15 +145,15 @@ func (s *Server) handle(request *Request) Framer {
147145

148146
response := request.frame.Copy()
149147

150-
function := request.frame.GetFunction()
151-
if s.function[function] != nil {
152-
data, exception = s.function[function](s, request.frame)
148+
funcCode := request.frame.GetFunction()
149+
if s.function[funcCode] != nil {
150+
data, exception = s.function[funcCode](s, request.frame)
153151
response.SetData(data)
154152
} else {
155153
exception = &IllegalFunction
156154
}
157155

158-
if exception != &Success {
156+
if *exception == Success {
159157
response.SetException(exception)
160158
}
161159

@@ -169,8 +167,9 @@ func (s *Server) handler() {
169167
case <-s.closeSignalChan:
170168
return
171169
case request := <-s.requestChan:
172-
response := s.handle(request)
173-
request.conn.Write(response.Bytes())
170+
response := s.handle(request)
171+
request.conn.Write(response.Bytes())
172+
}
174173
}
175174
}
176175

0 commit comments

Comments
 (0)