From abcf5a7b5e2c3df951b8048317a99a89b040b489 Mon Sep 17 00:00:00 2001
From: wxiaoguang <wxiaoguang@gmail.com>
Date: Tue, 23 May 2023 09:29:15 +0800
Subject: [PATCH] Fix install page context, make the install page tests really
 test (#24858)

Fix #24856

Rename "context.contextKey" to "context.WebContextKey", this context is
for web context only. But the Context itself is not renamed, otherwise
it would cause a lot of changes (if we really want to rename it, there
could be a separate PR).

The old test code doesn't really test, the "install page" gets broken
not only one time, so use new test code to make sure the "install page"
could work.
---
 cmd/web.go                     |  4 +---
 modules/context/context.go     | 12 +++++-----
 modules/context/package.go     |  2 +-
 modules/web/handler.go         |  2 +-
 routers/install/install.go     |  3 ++-
 routers/install/routes.go      |  3 +--
 routers/install/routes_test.go | 41 ++++++++++++++++++++++++----------
 routers/web/web.go             |  2 +-
 services/auth/auth.go          |  2 +-
 9 files changed, 43 insertions(+), 28 deletions(-)

diff --git a/cmd/web.go b/cmd/web.go
index bc344db54017..da6c987ff845 100644
--- a/cmd/web.go
+++ b/cmd/web.go
@@ -142,10 +142,8 @@ func runWeb(ctx *cli.Context) error {
 				return err
 			}
 		}
-		installCtx, cancel := context.WithCancel(graceful.GetManager().HammerContext())
-		c := install.Routes(installCtx)
+		c := install.Routes()
 		err := listen(c, false)
-		cancel()
 		if err != nil {
 			log.Critical("Unable to open listener for installer. Is Gitea already running?")
 			graceful.GetManager().DoGracefulShutdown()
diff --git a/modules/context/context.go b/modules/context/context.go
index 1e15081479ff..9e351432c4cd 100644
--- a/modules/context/context.go
+++ b/modules/context/context.go
@@ -68,12 +68,12 @@ func (ctx *Context) TrHTMLEscapeArgs(msg string, args ...string) string {
 	return ctx.Locale.Tr(msg, trArgs...)
 }
 
-type contextKeyType struct{}
+type webContextKeyType struct{}
 
-var contextKey interface{} = contextKeyType{}
+var WebContextKey = webContextKeyType{}
 
-func GetContext(req *http.Request) *Context {
-	ctx, _ := req.Context().Value(contextKey).(*Context)
+func GetWebContext(req *http.Request) *Context {
+	ctx, _ := req.Context().Value(WebContextKey).(*Context)
 	return ctx
 }
 
@@ -86,7 +86,7 @@ type ValidateContext struct {
 func GetValidateContext(req *http.Request) (ctx *ValidateContext) {
 	if ctxAPI, ok := req.Context().Value(apiContextKey).(*APIContext); ok {
 		ctx = &ValidateContext{Base: ctxAPI.Base}
-	} else if ctxWeb, ok := req.Context().Value(contextKey).(*Context); ok {
+	} else if ctxWeb, ok := req.Context().Value(WebContextKey).(*Context); ok {
 		ctx = &ValidateContext{Base: ctxWeb.Base}
 	} else {
 		panic("invalid context, expect either APIContext or Context")
@@ -135,7 +135,7 @@ func Contexter() func(next http.Handler) http.Handler {
 			ctx.PageData = map[string]any{}
 			ctx.Data["PageData"] = ctx.PageData
 
-			ctx.Base.AppendContextValue(contextKey, ctx)
+			ctx.Base.AppendContextValue(WebContextKey, ctx)
 			ctx.Base.AppendContextValueFunc(git.RepositoryContextKey, func() any { return ctx.Repo.GitRepo })
 
 			ctx.Csrf = PrepareCSRFProtector(csrfOpts, ctx)
diff --git a/modules/context/package.go b/modules/context/package.go
index b1fd7088ddc4..805203278711 100644
--- a/modules/context/package.go
+++ b/modules/context/package.go
@@ -150,7 +150,7 @@ func PackageContexter() func(next http.Handler) http.Handler {
 			}
 			defer baseCleanUp()
 
-			ctx.Base.AppendContextValue(contextKey, ctx)
+			ctx.Base.AppendContextValue(WebContextKey, ctx)
 			next.ServeHTTP(ctx.Resp, ctx.Req)
 		})
 	}
diff --git a/modules/web/handler.go b/modules/web/handler.go
index 5013bac93f64..c8aebd90518d 100644
--- a/modules/web/handler.go
+++ b/modules/web/handler.go
@@ -22,7 +22,7 @@ type ResponseStatusProvider interface {
 // TODO: decouple this from the context package, let the context package register these providers
 var argTypeProvider = map[reflect.Type]func(req *http.Request) ResponseStatusProvider{
 	reflect.TypeOf(&context.APIContext{}):     func(req *http.Request) ResponseStatusProvider { return context.GetAPIContext(req) },
-	reflect.TypeOf(&context.Context{}):        func(req *http.Request) ResponseStatusProvider { return context.GetContext(req) },
+	reflect.TypeOf(&context.Context{}):        func(req *http.Request) ResponseStatusProvider { return context.GetWebContext(req) },
 	reflect.TypeOf(&context.PrivateContext{}): func(req *http.Request) ResponseStatusProvider { return context.GetPrivateContext(req) },
 }
 
diff --git a/routers/install/install.go b/routers/install/install.go
index 89b91a5a483a..4635cd7cb60f 100644
--- a/routers/install/install.go
+++ b/routers/install/install.go
@@ -59,7 +59,7 @@ func Contexter() func(next http.Handler) http.Handler {
 	return func(next http.Handler) http.Handler {
 		return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
 			base, baseCleanUp := context.NewBaseContext(resp, req)
-			ctx := context.Context{
+			ctx := &context.Context{
 				Base:    base,
 				Flash:   &middleware.Flash{},
 				Render:  rnd,
@@ -67,6 +67,7 @@ func Contexter() func(next http.Handler) http.Handler {
 			}
 			defer baseCleanUp()
 
+			ctx.AppendContextValue(context.WebContextKey, ctx)
 			ctx.Data.MergeFrom(middleware.CommonTemplateContextData())
 			ctx.Data.MergeFrom(middleware.ContextData{
 				"locale":        ctx.Locale,
diff --git a/routers/install/routes.go b/routers/install/routes.go
index 52c07cfa26e6..f09a22b1e601 100644
--- a/routers/install/routes.go
+++ b/routers/install/routes.go
@@ -4,7 +4,6 @@
 package install
 
 import (
-	goctx "context"
 	"fmt"
 	"html"
 	"net/http"
@@ -18,7 +17,7 @@ import (
 )
 
 // Routes registers the installation routes
-func Routes(ctx goctx.Context) *web.Route {
+func Routes() *web.Route {
 	base := web.NewRoute()
 	base.Use(common.ProtocolMiddlewares()...)
 	base.RouteMethods("/assets/*", "GET, HEAD", public.AssetsHandlerFunc("/assets/"))
diff --git a/routers/install/routes_test.go b/routers/install/routes_test.go
index e3d2a4246740..fcbd05297742 100644
--- a/routers/install/routes_test.go
+++ b/routers/install/routes_test.go
@@ -1,24 +1,41 @@
-// Copyright 2021 The Gitea Authors. All rights reserved.
+// Copyright 2023 The Gitea Authors. All rights reserved.
 // SPDX-License-Identifier: MIT
 
 package install
 
 import (
-	"context"
+	"net/http/httptest"
+	"path/filepath"
 	"testing"
 
+	"code.gitea.io/gitea/models/unittest"
+
 	"github.com/stretchr/testify/assert"
 )
 
 func TestRoutes(t *testing.T) {
-	// TODO: this test seems not really testing the handlers
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	base := Routes(ctx)
-	assert.NotNil(t, base)
-	r := base.R.Routes()[1]
-	routes := r.SubRoutes.Routes()[0]
-	assert.EqualValues(t, "/", routes.Pattern)
-	assert.Nil(t, routes.SubRoutes)
-	assert.Len(t, routes.Handlers, 2)
+	r := Routes()
+	assert.NotNil(t, r)
+
+	w := httptest.NewRecorder()
+	req := httptest.NewRequest("GET", "/", nil)
+	r.ServeHTTP(w, req)
+	assert.EqualValues(t, 200, w.Code)
+	assert.Contains(t, w.Body.String(), `class="page-content install"`)
+
+	w = httptest.NewRecorder()
+	req = httptest.NewRequest("GET", "/no-such", nil)
+	r.ServeHTTP(w, req)
+	assert.EqualValues(t, 404, w.Code)
+
+	w = httptest.NewRecorder()
+	req = httptest.NewRequest("GET", "/assets/img/gitea.svg", nil)
+	r.ServeHTTP(w, req)
+	assert.EqualValues(t, 200, w.Code)
+}
+
+func TestMain(m *testing.M) {
+	unittest.MainTest(m, &unittest.TestOptions{
+		GiteaRootPath: filepath.Join("..", ".."),
+	})
 }
diff --git a/routers/web/web.go b/routers/web/web.go
index c230d33398c8..395fc9425f22 100644
--- a/routers/web/web.go
+++ b/routers/web/web.go
@@ -1405,7 +1405,7 @@ func registerRoutes(m *web.Route) {
 	}
 
 	m.NotFound(func(w http.ResponseWriter, req *http.Request) {
-		ctx := context.GetContext(req)
+		ctx := context.GetWebContext(req)
 		ctx.NotFound("", nil)
 	})
 }
diff --git a/services/auth/auth.go b/services/auth/auth.go
index 905c776e5871..c7fdc56cbed0 100644
--- a/services/auth/auth.go
+++ b/services/auth/auth.go
@@ -92,7 +92,7 @@ func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore
 	middleware.SetLocaleCookie(resp, user.Language, 0)
 
 	// Clear whatever CSRF has right now, force to generate a new one
-	if ctx := gitea_context.GetContext(req); ctx != nil {
+	if ctx := gitea_context.GetWebContext(req); ctx != nil {
 		ctx.Csrf.DeleteCookie(ctx)
 	}
 }