diff --git a/.travis.yml b/.travis.yml index 18c443a..687544f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ before_script: - go get github.com/onsi/gomega - go get code.google.com/p/go.tools/cmd/cover - go install github.com/onsi/ginkgo/ginkgo -script: ginkgo -r --skipMeasurements --cover --trace +script: ginkgo -r --skipMeasurements --cover --trace --race env: global: - PATH=$HOME/gopath/bin:$PATH diff --git a/apns_suite_test.go b/apns_suite_test.go index b0bcca4..858a5b9 100644 --- a/apns_suite_test.go +++ b/apns_suite_test.go @@ -5,8 +5,46 @@ import ( . "github.com/onsi/gomega" "testing" + "time" ) +type mockConn struct { + connect func() error + read func([]byte) (int, error) + readWithTimeout func([]byte, time.Time) (int, error) +} + +func (m *mockConn) Connect() error { + if m.connect != nil { + return m.connect() + } + + return nil +} + +func (m *mockConn) Read(b []byte) (int, error) { + if m.read != nil { + return m.read(b) + } + return 0, nil +} + +func (m *mockConn) Write([]byte) (int, error) { + return 0, nil +} + +func (m *mockConn) Close() error { + return nil +} + +func (m *mockConn) ReadWithTimeout(b []byte, t time.Time) (int, error) { + if m.readWithTimeout != nil { + return m.readWithTimeout(b, t) + } + + return 0, nil +} + func TestApns(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Apns Suite") diff --git a/client.go b/client.go index de7ab1a..b9159bb 100644 --- a/client.go +++ b/client.go @@ -1,57 +1,33 @@ package apns import ( - "container/list" "crypto/tls" - "io" - "log" + "sync" "time" ) -type buffer struct { - size int - *list.List -} - -func newBuffer(size int) *buffer { - return &buffer{size, list.New()} -} - -func (b *buffer) Add(v interface{}) *list.Element { - e := b.PushBack(v) - - if b.Len() > b.size { - b.Remove(b.Front()) - } - - return e -} - type Client struct { - Conn *Conn - FailedNotifs chan NotificationResult + conn Conn - notifs chan Notification - id uint32 + sess Session + sessm sync.Mutex } -func newClientWithConn(gw string, conn Conn) Client { - c := Client{ - Conn: &conn, - FailedNotifs: make(chan NotificationResult), - id: uint32(1), - notifs: make(chan Notification), - } +func newClientWithConn(conn Conn) (Client, error) { + c := Client{conn: conn} - go c.runLoop() + sess := newSession(conn) + err := sess.Connect() + if err != nil { + return c, err + } - return c + return Client{conn, sess, sync.Mutex{}}, nil } -func NewClientWithCert(gw string, cert tls.Certificate) Client { +func NewClientWithCert(gw string, cert tls.Certificate) (Client, error) { conn := NewConnWithCert(gw, cert) - - return newClientWithConn(gw, conn) + return newClientWithConn(conn) } func NewClient(gw string, cert string, key string) (Client, error) { @@ -60,7 +36,7 @@ func NewClient(gw string, cert string, key string) (Client, error) { return Client{}, err } - return newClientWithConn(gw, conn), nil + return newClientWithConn(conn) } func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, error) { @@ -69,151 +45,48 @@ func NewClientWithFiles(gw string, certFile string, keyFile string) (Client, err return Client{}, err } - return newClientWithConn(gw, conn), nil + return newClientWithConn(conn) } func (c *Client) Send(n Notification) error { - c.notifs <- n - return nil -} - -func (c *Client) reportFailedPush(v interface{}, err *Error) { - failedNotif, ok := v.(Notification) - if !ok || v == nil { - return + if c.sess.Disconnected() { + c.reconnectAndRequeue() } - select { - case c.FailedNotifs <- NotificationResult{Notif: failedNotif, Err: *err}: - default: - } + return c.sess.Send(n) } -func (c *Client) requeue(cursor *list.Element) { - // If `cursor` is not nil, this means there are notifications that - // need to be delivered (or redelivered) - for ; cursor != nil; cursor = cursor.Next() { - if n, ok := cursor.Value.(Notification); ok { - go func() { c.notifs <- n }() - } - } -} - -func (c *Client) handleError(err *Error, buffer *buffer) *list.Element { - cursor := buffer.Back() - - for cursor != nil { - // Get notification - n, _ := cursor.Value.(Notification) +func (c *Client) reconnectAndRequeue() { + c.sessm.Lock() + defer c.sessm.Unlock() - // If the notification, move cursor after the trouble notification - if n.Identifier == err.Identifier { - go c.reportFailedPush(cursor.Value, err) + // Pull off undelivered notifications + notifs := c.sess.RequeueableNotifications() - next := cursor.Next() - - buffer.Remove(cursor) - return next - } - - cursor = cursor.Prev() - } - - return cursor -} + // Reconnect + c.sess = nil -func (c *Client) runLoop() { - sent := newBuffer(50) - cursor := sent.Front() + for c.sess == nil { + sess := newSession(c.conn) - // APNS connection - for { - err := c.Conn.Connect() + err := sess.Connect() if err != nil { - // TODO Probably want to exponentially backoff... + // TODO retry policy + // TODO connect error channel + // Keep trying to connect time.Sleep(1 * time.Second) continue } - // Start reading errors from APNS - errs := readErrs(c.Conn) - - c.requeue(cursor) - - // Connection open, listen for notifs and errors - for { - var err error - var n Notification - - // Check for notifications or errors. There is a chance we'll send notifications - // if we already have an error since `select` will "pseudorandomly" choose a - // ready channels. It turns out to be fine because the connection will already - // be closed and it'll requeue. We could check before we get to this select - // block, but it doesn't seem worth the extra code and complexity. - select { - case err = <-errs: - case n = <-c.notifs: - } - - // If there is an error we understand, find the notification that failed, - // move the cursor right after it. - if nErr, ok := err.(*Error); ok { - cursor = c.handleError(nErr, sent) - break - } - - if err != nil { - break - } - - // Add to list - cursor = sent.Add(n) - - // Set identifier if not specified - if n.Identifier == 0 { - n.Identifier = c.id - c.id++ - } else if c.id < n.Identifier { - c.id = n.Identifier + 1 - } - - b, err := n.ToBinary() - if err != nil { - // TODO - continue - } - - _, err = c.Conn.Write(b) - - if err == io.EOF { - log.Println("EOF trying to write notification") - break - } - - if err != nil { - log.Println("err writing to apns", err.Error()) - break - } - - cursor = cursor.Next() - } + c.sess = sess } -} - -func readErrs(c *Conn) chan error { - errs := make(chan error) - - go func() { - p := make([]byte, 6, 6) - _, err := c.Read(p) - if err != nil { - errs <- err - return - } - e := NewError(p) - errs <- &e - }() + for _, n := range notifs { + // TODO handle error from sending + c.sess.Send(n) + } +} - return errs +var newSession = func(c Conn) Session { + return NewSession(c) } diff --git a/client_test.go b/client_test.go index c9dfd47..7dbee14 100644 --- a/client_test.go +++ b/client_test.go @@ -1,38 +1,91 @@ -package apns_test +package apns import ( - "bytes" - "encoding/binary" + "errors" "io/ioutil" "os" - "time" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "github.com/timehop/apns" + "github.com/timehop/tcptest" ) +type mockSession struct { + sendCB func(n Notification) error + requeueNotifs []Notification + disconnectedState bool +} + +func (m *mockSession) Send(n Notification) error { + if m.sendCB == nil { + return nil + } + + return m.sendCB(n) +} + +func (m *mockSession) Connect() error { + return nil +} + +func (m *mockSession) RequeueableNotifications() []Notification { + if len(m.requeueNotifs) == 0 { + return []Notification{} + } + + return m.requeueNotifs +} + +func (m *mockSession) Disconnect() { + m.disconnectedState = true +} + +func (m *mockSession) Disconnected() bool { + return m.disconnectedState +} + +type badConnMockSession struct { + *mockSession +} + +func (_ badConnMockSession) Connect() error { + return errors.New("whatev") +} + var _ = Describe("Client", func() { - Describe(".NewConn", func() { + BeforeEach(func() { + newSession = func(_ Conn) Session { return &mockSession{} } + }) + + Describe(".NewClient", func() { Context("bad cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClient(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClient(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) Context("valid cert/key pair", func() { It("should create a valid client", func() { - c, err := apns.NewClient(apns.ProductionGateway, DummyCert, DummyKey) + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) + }) + }) + + Context("bad connection", func() { + It("should error out", func() { + newSession = func(_ Conn) Session { return badConnMockSession{} } + + _, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).NotTo(BeNil()) }) }) }) - Describe(".NewConnWithFiles", func() { + Describe(".NewClientWithFiles", func() { Context("missing cert/key pair", func() { It("should error out", func() { - _, err := apns.NewClientWithFiles(apns.ProductionGateway, "missing", "missing_also") + _, err := NewClientWithFiles(ProductionGateway, "missing", "missing_also") Expect(err).NotTo(BeNil()) }) }) @@ -42,11 +95,11 @@ var _ = Describe("Client", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -61,324 +114,92 @@ var _ = Describe("Client", func() { }) It("should create a valid client", func() { - c, err := apns.NewClientWithFiles(apns.ProductionGateway, certFile.Name(), keyFile.Name()) + _, err := NewClientWithFiles(ProductionGateway, certFile.Name(), keyFile.Name()) Expect(err).To(BeNil()) - Expect(c.Conn).NotTo(BeNil()) }) }) }) - Describe("#Send", func() { - Context("simple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + Describe("Send", func() { + Context("connected", func() { + Context("valid push", func() { + It("should not return an error", func() { + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) }) }) - }) - Context("simple write with buffer", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - for i := 0; i < 54; i++ { - Expect(c.Send(apns.Notification{})).To(BeNil()) + Context("invalid notification", func() { + It("should return an error", func() { + newSession = func(_ Conn) Session { + return &mockSession{ + sendCB: func(_ Notification) error { + return errors.New("") + }, + } } - close(mockDone) - close(d) - }) - }) - }) - - Context("multiple write", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - Expect(c.Send(apns.Notification{})).To(BeNil()) - Expect(c.Send(apns.Notification{})).To(BeNil()) - - close(mockDone) - close(d) - }) - }) - }) - - Context("bad push", func() { - n := apns.Notification{Identifier: 9, ID: "some_rando"} - nb, _ := n.ToBinary() - nbcb := make([]byte, len(nb)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(9)) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - serverAction{action: readAction, data: nbcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(nb)) - }}, - - // Bad push results in a close - serverAction{action: writeAction, data: errPayload.Bytes()}, - serverAction{action: closeAction, data: []byte{}}, - }, - } - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true - - go func() { - n := <-c.FailedNotifs - - Expect(n.Notif.Identifier).To(Equal(uint32(9))) - Expect(n.Notif.ID).To(Equal("some_rando")) - - close(mockDone) - close(d) - }() - - Expect(c.Send(n)).To(BeNil()) - }) - }) - }) - - Context("closed, reconnect", func() { - done := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n1b, _ := n1.ToBinary() - n1bcb := make([]byte, len(n1b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - as := [][]serverAction{ - []serverAction{ - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - done <- true - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - - close(mockDone) - close(d) - }}, - }, - } - - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - <-done - time.Sleep(5 * time.Millisecond) - - // Good - Expect(c.Send(n1)).To(BeNil()) + err = c.Send(Notification{DeviceToken: "lol"}) + Expect(err).NotTo(BeNil()) }) }) }) - Context("good, close, good, requeue of last good", func() { - closed := make(chan bool) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - closed <- true - }}, - }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - - close(mockDone) - close(d) - }}, - }, + Context("disconnected", func() { + It("should reconnect", func() { + newSessCount := 0 + newSession = func(_ Conn) Session { + newSessCount += 1 + return &mockSession{} } - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n1)).To(BeNil()) + c.sess.Disconnect() - <-closed - time.Sleep(5 * time.Millisecond) + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n2)).To(BeNil()) - }) + Expect(newSessCount).To(Equal(2)) }) }) - Context("good, bad, good, requeue of last good", func() { - It("should not return an error", func(d Done) { - mockDone := make(chan interface{}) - - n1 := apns.Notification{Identifier: 1} - n2 := apns.Notification{Identifier: 2} - n3 := apns.Notification{Identifier: 3} - - n1b, _ := n1.ToBinary() - n2b, _ := n2.ToBinary() - n3b, _ := n3.ToBinary() - - n1bcb := make([]byte, len(n1b)) - n2bcb := make([]byte, len(n2b)) - n3bcb := make([]byte, len(n3b)) - - errPayload := bytes.NewBuffer([]byte{}) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint8(8)) - binary.Write(errPayload, binary.BigEndian, uint32(2)) - - as := [][]serverAction{ - []serverAction{ - // Connect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Handshake - }}, - - // Read first good notification - serverAction{action: readAction, data: n1bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n1b)) - }}, - - // Read bad notification - serverAction{action: readAction, data: n2bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n2b)) - }}, - - // Read second good notification - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - }}, - - // Write error - serverAction{action: writeAction, data: errPayload.Bytes(), cb: func(a serverAction) { - }}, - - // Close on error - serverAction{action: closeAction, cb: func(a serverAction) { - }}, + It("should reconnect and requeue", func() { + newSessCount := 0 + sendCount := 0 + + newSession = func(_ Conn) Session { + newSessCount += 1 + return &mockSession{ + requeueNotifs: []Notification{ + Notification{}, + Notification{}, + Notification{}, }, - []serverAction{ - // Reconnect - serverAction{action: readAction, data: []byte{}, cb: func(a serverAction) { - // Reconnected - }}, - - // Requeue - serverAction{action: readAction, data: n3bcb, cb: func(a serverAction) { - Expect(a.data).To(Equal(n3b)) - - close(mockDone) - close(d) - }}, + sendCB: func(_ Notification) error { + sendCount += 1 + return nil }, } + } - withMockServerAsync(as, mockDone, func(s *mockTLSServer) { - c, _ := apns.NewClient(s.Address(), DummyCert, DummyKey) - c.Conn.Conf.InsecureSkipVerify = true + c, err := NewClient(SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n1)).To(BeNil()) + c.sess.Disconnect() - // Bad - Expect(c.Send(n2)).To(BeNil()) + err = c.Send(Notification{DeviceToken: "0000000000000000000000000000000000000000000000000000000000000000"}) + Expect(err).To(BeNil()) - // Good - Expect(c.Send(n3)).To(BeNil()) - }) - }) + Expect(newSessCount).To(Equal(2)) + Expect(sendCount).To(Equal(4)) }) }) }) diff --git a/conn.go b/conn.go index d3aa712..b14f8b8 100644 --- a/conn.go +++ b/conn.go @@ -2,8 +2,10 @@ package apns import ( "crypto/tls" + "io" "net" "strings" + "time" ) const ( @@ -15,9 +17,16 @@ const ( ) // Conn is a wrapper for the actual TLS connections made to Apple -type Conn struct { - NetConn net.Conn - Conf *tls.Config +type Conn interface { + io.ReadWriteCloser + + Connect() error + ReadWithTimeout(p []byte, deadline time.Time) (int, error) +} + +type conn struct { + netConn net.Conn + tls *tls.Config gateway string connected bool @@ -25,19 +34,20 @@ type Conn struct { func NewConnWithCert(gw string, cert tls.Certificate) Conn { gatewayParts := strings.Split(gw, ":") - conf := tls.Config{ - Certificates: []tls.Certificate{cert}, - ServerName: gatewayParts[0], + tls := tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: gatewayParts[0], + InsecureSkipVerify: true, } - return Conn{gateway: gw, Conf: &conf} + return &conn{gateway: gw, tls: &tls} } // NewConnWithFiles creates a new Conn from certificate and key in the specified files func NewConn(gw string, crt string, key string) (Conn, error) { cert, err := tls.X509KeyPair([]byte(crt), []byte(key)) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil @@ -47,49 +57,49 @@ func NewConn(gw string, crt string, key string) (Conn, error) { func NewConnWithFiles(gw string, certFile string, keyFile string) (Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return Conn{}, err + return &conn{}, err } return NewConnWithCert(gw, cert), nil } // Connect actually creates the TLS connection -func (c *Conn) Connect() error { +func (c *conn) Connect() error { // Make sure the existing connection is closed - if c.NetConn != nil { - c.NetConn.Close() + if c.netConn != nil { + c.netConn.Close() } - conn, err := net.Dial("tcp", c.gateway) + tlsConn, err := tls.Dial("tcp", c.gateway, c.tls) if err != nil { return err } - tlsConn := tls.Client(conn, c.Conf) - err = tlsConn.Handshake() - if err != nil { - return err - } - - c.NetConn = tlsConn + c.netConn = tlsConn return nil } -func (c *Conn) Close() error { - if c.NetConn != nil { - return c.NetConn.Close() +func (c *conn) Close() error { + if c.netConn != nil { + return c.netConn.Close() } return nil } // Read reads data from the connection -func (c *Conn) Read(p []byte) (int, error) { - i, err := c.NetConn.Read(p) - return i, err +func (c *conn) Read(p []byte) (int, error) { + return c.netConn.Read(p) +} + +// ReadWithTimeout reads data from the connection and returns an error +// after duration +func (c *conn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { + c.netConn.SetReadDeadline(deadline) + return c.netConn.Read(p) } // Write writes data from the connection -func (c *Conn) Write(p []byte) (int, error) { - return c.NetConn.Write(p) +func (c *conn) Write(p []byte) (int, error) { + return c.netConn.Write(p) } diff --git a/conn_test.go b/conn_test.go index e910e6c..d388a6d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,232 +1,15 @@ package apns_test import ( - "bytes" - "crypto/tls" - "fmt" - "io" "io/ioutil" - "log" "net" "os" - "strings" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) -var DummyCert = `-----BEGIN CERTIFICATE----- -MIIC9TCCAd+gAwIBAgIQf3bEgFWUb+q6eK5ySkV/gjALBgkqhkiG9w0BAQUwEjEQ -MA4GA1UEChMHQWNtZSBDbzAeFw0xNDA2MzAwNDI5MDhaFw0xNTA2MzAwNDI5MDha -MBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK -AoIBAQDhAgWrrFZBtCfVEPg1tSIr9fuSUoeundb556IUr9uOmOHaYK7r3/I43acw -bVIfaenFxwUUf8YakQzTjOa5qSfK/Eylyw2ezBJtNUEqcHw0f+y66+jJbZa4clPa -tL6ezaMS/syXPpvNU8+16jdVdTJzqdBdSGAZMOCeumUWDNdlfBmHPVq1JMy0uGmO -XDoZK2Ir0/3LUfjk9R2wdm1VLrJAml7F0L0FhBHHXgHOSFM2ixjGflffaiuTCxhW -1z1NTo9XjWUQh2iM9Udf+xVnJLGLZ0EMFr2qihuK604Fp4SlNHEF+UWUn+j0PYo+ -LbzM9oKJcdVD0XI36vrn3rGPHO9vAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIAoDAT -BgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxv -Y2FsaG9zdDALBgkqhkiG9w0BAQUDggEBAGJ/3I4KKlbEwLAC5ut4ZZ9V8WF4sHkI -Lj7e4vx2pPi6hf9miV1ff01NrpfUna7flwL9yD7Ybl7jRRIB4rIcKk+U5djGsT3H -ScGkbIMKrr08drWw1g4JU6PBH7xTfzGxNRERrnmrbJV0jCo9Tt8i53IpPtp6Z2Q1 -8ydtPhU+Bpe2YoNr1w1fSV1JHXqjKV8RlGkCNSi4ozPOO8RbAYnBT3d9XSGoX//q -RGJUf3wC/rCxJkN63Moxuy3vxV2TmiqccHOrXJSJ8P/4PpPV/xuBk5k4HS1Nfmew -d9WHHn6bMJE9arVvWAiu9teCadVffuS2cl2cicN4XB6Ui0aDqhG2Exw= ------END CERTIFICATE-----` - -var DummyKey = `-----BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEA4QIFq6xWQbQn1RD4NbUiK/X7klKHrp3W+eeiFK/bjpjh2mCu -69/yON2nMG1SH2npxccFFH/GGpEM04zmuaknyvxMpcsNnswSbTVBKnB8NH/suuvo -yW2WuHJT2rS+ns2jEv7Mlz6bzVPPteo3VXUyc6nQXUhgGTDgnrplFgzXZXwZhz1a -tSTMtLhpjlw6GStiK9P9y1H45PUdsHZtVS6yQJpexdC9BYQRx14BzkhTNosYxn5X -32orkwsYVtc9TU6PV41lEIdojPVHX/sVZySxi2dBDBa9qoobiutOBaeEpTRxBflF -lJ/o9D2KPi28zPaCiXHVQ9FyN+r6596xjxzvbwIDAQABAoIBAFzW+cIA5MJNdFX8 -n32BlGzxHPEd7nAFHmuUwJKqkPwAZsg1NleK2qXOByr7IHRnvhZl7Nmtcu8JRHKR -Y63ddtbRTUrnQmJwL3YyEAZTzVvYILRrnGxoNFU8jw7hnvllPdEbow0QvzZ0S3Lz -BgvTxJJm0dt7fnNGcJftrsHvYHy1dptaR4hPv0xV5G7RPrbTl94llKfi745tp5Wd -xGpnjcBXoAnzCVRij1tHfSYubRJ2MJV0kzG3oVdRV2P/zWaout8BlhLCURv4sRUX -7FfCNa/z+G6AlROjCKJUP9YIUbxBEa/aP8YlSiyLRi1jFbMWcnKWQUdqS19m73Ap -a1LJFPECgYEA+Ve5DegcrWnUb2HsHD38HlmEg6S+/jg2P4TsuLZBtvO4/vzRx/qq -pwuuMm2CsvXr4nVmMEsMlSzYdsnaXIlWqyVDCOwIWR5VYT2GDWqQLaIXPlFaISzN -27tHd64KUtR1fMJUwQVK/MUORUbpYoAnSIil2SlYkWUhF024fNP8CxcCgYEA5wP4 -HLiqU2rqe7vSAF/8fHwPleTzuCfMCVZm0aegUzQQQtklZoVE/BBwEGHdXflq1veq -pHeC8bNR4BF6ZgeSWgbLVF3msquy47QeNElHA2muJd3qmNWz4LXo1Pxb8KXcnXri -QZ+r3Y8obWTFQYq7gGQGPLXGTV3bhLGIyrT4lWkCgYAgZ2MYSJL5gmhmNT6fCPsr -4oxTI2Ti2uFJ7fdppd3ybcgb8zU8HPpyjRUNXqf+o/EM1B78pbQz6skS3vau0fZe -dZA5p5sKIeQMqBc0xSWJmKgWpDHnX9A8/yCxj/+tdgjytrqW/x4YrW9GV4nbEDaK -uZ98EmB9PLxJMAOKzW3S7wKBgQDD4PCy4b3CR2iVC9dva/P5VXQdo+knX884p6M8 -58YgZofXNqnouN2aYRG0QlbiBMcbiRqOo6tK58JnnEpNUuQ8I4Cqg4hGPSHMwv/N -U8i70xLPltABUUpZIcVPOr92WBytBvHrtMiUb3tW7lf3T/vWTHmhZnvDQ+8LH0Ge -pz4T6QKBgQCoBJKOd781IQmT6i5hHSYJlsP6ymaaaQniJPVpnci/jf8+2QtponQY -scgnaBLBasLQ6GfKSRtcyidEi9wwxpVj0tw2p567jeNcIveD0TOYFf0RHEfrs+D4 -VdRgai/v2NbFZLDnzeGVuYypXu6R78isJfHtz/a0aEave8yB3CRiDw== ------END RSA PRIVATE KEY-----` - -// To be able to run in parallel -var mockPort = 50000 - -// Mock Addr -type mockAddr struct { -} - -func (m mockAddr) Network() string { - return "localhost:56789" -} - -func (m mockAddr) String() string { - return "localhost:56789" -} - -// Mock TLS connection -type mockTLSNetConn struct { - bb *bytes.Buffer - err error -} - -func (t mockTLSNetConn) Read(p []byte) (int, error) { - r := bytes.NewReader(t.bb.Bytes()) - return r.Read(p) -} - -func (t mockTLSNetConn) Write(p []byte) (int, error) { - return t.bb.Write(p) -} - -func (t mockTLSNetConn) Close() error { - return t.err -} - -func (m mockTLSNetConn) LocalAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) RemoteAddr() net.Addr { - return mockAddr{} -} - -func (m mockTLSNetConn) SetDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (m mockTLSNetConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type serverAction struct { - action string - data []byte - cb func(s serverAction) -} - -const ( - readAction = "read" - writeAction = "write" - closeAction = "close" -) - -type mockTLSServer struct { - Port int - Server net.Listener - ConnectionActionGroups [][]serverAction -} - -func (m *mockTLSServer) portStr() string { - if m.Port == 0 { - mockPort = mockPort + 1 - m.Port = mockPort - } - - return fmt.Sprint(m.Port) -} - -func (m *mockTLSServer) Address() string { - return "localhost:" + m.portStr() -} - -func (m *mockTLSServer) start() { - cert, err := tls.X509KeyPair([]byte(DummyCert), []byte(DummyKey)) - if err != nil { - log.Panic(err) - } - - config := tls.Config{Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAnyClientCert} - - m.Server, err = tls.Listen("tcp", "localhost:"+m.portStr(), &config) - go func() { - for i := 0; i < len(m.ConnectionActionGroups); i++ { - g := m.ConnectionActionGroups[i] - - // Wait for a connection. - conn, err := m.Server.Accept() - if err != nil { - if strings.Contains(err.Error(), "use of closed network connection") { - return - } else { - log.Fatal(err) - } - } - // Handle the connection in a new goroutine. - // The loop then returns to accepting, so that - // multiple connections may be served concurrently. - go func(c net.Conn) { - for j := 0; j < len(g); j++ { - a := g[j] - switch a.action { - case readAction: - c.Read(a.data) - case writeAction: - c.Write(a.data) - case closeAction: - c.Close() - - if a.cb != nil { - a.cb(a) - } - return - } - - if a.cb != nil { - a.cb(a) - } - } - }(conn) - } - - // No more connection action groups - }() -} - -func (m *mockTLSServer) stop() { - if m.Server != nil { - m.Server.Close() - } -} - -var withMockServer = func(as [][]serverAction, cb func(s *mockTLSServer)) { - d := make(chan interface{}) - withMockServerAsync(as, d, func(s *mockTLSServer) { - cb(s) - close(d) - }) -} - -var withMockServerAsync = func(as [][]serverAction, d chan interface{}, cb func(s *mockTLSServer)) { - s := &mockTLSServer{} - s.ConnectionActionGroups = as - - s.start() - - cb(s) - - <-d - s.stop() -} - // Tests var _ = Describe("Conn", func() { Describe(".NewConn", func() { @@ -239,7 +22,7 @@ var _ = Describe("Conn", func() { Context("valid key/cert pair", func() { It("should not return an error", func() { - _, err := apns.NewConn(apns.SandboxGateway, DummyCert, DummyKey) + _, err := apns.NewConn(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -259,11 +42,11 @@ var _ = Describe("Conn", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -295,117 +78,122 @@ var _ = Describe("Conn", func() { }) Context("server up", func() { - as := [][]serverAction{[]serverAction{serverAction{action: readAction, data: []byte{}}}} - Context("with untrusted certs", func() { It("should return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - err := conn.Connect() - Expect(err).NotTo(BeNil()) + s := tcptest.NewTLSServer(func(c net.Conn) {}) + defer s.Close() - close(d) - }) + conn, err := apns.NewConn(s.Addr, "not trusted", "not even a little") + Expect(err).NotTo(BeNil()) + + err = conn.Connect() + Expect(err).NotTo(BeNil()) + + close(d) }) }) Context("trusting the certs", func() { It("should not return an error", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + s := tcptest.NewUnstartedServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) - err := conn.Connect() - Expect(err).To(BeNil()) + s.StartTLS() + defer s.Close() - close(d) - }) - }) + conn, err := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + err = conn.Connect() + Expect(err).To(BeNil()) + + close(d) + }, 10) }) Context("with existing connection", func() { It("should not return an error", func(d Done) { - as = [][]serverAction{ - []serverAction{serverAction{action: readAction, data: []byte{}}}, - []serverAction{serverAction{action: readAction, data: []byte{}}}, - } + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect + }) + defer s.Close() - withMockServer(as, func(s *mockTLSServer) { - conn, _ := apns.NewConn(s.Address(), DummyCert, DummyKey) - conn.Conf.InsecureSkipVerify = true + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) - conn.Connect() + conn.Connect() - err := conn.Connect() - Expect(err).To(BeNil()) + err := conn.Connect() + Expect(err).To(BeNil()) - close(d) - }) - }) + close(d) + }, 10) }) }) }) Describe("#Read", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte("hello!"))} - - pp := make([]byte, 6) - bytes.NewReader(rwc.bb.Bytes()).Read(pp) + It("should read out 'hello!'", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte("hello!")) + }) + defer s.Close() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - It("should read out 'hello!'", func() { p := make([]byte, 6) conn.Read(p) Expect(p).To(Equal([]byte("hello!"))) }) }) +}) - Describe("#Write", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} +var _ = Describe("#Write", func() { + It("should read out 'hello!'", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + b := make([]byte, 6) + c.Read(b) - It("should write out 'world!'", func() { - conn.Write([]byte("world!")) - Expect(rwc.bb.String()).To(Equal("world!")) + Expect(string(b)).To(Equal("hello!")) + close(d) }) - }) - Describe("#Close", func() { - Context("with connection", func() { - Context("no error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + conn.Write([]byte("hello!")) + }, 10) +}) - It("should return no error", func() { - Expect(rwc.Close()).To(BeNil()) +var _ = Describe("#Close", func() { + Context("with connection", func() { + Context("no error", func() { + It("should return no error", func() { + s := tcptest.NewTLSServer(func(c net.Conn) { + defer c.Close() + c.Write([]byte{}) // Connect }) - }) - - Context("with error", func() { - rwc := mockTLSNetConn{bb: bytes.NewBuffer([]byte{})} - - conn, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - conn.NetConn = rwc + defer s.Close() - rwc.err = io.EOF - It("should return that error", func() { - Expect(rwc.Close()).To(Equal(io.EOF)) - }) + conn, _ := apns.NewConn(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + conn.Connect() + Expect(conn.Close()).To(BeNil()) }) }) + }) - Context("without connection", func() { - c, _ := apns.NewConn(apns.ProductionGateway, DummyCert, DummyKey) - It("should not return an error", func() { - Expect(c.Close()).To(BeNil()) - }) + Context("without connection", func() { + It("should not return an error", func() { + conn, _ := apns.NewConn("localhost:12345", string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(conn.Close()).To(BeNil()) }) }) }) diff --git a/error.go b/error.go index 5425868..3371bea 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,11 @@ package apns import ( "bytes" "encoding/binary" + "errors" +) + +var ( + ErrDisconnected = errors.New("disconnected from gateway") ) const ( diff --git a/example/example.go b/example/example.go index 1b670ac..637d0af 100644 --- a/example/example.go +++ b/example/example.go @@ -10,10 +10,10 @@ import ( func main() { c, err := apns.NewClientWithFiles(apns.ProductionGateway, "apns.crt", "apns.key") if err != nil { - log.Fatal("Could not create client", err.Error()) + log.Fatal("Could not create client: ", err.Error()) } - i := 0 + i := 1 for { fmt.Print("Enter ' ': ") diff --git a/feedback.go b/feedback.go index 488bf1b..a0093b1 100644 --- a/feedback.go +++ b/feedback.go @@ -9,7 +9,7 @@ import ( ) type Feedback struct { - Conn *Conn + Conn Conn } type FeedbackTuple struct { @@ -40,7 +40,7 @@ func feedbackTupleFromBytes(b []byte) FeedbackTuple { func NewFeedbackWithCert(gw string, cert tls.Certificate) Feedback { conn := NewConnWithCert(gw, cert) - return Feedback{Conn: &conn} + return Feedback{Conn: conn} } func NewFeedback(gw string, cert string, key string) (Feedback, error) { @@ -49,7 +49,7 @@ func NewFeedback(gw string, cert string, key string) (Feedback, error) { return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, error) { @@ -58,7 +58,7 @@ func NewFeedbackWithFiles(gw string, certFile string, keyFile string) (Feedback, return Feedback{}, err } - return Feedback{Conn: &conn}, nil + return Feedback{Conn: conn}, nil } // Receive returns a read only channel for APNs feedback. The returned channel @@ -80,9 +80,7 @@ func (f Feedback) receive(fc chan FeedbackTuple) { for { b := make([]byte, 38) - f.Conn.NetConn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - - _, err := f.Conn.Read(b) + _, err := f.Conn.ReadWithTimeout(b, time.Now().Add(100*time.Millisecond)) if err != nil { close(fc) return diff --git a/feedback_test.go b/feedback_test.go index 29978b4..96dadb4 100644 --- a/feedback_test.go +++ b/feedback_test.go @@ -4,12 +4,15 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "io" "io/ioutil" + "net" "os" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/timehop/apns" + "github.com/timehop/tcptest" ) var _ = Describe("Feedback", func() { @@ -23,7 +26,7 @@ var _ = Describe("Feedback", func() { Context("valid cert/key pair", func() { It("should create a valid client", func() { - _, err := apns.NewFeedback(apns.ProductionGateway, DummyCert, DummyKey) + _, err := apns.NewFeedback(apns.SandboxGateway, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) Expect(err).To(BeNil()) }) }) @@ -42,11 +45,11 @@ var _ = Describe("Feedback", func() { BeforeEach(func() { certFile, _ = ioutil.TempFile("", "cert.pem") - certFile.Write([]byte(DummyCert)) + certFile.Write([]byte(tcptest.LocalhostCert)) certFile.Close() keyFile, _ = ioutil.TempFile("", "key.pem") - keyFile.Write([]byte(DummyKey)) + keyFile.Write([]byte(tcptest.LocalhostKey)) keyFile.Close() }) @@ -70,11 +73,13 @@ var _ = Describe("Feedback", func() { Describe("#Receive", func() { Context("could not connect", func() { It("should not receive anything", func() { - s := &mockTLSServer{} - - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true + m := mockConn{ + connect: func() error { + return io.EOF + }, + } + f := apns.Feedback{Conn: &m} c := f.Receive() r := 0 @@ -87,89 +92,87 @@ var _ = Describe("Feedback", func() { }) Context("times out", func() { - as := [][]serverAction{ - []serverAction{ - serverAction{action: readAction, data: []byte{}}, - }, - } - - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - It("should not receive anything", func() { - c := f.Receive() - - r := 0 - for _ = range c { - r += 1 - } - - Expect(r).To(Equal(0)) - }) + It("should not receive anything", func() { + m := mockConn{ + readWithTimeout: func(b []byte, t time.Time) (int, error) { + return 0, net.UnknownNetworkError("") + }, + } + + f := apns.Feedback{Conn: &m} + c := f.Receive() + + r := 0 + for _ = range c { + r += 1 + } + + Expect(r).To(Equal(0)) }) }) + }) + + Context("with feedback", func() { + f1 := bytes.NewBuffer([]byte{}) + f2 := bytes.NewBuffer([]byte{}) + f3 := bytes.NewBuffer([]byte{}) - Context("with feedback", func() { - f1 := bytes.NewBuffer([]byte{}) - f2 := bytes.NewBuffer([]byte{}) - f3 := bytes.NewBuffer([]byte{}) - - // The final token strings - t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" - t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" - t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" - - bt1, _ := hex.DecodeString(t1) - bt2, _ := hex.DecodeString(t2) - bt3, _ := hex.DecodeString(t3) - - binary.Write(f1, binary.BigEndian, uint32(1404358249)) - binary.Write(f1, binary.BigEndian, uint16(len(bt1))) - binary.Write(f1, binary.BigEndian, bt1) - - binary.Write(f2, binary.BigEndian, uint32(1404352249)) - binary.Write(f2, binary.BigEndian, uint16(len(bt2))) - binary.Write(f2, binary.BigEndian, bt2) - - binary.Write(f3, binary.BigEndian, uint32(1394352249)) - binary.Write(f3, binary.BigEndian, uint16(len(bt3))) - binary.Write(f3, binary.BigEndian, bt3) - - as := [][]serverAction{ - []serverAction{ - serverAction{action: writeAction, data: f1.Bytes()}, - serverAction{action: writeAction, data: f2.Bytes()}, - serverAction{action: writeAction, data: f3.Bytes()}, - }, - } - - It("should receive feedback", func(d Done) { - withMockServer(as, func(s *mockTLSServer) { - f, _ := apns.NewFeedback(s.Address(), DummyCert, DummyKey) - f.Conn.Conf.InsecureSkipVerify = true - - c := f.Receive() - - r1 := <-c - Expect(r1.Timestamp).To(Equal(time.Unix(1404358249, 0))) - Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) - Expect(r1.DeviceToken).To(Equal(t1)) - - r2 := <-c - Expect(r2.Timestamp).To(Equal(time.Unix(1404352249, 0))) - Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) - Expect(r2.DeviceToken).To(Equal(t2)) - - r3 := <-c - Expect(r3.Timestamp).To(Equal(time.Unix(1394352249, 0))) - Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) - Expect(r3.DeviceToken).To(Equal(t3)) - - <-c - close(d) - }) + // The final token strings + t1 := "00a18269661e9406aea59a5620b05c7c0e371574fa6f251951de8d7a5a292535" + t2 := "00a1a4b7294fcfbc5293f63d4298fcecd9c20a893befd45adceead5fc92d3319" + t3 := "00a1b7893d5e85eb8bb7bf0846b464d075248555118ae893b06e96cfb8d678e3" + + bt1, _ := hex.DecodeString(t1) + bt2, _ := hex.DecodeString(t2) + bt3, _ := hex.DecodeString(t3) + + binary.Write(f1, binary.BigEndian, uint32(1404358249)) + binary.Write(f1, binary.BigEndian, uint16(len(bt1))) + binary.Write(f1, binary.BigEndian, bt1) + + binary.Write(f2, binary.BigEndian, uint32(1404352249)) + binary.Write(f2, binary.BigEndian, uint16(len(bt2))) + binary.Write(f2, binary.BigEndian, bt2) + + binary.Write(f3, binary.BigEndian, uint32(1394352249)) + binary.Write(f3, binary.BigEndian, uint16(len(bt3))) + binary.Write(f3, binary.BigEndian, bt3) + + It("should receive feedback", func(d Done) { + s := tcptest.NewTLSServer(func(c net.Conn) { + c.Write(f1.Bytes()) + c.Write(f2.Bytes()) + c.Write(f3.Bytes()) + + // TODO(bw) this doesn't seem right + c.Write([]byte{0}) + + c.Close() }) - }) + defer s.Close() + + f, err := apns.NewFeedback(s.Addr, string(tcptest.LocalhostCert), string(tcptest.LocalhostKey)) + Expect(err).To(BeNil()) + + c := f.Receive() + + r1 := <-c + Expect(r1.Timestamp.Unix()).To(Equal(int64(1404358249))) + Expect(r1.TokenLength).To(Equal(uint16(len(bt1)))) + Expect(r1.DeviceToken).To(Equal(t1)) + + r2 := <-c + Expect(r2.Timestamp.Unix()).To(Equal(int64(1404352249))) + Expect(r2.TokenLength).To(Equal(uint16(len(bt2)))) + Expect(r2.DeviceToken).To(Equal(t2)) + + r3 := <-c + Expect(r3.Timestamp.Unix()).To(Equal(int64(1394352249))) + Expect(r3.TokenLength).To(Equal(uint16(len(bt3)))) + Expect(r3.DeviceToken).To(Equal(t3)) + + <-c + close(d) + }, 10) }) }) diff --git a/notification.go b/notification.go index a82557e..ad967c0 100644 --- a/notification.go +++ b/notification.go @@ -15,6 +15,10 @@ const ( PriorityPowerConserve = 5 ) +const ( + validDeviceTokenLength = 64 +) + const ( commandID = 2 @@ -133,6 +137,10 @@ func NewPayload() *Payload { func (n Notification) ToBinary() ([]byte, error) { b := []byte{} + if len(n.DeviceToken) != validDeviceTokenLength { + return b, errors.New(ErrInvalidToken) + } + binTok, err := hex.DecodeString(n.DeviceToken) if err != nil { return b, fmt.Errorf("convert token to hex error: %s", err) diff --git a/notification_test.go b/notification_test.go index cea990a..e76e253 100644 --- a/notification_test.go +++ b/notification_test.go @@ -193,7 +193,20 @@ var _ = Describe("Notifications", func() { Describe("#ToBinary", func() { Context("invalid token format", func() { n := apns.NewNotification() - n.DeviceToken = "totally not a valid token" + n.DeviceToken = "totally not a valid token length" + + It("should return an error", func() { + _, err := n.ToBinary() + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(Equal(apns.ErrInvalidToken)) + }) + + // Expect(err.Error()).To(ContainSubstring("convert token to hex error")) + }) + + Context("non-convertable token", func() { + n := apns.NewNotification() + n.DeviceToken = "123456789012345678901234567890zz123456789012345678901234567890zz" It("should return an error", func() { _, err := n.ToBinary() diff --git a/session.go b/session.go new file mode 100644 index 0000000..36e4576 --- /dev/null +++ b/session.go @@ -0,0 +1,229 @@ +package apns + +import ( + "container/list" + "errors" + "io" + "sync" +) + +type SessionError struct { + Notification Notification + Err Error +} + +func (s SessionError) Error() string { + return s.Err.Error() +} + +type Session interface { + Send(n Notification) error + Connect() error + RequeueableNotifications() []Notification + Disconnect() + Disconnected() bool +} + +type buffer struct { + size int + m sync.Mutex + *list.List +} + +func newBuffer(size int) *buffer { + return &buffer{size, sync.Mutex{}, list.New()} +} + +func (b *buffer) Add(v interface{}) *list.Element { + b.m.Lock() + defer b.m.Unlock() + + e := b.PushBack(v) + + if b.Len() > b.size { + b.Remove(b.Front()) + } + + return e +} + +type sessionState int + +const ( + sessionStateNew sessionState = 1 << iota + sessionStateConnected + sessionStateDisconnected +) + +type session struct { + b *buffer + + conn Conn + connm sync.Mutex + + st sessionState + stm sync.Mutex + + id uint32 + idm sync.Mutex + + err SessionError +} + +func NewSession(conn Conn) Session { + return &session{ + st: sessionStateNew, + stm: sync.Mutex{}, + conn: conn, + connm: sync.Mutex{}, + idm: sync.Mutex{}, + b: newBuffer(50), + } +} + +func (s *session) Connect() error { + if s.isNew() { + return errors.New("can't connect unless the session is new") + } + + go s.readErrors() + return nil +} + +func (s *session) isNew() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st != sessionStateNew +} + +func (s *session) Disconnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st == sessionStateDisconnected +} + +func (s *session) Connnected() bool { + s.stm.Lock() + defer s.stm.Unlock() + + return s.st == sessionStateConnected +} + +func (s *session) Send(n Notification) error { + // If disconnected, error out + if !s.Connnected() { + return errors.New("not connected") + } + + // Set identifier if not specified + n.Identifier = s.determineIdentifier(n.Identifier) + + // Serialize + b, err := n.ToBinary() + if err != nil { + return err + } + + // Add to buffer + s.b.Add(n) + + // Send synchronously + return s.send(b) +} + +func (s *session) send(b []byte) error { + s.connm.Lock() + defer s.connm.Unlock() + + _, err := s.conn.Write(b) + if err == io.EOF { + s.Disconnect() + return err + } + + return err +} + +func (s *session) Disconnect() { + // Disconnect + s.transitionState(sessionStateDisconnected) +} + +func (s *session) RequeueableNotifications() []Notification { + notifs := []Notification{} + + // If still connected, return nothing + if s.st != sessionStateDisconnected { + return notifs + } + + // Walk back to last known good notification and return the slice + var e *list.Element + for e = s.b.Front(); e != nil; e = e.Next() { + if n, ok := e.Value.(Notification); ok && n.Identifier == s.err.Notification.Identifier { + break + } + } + + // Start right after errored ID and get the rest of the list + for e = e.Next(); e != nil; e = e.Next() { + n, ok := e.Value.(Notification) + if !ok { + continue + } + + notifs = append(notifs, n) + } + + return notifs +} + +func (s *session) transitionState(st sessionState) { + s.stm.Lock() + defer s.stm.Unlock() + + s.st = st +} + +func (s *session) determineIdentifier(n uint32) uint32 { + s.idm.Lock() + defer s.idm.Unlock() + + // If the id passed in is 0, that means it wasn't + // set so get the next ID. Otherwise, set it to that + // identifier. + if n == 0 { + s.id++ + } else { + s.id = n + } + + return s.id +} + +func (s *session) readErrors() { + p := make([]byte, 6, 6) + + _, err := s.conn.Read(p) + // TODO(bw) not sure what to do here. It's unclear what errors + // come out of this and how we handle it. + if err != nil { + return + } + + s.Disconnect() + + e := NewError(p) + + for cursor := s.b.Back(); cursor != nil; cursor = cursor.Prev() { + // Get serialized notification + n, _ := cursor.Value.(Notification) + + // If the notification, move cursor after the trouble notification + if n.Identifier == e.Identifier { + s.err = SessionError{n, e} + } + } +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..67a598f --- /dev/null +++ b/session_test.go @@ -0,0 +1,94 @@ +package apns + +import ( + "time" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +type mockConn struct{} + +func (m mockConn) Read(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Write(b []byte) (int, error) { + return 0, nil +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) Connect() error { + return nil +} + +func (m mockConn) ReadWithTimeout(p []byte, deadline time.Time) (int, error) { + return 0, nil +} + +var _ = Describe("Session", func() { + Describe("NewSession", func() { + It("creates a session", func() { + s := NewSession(mockConn{}) + Expect(s).NotTo(BeNil()) + }) + }) + + Describe("Connect", func() { + Context("new state", func() { + It("should not return an error", func() { + s := NewSession(mockConn{}) + + err := s.Connect() + Expect(err).To(BeNil()) + }) + }) + + Context("not new state", func() { + It("should return an error", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + err := s.Connect() + Expect(err).NotTo(BeNil()) + }) + }) + }) + + Describe("Disconnected", func() { + Context("not connected", func() { + It("should not be true", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.transitionState(sessionStateDisconnected) + + Expect(s.Disconnected()).To(BeTrue()) + }) + }) + + Context("connected", func() { + It("should be false", func() { + sess := NewSession(mockConn{}) + + s := sess.(*session) + s.Connect() + + Expect(s.Disconnected()).To(BeFalse()) + }) + }) + }) + + Describe("Send", func() { + }) + + Describe("Disconnect", func() { + }) + + Describe("RequeueableNotifications", func() { + }) +})