diff --git a/pkg/github/discussions.go b/pkg/github/discussions.go index 68ed014b2..ab631ecb2 100644 --- a/pkg/github/discussions.go +++ b/pkg/github/discussions.go @@ -11,7 +11,6 @@ import ( "github.com/github/github-mcp-server/pkg/scopes" "github.com/github/github-mcp-server/pkg/translations" "github.com/github/github-mcp-server/pkg/utils" - "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v87/github" "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -313,15 +312,19 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool { }, []scopes.Scope{scopes.Repo}, func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Decode params - var params struct { - Owner string - Repo string - DiscussionNumber int32 + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - if err := mapstructure.WeakDecode(args, ¶ms); err != nil { + repo, err := RequiredParam[string](args, "repo") + if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } + discussionNumber, err := RequiredInt(args, "discussionNumber") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + client, err := deps.GetGQLClient(ctx) if err != nil { return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil @@ -345,9 +348,9 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool { } `graphql:"repository(owner: $owner, name: $repo)"` } vars := map[string]any{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), - "discussionNumber": githubv4.Int(params.DiscussionNumber), + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "discussionNumber": githubv4.Int(discussionNumber), } if err := client.Query(ctx, &q, vars); err != nil { return utils.NewToolResultError(err.Error()), nil, nil @@ -384,7 +387,7 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool { result := utils.NewToolResultText(string(out)) // Discussion content is user-authored (untrusted); confidentiality // follows repo visibility. - result = attachRepoVisibilityIFCLabelLazy(ctx, deps, params.Owner, params.Repo, result, ifc.LabelRepoUserContent) + result = attachRepoVisibilityIFCLabelLazy(ctx, deps, owner, repo, result, ifc.LabelRepoUserContent) return result, nil, nil }, ) @@ -425,13 +428,16 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve }, []scopes.Scope{scopes.Repo}, func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { - // Decode params - var params struct { - Owner string - Repo string - DiscussionNumber int32 + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil } - if err := mapstructure.WeakDecode(args, ¶ms); err != nil { + discussionNumber, err := RequiredInt(args, "discussionNumber") + if err != nil { return utils.NewToolResultError(err.Error()), nil, nil } @@ -467,9 +473,9 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve } vars := map[string]any{ - "owner": githubv4.String(params.Owner), - "repo": githubv4.String(params.Repo), - "discussionNumber": githubv4.Int(params.DiscussionNumber), + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "discussionNumber": githubv4.Int(discussionNumber), "first": githubv4.Int(*paginationParams.First), } if paginationParams.After != nil { @@ -592,7 +598,7 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve result := utils.NewToolResultText(string(out)) // Discussion comments are user-authored (untrusted); confidentiality // follows repo visibility. - result = attachRepoVisibilityIFCLabelLazy(ctx, deps, params.Owner, params.Repo, result, ifc.LabelRepoUserContent) + result = attachRepoVisibilityIFCLabelLazy(ctx, deps, owner, repo, result, ifc.LabelRepoUserContent) return result, nil, nil }, ) diff --git a/pkg/github/discussions_test.go b/pkg/github/discussions_test.go index 36fdb6c43..3a090500f 100644 --- a/pkg/github/discussions_test.go +++ b/pkg/github/discussions_test.go @@ -590,8 +590,47 @@ func Test_GetDiscussion(t *testing.T) { } } +func Test_GetDiscussionRequiredParams(t *testing.T) { + t.Parallel() + + toolDef := GetDiscussion(translations.NullTranslationHelper) + handler := toolDef.Handler(BaseDeps{GQLClient: githubv4.NewClient(githubv4mock.NewMockedHTTPClient())}) + + tests := []struct { + name string + requestArgs map[string]any + expectedErrMsg string + }{ + { + name: "missing owner", + requestArgs: map[string]any{"repo": "repo", "discussionNumber": float64(1)}, + expectedErrMsg: "missing required parameter: owner", + }, + { + name: "missing repo", + requestArgs: map[string]any{"owner": "owner", "discussionNumber": float64(1)}, + expectedErrMsg: "missing required parameter: repo", + }, + { + name: "missing discussionNumber", + requestArgs: map[string]any{"owner": "owner", "repo": "repo"}, + expectedErrMsg: "missing required parameter: discussionNumber", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := createMCPRequest(tc.requestArgs) + res, err := handler(ContextWithDeps(context.Background(), BaseDeps{}), &req) + require.NoError(t, err) + require.True(t, res.IsError) + assert.Contains(t, getTextResult(t, res).Text, tc.expectedErrMsg) + }) + } +} + func Test_GetDiscussionWithStringNumber(t *testing.T) { - // Test that WeakDecode handles string discussionNumber from MCP clients + // Test that RequiredInt handles string discussionNumber from MCP clients toolDef := GetDiscussion(translations.NullTranslationHelper) qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,closed,isAnswered,answerChosenAt,url,category{name}}}}" @@ -723,8 +762,47 @@ func Test_GetDiscussionComments(t *testing.T) { assert.Equal(t, "This is the second comment", response.Comments[1].Body) } +func Test_GetDiscussionCommentsRequiredParams(t *testing.T) { + t.Parallel() + + toolDef := GetDiscussionComments(translations.NullTranslationHelper) + handler := toolDef.Handler(BaseDeps{GQLClient: githubv4.NewClient(githubv4mock.NewMockedHTTPClient())}) + + tests := []struct { + name string + requestArgs map[string]any + expectedErrMsg string + }{ + { + name: "missing owner", + requestArgs: map[string]any{"repo": "repo", "discussionNumber": float64(1)}, + expectedErrMsg: "missing required parameter: owner", + }, + { + name: "missing repo", + requestArgs: map[string]any{"owner": "owner", "discussionNumber": float64(1)}, + expectedErrMsg: "missing required parameter: repo", + }, + { + name: "missing discussionNumber", + requestArgs: map[string]any{"owner": "owner", "repo": "repo"}, + expectedErrMsg: "missing required parameter: discussionNumber", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := createMCPRequest(tc.requestArgs) + res, err := handler(ContextWithDeps(context.Background(), BaseDeps{}), &req) + require.NoError(t, err) + require.True(t, res.IsError) + assert.Contains(t, getTextResult(t, res).Text, tc.expectedErrMsg) + }) + } +} + func Test_GetDiscussionCommentsWithStringNumber(t *testing.T) { - // Test that WeakDecode handles string discussionNumber from MCP clients + // Test that RequiredInt handles string discussionNumber from MCP clients toolDef := GetDiscussionComments(translations.NullTranslationHelper) qGetComments := "query($after:String$discussionNumber:Int!$first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){comments(first: $first, after: $after){nodes{id,body,isAnswer},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}}"