diff --git a/.github/workflows/deployment.yml b/.github/workflows/deployment.yml index a7b03f40d20..850cc19b799 100644 --- a/.github/workflows/deployment.yml +++ b/.github/workflows/deployment.yml @@ -309,7 +309,7 @@ jobs: rpmsign --addsign dist/*.rpm - name: Attest release artifacts if: inputs.environment == 'production' - uses: actions/attest-build-provenance@bd77c077858b8d561b7a36cbe48ef4cc642ca39d # v2.2.2 + uses: actions/attest-build-provenance@db473fddc028af60658334401dc6fa3ffd8669fd # v2.3.0 with: subject-path: "dist/gh_*" - name: Run createrepo diff --git a/api/export_pr.go b/api/export_pr.go index 7ae1a4ff4e9..9b030c39ed7 100644 --- a/api/export_pr.go +++ b/api/export_pr.go @@ -28,6 +28,24 @@ func (issue *Issue) ExportData(fields []string) map[string]interface{} { }) } data[f] = items + case "closedByPullRequestsReferences": + items := make([]map[string]interface{}, 0, len(issue.ClosedByPullRequestsReferences.Nodes)) + for _, n := range issue.ClosedByPullRequestsReferences.Nodes { + items = append(items, map[string]interface{}{ + "id": n.ID, + "number": n.Number, + "url": n.URL, + "repository": map[string]interface{}{ + "id": n.Repository.ID, + "name": n.Repository.Name, + "owner": map[string]interface{}{ + "id": n.Repository.Owner.ID, + "login": n.Repository.Owner.Login, + }, + }, + }) + } + data[f] = items default: sf := fieldByName(v, f) data[f] = sf.Interface() @@ -143,7 +161,6 @@ func (pr *PullRequest) ExportData(fields []string) map[string]interface{} { items := make([]map[string]interface{}, 0, len(pr.ClosingIssuesReferences.Nodes)) for _, n := range pr.ClosingIssuesReferences.Nodes { items = append(items, map[string]interface{}{ - "id": n.ID, "number": n.Number, "url": n.URL, diff --git a/api/export_pr_test.go b/api/export_pr_test.go index 09a1dffe870..1f310693e68 100644 --- a/api/export_pr_test.go +++ b/api/export_pr_test.go @@ -107,6 +107,70 @@ func TestIssue_ExportData(t *testing.T) { } `), }, + { + name: "linked pull requests", + fields: []string{"closedByPullRequestsReferences"}, + inputJSON: heredoc.Doc(` + { "closedByPullRequestsReferences": { "nodes": [ + { + "id": "I_123", + "number": 123, + "url": "https://github.com/cli/cli/pull/123", + "repository": { + "id": "R_123", + "name": "cli", + "owner": { + "id": "O_123", + "login": "cli" + } + } + }, + { + "id": "I_456", + "number": 456, + "url": "https://github.com/cli/cli/pull/456", + "repository": { + "id": "R_456", + "name": "cli", + "owner": { + "id": "O_456", + "login": "cli" + } + } + } + ] } } + `), + outputJSON: heredoc.Doc(` + { "closedByPullRequestsReferences": [ + { + "id": "I_123", + "number": 123, + "repository": { + "id": "R_123", + "name": "cli", + "owner": { + "id": "O_123", + "login": "cli" + } + }, + "url": "https://github.com/cli/cli/pull/123" + }, + { + "id": "I_456", + "number": 456, + "repository": { + "id": "R_456", + "name": "cli", + "owner": { + "id": "O_456", + "login": "cli" + } + }, + "url": "https://github.com/cli/cli/pull/456" + } + ] } + `), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -120,7 +184,14 @@ func TestIssue_ExportData(t *testing.T) { enc := json.NewEncoder(&buf) enc.SetIndent("", "\t") require.NoError(t, enc.Encode(exported)) - assert.Equal(t, tt.outputJSON, buf.String()) + + var gotData interface{} + dec = json.NewDecoder(&buf) + require.NoError(t, dec.Decode(&gotData)) + var expectData interface{} + require.NoError(t, json.Unmarshal([]byte(tt.outputJSON), &expectData)) + + assert.Equal(t, expectData, gotData) }) } } diff --git a/api/queries_comments.go b/api/queries_comments.go index 5cc84a3e43f..8af17fd2ae6 100644 --- a/api/queries_comments.go +++ b/api/queries_comments.go @@ -44,6 +44,10 @@ type CommentCreateInput struct { SubjectId string } +type CommentDeleteInput struct { + CommentId string +} + type CommentUpdateInput struct { Body string CommentId string @@ -99,6 +103,27 @@ func CommentUpdate(client *Client, repoHost string, params CommentUpdateInput) ( return mutation.UpdateIssueComment.IssueComment.URL, nil } +func CommentDelete(client *Client, repoHost string, params CommentDeleteInput) error { + var mutation struct { + DeleteIssueComment struct { + ClientMutationID string + } `graphql:"deleteIssueComment(input: $input)"` + } + + variables := map[string]interface{}{ + "input": githubv4.DeleteIssueCommentInput{ + ID: githubv4.ID(params.CommentId), + }, + } + + err := client.Mutate(repoHost, "CommentDelete", &mutation, variables) + if err != nil { + return err + } + + return nil +} + func (c Comment) Identifier() string { return c.ID } diff --git a/api/queries_issue.go b/api/queries_issue.go index 094b6b198c7..24e0b4f4c29 100644 --- a/api/queries_issue.go +++ b/api/queries_issue.go @@ -38,12 +38,35 @@ type Issue struct { Comments Comments Author Author Assignees Assignees + AssignedActors AssignedActors Labels Labels ProjectCards ProjectCards ProjectItems ProjectItems Milestone *Milestone ReactionGroups ReactionGroups IsPinned bool + + ClosedByPullRequestsReferences ClosedByPullRequestsReferences +} + +type ClosedByPullRequestsReferences struct { + Nodes []struct { + ID string + Number int + URL string + Repository struct { + ID string + Name string + Owner struct { + ID string + Login string + } + } + } + PageInfo struct { + HasNextPage bool + EndCursor string + } } // return values for Issue.Typename @@ -69,6 +92,61 @@ func (a Assignees) Logins() []string { return logins } +type AssignedActors struct { + Nodes []Actor + TotalCount int +} + +func (a AssignedActors) Logins() []string { + logins := make([]string, len(a.Nodes)) + for i, a := range a.Nodes { + logins[i] = a.Login + } + return logins +} + +// DisplayNames returns a list of display names for the assigned actors. +func (a AssignedActors) DisplayNames() []string { + // These display names are used for populating the "default" assigned actors + // from the AssignedActors type. But, this is only one piece of the puzzle + // as later, other queries will fetch the full list of possible assignable + // actors from the repository, and the two lists will be reconciled. + // + // It's important that the display names are the same between the defaults + // (the values returned here) and the full list (the values returned by + // other repository queries). Any discrepancy would result in an + // "invalid default", which means an assigned actor will not be matched + // to an assignable actor and not presented as a "default" selection. + // Not being presented as a default would cause the actor to be potentially + // unassigned if the edits were submitted. + // + // To prevent this, we need shared logic to look up an actor's display name. + // However, our API types between assignedActors and the full list of + // assignableActors are different. So, as an attempt to maintain + // consistency we convert the assignedActors to the same types as the + // repository's assignableActors, treating the assignableActors DisplayName + // methods as the sources of truth. + // TODO KW: make this comment less of a wall of text if needed. + var displayNames []string + for _, a := range a.Nodes { + if a.TypeName == "User" { + u := NewAssignableUser( + a.ID, + a.Login, + a.Name, + ) + displayNames = append(displayNames, u.DisplayName()) + } else if a.TypeName == "Bot" { + b := NewAssignableBot( + a.ID, + a.Login, + ) + displayNames = append(displayNames, b.DisplayName()) + } + } + return displayNames +} + type Labels struct { Nodes []IssueLabel TotalCount int diff --git a/api/queries_pr.go b/api/queries_pr.go index 5b941bb428a..525418a11f0 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -84,6 +84,7 @@ type PullRequest struct { } Assignees Assignees + AssignedActors AssignedActors Labels Labels ProjectCards ProjectCards ProjectItems ProjectItems diff --git a/api/queries_repo.go b/api/queries_repo.go index 27e21eb32ac..efbcfcb197d 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -146,6 +146,18 @@ type GitHubUser struct { Name string `json:"name"` } +// Actor is a superset of User and Bot, among others. +// At the time of writing, some of these fields +// are not directly supported by the Actor type and +// instead are only available on the User or Bot types +// directly. +type Actor struct { + ID string `json:"id"` + Login string `json:"login"` + Name string `json:"name"` + TypeName string `json:"__typename"` +} + // BranchRef is the branch name in a GitHub repository type BranchRef struct { Name string `json:"name"` @@ -674,13 +686,14 @@ func RepoFindForks(client *Client, repo ghrepo.Interface, limit int) ([]*Reposit } type RepoMetadataResult struct { - CurrentLogin string - AssignableUsers []RepoAssignee - Labels []RepoLabel - Projects []RepoProject - ProjectsV2 []ProjectV2 - Milestones []RepoMilestone - Teams []OrgTeam + CurrentLogin string + AssignableUsers []AssignableUser + AssignableActors []AssignableActor + Labels []RepoLabel + Projects []RepoProject + ProjectsV2 []ProjectV2 + Milestones []RepoMilestone + Teams []OrgTeam } func (m *RepoMetadataResult) MembersToIDs(names []string) ([]string, error) { @@ -688,12 +701,27 @@ func (m *RepoMetadataResult) MembersToIDs(names []string) ([]string, error) { for _, assigneeLogin := range names { found := false for _, u := range m.AssignableUsers { - if strings.EqualFold(assigneeLogin, u.Login) { - ids = append(ids, u.ID) + if strings.EqualFold(assigneeLogin, u.Login()) { + ids = append(ids, u.ID()) + found = true + break + } + } + + // Look for ID in assignable actors if not found in assignable users + for _, a := range m.AssignableActors { + if strings.EqualFold(assigneeLogin, a.Login()) { + ids = append(ids, a.ID()) + found = true + break + } + if strings.EqualFold(assigneeLogin, a.DisplayName()) { + ids = append(ids, a.ID()) found = true break } } + if !found { return nil, fmt.Errorf("'%s' not found", assigneeLogin) } @@ -738,34 +766,37 @@ func (m *RepoMetadataResult) LabelsToIDs(names []string) ([]string, error) { return ids, nil } -// ProjectsToIDs returns two arrays: +// ProjectsTitlesToIDs returns two arrays: // - the first contains IDs of projects V1 // - the second contains IDs of projects V2 // - if neither project V1 or project V2 can be found with a given name, then an error is returned -func (m *RepoMetadataResult) ProjectsToIDs(names []string) ([]string, []string, error) { +func (m *RepoMetadataResult) ProjectsTitlesToIDs(titles []string) ([]string, []string, error) { var ids []string var idsV2 []string - for _, projectName := range names { - id, found := m.projectNameToID(projectName) + for _, title := range titles { + id, found := m.v1ProjectNameToID(title) if found { ids = append(ids, id) continue } - idV2, found := m.projectV2TitleToID(projectName) + idV2, found := m.v2ProjectTitleToID(title) if found { idsV2 = append(idsV2, idV2) continue } - return nil, nil, fmt.Errorf("'%s' not found", projectName) + return nil, nil, fmt.Errorf("'%s' not found", title) } return ids, idsV2, nil } -func (m *RepoMetadataResult) projectNameToID(projectName string) (string, bool) { +// We use the word "titles" when referring to v1 and v2 projects. +// In reality, v1 projects really have "names", so there is a bit of a +// mismatch we just need to gloss over. +func (m *RepoMetadataResult) v1ProjectNameToID(name string) (string, bool) { for _, p := range m.Projects { - if strings.EqualFold(projectName, p.Name) { + if strings.EqualFold(name, p.Name) { return p.ID, true } } @@ -773,9 +804,9 @@ func (m *RepoMetadataResult) projectNameToID(projectName string) (string, bool) return "", false } -func (m *RepoMetadataResult) projectV2TitleToID(projectTitle string) (string, bool) { +func (m *RepoMetadataResult) v2ProjectTitleToID(title string) (string, bool) { for _, p := range m.ProjectsV2 { - if strings.EqualFold(projectTitle, p.Title) { + if strings.EqualFold(title, p.Title) { return p.ID, true } } @@ -783,8 +814,8 @@ func (m *RepoMetadataResult) projectV2TitleToID(projectTitle string) (string, bo return "", false } -func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []string, projectsV1Support gh.ProjectsV1Support) ([]string, error) { - paths := make([]string, 0, len(projectNames)) +func ProjectTitlesToPaths(client *Client, repo ghrepo.Interface, titles []string, projectsV1Support gh.ProjectsV1Support) ([]string, error) { + paths := make([]string, 0, len(titles)) matchedPaths := map[string]struct{}{} // TODO: ProjectsV1Cleanup @@ -796,9 +827,9 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s return nil, err } - for _, projectName := range projectNames { + for _, title := range titles { for _, p := range v1Projects { - if strings.EqualFold(projectName, p.Name) { + if strings.EqualFold(title, p.Name) { pathParts := strings.Split(p.ResourcePath, "/") var path string if pathParts[1] == "orgs" || pathParts[1] == "users" { @@ -807,7 +838,7 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s path = fmt.Sprintf("%s/%s/%s", pathParts[1], pathParts[2], pathParts[4]) } paths = append(paths, path) - matchedPaths[projectName] = struct{}{} + matchedPaths[title] = struct{}{} break } } @@ -820,15 +851,15 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s return nil, err } - for _, projectName := range projectNames { + for _, title := range titles { // If we already found a v1 project with this name, skip it - if _, ok := matchedPaths[projectName]; ok { + if _, ok := matchedPaths[title]; ok { continue } found := false for _, p := range v2Projects { - if strings.EqualFold(projectName, p.Title) { + if strings.EqualFold(title, p.Title) { pathParts := strings.Split(p.ResourcePath, "/") var path string if pathParts[1] == "orgs" || pathParts[1] == "users" { @@ -843,7 +874,7 @@ func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []s } if !found { - return nil, fmt.Errorf("'%s' not found", projectName) + return nil, fmt.Errorf("'%s' not found", title) } } @@ -882,12 +913,13 @@ func (m *RepoMetadataResult) Merge(m2 *RepoMetadataResult) { } type RepoMetadataInput struct { - Assignees bool - Reviewers bool - Labels bool - ProjectsV1 bool - ProjectsV2 bool - Milestones bool + Assignees bool + ActorAssignees bool + Reviewers bool + Labels bool + ProjectsV1 bool + ProjectsV2 bool + Milestones bool } // RepoMetadata pre-fetches the metadata for attaching to issues and pull requests @@ -896,14 +928,37 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput var g errgroup.Group if input.Assignees || input.Reviewers { - g.Go(func() error { - users, err := RepoAssignableUsers(client, repo) - if err != nil { - err = fmt.Errorf("error fetching assignees: %w", err) - } - result.AssignableUsers = users - return err - }) + if input.ActorAssignees { + g.Go(func() error { + actors, err := RepoAssignableActors(client, repo) + if err != nil { + return fmt.Errorf("error fetching assignable actors: %w", err) + } + result.AssignableActors = actors + + // Filter actors for users to use for pull request reviewers, + // skip retrieving the same info through RepoAssignableUsers(). + var users []AssignableUser + for _, a := range actors { + if _, ok := a.(AssignableUser); !ok { + continue + } + users = append(users, a.(AssignableUser)) + } + result.AssignableUsers = users + return nil + }) + } else { + // Not using Actors, fetch legacy assignable users. + g.Go(func() error { + users, err := RepoAssignableUsers(client, repo) + if err != nil { + err = fmt.Errorf("error fetching assignable users: %w", err) + } + result.AssignableUsers = users + return err + }) + } } if input.Reviewers { @@ -1067,12 +1122,16 @@ func RepoResolveMetadataIDs(client *Client, repo ghrepo.Interface, input RepoRes result.Teams = append(result.Teams, t) } default: - user := RepoAssignee{} + user := struct { + Id string + Login string + Name string + }{} err := json.Unmarshal(v, &user) if err != nil { return result, err } - result.AssignableUsers = append(result.AssignableUsers, user) + result.AssignableUsers = append(result.AssignableUsers, NewAssignableUser(user.Id, user.Login, user.Name)) } } @@ -1124,26 +1183,99 @@ func RepoProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) return projects, nil } -type RepoAssignee struct { - ID string - Login string - Name string +// Expected login for Copilot when retrieved as an Actor +// This is returned from assignable actors and issue/pr assigned actors. +// We use this to check if the actor is Copilot. +const CopilotActorLogin = "copilot-swe-agent" + +type AssignableActor interface { + DisplayName() string + ID() string + Login() string + + sealedAssignableActor() +} + +// Always a user +type AssignableUser struct { + id string + login string + name string +} + +func NewAssignableUser(id, login, name string) AssignableUser { + return AssignableUser{ + id: id, + login: login, + name: name, + } } // DisplayName returns a formatted string that uses Login and Name to be displayed e.g. 'Login (Name)' or 'Login' -func (ra RepoAssignee) DisplayName() string { - if ra.Name != "" { - return fmt.Sprintf("%s (%s)", ra.Login, ra.Name) +func (u AssignableUser) DisplayName() string { + if u.name != "" { + return fmt.Sprintf("%s (%s)", u.login, u.name) + } + return u.login +} + +func (u AssignableUser) ID() string { + return u.id +} + +func (u AssignableUser) Login() string { + return u.login +} + +func (u AssignableUser) Name() string { + return u.name +} + +func (u AssignableUser) sealedAssignableActor() {} + +type AssignableBot struct { + id string + login string +} + +func NewAssignableBot(id, login string) AssignableBot { + return AssignableBot{ + id: id, + login: login, + } +} + +func (b AssignableBot) DisplayName() string { + if b.login == CopilotActorLogin { + return "Copilot (AI)" } - return ra.Login + return b.Login() +} + +func (b AssignableBot) ID() string { + return b.id } +func (b AssignableBot) Login() string { + return b.login +} + +func (b AssignableBot) Name() string { + return "" +} + +func (b AssignableBot) sealedAssignableActor() {} + // RepoAssignableUsers fetches all the assignable users for a repository -func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]RepoAssignee, error) { +func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]AssignableUser, error) { type responseData struct { Repository struct { AssignableUsers struct { - Nodes []RepoAssignee + Nodes []struct { + ID string + Login string + Name string + } PageInfo struct { HasNextPage bool EndCursor string @@ -1158,7 +1290,7 @@ func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]RepoAssignee, "endCursor": (*githubv4.String)(nil), } - var users []RepoAssignee + var users []AssignableUser for { var query responseData err := client.Query(repo.RepoHost(), "RepositoryAssignableUsers", &query, variables) @@ -1166,7 +1298,15 @@ func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]RepoAssignee, return nil, err } - users = append(users, query.Repository.AssignableUsers.Nodes...) + for _, node := range query.Repository.AssignableUsers.Nodes { + user := AssignableUser{ + id: node.ID, + login: node.Login, + name: node.Name, + } + + users = append(users, user) + } if !query.Repository.AssignableUsers.PageInfo.HasNextPage { break } @@ -1176,6 +1316,72 @@ func RepoAssignableUsers(client *Client, repo ghrepo.Interface) ([]RepoAssignee, return users, nil } +// RepoAssignableActors fetches all the assignable actors for a repository on +// GitHub hosts that support Actor assignees. +func RepoAssignableActors(client *Client, repo ghrepo.Interface) ([]AssignableActor, error) { + type responseData struct { + Repository struct { + SuggestedActors struct { + Nodes []struct { + User struct { + ID string + Login string + Name string + TypeName string `graphql:"__typename"` + } `graphql:"... on User"` + Bot struct { + ID string + Login string + TypeName string `graphql:"__typename"` + } `graphql:"... on Bot"` + } + PageInfo struct { + HasNextPage bool + EndCursor string + } + } `graphql:"suggestedActors(first: 100, after: $endCursor, capabilities: CAN_BE_ASSIGNED)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "name": githubv4.String(repo.RepoName()), + "endCursor": (*githubv4.String)(nil), + } + + var actors []AssignableActor + for { + var query responseData + err := client.Query(repo.RepoHost(), "RepositoryAssignableActors", &query, variables) + if err != nil { + return nil, err + } + + for _, node := range query.Repository.SuggestedActors.Nodes { + if node.User.TypeName == "User" { + actor := AssignableUser{ + id: node.User.ID, + login: node.User.Login, + name: node.User.Name, + } + actors = append(actors, actor) + } else if node.Bot.TypeName == "Bot" { + actor := AssignableBot{ + id: node.Bot.ID, + login: node.Bot.Login, + } + actors = append(actors, actor) + } + } + + if !query.Repository.SuggestedActors.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Repository.SuggestedActors.PageInfo.EndCursor) + } + return actors, nil +} + type RepoLabel struct { ID string Name string diff --git a/api/queries_repo_test.go b/api/queries_repo_test.go index 72ed357760b..9040a001802 100644 --- a/api/queries_repo_test.go +++ b/api/queries_repo_test.go @@ -187,7 +187,7 @@ func Test_RepoMetadata(t *testing.T) { expectedProjectIDs := []string{"TRIAGEID", "ROADMAPID"} expectedProjectV2IDs := []string{"TRIAGEV2ID", "ROADMAPV2ID", "MONALISAV2ID"} - projectIDs, projectV2IDs, err := result.ProjectsToIDs([]string{"triage", "roadmap", "triagev2", "roadmapv2", "monalisav2"}) + projectIDs, projectV2IDs, err := result.ProjectsTitlesToIDs([]string{"triage", "roadmap", "triagev2", "roadmapv2", "monalisav2"}) if err != nil { t.Errorf("error resolving projects: %v", err) } @@ -273,7 +273,7 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - projectPaths, err := ProjectNamesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Supported) + projectPaths, err := ProjectTitlesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Supported) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -331,7 +331,7 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - projectPaths, err := ProjectNamesToPaths(client, repo, []string{"TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Unsupported) + projectPaths, err := ProjectTitlesToPaths(client, repo, []string{"TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Unsupported) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -374,7 +374,7 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - _, err := ProjectNamesToPaths(client, repo, []string{"TriageV2"}, gh.ProjectsV1Unsupported) + _, err := ProjectTitlesToPaths(client, repo, []string{"TriageV2"}, gh.ProjectsV1Unsupported) require.Equal(t, err, fmt.Errorf("'TriageV2' not found")) }) } @@ -526,17 +526,17 @@ func Test_RepoMilestones(t *testing.T) { func TestDisplayName(t *testing.T) { tests := []struct { name string - assignee RepoAssignee + assignee AssignableUser want string }{ { name: "assignee with name", - assignee: RepoAssignee{"123", "octocat123", "Octavious Cath"}, + assignee: AssignableUser{"123", "octocat123", "Octavious Cath"}, want: "octocat123 (Octavious Cath)", }, { name: "assignee without name", - assignee: RepoAssignee{"123", "octocat123", ""}, + assignee: AssignableUser{"123", "octocat123", ""}, want: "octocat123", }, } diff --git a/api/query_builder.go b/api/query_builder.go index 4c45da3c1b2..a2432673b74 100644 --- a/api/query_builder.go +++ b/api/query_builder.go @@ -20,6 +20,25 @@ func shortenQuery(q string) string { return strings.Map(squeeze, q) } +var assignedActors = shortenQuery(` + assignedActors(first: 10) { + nodes { + ...on User { + id, + login, + name, + __typename + } + ...on Bot { + id, + login, + __typename + } + }, + totalCount + } +`) + var issueComments = shortenQuery(` comments(first: 100) { nodes { @@ -56,6 +75,25 @@ var issueCommentLast = shortenQuery(` } `) +var issueClosedByPullRequestsReferences = shortenQuery(` + closedByPullRequestsReferences(first: 100) { + nodes { + id, + number, + url, + repository { + id, + name, + owner { + id, + login + } + } + } + pageInfo{hasNextPage,endCursor} + } +`) + var prReviewRequests = shortenQuery(` reviewRequests(first: 100) { nodes { @@ -296,6 +334,7 @@ var sharedIssuePRFields = []string{ var issueOnlyFields = []string{ "isPinned", "stateReason", + "closedByPullRequestsReferences", } var IssueFields = append(sharedIssuePRFields, issueOnlyFields...) @@ -346,6 +385,8 @@ func IssueGraphQL(fields []string) string { q = append(q, `headRepository{id,name}`) case "assignees": q = append(q, `assignees(first:100){nodes{id,login,name},totalCount}`) + case "assignedActors": + q = append(q, assignedActors) case "labels": q = append(q, `labels(first:100){nodes{id,name,description,color},totalCount}`) case "projectCards": @@ -388,6 +429,8 @@ func IssueGraphQL(fields []string) string { q = append(q, StatusCheckRollupGraphQLWithCountByState()) case "closingIssuesReferences": q = append(q, prClosingIssuesReferences) + case "closedByPullRequestsReferences": + q = append(q, issueClosedByPullRequestsReferences) default: q = append(q, field) } diff --git a/go.mod b/go.mod index 3562f24a696..bf8b033fe06 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24 github.com/cli/oauth v1.1.1 github.com/cli/safeexec v1.0.1 - github.com/cpuguy83/go-md2man/v2 v2.0.6 + github.com/cpuguy83/go-md2man/v2 v2.0.7 github.com/creack/pty v1.1.24 github.com/digitorus/timestamp v0.0.0-20231217203849-220c5c2851b7 github.com/distribution/reference v0.6.0 diff --git a/go.sum b/go.sum index 2ac25c2f895..e0ecad6a7c1 100644 --- a/go.sum +++ b/go.sum @@ -150,8 +150,9 @@ github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb h1:EDmT6Q9Zs+SbUo github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb/go.mod h1:ZjrT6AXHbDs86ZSdt/osfBi5qfexBrKUdONk989Wnk4= github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= -github.com/cpuguy83/go-md2man/v2 v2.0.6 h1:XJtiaUW6dEEqVuZiMTn1ldk455QWwEIsMIJlo5vtkx0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= +github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= diff --git a/internal/featuredetection/feature_detection.go b/internal/featuredetection/feature_detection.go index fba317f5874..a2f34a60bec 100644 --- a/internal/featuredetection/feature_detection.go +++ b/internal/featuredetection/feature_detection.go @@ -18,11 +18,13 @@ type Detector interface { } type IssueFeatures struct { - StateReason bool + StateReason bool + ActorIsAssignable bool } var allIssueFeatures = IssueFeatures{ - StateReason: true, + StateReason: true, + ActorIsAssignable: true, } type PullRequestFeatures struct { @@ -70,7 +72,8 @@ func (d *detector) IssueFeatures() (IssueFeatures, error) { } features := IssueFeatures{ - StateReason: false, + StateReason: false, + ActorIsAssignable: false, // replaceActorsForAssignable GraphQL mutation unavailable on GHES } var featureDetection struct { diff --git a/internal/featuredetection/feature_detection_test.go b/internal/featuredetection/feature_detection_test.go index f1152da2cf7..2c7d190716a 100644 --- a/internal/featuredetection/feature_detection_test.go +++ b/internal/featuredetection/feature_detection_test.go @@ -23,7 +23,8 @@ func TestIssueFeatures(t *testing.T) { name: "github.com", hostname: "github.com", wantFeatures: IssueFeatures{ - StateReason: true, + StateReason: true, + ActorIsAssignable: true, }, wantErr: false, }, @@ -31,7 +32,8 @@ func TestIssueFeatures(t *testing.T) { name: "ghec data residency (ghe.com)", hostname: "stampname.ghe.com", wantFeatures: IssueFeatures{ - StateReason: true, + StateReason: true, + ActorIsAssignable: true, }, wantErr: false, }, @@ -42,7 +44,8 @@ func TestIssueFeatures(t *testing.T) { `query Issue_fields\b`: `{"data": {}}`, }, wantFeatures: IssueFeatures{ - StateReason: false, + StateReason: false, + ActorIsAssignable: false, }, wantErr: false, }, diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 00947b8f4ce..2b8104e9af4 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -5,6 +5,7 @@ package prompter_test import ( "fmt" "io" + "slices" "strings" "testing" "time" @@ -32,6 +33,9 @@ import ( // are sufficient to ensure that the accessible prompter behaves roughly as expected // but doesn't mandate that prompts always look exactly the same. func TestAccessiblePrompter(t *testing.T) { + + beforePasswordSendTimeout := 100 * time.Microsecond + t.Run("Select", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -51,6 +55,73 @@ func TestAccessiblePrompter(t *testing.T) { assert.Equal(t, 0, selectValue) }) + t.Run("Select - blank input returns default value", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValue := "12345abcdefg" + options := []string{"1", "2", dummyDefaultValue} + + go func() { + // Wait for prompt to appear + _, err := console.ExpectString("Input a number between 1 and 3:") + require.NoError(t, err) + + // Just press enter to accept the default + _, err = console.SendLine("") + require.NoError(t, err) + }() + + selectValue, err := p.Select("Select a number", dummyDefaultValue, options) + require.NoError(t, err) + + expectedIndex := slices.Index(options, dummyDefaultValue) + assert.Equal(t, expectedIndex, selectValue) + }) + + t.Run("Select - default value is in prompt and in readable format", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValue := "12345abcdefg" + options := []string{"1", "2", dummyDefaultValue} + + go func() { + // Wait for prompt to appear + _, err := console.ExpectString("Select a number (default: 12345abcdefg)") + require.NoError(t, err) + + // Just press enter to accept the default + _, err = console.SendLine("") + require.NoError(t, err) + }() + + selectValue, err := p.Select("Select a number", dummyDefaultValue, options) + require.NoError(t, err) + + expectedIndex := slices.Index(options, dummyDefaultValue) + assert.Equal(t, expectedIndex, selectValue) + }) + + t.Run("Select - invalid defaults are excluded from prompt", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValue := "foo" + options := []string{"1", "2"} + + go func() { + // Wait for prompt to appear without the invalid default value + _, err := console.ExpectString("Select a number \r\n") + require.NoError(t, err) + + // Select option 2 + _, err = console.SendLine("2") + require.NoError(t, err) + }() + + selectValue, err := p.Select("Select a number", dummyDefaultValue, options) + require.NoError(t, err) + assert.Equal(t, 1, selectValue) + }) + t.Run("MultiSelect", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -97,6 +168,62 @@ func TestAccessiblePrompter(t *testing.T) { assert.Equal(t, []int{1}, multiSelectValue) }) + t.Run("MultiSelect - default value is in prompt and in readable format", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValues := []string{"foo", "bar"} + options := []string{"1", "2"} + options = append(options, dummyDefaultValues...) + + go func() { + // Wait for prompt to appear + _, err := console.ExpectString("Select a number (defaults: foo, bar)") + require.NoError(t, err) + + // Don't select anything because the defaults should be selected. + + // This confirms selections + _, err = console.SendLine("0") + require.NoError(t, err) + }() + + multiSelectValues, err := p.MultiSelect("Select a number", dummyDefaultValues, options) + require.NoError(t, err) + var expectedIndices []int + + // Get the indices of the default values within the options slice + // as that's what we expect the prompter to return when no selections are made. + for _, defaultValue := range dummyDefaultValues { + expectedIndices = append(expectedIndices, slices.Index(options, defaultValue)) + } + assert.Equal(t, expectedIndices, multiSelectValues) + }) + + t.Run("MultiSelect - invalid defaults are excluded from prompt", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValues := []string{"foo", "bar"} + options := []string{"1", "2"} + + go func() { + // Wait for prompt to appear without the invalid default values + _, err := console.ExpectString("Select a number \r\n") + require.NoError(t, err) + + // Not selecting anything will fail because there are no defaults. + _, err = console.SendLine("2") + require.NoError(t, err) + + // This confirms selections + _, err = console.SendLine("0") + require.NoError(t, err) + }() + + multiSelectValues, err := p.MultiSelect("Select a number", dummyDefaultValues, options) + require.NoError(t, err) + assert.Equal(t, []int{1}, multiSelectValues) + }) + t.Run("Input", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -137,6 +264,26 @@ func TestAccessiblePrompter(t *testing.T) { assert.Equal(t, dummyDefaultValue, inputValue) }) + t.Run("Input - default value is in prompt and in readable format", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + dummyDefaultValue := "12345abcdefg" + + go func() { + // Wait for prompt to appear + _, err := console.ExpectString("Enter some characters (default: 12345abcdefg)") + require.NoError(t, err) + + // Enter nothing + _, err = console.SendLine("") + require.NoError(t, err) + }() + + inputValue, err := p.Input("Enter some characters", dummyDefaultValue) + require.NoError(t, err) + assert.Equal(t, dummyDefaultValue, inputValue) + }) + t.Run("Password", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -147,6 +294,9 @@ func TestAccessiblePrompter(t *testing.T) { _, err := console.ExpectString("Enter password") require.NoError(t, err) + // Wait to ensure huh has time to set the echo mode + time.Sleep(beforePasswordSendTimeout) + // Enter a number _, err = console.SendLine(dummyPassword) require.NoError(t, err) @@ -158,7 +308,12 @@ func TestAccessiblePrompter(t *testing.T) { // Ensure the dummy password is not printed to the screen, // asserting that echo mode is disabled. - _, err = console.ExpectString(" \r\n\r\n") + // + // Note that since console.ExpectString returns successful if the + // expected string matches any part of the stream, we have to use an + // anchored regexp (i.e., with ^ and $) to make sure the password/token + // is not printed at all. + _, err = console.Expect(expect.RegexpPattern("^ \r\n\r\n$")) require.NoError(t, err) }) @@ -200,6 +355,26 @@ func TestAccessiblePrompter(t *testing.T) { require.Equal(t, false, confirmValue) }) + t.Run("Confirm - default value is in prompt and in readable format", func(t *testing.T) { + console := newTestVirtualTerminal(t) + p := newTestAccessiblePrompter(t, console) + defaultValue := true + + go func() { + // Wait for prompt to appear + _, err := console.ExpectString("Are you sure (default: yes)") + require.NoError(t, err) + + // Enter nothing + _, err = console.SendLine("") + require.NoError(t, err) + }() + + confirmValue, err := p.Confirm("Are you sure", defaultValue) + require.NoError(t, err) + require.Equal(t, defaultValue, confirmValue) + }) + t.Run("AuthToken", func(t *testing.T) { console := newTestVirtualTerminal(t) p := newTestAccessiblePrompter(t, console) @@ -210,6 +385,9 @@ func TestAccessiblePrompter(t *testing.T) { _, err := console.ExpectString("Paste your authentication token:") require.NoError(t, err) + // Wait to ensure huh has time to set the echo mode + time.Sleep(beforePasswordSendTimeout) + // Enter some dummy auth token _, err = console.SendLine(dummyAuthToken) require.NoError(t, err) @@ -221,7 +399,12 @@ func TestAccessiblePrompter(t *testing.T) { // Ensure the dummy password is not printed to the screen, // asserting that echo mode is disabled. - _, err = console.ExpectString(" \r\n\r\n") + // + // Note that since console.ExpectString returns successful if the + // expected string matches any part of the stream, we have to use an + // anchored regexp (i.e., with ^ and $) to make sure the password/token + // is not printed at all. + _, err = console.Expect(expect.RegexpPattern("^ \r\n\r\n$")) require.NoError(t, err) }) @@ -243,6 +426,13 @@ func TestAccessiblePrompter(t *testing.T) { _, err = console.ExpectString("token is required") require.NoError(t, err) + // Wait for the retry prompt + _, err = console.ExpectString("Paste your authentication token:") + require.NoError(t, err) + + // Wait to ensure huh has time to set the echo mode + time.Sleep(beforePasswordSendTimeout) + // Now enter some dummy auth token to return control back to the test _, err = console.SendLine(dummyAuthTokenForAfterFailure) require.NoError(t, err) @@ -254,7 +444,12 @@ func TestAccessiblePrompter(t *testing.T) { // Ensure the dummy password is not printed to the screen, // asserting that echo mode is disabled. - _, err = console.ExpectString(" \r\n\r\n") + // + // Note that since console.ExpectString returns successful if the + // expected string matches any part of the stream, we have to use an + // anchored regexp (i.e., with ^ and $) to make sure the password/token + // is not printed at all. + _, err = console.Expect(expect.RegexpPattern("^ \r\n\r\n$")) require.NoError(t, err) }) diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index 1e4f5592a56..c2233fd9266 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -77,10 +77,40 @@ func (p *accessiblePrompter) newForm(groups ...*huh.Group) *huh.Form { WithOutput(p.stdout) } -func (p *accessiblePrompter) Select(prompt, _ string, options []string) (int, error) { +// addDefaultsToPrompt adds default values to the prompt string. +func (p *accessiblePrompter) addDefaultsToPrompt(prompt string, defaultValues []string) string { + // Removing empty defaults from the slice. + defaultValues = slices.DeleteFunc(defaultValues, func(s string) bool { + return s == "" + }) + + // Pluralizing the prompt if there are multiple default values. + if len(defaultValues) == 1 { + prompt = fmt.Sprintf("%s (default: %s)", prompt, defaultValues[0]) + } else if len(defaultValues) > 1 { + prompt = fmt.Sprintf("%s (defaults: %s)", prompt, strings.Join(defaultValues, ", ")) + } + + // Zero-length defaultValues means return prompt unchanged. + return prompt +} + +func (p *accessiblePrompter) Select(prompt, defaultValue string, options []string) (int, error) { var result int + + // Remove invalid default values from the defaults slice. + if !slices.Contains(options, defaultValue) { + defaultValue = "" + } + + prompt = p.addDefaultsToPrompt(prompt, []string{defaultValue}) formOptions := []huh.Option[int]{} for i, o := range options { + // If this option is the default value, assign its index + // to the result variable. huh will treat it as a default selection. + if defaultValue == o { + result = i + } formOptions = append(formOptions, huh.NewOption(o, i)) } @@ -99,12 +129,18 @@ func (p *accessiblePrompter) Select(prompt, _ string, options []string) (int, er func (p *accessiblePrompter) MultiSelect(prompt string, defaults []string, options []string) ([]int, error) { var result []int + + // Remove invalid default values from the defaults slice. + defaults = slices.DeleteFunc(defaults, func(s string) bool { + return !slices.Contains(options, s) + }) + + prompt = p.addDefaultsToPrompt(prompt, defaults) formOptions := make([]huh.Option[int], len(options)) for i, o := range options { // If this option is in the defaults slice, // let's add its index to the result slice and huh // will treat it as a default selection. - // TODO: does an invalid default value constitute a panic? if slices.Contains(defaults, o) { result = append(result, i) } @@ -131,7 +167,7 @@ func (p *accessiblePrompter) MultiSelect(prompt string, defaults []string, optio func (p *accessiblePrompter) Input(prompt, defaultValue string) (string, error) { result := defaultValue - prompt = fmt.Sprintf("%s (%s)", prompt, defaultValue) + prompt = p.addDefaultsToPrompt(prompt, []string{defaultValue}) form := p.newForm( huh.NewGroup( huh.NewInput(). @@ -167,6 +203,13 @@ func (p *accessiblePrompter) Password(prompt string) (string, error) { func (p *accessiblePrompter) Confirm(prompt string, defaultValue bool) (bool, error) { result := defaultValue + + if defaultValue { + prompt = p.addDefaultsToPrompt(prompt, []string{"yes"}) + } else { + prompt = p.addDefaultsToPrompt(prompt, []string{"no"}) + } + form := p.newForm( huh.NewGroup( huh.NewConfirm(). @@ -174,6 +217,7 @@ func (p *accessiblePrompter) Confirm(prompt string, defaultValue bool) (bool, er Value(&result), ), ) + if err := form.Run(); err != nil { return false, err } diff --git a/internal/prompter/test.go b/internal/prompter/test.go index 04375ce761c..dfa124fcad6 100644 --- a/internal/prompter/test.go +++ b/internal/prompter/test.go @@ -141,6 +141,18 @@ func IndexFor(options []string, answer string) (int, error) { return -1, NoSuchAnswerErr(answer, options) } +func IndexesFor(options []string, answers ...string) ([]int, error) { + indexes := make([]int, len(answers)) + for i, answer := range answers { + index, err := IndexFor(options, answer) + if err != nil { + return nil, err + } + indexes[i] = index + } + return indexes, nil +} + func NoSuchPromptErr(prompt string) error { return fmt.Errorf("no such prompt '%s'", prompt) } diff --git a/pkg/cmd/attestation/api/attestation.go b/pkg/cmd/attestation/api/attestation.go index fd6b484a7cd..daec12b5051 100644 --- a/pkg/cmd/attestation/api/attestation.go +++ b/pkg/cmd/attestation/api/attestation.go @@ -1,7 +1,10 @@ package api import ( + "encoding/json" "errors" + "fmt" + "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -20,3 +23,35 @@ type Attestation struct { type AttestationsResponse struct { Attestations []*Attestation `json:"attestations"` } + +type IntotoStatement struct { + PredicateType string `json:"predicateType"` +} + +func FilterAttestations(predicateType string, attestations []*Attestation) ([]*Attestation, error) { + filteredAttestations := []*Attestation{} + + for _, each := range attestations { + dsseEnvelope := each.Bundle.GetDsseEnvelope() + if dsseEnvelope != nil { + if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { + // Don't fail just because an entry isn't intoto + continue + } + var intotoStatement IntotoStatement + if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { + // Don't fail just because a single entry can't be unmarshalled + continue + } + if intotoStatement.PredicateType == predicateType { + filteredAttestations = append(filteredAttestations, each) + } + } + } + + if len(filteredAttestations) == 0 { + return nil, fmt.Errorf("no attestations found with predicate type: %s", predicateType) + } + + return filteredAttestations, nil +} diff --git a/pkg/cmd/attestation/api/client.go b/pkg/cmd/attestation/api/client.go index 1e99a2a068d..61d0bee52c9 100644 --- a/pkg/cmd/attestation/api/client.go +++ b/pkg/cmd/attestation/api/client.go @@ -27,6 +27,28 @@ const ( // Allow injecting backoff interval in tests. var getAttestationRetryInterval = time.Millisecond * 200 +// FetchParams are the parameters for fetching attestations from the GitHub API +type FetchParams struct { + Digest string + Limit int + Owner string + PredicateType string + Repo string +} + +func (p *FetchParams) Validate() error { + if p.Digest == "" { + return fmt.Errorf("digest must be provided") + } + if p.Limit <= 0 || p.Limit > maxLimitForFlag { + return fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) + } + if p.Repo == "" && p.Owner == "" { + return fmt.Errorf("owner or repo must be provided") + } + return nil +} + // githubApiClient makes REST calls to the GitHub API type githubApiClient interface { REST(hostname, method, p string, body io.Reader, data interface{}) error @@ -39,8 +61,7 @@ type httpClient interface { } type Client interface { - GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) - GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) + GetByDigest(params FetchParams) ([]*Attestation, error) GetTrustDomain() (string, error) } @@ -60,22 +81,11 @@ func NewLiveClient(hc *http.Client, host string, l *ioconfig.Handler) *LiveClien } } -// GetByRepoAndDigest fetches the attestation by repo and digest -func (c *LiveClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - url := fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, repo, digest) - return c.getByURL(url, limit) -} - -// GetByOwnerAndDigest fetches attestation by owner and digest -func (c *LiveClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { - c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest) - url := fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, owner, digest) - return c.getByURL(url, limit) -} - -func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) { - attestations, err := c.getAttestations(url, limit) +// GetByDigest fetches the attestation by digest and either owner or repo +// depending on which is provided +func (c *LiveClient) GetByDigest(params FetchParams) ([]*Attestation, error) { + c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", params.Digest) + attestations, err := c.getAttestations(params) if err != nil { return nil, err } @@ -88,40 +98,52 @@ func (c *LiveClient) getByURL(url string, limit int) ([]*Attestation, error) { return bundles, nil } -// GetTrustDomain returns the current trust domain. If the default is used -// the empty string is returned -func (c *LiveClient) GetTrustDomain() (string, error) { - return c.getTrustDomain(MetaPath) -} +func (c *LiveClient) buildRequestURL(params FetchParams) (string, error) { + if err := params.Validate(); err != nil { + return "", err + } -func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, error) { - perPage := limit - if perPage <= 0 || perPage > maxLimitForFlag { - return nil, fmt.Errorf("limit must be greater than 0 and less than or equal to %d", maxLimitForFlag) + var url string + if params.Repo != "" { + // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. + // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. + url = fmt.Sprintf(GetAttestationByRepoAndSubjectDigestPath, params.Repo, params.Digest) + } else { + url = fmt.Sprintf(GetAttestationByOwnerAndSubjectDigestPath, params.Owner, params.Digest) } + perPage := params.Limit if perPage > maxLimitForFetch { perPage = maxLimitForFetch } // ref: https://github.com/cli/go-gh/blob/d32c104a9a25c9de3d7c7b07a43ae0091441c858/example_gh_test.go#L96 url = fmt.Sprintf("%s?per_page=%d", url, perPage) + if params.PredicateType != "" { + url = fmt.Sprintf("%s&predicate_type=%s", url, params.PredicateType) + } + return url, nil +} + +func (c *LiveClient) getAttestations(params FetchParams) ([]*Attestation, error) { + url, err := c.buildRequestURL(params) + if err != nil { + return nil, err + } var attestations []*Attestation var resp AttestationsResponse bo := backoff.NewConstantBackOff(getAttestationRetryInterval) // if no attestation or less than limit, then keep fetching - for url != "" && len(attestations) < limit { + for url != "" && len(attestations) < params.Limit { err := backoff.Retry(func() error { newURL, restErr := c.githubAPI.RESTWithNext(c.host, http.MethodGet, url, nil, &resp) - if restErr != nil { if shouldRetry(restErr) { return restErr - } else { - return backoff.Permanent(restErr) } + return backoff.Permanent(restErr) } url = newURL @@ -140,8 +162,8 @@ func (c *LiveClient) getAttestations(url string, limit int) ([]*Attestation, err return nil, ErrNoAttestationsFound } - if len(attestations) > limit { - return attestations[:limit], nil + if len(attestations) > params.Limit { + return attestations[:params.Limit], nil } return attestations, nil @@ -241,6 +263,12 @@ func shouldRetry(err error) bool { return false } +// GetTrustDomain returns the current trust domain. If the default is used +// the empty string is returned +func (c *LiveClient) GetTrustDomain() (string, error) { + return c.getTrustDomain(MetaPath) +} + func (c *LiveClient) getTrustDomain(url string) (string, error) { var resp MetaResponse diff --git a/pkg/cmd/attestation/api/client_test.go b/pkg/cmd/attestation/api/client_test.go index 787408a4e79..384c7c9c8af 100644 --- a/pkg/cmd/attestation/api/client_test.go +++ b/pkg/cmd/attestation/api/client_test.go @@ -42,78 +42,75 @@ func NewClientWithMockGHClient(hasNextPage bool) Client { } } -func TestGetByDigest(t *testing.T) { - c := NewClientWithMockGHClient(false) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, 5, len(attestations)) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, 5, len(attestations)) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") +var testFetchParamsWithOwner = FetchParams{ + Digest: testDigest, + Limit: DefaultLimit, + Owner: testOwner, + PredicateType: "https://slsa.dev/provenance/v1", } - -func TestGetByDigestGreaterThanLimit(t *testing.T) { - c := NewClientWithMockGHClient(false) - - limit := 3 - // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, 3, len(attestations)) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") +var testFetchParamsWithRepo = FetchParams{ + Digest: testDigest, + Limit: DefaultLimit, + Repo: testRepo, + PredicateType: "https://slsa.dev/provenance/v1", } -func TestGetByDigestWithNextPage(t *testing.T) { - c := NewClientWithMockGHClient(true) - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, len(attestations), 10) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) - require.NoError(t, err) - - require.Equal(t, len(attestations), 10) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") +type getByTestCase struct { + name string + params FetchParams + limit int + expectedAttestations int + hasNextPage bool } -func TestGetByDigestGreaterThanLimitWithNextPage(t *testing.T) { - c := NewClientWithMockGHClient(true) - - limit := 7 - // The method should return five results when the limit is not set - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, limit) - require.NoError(t, err) - - require.Equal(t, len(attestations), limit) - bundle := (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, limit) - require.NoError(t, err) +var getByTestCases = []getByTestCase{ + { + name: "get by digest with owner", + params: testFetchParamsWithOwner, + expectedAttestations: 5, + }, + { + name: "get by digest with repo", + params: testFetchParamsWithRepo, + expectedAttestations: 5, + }, + { + name: "get by digest with attestations greater than limit", + params: testFetchParamsWithRepo, + limit: 3, + expectedAttestations: 3, + }, + { + name: "get by digest with next page", + params: testFetchParamsWithRepo, + expectedAttestations: 10, + hasNextPage: true, + }, + { + name: "greater than limit with next page", + params: testFetchParamsWithRepo, + limit: 7, + expectedAttestations: 7, + hasNextPage: true, + }, +} - require.Equal(t, len(attestations), limit) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") +func TestGetByDigest(t *testing.T) { + for _, tc := range getByTestCases { + t.Run(tc.name, func(t *testing.T) { + c := NewClientWithMockGHClient(tc.hasNextPage) + + if tc.limit > 0 { + tc.params.Limit = tc.limit + } + attestations, err := c.GetByDigest(tc.params) + require.NoError(t, err) + + require.Equal(t, tc.expectedAttestations, len(attestations)) + bundle := (attestations)[0].Bundle + require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") + }) + } } func TestGetByDigest_NoAttestationsFound(t *testing.T) { @@ -130,12 +127,7 @@ func TestGetByDigest_NoAttestationsFound(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.Error(t, err) - require.IsType(t, ErrNoAttestationsFound, err) - require.Nil(t, attestations) - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) require.IsType(t, ErrNoAttestationsFound, err) require.Nil(t, attestations) @@ -153,11 +145,7 @@ func TestGetByDigest_Error(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) - require.Error(t, err) - require.Nil(t, attestations) - - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) require.Nil(t, attestations) } @@ -362,7 +350,8 @@ func TestGetAttestationsRetries(t *testing.T) { logger: io.NewTestHandler(), } - attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + testFetchParamsWithRepo.Limit = 30 + attestations, err := c.GetByDigest(testFetchParamsWithRepo) require.NoError(t, err) // assert the error path was executed; because this is a paged @@ -373,17 +362,6 @@ func TestGetAttestationsRetries(t *testing.T) { require.Equal(t, len(attestations), 10) bundle := (attestations)[0].Bundle require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") - - // same test as above, but for GetByOwnerAndDigest: - attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit) - require.NoError(t, err) - - // because we haven't reset the mock, we have added 2 more failed requests - fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4) - - require.Equal(t, len(attestations), 10) - bundle = (attestations)[0].Bundle - require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json") } // test total retries @@ -401,7 +379,7 @@ func TestGetAttestationsMaxRetries(t *testing.T) { logger: io.NewTestHandler(), } - _, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit) + _, err := c.GetByDigest(testFetchParamsWithRepo) require.Error(t, err) fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4) diff --git a/pkg/cmd/attestation/api/mock_client.go b/pkg/cmd/attestation/api/mock_client.go index b2fd334c056..b6062b39fb3 100644 --- a/pkg/cmd/attestation/api/mock_client.go +++ b/pkg/cmd/attestation/api/mock_client.go @@ -6,58 +6,49 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/test/data" ) -type MockClient struct { - OnGetByRepoAndDigest func(repo, digest string, limit int) ([]*Attestation, error) - OnGetByOwnerAndDigest func(owner, digest string, limit int) ([]*Attestation, error) - OnGetTrustDomain func() (string, error) +func makeTestAttestation() Attestation { + return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} } -func (m MockClient) GetByRepoAndDigest(repo, digest string, limit int) ([]*Attestation, error) { - return m.OnGetByRepoAndDigest(repo, digest, limit) +type MockClient struct { + OnGetByDigest func(params FetchParams) ([]*Attestation, error) + OnGetTrustDomain func() (string, error) } -func (m MockClient) GetByOwnerAndDigest(owner, digest string, limit int) ([]*Attestation, error) { - return m.OnGetByOwnerAndDigest(owner, digest, limit) +func (m MockClient) GetByDigest(params FetchParams) ([]*Attestation, error) { + return m.OnGetByDigest(params) } func (m MockClient) GetTrustDomain() (string, error) { return m.OnGetTrustDomain() } -func makeTestAttestation() Attestation { - return Attestation{Bundle: data.SigstoreBundle(nil), BundleURL: "https://example.com"} -} - -func OnGetByRepoAndDigestSuccess(repo, digest string, limit int) ([]*Attestation, error) { +func OnGetByDigestSuccess(params FetchParams) ([]*Attestation, error) { att1 := makeTestAttestation() att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil -} + attestations := []*Attestation{&att1, &att2} + if params.PredicateType != "" { + return FilterAttestations(params.PredicateType, attestations) + } -func OnGetByRepoAndDigestFailure(repo, digest string, limit int) ([]*Attestation, error) { - return nil, fmt.Errorf("failed to fetch by repo and digest") + return attestations, nil } -func OnGetByOwnerAndDigestSuccess(owner, digest string, limit int) ([]*Attestation, error) { - att1 := makeTestAttestation() - att2 := makeTestAttestation() - return []*Attestation{&att1, &att2}, nil -} - -func OnGetByOwnerAndDigestFailure(owner, digest string, limit int) ([]*Attestation, error) { - return nil, fmt.Errorf("failed to fetch by owner and digest") +func OnGetByDigestFailure(params FetchParams) ([]*Attestation, error) { + if params.Repo != "" { + return nil, fmt.Errorf("failed to fetch attestations from %s", params.Repo) + } + return nil, fmt.Errorf("failed to fetch attestations from %s", params.Owner) } func NewTestClient() *MockClient { return &MockClient{ - OnGetByRepoAndDigest: OnGetByRepoAndDigestSuccess, - OnGetByOwnerAndDigest: OnGetByOwnerAndDigestSuccess, + OnGetByDigest: OnGetByDigestSuccess, } } func NewFailTestClient() *MockClient { return &MockClient{ - OnGetByRepoAndDigest: OnGetByRepoAndDigestFailure, - OnGetByOwnerAndDigest: OnGetByOwnerAndDigestFailure, + OnGetByDigest: OnGetByDigestFailure, } } diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 6913c07873e..8d1d1dc0511 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -9,7 +9,6 @@ import ( "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" "github.com/cli/cli/v2/pkg/cmd/attestation/auth" "github.com/cli/cli/v2/pkg/cmd/attestation/io" - "github.com/cli/cli/v2/pkg/cmd/attestation/verification" "github.com/cli/cli/v2/pkg/cmdutil" ghauth "github.com/cli/go-gh/v2/pkg/auth" @@ -127,13 +126,16 @@ func runDownload(opts *Options) error { opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath) - params := verification.FetchRemoteAttestationsParams{ + if opts.APIClient == nil { + return fmt.Errorf("no APIClient provided") + } + params := api.FetchParams{ Digest: artifact.DigestWithAlg(), Limit: opts.Limit, Owner: opts.Owner, Repo: opts.Repo, } - attestations, err := verification.GetRemoteAttestations(opts.APIClient, params) + attestations, err := opts.APIClient.GetByDigest(params) if err != nil { if errors.Is(err, api.ErrNoAttestationsFound) { fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath) @@ -144,10 +146,9 @@ func runDownload(opts *Options) error { // Apply predicate type filter to returned attestations if opts.PredicateType != "" { - filteredAttestations := verification.FilterAttestations(opts.PredicateType, attestations) - - if len(filteredAttestations) == 0 { - return fmt.Errorf("no attestations found with predicate type: %s", opts.PredicateType) + filteredAttestations, err := api.FilterAttestations(opts.PredicateType, attestations) + if err != nil { + return fmt.Errorf("failed to filter attestations: %v", err) } attestations = filteredAttestations diff --git a/pkg/cmd/attestation/download/download_test.go b/pkg/cmd/attestation/download/download_test.go index ddcd08c9280..11872daf900 100644 --- a/pkg/cmd/attestation/download/download_test.go +++ b/pkg/cmd/attestation/download/download_test.go @@ -275,7 +275,7 @@ func TestRunDownload(t *testing.T) { t.Run("no attestations found", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { + OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, api.ErrNoAttestationsFound }, } @@ -291,7 +291,7 @@ func TestRunDownload(t *testing.T) { t.Run("failed to fetch attestations", func(t *testing.T) { opts := baseOpts opts.APIClient = api.MockClient{ - OnGetByOwnerAndDigest: func(repo, digest string, limit int) ([]*api.Attestation, error) { + OnGetByDigest: func(params api.FetchParams) ([]*api.Attestation, error) { return nil, fmt.Errorf("failed to fetch attestations") }, } diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index db419ebaca1..10eb02ac402 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -20,13 +20,6 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1" var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl") var ErrEmptyBundleFile = errors.New("provided bundle file is empty") -type FetchRemoteAttestationsParams struct { - Digest string - Limit int - Owner string - Repo string -} - // GetLocalAttestations returns a slice of attestations read from a local bundle file. func GetLocalAttestations(path string) ([]*api.Attestation, error) { var attestations []*api.Attestation @@ -89,28 +82,6 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) { return attestations, nil } -func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) { - if client == nil { - return nil, fmt.Errorf("api client must be provided") - } - // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. - // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - if params.Repo != "" { - attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) - } - return attestations, nil - } else if params.Owner != "" { - attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit) - if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) - } - return attestations, nil - } - return nil, fmt.Errorf("owner or repo must be provided") -} - func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) { attestations, err := client.GetAttestations(artifact.NameRef(), artifact.DigestWithAlg()) if err != nil { @@ -121,31 +92,3 @@ func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ( } return attestations, nil } - -type IntotoStatement struct { - PredicateType string `json:"predicateType"` -} - -func FilterAttestations(predicateType string, attestations []*api.Attestation) []*api.Attestation { - filteredAttestations := []*api.Attestation{} - - for _, each := range attestations { - dsseEnvelope := each.Bundle.GetDsseEnvelope() - if dsseEnvelope != nil { - if dsseEnvelope.PayloadType != "application/vnd.in-toto+json" { - // Don't fail just because an entry isn't intoto - continue - } - var intotoStatement IntotoStatement - if err := json.Unmarshal([]byte(dsseEnvelope.Payload), &intotoStatement); err != nil { - // Don't fail just because a single entry can't be unmarshalled - continue - } - if intotoStatement.PredicateType == predicateType { - filteredAttestations = append(filteredAttestations, each) - } - } - } - - return filteredAttestations -} diff --git a/pkg/cmd/attestation/verification/attestation_test.go b/pkg/cmd/attestation/verification/attestation_test.go index 8acff0c37d2..6826e2e4000 100644 --- a/pkg/cmd/attestation/verification/attestation_test.go +++ b/pkg/cmd/attestation/verification/attestation_test.go @@ -157,10 +157,11 @@ func TestFilterAttestations(t *testing.T) { }, } - filtered := FilterAttestations("https://slsa.dev/provenance/v1", attestations) - + filtered, err := api.FilterAttestations("https://slsa.dev/provenance/v1", attestations) require.Len(t, filtered, 1) + require.NoError(t, err) - filtered = FilterAttestations("NonExistentPredicate", attestations) - require.Len(t, filtered, 0) + filtered, err = api.FilterAttestations("NonExistentPredicate", attestations) + require.Nil(t, filtered) + require.Error(t, err) } diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go index bb96c95269f..1b98fabf334 100644 --- a/pkg/cmd/attestation/verify/attestation.go +++ b/pkg/cmd/attestation/verify/attestation.go @@ -1,6 +1,7 @@ package verify import ( + "errors" "fmt" "github.com/cli/cli/v2/internal/text" @@ -10,43 +11,63 @@ import ( ) func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { - if o.BundlePath != "" { - attestations, err := verification.GetLocalAttestations(o.BundlePath) + // Fetch attestations from GitHub API within this if block since predicate type + // filter is done when the API is called + if o.FetchAttestationsFromGitHubAPI() { + if o.APIClient == nil { + errMsg := "✗ No APIClient provided" + return nil, errMsg, errors.New(errMsg) + } + + params := api.FetchParams{ + Digest: a.DigestWithAlg(), + Limit: o.Limit, + Owner: o.Owner, + PredicateType: o.PredicateType, + Repo: o.Repo, + } + + attestations, err := o.APIClient.GetByDigest(params) if err != nil { - msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) + msg := "✗ Loading attestations from GitHub API failed" return nil, msg, err } pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) return attestations, msg, nil } - if o.UseBundleFromRegistry { - attestations, err := verification.GetOCIAttestations(o.OCIClient, a) + // Fetch attestations from local bundle or OCI registry + // Predicate type filtering is done after the attestations are fetched + var attestations []*api.Attestation + var err error + var msg string + if o.BundlePath != "" { + attestations, err = verification.GetLocalAttestations(o.BundlePath) if err != nil { - msg := "✗ Loading attestations from OCI registry failed" - return nil, msg, err + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg = fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + } else { + msg = fmt.Sprintf("Loaded %d attestations from %s", len(attestations), o.BundlePath) + } + } else if o.UseBundleFromRegistry { + attestations, err = verification.GetOCIAttestations(o.OCIClient, a) + if err != nil { + msg = "✗ Loading attestations from OCI registry failed" + } else { + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg = fmt.Sprintf("Loaded %s from OCI registry", pluralAttestation) } - pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) - return attestations, msg, nil } - - params := verification.FetchRemoteAttestationsParams{ - Digest: a.DigestWithAlg(), - Limit: o.Limit, - Owner: o.Owner, - Repo: o.Repo, + if err != nil { + return nil, msg, err } - attestations, err := verification.GetRemoteAttestations(o.APIClient, params) + filtered, err := api.FilterAttestations(o.PredicateType, attestations) if err != nil { - msg := "✗ Loading attestations from GitHub API failed" - return nil, msg, err + return nil, err.Error(), err } - pluralAttestation := text.Pluralize(len(attestations), "attestation") - msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) - return attestations, msg, nil + return filtered, msg, nil } func verifyAttestations(art artifact.DigestedArtifact, att []*api.Attestation, sgVerifier verification.SigstoreVerifier, ec verification.EnforcementCriteria) ([]*verification.AttestationProcessingResult, string, error) { diff --git a/pkg/cmd/attestation/verify/attestation_test.go b/pkg/cmd/attestation/verify/attestation_test.go new file mode 100644 index 00000000000..f015805ae5f --- /dev/null +++ b/pkg/cmd/attestation/verify/attestation_test.go @@ -0,0 +1,71 @@ +package verify + +import ( + "testing" + + "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" + "github.com/cli/cli/v2/pkg/cmd/attestation/verification" + "github.com/stretchr/testify/require" +) + +func TestGetAttestations_OCIRegistry_PredicateTypeFiltering(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + OCIClient: oci.MockClient{}, + PredicateType: verification.SLSAPredicateV1, + Repo: "cli/cli", + UseBundleFromRegistry: true, + } + attestations, msg, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Contains(t, msg, "Loaded 2 attestations from OCI registry") + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, msg, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Contains(t, msg, "no attestations found with predicate type") + require.Nil(t, attestations) +} + +func TestGetAttestations_LocalBundle_PredicateTypeFiltering(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + BundlePath: "../test/data/sigstore-js-2.1.0-bundle.json", + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, _, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Len(t, attestations, 1) + + o.PredicateType = "custom predicate type" + attestations, _, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Nil(t, attestations) +} + +func TestGetAttestations_GhAPI_NoAttestationsFound(t *testing.T) { + artifact, err := artifact.NewDigestedArtifact(nil, "../test/data/gh_2.60.1_windows_arm64.zip", "sha256") + require.NoError(t, err) + + o := &Options{ + APIClient: api.NewTestClient(), + PredicateType: verification.SLSAPredicateV1, + Repo: "sigstore/sigstore-js", + } + attestations, _, err := getAttestations(o, *artifact) + require.NoError(t, err) + require.Len(t, attestations, 2) + + o.PredicateType = "custom predicate type" + attestations, _, err = getAttestations(o, *artifact) + require.Error(t, err) + require.Nil(t, attestations) +} diff --git a/pkg/cmd/attestation/verify/options.go b/pkg/cmd/attestation/verify/options.go index 0fbbec55a05..e47c4f4a83b 100644 --- a/pkg/cmd/attestation/verify/options.go +++ b/pkg/cmd/attestation/verify/options.go @@ -53,6 +53,12 @@ func (opts *Options) Clean() { } } +// FetchAttestationsFromGitHubAPI returns true if the command should fetch attestations from the GitHub API +// It checks that a bundle path is not provided and that the "use bundle from registry" flag is not set +func (opts *Options) FetchAttestationsFromGitHubAPI() bool { + return opts.BundlePath == "" && !opts.UseBundleFromRegistry +} + // AreFlagsValid checks that the provided flag combination is valid // and returns an error otherwise func (opts *Options) AreFlagsValid() error { diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index b3bad519aad..b8debc529c1 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -288,14 +288,6 @@ func runVerify(opts *Options) error { // Print the message signifying success fetching attestations opts.Logger.Println(logMsg) - // Apply predicate type filter to returned attestations - filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations) - if len(filteredAttestations) == 0 { - opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found with predicate type: %s\n"), opts.PredicateType) - return fmt.Errorf("no matching predicate found") - } - attestations = filteredAttestations - // print information about the policy that will be enforced against attestations opts.Logger.Println("\nThe following policy criteria will be enforced:") opts.Logger.Println(ec.BuildPolicyInformation()) diff --git a/pkg/cmd/attestation/verify/verify_test.go b/pkg/cmd/attestation/verify/verify_test.go index 092a009d81e..2b821a435d9 100644 --- a/pkg/cmd/attestation/verify/verify_test.go +++ b/pkg/cmd/attestation/verify/verify_test.go @@ -510,7 +510,7 @@ func TestRunVerify(t *testing.T) { err := runVerify(&customOpts) require.Error(t, err) - require.ErrorContains(t, err, "no matching predicate found") + require.ErrorContains(t, err, "no attestations found with predicate type") }) t.Run("with valid OCI artifact with UseBundleFromRegistry flag but no bundle return from registry", func(t *testing.T) { diff --git a/pkg/cmd/gist/delete/delete_test.go b/pkg/cmd/gist/delete/delete_test.go index 24ca2bb33c2..2c4df8d8d6f 100644 --- a/pkg/cmd/gist/delete/delete_test.go +++ b/pkg/cmd/gist/delete/delete_test.go @@ -18,6 +18,7 @@ import ( ghAPI "github.com/cli/go-gh/v2/pkg/api" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCmdDelete(t *testing.T) { @@ -327,11 +328,12 @@ func Test_deleteRun(t *testing.T) { func Test_gistDelete(t *testing.T) { tests := []struct { - name string - httpStubs func(*httpmock.Registry) - hostname string - gistID string - wantErr error + name string + httpStubs func(*httpmock.Registry) + hostname string + gistID string + wantErr error + wantErrString string }{ { name: "successful delete", @@ -343,36 +345,34 @@ func Test_gistDelete(t *testing.T) { }, hostname: "github.com", gistID: "1234", - wantErr: nil, }, { - name: "when an gist is not found, it returns a NotFoundError", + name: "when a gist is not found, it returns a NotFoundError", httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.REST("DELETE", "gists/1234"), httpmock.StatusStringResponse(404, "{}"), ) }, - hostname: "github.com", - gistID: "1234", - wantErr: shared.NotFoundErr, + hostname: "github.com", + gistID: "1234", + wantErr: shared.NotFoundErr, // To make sure we return the pre-defined error instance. + wantErrString: "not found", }, { name: "when there is a non-404 error deleting the gist, that error is returned", httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.REST("DELETE", "gists/1234"), - httpmock.StatusJSONResponse(500, `{"message": "arbitrary error"}`), + httpmock.JSONErrorResponse(500, ghAPI.HTTPError{ + StatusCode: 500, + Message: "arbitrary error", + }), ) }, - hostname: "github.com", - gistID: "1234", - wantErr: api.HTTPError{ - HTTPError: &ghAPI.HTTPError{ - StatusCode: 500, - Message: "arbitrary error", - }, - }, + hostname: "github.com", + gistID: "1234", + wantErrString: "HTTP 500: arbitrary error (https://api.github.com/gists/1234)", }, } @@ -383,12 +383,16 @@ func Test_gistDelete(t *testing.T) { client := api.NewClientFromHTTP(&http.Client{Transport: reg}) err := deleteGist(client, tt.hostname, tt.gistID) - if tt.wantErr != nil { - assert.ErrorAs(t, err, &tt.wantErr) + if tt.wantErrString == "" && tt.wantErr == nil { + require.NoError(t, err) } else { - assert.NoError(t, err) + if tt.wantErrString != "" { + require.EqualError(t, err, tt.wantErrString) + } + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + } } - }) } } diff --git a/pkg/cmd/gpg-key/delete/delete_test.go b/pkg/cmd/gpg-key/delete/delete_test.go index 115e72db239..dc730b100ed 100644 --- a/pkg/cmd/gpg-key/delete/delete_test.go +++ b/pkg/cmd/gpg-key/delete/delete_test.go @@ -177,7 +177,7 @@ func Test_deleteRun(t *testing.T) { opts: DeleteOptions{KeyID: "ABC123", Confirmed: true}, httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "user/gpg_keys"), httpmock.StatusStringResponse(200, keysResp)) - reg.Register(httpmock.REST("DELETE", "user/gpg_keys/123"), httpmock.StatusJSONResponse(404, api.HTTPError{ + reg.Register(httpmock.REST("DELETE", "user/gpg_keys/123"), httpmock.JSONErrorResponse(404, api.HTTPError{ StatusCode: 404, Message: "GPG key 123 not found", })) diff --git a/pkg/cmd/issue/comment/comment.go b/pkg/cmd/issue/comment/comment.go index 706ff791eae..9b779165699 100644 --- a/pkg/cmd/issue/comment/comment.go +++ b/pkg/cmd/issue/comment/comment.go @@ -18,6 +18,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e InteractiveEditSurvey: prShared.CommentableInteractiveEditSurvey(f.Config, f.IOStreams), ConfirmSubmitSurvey: prShared.CommentableConfirmSubmitSurvey(f.Prompter), ConfirmCreateIfNoneSurvey: prShared.CommentableInteractiveCreateIfNoneSurvey(f.Prompter), + ConfirmDeleteLastComment: prShared.CommentableConfirmDeleteLastComment(f.Prompter), OpenInBrowser: f.Browser.Browse, } @@ -63,7 +64,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e } fields := []string{"id", "url"} - if opts.EditLast { + if opts.EditLast || opts.DeleteLast { fields = append(fields, "comments") } @@ -96,7 +97,9 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e cmd.Flags().StringVarP(&bodyFile, "body-file", "F", "", "Read body text from `file` (use \"-\" to read from standard input)") cmd.Flags().BoolP("editor", "e", false, "Skip prompts and open the text editor to write the body in") cmd.Flags().BoolP("web", "w", false, "Open the web browser to write the comment") - cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the same author") + cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLast, "delete-last", false, "Delete the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLastConfirmed, "yes", false, "Skip the delete confirmation prompt when --delete-last is provided") cmd.Flags().BoolVar(&opts.CreateIfNone, "create-if-none", false, "Create a new comment if no comments are found. Can be used only with --edit-last") return cmd diff --git a/pkg/cmd/issue/comment/comment_test.go b/pkg/cmd/issue/comment/comment_test.go index 794dafda4a4..adee53f7e2a 100644 --- a/pkg/cmd/issue/comment/comment_test.go +++ b/pkg/cmd/issue/comment/comment_test.go @@ -2,6 +2,7 @@ package comment import ( "bytes" + "errors" "fmt" "net/http" "os" @@ -31,11 +32,13 @@ func TestNewCmdComment(t *testing.T) { stdin string output shared.CommentableOptions wantsErr bool + isTTY bool }{ { name: "no arguments", input: "", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { @@ -46,6 +49,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -56,6 +60,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -66,6 +71,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "test", }, + isTTY: true, wantsErr: false, }, { @@ -77,6 +83,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "this is on standard input", }, + isTTY: true, wantsErr: false, }, { @@ -87,6 +94,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "a body from file", }, + isTTY: true, wantsErr: false, }, { @@ -97,6 +105,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeEditor, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -107,6 +116,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeWeb, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -118,6 +128,7 @@ func TestNewCmdComment(t *testing.T) { Body: "", EditLast: true, }, + isTTY: true, wantsErr: false, }, { @@ -130,42 +141,110 @@ func TestNewCmdComment(t *testing.T) { EditLast: true, CreateIfNone: true, }, + isTTY: true, wantsErr: false, }, + { + name: "delete last flag non-interactive", + input: "1 --delete-last", + isTTY: false, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation non-interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: false, + wantsErr: false, + }, + { + name: "delete last flag interactive", + input: "1 --delete-last", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation with web flag", + input: "1 --delete-last --yes --web", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with editor flag", + input: "1 --delete-last --yes --editor", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with body flag", + input: "1 --delete-last --yes --body", + isTTY: true, + wantsErr: true, + }, + { + name: "delete pre-confirmation without delete last flag", + input: "1 --yes", + isTTY: true, + wantsErr: true, + }, { name: "body and body-file flags", input: "1 --body 'test' --body-file 'test-file.txt'", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and web flags", input: "1 --editor --web", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and body flags", input: "1 --editor --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "web and body flags", input: "1 --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor, web, and body flags", input: "1 --editor --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "create-if-none flag without edit-last", input: "1 --create-if-none", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, } @@ -173,9 +252,10 @@ func TestNewCmdComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, stdin, _, _ := iostreams.Test() - ios.SetStdoutTTY(true) - ios.SetStdinTTY(true) - ios.SetStderrTTY(true) + isTTY := tt.isTTY + ios.SetStdoutTTY(isTTY) + ios.SetStdinTTY(isTTY) + ios.SetStderrTTY(isTTY) if tt.stdin != "" { _, _ = stdin.WriteString(tt.stdin) @@ -211,6 +291,8 @@ func TestNewCmdComment(t *testing.T) { assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) assert.Equal(t, tt.output.InputType, gotOpts.InputType) assert.Equal(t, tt.output.Body, gotOpts.Body) + assert.Equal(t, tt.output.DeleteLast, gotOpts.DeleteLast) + assert.Equal(t, tt.output.DeleteLastConfirmed, gotOpts.DeleteLastConfirmed) }) } } @@ -220,6 +302,7 @@ func Test_commentRun(t *testing.T) { name string input *shared.CommentableOptions emptyComments bool + comments api.Comments httpStubs func(*testing.T, *httpmock.Registry) stdout string stderr string @@ -255,6 +338,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with interactive editor succeeds if there are comments", @@ -331,6 +415,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "creating new comment with non-interactive editor succeeds", @@ -358,6 +443,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with non-interactive editor succeeds if there are comments", @@ -433,6 +519,117 @@ func Test_commentRun(t *testing.T) { }, stdout: "https://github.com/OWNER/REPO/issues/123#issuecomment-456\n", }, + { + name: "deleting last comment non-interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment non-interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmation declined", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + wantsErr: true, + stdout: "deletion not confirmed", + }, + { + name: "deleting last comment interactively and confirmed with long comment body", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "Lorem ipsum dolor sit amet, consectet lo..." { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "Lorem ipsum dolor sit amet, consectet lorem ipsum again"}, + }}, + wantsErr: false, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, } for _, tt := range tests { ios, _, stdout, stderr := iostreams.Test() @@ -458,6 +655,8 @@ func Test_commentRun(t *testing.T) { if tt.emptyComments { comments.Nodes = []api.Comment{} + } else if len(tt.comments.Nodes) > 0 { + comments = tt.comments } tt.input.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) { @@ -472,6 +671,7 @@ func Test_commentRun(t *testing.T) { err := shared.CommentableRun(tt.input) if tt.wantsErr { assert.Error(t, err) + assert.Equal(t, tt.stderr, stderr.String()) return } assert.NoError(t, err) @@ -508,3 +708,15 @@ func mockCommentUpdate(t *testing.T, reg *httpmock.Registry) { }), ) } + +func mockCommentDelete(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation CommentDelete\b`), + httpmock.GraphQLMutation(` + { "data": { "deleteIssueComment": {} } }`, + func(inputs map[string]interface{}) { + assert.Equal(t, "id1", inputs["id"]) + }, + ), + ) +} diff --git a/pkg/cmd/issue/edit/edit.go b/pkg/cmd/issue/edit/edit.go index 8386cbcfa24..e959cde2b30 100644 --- a/pkg/cmd/issue/edit/edit.go +++ b/pkg/cmd/issue/edit/edit.go @@ -60,11 +60,17 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman Editing issues' projects requires authorization with the %[1]sproject%[1]s scope. To authorize, run %[1]sgh auth refresh -s project%[1]s. + + The %[1]s--add-assignee%[1]s and %[1]s--remove-assignee%[1]s flags both support + the following special values: + - %[1]s@me%[1]s: assign or unassign yourself + - %[1]s@copilot%[1]s: assign or unassign Copilot (not supported on GitHub Enterprise Server) `, "`"), Example: heredoc.Doc(` $ gh issue edit 23 --title "I found a bug" --body "Nothing works" $ gh issue edit 23 --add-label "bug,help wanted" --remove-label "core" $ gh issue edit 23 --add-assignee "@me" --remove-assignee monalisa,hubot + $ gh issue edit 23 --add-assignee "@copilot" $ gh issue edit 23 --add-project "Roadmap" --remove-project v1,v2 $ gh issue edit 23 --milestone "Version 1" $ gh issue edit 23 --remove-milestone @@ -197,9 +203,24 @@ func editRun(opts *EditOptions) error { } } + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) + } + + issueFeatures, err := opts.Detector.IssueFeatures() + if err != nil { + return err + } + lookupFields := []string{"id", "number", "title", "body", "url"} if editable.Assignees.Edited { - lookupFields = append(lookupFields, "assignees") + if issueFeatures.ActorIsAssignable { + editable.Assignees.ActorAssignees = true + lookupFields = append(lookupFields, "assignedActors") + } else { + lookupFields = append(lookupFields, "assignees") + } } if editable.Labels.Edited { lookupFields = append(lookupFields, "labels") @@ -207,11 +228,6 @@ func editRun(opts *EditOptions) error { if editable.Projects.Edited { // TODO projectsV1Deprecation // Remove this section as we should no longer add projectCards - if opts.Detector == nil { - cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) - opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) - } - projectsV1Support := opts.Detector.ProjectsV1() if projectsV1Support == gh.ProjectsV1Supported { lookupFields = append(lookupFields, "projectCards") @@ -254,7 +270,14 @@ func editRun(opts *EditOptions) error { editable.Title.Default = issue.Title editable.Body.Default = issue.Body - editable.Assignees.Default = issue.Assignees.Logins() + // We use Actors as the default assignees if Actors are assignable + // on this GitHub host. + if editable.Assignees.ActorAssignees { + editable.Assignees.Default = issue.AssignedActors.DisplayNames() + editable.Assignees.DefaultLogins = issue.AssignedActors.Logins() + } else { + editable.Assignees.Default = issue.Assignees.Logins() + } editable.Labels.Default = issue.Labels.Names() editable.Projects.Default = append(issue.ProjectCards.ProjectNames(), issue.ProjectItems.ProjectTitles()...) projectItems := map[string]string{} diff --git a/pkg/cmd/issue/edit/edit_test.go b/pkg/cmd/issue/edit/edit_test.go index c9aa4c409f4..4840cbf7ac1 100644 --- a/pkg/cmd/issue/edit/edit_test.go +++ b/pkg/cmd/issue/edit/edit_test.go @@ -118,9 +118,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ IssueNumbers: []int{23}, Editable: prShared.Editable{ - Assignees: prShared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Edited: true, + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Edited: true, + }, }, }, }, @@ -132,9 +134,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ IssueNumbers: []int{23}, Editable: prShared.Editable{ - Assignees: prShared.EditableSlice{ - Remove: []string{"monalisa", "hubot"}, - Edited: true, + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Remove: []string{"monalisa", "hubot"}, + Edited: true, + }, }, }, }, @@ -354,10 +358,12 @@ func Test_editRun(t *testing.T) { Value: "new body", Edited: true, }, - Assignees: prShared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Remove: []string{"octocat"}, - Edited: true, + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, }, Labels: prShared.EditableSlice{ Add: []string{"feature", "TODO", "bug"}, @@ -388,6 +394,7 @@ func Test_editRun(t *testing.T) { mockIssueProjectItemsGet(t, reg) mockRepoMetadata(t, reg) mockIssueUpdate(t, reg) + mockIssueUpdateActorAssignees(t, reg) mockIssueUpdateLabels(t, reg) mockProjectV2ItemUpdate(t, reg) }, @@ -399,10 +406,12 @@ func Test_editRun(t *testing.T) { IssueNumbers: []int{456, 123}, Interactive: false, Editable: prShared.Editable{ - Assignees: prShared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Remove: []string{"octocat"}, - Edited: true, + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, }, Labels: prShared.EditableSlice{ Add: []string{"feature", "TODO", "bug"}, @@ -433,6 +442,8 @@ func Test_editRun(t *testing.T) { mockIssueProjectItemsGet(t, reg) mockIssueUpdate(t, reg) mockIssueUpdate(t, reg) + mockIssueUpdateActorAssignees(t, reg) + mockIssueUpdateActorAssignees(t, reg) mockIssueUpdateLabels(t, reg) mockIssueUpdateLabels(t, reg) mockProjectV2ItemUpdate(t, reg) @@ -449,10 +460,12 @@ func Test_editRun(t *testing.T) { IssueNumbers: []int{123, 9999}, Interactive: false, Editable: prShared.Editable{ - Assignees: prShared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Remove: []string{"octocat"}, - Edited: true, + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, }, Labels: prShared.EditableSlice{ Add: []string{"feature", "TODO", "bug"}, @@ -494,10 +507,12 @@ func Test_editRun(t *testing.T) { IssueNumbers: []int{123, 456}, Interactive: false, Editable: prShared.Editable{ - Assignees: prShared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Remove: []string{"octocat"}, - Edited: true, + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, }, Milestone: prShared.EditableString{ Value: "GA", @@ -509,14 +524,14 @@ func Test_editRun(t *testing.T) { httpStubs: func(t *testing.T, reg *httpmock.Registry) { // Should only be one fetch of metadata. reg.Register( - httpmock.GraphQL(`query RepositoryAssignableUsers\b`), + httpmock.GraphQL(`query RepositoryAssignableActors\b`), httpmock.StringResponse(` - { "data": { "repository": { "assignableUsers": { + { "data": { "repository": { "suggestedActors": { "nodes": [ - { "login": "hubot", "id": "HUBOTID" }, - { "login": "MonaLisa", "id": "MONAID" } + { "login": "hubot", "id": "HUBOTID", "__typename": "Bot" }, + { "login": "MonaLisa", "id": "MONAID", "__typename": "User" } ], - "pageInfo": { "hasNextPage": false } + "pageInfo": { "hasNextPage": false, "endCursor": "Mg" } } } } } `)) reg.Register( @@ -534,6 +549,14 @@ func Test_editRun(t *testing.T) { mockIssueNumberGet(t, reg, 123) mockIssueNumberGet(t, reg, 456) // Updating 123 should succeed. + reg.Register( + httpmock.GraphQLMutationMatcher(`mutation ReplaceActorsForAssignable\b`, func(m map[string]interface{}) bool { + return m["assignableId"] == "123" + }), + httpmock.GraphQLMutation(` + { "data": { "replaceActorsForAssignable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) reg.Register( httpmock.GraphQLMutationMatcher(`mutation IssueUpdate\b`, func(m map[string]interface{}) bool { return m["id"] == "123" @@ -544,8 +567,8 @@ func Test_editRun(t *testing.T) { ) // Updating 456 should fail. reg.Register( - httpmock.GraphQLMutationMatcher(`mutation IssueUpdate\b`, func(m map[string]interface{}) bool { - return m["id"] == "456" + httpmock.GraphQLMutationMatcher(`mutation ReplaceActorsForAssignable\b`, func(m map[string]interface{}) bool { + return m["assignableId"] == "456" }), httpmock.GraphQLMutation(` { "errors": [ { "message": "test error" } ] }`, @@ -591,11 +614,129 @@ func Test_editRun(t *testing.T) { mockIssueProjectItemsGet(t, reg) mockRepoMetadata(t, reg) mockIssueUpdate(t, reg) + mockIssueUpdateActorAssignees(t, reg) mockIssueUpdateLabels(t, reg) mockProjectV2ItemUpdate(t, reg) }, stdout: "https://github.com/OWNER/REPO/issue/123\n", }, + { + name: "interactive prompts with actor assignee display names when actors available", + input: &EditOptions{ + IssueNumbers: []int{123}, + Interactive: true, + FieldsToEditSurvey: func(p prShared.EditPrompter, eo *prShared.Editable) error { + eo.Assignees.Edited = true + return nil + }, + EditFieldsSurvey: func(p prShared.EditPrompter, eo *prShared.Editable, _ string) error { + // Checking that the display name is being used in the prompt. + require.Equal(t, eo.Assignees.Default, []string{"hubot", "MonaLisa (Mona Display Name)"}) + + // Mocking a selection of only MonaLisa in the prompt. + eo.Assignees.Value = []string{"MonaLisa (Mona Display Name)"} + return nil + }, + FetchOptions: prShared.FetchOptions, + DetermineEditor: func() (string, error) { return "vim", nil }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockIsssueNumberGetWithAssignedActors(t, reg, 123) + reg.Register( + httpmock.GraphQL(`query RepositoryAssignableActors\b`), + httpmock.StringResponse(` + { "data": { "repository": { "suggestedActors": { + "nodes": [ + { "login": "hubot", "id": "HUBOTID", "__typename": "Bot" }, + { "login": "MonaLisa", "id": "MONAID", "name": "Mona Display Name", "__typename": "User" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + mockIssueUpdate(t, reg) + reg.Register( + httpmock.GraphQL(`mutation ReplaceActorsForAssignable\b`), + httpmock.GraphQLMutation(` + { "data": { "replaceActorsForAssignable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) { + // Checking that despite the display name being returned + // from the EditFieldsSurvey, the ID is still + // used in the mutation. + require.Contains(t, inputs["actorIds"], "MONAID") + }), + ) + }, + stdout: "https://github.com/OWNER/REPO/issue/123\n", + }, + { + name: "interactive prompts with user assignee logins when actors unavailable", + input: &EditOptions{ + IssueNumbers: []int{123}, + Interactive: true, + FieldsToEditSurvey: func(p prShared.EditPrompter, eo *prShared.Editable) error { + eo.Assignees.Edited = true + return nil + }, + EditFieldsSurvey: func(p prShared.EditPrompter, eo *prShared.Editable, _ string) error { + // Checking that only the login is used in the prompt (no display name) + require.Equal(t, eo.Assignees.Default, []string{"hubot", "MonaLisa"}) + + // Mocking a selection of only MonaLisa in the prompt. + eo.Assignees.Value = []string{"MonaLisa"} + return nil + }, + FetchOptions: prShared.FetchOptions, + DetermineEditor: func() (string, error) { return "vim", nil }, + Detector: &fd.DisabledDetectorMock{}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query IssueByNumber\b`), + httpmock.StringResponse(fmt.Sprintf(` + { "data": { "repository": { "hasIssuesEnabled": true, "issue": { + "id": "%[1]d", + "number": %[1]d, + "url": "https://github.com/OWNER/REPO/issue/123", + "assignees": { + "nodes": [ + { + "id": "HUBOTID", + "login": "hubot", + "name": "" + }, + { + "id": "MONAID", + "login": "MonaLisa", + "name": "Mona Display Name" + } + ], + "totalCount": 2 + } + } } } }`, 123)), + ) + reg.Register( + httpmock.GraphQL(`query RepositoryAssignableUsers\b`), + httpmock.StringResponse(` + { "data": { "repository": { "assignableUsers": { + "nodes": [ + { "login": "hubot", "id": "HUBOTID", "name": "" }, + { "login": "MonaLisa", "id": "MONAID", "name": "Mona Display Name" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`mutation IssueUpdate\b`), + httpmock.GraphQLMutation(` + { "data": { "updateIssue": { "__typename": "" } } }`, + func(inputs map[string]interface{}) { + // Checking that we still assigned the expected ID. + require.Contains(t, inputs["assigneeIds"], "MONAID") + }), + ) + }, + stdout: "https://github.com/OWNER/REPO/issue/123\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -654,6 +795,34 @@ func mockIssueNumberGet(_ *testing.T, reg *httpmock.Registry, number int) { ) } +func mockIsssueNumberGetWithAssignedActors(_ *testing.T, reg *httpmock.Registry, number int) { + reg.Register( + httpmock.GraphQL(`query IssueByNumber\b`), + httpmock.StringResponse(fmt.Sprintf(` + { "data": { "repository": { "hasIssuesEnabled": true, "issue": { + "id": "%[1]d", + "number": %[1]d, + "url": "https://github.com/OWNER/REPO/issue/%[1]d", + "assignedActors": { + "nodes": [ + { + "id": "HUBOTID", + "login": "hubot", + "__typename": "Bot" + }, + { + "id": "MONAID", + "login": "MonaLisa", + "name": "Mona Display Name", + "__typename": "User" + } + ], + "totalCount": 2 + } + } } } }`, number)), + ) +} + func mockIssueProjectItemsGet(_ *testing.T, reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`query IssueProjectItems\b`), @@ -670,16 +839,17 @@ func mockIssueProjectItemsGet(_ *testing.T, reg *httpmock.Registry) { func mockRepoMetadata(_ *testing.T, reg *httpmock.Registry) { reg.Register( - httpmock.GraphQL(`query RepositoryAssignableUsers\b`), + httpmock.GraphQL(`query RepositoryAssignableActors\b`), httpmock.StringResponse(` - { "data": { "repository": { "assignableUsers": { + { "data": { "repository": { "suggestedActors": { "nodes": [ - { "login": "hubot", "id": "HUBOTID" }, - { "login": "MonaLisa", "id": "MONAID" } + { "login": "hubot", "id": "HUBOTID", "__typename": "Bot" }, + { "login": "MonaLisa", "id": "MONAID", "name": "Mona Display Name", "__typename": "User" } ], "pageInfo": { "hasNextPage": false } } } } } `)) + reg.Register( httpmock.GraphQL(`query RepositoryLabelList\b`), httpmock.StringResponse(` @@ -767,6 +937,15 @@ func mockIssueUpdate(t *testing.T, reg *httpmock.Registry) { ) } +func mockIssueUpdateActorAssignees(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation ReplaceActorsForAssignable\b`), + httpmock.GraphQLMutation(` + { "data": { "replaceActorsForAssignable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) +} + func mockIssueUpdateLabels(t *testing.T, reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`mutation LabelAdd\b`), @@ -791,6 +970,85 @@ func mockProjectV2ItemUpdate(t *testing.T, reg *httpmock.Registry) { ) } +func TestActorIsAssignable(t *testing.T) { + t.Run("when actors are assignable, query includes assignedActors", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + httpmock.GraphQL(`assignedActors`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we don't care. + _ = editRun(&EditOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + Detector: &fd.EnabledDetectorMock{}, + IssueNumbers: []int{123}, + Editable: prShared.Editable{ + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "octocat"}, + Edited: true, + }, + }, + }, + }) + + reg.Verify(t) + }) + + t.Run("when actors are not assignable, query includes assignees instead", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + // This test should NOT include assignedActors in the query + reg.Exclude(t, httpmock.GraphQL(`assignedActors`)) + // It should include the regular assignees field + reg.Register( + httpmock.GraphQL(`assignees`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we're not really interested in it. + _ = editRun(&EditOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + Detector: &fd.DisabledDetectorMock{}, + IssueNumbers: []int{123}, + Editable: prShared.Editable{ + Assignees: prShared.EditableAssignees{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"monalisa", "octocat"}, + Edited: true, + }, + }, + }, + }) + + reg.Verify(t) + }) +} + // TODO projectsV1Deprecation // Remove this test. func TestProjectsV1Deprecation(t *testing.T) { diff --git a/pkg/cmd/issue/view/http.go b/pkg/cmd/issue/view/http.go index e4f756436c6..4adc71802dc 100644 --- a/pkg/cmd/issue/view/http.go +++ b/pkg/cmd/issue/view/http.go @@ -53,3 +53,42 @@ func preloadIssueComments(client *http.Client, repo ghrepo.Interface, issue *api issue.Comments.PageInfo.HasNextPage = false return nil } + +func preloadClosedByPullRequestsReferences(client *http.Client, repo ghrepo.Interface, issue *api.Issue) error { + if !issue.ClosedByPullRequestsReferences.PageInfo.HasNextPage { + return nil + } + + type response struct { + Node struct { + Issue struct { + ClosedByPullRequestsReferences api.ClosedByPullRequestsReferences `graphql:"closedByPullRequestsReferences(first: 100, after: $endCursor)"` + } `graphql:"...on Issue"` + } `graphql:"node(id: $id)"` + } + + variables := map[string]interface{}{ + "id": githubv4.ID(issue.ID), + "endCursor": githubv4.String(issue.ClosedByPullRequestsReferences.PageInfo.EndCursor), + } + + gql := api.NewClientFromHTTP(client) + + for { + var query response + err := gql.Query(repo.RepoHost(), "closedByPullRequestsReferences", &query, variables) + if err != nil { + return err + } + + issue.ClosedByPullRequestsReferences.Nodes = append(issue.ClosedByPullRequestsReferences.Nodes, query.Node.Issue.ClosedByPullRequestsReferences.Nodes...) + + if !query.Node.Issue.ClosedByPullRequestsReferences.PageInfo.HasNextPage { + break + } + variables["endCursor"] = githubv4.String(query.Node.Issue.ClosedByPullRequestsReferences.PageInfo.EndCursor) + } + + issue.ClosedByPullRequestsReferences.PageInfo.HasNextPage = false + return nil +} diff --git a/pkg/cmd/issue/view/view.go b/pkg/cmd/issue/view/view.go index a9e25513bc9..3b02a3f2dc7 100644 --- a/pkg/cmd/issue/view/view.go +++ b/pkg/cmd/issue/view/view.go @@ -1,7 +1,6 @@ package view import ( - "errors" "fmt" "io" "net/http" @@ -134,6 +133,8 @@ func viewRun(opts *ViewOptions) error { opts.IO.DetectTerminalTheme() opts.IO.StartProgressIndicator() + defer opts.IO.StopProgressIndicator() + lookupFields.Add("id") issue, err := issueShared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, lookupFields.ToSlice()) @@ -144,18 +145,21 @@ func viewRun(opts *ViewOptions) error { if lookupFields.Contains("comments") { // FIXME: this re-fetches the comments connection even though the initial set of 100 were // fetched in the previous request. - err = preloadIssueComments(httpClient, baseRepo, issue) + err := preloadIssueComments(httpClient, baseRepo, issue) + if err != nil { + return err + } } - opts.IO.StopProgressIndicator() - if err != nil { - var loadErr *issueShared.PartialLoadError - if opts.Exporter == nil && errors.As(err, &loadErr) { - fmt.Fprintf(opts.IO.ErrOut, "warning: %s\n", loadErr.Error()) - } else { + + if lookupFields.Contains("closedByPullRequestsReferences") { + err := preloadClosedByPullRequestsReferences(httpClient, baseRepo, issue) + if err != nil { return err } } + opts.IO.StopProgressIndicator() + if opts.WebMode { openURL := issue.URL if opts.IO.IsStdoutTTY() { diff --git a/pkg/cmd/issue/view/view_test.go b/pkg/cmd/issue/view/view_test.go index 391a288fb21..71b0884a1d8 100644 --- a/pkg/cmd/issue/view/view_test.go +++ b/pkg/cmd/issue/view/view_test.go @@ -31,6 +31,7 @@ func TestJSONFields(t *testing.T) { "body", "closed", "comments", + "closedByPullRequestsReferences", "createdAt", "closedAt", "id", diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index 40917fd7629..496139423e9 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -518,7 +518,7 @@ func TestPRCheckout_sameRepo(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -539,7 +539,7 @@ func TestPRCheckout_existingBranch(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -570,7 +570,7 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -590,7 +590,7 @@ func TestPRCheckout_differentRepo(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -613,7 +613,7 @@ func TestPRCheckout_differentRepoForce(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - finder := shared.RunCommandFinder("123", pr, baseRepo) + finder := shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) finder.ExpectFields([]string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}) cs, cmdTeardown := run.Stub() @@ -636,7 +636,7 @@ func TestPRCheckout_differentRepo_existingBranch(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -655,7 +655,7 @@ func TestPRCheckout_detachedHead(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -674,7 +674,7 @@ func TestPRCheckout_differentRepo_currentBranch(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -693,7 +693,7 @@ func TestPRCheckout_differentRepo_invalidBranchName(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:-foo") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -711,7 +711,7 @@ func TestPRCheckout_maintainerCanModify(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") pr.MaintainerCanModify = true - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -732,7 +732,7 @@ func TestPRCheckout_recurseSubmodules(t *testing.T) { http := &httpmock.Registry{} baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -753,7 +753,7 @@ func TestPRCheckout_force(t *testing.T) { http := &httpmock.Registry{} baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -774,7 +774,7 @@ func TestPRCheckout_detach(t *testing.T) { defer http.Verify(t) baseRepo, pr := stubPR("OWNER/REPO:master", "hubot/REPO:feature") - shared.RunCommandFinder("123", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, baseRepo) cs, cmdTeardown := run.Stub() defer cmdTeardown(t) diff --git a/pkg/cmd/pr/close/close_test.go b/pkg/cmd/pr/close/close_test.go index 959af0e04f2..57ee0f0e643 100644 --- a/pkg/cmd/pr/close/close_test.go +++ b/pkg/cmd/pr/close/close_test.go @@ -110,7 +110,7 @@ func TestPrClose(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -133,7 +133,7 @@ func TestPrClose_alreadyClosed(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") pr.State = "CLOSED" pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) output, err := runCommand(http, true, "96") assert.NoError(t, err) @@ -147,7 +147,7 @@ func TestPrClose_deleteBranch_sameRepo(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:blueberries") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -181,7 +181,7 @@ func TestPrClose_deleteBranch_crossRepo(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "hubot/REPO:blueberries") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -213,7 +213,7 @@ func TestPrClose_deleteBranch_sameBranch(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO:main", "OWNER/REPO:trunk") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -248,7 +248,7 @@ func TestPrClose_deleteBranch_notInGitRepo(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO:main", "OWNER/REPO:trunk") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation PullRequestClose\b`), @@ -282,7 +282,7 @@ func TestPrClose_withComment(t *testing.T) { baseRepo, pr := stubPR("OWNER/REPO", "OWNER/REPO:feature") pr.Title = "The title of the PR" - shared.RunCommandFinder("96", pr, baseRepo) + shared.StubFinderForRunCommandStyleTests(t, "96", pr, baseRepo) http.Register( httpmock.GraphQL(`mutation CommentCreate\b`), diff --git a/pkg/cmd/pr/comment/comment.go b/pkg/cmd/pr/comment/comment.go index a2ab4bf9ee0..2eed7d353bd 100644 --- a/pkg/cmd/pr/comment/comment.go +++ b/pkg/cmd/pr/comment/comment.go @@ -16,6 +16,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err InteractiveEditSurvey: shared.CommentableInteractiveEditSurvey(f.Config, f.IOStreams), ConfirmSubmitSurvey: shared.CommentableConfirmSubmitSurvey(f.Prompter), ConfirmCreateIfNoneSurvey: shared.CommentableInteractiveCreateIfNoneSurvey(f.Prompter), + ConfirmDeleteLastComment: shared.CommentableConfirmDeleteLastComment(f.Prompter), OpenInBrowser: f.Browser.Browse, } @@ -43,7 +44,7 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err selector = args[0] } fields := []string{"id", "url"} - if opts.EditLast { + if opts.EditLast || opts.DeleteLast { fields = append(fields, "comments") } finder := shared.NewFinder(f) @@ -75,7 +76,9 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*shared.CommentableOptions) err cmd.Flags().StringVarP(&bodyFile, "body-file", "F", "", "Read body text from `file` (use \"-\" to read from standard input)") cmd.Flags().BoolP("editor", "e", false, "Skip prompts and open the text editor to write the body in") cmd.Flags().BoolP("web", "w", false, "Open the web browser to write the comment") - cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the same author") + cmd.Flags().BoolVar(&opts.EditLast, "edit-last", false, "Edit the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLast, "delete-last", false, "Delete the last comment of the current user") + cmd.Flags().BoolVar(&opts.DeleteLastConfirmed, "yes", false, "Skip the delete confirmation prompt when --delete-last is provided") cmd.Flags().BoolVar(&opts.CreateIfNone, "create-if-none", false, "Create a new comment if no comments are found. Can be used only with --edit-last") return cmd diff --git a/pkg/cmd/pr/comment/comment_test.go b/pkg/cmd/pr/comment/comment_test.go index 0941f25337a..b9d8e153d48 100644 --- a/pkg/cmd/pr/comment/comment_test.go +++ b/pkg/cmd/pr/comment/comment_test.go @@ -2,6 +2,7 @@ package comment import ( "bytes" + "errors" "fmt" "net/http" "os" @@ -31,6 +32,7 @@ func TestNewCmdComment(t *testing.T) { stdin string output shared.CommentableOptions wantsErr bool + isTTY bool }{ { name: "no arguments", @@ -40,12 +42,14 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { name: "two arguments", input: "1 2", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { @@ -56,6 +60,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -66,6 +71,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -76,6 +82,7 @@ func TestNewCmdComment(t *testing.T) { InputType: 0, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -86,6 +93,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "test", }, + isTTY: true, wantsErr: false, }, { @@ -97,6 +105,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "this is on standard input", }, + isTTY: true, wantsErr: false, }, { @@ -107,6 +116,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeInline, Body: "a body from file", }, + isTTY: true, wantsErr: false, }, { @@ -117,6 +127,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeEditor, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -127,6 +138,7 @@ func TestNewCmdComment(t *testing.T) { InputType: shared.InputTypeWeb, Body: "", }, + isTTY: true, wantsErr: false, }, { @@ -138,6 +150,7 @@ func TestNewCmdComment(t *testing.T) { Body: "", EditLast: true, }, + isTTY: true, wantsErr: false, }, { @@ -150,42 +163,110 @@ func TestNewCmdComment(t *testing.T) { EditLast: true, CreateIfNone: true, }, + isTTY: true, wantsErr: false, }, + { + name: "delete last flag non-interactive", + input: "1 --delete-last", + isTTY: false, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation non-interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: false, + wantsErr: false, + }, + { + name: "delete last flag interactive", + input: "1 --delete-last", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation interactive", + input: "1 --delete-last --yes", + output: shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + isTTY: true, + wantsErr: false, + }, + { + name: "delete last flag and pre-confirmation with web flag", + input: "1 --delete-last --yes --web", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with editor flag", + input: "1 --delete-last --yes --editor", + isTTY: true, + wantsErr: true, + }, + { + name: "delete last flag and pre-confirmation with body flag", + input: "1 --delete-last --yes --body", + isTTY: true, + wantsErr: true, + }, + { + name: "delete pre-confirmation without delete last flag", + input: "1 --yes", + isTTY: true, + wantsErr: true, + }, { name: "body and body-file flags", input: "1 --body 'test' --body-file 'test-file.txt'", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and web flags", input: "1 --editor --web", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor and body flags", input: "1 --editor --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "web and body flags", input: "1 --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "editor, web, and body flags", input: "1 --editor --web --body test", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, { name: "create-if-none flag without edit-last", input: "1 --create-if-none", output: shared.CommentableOptions{}, + isTTY: true, wantsErr: true, }, } @@ -193,9 +274,10 @@ func TestNewCmdComment(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, stdin, _, _ := iostreams.Test() - ios.SetStdoutTTY(true) - ios.SetStdinTTY(true) - ios.SetStderrTTY(true) + isTTY := tt.isTTY + ios.SetStdoutTTY(isTTY) + ios.SetStdinTTY(isTTY) + ios.SetStderrTTY(isTTY) if tt.stdin != "" { _, _ = stdin.WriteString(tt.stdin) @@ -231,6 +313,8 @@ func TestNewCmdComment(t *testing.T) { assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) assert.Equal(t, tt.output.InputType, gotOpts.InputType) assert.Equal(t, tt.output.Body, gotOpts.Body) + assert.Equal(t, tt.output.DeleteLast, gotOpts.DeleteLast) + assert.Equal(t, tt.output.DeleteLastConfirmed, gotOpts.DeleteLastConfirmed) }) } } @@ -240,6 +324,7 @@ func Test_commentRun(t *testing.T) { name string input *shared.CommentableOptions emptyComments bool + comments api.Comments httpStubs func(*testing.T, *httpmock.Registry) stdout string stderr string @@ -274,6 +359,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with interactive editor succeeds if there are comments", @@ -350,6 +436,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "creating new comment with non-interactive editor succeeds", @@ -377,6 +464,7 @@ func Test_commentRun(t *testing.T) { }, emptyComments: true, wantsErr: true, + stdout: "no comments found for current user", }, { name: "updating last comment with non-interactive editor succeeds if there are comments", @@ -451,6 +539,117 @@ func Test_commentRun(t *testing.T) { }, stdout: "https://github.com/OWNER/REPO/pull/123#issuecomment-456\n", }, + { + name: "deleting last comment non-interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment interactively without any comment", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + }, + emptyComments: true, + wantsErr: true, + stdout: "no comments found for current user", + }, + { + name: "deleting last comment non-interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: false, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and pre-confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + DeleteLastConfirmed: true, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmed", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, + { + name: "deleting last comment interactively and confirmation declined", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "comment body" { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "comment body"}, + }}, + wantsErr: true, + stdout: "deletion not confirmed", + }, + { + name: "deleting last comment interactively and confirmed with long comment body", + input: &shared.CommentableOptions{ + Interactive: true, + DeleteLast: true, + + ConfirmDeleteLastComment: func(body string) (bool, error) { + if body != "Lorem ipsum dolor sit amet, consectet lo..." { + return false, errors.New("unexpected comment body") + } + return true, nil + }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + mockCommentDelete(t, reg) + }, + comments: api.Comments{Nodes: []api.Comment{ + {ID: "id1", Author: api.CommentAuthor{Login: "octocat"}, URL: "https://github.com/OWNER/REPO/pull/123#issuecomment-111", ViewerDidAuthor: true, Body: "Lorem ipsum dolor sit amet, consectet lorem ipsum again"}, + }}, + wantsErr: false, + stdout: "! Deleted comments cannot be recovered.\n", + stderr: "Comment deleted\n", + }, } for _, tt := range tests { ios, _, stdout, stderr := iostreams.Test() @@ -475,6 +674,8 @@ func Test_commentRun(t *testing.T) { }} if tt.emptyComments { comments.Nodes = []api.Comment{} + } else if len(tt.comments.Nodes) > 0 { + comments = tt.comments } tt.input.RetrieveCommentable = func() (shared.Commentable, ghrepo.Interface, error) { @@ -489,6 +690,7 @@ func Test_commentRun(t *testing.T) { err := shared.CommentableRun(tt.input) if tt.wantsErr { assert.Error(t, err) + assert.Equal(t, tt.stderr, stderr.String()) return } assert.NoError(t, err) @@ -524,3 +726,15 @@ func mockCommentUpdate(t *testing.T, reg *httpmock.Registry) { }), ) } + +func mockCommentDelete(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation CommentDelete\b`), + httpmock.GraphQLMutation(` + { "data": { "deleteIssueComment": {} } }`, + func(inputs map[string]interface{}) { + assert.Equal(t, "id1", inputs["id"]) + }, + ), + ) +} diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 7f960bce446..1d980b68d0d 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -18,6 +18,7 @@ import ( ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" @@ -31,6 +32,7 @@ import ( type CreateOptions struct { // This struct stores user input and factory functions + Detector fd.Detector HttpClient func() (*http.Client, error) GitClient *git.Client Config func() (gh.Config, error) @@ -363,6 +365,20 @@ func createRun(opts *CreateOptions) error { return err } + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + // TODO projectsV1Deprecation + // Remove this section as we should no longer need to detect + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, ctx.PRRefs.BaseRepo().RepoHost()) + } + + projectsV1Support := opts.Detector.ProjectsV1() + client := ctx.Client state, err := NewIssueState(*ctx, *opts) @@ -384,7 +400,7 @@ func createRun(opts *CreateOptions) error { if err != nil { return err } - openURL, err = generateCompareURL(*ctx, *state) + openURL, err = generateCompareURL(*ctx, *state, projectsV1Support) if err != nil { return err } @@ -440,8 +456,7 @@ func createRun(opts *CreateOptions) error { if err != nil { return err } - // TODO wm: revisit project support - return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) + return submitPR(*opts, *ctx, *state, projectsV1Support) } if opts.RecoverFile != "" { @@ -518,7 +533,7 @@ func createRun(opts *CreateOptions) error { } } - openURL, err = generateCompareURL(*ctx, *state) + openURL, err = generateCompareURL(*ctx, *state, projectsV1Support) if err != nil { return err } @@ -537,8 +552,7 @@ func createRun(opts *CreateOptions) error { Repo: ctx.PRRefs.BaseRepo(), State: state, } - // TODO wm: revisit project support - err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state, gh.ProjectsV1Supported) + err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state, projectsV1Support) if err != nil { return err } @@ -567,13 +581,11 @@ func createRun(opts *CreateOptions) error { if action == shared.SubmitDraftAction { state.Draft = true - // TODO wm: revisit project support - return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) + return submitPR(*opts, *ctx, *state, projectsV1Support) } if action == shared.SubmitAction { - // TODO wm: revisit project support - return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) + return submitPR(*opts, *ctx, *state, projectsV1Support) } err = errors.New("expected to cancel, preview, or submit") @@ -1216,13 +1228,12 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return pushBranch() } -func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (string, error) { +func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState, projectsV1Support gh.ProjectsV1Support) (string, error) { u := ghrepo.GenerateRepoURL( ctx.PRRefs.BaseRepo(), "compare/%s...%s?expand=1", url.PathEscape(ctx.PRRefs.BaseRef()), url.PathEscape(ctx.PRRefs.QualifiedHeadRef())) - // TODO wm: revisit project support - url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state, gh.ProjectsV1Supported) + url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state, projectsV1Support) if err != nil { return "", err } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 2a88b5eee13..bd68f19d967 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -15,6 +15,7 @@ import ( "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/config" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" @@ -1618,6 +1619,7 @@ func Test_createRun(t *testing.T) { } opts := CreateOptions{} + opts.Detector = &fd.EnabledDetectorMock{} opts.Prompter = pm ios, _, stdout, stderr := iostreams.Test() @@ -1850,11 +1852,13 @@ func mustParseQualifiedHeadRef(ref string) shared.QualifiedHeadRef { func Test_generateCompareURL(t *testing.T) { tests := []struct { - name string - ctx CreateContext - state shared.IssueMetadataState - want string - wantErr bool + name string + ctx CreateContext + state shared.IssueMetadataState + httpStubs func(*testing.T, *httpmock.Registry) + projectsV1Support gh.ProjectsV1Support + want string + wantErr bool }{ { name: "basic", @@ -1938,10 +1942,135 @@ func Test_generateCompareURL(t *testing.T) { want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1&template=story.md", wantErr: false, }, + // TODO projectsV1Deprecation + // Clean up these tests, but probably keep one for general project ID resolution. + { + name: "with projects, no v1 support", + ctx: CreateContext{ + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, + }, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + // Ensure no v1 projects are requestd + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "ProjectTitle", "id": "PROJECTV2ID", "resourcePath": "/OWNER/REPO/projects/3" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + }, + state: shared.IssueMetadataState{ + ProjectTitles: []string{"ProjectTitle"}, + }, + projectsV1Support: gh.ProjectsV1Unsupported, + want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1&projects=OWNER%2FREPO%2F3", + wantErr: false, + }, + { + name: "with projects, v1 support", + ctx: CreateContext{ + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, + }, + }, + state: shared.IssueMetadataState{ + ProjectTitles: []string{"ProjectV1Title"}, + }, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + // v1 project query responses + reg.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projects": { + "nodes": [ + { "name": "ProjectV1Title", "id": "PROJECTV1ID", "resourcePath": "/OWNER/REPO/projects/1" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projects": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + // v2 project query responses + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + }, + projectsV1Support: gh.ProjectsV1Supported, + want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1&projects=OWNER%2FREPO%2F1", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := generateCompareURL(tt.ctx, tt.state) + // If http stubs are provided, register them and inject the registry into a client + // that is provided to generateCompareURL in the ctx. + if tt.httpStubs != nil { + reg := &httpmock.Registry{} + defer reg.Verify(t) + + tt.httpStubs(t, reg) + tt.ctx.Client = api.NewClientFromHTTP(&http.Client{Transport: reg}) + } + + got, err := generateCompareURL(tt.ctx, tt.state, tt.projectsV1Support) if (err != nil) != tt.wantErr { t.Errorf("generateCompareURL() error = %v, wantErr %v", err, tt.wantErr) return @@ -2008,4 +2137,438 @@ func mockRetrieveProjects(_ *testing.T, reg *httpmock.Registry) { `)) } -// TODO interactive metadata tests once: 1) we have test utils for Prompter and 2) metadata questions use Prompter +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + + t.Run("non-interactive submission", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&CreateOptions{ + Detector: &fd.EnabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we're not really interested in it. + _ = createRun(&CreateOptions{ + Detector: &fd.DisabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + }) + + t.Run("interactive submission", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register("git -c log.ShowSignature=false log --pretty=format:%H%x00%s%x00%b%x00 --cherry origin/master...feature", 0, "") + cs.Register(`git rev-parse --show-toplevel`, 0, "") + + // When the command is run + reg := &httpmock.Registry{} + reg.StubRepoResponse("OWNER", "REPO") + + reg.Register( + httpmock.GraphQL(`query PullRequestTemplates\b`), + httpmock.StringResponse(`{ "data": { "repository": { "pullRequestTemplates": [] } } }`), + ) + + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + // Register a handler to check for projects V2 just to avoid the registry panicking, even + // though we return a 500 error. This is because the project lookup is done in parallel + // so the previous error doesn't early exit. + reg.Register( + httpmock.GraphQL(`projectsV2`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + + pm := &prompter.PrompterMock{} + pm.InputFunc = func(p, _ string) (string, error) { + if p == "Title (required)" { + return "Test Title", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.MarkdownEditorFunc = func(p, _ string, ba bool) (string, error) { + if p == "Body" { + return "Test Body", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + switch p { + case "Choose a template": + return 0, nil + case "What's next?": + return prompter.IndexFor(opts, "Add metadata") + default: + return -1, prompter.NoSuchPromptErr(p) + } + } + pm.MultiSelectFunc = func(p string, _ []string, opts []string) ([]int, error) { + return prompter.IndexesFor(opts, "Projects") + } + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: pm, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Detector: &fd.EnabledDetectorMock{}, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + HeadBranch: "feature", + } + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&opts) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register("git -c log.ShowSignature=false log --pretty=format:%H%x00%s%x00%b%x00 --cherry origin/master...feature", 0, "") + cs.Register(`git rev-parse --show-toplevel`, 0, "") + + // When the command is run + reg := &httpmock.Registry{} + reg.StubRepoResponse("OWNER", "REPO") + + reg.Register( + httpmock.GraphQL(`query PullRequestTemplates\b`), + httpmock.StringResponse(`{ "data": { "repository": { "pullRequestTemplates": [] } } }`), + ) + + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) + + pm := &prompter.PrompterMock{} + pm.InputFunc = func(p, _ string) (string, error) { + if p == "Title (required)" { + return "Test Title", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.MarkdownEditorFunc = func(p, _ string, ba bool) (string, error) { + if p == "Body" { + return "Test Body", nil + } else { + return "", prompter.NoSuchPromptErr(p) + } + } + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + switch p { + case "Choose a template": + return 0, nil + case "What's next?": + return prompter.IndexFor(opts, "Add metadata") + default: + return -1, prompter.NoSuchPromptErr(p) + } + } + pm.MultiSelectFunc = func(p string, _ []string, opts []string) ([]int, error) { + return prompter.IndexesFor(opts, "Projects") + } + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: pm, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Detector: &fd.DisabledDetectorMock{}, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + HeadBranch: "feature", + } + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&opts) + + // Verify that our request did not contain projectCards + reg.Verify(t) + }) + }) + + t.Run("web mode", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&CreateOptions{ + Detector: &fd.EnabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + WebMode: true, + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + + // Ignore the error because we're not really interested in it. + _ = createRun(&CreateOptions{ + Detector: &fd.DisabledDetectorMock{}, + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Finder: shared.NewMockFinder("feature", nil, nil), + + WebMode: true, + + HeadBranch: "feature", + + TitleProvided: true, + BodyProvided: true, + Title: "Test Title", + Body: "Test Body", + + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request did not contain projectCards + reg.Verify(t) + }) + }) +} diff --git a/pkg/cmd/pr/edit/edit.go b/pkg/cmd/pr/edit/edit.go index 3c8d73ad393..c30c4f9bb6b 100644 --- a/pkg/cmd/pr/edit/edit.go +++ b/pkg/cmd/pr/edit/edit.go @@ -3,9 +3,11 @@ package edit import ( "fmt" "net/http" + "time" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" shared "github.com/cli/cli/v2/pkg/cmd/pr/shared" @@ -25,6 +27,8 @@ type EditOptions struct { Fetcher EditableOptionsFetcher EditorRetriever EditorRetriever Prompter shared.EditPrompter + Detector fd.Detector + BaseRepo func() (ghrepo.Interface, error) SelectorArg string Interactive bool @@ -56,12 +60,21 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman Editing a pull request's projects requires authorization with the %[1]sproject%[1]s scope. To authorize, run %[1]sgh auth refresh -s project%[1]s. + + The %[1]s--add-assignee%[1]s and %[1]s--remove-assignee%[1]s flags both support + the following special values: + - %[1]s@me%[1]s: assign or unassign yourself + - %[1]s@copilot%[1]s: assign or unassign Copilot (not supported on GitHub Enterprise Server) + + The %[1]s--add-reviewer%[1]s and %[1]s--remove-reviewer%[1]s flags do not support + these special values. `, "`"), Example: heredoc.Doc(` $ gh pr edit 23 --title "I found a bug" --body "Nothing works" $ gh pr edit 23 --add-label "bug,help wanted" --remove-label "core" $ gh pr edit 23 --add-reviewer monalisa,hubot --remove-reviewer myorg/team-name $ gh pr edit 23 --add-assignee "@me" --remove-assignee monalisa,hubot + $ gh pr edit 23 --add-assignee "@copilot" $ gh pr edit 23 --add-project "Roadmap" --remove-project v1,v2 $ gh pr edit 23 --milestone "Version 1" $ gh pr edit 23 --remove-milestone @@ -69,6 +82,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { opts.Finder = shared.NewFinder(f) + opts.BaseRepo = f.BaseRepo if len(args) > 0 { opts.SelectorArg = args[0] @@ -192,8 +206,36 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman func editRun(opts *EditOptions) error { findOptions := shared.FindOptions{ Selector: opts.SelectorArg, - Fields: []string{"id", "url", "title", "body", "baseRefName", "reviewRequests", "assignees", "labels", "projectCards", "projectItems", "milestone"}, + Fields: []string{"id", "url", "title", "body", "baseRefName", "reviewRequests", "labels", "projectCards", "projectItems", "milestone"}, + Detector: opts.Detector, + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + if opts.Detector == nil { + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) } + + issueFeatures, err := opts.Detector.IssueFeatures() + if err != nil { + return err + } + + if issueFeatures.ActorIsAssignable { + findOptions.Fields = append(findOptions.Fields, "assignedActors") + } else { + findOptions.Fields = append(findOptions.Fields, "assignees") + } + pr, repo, err := opts.Finder.Find(findOptions) if err != nil { return err @@ -205,7 +247,12 @@ func editRun(opts *EditOptions) error { editable.Body.Default = pr.Body editable.Base.Default = pr.BaseRefName editable.Reviewers.Default = pr.ReviewRequests.Logins() - editable.Assignees.Default = pr.Assignees.Logins() + if issueFeatures.ActorIsAssignable { + editable.Assignees.ActorAssignees = true + editable.Assignees.Default = pr.AssignedActors.DisplayNames() + } else { + editable.Assignees.Default = pr.Assignees.Logins() + } editable.Labels.Default = pr.Labels.Names() editable.Projects.Default = append(pr.ProjectCards.ProjectNames(), pr.ProjectItems.ProjectTitles()...) projectItems := map[string]string{} @@ -224,10 +271,6 @@ func editRun(opts *EditOptions) error { } } - httpClient, err := opts.HttpClient() - if err != nil { - return err - } apiClient := api.NewClientFromHTTP(httpClient) opts.IO.StartProgressIndicator() @@ -278,8 +321,7 @@ func updatePullRequestReviews(httpClient *http.Client, repo ghrepo.Interface, id if err != nil { return err } - if (userIds == nil || len(*userIds) == 0) && - (teamIds == nil || len(*teamIds) == 0) { + if userIds == nil && teamIds == nil { return nil } union := githubv4.Boolean(false) diff --git a/pkg/cmd/pr/edit/edit_test.go b/pkg/cmd/pr/edit/edit_test.go index 3c4882961ad..3a1e8a54825 100644 --- a/pkg/cmd/pr/edit/edit_test.go +++ b/pkg/cmd/pr/edit/edit_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/cli/cli/v2/api" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" shared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -165,9 +166,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ SelectorArg: "23", Editable: shared.Editable{ - Assignees: shared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Edited: true, + Assignees: shared.EditableAssignees{ + EditableSlice: shared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Edited: true, + }, }, }, }, @@ -179,9 +182,11 @@ func TestNewCmdEdit(t *testing.T) { output: EditOptions{ SelectorArg: "23", Editable: shared.Editable{ - Assignees: shared.EditableSlice{ - Remove: []string{"monalisa", "hubot"}, - Edited: true, + Assignees: shared.EditableAssignees{ + EditableSlice: shared.EditableSlice{ + Remove: []string{"monalisa", "hubot"}, + Edited: true, + }, }, }, }, @@ -359,10 +364,12 @@ func Test_editRun(t *testing.T) { Remove: []string{"dependabot"}, Edited: true, }, - Assignees: shared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Remove: []string{"octocat"}, - Edited: true, + Assignees: shared.EditableAssignees{ + EditableSlice: shared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, }, Labels: shared.EditableSlice{ Add: []string{"feature", "TODO", "bug"}, @@ -386,6 +393,7 @@ func Test_editRun(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { mockRepoMetadata(reg, false) mockPullRequestUpdate(reg) + mockPullRequestUpdateActorAssignees(reg) mockPullRequestReviewersUpdate(reg) mockPullRequestUpdateLabels(reg) mockProjectV2ItemUpdate(reg) @@ -413,10 +421,12 @@ func Test_editRun(t *testing.T) { Value: "base-branch-name", Edited: true, }, - Assignees: shared.EditableSlice{ - Add: []string{"monalisa", "hubot"}, - Remove: []string{"octocat"}, - Edited: true, + Assignees: shared.EditableAssignees{ + EditableSlice: shared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, }, Labels: shared.EditableSlice{ Add: []string{"feature", "TODO", "bug"}, @@ -440,7 +450,69 @@ func Test_editRun(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { mockRepoMetadata(reg, true) mockPullRequestUpdate(reg) + mockPullRequestUpdateActorAssignees(reg) + mockPullRequestUpdateLabels(reg) + mockProjectV2ItemUpdate(reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, + { + name: "non-interactive remove all reviewers", + input: &EditOptions{ + SelectorArg: "123", + Finder: shared.NewMockFinder("123", &api.PullRequest{ + URL: "https://github.com/OWNER/REPO/pull/123", + }, ghrepo.New("OWNER", "REPO")), + Interactive: false, + Editable: shared.Editable{ + Title: shared.EditableString{ + Value: "new title", + Edited: true, + }, + Body: shared.EditableString{ + Value: "new body", + Edited: true, + }, + Base: shared.EditableString{ + Value: "base-branch-name", + Edited: true, + }, + Reviewers: shared.EditableSlice{ + Remove: []string{"OWNER/core", "OWNER/external", "monalisa", "hubot", "dependabot"}, + Edited: true, + }, + Assignees: shared.EditableAssignees{ + EditableSlice: shared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, + }, + Labels: shared.EditableSlice{ + Add: []string{"feature", "TODO", "bug"}, + Remove: []string{"docs"}, + Edited: true, + }, + Projects: shared.EditableProjects{ + EditableSlice: shared.EditableSlice{ + Add: []string{"Cleanup", "CleanupV2"}, + Remove: []string{"Roadmap", "RoadmapV2"}, + Edited: true, + }, + }, + Milestone: shared.EditableString{ + Value: "GA", + Edited: true, + }, + }, + Fetcher: testFetcher{}, + }, + httpStubs: func(reg *httpmock.Registry) { + mockRepoMetadata(reg, false) + mockPullRequestUpdate(reg) + mockPullRequestReviewersUpdate(reg) mockPullRequestUpdateLabels(reg) + mockPullRequestUpdateActorAssignees(reg) mockProjectV2ItemUpdate(reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", @@ -460,6 +532,7 @@ func Test_editRun(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { mockRepoMetadata(reg, false) mockPullRequestUpdate(reg) + mockPullRequestUpdateActorAssignees(reg) mockPullRequestReviewersUpdate(reg) mockPullRequestUpdateLabels(reg) mockProjectV2ItemUpdate(reg) @@ -481,11 +554,72 @@ func Test_editRun(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { mockRepoMetadata(reg, true) mockPullRequestUpdate(reg) + mockPullRequestUpdateActorAssignees(reg) mockPullRequestUpdateLabels(reg) mockProjectV2ItemUpdate(reg) }, stdout: "https://github.com/OWNER/REPO/pull/123\n", }, + { + name: "interactive remove all reviewers", + input: &EditOptions{ + SelectorArg: "123", + Finder: shared.NewMockFinder("123", &api.PullRequest{ + URL: "https://github.com/OWNER/REPO/pull/123", + }, ghrepo.New("OWNER", "REPO")), + Interactive: true, + Surveyor: testSurveyor{removeAllReviewers: true}, + Fetcher: testFetcher{}, + EditorRetriever: testEditorRetriever{}, + }, + httpStubs: func(reg *httpmock.Registry) { + mockRepoMetadata(reg, false) + mockPullRequestUpdate(reg) + mockPullRequestReviewersUpdate(reg) + mockPullRequestUpdateActorAssignees(reg) + mockPullRequestUpdateLabels(reg) + mockProjectV2ItemUpdate(reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, + { + name: "Legacy assignee users are fetched and updated on unsupported GitHub Hosts", + input: &EditOptions{ + Detector: &fd.DisabledDetectorMock{}, + SelectorArg: "123", + Finder: shared.NewMockFinder("123", &api.PullRequest{ + URL: "https://github.com/OWNER/REPO/pull/123", + }, ghrepo.New("OWNER", "REPO")), + Interactive: false, + Editable: shared.Editable{ + Assignees: shared.EditableAssignees{ + EditableSlice: shared.EditableSlice{ + Add: []string{"monalisa", "hubot"}, + Remove: []string{"octocat"}, + Edited: true, + }, + }, + }, + Fetcher: testFetcher{}, + }, + httpStubs: func(reg *httpmock.Registry) { + // Notice there is no call to mockReplaceActorsForAssignable() + // and no GraphQL call to RepositoryAssignableActors below. + reg.Register( + httpmock.GraphQL(`query RepositoryAssignableUsers\b`), + httpmock.StringResponse(` + { "data": { "repository": { "assignableUsers": { + "nodes": [ + { "login": "hubot", "id": "HUBOTID" }, + { "login": "MonaLisa", "id": "MONAID" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + mockPullRequestUpdate(reg) + }, + stdout: "https://github.com/OWNER/REPO/pull/123\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -499,9 +633,11 @@ func Test_editRun(t *testing.T) { tt.httpStubs(reg) httpClient := func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } + baseRepo := func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil } tt.input.IO = ios tt.input.HttpClient = httpClient + tt.input.BaseRepo = baseRepo err := editRun(tt.input) assert.NoError(t, err) @@ -513,16 +649,16 @@ func Test_editRun(t *testing.T) { func mockRepoMetadata(reg *httpmock.Registry, skipReviewers bool) { reg.Register( - httpmock.GraphQL(`query RepositoryAssignableUsers\b`), + httpmock.GraphQL(`query RepositoryAssignableActors\b`), httpmock.StringResponse(` - { "data": { "repository": { "assignableUsers": { - "nodes": [ - { "login": "hubot", "id": "HUBOTID" }, - { "login": "MonaLisa", "id": "MONAID" } - ], - "pageInfo": { "hasNextPage": false } - } } } } - `)) + { "data": { "repository": { "suggestedActors": { + "nodes": [ + { "login": "hubot", "id": "HUBOTID", "__typename": "Bot" }, + { "login": "MonaLisa", "id": "MONAID", "__typename": "User" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) reg.Register( httpmock.GraphQL(`query RepositoryLabelList\b`), httpmock.StringResponse(` @@ -625,6 +761,15 @@ func mockPullRequestUpdate(reg *httpmock.Registry) { httpmock.StringResponse(`{}`)) } +func mockPullRequestUpdateActorAssignees(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation ReplaceActorsForAssignable\b`), + httpmock.GraphQLMutation(` + { "data": { "replaceActorsForAssignable": { "__typename": "" } } }`, + func(inputs map[string]interface{}) {}), + ) +} + func mockPullRequestReviewersUpdate(reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`mutation PullRequestUpdateRequestReviews\b`), @@ -657,7 +802,8 @@ func mockProjectV2ItemUpdate(reg *httpmock.Registry) { type testFetcher struct{} type testSurveyor struct { - skipReviewers bool + skipReviewers bool + removeAllReviewers bool } type testEditorRetriever struct{} @@ -682,7 +828,11 @@ func (s testSurveyor) EditFields(e *shared.Editable, _ string) error { e.Title.Value = "new title" e.Body.Value = "new body" if !s.skipReviewers { - e.Reviewers.Value = []string{"monalisa", "hubot", "OWNER/core", "OWNER/external"} + if s.removeAllReviewers { + e.Reviewers.Remove = []string{"monalisa", "hubot", "OWNER/core", "OWNER/external", "dependabot"} + } else { + e.Reviewers.Value = []string{"monalisa", "hubot", "OWNER/core", "OWNER/external"} + } } e.Assignees.Value = []string{"monalisa", "hubot"} e.Labels.Value = []string{"feature", "TODO", "bug"} @@ -696,3 +846,73 @@ func (s testSurveyor) EditFields(e *shared.Editable, _ string) error { func (t testEditorRetriever) Retrieve() (string, error) { return "vim", nil } + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + t.Run("when projects v1 is supported, is included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + httpmock.GraphQL(`projectCards`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + f := &cmdutil.Factory{ + IOStreams: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = editRun(&EditOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Detector: &fd.EnabledDetectorMock{}, + + Finder: shared.NewFinder(f), + + SelectorArg: "https://github.com/cli/cli/pull/123", + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, is not included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Exclude(t, httpmock.GraphQL(`projectCards`)) + + f := &cmdutil.Factory{ + IOStreams: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = editRun(&EditOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Detector: &fd.DisabledDetectorMock{}, + + Finder: shared.NewFinder(f), + + SelectorArg: "https://github.com/cli/cli/pull/123", + }) + + // Verify that our request did not contain projectCards + reg.Verify(t) + }) +} diff --git a/pkg/cmd/pr/merge/merge_test.go b/pkg/cmd/pr/merge/merge_test.go index f1c2e37fefb..4ca8c5d06df 100644 --- a/pkg/cmd/pr/merge/merge_test.go +++ b/pkg/cmd/pr/merge/merge_test.go @@ -307,7 +307,7 @@ func TestPrMerge(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -348,7 +348,7 @@ func TestPrMerge_blocked(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -379,7 +379,7 @@ func TestPrMerge_dirty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -413,7 +413,7 @@ func TestPrMerge_nontty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -451,7 +451,7 @@ func TestPrMerge_editMessage_nontty(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -490,7 +490,7 @@ func TestPrMerge_withRepoFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -529,7 +529,7 @@ func TestPrMerge_withMatchCommitHeadFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -570,7 +570,7 @@ func TestPrMerge_withAuthorFlag(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -612,7 +612,7 @@ func TestPrMerge_deleteBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -663,7 +663,7 @@ func TestPrMerge_deleteBranch_mergeQueue(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -686,7 +686,7 @@ func TestPrMerge_deleteBranch_nonDefault(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -737,7 +737,7 @@ func TestPrMerge_deleteBranch_onlyLocally(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -785,7 +785,7 @@ func TestPrMerge_deleteBranch_checkoutNewBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -836,7 +836,7 @@ func TestPrMerge_deleteNonCurrentBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "blueberries", &api.PullRequest{ ID: "PR_10", @@ -893,7 +893,7 @@ func Test_nonDivergingPullRequest(t *testing.T) { } stubCommit(pr, "COMMITSHA1") - shared.RunCommandFinder("", pr, baseRepo("OWNER", "REPO", "main")) + shared.StubFinderForRunCommandStyleTests(t, "", pr, baseRepo("OWNER", "REPO", "main")) http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), @@ -933,7 +933,7 @@ func Test_divergingPullRequestWarning(t *testing.T) { } stubCommit(pr, "COMMITSHA1") - shared.RunCommandFinder("", pr, baseRepo("OWNER", "REPO", "main")) + shared.StubFinderForRunCommandStyleTests(t, "", pr, baseRepo("OWNER", "REPO", "main")) http.Register( httpmock.GraphQL(`mutation PullRequestMerge\b`), @@ -964,7 +964,7 @@ func Test_pullRequestWithoutCommits(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "PR_10", @@ -1003,7 +1003,7 @@ func TestPrMerge_rebase(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "2", &api.PullRequest{ ID: "THE-ID", @@ -1044,7 +1044,7 @@ func TestPrMerge_squash(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "3", &api.PullRequest{ ID: "THE-ID", @@ -1084,7 +1084,7 @@ func TestPrMerge_alreadyMerged(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1129,7 +1129,7 @@ func TestPrMerge_alreadyMerged_withMergeStrategy(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1159,7 +1159,7 @@ func TestPrMerge_alreadyMerged_withMergeStrategy_TTY(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1200,7 +1200,7 @@ func TestPrMerge_alreadyMerged_withMergeStrategy_crossRepo(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "4", &api.PullRequest{ ID: "THE-ID", @@ -1239,7 +1239,7 @@ func TestPRMergeTTY(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "THE-ID", @@ -1305,7 +1305,7 @@ func TestPRMergeTTY_withDeleteBranch(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ ID: "THE-ID", @@ -1468,7 +1468,7 @@ func TestPRMergeEmptyStrategyNonTTY(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1495,7 +1495,7 @@ func TestPRTTY_cancelled(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123, Title: "title", MergeStateStatus: "CLEAN"}, ghrepo.New("OWNER", "REPO"), @@ -1679,7 +1679,7 @@ func TestPrInMergeQueue(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1710,7 +1710,7 @@ func TestPrAddToMergeQueueWithMergeMethod(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1748,7 +1748,7 @@ func TestPrAddToMergeQueueClean(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1788,7 +1788,7 @@ func TestPrAddToMergeQueueBlocked(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1828,7 +1828,7 @@ func TestPrAddToMergeQueueAdmin(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", @@ -1897,7 +1897,7 @@ func TestPrAddToMergeQueueAdminWithMergeStrategy(t *testing.T) { http := initFakeHTTP() defer http.Verify(t) - shared.RunCommandFinder( + shared.StubFinderForRunCommandStyleTests(t, "1", &api.PullRequest{ ID: "THE-ID", diff --git a/pkg/cmd/pr/ready/ready_test.go b/pkg/cmd/pr/ready/ready_test.go index 9046ab3aca3..5a6053a17c7 100644 --- a/pkg/cmd/pr/ready/ready_test.go +++ b/pkg/cmd/pr/ready/ready_test.go @@ -124,7 +124,7 @@ func TestPRReady(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -149,7 +149,7 @@ func TestPRReady_alreadyReady(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -166,7 +166,7 @@ func TestPRReadyUndo(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -191,7 +191,7 @@ func TestPRReadyUndo_alreadyDraft(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -208,7 +208,7 @@ func TestPRReady_closed(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "CLOSED", diff --git a/pkg/cmd/pr/reopen/reopen_test.go b/pkg/cmd/pr/reopen/reopen_test.go index 856e191721c..9fb3702c082 100644 --- a/pkg/cmd/pr/reopen/reopen_test.go +++ b/pkg/cmd/pr/reopen/reopen_test.go @@ -53,7 +53,7 @@ func TestPRReopen(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "CLOSED", @@ -78,7 +78,7 @@ func TestPRReopen_alreadyOpen(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "OPEN", @@ -95,7 +95,7 @@ func TestPRReopen_alreadyMerged(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "MERGED", @@ -112,7 +112,7 @@ func TestPRReopen_withComment(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("123", &api.PullRequest{ + shared.StubFinderForRunCommandStyleTests(t, "123", &api.PullRequest{ ID: "THE-ID", Number: 123, State: "CLOSED", diff --git a/pkg/cmd/pr/review/review_test.go b/pkg/cmd/pr/review/review_test.go index f9e00c3b8ee..684617ca97a 100644 --- a/pkg/cmd/pr/review/review_test.go +++ b/pkg/cmd/pr/review/review_test.go @@ -235,7 +235,7 @@ func TestPRReview(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID"}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID"}, ghrepo.New("OWNER", "REPO")) http.Register( httpmock.GraphQL(`mutation PullRequestReviewAdd\b`), @@ -261,7 +261,7 @@ func TestPRReview_interactive(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) http.Register( httpmock.GraphQL(`mutation PullRequestReviewAdd\b`), @@ -293,7 +293,7 @@ func TestPRReview_interactive_no_body(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) pm := &prompter.PrompterMock{ SelectFunc: func(_, _ string, _ []string) (int, error) { return 2, nil }, @@ -308,7 +308,7 @@ func TestPRReview_interactive_blank_approve(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{ID: "THE-ID", Number: 123}, ghrepo.New("OWNER", "REPO")) http.Register( httpmock.GraphQL(`mutation PullRequestReviewAdd\b`), diff --git a/pkg/cmd/pr/shared/commentable.go b/pkg/cmd/pr/shared/commentable.go index f909c755934..015d84a4b16 100644 --- a/pkg/cmd/pr/shared/commentable.go +++ b/pkg/cmd/pr/shared/commentable.go @@ -18,6 +18,7 @@ import ( ) var errNoUserComments = errors.New("no comments found for current user") +var errDeleteNotConfirmed = errors.New("deletion not confirmed") type InputType int @@ -41,11 +42,14 @@ type CommentableOptions struct { InteractiveEditSurvey func(string) (string, error) ConfirmSubmitSurvey func() (bool, error) ConfirmCreateIfNoneSurvey func() (bool, error) + ConfirmDeleteLastComment func(string) (bool, error) OpenInBrowser func(string) error Interactive bool InputType InputType Body string EditLast bool + DeleteLast bool + DeleteLastConfirmed bool CreateIfNone bool Quiet bool Host string @@ -74,6 +78,21 @@ func CommentablePreRun(cmd *cobra.Command, opts *CommentableOptions) error { return cmdutil.FlagErrorf("`--create-if-none` can only be used with `--edit-last`") } + if opts.DeleteLastConfirmed && !opts.DeleteLast { + return cmdutil.FlagErrorf("`--yes` should only be used with `--delete-last`") + } + + if opts.DeleteLast { + if inputFlags > 0 { + return cmdutil.FlagErrorf("should not provide comment body when using `--delete-last`") + } + if opts.IO.CanPrompt() || opts.DeleteLastConfirmed { + opts.Interactive = opts.IO.CanPrompt() + return nil + } + return cmdutil.FlagErrorf("should provide `--yes` to confirm deletion in non-interactive mode") + } + if inputFlags == 0 { if !opts.IO.CanPrompt() { return cmdutil.FlagErrorf("flags required when not running interactively") @@ -92,6 +111,9 @@ func CommentableRun(opts *CommentableOptions) error { return err } opts.Host = repo.RepoHost() + if opts.DeleteLast { + return deleteComment(commentable, opts) + } // Create new comment, bail before complexities of updating the last comment if !opts.EditLast { @@ -236,6 +258,53 @@ func updateComment(commentable Commentable, opts *CommentableOptions) error { return nil } +func deleteComment(commentable Commentable, opts *CommentableOptions) error { + comments := commentable.CurrentUserComments() + if len(comments) == 0 { + return errNoUserComments + } + + lastComment := comments[len(comments)-1] + + cs := opts.IO.ColorScheme() + + if opts.Interactive && !opts.DeleteLastConfirmed { + // This is not an ideal way of truncating a random string that may + // contain emojis or other kind of wide chars. + truncated := lastComment.Body + if len(lastComment.Body) > 40 { + truncated = lastComment.Body[:40] + "..." + } + + fmt.Fprintf(opts.IO.Out, "%s Deleted comments cannot be recovered.\n", cs.WarningIcon()) + ok, err := opts.ConfirmDeleteLastComment(truncated) + if err != nil { + return err + } + if !ok { + return errDeleteNotConfirmed + } + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + apiClient := api.NewClientFromHTTP(httpClient) + params := api.CommentDeleteInput{CommentId: lastComment.Identifier()} + deletionErr := api.CommentDelete(apiClient, opts.Host, params) + if deletionErr != nil { + return deletionErr + } + + if !opts.Quiet { + fmt.Fprintln(opts.IO.ErrOut, "Comment deleted") + } + + return nil +} + func CommentableConfirmSubmitSurvey(p Prompt) func() (bool, error) { return func() (bool, error) { return p.Confirm("Submit?", true) @@ -271,6 +340,12 @@ func CommentableEditSurvey(cf func() (gh.Config, error), io *iostreams.IOStreams } } +func CommentableConfirmDeleteLastComment(p Prompt) func(string) (bool, error) { + return func(body string) (bool, error) { + return p.Confirm(fmt.Sprintf("Delete the comment: %q?", body), true) + } +} + func waitForEnter(r io.Reader) error { scanner := bufio.NewScanner(r) scanner.Scan() diff --git a/pkg/cmd/pr/shared/completion.go b/pkg/cmd/pr/shared/completion.go index e07abc5a73f..c1296be7177 100644 --- a/pkg/cmd/pr/shared/completion.go +++ b/pkg/cmd/pr/shared/completion.go @@ -21,13 +21,13 @@ func RequestableReviewersForCompletion(httpClient *http.Client, repo ghrepo.Inte results := []string{} for _, user := range metadata.AssignableUsers { - if strings.EqualFold(user.Login, metadata.CurrentLogin) { + if strings.EqualFold(user.Login(), metadata.CurrentLogin) { continue } - if user.Name != "" { - results = append(results, fmt.Sprintf("%s\t%s", user.Login, user.Name)) + if user.Name() != "" { + results = append(results, fmt.Sprintf("%s\t%s", user.Login(), user.Name())) } else { - results = append(results, user.Login) + results = append(results, user.Login()) } } for _, team := range metadata.Teams { diff --git a/pkg/cmd/pr/shared/editable.go b/pkg/cmd/pr/shared/editable.go index 0bebb999ac0..2f51f2ae814 100644 --- a/pkg/cmd/pr/shared/editable.go +++ b/pkg/cmd/pr/shared/editable.go @@ -14,7 +14,7 @@ type Editable struct { Body EditableString Base EditableString Reviewers EditableSlice - Assignees EditableSlice + Assignees EditableAssignees Labels EditableSlice Projects EditableProjects Milestone EditableString @@ -38,6 +38,14 @@ type EditableSlice struct { Allowed bool } +// EditableAssignees is a special case of EditableSlice. +// It contains a flag to indicate whether the assignees are actors or not. +type EditableAssignees struct { + EditableSlice + ActorAssignees bool + DefaultLogins []string // For disambiguating actors from display names +} + // ProjectsV2 mutations require a mapping of an item ID to a project ID. // Keep that map along with standard EditableSlice data. type EditableProjects struct { @@ -105,21 +113,56 @@ func (e Editable) AssigneeIds(client *api.Client, repo ghrepo.Interface) (*[]str if !e.Assignees.Edited { return nil, nil } + + // If assignees came in from command line flags, we need to + // curate the final list of assignees from the default list. if len(e.Assignees.Add) != 0 || len(e.Assignees.Remove) != 0 { meReplacer := NewMeReplacer(client, repo.RepoHost()) - s := set.NewStringSet() - s.AddValues(e.Assignees.Default) - add, err := meReplacer.ReplaceSlice(e.Assignees.Add) + copilotReplacer := NewCopilotReplacer() + + replaceSpecialAssigneeNames := func(value []string) ([]string, error) { + replaced, err := meReplacer.ReplaceSlice(value) + if err != nil { + return nil, err + } + + // Only suppported for actor assignees. + if e.Assignees.ActorAssignees { + replaced = copilotReplacer.ReplaceSlice(replaced) + } + + return replaced, nil + } + + assigneeSet := set.NewStringSet() + + // This check below is required because in a non-interactive flow, + // the user gives us a login and not the DisplayName, and when + // we have actor assignees e.Assignees.Default will contain + // DisplayNames and not logins (this is to accommodate special actor + // display names in the interactive flow). + // So, we need to add the default logins here instead of the DisplayNames. + // Otherwise, the value the user provided won't be found in the + // set to be added or removed, causing unexpected behavior. + if e.Assignees.ActorAssignees { + assigneeSet.AddValues(e.Assignees.DefaultLogins) + } else { + assigneeSet.AddValues(e.Assignees.Default) + } + + add, err := replaceSpecialAssigneeNames(e.Assignees.Add) if err != nil { return nil, err } - s.AddValues(add) - remove, err := meReplacer.ReplaceSlice(e.Assignees.Remove) + assigneeSet.AddValues(add) + + remove, err := replaceSpecialAssigneeNames(e.Assignees.Remove) if err != nil { return nil, err } - s.RemoveValues(remove) - e.Assignees.Value = s.ToSlice() + assigneeSet.RemoveValues(remove) + + e.Assignees.Value = assigneeSet.ToSlice() } a, err := e.Metadata.MembersToIDs(e.Assignees.Value) return &a, err @@ -137,7 +180,7 @@ func (e Editable) ProjectIds() (*[]string, error) { s.RemoveValues(e.Projects.Remove) e.Projects.Value = s.ToSlice() } - p, _, err := e.Metadata.ProjectsToIDs(e.Projects.Value) + p, _, err := e.Metadata.ProjectsTitlesToIDs(e.Projects.Value) return &p, err } @@ -171,14 +214,14 @@ func (e Editable) ProjectV2Ids() (*[]string, *[]string, error) { var err error if addTitles.Len() > 0 { - _, addIds, err = e.Metadata.ProjectsToIDs(addTitles.ToSlice()) + _, addIds, err = e.Metadata.ProjectsTitlesToIDs(addTitles.ToSlice()) if err != nil { return nil, nil, err } } if removeTitles.Len() > 0 { - _, removeIds, err = e.Metadata.ProjectsToIDs(removeTitles.ToSlice()) + _, removeIds, err = e.Metadata.ProjectsTitlesToIDs(removeTitles.ToSlice()) if err != nil { return nil, nil, err } @@ -245,6 +288,14 @@ func (es *EditableSlice) clone() EditableSlice { return cpy } +func (ea *EditableAssignees) clone() EditableAssignees { + return EditableAssignees{ + EditableSlice: ea.EditableSlice.clone(), + ActorAssignees: ea.ActorAssignees, + DefaultLogins: ea.DefaultLogins, + } +} + func (ep *EditableProjects) clone() EditableProjects { return EditableProjects{ EditableSlice: ep.EditableSlice.clone(), @@ -378,12 +429,13 @@ func FieldsToEditSurvey(p EditPrompter, editable *Editable) error { func FetchOptions(client *api.Client, repo ghrepo.Interface, editable *Editable) error { input := api.RepoMetadataInput{ - Reviewers: editable.Reviewers.Edited, - Assignees: editable.Assignees.Edited, - Labels: editable.Labels.Edited, - ProjectsV1: editable.Projects.Edited, - ProjectsV2: editable.Projects.Edited, - Milestones: editable.Milestone.Edited, + Reviewers: editable.Reviewers.Edited, + Assignees: editable.Assignees.Edited, + ActorAssignees: editable.Assignees.ActorAssignees, + Labels: editable.Labels.Edited, + ProjectsV1: editable.Projects.Edited, + ProjectsV2: editable.Projects.Edited, + Milestones: editable.Milestone.Edited, } metadata, err := api.RepoMetadata(client, repo, input) if err != nil { @@ -392,7 +444,11 @@ func FetchOptions(client *api.Client, repo ghrepo.Interface, editable *Editable) var users []string for _, u := range metadata.AssignableUsers { - users = append(users, u.Login) + users = append(users, u.Login()) + } + var actors []string + for _, a := range metadata.AssignableActors { + actors = append(actors, a.DisplayName()) } var teams []string for _, t := range metadata.Teams { @@ -416,7 +472,11 @@ func FetchOptions(client *api.Client, repo ghrepo.Interface, editable *Editable) editable.Metadata = *metadata editable.Reviewers.Options = append(users, teams...) - editable.Assignees.Options = users + if editable.Assignees.ActorAssignees { + editable.Assignees.Options = actors + } else { + editable.Assignees.Options = users + } editable.Labels.Options = labels editable.Projects.Options = projects editable.Milestone.Options = milestones diff --git a/pkg/cmd/pr/shared/editable_http.go b/pkg/cmd/pr/shared/editable_http.go index fcc30095ae4..8cd51c34942 100644 --- a/pkg/cmd/pr/shared/editable_http.go +++ b/pkg/cmd/pr/shared/editable_http.go @@ -60,25 +60,78 @@ func UpdateIssue(httpClient *http.Client, repo ghrepo.Interface, id string, isPR if dirtyExcludingLabels(options) { wg.Go(func() error { - return replaceIssueFields(httpClient, repo, id, isPR, options) + // updateIssue mutation does not support Actors so assignment needs to + // be in a separate request when our assignees are Actors. + // Note: this is intentionally done synchronously with updating + // other issue fields to ensure consistency with how legacy + // user assignees are handled. + // https://github.com/cli/cli/pull/10960#discussion_r2086725348 + if options.Assignees.Edited && options.Assignees.ActorAssignees { + apiClient := api.NewClientFromHTTP(httpClient) + assigneeIds, err := options.AssigneeIds(apiClient, repo) + if err != nil { + return err + } + + err = replaceActorAssigneesForEditable(apiClient, repo, id, assigneeIds) + if err != nil { + return err + } + } + err := replaceIssueFields(httpClient, repo, id, isPR, options) + if err != nil { + return err + } + + return nil }) } return wg.Wait() } -func replaceIssueFields(httpClient *http.Client, repo ghrepo.Interface, id string, isPR bool, options Editable) error { - apiClient := api.NewClientFromHTTP(httpClient) - assigneeIds, err := options.AssigneeIds(apiClient, repo) +func replaceActorAssigneesForEditable(apiClient *api.Client, repo ghrepo.Interface, id string, assigneeIds *[]string) error { + type ReplaceActorsForAssignableInput struct { + AssignableID githubv4.ID `json:"assignableId"` + ActorIDs []githubv4.ID `json:"actorIds"` + } + + params := ReplaceActorsForAssignableInput{ + AssignableID: githubv4.ID(id), + ActorIDs: *ghIds(assigneeIds), + } + + var mutation struct { + ReplaceActorsForAssignable struct { + TypeName string `graphql:"__typename"` + } `graphql:"replaceActorsForAssignable(input: $input)"` + } + + variables := map[string]interface{}{"input": params} + err := apiClient.Mutate(repo.RepoHost(), "ReplaceActorsForAssignable", &mutation, variables) if err != nil { return err } + return nil +} + +func replaceIssueFields(httpClient *http.Client, repo ghrepo.Interface, id string, isPR bool, options Editable) error { + apiClient := api.NewClientFromHTTP(httpClient) + projectIds, err := options.ProjectIds() if err != nil { return err } + var assigneeIds *[]string + if !options.Assignees.ActorAssignees { + assigneeIds, err = options.AssigneeIds(apiClient, repo) + if err != nil { + return err + } + } + milestoneId, err := options.MilestoneId() if err != nil { return err diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index e6bb7d66a63..b8d23c9789d 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -10,12 +10,14 @@ import ( "sort" "strconv" "strings" + "testing" "time" "github.com/cli/cli/v2/api" ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" fd "github.com/cli/cli/v2/internal/featuredetection" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmdutil" o "github.com/cli/cli/v2/pkg/option" @@ -54,9 +56,9 @@ type finder struct { } func NewFinder(factory *cmdutil.Factory) PRFinder { - if runCommandFinder != nil { - f := runCommandFinder - runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")} + if finderForRunCommandStyleTests != nil { + f := finderForRunCommandStyleTests + finderForRunCommandStyleTests = &mockFinder{err: errors.New("you must use StubFinderForRunCommandStyleTests to stub PR lookups")} return f } @@ -70,12 +72,23 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { } } -var runCommandFinder PRFinder +var finderForRunCommandStyleTests PRFinder -// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. -func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { +// StubFinderForRunCommandStyleTests is the NewMockFinder substitute to be used ONLY in runCommand-style tests. +func StubFinderForRunCommandStyleTests(t *testing.T, selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { + // Create a new mock finder and override the "runCommandFinder" variable so that calls to + // NewFinder() will return this mock. This is a bad pattern, and a result of old style runCommand + // tests that would ideally be replaced. The reason we need to do this is that the runCommand style tests + // construct the cobra command via NewCmd* functions, and then Execute them directly, providing no opportunity + // to inject a test double unless it's on the factory, which finder never is, because only PR commands need it. finder := NewMockFinder(selector, pr, repo) - runCommandFinder = finder + finderForRunCommandStyleTests = finder + + // Ensure that at the end of the test, we reset the "runCommandFinder" variable so that tests are isolated, + // at least if they are run sequentially. + t.Cleanup(func() { + finderForRunCommandStyleTests = nil + }) return finder } @@ -89,6 +102,8 @@ type FindOptions struct { BaseBranch string // States lists the possible PR states to scope the PR-for-branch lookup to. States []string + + Detector fd.Detector } func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { @@ -193,9 +208,11 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err fields.AddValues([]string{"id", "number"}) // for additional preload queries below if fields.Contains("isInMergeQueue") || fields.Contains("isMergeQueueEnabled") { - cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) - detector := fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) - prFeatures, err := detector.PullRequestFeatures() + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) + } + prFeatures, err := opts.Detector.PullRequestFeatures() if err != nil { return nil, nil, err } @@ -211,8 +228,23 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err fields.Remove("projectItems") } + // TODO projectsV1Deprecation + // Remove this block + // When removing this, remember to remove `projectCards` from the list of default fields in pr/view.go + if fields.Contains("projectCards") { + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, f.baseRefRepo.RepoHost()) + } + + if opts.Detector.ProjectsV1() == gh.ProjectsV1Unsupported { + fields.Remove("projectCards") + } + } + var pr *api.PullRequest if f.prNumber > 0 { + // If we have a PR number, let's look it up if numberFieldOnly { // avoid hitting the API if we already have all the information return &api.PullRequest{Number: f.prNumber}, f.baseRefRepo, nil @@ -221,11 +253,16 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err if err != nil { return pr, f.baseRefRepo, err } - } else { + } else if prRefs.BaseRepo() != nil && f.branchName != "" { + // No PR number, but we have a base repo and branch name. pr, err = findForRefs(httpClient, prRefs, opts.States, fields.ToSlice()) if err != nil { return pr, f.baseRefRepo, err } + } else { + // If we don't have a PR number or a base repo and branch name, + // we can't do anything + return nil, f.baseRefRepo, &NotFoundError{fmt.Errorf("no pull requests found")} } g, _ := errgroup.WithContext(context.Background()) diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index e1aae16b114..abc754d1af5 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -165,6 +165,23 @@ func TestFind(t *testing.T) { wantPR: 13, wantRepo: "https://github.com/ORIGINOWNER/REPO", }, + { + name: "pr number zero", + args: args{ + selector: "0", + fields: []string{"number"}, + baseRepoFn: stubBaseRepoFn(ghrepo.New("ORIGINOWNER", "REPO"), nil), + branchFn: func() (string, error) { + return "blueberries", nil + }, + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, + }, + wantErr: true, + }, { name: "number with hash argument", args: args{ diff --git a/pkg/cmd/pr/shared/params.go b/pkg/cmd/pr/shared/params.go index 4f36a80aaa5..1fa45652abd 100644 --- a/pkg/cmd/pr/shared/params.go +++ b/pkg/cmd/pr/shared/params.go @@ -36,7 +36,7 @@ func WithPrAndIssueQueryParams(client *api.Client, baseRepo ghrepo.Interface, ba q.Set("labels", strings.Join(state.Labels, ",")) } if len(state.ProjectTitles) > 0 { - projectPaths, err := api.ProjectNamesToPaths(client, baseRepo, state.ProjectTitles, projectsV1Support) + projectPaths, err := api.ProjectTitlesToPaths(client, baseRepo, state.ProjectTitles, projectsV1Support) if err != nil { return "", fmt.Errorf("could not add to project: %w", err) } @@ -119,7 +119,7 @@ func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, par } params["labelIds"] = labelIDs - projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsToIDs(tb.ProjectTitles) + projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsTitlesToIDs(tb.ProjectTitles) if err != nil { return fmt.Errorf("could not add to project: %w", err) } @@ -312,3 +312,26 @@ func (r *MeReplacer) ReplaceSlice(handles []string) ([]string, error) { } return res, nil } + +// CopilotReplacer resolves usages of `@copilot` to Copilot's login. +type CopilotReplacer struct{} + +func NewCopilotReplacer() *CopilotReplacer { + return &CopilotReplacer{} +} + +func (r *CopilotReplacer) replace(handle string) string { + if strings.EqualFold(handle, "@copilot") { + return api.CopilotActorLogin + } + return handle +} + +// ReplaceSlice replaces usages of `@copilot` in a slice with Copilot's login. +func (r *CopilotReplacer) ReplaceSlice(handles []string) []string { + res := make([]string, len(handles)) + for i, h := range handles { + res[i] = r.replace(h) + } + return res +} diff --git a/pkg/cmd/pr/shared/params_test.go b/pkg/cmd/pr/shared/params_test.go index 15f00ca4f22..53eb6328fb6 100644 --- a/pkg/cmd/pr/shared/params_test.go +++ b/pkg/cmd/pr/shared/params_test.go @@ -187,6 +187,67 @@ func TestMeReplacer_Replace(t *testing.T) { } } +func TestCopilotReplacer_ReplaceSlice(t *testing.T) { + type args struct { + handles []string + } + tests := []struct { + name string + args args + want []string + }{ + { + name: "replaces @copilot with copilot-swe-agent", + args: args{ + handles: []string{"monalisa", "@copilot", "hubot"}, + }, + want: []string{"monalisa", "copilot-swe-agent", "hubot"}, + }, + { + name: "handles no @copilot mentions", + args: args{ + handles: []string{"monalisa", "user", "hubot"}, + }, + want: []string{"monalisa", "user", "hubot"}, + }, + { + name: "replaces multiple @copilot mentions", + args: args{ + handles: []string{"@copilot", "user", "@copilot"}, + }, + want: []string{"copilot-swe-agent", "user", "copilot-swe-agent"}, + }, + { + name: "handles @copilot case-insensitively", + args: args{ + handles: []string{"@Copilot", "user", "@CoPiLoT"}, + }, + want: []string{"copilot-swe-agent", "user", "copilot-swe-agent"}, + }, + { + name: "handles nil slice", + args: args{ + handles: nil, + }, + want: []string{}, + }, + { + name: "handles empty slice", + args: args{ + handles: []string{}, + }, + want: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := NewCopilotReplacer() + got := r.ReplaceSlice(tt.args.handles) + require.Equal(t, tt.want, got) + }) + } +} + func Test_QueryHasStateClause(t *testing.T) { tests := []struct { searchQuery string diff --git a/pkg/cmd/pr/shared/survey.go b/pkg/cmd/pr/shared/survey.go index bf4476ca1ed..b6c927a2d9b 100644 --- a/pkg/cmd/pr/shared/survey.go +++ b/pkg/cmd/pr/shared/survey.go @@ -192,7 +192,7 @@ func MetadataSurvey(p Prompt, io *iostreams.IOStreams, baseRepo ghrepo.Interface var reviewers []string for _, u := range metadataResult.AssignableUsers { - if u.Login != metadataResult.CurrentLogin { + if u.Login() != metadataResult.CurrentLogin { reviewers = append(reviewers, u.DisplayName()) } } diff --git a/pkg/cmd/pr/shared/survey_test.go b/pkg/cmd/pr/shared/survey_test.go index 6895b52ac99..7097d0761d4 100644 --- a/pkg/cmd/pr/shared/survey_test.go +++ b/pkg/cmd/pr/shared/survey_test.go @@ -28,9 +28,9 @@ func TestMetadataSurvey_selectAll(t *testing.T) { fetcher := &metadataFetcher{ metadataResult: &api.RepoMetadataResult{ - AssignableUsers: []api.RepoAssignee{ - {Login: "hubot"}, - {Login: "monalisa"}, + AssignableUsers: []api.AssignableUser{ + api.NewAssignableUser("", "hubot", ""), + api.NewAssignableUser("", "monalisa", ""), }, Labels: []api.RepoLabel{ {Name: "help wanted"}, diff --git a/pkg/cmd/pr/view/view.go b/pkg/cmd/pr/view/view.go index 997f74d877d..8a39d113463 100644 --- a/pkg/cmd/pr/view/view.go +++ b/pkg/cmd/pr/view/view.go @@ -10,6 +10,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/pr/shared" @@ -22,6 +23,9 @@ import ( type ViewOptions struct { IO *iostreams.IOStreams Browser browser.Browser + // TODO projectsV1Deprecation + // Remove this detector since it is only used for test validation. + Detector fd.Detector Finder shared.PRFinder Exporter cmdutil.Exporter @@ -89,6 +93,7 @@ func viewRun(opts *ViewOptions) error { findOptions := shared.FindOptions{ Selector: opts.SelectorArg, Fields: defaultFields, + Detector: opts.Detector, } if opts.BrowserMode { findOptions.Fields = []string{"url"} diff --git a/pkg/cmd/pr/view/view_test.go b/pkg/cmd/pr/view/view_test.go index 2cd4066b84b..35f7fa5136c 100644 --- a/pkg/cmd/pr/view/view_test.go +++ b/pkg/cmd/pr/view/view_test.go @@ -12,6 +12,7 @@ import ( "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/run" "github.com/cli/cli/v2/pkg/cmd/pr/shared" @@ -176,6 +177,9 @@ func runCommand(rt http.RoundTripper, branch string, isTTY bool, cli string) (*t factory := &cmdutil.Factory{ IOStreams: ios, Browser: browser, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: rt}, nil + }, } cmd := NewCmdView(factory, nil) @@ -404,7 +408,7 @@ func TestPRView_Preview_nontty(t *testing.T) { pr, err := prFromFixtures(tc.fixtures) require.NoError(t, err) - shared.RunCommandFinder("12", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "12", pr, ghrepo.New("OWNER", "REPO")) output, err := runCommand(http, tc.branch, false, tc.args) if err != nil { @@ -608,7 +612,7 @@ func TestPRView_Preview(t *testing.T) { pr, err := prFromFixtures(tc.fixtures) require.NoError(t, err) - shared.RunCommandFinder("12", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "12", pr, ghrepo.New("OWNER", "REPO")) output, err := runCommand(http, tc.branch, true, tc.args) if err != nil { @@ -631,7 +635,7 @@ func TestPRView_web_currentBranch(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", &api.PullRequest{URL: "https://github.com/OWNER/REPO/pull/10"}, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "", &api.PullRequest{URL: "https://github.com/OWNER/REPO/pull/10"}, ghrepo.New("OWNER", "REPO")) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -650,7 +654,7 @@ func TestPRView_web_noResultsForBranch(t *testing.T) { http := &httpmock.Registry{} defer http.Verify(t) - shared.RunCommandFinder("", nil, nil) + shared.StubFinderForRunCommandStyleTests(t, "", nil, nil) _, cmdTeardown := run.Stub() defer cmdTeardown(t) @@ -742,9 +746,9 @@ func TestPRView_tty_Comments(t *testing.T) { if len(tt.fixtures) > 0 { pr, err := prFromFixtures(tt.fixtures) require.NoError(t, err) - shared.RunCommandFinder("123", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, ghrepo.New("OWNER", "REPO")) } else { - shared.RunCommandFinder("123", nil, nil) + shared.StubFinderForRunCommandStyleTests(t, "123", nil, nil) } output, err := runCommand(http, tt.branch, true, tt.cli) @@ -853,9 +857,9 @@ func TestPRView_nontty_Comments(t *testing.T) { if len(tt.fixtures) > 0 { pr, err := prFromFixtures(tt.fixtures) require.NoError(t, err) - shared.RunCommandFinder("123", pr, ghrepo.New("OWNER", "REPO")) + shared.StubFinderForRunCommandStyleTests(t, "123", pr, ghrepo.New("OWNER", "REPO")) } else { - shared.RunCommandFinder("123", nil, nil) + shared.StubFinderForRunCommandStyleTests(t, "123", nil, nil) } output, err := runCommand(http, tt.branch, false, tt.cli) @@ -870,3 +874,74 @@ func TestPRView_nontty_Comments(t *testing.T) { }) } } + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + t.Run("when projects v1 is supported, is included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + httpmock.GraphQL(`projectCards`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + f := &cmdutil.Factory{ + IOStreams: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = viewRun(&ViewOptions{ + IO: ios, + Finder: shared.NewFinder(f), + Detector: &fd.EnabledDetectorMock{}, + + SelectorArg: "https://github.com/cli/cli/pull/123", + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, is not included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Exclude( + t, + httpmock.GraphQL(`projectCards`), + ) + + f := &cmdutil.Factory{ + IOStreams: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = viewRun(&ViewOptions{ + IO: ios, + Finder: shared.NewFinder(f), + Detector: &fd.DisabledDetectorMock{}, + + SelectorArg: "https://github.com/cli/cli/pull/123", + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) +} diff --git a/pkg/cmd/release/download/download.go b/pkg/cmd/release/download/download.go index cdf6135b633..b907214125b 100644 --- a/pkg/cmd/release/download/download.go +++ b/pkg/cmd/release/download/download.go @@ -165,10 +165,24 @@ func downloadRun(opts *DownloadOptions) error { var toDownload []shared.ReleaseAsset isArchive := false if opts.ArchiveType != "" { - var archiveURL = release.ZipballURL + var archiveURL string if opts.ArchiveType == "tar.gz" { archiveURL = release.TarballURL + } else { + archiveURL = release.ZipballURL + } + + if archiveURL == "" { + errMessage := fmt.Sprintf( + "release %q with tag %q, does not have a %q archive asset.", + release.Name, release.TagName, opts.ArchiveType, + ) + if release.IsDraft { + errMessage += " Most likely, this is because it is a draft." + } + return errors.New(errMessage) } + // create pseudo-Asset with no name and pointing to ZipBallURL or TarBallURL toDownload = append(toDownload, shared.ReleaseAsset{APIURL: archiveURL}) isArchive = true diff --git a/pkg/cmd/release/download/download_test.go b/pkg/cmd/release/download/download_test.go index 78709dd578b..9337c9b6561 100644 --- a/pkg/cmd/release/download/download_test.go +++ b/pkg/cmd/release/download/download_test.go @@ -183,6 +183,7 @@ func Test_downloadRun(t *testing.T) { name string isTTY bool opts DownloadOptions + httpStubs func(*httpmock.Registry) wantErr string wantStdout string wantStderr string @@ -196,6 +197,24 @@ func Test_downloadRun(t *testing.T) { Destination: ".", Concurrency: 2, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`)) + reg.Register(httpmock.REST("GET", "assets/3456"), httpmock.StringResponse(`3456`)) + reg.Register(httpmock.REST("GET", "assets/5678"), httpmock.StringResponse(`5678`)) + }, wantStdout: ``, wantStderr: ``, wantFiles: []string{ @@ -213,6 +232,23 @@ func Test_downloadRun(t *testing.T) { Destination: "tmp/assets", Concurrency: 2, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`)) + reg.Register(httpmock.REST("GET", "assets/3456"), httpmock.StringResponse(`3456`)) + }, wantStdout: ``, wantStderr: ``, wantFiles: []string{ @@ -229,6 +265,20 @@ func Test_downloadRun(t *testing.T) { Destination: ".", Concurrency: 2, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + }, wantStdout: ``, wantStderr: ``, wantErr: "no assets match the file pattern", @@ -242,6 +292,30 @@ func Test_downloadRun(t *testing.T) { Destination: "tmp/packages", Concurrency: 2, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register( + httpmock.REST( + "GET", + "repos/OWNER/REPO/zipball/v1.2.3", + ), + httpmock.WithHeader( + httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=zipball.zip", + ), + ) + }, wantStdout: ``, wantStderr: ``, wantFiles: []string{ @@ -257,6 +331,30 @@ func Test_downloadRun(t *testing.T) { Destination: "tmp/packages", Concurrency: 2, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register( + httpmock.REST( + "GET", + "repos/OWNER/REPO/tarball/v1.2.3", + ), + httpmock.WithHeader( + httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=tarball.tgz", + ), + ) + }, wantStdout: ``, wantStderr: ``, wantFiles: []string{ @@ -273,6 +371,30 @@ func Test_downloadRun(t *testing.T) { Concurrency: 2, ArchiveType: "tar.gz", }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register( + httpmock.REST( + "GET", + "repos/OWNER/REPO/tarball/v1.2.3", + ), + httpmock.WithHeader( + httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=tarball.tgz", + ), + ) + }, wantStdout: ``, wantStderr: ``, wantFiles: []string{ @@ -289,6 +411,22 @@ func Test_downloadRun(t *testing.T) { Concurrency: 2, FilePatterns: []string{"*windows-32bit.zip"}, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`)) + }, wantStdout: ``, wantStderr: ``, wantFiles: []string{ @@ -305,9 +443,85 @@ func Test_downloadRun(t *testing.T) { Concurrency: 2, FilePatterns: []string{"*windows-32bit.zip"}, }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", + "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" + }`) + + reg.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`)) + }, wantStdout: `1234`, wantStderr: ``, }, + { + name: "draft release with null tarball_url and zipball_url", + isTTY: true, + opts: DownloadOptions{ + TagName: "v1.2.3", + ArchiveType: "tar.gz", + Destination: "tmp/packages", + Concurrency: 2, + }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "tag_name": "v1.2.3", + "name": "patch-36", + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": null, + "zipball_url": null, + "draft": true + }`) + }, + wantStdout: ``, + wantStderr: ``, + wantErr: "release \"patch-36\" with tag \"v1.2.3\", does not have a \"tar.gz\" archive asset. Most likely, this is because it is a draft.", + }, + { + name: "non-draft release with null tarball_url and zipball_url", + isTTY: true, + opts: DownloadOptions{ + TagName: "v1.2.3", + ArchiveType: "tar.gz", + Destination: "tmp/packages", + Concurrency: 2, + }, + httpStubs: func(reg *httpmock.Registry) { + shared.StubFetchRelease(t, reg, "OWNER", "REPO", "v1.2.3", `{ + "tag_name": "v1.2.3", + "name": "patch-36", + "assets": [ + { "name": "windows-32bit.zip", "size": 12, + "url": "https://api.github.com/assets/1234" }, + { "name": "windows-64bit.zip", "size": 34, + "url": "https://api.github.com/assets/3456" }, + { "name": "linux.tgz", "size": 56, + "url": "https://api.github.com/assets/5678" } + ], + "tarball_url": null, + "zipball_url": null, + "draft": false + }`) + }, + wantStdout: ``, + wantStderr: ``, + wantErr: "release \"patch-36\" with tag \"v1.2.3\", does not have a \"tar.gz\" archive asset.", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -324,41 +538,11 @@ func Test_downloadRun(t *testing.T) { ios.SetStderrTTY(tt.isTTY) fakeHTTP := &httpmock.Registry{} - shared.StubFetchRelease(t, fakeHTTP, "OWNER", "REPO", tt.opts.TagName, `{ - "assets": [ - { "name": "windows-32bit.zip", "size": 12, - "url": "https://api.github.com/assets/1234" }, - { "name": "windows-64bit.zip", "size": 34, - "url": "https://api.github.com/assets/3456" }, - { "name": "linux.tgz", "size": 56, - "url": "https://api.github.com/assets/5678" } - ], - "tarball_url": "https://api.github.com/repos/OWNER/REPO/tarball/v1.2.3", - "zipball_url": "https://api.github.com/repos/OWNER/REPO/zipball/v1.2.3" - }`) - fakeHTTP.Register(httpmock.REST("GET", "assets/1234"), httpmock.StringResponse(`1234`)) - fakeHTTP.Register(httpmock.REST("GET", "assets/3456"), httpmock.StringResponse(`3456`)) - fakeHTTP.Register(httpmock.REST("GET", "assets/5678"), httpmock.StringResponse(`5678`)) - - fakeHTTP.Register( - httpmock.REST( - "GET", - "repos/OWNER/REPO/tarball/v1.2.3", - ), - httpmock.WithHeader( - httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=tarball.tgz", - ), - ) - - fakeHTTP.Register( - httpmock.REST( - "GET", - "repos/OWNER/REPO/zipball/v1.2.3", - ), - httpmock.WithHeader( - httpmock.StringResponse("somedata"), "content-disposition", "attachment; filename=zipball.zip", - ), - ) + defer fakeHTTP.Verify(t) + + if tt.httpStubs != nil { + tt.httpStubs(fakeHTTP) + } tt.opts.IO = ios tt.opts.HttpClient = func() (*http.Client, error) { diff --git a/pkg/cmd/repo/autolink/delete/http_test.go b/pkg/cmd/repo/autolink/delete/http_test.go index a2676178db4..a0aec5e131a 100644 --- a/pkg/cmd/repo/autolink/delete/http_test.go +++ b/pkg/cmd/repo/autolink/delete/http_test.go @@ -7,6 +7,7 @@ import ( "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" + "github.com/cli/go-gh/v2/pkg/api" "github.com/stretchr/testify/require" ) @@ -14,10 +15,10 @@ func TestAutolinkDeleter_Delete(t *testing.T) { repo := ghrepo.New("OWNER", "REPO") tests := []struct { - name string - id string - stubStatus int - stubRespJSON string + name string + id string + stubStatus int + stubResp any expectErr bool expectedErrMsg string @@ -31,17 +32,18 @@ func TestAutolinkDeleter_Delete(t *testing.T) { name: "404 repo or autolink not found", id: "123", stubStatus: http.StatusNotFound, - stubRespJSON: `{}`, // API response not used in output expectErr: true, expectedErrMsg: "error deleting autolink: HTTP 404: Perhaps you are missing admin rights to the repository? (https://api.github.com/repos/OWNER/REPO/autolinks/123)", }, { - name: "500 unexpected error", - id: "123", - stubRespJSON: `{"messsage": "arbitrary error"}`, + name: "500 unexpected error", + id: "123", + stubResp: api.HTTPError{ + Message: "arbitrary error", + }, stubStatus: http.StatusInternalServerError, expectErr: true, - expectedErrMsg: "HTTP 500 (https://api.github.com/repos/OWNER/REPO/autolinks/123)", + expectedErrMsg: "HTTP 500: arbitrary error (https://api.github.com/repos/OWNER/REPO/autolinks/123)", }, } @@ -53,7 +55,7 @@ func TestAutolinkDeleter_Delete(t *testing.T) { http.MethodDelete, fmt.Sprintf("repos/%s/%s/autolinks/%s", repo.RepoOwner(), repo.RepoName(), tt.id), ), - httpmock.StatusJSONResponse(tt.stubStatus, tt.stubRespJSON), + httpmock.StatusJSONResponse(tt.stubStatus, tt.stubResp), ) defer reg.Verify(t) diff --git a/pkg/cmd/run/watch/watch_test.go b/pkg/cmd/run/watch/watch_test.go index d42e8d3d80c..49e56217b41 100644 --- a/pkg/cmd/run/watch/watch_test.go +++ b/pkg/cmd/run/watch/watch_test.go @@ -316,7 +316,7 @@ func TestWatchRun(t *testing.T) { ) reg.Register( httpmock.REST("GET", "repos/OWNER/REPO/actions/runs/1234"), - httpmock.StatusJSONResponse(404, api.HTTPError{ + httpmock.JSONErrorResponse(404, api.HTTPError{ StatusCode: 404, Message: "run 1234 not found", }), diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index 745c1241743..3b03ae718fd 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -9,6 +9,8 @@ import ( "os" "regexp" "strings" + + "github.com/cli/go-gh/v2/pkg/api" ) type Matcher func(req *http.Request) bool @@ -161,6 +163,9 @@ func JSONResponse(body interface{}) Responder { } } +// StatusJSONResponse turns the given argument into a JSON response. +// +// The argument is not meant to be a JSON string, unless it's intentional. func StatusJSONResponse(status int, body interface{}) Responder { return func(req *http.Request) (*http.Response, error) { b, _ := json.Marshal(body) @@ -171,6 +176,12 @@ func StatusJSONResponse(status int, body interface{}) Responder { } } +// JSONErrorResponse is a type-safe helper to avoid confusion around the +// provided argument. +func JSONErrorResponse(status int, err api.HTTPError) Responder { + return StatusJSONResponse(status, err) +} + func FileResponse(filename string) Responder { return func(req *http.Request) (*http.Response, error) { f, err := os.Open(filename) diff --git a/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh b/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh index 647a13a4c9c..cea3c72286d 100644 --- a/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh +++ b/test/integration/attestation-cmd/verify/verify-with-internal-github-sigstore.sh @@ -14,3 +14,9 @@ if ! $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owne echo "Failed to verify" exit 1 fi + +# Try to verify when specifying a predicate type that does not match the attestation +if $ghBuildPath attestation verify "$ghCLIArtifact" --digest-alg=sha256 --owner=cli --predicate-type=my-custom-predicate-type; then + echo "Verification should have failed" + exit 1 +fi