From 9e6e1dc950f06bbd000d5b6438f39113e8902082 Mon Sep 17 00:00:00 2001
From: zeripath <art27@cantab.net>
Date: Wed, 8 Dec 2021 19:08:16 +0000
Subject: [PATCH] Improve checkBranchName (#17901)

The current implementation of checkBranchName is highly inefficient
involving opening the repository, the listing all of the branch names
checking them individually before then using using opened repo to get
the tags.

This PR avoids this by simply walking the references from show-ref
instead of opening the repository (in the nogogit case).

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 modules/context/repo.go            |  4 +--
 modules/git/repo_branch.go         | 11 +++++--
 modules/git/repo_branch_gogit.go   | 26 ++++++++++++++-
 modules/git/repo_branch_nogogit.go | 50 ++++++++++++++++++++--------
 modules/git/repo_branch_test.go    |  8 ++---
 routers/web/repo/branch.go         |  4 +--
 routers/web/repo/compare.go        |  4 +--
 routers/web/repo/issue.go          |  2 +-
 services/repository/adopt.go       |  2 +-
 services/repository/branch.go      | 53 ++++++++++++++----------------
 10 files changed, 106 insertions(+), 58 deletions(-)

diff --git a/modules/context/repo.go b/modules/context/repo.go
index 159fd07d9d18..b2844c04c4c1 100644
--- a/modules/context/repo.go
+++ b/modules/context/repo.go
@@ -584,7 +584,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) {
 	}
 	ctx.Data["Tags"] = tags
 
-	brs, _, err := ctx.Repo.GitRepo.GetBranches(0, 0)
+	brs, _, err := ctx.Repo.GitRepo.GetBranchNames(0, 0)
 	if err != nil {
 		ctx.ServerError("GetBranches", err)
 		return
@@ -810,7 +810,7 @@ func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context
 		if len(ctx.Params("*")) == 0 {
 			refName = ctx.Repo.Repository.DefaultBranch
 			if !ctx.Repo.GitRepo.IsBranchExist(refName) {
-				brs, _, err := ctx.Repo.GitRepo.GetBranches(0, 0)
+				brs, _, err := ctx.Repo.GitRepo.GetBranchNames(0, 0)
 				if err != nil {
 					ctx.ServerError("GetBranches", err)
 					return
diff --git a/modules/git/repo_branch.go b/modules/git/repo_branch.go
index 98b1bc8ae7c7..01933d7ade6a 100644
--- a/modules/git/repo_branch.go
+++ b/modules/git/repo_branch.go
@@ -95,7 +95,12 @@ func GetBranchesByPath(path string, skip, limit int) ([]*Branch, int, error) {
 	}
 	defer gitRepo.Close()
 
-	brs, countAll, err := gitRepo.GetBranches(skip, limit)
+	return gitRepo.GetBranches(skip, limit)
+}
+
+// GetBranches returns a slice of *git.Branch
+func (repo *Repository) GetBranches(skip, limit int) ([]*Branch, int, error) {
+	brs, countAll, err := repo.GetBranchNames(skip, limit)
 	if err != nil {
 		return nil, 0, err
 	}
@@ -103,9 +108,9 @@ func GetBranchesByPath(path string, skip, limit int) ([]*Branch, int, error) {
 	branches := make([]*Branch, len(brs))
 	for i := range brs {
 		branches[i] = &Branch{
-			Path:    path,
+			Path:    repo.Path,
 			Name:    brs[i],
-			gitRepo: gitRepo,
+			gitRepo: repo,
 		}
 	}
 
diff --git a/modules/git/repo_branch_gogit.go b/modules/git/repo_branch_gogit.go
index 6bf14b399986..d159aafd6f3b 100644
--- a/modules/git/repo_branch_gogit.go
+++ b/modules/git/repo_branch_gogit.go
@@ -9,6 +9,7 @@
 package git
 
 import (
+	"context"
 	"strings"
 
 	"github.com/go-git/go-git/v5/plumbing"
@@ -52,7 +53,7 @@ func (repo *Repository) IsBranchExist(name string) bool {
 
 // GetBranches returns branches from the repository, skipping skip initial branches and
 // returning at most limit branches, or all branches if limit is 0.
-func (repo *Repository) GetBranches(skip, limit int) ([]string, int, error) {
+func (repo *Repository) GetBranchNames(skip, limit int) ([]string, int, error) {
 	var branchNames []string
 
 	branches, err := repo.gogitRepo.Branches()
@@ -79,3 +80,26 @@ func (repo *Repository) GetBranches(skip, limit int) ([]string, int, error) {
 
 	return branchNames, count, nil
 }
+
+// WalkReferences walks all the references from the repository
+func WalkReferences(ctx context.Context, repoPath string, walkfn func(string) error) (int, error) {
+	repo, err := OpenRepositoryCtx(ctx, repoPath)
+	if err != nil {
+		return 0, err
+	}
+	defer repo.Close()
+
+	i := 0
+	iter, err := repo.gogitRepo.References()
+	if err != nil {
+		return i, err
+	}
+	defer iter.Close()
+
+	err = iter.ForEach(func(ref *plumbing.Reference) error {
+		err := walkfn(string(ref.Name()))
+		i++
+		return err
+	})
+	return i, err
+}
diff --git a/modules/git/repo_branch_nogogit.go b/modules/git/repo_branch_nogogit.go
index 1928c7515bca..55952acda4a3 100644
--- a/modules/git/repo_branch_nogogit.go
+++ b/modules/git/repo_branch_nogogit.go
@@ -61,14 +61,29 @@ func (repo *Repository) IsBranchExist(name string) bool {
 	return repo.IsReferenceExist(BranchPrefix + name)
 }
 
-// GetBranches returns branches from the repository, skipping skip initial branches and
+// GetBranchNames returns branches from the repository, skipping skip initial branches and
 // returning at most limit branches, or all branches if limit is 0.
-func (repo *Repository) GetBranches(skip, limit int) ([]string, int, error) {
+func (repo *Repository) GetBranchNames(skip, limit int) ([]string, int, error) {
 	return callShowRef(repo.Ctx, repo.Path, BranchPrefix, "--heads", skip, limit)
 }
 
+// WalkReferences walks all the references from the repository
+func WalkReferences(ctx context.Context, repoPath string, walkfn func(string) error) (int, error) {
+	return walkShowRef(ctx, repoPath, "", 0, 0, walkfn)
+}
+
 // callShowRef return refs, if limit = 0 it will not limit
 func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit int) (branchNames []string, countAll int, err error) {
+	countAll, err = walkShowRef(ctx, repoPath, arg, skip, limit, func(branchName string) error {
+		branchName = strings.TrimPrefix(branchName, prefix)
+		branchNames = append(branchNames, branchName)
+
+		return nil
+	})
+	return
+}
+
+func walkShowRef(ctx context.Context, repoPath, arg string, skip, limit int, walkfn func(string) error) (countAll int, err error) {
 	stdoutReader, stdoutWriter := io.Pipe()
 	defer func() {
 		_ = stdoutReader.Close()
@@ -77,7 +92,11 @@ func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit
 
 	go func() {
 		stderrBuilder := &strings.Builder{}
-		err := NewCommandContext(ctx, "show-ref", arg).RunInDirPipeline(repoPath, stdoutWriter, stderrBuilder)
+		args := []string{"show-ref"}
+		if arg != "" {
+			args = append(args, arg)
+		}
+		err := NewCommandContext(ctx, args...).RunInDirPipeline(repoPath, stdoutWriter, stderrBuilder)
 		if err != nil {
 			if stderrBuilder.Len() == 0 {
 				_ = stdoutWriter.Close()
@@ -94,10 +113,10 @@ func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit
 	for i < skip {
 		_, isPrefix, err := bufReader.ReadLine()
 		if err == io.EOF {
-			return branchNames, i, nil
+			return i, nil
 		}
 		if err != nil {
-			return nil, 0, err
+			return 0, err
 		}
 		if !isPrefix {
 			i++
@@ -112,39 +131,42 @@ func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit
 			_, err = bufReader.ReadSlice(' ')
 		}
 		if err == io.EOF {
-			return branchNames, i, nil
+			return i, nil
 		}
 		if err != nil {
-			return nil, 0, err
+			return 0, err
 		}
 
 		branchName, err := bufReader.ReadString('\n')
 		if err == io.EOF {
 			// This shouldn't happen... but we'll tolerate it for the sake of peace
-			return branchNames, i, nil
+			return i, nil
 		}
 		if err != nil {
-			return nil, i, err
+			return i, err
 		}
-		branchName = strings.TrimPrefix(branchName, prefix)
+
 		if len(branchName) > 0 {
 			branchName = branchName[:len(branchName)-1]
 		}
-		branchNames = append(branchNames, branchName)
+		err = walkfn(branchName)
+		if err != nil {
+			return i, err
+		}
 		i++
 	}
 	// count all refs
 	for limit != 0 {
 		_, isPrefix, err := bufReader.ReadLine()
 		if err == io.EOF {
-			return branchNames, i, nil
+			return i, nil
 		}
 		if err != nil {
-			return nil, 0, err
+			return 0, err
 		}
 		if !isPrefix {
 			i++
 		}
 	}
-	return branchNames, i, nil
+	return i, nil
 }
diff --git a/modules/git/repo_branch_test.go b/modules/git/repo_branch_test.go
index 05d5237e6a65..ac5f5deea9be 100644
--- a/modules/git/repo_branch_test.go
+++ b/modules/git/repo_branch_test.go
@@ -17,21 +17,21 @@ func TestRepository_GetBranches(t *testing.T) {
 	assert.NoError(t, err)
 	defer bareRepo1.Close()
 
-	branches, countAll, err := bareRepo1.GetBranches(0, 2)
+	branches, countAll, err := bareRepo1.GetBranchNames(0, 2)
 
 	assert.NoError(t, err)
 	assert.Len(t, branches, 2)
 	assert.EqualValues(t, 3, countAll)
 	assert.ElementsMatch(t, []string{"branch1", "branch2"}, branches)
 
-	branches, countAll, err = bareRepo1.GetBranches(0, 0)
+	branches, countAll, err = bareRepo1.GetBranchNames(0, 0)
 
 	assert.NoError(t, err)
 	assert.Len(t, branches, 3)
 	assert.EqualValues(t, 3, countAll)
 	assert.ElementsMatch(t, []string{"branch1", "branch2", "master"}, branches)
 
-	branches, countAll, err = bareRepo1.GetBranches(5, 1)
+	branches, countAll, err = bareRepo1.GetBranchNames(5, 1)
 
 	assert.NoError(t, err)
 	assert.Len(t, branches, 0)
@@ -48,7 +48,7 @@ func BenchmarkRepository_GetBranches(b *testing.B) {
 	defer bareRepo1.Close()
 
 	for i := 0; i < b.N; i++ {
-		_, _, err := bareRepo1.GetBranches(0, 0)
+		_, _, err := bareRepo1.GetBranchNames(0, 0)
 		if err != nil {
 			b.Fatal(err)
 		}
diff --git a/routers/web/repo/branch.go b/routers/web/repo/branch.go
index 05b45eba4b20..9c2518059699 100644
--- a/routers/web/repo/branch.go
+++ b/routers/web/repo/branch.go
@@ -165,14 +165,14 @@ func redirect(ctx *context.Context) {
 // loadBranches loads branches from the repository limited by page & pageSize.
 // NOTE: May write to context on error.
 func loadBranches(ctx *context.Context, skip, limit int) ([]*Branch, int) {
-	defaultBranch, err := repo_service.GetBranch(ctx.Repo.Repository, ctx.Repo.Repository.DefaultBranch)
+	defaultBranch, err := ctx.Repo.GitRepo.GetBranch(ctx.Repo.Repository.DefaultBranch)
 	if err != nil {
 		log.Error("loadBranches: get default branch: %v", err)
 		ctx.ServerError("GetDefaultBranch", err)
 		return nil, 0
 	}
 
-	rawBranches, totalNumOfBranches, err := repo_service.GetBranches(ctx.Repo.Repository, skip, limit)
+	rawBranches, totalNumOfBranches, err := ctx.Repo.GitRepo.GetBranches(skip, limit)
 	if err != nil {
 		log.Error("GetBranches: %v", err)
 		ctx.ServerError("GetBranches", err)
diff --git a/routers/web/repo/compare.go b/routers/web/repo/compare.go
index 54d7e77f2d4e..4cd817a39966 100644
--- a/routers/web/repo/compare.go
+++ b/routers/web/repo/compare.go
@@ -660,7 +660,7 @@ func getBranchesAndTagsForRepo(repo *models.Repository) (branches, tags []string
 	}
 	defer gitRepo.Close()
 
-	branches, _, err = gitRepo.GetBranches(0, 0)
+	branches, _, err = gitRepo.GetBranchNames(0, 0)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -711,7 +711,7 @@ func CompareDiff(ctx *context.Context) {
 		return
 	}
 
-	headBranches, _, err := ci.HeadGitRepo.GetBranches(0, 0)
+	headBranches, _, err := ci.HeadGitRepo.GetBranchNames(0, 0)
 	if err != nil {
 		ctx.ServerError("GetBranches", err)
 		return
diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go
index f0857b18c0eb..398aa26cc49d 100644
--- a/routers/web/repo/issue.go
+++ b/routers/web/repo/issue.go
@@ -690,7 +690,7 @@ func RetrieveRepoMetas(ctx *context.Context, repo *models.Repository, isPull boo
 		return nil
 	}
 
-	brs, _, err := ctx.Repo.GitRepo.GetBranches(0, 0)
+	brs, _, err := ctx.Repo.GitRepo.GetBranchNames(0, 0)
 	if err != nil {
 		ctx.ServerError("GetBranches", err)
 		return nil
diff --git a/services/repository/adopt.go b/services/repository/adopt.go
index 3f4045a77831..5503155ab038 100644
--- a/services/repository/adopt.go
+++ b/services/repository/adopt.go
@@ -142,7 +142,7 @@ func adoptRepository(ctx context.Context, repoPath string, u *user_model.User, r
 
 		repo.DefaultBranch = strings.TrimPrefix(repo.DefaultBranch, git.BranchPrefix)
 	}
-	branches, _, _ := gitRepo.GetBranches(0, 0)
+	branches, _, _ := gitRepo.GetBranchNames(0, 0)
 	found := false
 	hasDefault := false
 	hasMaster := false
diff --git a/services/repository/branch.go b/services/repository/branch.go
index f33bac762181..08310134bdf6 100644
--- a/services/repository/branch.go
+++ b/services/repository/branch.go
@@ -5,8 +5,10 @@
 package repository
 
 import (
+	"context"
 	"errors"
 	"fmt"
+	"strings"
 
 	"code.gitea.io/gitea/models"
 	user_model "code.gitea.io/gitea/models/user"
@@ -20,7 +22,7 @@ import (
 // CreateNewBranch creates a new repository branch
 func CreateNewBranch(doer *user_model.User, repo *models.Repository, oldBranchName, branchName string) (err error) {
 	// Check if branch name can be used
-	if err := checkBranchName(repo, branchName); err != nil {
+	if err := checkBranchName(git.DefaultContext, repo, branchName); err != nil {
 		return err
 	}
 
@@ -65,44 +67,39 @@ func GetBranches(repo *models.Repository, skip, limit int) ([]*git.Branch, int,
 }
 
 // checkBranchName validates branch name with existing repository branches
-func checkBranchName(repo *models.Repository, name string) error {
-	gitRepo, err := git.OpenRepository(repo.RepoPath())
-	if err != nil {
-		return err
-	}
-	defer gitRepo.Close()
-
-	branches, _, err := GetBranches(repo, 0, 0)
-	if err != nil {
-		return err
-	}
-
-	for _, branch := range branches {
-		if branch.Name == name {
+func checkBranchName(ctx context.Context, repo *models.Repository, name string) error {
+	_, err := git.WalkReferences(ctx, repo.RepoPath(), func(refName string) error {
+		branchRefName := strings.TrimPrefix(refName, git.BranchPrefix)
+		switch {
+		case branchRefName == name:
 			return models.ErrBranchAlreadyExists{
-				BranchName: branch.Name,
+				BranchName: name,
 			}
-		} else if (len(branch.Name) < len(name) && branch.Name+"/" == name[0:len(branch.Name)+1]) ||
-			(len(branch.Name) > len(name) && name+"/" == branch.Name[0:len(name)+1]) {
+		// If branchRefName like a/b but we want to create a branch named a then we have a conflict
+		case strings.HasPrefix(branchRefName, name+"/"):
 			return models.ErrBranchNameConflict{
-				BranchName: branch.Name,
+				BranchName: branchRefName,
+			}
+			// Conversely if branchRefName like a but we want to create a branch named a/b then we also have a conflict
+		case strings.HasPrefix(name, branchRefName+"/"):
+			return models.ErrBranchNameConflict{
+				BranchName: branchRefName,
+			}
+		case refName == git.TagPrefix+name:
+			return models.ErrTagAlreadyExists{
+				TagName: name,
 			}
 		}
-	}
+		return nil
+	})
 
-	if _, err := gitRepo.GetTag(name); err == nil {
-		return models.ErrTagAlreadyExists{
-			TagName: name,
-		}
-	}
-
-	return nil
+	return err
 }
 
 // CreateNewBranchFromCommit creates a new repository branch
 func CreateNewBranchFromCommit(doer *user_model.User, repo *models.Repository, commit, branchName string) (err error) {
 	// Check if branch name can be used
-	if err := checkBranchName(repo, branchName); err != nil {
+	if err := checkBranchName(git.DefaultContext, repo, branchName); err != nil {
 		return err
 	}