diff --git a/cmd/pulls/merge.go b/cmd/pulls/merge.go index 3213db3..cf127dc 100644 --- a/cmd/pulls/merge.go +++ b/cmd/pulls/merge.go @@ -44,13 +44,27 @@ var CmdPullsMerge = cli.Command{ ctx := context.InitCommand(cmd) ctx.Ensure(context.CtxRequirement{RemoteRepo: true}) - if ctx.Args().Len() != 1 { - return fmt.Errorf("Must specify a PR index") - } + var idx int64 + var err error + if ctx.Args().Len() == 1 { + idx, err = utils.ArgToIndex(ctx.Args().First()) + if err != nil { + return err + } + } else { + if ctx.LocalRepo == nil { + return fmt.Errorf("Must specify a PR index") + } - idx, err := utils.ArgToIndex(ctx.Args().First()) - if err != nil { - return err + branch, _, err := ctx.LocalRepo.TeaGetCurrentBranchNameAndSHA() + if err != nil { + return err + } + + idx, err = GetPullIndexByBranch(ctx, branch) + if err != nil { + return err + } } success, _, err := ctx.Login.Client().MergePullRequest(ctx.Owner, ctx.Repo, idx, gitea.MergePullRequestOption{ @@ -68,3 +82,19 @@ var CmdPullsMerge = cli.Command{ return nil }, } + +func GetPullIndexByBranch(ctx *context.TeaContext, branch string) (int64, error) { + prs, _, err := ctx.Login.Client().ListRepoPullRequests(ctx.Owner, ctx.Repo, gitea.ListPullRequestsOptions{ + State: gitea.StateOpen, + }) + if err != nil { + return 0, err + } + + for _, pr := range prs { + if pr.Head.Ref == branch { + return pr.Index, nil + } + } + return 0, fmt.Errorf("No open PR for branch %s", branch) +}