Skip to content

Commit 21e5041

Browse files
committed
Start performing a lot more error checking in Send()
1 parent 4274e69 commit 21e5041

File tree

2 files changed

+55
-25
lines changed

2 files changed

+55
-25
lines changed

client.go

+47-25
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/tls"
66
"io"
77
"log"
8+
"sync"
89
"time"
910
)
1011

@@ -27,20 +28,28 @@ func (b *buffer) Add(v interface{}) *list.Element {
2728
return e
2829
}
2930

31+
type serializedNotif struct {
32+
id uint32
33+
b []byte
34+
}
35+
3036
type Client struct {
3137
Conn *Conn
3238
FailedNotifs chan NotificationResult
3339

34-
notifs chan Notification
35-
id uint32
40+
notifs chan serializedNotif
41+
42+
id uint32
43+
idm sync.Mutex
3644
}
3745

3846
func newClientWithConn(gw string, conn Conn) Client {
3947
c := Client{
4048
Conn: &conn,
4149
FailedNotifs: make(chan NotificationResult),
42-
id: uint32(1),
43-
notifs: make(chan Notification),
50+
notifs: make(chan serializedNotif),
51+
id: 1,
52+
idm: sync.Mutex{},
4453
}
4554

4655
go c.runLoop()
@@ -73,10 +82,37 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err
7382
}
7483

7584
func (c *Client) Send(n Notification) error {
76-
c.notifs <- n
85+
// Set identifier if not specified
86+
if n.Identifier == 0 {
87+
n.Identifier = c.nextID()
88+
} else if c.id < n.Identifier {
89+
c.setID(n.Identifier)
90+
}
91+
92+
b, err := n.ToBinary()
93+
if err != nil {
94+
return err
95+
}
96+
97+
c.notifs <- serializedNotif{b: b, id: n.Identifier}
7798
return nil
7899
}
79100

101+
func (c *Client) setID(n uint32) {
102+
c.idm.Lock()
103+
defer c.idm.Unlock()
104+
105+
c.id = n
106+
}
107+
108+
func (c *Client) nextID() uint32 {
109+
c.idm.Lock()
110+
defer c.idm.Unlock()
111+
112+
c.id++
113+
return c.id
114+
}
115+
80116
func (c *Client) reportFailedPush(v interface{}, err *Error) {
81117
failedNotif, ok := v.(Notification)
82118
if !ok || v == nil {
@@ -93,7 +129,7 @@ func (c *Client) requeue(cursor *list.Element) {
93129
// If `cursor` is not nil, this means there are notifications that
94130
// need to be delivered (or redelivered)
95131
for ; cursor != nil; cursor = cursor.Next() {
96-
if n, ok := cursor.Value.(Notification); ok {
132+
if n, ok := cursor.Value.(serializedNotif); ok {
97133
go func() { c.notifs <- n }()
98134
}
99135
}
@@ -103,11 +139,11 @@ func (c *Client) handleError(err *Error, buffer *buffer) *list.Element {
103139
cursor := buffer.Back()
104140

105141
for cursor != nil {
106-
// Get notification
107-
n, _ := cursor.Value.(Notification)
142+
// Get serialized notification
143+
n, _ := cursor.Value.(serializedNotif)
108144

109145
// If the notification, move cursor after the trouble notification
110-
if n.Identifier == err.Identifier {
146+
if n.id == err.Identifier {
111147
go c.reportFailedPush(cursor.Value, err)
112148

113149
next := cursor.Next()
@@ -143,7 +179,7 @@ func (c *Client) runLoop() {
143179
// Connection open, listen for notifs and errors
144180
for {
145181
var err error
146-
var n Notification
182+
var n serializedNotif
147183

148184
// Check for notifications or errors. There is a chance we'll send notifications
149185
// if we already have an error since `select` will "pseudorandomly" choose a
@@ -169,21 +205,7 @@ func (c *Client) runLoop() {
169205
// Add to list
170206
cursor = sent.Add(n)
171207

172-
// Set identifier if not specified
173-
if n.Identifier == 0 {
174-
n.Identifier = c.id
175-
c.id++
176-
} else if c.id < n.Identifier {
177-
c.id = n.Identifier + 1
178-
}
179-
180-
b, err := n.ToBinary()
181-
if err != nil {
182-
// TODO
183-
continue
184-
}
185-
186-
_, err = c.Conn.Write(b)
208+
_, err = c.Conn.Write(n.b)
187209

188210
if err == io.EOF {
189211
log.Println("EOF trying to write notification")

notification.go

+8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ const (
1515
PriorityPowerConserve = 5
1616
)
1717

18+
const (
19+
validDeviceTokenLength = 64
20+
)
21+
1822
const (
1923
commandID = 2
2024

@@ -93,6 +97,10 @@ func (p *Payload) MarshalJSON() ([]byte, error) {
9397
func (n Notification) ToBinary() ([]byte, error) {
9498
b := []byte{}
9599

100+
if len(n.DeviceToken) != validDeviceTokenLength {
101+
return b, errors.New(ErrInvalidToken)
102+
}
103+
96104
binTok, err := hex.DecodeString(n.DeviceToken)
97105
if err != nil {
98106
return b, fmt.Errorf("convert token to hex error: %s", err)

0 commit comments

Comments
 (0)