diff --git a/context.go b/context.go index ff26553..c08c781 100644 --- a/context.go +++ b/context.go @@ -1,7 +1,6 @@ package xingyun import ( - "fmt" "net/http" "github.com/gorilla/context" @@ -23,13 +22,13 @@ func (h ContextHandlerFunc) ServeContext(ctx *Context) { func ToHTTPHandlerFunc(h ContextHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - h.ServeContext(GetContext(r)) + h.ServeContext(getUnInitedContext(r, w)) } } func ToHTTPHandler(h ContextHandler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h.ServeContext(GetContext(r)) + h.ServeContext(getUnInitedContext(r, w)) }) } @@ -61,14 +60,31 @@ type Context struct { // use for user PipeHandler. avoid name conflict PipeHandlerData map[string]interface{} + isInited bool flash *Flash staticData map[string][]string opts *Options xsrf XSRF } -func NewContext(r *http.Request, w http.ResponseWriter, s *Server) *Context { - ctx := &Context{ +func GetContext(r *http.Request) *Context { + obj, ok := context.GetOk(r, CONTEXT_KEY) + if !ok { + panic("can't get context") + } + ctx := obj.(*Context) + if !ctx.isInited { + panic("get uninited context") + } + return ctx +} + +func initContext(r *http.Request, w http.ResponseWriter, s *Server) *Context { + ctx := getUnInitedContext(r, w) + if ctx.isInited { + return ctx + } + *ctx = Context{ ResponseWriter: w, Request: r, Server: s, @@ -78,14 +94,17 @@ func NewContext(r *http.Request, w http.ResponseWriter, s *Server) *Context { Data: map[string]interface{}{}, staticData: map[string][]string{}, } + ctx.isInited = true context.Set(r, CONTEXT_KEY, ctx) return ctx } -func GetContext(r *http.Request) *Context { +func getUnInitedContext(r *http.Request, w http.ResponseWriter) *Context { ctx, ok := context.GetOk(r, CONTEXT_KEY) if !ok { - panic(fmt.Errorf("can't get context")) + newctx := &Context{Request: r, ResponseWriter: w} + context.Set(r, CONTEXT_KEY, newctx) + return newctx } return ctx.(*Context) } diff --git a/context_cookie.go b/context_cookie.go index 0f63c22..f984f80 100644 --- a/context_cookie.go +++ b/context_cookie.go @@ -32,7 +32,6 @@ func (ctx *Context) GetCookie(name string, value interface{}) error { r := ctx.Request cookie, err := r.Cookie(name) if err != nil { - ctx.Logger.Errorf(err.Error()) return err } return cookier.Decode(name, cookie.Value, value) diff --git a/pipe_context.go b/pipe_context.go index 7f35aa1..c23024f 100644 --- a/pipe_context.go +++ b/pipe_context.go @@ -11,7 +11,7 @@ func (s *Server) GetContextPipeHandler() PipeHandler { s.Logger.Tracef("enter") defer s.Logger.Tracef("exit") - NewContext(r, w, s) + initContext(r, w, s) defer context.Clear(r) next.ServeHTTP(w, r) }) diff --git a/pipe_logger.go b/pipe_logger.go index 5c8c39f..5b45e5d 100644 --- a/pipe_logger.go +++ b/pipe_logger.go @@ -5,19 +5,17 @@ import ( "time" "code.1dmy.com/xyz/logex" - - "github.com/codegangsta/negroni" ) func (s *Server) GetLogPipeHandler() PipeHandler { - return PipeHandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.Handler) { + return PipeHandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { s.Logger.Tracef("enter") defer s.Logger.Tracef("exit") start := time.Now() - next.ServeHTTP(rw, r) + next.ServeHTTP(w, r) - res := rw.(negroni.ResponseWriter) + res := w.(ResponseWriter) log := logex.Infof status := res.Status() if status >= 500 && status <= 599 { diff --git a/pipe_recovery.go b/pipe_recovery.go index c95cff1..085b338 100644 --- a/pipe_recovery.go +++ b/pipe_recovery.go @@ -10,13 +10,13 @@ import ( ) func (s *Server) GetRecoverPipeHandler() PipeHandler { - return PipeHandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.Handler) { + return PipeHandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.Handler) { s.Logger.Tracef("enter") defer s.Logger.Tracef("exit") defer func() { if err := recover(); err != nil { - rw.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusInternalServerError) var stacks []string for i := 1; ; i += 1 { @@ -31,6 +31,6 @@ func (s *Server) GetRecoverPipeHandler() PipeHandler { } }() - next.ServeHTTP(rw, r) + next.ServeHTTP(w, r) }) } diff --git a/response_writer.go b/response_writer.go new file mode 100644 index 0000000..524b8d8 --- /dev/null +++ b/response_writer.go @@ -0,0 +1,96 @@ +package xingyun + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about +// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter +// if the functionality calls for it. +type ResponseWriter interface { + http.ResponseWriter + http.Flusher + // Status returns the status code of the response or 0 if the response has not been written. + Status() int + // Written returns whether or not the ResponseWriter has been written. + Written() bool + // Size returns the size of the response body. + Size() int + // Before allows for a function to be called before the ResponseWriter has been written to. This is + // useful for setting headers or any other operations that must happen before a response has been written. + Before(func(ResponseWriter)) +} + +type beforeFunc func(ResponseWriter) + +// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter +func NewResponseWriter(rw http.ResponseWriter) ResponseWriter { + return &responseWriter{rw, 0, 0, nil} +} + +type responseWriter struct { + http.ResponseWriter + status int + size int + beforeFuncs []beforeFunc +} + +func (rw *responseWriter) WriteHeader(s int) { + rw.callBefore() + rw.ResponseWriter.WriteHeader(s) + rw.status = s +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.Written() { + // The status will be StatusOK if WriteHeader has not been called yet + rw.WriteHeader(http.StatusOK) + } + size, err := rw.ResponseWriter.Write(b) + rw.size += size + return size, err +} + +func (rw *responseWriter) Status() int { + return rw.status +} + +func (rw *responseWriter) Size() int { + return rw.size +} + +func (rw *responseWriter) Written() bool { + return rw.status != 0 +} + +func (rw *responseWriter) Before(before func(ResponseWriter)) { + rw.beforeFuncs = append(rw.beforeFuncs, before) +} + +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface") + } + return hijacker.Hijack() +} + +func (rw *responseWriter) CloseNotify() <-chan bool { + return rw.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + +func (rw *responseWriter) callBefore() { + for i := len(rw.beforeFuncs) - 1; i >= 0; i-- { + rw.beforeFuncs[i](rw) + } +} + +func (rw *responseWriter) Flush() { + flusher, ok := rw.ResponseWriter.(http.Flusher) + if ok { + flusher.Flush() + } +} diff --git a/server.go b/server.go index 4a31a87..784ec71 100644 --- a/server.go +++ b/server.go @@ -26,6 +26,7 @@ func NewServer(config *Config) *Server { server := &Server{ Router: NewRouter(), Logger: logex.NewLogger(1), + Config: config, } server.StaticDir = http.Dir(config.StaticDir) server.SecureCookie = securecookie.New([]byte(config.CookieSecret), []byte(config.CookieSecret)) @@ -43,7 +44,7 @@ func NewServer(config *Config) *Server { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.Router.ServeHTTP(w, r) + s.Router.ServeHTTP(NewResponseWriter(w), r) } func (s *Server) NewPipe(name string, handlers ...PipeHandler) *Pipe {