From d67e40684f43b0eb744cad26e0265002f033dbc3 Mon Sep 17 00:00:00 2001
From: Jason Song <i@wolfogre.com>
Date: Mon, 3 Apr 2023 16:42:38 +0800
Subject: [PATCH] Improve LoadUnitConfig to handle invalid or duplicate units
 (#23736)

The old code just parses an invalid key to `TypeInvalid` and uses it as
normal, and duplicate keys will be kept.

So this PR will ignore invalid key and log warning and also deduplicate
valid units.
---
 models/unit/unit.go        | 39 ++++++++++++++++------------
 models/unit/unit_test.go   | 53 ++++++++++++++++++++++++++++++++++++++
 routers/api/v1/org/team.go |  2 +-
 3 files changed, 77 insertions(+), 17 deletions(-)
 create mode 100644 models/unit/unit_test.go

diff --git a/models/unit/unit.go b/models/unit/unit.go
index 883f443cbe65..3d5a8842cd63 100644
--- a/models/unit/unit.go
+++ b/models/unit/unit.go
@@ -151,7 +151,11 @@ func validateDefaultRepoUnits(defaultUnits, settingDefaultUnits []Type) []Type {
 
 // LoadUnitConfig load units from settings
 func LoadUnitConfig() {
-	DisabledRepoUnits = FindUnitTypes(setting.Repository.DisabledRepoUnits...)
+	var invalidKeys []string
+	DisabledRepoUnits, invalidKeys = FindUnitTypes(setting.Repository.DisabledRepoUnits...)
+	if len(invalidKeys) > 0 {
+		log.Warn("Invalid keys in disabled repo units: %s", strings.Join(invalidKeys, ", "))
+	}
 	// Check that must units are not disabled
 	for i, disabledU := range DisabledRepoUnits {
 		if !disabledU.CanDisable() {
@@ -160,9 +164,15 @@ func LoadUnitConfig() {
 		}
 	}
 
-	setDefaultRepoUnits := FindUnitTypes(setting.Repository.DefaultRepoUnits...)
+	setDefaultRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultRepoUnits...)
+	if len(invalidKeys) > 0 {
+		log.Warn("Invalid keys in default repo units: %s", strings.Join(invalidKeys, ", "))
+	}
 	DefaultRepoUnits = validateDefaultRepoUnits(DefaultRepoUnits, setDefaultRepoUnits)
-	setDefaultForkRepoUnits := FindUnitTypes(setting.Repository.DefaultForkRepoUnits...)
+	setDefaultForkRepoUnits, invalidKeys := FindUnitTypes(setting.Repository.DefaultForkRepoUnits...)
+	if len(invalidKeys) > 0 {
+		log.Warn("Invalid keys in default fork repo units: %s", strings.Join(invalidKeys, ", "))
+	}
 	DefaultForkRepoUnits = validateDefaultRepoUnits(DefaultForkRepoUnits, setDefaultForkRepoUnits)
 }
 
@@ -334,22 +344,19 @@ var (
 	}
 )
 
-// FindUnitTypes give the unit key names and return unit
-func FindUnitTypes(nameKeys ...string) (res []Type) {
+// FindUnitTypes give the unit key names and return valid unique units and invalid keys
+func FindUnitTypes(nameKeys ...string) (res []Type, invalidKeys []string) {
+	m := map[Type]struct{}{}
 	for _, key := range nameKeys {
-		var found bool
-		for t, u := range Units {
-			if strings.EqualFold(key, u.NameKey) {
-				res = append(res, t)
-				found = true
-				break
-			}
-		}
-		if !found {
-			res = append(res, TypeInvalid)
+		t := TypeFromKey(key)
+		if t == TypeInvalid {
+			invalidKeys = append(invalidKeys, key)
+		} else if _, ok := m[t]; !ok {
+			res = append(res, t)
+			m[t] = struct{}{}
 		}
 	}
-	return res
+	return res, invalidKeys
 }
 
 // TypeFromKey give the unit key name and return unit
diff --git a/models/unit/unit_test.go b/models/unit/unit_test.go
new file mode 100644
index 000000000000..50d781719771
--- /dev/null
+++ b/models/unit/unit_test.go
@@ -0,0 +1,53 @@
+// Copyright 2023 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package unit
+
+import (
+	"testing"
+
+	"code.gitea.io/gitea/modules/setting"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestLoadUnitConfig(t *testing.T) {
+	defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []Type) {
+		DisabledRepoUnits = disabledRepoUnits
+		DefaultRepoUnits = defaultRepoUnits
+		DefaultForkRepoUnits = defaultForkRepoUnits
+	}(DisabledRepoUnits, DefaultRepoUnits, DefaultForkRepoUnits)
+	defer func(disabledRepoUnits, defaultRepoUnits, defaultForkRepoUnits []string) {
+		setting.Repository.DisabledRepoUnits = disabledRepoUnits
+		setting.Repository.DefaultRepoUnits = defaultRepoUnits
+		setting.Repository.DefaultForkRepoUnits = defaultForkRepoUnits
+	}(setting.Repository.DisabledRepoUnits, setting.Repository.DefaultRepoUnits, setting.Repository.DefaultForkRepoUnits)
+
+	t.Run("regular", func(t *testing.T) {
+		setting.Repository.DisabledRepoUnits = []string{"repo.issues"}
+		setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls"}
+		setting.Repository.DefaultForkRepoUnits = []string{"repo.releases"}
+		LoadUnitConfig()
+		assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
+		assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
+		assert.Equal(t, []Type{TypeCode, TypeReleases}, DefaultForkRepoUnits)
+	})
+	t.Run("invalid", func(t *testing.T) {
+		setting.Repository.DisabledRepoUnits = []string{"repo.issues", "invalid.1"}
+		setting.Repository.DefaultRepoUnits = []string{"repo.code", "invalid.2", "repo.releases", "repo.issues", "repo.pulls"}
+		setting.Repository.DefaultForkRepoUnits = []string{"invalid.3", "repo.releases"}
+		LoadUnitConfig()
+		assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
+		assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
+		assert.Equal(t, []Type{TypeCode, TypeReleases}, DefaultForkRepoUnits)
+	})
+	t.Run("duplicate", func(t *testing.T) {
+		setting.Repository.DisabledRepoUnits = []string{"repo.issues", "repo.issues"}
+		setting.Repository.DefaultRepoUnits = []string{"repo.code", "repo.releases", "repo.issues", "repo.pulls", "repo.code"}
+		setting.Repository.DefaultForkRepoUnits = []string{"repo.releases", "repo.releases"}
+		LoadUnitConfig()
+		assert.Equal(t, []Type{TypeIssues}, DisabledRepoUnits)
+		assert.Equal(t, []Type{TypeCode, TypeReleases, TypePullRequests}, DefaultRepoUnits)
+		assert.Equal(t, []Type{TypeCode, TypeReleases}, DefaultForkRepoUnits)
+	})
+}
diff --git a/routers/api/v1/org/team.go b/routers/api/v1/org/team.go
index 0c6926759a76..597f84620604 100644
--- a/routers/api/v1/org/team.go
+++ b/routers/api/v1/org/team.go
@@ -135,7 +135,7 @@ func GetTeam(ctx *context.APIContext) {
 }
 
 func attachTeamUnits(team *organization.Team, units []string) {
-	unitTypes := unit_model.FindUnitTypes(units...)
+	unitTypes, _ := unit_model.FindUnitTypes(units...)
 	team.Units = make([]*organization.TeamUnit, 0, len(units))
 	for _, tp := range unitTypes {
 		team.Units = append(team.Units, &organization.TeamUnit{