Skip to content

Commit 33a63a0

Browse files
authored
Add support for org-level discussions in list_discussion_categories tool (#819)
* hide implementation detail for org-level queries * update tests * autogen * made tests consistent with other tests
1 parent ff6e859 commit 33a63a0

File tree

3 files changed

+124
-43
lines changed

3 files changed

+124
-43
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ The following sets of tools are available (all are on by default):
458458

459459
- **list_discussion_categories** - List discussion categories
460460
- `owner`: Repository owner (string, required)
461-
- `repo`: Repository name (string, required)
461+
- `repo`: Repository name. If not provided, discussion categories will be queried at the organisation level. (string, optional)
462462

463463
- **list_discussions** - List discussions
464464
- `after`: Cursor for pagination. Use the endCursor from the previous page's PageInfo for GraphQL APIs. (string, optional)

pkg/github/discussions.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ func GetDiscussionComments(getGQLClient GetGQLClientFn, t translations.Translati
443443

444444
func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
445445
return mcp.NewTool("list_discussion_categories",
446-
mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository")),
446+
mcp.WithDescription(t("TOOL_LIST_DISCUSSION_CATEGORIES_DESCRIPTION", "List discussion categories with their id and name, for a repository or organisation.")),
447447
mcp.WithToolAnnotation(mcp.ToolAnnotation{
448448
Title: t("TOOL_LIST_DISCUSSION_CATEGORIES_USER_TITLE", "List discussion categories"),
449449
ReadOnlyHint: ToBoolPtr(true),
@@ -453,19 +453,23 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl
453453
mcp.Description("Repository owner"),
454454
),
455455
mcp.WithString("repo",
456-
mcp.Required(),
457-
mcp.Description("Repository name"),
456+
mcp.Description("Repository name. If not provided, discussion categories will be queried at the organisation level."),
458457
),
459458
),
460459
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
461-
// Decode params
462-
var params struct {
463-
Owner string
464-
Repo string
460+
owner, err := RequiredParam[string](request, "owner")
461+
if err != nil {
462+
return mcp.NewToolResultError(err.Error()), nil
465463
}
466-
if err := mapstructure.Decode(request.Params.Arguments, &params); err != nil {
464+
repo, err := OptionalParam[string](request, "repo")
465+
if err != nil {
467466
return mcp.NewToolResultError(err.Error()), nil
468467
}
468+
// when not provided, default to the .github repository
469+
// this will query discussion categories at the organisation level
470+
if repo == "" {
471+
repo = ".github"
472+
}
469473

470474
client, err := getGQLClient(ctx)
471475
if err != nil {
@@ -490,8 +494,8 @@ func ListDiscussionCategories(getGQLClient GetGQLClientFn, t translations.Transl
490494
} `graphql:"repository(owner: $owner, name: $repo)"`
491495
}
492496
vars := map[string]interface{}{
493-
"owner": githubv4.String(params.Owner),
494-
"repo": githubv4.String(params.Repo),
497+
"owner": githubv4.String(owner),
498+
"repo": githubv4.String(repo),
495499
"first": githubv4.Int(25),
496500
}
497501
if err := client.Query(ctx, &q, vars); err != nil {

pkg/github/discussions_test.go

Lines changed: 109 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ func Test_GetDiscussion(t *testing.T) {
484484
assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner", "repo", "discussionNumber"})
485485

486486
// Use exact string query that matches implementation output
487-
qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}"
487+
qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,url,category{name}}}}"
488488

489489
vars := map[string]interface{}{
490490
"owner": "owner",
@@ -638,17 +638,33 @@ func Test_GetDiscussionComments(t *testing.T) {
638638
}
639639

640640
func Test_ListDiscussionCategories(t *testing.T) {
641+
mockClient := githubv4.NewClient(nil)
642+
toolDef, _ := ListDiscussionCategories(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper)
643+
assert.Equal(t, "list_discussion_categories", toolDef.Name)
644+
assert.NotEmpty(t, toolDef.Description)
645+
assert.Contains(t, toolDef.Description, "or organisation")
646+
assert.Contains(t, toolDef.InputSchema.Properties, "owner")
647+
assert.Contains(t, toolDef.InputSchema.Properties, "repo")
648+
assert.ElementsMatch(t, toolDef.InputSchema.Required, []string{"owner"})
649+
641650
// Use exact string query that matches implementation output
642651
qListCategories := "query($first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussionCategories(first: $first){nodes{id,name},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}"
643652

644-
// Variables matching what GraphQL receives after JSON marshaling/unmarshaling
645-
vars := map[string]interface{}{
653+
// Variables for repository-level categories
654+
varsRepo := map[string]interface{}{
646655
"owner": "owner",
647656
"repo": "repo",
648657
"first": float64(25),
649658
}
650659

651-
mockResp := githubv4mock.DataResponse(map[string]any{
660+
// Variables for organization-level categories (using .github repo)
661+
varsOrg := map[string]interface{}{
662+
"owner": "owner",
663+
"repo": ".github",
664+
"first": float64(25),
665+
}
666+
667+
mockRespRepo := githubv4mock.DataResponse(map[string]any{
652668
"repository": map[string]any{
653669
"discussionCategories": map[string]any{
654670
"nodes": []map[string]any{
@@ -665,37 +681,98 @@ func Test_ListDiscussionCategories(t *testing.T) {
665681
},
666682
},
667683
})
668-
matcher := githubv4mock.NewQueryMatcher(qListCategories, vars, mockResp)
669-
httpClient := githubv4mock.NewMockedHTTPClient(matcher)
670-
gqlClient := githubv4.NewClient(httpClient)
671684

672-
tool, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)
673-
assert.Equal(t, "list_discussion_categories", tool.Name)
674-
assert.NotEmpty(t, tool.Description)
675-
assert.Contains(t, tool.InputSchema.Properties, "owner")
676-
assert.Contains(t, tool.InputSchema.Properties, "repo")
677-
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo"})
685+
mockRespOrg := githubv4mock.DataResponse(map[string]any{
686+
"repository": map[string]any{
687+
"discussionCategories": map[string]any{
688+
"nodes": []map[string]any{
689+
{"id": "789", "name": "Announcements"},
690+
{"id": "101", "name": "General"},
691+
{"id": "112", "name": "Ideas"},
692+
},
693+
"pageInfo": map[string]any{
694+
"hasNextPage": false,
695+
"hasPreviousPage": false,
696+
"startCursor": "",
697+
"endCursor": "",
698+
},
699+
"totalCount": 3,
700+
},
701+
},
702+
})
678703

679-
request := createMCPRequest(map[string]interface{}{"owner": "owner", "repo": "repo"})
680-
result, err := handler(context.Background(), request)
681-
require.NoError(t, err)
704+
tests := []struct {
705+
name string
706+
reqParams map[string]interface{}
707+
vars map[string]interface{}
708+
mockResponse githubv4mock.GQLResponse
709+
expectError bool
710+
expectedCount int
711+
expectedCategories []map[string]string
712+
}{
713+
{
714+
name: "list repository-level discussion categories",
715+
reqParams: map[string]interface{}{
716+
"owner": "owner",
717+
"repo": "repo",
718+
},
719+
vars: varsRepo,
720+
mockResponse: mockRespRepo,
721+
expectError: false,
722+
expectedCount: 2,
723+
expectedCategories: []map[string]string{
724+
{"id": "123", "name": "CategoryOne"},
725+
{"id": "456", "name": "CategoryTwo"},
726+
},
727+
},
728+
{
729+
name: "list org-level discussion categories (no repo provided)",
730+
reqParams: map[string]interface{}{
731+
"owner": "owner",
732+
// repo is not provided, it will default to ".github"
733+
},
734+
vars: varsOrg,
735+
mockResponse: mockRespOrg,
736+
expectError: false,
737+
expectedCount: 3,
738+
expectedCategories: []map[string]string{
739+
{"id": "789", "name": "Announcements"},
740+
{"id": "101", "name": "General"},
741+
{"id": "112", "name": "Ideas"},
742+
},
743+
},
744+
}
682745

683-
text := getTextResult(t, result).Text
746+
for _, tc := range tests {
747+
t.Run(tc.name, func(t *testing.T) {
748+
matcher := githubv4mock.NewQueryMatcher(qListCategories, tc.vars, tc.mockResponse)
749+
httpClient := githubv4mock.NewMockedHTTPClient(matcher)
750+
gqlClient := githubv4.NewClient(httpClient)
684751

685-
var response struct {
686-
Categories []map[string]string `json:"categories"`
687-
PageInfo struct {
688-
HasNextPage bool `json:"hasNextPage"`
689-
HasPreviousPage bool `json:"hasPreviousPage"`
690-
StartCursor string `json:"startCursor"`
691-
EndCursor string `json:"endCursor"`
692-
} `json:"pageInfo"`
693-
TotalCount int `json:"totalCount"`
752+
_, handler := ListDiscussionCategories(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)
753+
754+
req := createMCPRequest(tc.reqParams)
755+
res, err := handler(context.Background(), req)
756+
text := getTextResult(t, res).Text
757+
758+
if tc.expectError {
759+
require.True(t, res.IsError)
760+
return
761+
}
762+
require.NoError(t, err)
763+
764+
var response struct {
765+
Categories []map[string]string `json:"categories"`
766+
PageInfo struct {
767+
HasNextPage bool `json:"hasNextPage"`
768+
HasPreviousPage bool `json:"hasPreviousPage"`
769+
StartCursor string `json:"startCursor"`
770+
EndCursor string `json:"endCursor"`
771+
} `json:"pageInfo"`
772+
TotalCount int `json:"totalCount"`
773+
}
774+
require.NoError(t, json.Unmarshal([]byte(text), &response))
775+
assert.Equal(t, tc.expectedCategories, response.Categories)
776+
})
694777
}
695-
require.NoError(t, json.Unmarshal([]byte(text), &response))
696-
assert.Len(t, response.Categories, 2)
697-
assert.Equal(t, "123", response.Categories[0]["id"])
698-
assert.Equal(t, "CategoryOne", response.Categories[0]["name"])
699-
assert.Equal(t, "456", response.Categories[1]["id"])
700-
assert.Equal(t, "CategoryTwo", response.Categories[1]["name"])
701778
}

0 commit comments

Comments
 (0)