diff --git a/modules/context/repo.go b/modules/context/repo.go index a77d075de2b2..c1f60a13624f 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -6,6 +6,7 @@ package context import ( + "context" "fmt" "io/ioutil" "net/url" @@ -393,7 +394,7 @@ func RepoIDAssignment() func(ctx *Context) { } // RepoAssignment returns a middleware to handle repository assignment -func RepoAssignment(ctx *Context) { +func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { var ( owner *models.User err error @@ -529,12 +530,12 @@ func RepoAssignment(ctx *Context) { ctx.Repo.GitRepo = gitRepo // We opened it, we should close it - defer func() { + cancel = func() { // If it's been set to nil then assume someone else has closed it. if ctx.Repo.GitRepo != nil { ctx.Repo.GitRepo.Close() } - }() + } // Stop at this point when the repo is empty. if ctx.Repo.Repository.IsEmpty { @@ -619,6 +620,7 @@ func RepoAssignment(ctx *Context) { ctx.Data["GoDocDirectory"] = prefix + "{/dir}" ctx.Data["GoDocFile"] = prefix + "{/dir}/{file}#L{line}" } + return } // RepoRefType type of repo reference @@ -643,7 +645,7 @@ const ( // RepoRef handles repository reference names when the ref name is not // explicitly given -func RepoRef() func(*Context) { +func RepoRef() func(*Context) context.CancelFunc { // since no ref name is explicitly specified, ok to just use branch return RepoRefByType(RepoRefBranch) } @@ -722,8 +724,8 @@ func getRefName(ctx *Context, pathType RepoRefType) string { // RepoRefByType handles repository reference name for a specific type // of repository reference -func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context) { - return func(ctx *Context) { +func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context) context.CancelFunc { + return func(ctx *Context) (cancel context.CancelFunc) { // Empty repository does not have reference information. if ctx.Repo.Repository.IsEmpty { return @@ -742,12 +744,12 @@ func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context return } // We opened it, we should close it - defer func() { + cancel = func() { // If it's been set to nil then assume someone else has closed it. if ctx.Repo.GitRepo != nil { ctx.Repo.GitRepo.Close() } - }() + } } // Get default branch. @@ -844,6 +846,7 @@ func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context return } ctx.Data["CommitsCount"] = ctx.Repo.CommitsCount + return } } diff --git a/modules/web/route.go b/modules/web/route.go index 6f9e76bdf389..319d08f5981c 100644 --- a/modules/web/route.go +++ b/modules/web/route.go @@ -5,6 +5,7 @@ package web import ( + goctx "context" "fmt" "net/http" "reflect" @@ -27,6 +28,7 @@ func Wrap(handlers ...interface{}) http.HandlerFunc { switch t := handler.(type) { case http.HandlerFunc, func(http.ResponseWriter, *http.Request), func(ctx *context.Context), + func(ctx *context.Context) goctx.CancelFunc, func(*context.APIContext), func(*context.PrivateContext), func(http.Handler) http.Handler: @@ -48,6 +50,15 @@ func Wrap(handlers ...interface{}) http.HandlerFunc { if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { return } + case func(ctx *context.Context) goctx.CancelFunc: + ctx := context.GetContext(req) + cancel := t(ctx) + if cancel != nil { + defer cancel() + } + if ctx.Written() { + return + } case func(ctx *context.Context): ctx := context.GetContext(req) t(ctx) @@ -94,6 +105,23 @@ func Middle(f func(ctx *context.Context)) func(netx http.Handler) http.Handler { } } +// MiddleCancel wrap a context function as a chi middleware +func MiddleCancel(f func(ctx *context.Context) goctx.CancelFunc) func(netx http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + ctx := context.GetContext(req) + cancel := f(ctx) + if cancel != nil { + defer cancel() + } + if ctx.Written() { + return + } + next.ServeHTTP(ctx.Resp, ctx.Req) + }) + } +} + // MiddleAPI wrap a context function as a chi middleware func MiddleAPI(f func(ctx *context.APIContext)) func(netx http.Handler) http.Handler { return func(next http.Handler) http.Handler { @@ -163,6 +191,8 @@ func (r *Route) Use(middlewares ...interface{}) { r.R.Use(t) case func(*context.Context): r.R.Use(Middle(t)) + case func(*context.Context) goctx.CancelFunc: + r.R.Use(MiddleCancel(t)) case func(*context.APIContext): r.R.Use(MiddleAPI(t)) default: