Skip to content

Commit

Permalink
1, bug fix. 2, add custom ResponseWriter for get Status Code
Browse files Browse the repository at this point in the history
  • Loading branch information
dworld committed Nov 6, 2014
1 parent 7e8340a commit 12e9ccb
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 18 deletions.
33 changes: 26 additions & 7 deletions context.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package xingyun

import (
"fmt"
"net/http"

"github.com/gorilla/context"
Expand All @@ -23,13 +22,13 @@ func (h ContextHandlerFunc) ServeContext(ctx *Context) {

func ToHTTPHandlerFunc(h ContextHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.ServeContext(GetContext(r))
h.ServeContext(getUnInitedContext(r, w))
}
}

func ToHTTPHandler(h ContextHandler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeContext(GetContext(r))
h.ServeContext(getUnInitedContext(r, w))
})
}

Expand Down Expand Up @@ -61,14 +60,31 @@ type Context struct {
// use for user PipeHandler. avoid name conflict
PipeHandlerData map[string]interface{}

isInited bool
flash *Flash
staticData map[string][]string
opts *Options
xsrf XSRF
}

func NewContext(r *http.Request, w http.ResponseWriter, s *Server) *Context {
ctx := &Context{
func GetContext(r *http.Request) *Context {
obj, ok := context.GetOk(r, CONTEXT_KEY)
if !ok {
panic("can't get context")
}
ctx := obj.(*Context)
if !ctx.isInited {
panic("get uninited context")
}
return ctx
}

func initContext(r *http.Request, w http.ResponseWriter, s *Server) *Context {
ctx := getUnInitedContext(r, w)
if ctx.isInited {
return ctx
}
*ctx = Context{
ResponseWriter: w,
Request: r,
Server: s,
Expand All @@ -78,14 +94,17 @@ func NewContext(r *http.Request, w http.ResponseWriter, s *Server) *Context {
Data: map[string]interface{}{},
staticData: map[string][]string{},
}
ctx.isInited = true
context.Set(r, CONTEXT_KEY, ctx)
return ctx
}

func GetContext(r *http.Request) *Context {
func getUnInitedContext(r *http.Request, w http.ResponseWriter) *Context {
ctx, ok := context.GetOk(r, CONTEXT_KEY)
if !ok {
panic(fmt.Errorf("can't get context"))
newctx := &Context{Request: r, ResponseWriter: w}
context.Set(r, CONTEXT_KEY, newctx)
return newctx
}
return ctx.(*Context)
}
Expand Down
1 change: 0 additions & 1 deletion context_cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ func (ctx *Context) GetCookie(name string, value interface{}) error {
r := ctx.Request
cookie, err := r.Cookie(name)
if err != nil {
ctx.Logger.Errorf(err.Error())
return err
}
return cookier.Decode(name, cookie.Value, value)
Expand Down
2 changes: 1 addition & 1 deletion pipe_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func (s *Server) GetContextPipeHandler() PipeHandler {
s.Logger.Tracef("enter")
defer s.Logger.Tracef("exit")

NewContext(r, w, s)
initContext(r, w, s)
defer context.Clear(r)
next.ServeHTTP(w, r)
})
Expand Down
8 changes: 3 additions & 5 deletions pipe_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@ import (
"time"

"code.1dmy.com/xyz/logex"

"github.com/codegangsta/negroni"
)

func (s *Server) GetLogPipeHandler() PipeHandler {
return PipeHandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.Handler) {
return PipeHandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
s.Logger.Tracef("enter")
defer s.Logger.Tracef("exit")

start := time.Now()
next.ServeHTTP(rw, r)
next.ServeHTTP(w, r)

res := rw.(negroni.ResponseWriter)
res := w.(ResponseWriter)
log := logex.Infof
status := res.Status()
if status >= 500 && status <= 599 {
Expand Down
6 changes: 3 additions & 3 deletions pipe_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ import (
)

func (s *Server) GetRecoverPipeHandler() PipeHandler {
return PipeHandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.Handler) {
return PipeHandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) {
s.Logger.Tracef("enter")
defer s.Logger.Tracef("exit")

defer func() {
if err := recover(); err != nil {
rw.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(http.StatusInternalServerError)

var stacks []string
for i := 1; ; i += 1 {
Expand All @@ -31,6 +31,6 @@ func (s *Server) GetRecoverPipeHandler() PipeHandler {
}
}()

next.ServeHTTP(rw, r)
next.ServeHTTP(w, r)
})
}
96 changes: 96 additions & 0 deletions response_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package xingyun

import (
"bufio"
"fmt"
"net"
"net/http"
)

// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about
// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter
// if the functionality calls for it.
type ResponseWriter interface {
http.ResponseWriter
http.Flusher
// Status returns the status code of the response or 0 if the response has not been written.
Status() int
// Written returns whether or not the ResponseWriter has been written.
Written() bool
// Size returns the size of the response body.
Size() int
// Before allows for a function to be called before the ResponseWriter has been written to. This is
// useful for setting headers or any other operations that must happen before a response has been written.
Before(func(ResponseWriter))
}

type beforeFunc func(ResponseWriter)

// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter
func NewResponseWriter(rw http.ResponseWriter) ResponseWriter {
return &responseWriter{rw, 0, 0, nil}
}

type responseWriter struct {
http.ResponseWriter
status int
size int
beforeFuncs []beforeFunc
}

func (rw *responseWriter) WriteHeader(s int) {
rw.callBefore()
rw.ResponseWriter.WriteHeader(s)
rw.status = s
}

func (rw *responseWriter) Write(b []byte) (int, error) {
if !rw.Written() {
// The status will be StatusOK if WriteHeader has not been called yet
rw.WriteHeader(http.StatusOK)
}
size, err := rw.ResponseWriter.Write(b)
rw.size += size
return size, err
}

func (rw *responseWriter) Status() int {
return rw.status
}

func (rw *responseWriter) Size() int {
return rw.size
}

func (rw *responseWriter) Written() bool {
return rw.status != 0
}

func (rw *responseWriter) Before(before func(ResponseWriter)) {
rw.beforeFuncs = append(rw.beforeFuncs, before)
}

func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}

func (rw *responseWriter) CloseNotify() <-chan bool {
return rw.ResponseWriter.(http.CloseNotifier).CloseNotify()
}

func (rw *responseWriter) callBefore() {
for i := len(rw.beforeFuncs) - 1; i >= 0; i-- {
rw.beforeFuncs[i](rw)
}
}

func (rw *responseWriter) Flush() {
flusher, ok := rw.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}
3 changes: 2 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func NewServer(config *Config) *Server {
server := &Server{
Router: NewRouter(),
Logger: logex.NewLogger(1),
Config: config,
}
server.StaticDir = http.Dir(config.StaticDir)
server.SecureCookie = securecookie.New([]byte(config.CookieSecret), []byte(config.CookieSecret))
Expand All @@ -43,7 +44,7 @@ func NewServer(config *Config) *Server {
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.Router.ServeHTTP(w, r)
s.Router.ServeHTTP(NewResponseWriter(w), r)
}

func (s *Server) NewPipe(name string, handlers ...PipeHandler) *Pipe {
Expand Down

0 comments on commit 12e9ccb

Please sign in to comment.